@@ -16,6 +16,12 @@ namespace ModelContextProtocol.Shared;
1616/// </summary>
1717internal sealed class McpSession : IDisposable
1818{
19+ /// <summary>
20+ /// In-flight request handling, indexed by request ID. The value provides a <see cref="CancellationTokenSource"/>
21+ /// that can be used to request cancellation of the in-flight handler.
22+ /// </summary>
23+ private static readonly ConcurrentDictionary < RequestId , CancellationTokenSource > s_handlingRequests = new ( ) ;
24+
1925 private readonly ITransport _transport ;
2026 private readonly RequestHandlers _requestHandlers ;
2127 private readonly NotificationHandlers _notificationHandlers ;
@@ -69,25 +75,70 @@ public async Task ProcessMessagesAsync(CancellationToken cancellationToken)
6975 {
7076 _logger . TransportMessageRead ( EndpointName , message . GetType ( ) . Name ) ;
7177
72- // Fire and forget the message handling task to avoid blocking the transport
73- // If awaiting the task, the transport will not be able to read more messages,
74- // which could lead to a deadlock if the handler sends a message back
7578 _ = ProcessMessageAsync ( ) ;
7679 async Task ProcessMessageAsync ( )
7780 {
81+ IJsonRpcMessageWithId ? messageWithId = message as IJsonRpcMessageWithId ;
82+ CancellationTokenSource ? combinedCts = null ;
83+ try
84+ {
85+ // Register before we yield, so that the tracking is guaranteed to be there
86+ // when subsequent messages arrive, even if the asynchronous processing happens
87+ // out of order.
88+ if ( messageWithId is not null )
89+ {
90+ combinedCts = CancellationTokenSource . CreateLinkedTokenSource ( cancellationToken ) ;
91+ s_handlingRequests [ messageWithId . Id ] = combinedCts ;
92+ }
93+
94+ // Fire and forget the message handling to avoid blocking the transport
95+ // If awaiting the task, the transport will not be able to read more messages,
96+ // which could lead to a deadlock if the handler sends a message back
97+
7898#if NET
79- await Task . CompletedTask . ConfigureAwait ( ConfigureAwaitOptions . ForceYielding ) ;
99+ await Task . CompletedTask . ConfigureAwait ( ConfigureAwaitOptions . ForceYielding ) ;
80100#else
81- await default ( ForceYielding ) ;
101+ await default ( ForceYielding ) ;
82102#endif
83- try
84- {
85- await HandleMessageAsync ( message , cancellationToken ) . ConfigureAwait ( false ) ;
103+
104+ // Handle the message.
105+ await HandleMessageAsync ( message , combinedCts ? . Token ?? cancellationToken ) . ConfigureAwait ( false ) ;
86106 }
87107 catch ( Exception ex )
88108 {
89- var payload = JsonSerializer . Serialize ( message , _jsonOptions . GetTypeInfo < IJsonRpcMessage > ( ) ) ;
90- _logger . MessageHandlerError ( EndpointName , message . GetType ( ) . Name , payload , ex ) ;
109+ // Only send responses for request errors that aren't user-initiated cancellation.
110+ bool isUserCancellation =
111+ ex is OperationCanceledException &&
112+ ! cancellationToken . IsCancellationRequested &&
113+ combinedCts ? . IsCancellationRequested is true ;
114+
115+ if ( ! isUserCancellation && message is JsonRpcRequest request )
116+ {
117+ _logger . RequestHandlerError ( EndpointName , request . Method , ex ) ;
118+ await _transport . SendMessageAsync ( new JsonRpcError
119+ {
120+ Id = request . Id ,
121+ JsonRpc = "2.0" ,
122+ Error = new JsonRpcErrorDetail
123+ {
124+ Code = ErrorCodes . InternalError ,
125+ Message = ex . Message
126+ }
127+ } , cancellationToken ) . ConfigureAwait ( false ) ;
128+ }
129+ else if ( ex is not OperationCanceledException )
130+ {
131+ var payload = JsonSerializer . Serialize ( message , _jsonOptions . GetTypeInfo < IJsonRpcMessage > ( ) ) ;
132+ _logger . MessageHandlerError ( EndpointName , message . GetType ( ) . Name , payload , ex ) ;
133+ }
134+ }
135+ finally
136+ {
137+ if ( messageWithId is not null )
138+ {
139+ s_handlingRequests . TryRemove ( messageWithId . Id , out _ ) ;
140+ combinedCts ! . Dispose ( ) ;
141+ }
91142 }
92143 }
93144 }
@@ -123,6 +174,24 @@ private async Task HandleMessageAsync(IJsonRpcMessage message, CancellationToken
123174
124175 private async Task HandleNotification ( JsonRpcNotification notification )
125176 {
177+ // Special-case cancellation to cancel a pending operation. (We'll still subsequently invoke a user-specified handler if one exists.)
178+ if ( notification . Method == NotificationMethods . CancelledNotification )
179+ {
180+ try
181+ {
182+ if ( GetCancelledNotificationParams ( notification . Params ) is CancelledNotification cn &&
183+ s_handlingRequests . TryGetValue ( cn . RequestId , out var cts ) )
184+ {
185+ await cts . CancelAsync ( ) . ConfigureAwait ( false ) ;
186+ }
187+ }
188+ catch
189+ {
190+ // "Invalid cancellation notifications SHOULD be ignored"
191+ }
192+ }
193+
194+ // Handle user-defined notifications.
126195 if ( _notificationHandlers . TryGetValue ( notification . Method , out var handlers ) )
127196 {
128197 foreach ( var notificationHandler in handlers )
@@ -161,33 +230,15 @@ private async Task HandleRequest(JsonRpcRequest request, CancellationToken cance
161230 {
162231 if ( _requestHandlers . TryGetValue ( request . Method , out var handler ) )
163232 {
164- try
233+ _logger . RequestHandlerCalled ( EndpointName , request . Method ) ;
234+ var result = await handler ( request , cancellationToken ) . ConfigureAwait ( false ) ;
235+ _logger . RequestHandlerCompleted ( EndpointName , request . Method ) ;
236+ await _transport . SendMessageAsync ( new JsonRpcResponse
165237 {
166- _logger . RequestHandlerCalled ( EndpointName , request . Method ) ;
167- var result = await handler ( request , cancellationToken ) . ConfigureAwait ( false ) ;
168- _logger . RequestHandlerCompleted ( EndpointName , request . Method ) ;
169- await _transport . SendMessageAsync ( new JsonRpcResponse
170- {
171- Id = request . Id ,
172- JsonRpc = "2.0" ,
173- Result = result
174- } , cancellationToken ) . ConfigureAwait ( false ) ;
175- }
176- catch ( Exception ex )
177- {
178- _logger . RequestHandlerError ( EndpointName , request . Method , ex ) ;
179- // Send error response
180- await _transport . SendMessageAsync ( new JsonRpcError
181- {
182- Id = request . Id ,
183- JsonRpc = "2.0" ,
184- Error = new JsonRpcErrorDetail
185- {
186- Code = - 32000 , // Implementation defined error
187- Message = ex . Message
188- }
189- } , cancellationToken ) . ConfigureAwait ( false ) ;
190- }
238+ Id = request . Id ,
239+ JsonRpc = "2.0" ,
240+ Result = result
241+ } , cancellationToken ) . ConfigureAwait ( false ) ;
191242 }
192243 else
193244 {
@@ -273,7 +324,7 @@ public async Task<TResult> SendRequestAsync<TResult>(JsonRpcRequest request, Can
273324 }
274325 }
275326
276- public Task SendMessageAsync ( IJsonRpcMessage message , CancellationToken cancellationToken = default )
327+ public async Task SendMessageAsync ( IJsonRpcMessage message , CancellationToken cancellationToken = default )
277328 {
278329 Throw . IfNull ( message ) ;
279330
@@ -288,7 +339,44 @@ public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancella
288339 _logger . SendingMessage ( EndpointName , JsonSerializer . Serialize ( message , _jsonOptions . GetTypeInfo < IJsonRpcMessage > ( ) ) ) ;
289340 }
290341
291- return _transport . SendMessageAsync ( message , cancellationToken ) ;
342+ await _transport . SendMessageAsync ( message , cancellationToken ) . ConfigureAwait ( false ) ;
343+
344+ // If the sent notification was a cancellation notification, cancel the pending request's await, as either the
345+ // server won't be sending a response, or per the specification, the response should be ignored. There are inherent
346+ // race conditions here, so it's possible and allowed for the operation to complete before we get to this point.
347+ if ( message is JsonRpcNotification { Method : NotificationMethods . CancelledNotification } notification &&
348+ GetCancelledNotificationParams ( notification . Params ) is CancelledNotification cn &&
349+ _pendingRequests . TryRemove ( cn . RequestId , out var tcs ) )
350+ {
351+ tcs . TrySetCanceled ( default ) ;
352+ }
353+ }
354+
355+ private static CancelledNotification ? GetCancelledNotificationParams ( object ? notificationParams )
356+ {
357+ try
358+ {
359+ switch ( notificationParams )
360+ {
361+ case null :
362+ return null ;
363+
364+ case CancelledNotification cn :
365+ return cn ;
366+
367+ case JsonElement je :
368+ return JsonSerializer . Deserialize ( je , McpJsonUtilities . DefaultOptions . GetTypeInfo < CancelledNotification > ( ) ) ;
369+
370+ default :
371+ return JsonSerializer . Deserialize (
372+ JsonSerializer . Serialize ( notificationParams , McpJsonUtilities . DefaultOptions . GetTypeInfo < object ? > ( ) ) ,
373+ McpJsonUtilities . DefaultOptions . GetTypeInfo < CancelledNotification > ( ) ) ;
374+ }
375+ }
376+ catch
377+ {
378+ return null ;
379+ }
292380 }
293381
294382 public void Dispose ( )
0 commit comments