Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Per-session kore-rpc server state #3702

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cabal.project.freeze
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ constraints: any.Cabal ==3.6.3.0,
any.json-rpc ==1.0.4,
any.junit-xml ==0.1.0.2,
any.kan-extensions ==5.2.5,
kore -threaded,
kore +threaded,
any.lens ==5.1.1,
lens -benchmark-uniplate -dump-splices +inlining -j +test-hunit +test-properties +test-templates +trustworthy,
any.libyaml ==0.1.2,
Expand Down
39 changes: 28 additions & 11 deletions kore-rpc-types/src/Kore/JsonRpc/Server.hs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ module Kore.JsonRpc.Server (
JsonRpcHandler (..),
) where

import Control.Concurrent (forkIO, throwTo)
import Control.Concurrent (forkIO, newMVar, readMVar, swapMVar, throwTo)
import Control.Concurrent.STM.TChan (newTChan, readTChan, writeTChan)
import Control.Exception (Exception (fromException), catch, mask, throw)
import Control.Monad (forM_, forever)
import Control.Monad (forM_, forever, (>=>))
import Control.Monad.IO.Class (MonadIO (liftIO))
import Control.Monad.Logger (MonadLoggerIO)
import Control.Monad.Logger qualified as Log
Expand Down Expand Up @@ -79,11 +79,13 @@ jsonRpcServer ::
(MonadLoggerIO m, MonadUnliftIO m, FromRequestCancellable q, ToJSON r) =>
-- | Connection settings
ServerSettings ->
-- | Init session state
state ->
-- | Action to perform on connecting client thread
(Request -> Respond q (Log.LoggingT IO) r) ->
(Request -> state -> Respond q (Log.LoggingT IO) (r, Maybe state)) ->
[JsonRpcHandler] ->
m a
jsonRpcServer serverSettings respond handlers =
jsonRpcServer serverSettings initState respond handlers =
runGeneralTCPServer serverSettings $ \cl ->
runJSONRPCT
-- we have to ensure that the returned messages contain no newlines
Expand All @@ -93,17 +95,19 @@ jsonRpcServer serverSettings respond handlers =
False
(appSink cl)
(appSource cl)
(srv respond handlers)
(srv initState respond handlers)

data JsonRpcHandler = forall e. Exception e => JsonRpcHandler (e -> Log.LoggingT IO ErrorObj)

srv ::
forall m q r.
forall m q r state.
(MonadLoggerIO m, FromRequestCancellable q, ToJSON r) =>
(Request -> Respond q (Log.LoggingT IO) r) ->
state ->
(Request -> state -> Respond q (Log.LoggingT IO) (r, Maybe state)) ->
[JsonRpcHandler] ->
JSONRPCT m ()
srv respond handlers = do
srv initState respond handlers = do
state <- liftIO $ newMVar initState
reqQueue <- liftIO $ atomically newTChan
let mainLoop tid =
let loop =
Expand All @@ -121,7 +125,7 @@ srv respond handlers = do
liftIO $ atomically $ writeTChan reqQueue req
loop
in loop
spawnWorker reqQueue >>= mainLoop
spawnWorker state reqQueue >>= mainLoop
where
isRequest = \case
Request{} -> True
Expand All @@ -138,7 +142,7 @@ srv respond handlers = do

cancelError = ErrorObj "Request cancelled" (-32000) Null

spawnWorker reqQueue = do
spawnWorker stateMVar reqQueue = do
rpcSession <- ask
logger <- Log.askLoggerIO
let withLog :: Log.LoggingT IO a -> IO a
Expand All @@ -148,7 +152,20 @@ srv respond handlers = do
sendResponses r = flip runReaderT rpcSession $ sendBatchResponse r

respondTo :: Request -> Log.LoggingT IO (Maybe Response)
respondTo req = buildResponse (respond req) req
respondTo req = do
state <- liftIO $ readMVar stateMVar
buildResponse
( respond req state
>=> ( \case
Left err -> pure $ Left err
Right (res, Nothing) -> pure $ Right res
Right (res, Just newState) ->
do
_ <- liftIO $ swapMVar stateMVar newState
pure $ Right res
)
)
req

cancelReq :: ErrorObj -> BatchRequest -> Log.LoggingT IO ()
cancelReq err = \case
Expand Down
15 changes: 6 additions & 9 deletions kore/app/rpc/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

module Main (main) where

