From fede497c91332f30f721ac06cf535112edcd5bb3 Mon Sep 17 00:00:00 2001 From: Marcin Szamotulski Date: Tue, 20 Aug 2024 11:37:03 +0200 Subject: [PATCH] WIP: Annotated decoder This is an experiment to provide `runAnnotatedPeer`, which is like `runPeer' but allows us to run a decoder which has access to bytes used when decoding a message. This allows one to record offsets and decode record ByteString from which a piece of data was decoded, e.g. for each `tx` inside `MsgReplyTxs`. The `Codec` type in `typed-protocols` was generalised for this purpose. The core functionality is implemented in `runAnnotatedDecoderWithChannel` which runs `AnnotatedCodec` against a `Channel` which does incremental decoding & recording bytes used so far. We also expose `runAnnotatedPeer` which runs a `Peer` against `Channel` using an `AnnotatedCodec` (using `annotatedDriverSimple`). TODO: * `runAnnotatedPipelinedPeer` * `runAnnotatedPeerWithLimits` * `runAnnotatedPipelinedPeerWithLimits` It's actually the last one that we will need in `tx-submission`. TODO: Find a nice way so we won't need to maintain two codecs for `tx-submission`, e.g. `Codec` and `AnnotatedCodec`. --- cabal.project | 9 + .../src/Ouroboros/Network/Driver/Simple.hs | 154 ++++++++++++++++-- .../Network/Protocol/TxSubmission2/Codec.hs | 48 ++++-- 3 files changed, 177 insertions(+), 34 deletions(-) diff --git a/cabal.project b/cabal.project index 11db548b8c0..9e5a02f2555 100644 --- a/cabal.project +++ b/cabal.project @@ -54,3 +54,12 @@ package network-mux package ouroboros-network flags: +asserts +cddl + +source-repository-package + type: git + location: https://github.com/input-output-hk/typed-protocols + tag: 9a0acda4cd34e37b53e53986e7a71a76bba2ca8c + subdir: typed-protocols + typed-protocols-cborg +allow-newer: typed-protocols:io-classes + diff --git a/ouroboros-network-framework/src/Ouroboros/Network/Driver/Simple.hs b/ouroboros-network-framework/src/Ouroboros/Network/Driver/Simple.hs index c7c163f3465..868d93dcee8 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/Driver/Simple.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/Driver/Simple.hs @@ -6,7 +6,6 @@ {-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeFamilies #-} -- @UndecidableInstances@ extensions is required for defining @Show@ instance -- of @'TraceSendRecv'@. @@ -19,10 +18,12 @@ module Ouroboros.Network.Driver.Simple -- $intro -- * Normal peers runPeer + , runAnnotatedPeer , TraceSendRecv (..) , DecoderFailure (..) -- * Pipelined peers , runPipelinedPeer + , runPipelinedAnnotatedPeer -- * Connected peers -- TODO: move these to a test lib , Role (..) @@ -43,6 +44,9 @@ import Ouroboros.Network.Channel import Control.Monad.Class.MonadAsync import Control.Monad.Class.MonadThrow import Control.Tracer (Tracer (..), contramap, traceWith) +import Data.Maybe (fromMaybe) +import Data.Functor.Identity (Identity) +import Control.Monad.Identity (Identity(..)) -- $intro @@ -107,18 +111,31 @@ instance Show DecoderFailure where instance Exception DecoderFailure where -driverSimple :: forall ps failure bytes m. - ( MonadThrow m - , Show failure - , forall (st :: ps). Show (ClientHasAgency st) - , forall (st :: ps). Show (ServerHasAgency st) - , ShowProxy ps - ) - => Tracer m (TraceSendRecv ps) - -> Codec ps failure m bytes - -> Channel m bytes - -> Driver ps (Maybe bytes) m -driverSimple tracer Codec{encode, decode} channel@Channel{send} = +mkSimpleDriver :: forall ps failure bytes m f annotator. + ( MonadThrow m + , Show failure + , forall (st :: ps). Show (ClientHasAgency st) + , forall (st :: ps). Show (ServerHasAgency st) + , ShowProxy ps + ) + => (forall a. + Channel m bytes + -> Maybe bytes + -> DecodeStep bytes failure m (f a) + -> m (Either failure (a, Maybe bytes)) + ) + -- ^ run incremental decoder against a channel + + -> (forall st. annotator st -> f (SomeMessage st)) + -- ^ transform annotator to a container holding the decoded + -- message + + -> Tracer m (TraceSendRecv ps) + -> Codec' ps failure m annotator bytes + -> Channel m bytes + -> Driver ps (Maybe bytes) m + +mkSimpleDriver runDecodeSteps nat tracer Codec{encode, decode} channel@Channel{send} = Driver { sendMessage, recvMessage, startDState = Nothing } where sendMessage :: forall (pr :: PeerRole) (st :: ps) (st' :: ps). @@ -135,7 +152,7 @@ driverSimple tracer Codec{encode, decode} channel@Channel{send} = -> m (SomeMessage st, Maybe bytes) recvMessage stok trailing = do decoder <- decode stok - result <- runDecoderWithChannel channel trailing decoder + result <- runDecodeSteps channel trailing (nat <$> decoder) case result of Right x@(SomeMessage msg, _trailing') -> do traceWith tracer (TraceRecvMsg (AnyMessageAndAgency stok msg)) @@ -144,6 +161,36 @@ driverSimple tracer Codec{encode, decode} channel@Channel{send} = throwIO (DecoderFailure stok failure) +simpleDriver :: forall ps failure bytes m. + ( MonadThrow m + , Show failure + , forall (st :: ps). Show (ClientHasAgency st) + , forall (st :: ps). Show (ServerHasAgency st) + , ShowProxy ps + ) + => Tracer m (TraceSendRecv ps) + -> Codec ps failure m bytes + -> Channel m bytes + -> Driver ps (Maybe bytes) m +simpleDriver = mkSimpleDriver runDecoderWithChannel Identity + + +annotatedSimpleDriver + :: forall ps failure bytes m. + ( MonadThrow m + , Monoid bytes + , Show failure + , forall (st :: ps). Show (ClientHasAgency st) + , forall (st :: ps). Show (ServerHasAgency st) + , ShowProxy ps + ) + => Tracer m (TraceSendRecv ps) + -> AnnotatedCodec ps failure m bytes + -> Channel m bytes + -> Driver ps (Maybe bytes) m +annotatedSimpleDriver = mkSimpleDriver runAnnotatedDecoderWithChannel runAnnotator + + -- | Run a peer with the given channel via the given codec. -- -- This runs the peer to completion (if the protocol allows for termination). @@ -164,7 +211,31 @@ runPeer runPeer tracer codec channel peer = runPeerWithDriver driver peer (startDState driver) where - driver = driverSimple tracer codec channel + driver = simpleDriver tracer codec channel + + +-- | Run a peer with the given channel via the given annotated codec. +-- +-- This runs the peer to completion (if the protocol allows for termination). +-- +runAnnotatedPeer + :: forall ps (st :: ps) pr failure bytes m a . + ( MonadThrow m + , Monoid bytes + , Show failure + , forall (st' :: ps). Show (ClientHasAgency st') + , forall (st' :: ps). Show (ServerHasAgency st') + , ShowProxy ps + ) + => Tracer m (TraceSendRecv ps) + -> AnnotatedCodec ps failure m bytes + -> Channel m bytes + -> Peer ps pr st m a + -> m (a, Maybe bytes) +runAnnotatedPeer tracer codec channel peer = + runPeerWithDriver driver peer (startDState driver) + where + driver = annotatedSimpleDriver tracer codec channel -- | Run a pipelined peer with the given channel via the given codec. @@ -191,7 +262,35 @@ runPipelinedPeer runPipelinedPeer tracer codec channel peer = runPipelinedPeerWithDriver driver peer (startDState driver) where - driver = driverSimple tracer codec channel + driver = simpleDriver tracer codec channel + + +-- | Run a pipelined peer with the given channel via the given annotated codec. +-- +-- This runs the peer to completion (if the protocol allows for termination). +-- +-- Unlike normal peers, running pipelined peers rely on concurrency, hence the +-- 'MonadAsync' constraint. +-- +runPipelinedAnnotatedPeer + :: forall ps (st :: ps) pr failure bytes m a. + ( MonadAsync m + , MonadThrow m + , Monoid bytes + , Show failure + , forall (st' :: ps). Show (ClientHasAgency st') + , forall (st' :: ps). Show (ServerHasAgency st') + , ShowProxy ps + ) + => Tracer m (TraceSendRecv ps) + -> AnnotatedCodec ps failure m bytes + -> Channel m bytes + -> PeerPipelined ps pr st m a + -> m (a, Maybe bytes) +runPipelinedAnnotatedPeer tracer codec channel peer = + runPipelinedPeerWithDriver driver peer (startDState driver) + where + driver = annotatedSimpleDriver tracer codec channel -- @@ -204,17 +303,36 @@ runPipelinedPeer tracer codec channel peer = runDecoderWithChannel :: Monad m => Channel m bytes -> Maybe bytes - -> DecodeStep bytes failure m a + -> DecodeStep bytes failure m (Identity a) -> m (Either failure (a, Maybe bytes)) runDecoderWithChannel Channel{recv} = go where - go _ (DecodeDone x trailing) = return (Right (x, trailing)) + go _ (DecodeDone (Identity x) trailing) = return (Right (x, trailing)) go _ (DecodeFail failure) = return (Left failure) go Nothing (DecodePartial k) = recv >>= k >>= go Nothing go (Just trailing) (DecodePartial k) = k (Just trailing) >>= go Nothing +runAnnotatedDecoderWithChannel + :: forall m bytes failure a. + ( Monad m + , Monoid bytes + ) + => Channel m bytes + -> Maybe bytes + -> DecodeStep bytes failure m (bytes -> a) + -> m (Either failure (a, Maybe bytes)) + +runAnnotatedDecoderWithChannel Channel{recv} bs0 = go (fromMaybe mempty bs0) bs0 + where + go :: bytes -> Maybe bytes -> DecodeStep bytes failure m (bytes -> a) -> m (Either failure (a, Maybe bytes)) + go bytes _ (DecodeDone f trailing) = return $ Right (f bytes, trailing) + go _bytes _ (DecodeFail failure) = return (Left failure) + go bytes Nothing (DecodePartial k) = recv >>= \bs -> k bs >>= go (bytes <> fromMaybe mempty bs) Nothing + go bytes (Just trailing) (DecodePartial k) = k (Just trailing) >>= go (bytes <> trailing) Nothing + + data Role = Client | Server -- | Run two 'Peer's via a pair of connected 'Channel's and a common 'Codec'. diff --git a/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/TxSubmission2/Codec.hs b/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/TxSubmission2/Codec.hs index 7bfa2f0f806..e83eacd0c6f 100644 --- a/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/TxSubmission2/Codec.hs +++ b/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/TxSubmission2/Codec.hs @@ -64,26 +64,30 @@ timeLimitsTxSubmission2 = ProtocolTimeLimits stateToLimit codecTxSubmission2 - :: forall txid tx m. + :: forall txid tx annotator m. MonadST m => (txid -> CBOR.Encoding) -> (forall s . CBOR.Decoder s txid) -> (tx -> CBOR.Encoding) -> (forall s . CBOR.Decoder s tx) - -> Codec (TxSubmission2 txid tx) CBOR.DeserialiseFailure m ByteString + -- the codec is polymorphic in annotator. The primary use case is an + -- `Identity` functor or `Annotator LBS.ByteString`. + -> (forall st. SomeMessage st -> annotator st) + -> Codec' (TxSubmission2 txid tx) CBOR.DeserialiseFailure m annotator ByteString codecTxSubmission2 encodeTxId decodeTxId - encodeTx decodeTx = + encodeTx decodeTx + annotate = mkCodecCborLazyBS (encodeTxSubmission2 encodeTxId encodeTx) decode where decode :: forall (pr :: PeerRole) (st :: TxSubmission2 txid tx). PeerHasAgency pr st - -> forall s. CBOR.Decoder s (SomeMessage st) + -> forall s. CBOR.Decoder s (annotator st) decode stok = do len <- CBOR.decodeListLen key <- CBOR.decodeWord - decodeTxSubmission2 decodeTxId decodeTx stok len key + decodeTxSubmission2 decodeTxId decodeTx annotate stok len key encodeTxSubmission2 :: forall txid tx. @@ -149,30 +153,31 @@ encodeTxSubmission2 encodeTxId encodeTx = encode decodeTxSubmission2 - :: forall txid tx. + :: forall txid tx annotator. (forall s . CBOR.Decoder s txid) -> (forall s . CBOR.Decoder s tx) + -> (forall st. SomeMessage st -> annotator st) -> (forall (pr :: PeerRole) (st :: TxSubmission2 txid tx) s. PeerHasAgency pr st -> Int -> Word - -> CBOR.Decoder s (SomeMessage st)) -decodeTxSubmission2 decodeTxId decodeTx = decode + -> CBOR.Decoder s (annotator st)) +decodeTxSubmission2 decodeTxId decodeTx annotate = decode where decode :: forall (pr :: PeerRole) s (st :: TxSubmission2 txid tx). PeerHasAgency pr st -> Int -> Word - -> CBOR.Decoder s (SomeMessage st) + -> CBOR.Decoder s (annotator st) decode stok len key = do case (stok, len, key) of (ClientAgency TokInit, 1, 6) -> - return (SomeMessage MsgInit) + return (annotate $ SomeMessage MsgInit) (ServerAgency TokIdle, 4, 0) -> do blocking <- CBOR.decodeBool ackNo <- NumTxIdsToAck <$> CBOR.decodeWord16 reqNo <- NumTxIdsToReq <$> CBOR.decodeWord16 - return $! + return $! annotate $ if blocking then SomeMessage (MsgRequestTxIds TokBlocking ackNo reqNo) else SomeMessage (MsgRequestTxIds TokNonBlocking ackNo reqNo) @@ -187,11 +192,11 @@ decodeTxSubmission2 decodeTxId decodeTx = decode return (txid, SizeInBytes sz)) case (b, txids) of (TokBlocking, t:ts) -> - return $ + return $ annotate $ SomeMessage (MsgReplyTxIds (BlockingReply (t NonEmpty.:| ts))) (TokNonBlocking, ts) -> - return $ + return $ annotate $ SomeMessage (MsgReplyTxIds (NonBlockingReply ts)) (TokBlocking, []) -> @@ -201,15 +206,26 @@ decodeTxSubmission2 decodeTxId decodeTx = decode (ServerAgency TokIdle, 2, 2) -> do CBOR.decodeListLenIndef txids <- CBOR.decodeSequenceLenIndef (flip (:)) [] reverse decodeTxId - return (SomeMessage (MsgRequestTxs txids)) + return (annotate $ SomeMessage (MsgRequestTxs txids)) (ClientAgency TokTxs, 2, 3) -> do CBOR.decodeListLenIndef txids <- CBOR.decodeSequenceLenIndef (flip (:)) [] reverse decodeTx - return (SomeMessage (MsgReplyTxs txids)) + -- ^ TODO: `txids -> txs` :grin: + + -- TODO: here we have access to bytes from which the message was decoded. + -- we can use `Codec.CBOR.Decoding.decodeWithByteSpan` + -- around each `tx` and wrap each `tx` in `WithBytes`. + -- + -- `decodeTxSubmission2` can be polymorphic by adding an + -- extra argument of type + -- `ByteString -> ByteOffSet -> ByteOffset -> tx -> a` + -- this way we could wrap `tx` in `WithBytes` or just + -- return `tx`. + return (annotate $ SomeMessage (MsgReplyTxs txids)) (ClientAgency (TokTxIds TokBlocking), 1, 4) -> - return (SomeMessage MsgDone) + return (annotate $ SomeMessage MsgDone) -- -- failures per protocol state