From c24d59ec0aced82fe69c39d87f62042051758ffe Mon Sep 17 00:00:00 2001 From: ihrpr Date: Thu, 17 Jul 2025 16:46:33 +0100 Subject: [PATCH 1/2] fix flaky test --- src/mcp/server/streamable_http.py | 4 +- tests/shared/test_streamable_http.py | 151 +++++++++++++++++---------- 2 files changed, 97 insertions(+), 58 deletions(-) diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 32b63c1ae..802cb8680 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -837,9 +837,7 @@ async def message_router(): response_id = str(message.root.id) # If this response is for an existing request stream, # send it there - if response_id in self._request_streams: - target_request_id = response_id - + target_request_id = response_id else: # Extract related_request_id from meta if it exists if ( diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 1ffcc13b0..fa60513a9 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -98,32 +98,36 @@ async def replay_events_after( send_callback: EventCallback, ) -> StreamId | None: """Replay events after the specified ID.""" - # Find the index of the last event ID - start_index = None - for i, (_, event_id, _) in enumerate(self._events): + # Find the stream ID of the last event + target_stream_id = None + for stream_id, event_id, _ in self._events: if event_id == last_event_id: - start_index = i + 1 + target_stream_id = stream_id break - if start_index is None: - # If event ID not found, start from beginning - start_index = 0 + if target_stream_id is None: + # If event ID not found, return None + return None - stream_id = None - # Replay events - for _, event_id, message in self._events[start_index:]: - await send_callback(EventMessage(message, event_id)) - # Capture the stream ID from the first replayed event - if stream_id is None and len(self._events) > start_index: - stream_id = self._events[start_index][0] + # Convert last_event_id to int for comparison + last_event_id_int = int(last_event_id) - return stream_id + # Replay only events from the same stream with ID > last_event_id + for stream_id, event_id, message in self._events: + if stream_id == target_stream_id and int(event_id) > last_event_id_int: + await send_callback(EventMessage(message, event_id)) + + return target_stream_id # Test server implementation that follows MCP protocol class ServerTest(Server): def __init__(self): super().__init__(SERVER_NAME) + self._lock = anyio.Event() + # Reset the lock for each new server instance + self._lock.set() + self._lock = anyio.Event() @self.read_resource() async def handle_read_resource(uri: AnyUrl) -> str | bytes: @@ -159,6 +163,16 @@ async def handle_list_tools() -> list[Tool]: description="A tool that triggers server-side sampling", inputSchema={"type": "object", "properties": {}}, ), + Tool( + name="wait_for_lock_with_notification", + description="A tool that sends a notification and waits for lock", + inputSchema={"type": "object", "properties": {}}, + ), + Tool( + name="release_lock", + description="A tool that releases the lock", + inputSchema={"type": "object", "properties": {}}, + ), ] @self.call_tool() @@ -214,6 +228,33 @@ async def handle_call_tool(name: str, args: dict) -> list[TextContent]: ) ] + elif name == "wait_for_lock_with_notification": + # First send a notification + await ctx.session.send_log_message( + level="info", + data="First notification before lock", + logger="lock_tool", + related_request_id=ctx.request_id, + ) + + # Now wait for the lock to be released + await self._lock.wait() + + # Send second notification after lock is released + await ctx.session.send_log_message( + level="info", + data="Second notification after lock", + logger="lock_tool", + related_request_id=ctx.request_id, + ) + + return [TextContent(type="text", text="Completed")] + + elif name == "release_lock": + # Release the lock + self._lock.set() + return [TextContent(type="text", text="Lock released")] + return [TextContent(type="text", text=f"Called {name}")] @@ -825,7 +866,7 @@ async def test_streamablehttp_client_tool_invocation(initialized_client_session) """Test client tool invocation.""" # First list tools tools = await initialized_client_session.list_tools() - assert len(tools.tools) == 4 + assert len(tools.tools) == 6 assert tools.tools[0].name == "test_tool" # Call the tool @@ -862,7 +903,7 @@ async def test_streamablehttp_client_session_persistence(basic_server, basic_ser # Make multiple requests to verify session persistence tools = await session.list_tools() - assert len(tools.tools) == 4 + assert len(tools.tools) == 6 # Read a resource resource = await session.read_resource(uri=AnyUrl("foobar://test-persist")) @@ -891,7 +932,7 @@ async def test_streamablehttp_client_json_response(json_response_server, json_se # Check tool listing tools = await session.list_tools() - assert len(tools.tools) == 4 + assert len(tools.tools) == 6 # Call a tool and verify JSON response handling result = await session.call_tool("test_tool", {}) @@ -962,7 +1003,7 @@ async def test_streamablehttp_client_session_termination(basic_server, basic_ser # Make a request to confirm session is working tools = await session.list_tools() - assert len(tools.tools) == 4 + assert len(tools.tools) == 6 headers = {} if captured_session_id: @@ -1026,7 +1067,7 @@ async def mock_delete(self, *args, **kwargs): # Make a request to confirm session is working tools = await session.list_tools() - assert len(tools.tools) == 4 + assert len(tools.tools) == 6 headers = {} if captured_session_id: @@ -1048,32 +1089,32 @@ async def mock_delete(self, *args, **kwargs): @pytest.mark.anyio async def test_streamablehttp_client_resumption(event_server): - """Test client session to resume a long running tool.""" + """Test client session resumption using sync primitives for reliable coordination.""" _, server_url = event_server # Variables to track the state captured_session_id = None captured_resumption_token = None captured_notifications = [] - tool_started = False captured_protocol_version = None + first_notification_received = False async def message_handler( message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: if isinstance(message, types.ServerNotification): captured_notifications.append(message) - # Look for our special notification that indicates the tool is running + # Look for our first notification if isinstance(message.root, types.LoggingMessageNotification): - if message.root.params.data == "Tool started": - nonlocal tool_started - tool_started = True + if message.root.params.data == "First notification before lock": + nonlocal first_notification_received + first_notification_received = True async def on_resumption_token_update(token: str) -> None: nonlocal captured_resumption_token captured_resumption_token = token - # First, start the client session and begin the long-running tool + # First, start the client session and begin the tool that waits on lock async with streamablehttp_client(f"{server_url}/mcp", terminate_on_close=False) as ( read_stream, write_stream, @@ -1088,7 +1129,7 @@ async def on_resumption_token_update(token: str) -> None: # Capture the negotiated protocol version captured_protocol_version = result.protocolVersion - # Start a long-running tool in a task + # Start the tool that will wait on lock in a task async with anyio.create_task_group() as tg: async def run_tool(): @@ -1099,7 +1140,9 @@ async def run_tool(): types.ClientRequest( types.CallToolRequest( method="tools/call", - params=types.CallToolRequestParams(name="long_running_with_checkpoints", arguments={}), + params=types.CallToolRequestParams( + name="wait_for_lock_with_notification", arguments={} + ), ) ), types.CallToolResult, @@ -1108,15 +1151,19 @@ async def run_tool(): tg.start_soon(run_tool) - # Wait for the tool to start and at least one notification - # and then kill the task group - while not tool_started or not captured_resumption_token: + # Wait for the first notification and resumption token + while not first_notification_received or not captured_resumption_token: await anyio.sleep(0.1) + + # Kill the client session while tool is waiting on lock tg.cancel_scope.cancel() - # Store pre notifications and clear the captured notifications - # for the post-resumption check - captured_notifications_pre = captured_notifications.copy() + # Verify we received exactly one notification + assert len(captured_notifications) == 1 + assert isinstance(captured_notifications[0].root, types.LoggingMessageNotification) + assert captured_notifications[0].root.params.data == "First notification before lock" + + # Clear notifications for the second phase captured_notifications = [] # Now resume the session with the same mcp-session-id and protocol version @@ -1125,54 +1172,48 @@ async def run_tool(): headers[MCP_SESSION_ID_HEADER] = captured_session_id if captured_protocol_version: headers[MCP_PROTOCOL_VERSION_HEADER] = captured_protocol_version - async with streamablehttp_client(f"{server_url}/mcp", headers=headers) as ( read_stream, write_stream, _, ): async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: - # Don't initialize - just use the existing session - - # Resume the tool with the resumption token - assert captured_resumption_token is not None - + result = await session.send_request( + types.ClientRequest( + types.CallToolRequest( + method="tools/call", + params=types.CallToolRequestParams(name="release_lock", arguments={}), + ) + ), + types.CallToolResult, + ) metadata = ClientMessageMetadata( resumption_token=captured_resumption_token, ) + result = await session.send_request( types.ClientRequest( types.CallToolRequest( method="tools/call", - params=types.CallToolRequestParams(name="long_running_with_checkpoints", arguments={}), + params=types.CallToolRequestParams(name="wait_for_lock_with_notification", arguments={}), ) ), types.CallToolResult, metadata=metadata, ) - - # We should get a complete result assert len(result.content) == 1 assert result.content[0].type == "text" - assert "Completed" in result.content[0].text + assert result.content[0].text == "Completed" # We should have received the remaining notifications - assert len(captured_notifications) > 0 + assert len(captured_notifications) == 1 - # Should not have the first notification - # Check that "Tool started" notification isn't repeated when resuming - assert not any( - isinstance(n.root, types.LoggingMessageNotification) and n.root.params.data == "Tool started" - for n in captured_notifications - ) - # there is no intersection between pre and post notifications - assert not any(n in captured_notifications_pre for n in captured_notifications) + assert captured_notifications[0].root.params.data == "Second notification after lock" @pytest.mark.anyio async def test_streamablehttp_server_sampling(basic_server, basic_server_url): """Test server-initiated sampling request through streamable HTTP transport.""" - print("Testing server sampling...") # Variable to track if sampling callback was invoked sampling_callback_invoked = False captured_message_params = None From 38d67133a5ab3ce7a06442b9f4416d89fec4a2a8 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Thu, 17 Jul 2025 17:02:03 +0100 Subject: [PATCH 2/2] fix lowest version --- tests/shared/test_streamable_http.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index fa60513a9..3fea54f0b 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -124,10 +124,7 @@ async def replay_events_after( class ServerTest(Server): def __init__(self): super().__init__(SERVER_NAME) - self._lock = anyio.Event() - # Reset the lock for each new server instance - self._lock.set() - self._lock = anyio.Event() + self._lock = None # Will be initialized in async context @self.read_resource() async def handle_read_resource(uri: AnyUrl) -> str | bytes: @@ -229,6 +226,10 @@ async def handle_call_tool(name: str, args: dict) -> list[TextContent]: ] elif name == "wait_for_lock_with_notification": + # Initialize lock if not already done + if self._lock is None: + self._lock = anyio.Event() + # First send a notification await ctx.session.send_log_message( level="info", @@ -251,6 +252,8 @@ async def handle_call_tool(name: str, args: dict) -> list[TextContent]: return [TextContent(type="text", text="Completed")] elif name == "release_lock": + assert self._lock is not None, "Lock must be initialized before releasing" + # Release the lock self._lock.set() return [TextContent(type="text", text="Lock released")]