import Control.Concurrent.MVar as MVar
import Control.Exception (AsyncException (..))
import Control.Monad.Catch (
bracket,
Expand Down Expand Up @@ -175,18 +174,16 @@ koreRpcServerRun GlobalMain.LocalOptions{execOptions} = do
lift $ writeIORef globalInternedTextCache internedTextCache

loadedDefinition <- GlobalMain.loadDefinitions [definitionFileName]
serverState <-
lift $
MVar.newMVar
ServerState
{ serializedModules = Map.singleton mainModuleName sd
, loadedDefinition
}
let initServerState =
ServerState
{ serializedModules = Map.singleton mainModuleName sd
, loadedDefinition
}
GlobalMain.clockSomethingIO "Executing" $
-- wrap the call to runServer in the logger monad
Log.LoggerT $
ReaderT $
\loggerEnv -> runServer port serverState mainModuleName (runSMT loggerEnv) loggerEnv
\loggerEnv -> runServer port initServerState mainModuleName (runSMT loggerEnv) loggerEnv

pure ExitSuccess
where
Expand Down
75 changes: 38 additions & 37 deletions kore/src/Kore/JsonRpc.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ module Kore.JsonRpc (
module Kore.JsonRpc,
) where

import Control.Concurrent.MVar qualified as MVar
import Control.Monad.Except (runExceptT)
import Control.Monad.Logger (runLoggingT)
import Data.Aeson.Types (ToJSON (..))
Expand Down Expand Up @@ -105,15 +104,15 @@ import System.Clock (Clock (Monotonic), diffTimeSpec, getTime, toNanoSecs)
respond ::
forall m.
MonadIO m =>
MVar.MVar ServerState ->
ServerState ->
ModuleName ->
( forall a.
SmtMetadataTools StepperAttributes ->
[SentenceAxiom (TermLike VariableName)] ->
SMT.SMT a ->
IO a
) ->
Respond (API 'Req) m (API 'Res)
Respond (API 'Req) m (API 'Res, Maybe ServerState)
respond serverState moduleName runSMT =
\case
Execute
Expand Down Expand Up @@ -158,7 +157,7 @@ respond serverState moduleName runSMT =
Just $
fromIntegral (toNanoSecs (diffTimeSpec stop start)) / 1e9
else Nothing
pure $ buildResult duration (TermLike.termLikeSort verifiedPattern) traversalResult
pure $ (,Nothing) <$> buildResult duration (TermLike.termLikeSort verifiedPattern) traversalResult
where
toStopLabels :: Maybe [Text] -> Maybe [Text] -> Exec.StopLabels
toStopLabels cpRs tRs =
Expand Down Expand Up @@ -376,7 +375,7 @@ respond serverState moduleName runSMT =
if (fromMaybe False logTiming)
then maybe (Just [timeLog]) (Just . (timeLog :)) simplLogs
else simplLogs
pure $ buildResult allLogs sort result
pure $ (,Nothing) <$> buildResult allLogs sort result
where
verify = do
antVerified <-
Expand Down Expand Up @@ -446,22 +445,24 @@ respond serverState moduleName runSMT =
then maybe (Just [timeLog]) (Just . (timeLog :)) simplLogs
else simplLogs
pure $
Right $
Simplify
Right
( Simplify
SimplifyResult
{ state =
PatternJson.fromTermLike $
TermLike.mapVariables getRewritingVariable $
OrPattern.toTermLike sort result
, logs = allLogs
}
, Nothing
)
AddModule AddModuleRequest{_module} ->
case parseKoreModule "<add-module>" _module of
Left err -> pure $ Left $ backendError CouldNotParsePattern err
Right parsedModule@Module{moduleName = name} -> do
LoadedDefinition{indexedModules, definedNames, kFileLocations} <-
liftIO $ loadedDefinition <$> MVar.readMVar serverState
let verified =
let LoadedDefinition{indexedModules, definedNames, kFileLocations} =
loadedDefinition serverState
verified =
verifyAndIndexDefinitionWithBase
(indexedModules, definedNames)
Builtin.koreVerifiers
Expand All @@ -480,29 +481,28 @@ respond serverState moduleName runSMT =
$ Exec.makeSerializedModule mainModule
internedTextCache <- liftIO $ readIORef globalInternedTextCache

liftIO . MVar.modifyMVar_ serverState $
\ServerState{serializedModules} -> do
let serializedDefinition =
SerializedDefinition
{ serializedModule = serializedModule'
, locations = kFileLocations
, internedTextCache
, lemmas
}
loadedDefinition =
LoadedDefinition
{ indexedModules = indexedModules'
, definedNames = definedNames'
, kFileLocations
}
pure
ServerState
{ serializedModules =
Map.insert (coerce name) serializedDefinition serializedModules
, loadedDefinition
}
let ServerState{serializedModules} = serverState
serializedDefinition =
SerializedDefinition
{ serializedModule = serializedModule'
, locations = kFileLocations
, internedTextCache
, lemmas
}
loadedDefinition =
LoadedDefinition
{ indexedModules = indexedModules'
, definedNames = definedNames'
, kFileLocations
}
newServerState =
ServerState
{ serializedModules =
Map.insert (coerce name) serializedDefinition serializedModules
, loadedDefinition
}

pure . Right . AddModule $ AddModuleResult (getModuleName name)
pure $ Right (AddModule $ AddModuleResult (getModuleName name), Just newServerState)
GetModel GetModelRequest{state, _module} ->
withMainModule (coerce _module) $ \serializedModule lemmas ->
case verifyIn serializedModule state of
Expand All @@ -525,7 +525,7 @@ respond serverState moduleName runSMT =
. SMT.Evaluator.getModelFor tools
$ NonEmpty.fromList preds

pure . Right . GetModel $
pure . Right . (,Nothing) . GetModel $
case result of
Left False ->
GetModelResult
Expand All @@ -550,7 +550,7 @@ respond serverState moduleName runSMT =
where
withMainModule module' act = do
let mainModule = fromMaybe moduleName module'
ServerState{serializedModules} <- liftIO $ MVar.readMVar serverState
ServerState{serializedModules} = serverState
case Map.lookup mainModule serializedModules of
Nothing -> pure $ Left $ backendError CouldNotFindModule mainModule
Just (SerializedDefinition{serializedModule, lemmas}) ->
Expand Down Expand Up @@ -622,7 +622,7 @@ data ServerState = ServerState

runServer ::
Int ->
MVar.MVar ServerState ->
ServerState ->
ModuleName ->
( forall a.
SmtMetadataTools StepperAttributes ->
Expand All @@ -632,11 +632,12 @@ runServer ::
) ->
Log.LoggerEnv IO ->
IO ()
runServer port serverState mainModule runSMT Log.LoggerEnv{logAction} = do
runServer port initServerState mainModule runSMT Log.LoggerEnv{logAction} = do
flip runLoggingT logFun $
jsonRpcServer
srvSettings
( \req parsed ->
initServerState
( \req serverState parsed ->
log (InfoJsonRpcProcessRequest (getReqId req) parsed)
>> respond serverState mainModule runSMT parsed
)
Expand Down
Loading