From 8b7ed6701e107d14c3761b9882954b6630920a07 Mon Sep 17 00:00:00 2001 From: Jost Berthold Date: Wed, 22 May 2024 22:11:53 +1000 Subject: [PATCH] 3863 llvm term cache (#3882) This PR introduces a cache for the terms returned from calls to `Booster.LLVM.Internal.API.simplify`. Unpacking now uses a term store for the unpacked terms, and recognises shared terms by a shallow index into this store, using `TermF Int` as the map key (where the `Int` are indexes of symbol application/injection arguments in the term store). In small targeted tests using requests from MX-backend proofs , memory consumption was noticeably reduced, also resulting in better performance. Currently, the cache only lives for the duration of one LLVM call (not across different calls), no global variables or unsafe IO is required. Fixes #3863 --- booster/library/Booster/LLVM/Internal.hs | 2 +- .../library/Booster/Pattern/ApplyEquations.hs | 7 +- booster/library/Booster/Pattern/Binary.hs | 301 +++++++++++++----- 3 files changed, 216 insertions(+), 94 deletions(-) diff --git a/booster/library/Booster/LLVM/Internal.hs b/booster/library/Booster/LLVM/Internal.hs index 433d5065d6..ba5506721b 100644 --- a/booster/library/Booster/LLVM/Internal.hs +++ b/booster/library/Booster/LLVM/Internal.hs @@ -51,7 +51,7 @@ import System.Posix.DynamicLinker qualified as Linker import Booster.LLVM.TH (dynamicBindings) import Booster.Pattern.Base -import Booster.Pattern.Binary hiding (Block) +import Booster.Pattern.Binary import Booster.Pattern.Util (sortOfTerm) import Booster.Trace import Booster.Trace qualified as Trace diff --git a/booster/library/Booster/Pattern/ApplyEquations.hs b/booster/library/Booster/Pattern/ApplyEquations.hs index 3c1b4cc876..7486e53f64 100644 --- a/booster/library/Booster/Pattern/ApplyEquations.hs +++ b/booster/library/Booster/Pattern/ApplyEquations.hs @@ -501,12 +501,9 @@ llvmSimplify term = do withTermContext result $ emitEquationTrace t Nothing (Just "LLVM") Nothing $ Success result - toCache LLVM t result pure result - | otherwise = do - result <- cb t - toCache LLVM t result - pure result + | otherwise = + cb t ---------------------------------------- -- Interface functions diff --git a/booster/library/Booster/Pattern/Binary.hs b/booster/library/Booster/Pattern/Binary.hs index 7cf9e55744..f865b2835b 100644 --- a/booster/library/Booster/Pattern/Binary.hs +++ b/booster/library/Booster/Pattern/Binary.hs @@ -1,27 +1,24 @@ {-# LANGUAGE PatternSynonyms #-} {- | -Copyright : (c) Runtime Verification, 2023 +Copyright : (c) Runtime Verification, 2023- License : BSD-3-Clause -} module Booster.Pattern.Binary ( Version (..), - Block (..), decodeTerm, decodeTerm', decodePattern, encodeMagicHeaderAndVersion, encodePattern, encodeTerm, - encodeSingleBlock, - decodeSingleBlock, ) where import Control.Monad (forM_, unless) import Control.Monad.Extra (forM) import Control.Monad.Trans.Class (MonadTrans (..)) import Control.Monad.Trans.Reader (ReaderT (runReaderT), ask, asks) -import Control.Monad.Trans.State (StateT, evalStateT, gets, modify) +import Control.Monad.Trans.State (StateT, evalStateT, get, gets, modify, put) import Data.Binary.Get import Data.Binary.Put import Data.Bits (Bits (complement, shiftL, (.&.), (.|.)), shiftR) @@ -30,6 +27,8 @@ import Data.ByteString qualified as BS import Data.Int (Int16) import Data.List (intercalate) import Data.Map qualified as Map +import Data.Sequence (Seq (..)) +import Data.Sequence qualified as Seq import Data.Set qualified as Set import Data.Word (Word64) import GHC.Word (Word8) @@ -71,9 +70,34 @@ data Version = Version instance Show Version where show version = printf "%d.%d.%d" version.major version.minor version.patch +{- | The marshalling algorithm below unpacks terms into a global term + store, indexed by "shallow terms", i.e., where arguments in + application nodes are replaced by indexes into a global term store. + +@ShallowTerm@s can be either simple data without recursion +(@DomainValue@, @Var@), or recursive data (@SymbolApplication@, +@AndTerm@, @Injection@), where the arguments are initially @Int@s +instead of @Term@s (internal collection types are constructed from the +unpacked symbol applications later, and not expected to occur here). + +A lookup map for these @ShallowTerm@ indexes is maintained while +unpacking blocks. + +Unpacked @ShallowTerm@s that are not found in the lookup map are added +to both the lookup map (as @ShallowTerm@s) and to the term store (as +@Term@s, resolving arguments), and the new index is returned. + +Conversely, when an unpacked @ShallowTerm@ has occurred before, the +previously-added index is returned. +-} +newtype ShallowTerm = ShallowTerm (TermF Idx) + deriving (Eq, Ord, Show) + +type Idx = Int + data Block - = BTerm Term - | BPredicate Predicate + = BTerm Idx + | BPredicate Idx -- Predicate -- problem?!? | BString ByteString | BSort Sort | BSymbol ByteString [Sort] @@ -82,6 +106,11 @@ data Block data DecoderState = DecoderState { internedStrings :: Map.Map Int BS.ByteString , stack :: [Block] + , termStore :: Seq Term + -- ^ remembers all unpacked terms in an append-only list. Needs to + -- ensure subterms will be shared. + , termCache :: Map.Map ShallowTerm Idx + -- ^ lookup index into termStore } deriving (Show) @@ -89,7 +118,10 @@ newtype DecodeM a = DecodeM {unDecodeM :: ReaderT (Version, Maybe KoreDefinition deriving newtype (Functor, Applicative, Monad, MonadFail) runDecodeM :: Version -> Maybe KoreDefinition -> DecodeM a -> Get a -runDecodeM v mDef = flip evalStateT (DecoderState mempty mempty) . flip runReaderT (v, mDef) . unDecodeM +runDecodeM v mDef = + flip evalStateT (DecoderState mempty mempty mempty mempty) + . flip runReaderT (v, mDef) + . unDecodeM liftDecode :: Get a -> DecodeM a liftDecode m = DecodeM $ lift $ lift m @@ -106,8 +138,40 @@ insertInternedString pos str = lift $ modify (\s@DecoderState{internedStrings} -> s{internedStrings = Map.insert pos str internedStrings}) -areCompatible :: Version -> Version -> Bool -areCompatible a b = a.major == b.major && a.minor == b.minor +-- | Insert a new item into the term store or return a previously-seen instance +registerTerm :: ShallowTerm -> DecodeM (Idx, Term) +registerTerm shallow = DecodeM $ do + ds@DecoderState{termCache, termStore} <- lift get + case Map.lookup shallow termCache of + Just idx -> pure (idx, termStore `Seq.index` idx) + Nothing -> do + let !new = resolve termStore shallow -- strict: fail early on inconsistent data + newStore = termStore :|> new + newIdx = Seq.length termStore -- NB index 0-based + newCache = Map.insert shallow newIdx termCache + lift $ put ds{termStore = newStore, termCache = newCache} + pure (newIdx, new) + +getTerm :: Idx -> DecodeM Term +getTerm idx = DecodeM $ do + store <- lift $ gets termStore + pure $ store `Seq.index` idx + +{- | Resolves indexes into the term store in a shallow term, assuming +all indexes exist. Returns the resolved (full) term. +-} +resolve :: Seq Term -> ShallowTerm -> Term +resolve store (ShallowTerm shallow) = case shallow of + AndTermF i1 i2 -> AndTerm (fromStore i1) (fromStore i2) + SymbolApplicationF sym sorts is -> SymbolApplication sym sorts $ map fromStore is + DomainValueF sort payload -> DomainValue sort payload + VarF v -> Var v + InjectionF s1 s2 i -> Injection s1 s2 (fromStore i) + other -> error $ "Unexpected shallow term " <> show other + where + fromStore = Seq.index store + +------------------------------------------------------------ {- | Length (non-negative integer) is encoded in one of two special formats (depending on the version). @@ -126,6 +190,9 @@ decodeLength l = do then readAndShift l 0x0 else readAndShiftV2 True 0 0 where + areCompatible :: Version -> Version -> Bool + areCompatible a b = a.major == b.major && a.minor == b.minor + readAndShift :: Int -> Int -> Get Int readAndShift counter ret | counter > 0 = do @@ -269,7 +336,8 @@ decodeBlock mbSize = do KOREVariable -> do var <- decodeString [sort] <- popStackSorts 1 - pushStack $ BTerm $ Var $ Variable sort var + (idx, _) <- registerTerm $ ShallowTerm $ VarF $ Variable sort var + pushStack $ BTerm idx h -> fail $ "Invalid header " <> show h getStack @@ -292,56 +360,100 @@ decodeBlock mbSize = do True -> pure () False -> m >> whileNotEnded m + -- The workhorse in this decoder + -- The term cache is managed here, so that we can keep the + -- distinction between terms and predicates confined. mkSymbolApplication :: ByteString -> [Sort] -> [Block] -> DecodeM Block - -- automatically transform `rawTerm(inj{SortX, KItem}(X))` to X:SortX - -- see https://github.com/runtimeverification/llvm-backend/issues/916 - mkSymbolApplication "rawTerm" [] [BTerm t] - | Injection sort SortKItem t' <- t - , sort == sortOfTerm t' = - pure $ BTerm t' - mkSymbolApplication "\\and" _ [BTerm t1, BTerm t2] = pure $ BTerm $ AndTerm t1 t2 - mkSymbolApplication "\\and" _ bs = - argError "AndTerm" [BTerm undefined, BTerm undefined] bs - mkSymbolApplication "\\bottom" _ bs = argError "Bottom" [] bs - mkSymbolApplication "\\ceil" _ bs = argError "Ceil" [BTerm undefined] bs - mkSymbolApplication "\\dv" [sort] [BString txt] = pure $ BTerm $ DomainValue sort txt - mkSymbolApplication "\\dv" _ bs = argError "DomainValue" [BString undefined] bs - mkSymbolApplication "\\equals" _ [BTerm t, BTerm TrueBool] = pure $ BPredicate $ Predicate t - mkSymbolApplication "\\equals" _ [BTerm TrueBool, BTerm t] = pure $ BPredicate $ Predicate t - mkSymbolApplication "\\equals" _ bs = - argError "EqualBTerm/EqualBPredicate" [BTerm undefined, BTerm undefined] bs - mkSymbolApplication "\\exists" _ bs = argError "Exists" [BTerm undefined, BPredicate undefined] bs - mkSymbolApplication "\\forall" _ bs = argError "Forall" [BTerm undefined, BPredicate undefined] bs - mkSymbolApplication "\\iff" _ bs = argError "Iff" [BPredicate undefined, BPredicate undefined] bs - mkSymbolApplication "\\implies" _ bs = argError "Implies" [BPredicate undefined, BPredicate undefined] bs - mkSymbolApplication "\\in" _ bs = argError "In" [BTerm undefined, BTerm undefined] bs - mkSymbolApplication "\\not" _ bs = argError "Not" [BPredicate undefined] bs - mkSymbolApplication "\\or" _ bs = argError "Or" [BPredicate undefined, BPredicate undefined] bs - mkSymbolApplication "\\top" _ bs = argError "Top" [] bs - mkSymbolApplication "inj" [source, target] [BTerm t] = pure $ BTerm $ Injection source target t - mkSymbolApplication "inj" _ bs = argError "Injection" [BTerm undefined] bs - mkSymbolApplication name sorts bs = - lookupKoreDefinitionSymbol name >>= \case - -- testing case when we don't have a KoreDefinition - Left symbol@Symbol{sortVars} -> do - args <- forM bs $ \case - BTerm trm -> pure trm - _ -> fail "Expecting term" - pure $ BTerm $ SymbolApplication symbol (zipWith (const id) sortVars sorts) args - Right (Just symbol@Symbol{sortVars, argSorts}) -> do - args <- forM (zip argSorts bs) $ \case - (srt, BTerm trm) -> - if sortOfTerm trm /= srt - then - fail $ - "Term has incorrect sort. Expecting " - <> renderDefault (pretty srt) - <> " but got " - <> renderDefault (pretty $ sortOfTerm trm) - else pure trm - _ -> fail "Expecting term" - pure $ BTerm $ SymbolApplication symbol (zipWith (const id) sortVars sorts) args - Right Nothing -> fail $ "Unknown symbol " <> show name + mkSymbolApplication name sorts args = do + store <- DecodeM $ lift $ gets termStore + case name of + -- This special symbol will be removed in post-processing, see below + "rawTerm" + | [BTerm idx] <- args + , null sorts -> + returnRegistered BTerm $ SymbolApplicationF RawTermSymbol [] [idx] + | otherwise -> + argError "rawTerm" [BTerm undefined] args + -- translate many reserved "special" symbols into their + -- respective terms or predicates. + "\\and" + | [BTerm t1, BTerm t2] <- args -> + returnRegistered BTerm $ AndTermF t1 t2 + | otherwise -> + argError "AndTerm" [BTerm undefined, BTerm undefined] args + "\\dv" + | [BString txt] <- args + , [sort] <- sorts -> + returnRegistered BTerm $ DomainValueF sort txt + | otherwise -> + argError "DomainValue" [BString undefined] args + "\\equals" + | [BTerm t1, BTerm t2] <- args + , TrueBool <- store `Seq.index` t2 -> + pure $ BPredicate t1 + | [BTerm t1, BTerm t2] <- args + , TrueBool <- store `Seq.index` t1 -> + pure $ BPredicate t2 + | otherwise -> + argError "EqualBTerm/EqualBPredicate" [BTerm undefined, BTerm undefined] args + "inj" + | [source, target] <- sorts + , [BTerm t] <- args -> + returnRegistered BTerm $ InjectionF source target t + | otherwise -> + argError "Injection" [BTerm undefined] args + -- unsupported symbols (non-boolean predicates) + "\\bottom" -> + argError "Bottom" [] args + "\\ceil" -> + argError "Ceil" [BTerm undefined] args + "\\exists" -> + argError "Exists" [BTerm undefined, BPredicate undefined] args + "\\forall" -> + argError "Forall" [BTerm undefined, BPredicate undefined] args + "\\iff" -> + argError "Iff" [BPredicate undefined, BPredicate undefined] args + "\\implies" -> + argError "Implies" [BPredicate undefined, BPredicate undefined] args + "\\in" -> + argError "In" [BTerm undefined, BTerm undefined] args + "\\not" -> + argError "Not" [BPredicate undefined] args + "\\or" -> + argError "Or" [BPredicate undefined, BPredicate undefined] args + "\\top" -> + argError "Top" [] args + _otherwise -> + lookupKoreDefinitionSymbol name >>= \case + Left symbol@Symbol{sortVars} -> do + -- testing case when we don't have a KoreDefinition: + -- only check arguments are terms + idxs <- forM args $ \case + BTerm i -> pure i + _ -> fail "Expecting term" + returnRegistered BTerm $ + SymbolApplicationF symbol (zipWith (const id) sortVars sorts) idxs + Right (Just symbol@Symbol{sortVars, argSorts}) -> do + -- check arguments and their sorts + idxs <- forM (zip argSorts args) $ \case + (srt, BTerm argIdx) -> do + trm <- getTerm argIdx + unless (sortOfTerm trm == srt) $ + fail $ + "Argument has incorrect sort. Expecting " + <> renderDefault (pretty srt) + <> " but got " + <> renderDefault (pretty $ sortOfTerm trm) + pure argIdx + _ -> fail "Expecting term" + returnRegistered BTerm $ + SymbolApplicationF symbol (zipWith (const id) sortVars sorts) idxs + Right Nothing -> + fail $ "Unknown symbol " <> show name + + returnRegistered cons shallow = do + (idx, _) <- registerTerm $ ShallowTerm shallow + pure $ cons idx argError cons expectedArgs receivedArgs = fail $ @@ -401,37 +513,58 @@ decodeMagicHeaderAndVersion = do pure Nothing hdrSize = 19 -supported :: Version -> Bool -supported version = version.major == 1 && version.minor `elem` [0 .. 2] + supported :: Version -> Bool + supported version = version.major == 1 && version.minor `elem` [0 .. 2] decodeTerm' :: Maybe KoreDefinition -> Get Term decodeTerm' mDef = do (version, mbSize) <- decodeMagicHeaderAndVersion - runDecodeM version mDef (decodeBlock mbSize) >>= \case - [BTerm trm] -> pure trm - _ -> fail "Expecting a single term on the top of the stack" + runDecodeM version mDef $ + decodeBlock mbSize >>= \case + [BTerm trmIdx] -> fmap stripRawTerm $ getTerm trmIdx + _ -> fail "Expecting a single term on the top of the stack" decodeTerm :: KoreDefinition -> Get Term decodeTerm = decodeTerm' . Just +{- | Automatically transform `rawTerm(inj{SortX, KItem}(X))` to X:SortX +see https://github.com/runtimeverification/llvm-backend/issues/916 +-} +stripRawTerm :: Term -> Term +stripRawTerm (SymbolApplication RawTermSymbol [] [Injection _ SortKItem t]) = t +stripRawTerm other = other + +pattern RawTermSymbol :: Symbol +pattern RawTermSymbol = + Symbol + "rawTerm" + [] + [SortKItem] + SortKItem + ( SymbolAttributes + Constructor + IsNotIdem + IsNotAssoc + IsNotMacroOrAlias + CannotBeEvaluated + Nothing + Nothing + Nothing + ) + decodePattern :: Maybe KoreDefinition -> Get Pattern decodePattern mDef = do (version, mbSize) <- decodeMagicHeaderAndVersion - res <- reverse <$> runDecodeM version mDef (decodeBlock mbSize) - case res of - BTerm trm : preds' -> do - preds <- forM preds' $ \case - BPredicate p -> pure p - _ -> fail "Expecting a predicate" - pure $ Pattern trm (Set.fromList preds) mempty - _ -> fail "Expecting a term on the top of the stack" - -decodeSingleBlock :: Get Block -decodeSingleBlock = do - (version, mbSize) <- decodeMagicHeaderAndVersion - runDecodeM version Nothing (decodeBlock mbSize) >>= \case - [b] -> pure b - _ -> fail "Expecting a single block on the top of the stack" + runDecodeM version mDef $ do + res <- reverse <$> decodeBlock mbSize + case res of + BTerm trmIdx : preds' -> do + trm <- stripRawTerm <$> getTerm trmIdx + preds <- forM preds' $ \case + BPredicate pIdx -> Predicate <$> getTerm pIdx + _ -> fail "Expecting a predicate" + pure $ Pattern trm (Set.fromList preds) mempty + _ -> fail "Expecting a term on the top of the stack" encodeMagicHeaderAndVersion :: Version -> Put encodeMagicHeaderAndVersion (Version major minor patch) = do @@ -509,11 +642,3 @@ encodePattern :: Pattern -> Put encodePattern Pattern{term, constraints} = do encodeTerm term forM_ constraints encodePredicate - -encodeSingleBlock :: Block -> Put -encodeSingleBlock = \case - BTerm t -> encodeTerm t - BPredicate p -> encodePredicate p - BString s -> encodeString s - BSort s -> encodeSort s - BSymbol name sorts -> encodeSymbol name sorts