@@ -16,7 +16,8 @@ module Ouroboros.Consensus.MiniProtocol.ObjectDiffusion.Inbound.V2.State
1616
1717import  Control.Concurrent.Class.MonadSTM.Strict 
1818import  Control.Concurrent.Class.MonadSTM.TSem 
19- import  Control.Exception  (assert )
19+ import  Control.Exception  (assert , throw )
20+ import  Control.Monad  (when )
2021import  Control.Tracer  (Tracer , traceWith )
2122import  Data.Foldable  qualified  as  Foldable 
2223import  Data.Map.Strict  (Map , findWithDefault )
@@ -81,6 +82,7 @@ onRequestIdsImpl
8182            let 
8283              --  We compute the ids to ack and new state of the FIFO based on the number of ids to ack given by the decision logic
8384              (idsToAck, dpsOutstandingFifo') = 
85+                 assert (StrictSeq. length  dpsOutstandingFifo >=  fromIntegral  numIdsToAck) $ 
8486                StrictSeq. splitAt 
8587                  (fromIntegral  numIdsToAck)
8688                  dpsOutstandingFifo
@@ -143,6 +145,10 @@ onRequestObjectsImpl
143145    dgsPeerStates' = 
144146      Map. adjust
145147        ( \ ps@ DecisionPeerState {dpsObjectsAvailableIds, dpsObjectsInflightIds} -> 
148+             assert
149+             (  objectIds `Set.isSubsetOf`  dpsObjectsAvailableIds
150+             &&  Set. null  (objectIds `Set.intersection`  dpsObjectsInflightIds)
151+             ) $ 
146152            ps
147153              { dpsObjectsAvailableIds =  dpsObjectsAvailableIds \\  objectIds
148154              , dpsObjectsInflightIds =  dpsObjectsInflightIds `Set.union`  objectIds
@@ -169,15 +175,32 @@ onReceiveIds ::
169175  --  |  received `objectId`s 
170176  m  () 
171177onReceiveIds odTracer decisionTracer globalStateVar peerAddr numIdsInitiallyRequested receivedIds =  do 
178+   peerState <-  atomically $  ((Map. !  peerAddr) .  dgsPeerStates) <$>  readTVar globalStateVar
179+   checkProtocolErrors peerState numIdsInitiallyRequested receivedIds
172180  globalState' <-  atomically $  do 
173181    stateTVar
174182      globalStateVar
175183      ( \ globalState -> 
176184          let  globalState' =  onReceiveIdsImpl peerAddr numIdsInitiallyRequested receivedIds globalState
177-             in  (globalState', globalState')
185+           in  (globalState', globalState')
178186      )
179187  traceWith odTracer (TraceObjectDiffusionInboundReceivedIds  (length  receivedIds))
180188  traceWith decisionTracer (TraceDecisionLogicGlobalStateUpdated  " onReceiveIds"   globalState')
189+   where 
190+     checkProtocolErrors  :: 
191+       DecisionPeerState  objectId  object -> 
192+       NumObjectIdsReq  -> 
193+       [objectId ] -> 
194+       m  () 
195+     checkProtocolErrors DecisionPeerState {dpsObjectsAvailableIds, dpsObjectsInflightIds} nReq ids =  do 
196+       when (length  ids >  fromIntegral  nReq) $  throw ProtocolErrorObjectIdsNotRequested 
197+       let  idSet =  Set. fromList ids
198+       when (length  ids /=  Set. size idSet) $  throw ProtocolErrorObjectIdsDuplicate 
199+       when
200+         --  TODO also check for IDs in pool
201+         (  (not  $  Set. null  $  idSet `Set.intersection`  dpsObjectsAvailableIds)
202+         ||  (not  $  Set. null  $  idSet `Set.intersection`  dpsObjectsInflightIds)
203+         ) $  throw ProtocolErrorObjectIdAlreadyKnown 
181204
182205onReceiveIdsImpl  :: 
183206  forall  peerAddr  object  objectId . 
@@ -253,13 +276,15 @@ onReceiveObjects ::
253276  ObjectPoolWriter  objectId  object  m  -> 
254277  ObjectPoolSem  m  -> 
255278  peerAddr  -> 
279+   --  |  requested objects 
280+   Set  objectId  -> 
256281  --  |  received objects 
257282  [object ] -> 
258283  m  () 
259- onReceiveObjects odTracer tracer globalStateVar objectPoolWriter poolSem peerAddr objectsReceived =  do 
284+ onReceiveObjects odTracer tracer globalStateVar objectPoolWriter poolSem peerAddr objectsRequestedIds  objectsReceived =  do 
260285  let  getId =  opwObjectId objectPoolWriter
261286  let  objectsReceivedMap =  Map. fromList $  (\ obj ->  (getId obj, obj)) <$>  objectsReceived
262- 
287+   checkProtocolErrors objectsRequestedIds objectsReceivedMap 
263288  globalState' <-  atomically $  do 
264289    stateTVar
265290      globalStateVar
@@ -281,6 +306,15 @@ onReceiveObjects odTracer tracer globalStateVar objectPoolWriter poolSem peerAdd
281306    poolSem
282307    peerAddr
283308    objectsReceivedMap
309+   where 
310+     checkProtocolErrors  :: 
311+       Set  objectId -> 
312+       Map  objectId  object  -> 
313+       m  () 
314+     checkProtocolErrors requested received' =  do 
315+       let  received =  Map. keysSet received'
316+       when (not  $  Set. null  $  requested \\  received) $  throw ProtocolErrorObjectMissing 
317+       when (not  $  Set. null  $  received \\  requested) $  throw ProtocolErrorObjectNotRequested 
284318
285319onReceiveObjectsImpl  :: 
286320  forall  peerAddr  object  objectId . 
@@ -314,7 +348,7 @@ onReceiveObjectsImpl
314348          dgsPeerStates
315349
316350    --  subtract requested from in-flight
317-     dpsObjectsInflightIds' = 
351+     dpsObjectsInflightIds' =  assert (objectsReceivedIds  `Set.isSubsetOf`  dpsObjectsInflightIds)  $ 
318352      dpsObjectsInflightIds \\  objectsReceivedIds
319353
320354    dpsObjectsOwtPool' =  dpsObjectsOwtPool <>  objectsReceived
0 commit comments