From d9a6317306c7f4f66d756aa2f47da937093475a1 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 25 Jun 2026 13:38:05 +0000 Subject: [PATCH 1/6] Client call_tool: accept input_responses/request_state; return InputRequiredResult via allow_input_required opt-in MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - ClientSession.send_request: accept TypeAdapter[T] alongside type[T] for result_type so callers can parse union results. - ClientSession.call_tool (mechanics): add input_responses= / request_state= retry kwargs; return CallToolResult | InputRequiredResult; gate output-schema validation on isinstance(result, CallToolResult). - Client.call_tool / ClientSessionGroup.call_tool (policy): @overload on allow_input_required — Literal[False] (default) returns CallToolResult; Literal[True] returns the union. Default raises RuntimeError on InputRequiredResult with a retry steer (TODO(L80) marks where the auto-loop driver replaces this). - Examples and tests that call ClientSession.call_tool directly narrow with isinstance(result, CallToolResult); README.v2.md regenerated from snippets. --- README.v2.md | 8 ++- .../mcp_simple_auth_client/main.py | 3 +- .../mcp_sse_polling_client/main.py | 2 + .../snippets/clients/parsing_tool_results.py | 7 ++- examples/snippets/clients/stdio_client.py | 1 + .../clients/url_elicitation_client.py | 1 + src/mcp/client/client.py | 61 +++++++++++++++++-- src/mcp/client/session.py | 24 ++++++-- src/mcp/client/session_group.py | 53 ++++++++++++++-- tests/client/test_http_unicode.py | 1 + tests/client/test_session.py | 58 ++++++++++++++++++ tests/client/test_session_group.py | 18 ++++++ tests/issues/test_88_random_error.py | 2 + tests/shared/test_sse.py | 2 + tests/shared/test_streamable_http.py | 15 ++++- 15 files changed, 239 insertions(+), 17 deletions(-) diff --git a/README.v2.md b/README.v2.md index 6eb869a8a4..8d5eb44d20 100644 --- a/README.v2.md +++ b/README.v2.md @@ -2166,6 +2166,7 @@ async def run(): # Call a tool (add tool from mcpserver_quickstart) result = await session.call_tool("add", arguments={"a": 5, "b": 3}) + assert isinstance(result, types.CallToolResult) result_unstructured = result.content[0] if isinstance(result_unstructured, types.TextContent): print(f"Tool result: {result_unstructured.text}") @@ -2431,19 +2432,22 @@ async def parse_tool_results(): # Example 1: Parsing text content result = await session.call_tool("get_data", {"format": "text"}) + assert isinstance(result, types.CallToolResult) for content in result.content: if isinstance(content, types.TextContent): print(f"Text: {content.text}") # Example 2: Parsing structured content from JSON tools result = await session.call_tool("get_user", {"id": "123"}) - if hasattr(result, "structured_content") and result.structured_content: + assert isinstance(result, types.CallToolResult) + if result.structured_content: # Access structured data directly user_data = result.structured_content print(f"User: {user_data.get('name')}, Age: {user_data.get('age')}") # Example 3: Parsing embedded resources result = await session.call_tool("read_config", {}) + assert isinstance(result, types.CallToolResult) for content in result.content: if isinstance(content, types.EmbeddedResource): resource = content.resource @@ -2454,12 +2458,14 @@ async def parse_tool_results(): # Example 4: Parsing image content result = await session.call_tool("generate_chart", {"data": [1, 2, 3]}) + assert isinstance(result, types.CallToolResult) for content in result.content: if isinstance(content, types.ImageContent): print(f"Image ({content.mime_type}): {len(content.data)} bytes") # Example 5: Handling errors result = await session.call_tool("failing_tool", {}) + assert isinstance(result, types.CallToolResult) if result.is_error: print("Tool execution failed!") for content in result.content: diff --git a/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py b/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py index 0d461d5d11..fbb484ba12 100644 --- a/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py +++ b/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py @@ -25,6 +25,7 @@ from mcp.client.streamable_http import streamable_http_client from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken from mcp.shared.message import SessionMessage +from mcp.types import CallToolResult class InMemoryTokenStorage(TokenStorage): @@ -293,7 +294,7 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None = Non try: result = await self.session.call_tool(tool_name, arguments or {}) print(f"\n🔧 Tool '{tool_name}' result:") - if hasattr(result, "content"): + if isinstance(result, CallToolResult): for content in result.content: if content.type == "text": print(content.text) diff --git a/examples/clients/sse-polling-client/mcp_sse_polling_client/main.py b/examples/clients/sse-polling-client/mcp_sse_polling_client/main.py index e91ed9d527..d6c180296d 100644 --- a/examples/clients/sse-polling-client/mcp_sse_polling_client/main.py +++ b/examples/clients/sse-polling-client/mcp_sse_polling_client/main.py @@ -20,6 +20,7 @@ import click from mcp import ClientSession from mcp.client.streamable_http import streamable_http_client +from mcp.types import CallToolResult async def run_demo(url: str, items: int, checkpoint_every: int) -> None: @@ -55,6 +56,7 @@ async def run_demo(url: str, items: int, checkpoint_every: int) -> None: ) print("-" * 40) + assert isinstance(result, CallToolResult) if result.content: content = result.content[0] text = getattr(content, "text", str(content)) diff --git a/examples/snippets/clients/parsing_tool_results.py b/examples/snippets/clients/parsing_tool_results.py index b166406774..be4dfd7e79 100644 --- a/examples/snippets/clients/parsing_tool_results.py +++ b/examples/snippets/clients/parsing_tool_results.py @@ -16,19 +16,22 @@ async def parse_tool_results(): # Example 1: Parsing text content result = await session.call_tool("get_data", {"format": "text"}) + assert isinstance(result, types.CallToolResult) for content in result.content: if isinstance(content, types.TextContent): print(f"Text: {content.text}") # Example 2: Parsing structured content from JSON tools result = await session.call_tool("get_user", {"id": "123"}) - if hasattr(result, "structured_content") and result.structured_content: + assert isinstance(result, types.CallToolResult) + if result.structured_content: # Access structured data directly user_data = result.structured_content print(f"User: {user_data.get('name')}, Age: {user_data.get('age')}") # Example 3: Parsing embedded resources result = await session.call_tool("read_config", {}) + assert isinstance(result, types.CallToolResult) for content in result.content: if isinstance(content, types.EmbeddedResource): resource = content.resource @@ -39,12 +42,14 @@ async def parse_tool_results(): # Example 4: Parsing image content result = await session.call_tool("generate_chart", {"data": [1, 2, 3]}) + assert isinstance(result, types.CallToolResult) for content in result.content: if isinstance(content, types.ImageContent): print(f"Image ({content.mime_type}): {len(content.data)} bytes") # Example 5: Handling errors result = await session.call_tool("failing_tool", {}) + assert isinstance(result, types.CallToolResult) if result.is_error: print("Tool execution failed!") for content in result.content: diff --git a/examples/snippets/clients/stdio_client.py b/examples/snippets/clients/stdio_client.py index 3f7c4b981b..f2b0c0cb55 100644 --- a/examples/snippets/clients/stdio_client.py +++ b/examples/snippets/clients/stdio_client.py @@ -64,6 +64,7 @@ async def run(): # Call a tool (add tool from mcpserver_quickstart) result = await session.call_tool("add", arguments={"a": 5, "b": 3}) + assert isinstance(result, types.CallToolResult) result_unstructured = result.content[0] if isinstance(result_unstructured, types.TextContent): print(f"Tool result: {result_unstructured.text}") diff --git a/examples/snippets/clients/url_elicitation_client.py b/examples/snippets/clients/url_elicitation_client.py index 2aecbeeee6..6aeb56864c 100644 --- a/examples/snippets/clients/url_elicitation_client.py +++ b/examples/snippets/clients/url_elicitation_client.py @@ -150,6 +150,7 @@ async def call_tool_with_error_handling( """ try: result = await session.call_tool(tool_name, arguments) + assert isinstance(result, types.CallToolResult) # Check if the tool returned an error in the result if result.is_error: diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index 1ab8209b18..f926a46e9d 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -5,7 +5,7 @@ from collections.abc import Awaitable, Callable, Mapping from contextlib import AsyncExitStack from dataclasses import KW_ONLY, dataclass, field -from typing import Any, Literal, TypeVar +from typing import Any, Literal, TypeVar, overload import anyio from typing_extensions import deprecated @@ -30,6 +30,8 @@ EmptyResult, GetPromptResult, Implementation, + InputRequiredResult, + InputResponses, ListPromptsResult, ListResourcesResult, ListResourceTemplatesResult, @@ -374,6 +376,7 @@ async def unsubscribe_resource(self, uri: str, *, meta: RequestParamsMeta | None """Unsubscribe from resource updates.""" return await self.session.unsubscribe_resource(uri, meta=meta) + @overload async def call_tool( self, name: str, @@ -381,8 +384,38 @@ async def call_tool( read_timeout_seconds: float | None = None, progress_callback: ProgressFnT | None = None, *, + input_responses: InputResponses | None = None, + request_state: str | None = None, meta: RequestParamsMeta | None = None, - ) -> CallToolResult: + allow_input_required: Literal[False] = False, + ) -> CallToolResult: ... + + @overload + async def call_tool( + self, + name: str, + arguments: dict[str, Any] | None = None, + read_timeout_seconds: float | None = None, + progress_callback: ProgressFnT | None = None, + *, + input_responses: InputResponses | None = None, + request_state: str | None = None, + meta: RequestParamsMeta | None = None, + allow_input_required: Literal[True], + ) -> CallToolResult | InputRequiredResult: ... + + async def call_tool( + self, + name: str, + arguments: dict[str, Any] | None = None, + read_timeout_seconds: float | None = None, + progress_callback: ProgressFnT | None = None, + *, + input_responses: InputResponses | None = None, + request_state: str | None = None, + meta: RequestParamsMeta | None = None, + allow_input_required: bool = False, + ) -> CallToolResult | InputRequiredResult: """Call a tool on the server. Args: @@ -390,18 +423,38 @@ async def call_tool( arguments: Arguments to pass to the tool read_timeout_seconds: Timeout for the tool call progress_callback: Callback for progress updates + input_responses: Responses to a prior `InputRequiredResult.input_requests` + request_state: Opaque state echoed from a prior `InputRequiredResult` meta: Additional metadata for the request + allow_input_required: When ``False`` (default), an `InputRequiredResult` + from the server raises `RuntimeError`; when ``True``, it is returned + so the caller can resolve the requests and retry. Returns: - The tool result. + The tool result. When ``allow_input_required=True``, may instead be an + `InputRequiredResult` carrying the server's input requests and opaque + ``request_state`` for the retry. + + Raises: + RuntimeError: If the server returns an `InputRequiredResult` and + ``allow_input_required`` is ``False``. """ - return await self.session.call_tool( + result = await self.session.call_tool( name=name, arguments=arguments, read_timeout_seconds=read_timeout_seconds, progress_callback=progress_callback, + input_responses=input_responses, + request_state=request_state, meta=meta, ) + if isinstance(result, InputRequiredResult) and not allow_input_required: + # TODO(L80): replace this raise with the MRTR auto-loop driver (S6). + raise RuntimeError( + "Server returned InputRequiredResult; pass allow_input_required=True to receive it " + "and retry call_tool(..., input_responses=..., request_state=result.request_state)." + ) + return result async def list_prompts( self, diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 8ac3e22882..887b2b4f55 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -173,6 +173,10 @@ async def _default_logging_callback( ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData) +_CallToolResultAdapter: TypeAdapter[types.CallToolResult | types.InputRequiredResult] = TypeAdapter( + types.CallToolResult | types.InputRequiredResult +) + class ClientSession: """Client half of an MCP connection, running on a `Dispatcher`. @@ -269,7 +273,7 @@ async def __aexit__( async def send_request( self, request: types.ClientRequest, - result_type: type[ReceiveResultT], + result_type: type[ReceiveResultT] | TypeAdapter[ReceiveResultT], request_read_timeout_seconds: float | None = None, metadata: ClientMessageMetadata | None = None, progress_callback: ProgressFnT | None = None, @@ -308,6 +312,8 @@ async def send_request( _methods.validate_server_result(method, version, raw) except KeyError: pass + if isinstance(result_type, TypeAdapter): + return result_type.validate_python(raw) return result_type.model_validate(raw, by_name=False) async def send_notification(self, notification: types.ClientNotification) -> None: @@ -603,20 +609,28 @@ async def call_tool( read_timeout_seconds: float | None = None, progress_callback: ProgressFnT | None = None, *, + input_responses: types.InputResponses | None = None, + request_state: str | None = None, meta: RequestParamsMeta | None = None, - ) -> types.CallToolResult: + ) -> types.CallToolResult | types.InputRequiredResult: """Send a tools/call request with optional progress callback support.""" result = await self.send_request( types.CallToolRequest( - params=types.CallToolRequestParams(name=name, arguments=arguments, _meta=meta), + params=types.CallToolRequestParams( + name=name, + arguments=arguments, + input_responses=input_responses, + request_state=request_state, + _meta=meta, + ), ), - types.CallToolResult, + _CallToolResultAdapter, request_read_timeout_seconds=read_timeout_seconds, progress_callback=progress_callback, ) - if not result.is_error: + if isinstance(result, types.CallToolResult) and not result.is_error: await self._validate_tool_result(name, result) return result diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 9610212642..17a227af48 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -11,7 +11,7 @@ from collections.abc import Callable from dataclasses import dataclass from types import TracebackType -from typing import Any, TypeAlias +from typing import Any, Literal, TypeAlias, overload import anyio import httpx @@ -190,6 +190,7 @@ def tools(self) -> dict[str, types.Tool]: """Returns the tools as a dictionary of names to tools.""" return self._tools + @overload async def call_tool( self, name: str, @@ -197,18 +198,62 @@ async def call_tool( read_timeout_seconds: float | None = None, progress_callback: ProgressFnT | None = None, *, + input_responses: types.InputResponses | None = None, + request_state: str | None = None, meta: types.RequestParamsMeta | None = None, - ) -> types.CallToolResult: - """Executes a tool given its name and arguments.""" + allow_input_required: Literal[False] = False, + ) -> types.CallToolResult: ... + + @overload + async def call_tool( + self, + name: str, + arguments: dict[str, Any] | None = None, + read_timeout_seconds: float | None = None, + progress_callback: ProgressFnT | None = None, + *, + input_responses: types.InputResponses | None = None, + request_state: str | None = None, + meta: types.RequestParamsMeta | None = None, + allow_input_required: Literal[True], + ) -> types.CallToolResult | types.InputRequiredResult: ... + + async def call_tool( + self, + name: str, + arguments: dict[str, Any] | None = None, + read_timeout_seconds: float | None = None, + progress_callback: ProgressFnT | None = None, + *, + input_responses: types.InputResponses | None = None, + request_state: str | None = None, + meta: types.RequestParamsMeta | None = None, + allow_input_required: bool = False, + ) -> types.CallToolResult | types.InputRequiredResult: + """Executes a tool given its name and arguments. + + Raises: + RuntimeError: If the server returns an `InputRequiredResult` and + ``allow_input_required`` is ``False``. + """ session = self._tool_to_session[name] session_tool_name = self.tools[name].name - return await session.call_tool( + result = await session.call_tool( session_tool_name, arguments=arguments, read_timeout_seconds=read_timeout_seconds, progress_callback=progress_callback, + input_responses=input_responses, + request_state=request_state, meta=meta, ) + if isinstance(result, types.InputRequiredResult) and not allow_input_required: + # TODO(L80): replace this raise with the MRTR auto-loop driver (S6). + raise RuntimeError( + "Server returned InputRequiredResult; pass allow_input_required=True to receive it " + "and retry call_tool(..., input_responses=..., request_state=result.request_state)." + ) + return result async def disconnect_from_server(self, session: mcp.ClientSession) -> None: """Disconnects from a single MCP server.""" diff --git a/tests/client/test_http_unicode.py b/tests/client/test_http_unicode.py index 585a142617..c82946df6f 100644 --- a/tests/client/test_http_unicode.py +++ b/tests/client/test_http_unicode.py @@ -142,6 +142,7 @@ async def test_streamable_http_client_unicode_tool_call() -> None: # Test 2: Send Unicode text in tool call (client→server→client) for test_name, test_string in UNICODE_TEST_STRINGS.items(): result = await session.call_tool("echo_unicode", arguments={"text": test_string}) + assert isinstance(result, types.CallToolResult) # Verify server correctly received and echoed back Unicode assert len(result.content) == 1 diff --git a/tests/client/test_session.py b/tests/client/test_session.py index c24a4569c5..f46d6e6069 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -12,7 +12,9 @@ from mcp import MCPError, types from mcp.client import ClientRequestContext +from mcp.client.client import Client from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession +from mcp.server import Server, ServerRequestContext from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair from mcp.shared.dispatcher import CallOptions, DispatchContext, OnNotify, OnRequest from mcp.shared.message import SessionMessage @@ -1656,3 +1658,59 @@ async def test_discover_reraises_unsupported_version_with_malformed_error_data() await session.discover() assert exc.value.error.code == UNSUPPORTED_PROTOCOL_VERSION assert [m for m, _ in dispatcher.calls] == ["server/discover"] + + +@pytest.mark.anyio +async def test_call_tool_returns_input_required_result_when_server_requests_input() -> None: + # `on_call_tool` is still typed `-> CallToolResult` on this branch (#2967 widens it later); + # `add_request_handler` is `HandlerResult`-typed and accepts `InputRequiredResult` cleanly. + async def handler(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.InputRequiredResult: + return types.InputRequiredResult(request_state="s") + + server = Server("test") + server.add_request_handler("tools/call", types.CallToolRequestParams, handler) + with anyio.fail_after(5): + async with Client(server, mode="2026-07-28") as client: + result = await client.call_tool("ask", allow_input_required=True) + assert isinstance(result, types.InputRequiredResult) + assert result.request_state == "s" + + +@pytest.mark.anyio +async def test_call_tool_threads_input_responses_and_request_state_into_params() -> None: + captured: list[types.CallToolRequestParams] = [] + + async def on_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + captured.append(params) + return CallToolResult(content=[]) + + async def on_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[]) + + server = Server("test", on_call_tool=on_call_tool, on_list_tools=on_list_tools) + with anyio.fail_after(5): + async with Client(server, mode="2026-07-28") as client: + await client.call_tool( + "ask", + input_responses={"k": types.ElicitResult(action="decline")}, + request_state="s", + ) + assert captured[0].input_responses == {"k": types.ElicitResult(action="decline")} + assert captured[0].request_state == "s" + + +@pytest.mark.anyio +async def test_client_call_tool_raises_on_input_required_without_opt_in() -> None: + async def handler(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.InputRequiredResult: + return types.InputRequiredResult(request_state="s") + + server = Server("test") + server.add_request_handler("tools/call", types.CallToolRequestParams, handler) + with anyio.fail_after(5): + async with Client(server, mode="2026-07-28") as client: + with pytest.raises(RuntimeError): + await client.call_tool("t") + result = await client.call_tool("t", allow_input_required=True) + assert isinstance(result, types.InputRequiredResult) diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py index 6a58b39f39..73bec00d88 100644 --- a/tests/client/test_session_group.py +++ b/tests/client/test_session_group.py @@ -82,10 +82,28 @@ def hook(name: str, server_info: types.Implementation) -> str: # pragma: no cov arguments={"name": "value1", "args": {}}, read_timeout_seconds=None, progress_callback=None, + input_responses=None, + request_state=None, meta=None, ) +@pytest.mark.anyio +async def test_client_session_group_call_tool_input_required(): + mock_session = mock.AsyncMock() + mcp_session_group = ClientSessionGroup() + mcp_session_group._tools = {"my_tool": types.Tool(name="my_tool", input_schema={})} + mcp_session_group._tool_to_session = {"my_tool": mock_session} + mock_session.call_tool.return_value = types.InputRequiredResult(request_state="s") + + with pytest.raises(RuntimeError): + await mcp_session_group.call_tool(name="my_tool", arguments={}) + + result = await mcp_session_group.call_tool(name="my_tool", arguments={}, allow_input_required=True) + assert isinstance(result, types.InputRequiredResult) + assert result.request_state == "s" + + @pytest.mark.anyio async def test_client_session_group_connect_to_server(mock_exit_stack: contextlib.AsyncExitStack): """Test connecting to a server and aggregating components.""" diff --git a/tests/issues/test_88_random_error.py b/tests/issues/test_88_random_error.py index b1c6a4f709..cf5ff5c74f 100644 --- a/tests/issues/test_88_random_error.py +++ b/tests/issues/test_88_random_error.py @@ -96,6 +96,7 @@ async def client( # First call should work (fast operation, no timeout) result = await session.call_tool("fast", read_timeout_seconds=None) + assert isinstance(result, CallToolResult) assert result.content == [TextContent(type="text", text="fast 1")] assert not slow_request_lock.is_set() @@ -111,6 +112,7 @@ async def client( # Third call should work (fast operation, no timeout), # proving server is still responsive result = await session.call_tool("fast", read_timeout_seconds=None) + assert isinstance(result, CallToolResult) assert result.content == [TextContent(type="text", text="fast 3")] scope.cancel() # pragma: lax no cover diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 675a4acb16..71d41ddcb8 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -306,6 +306,7 @@ async def test_request_context_propagation() -> None: assert isinstance(result, InitializeResult) tool_result = await session.call_tool("echo_headers", {}) + assert isinstance(tool_result, CallToolResult) assert len(tool_result.content) == 1 content = tool_result.content[0] @@ -332,6 +333,7 @@ async def test_request_context_isolation() -> None: await session.initialize() tool_result = await session.call_tool("echo_context", {"request_id": f"request-{i}"}) + assert isinstance(tool_result, CallToolResult) assert len(tool_result.content) == 1 content = tool_result.content[0] diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 5360e56ff6..513f5a4484 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -899,6 +899,7 @@ async def test_streamable_http_client_tool_invocation(initialized_client_session # Call the tool result = await initialized_client_session.call_tool("test_tool", {}) + assert isinstance(result, CallToolResult) assert len(result.content) == 1 assert result.content[0].type == "text" assert result.content[0].text == "Called test_tool" @@ -956,6 +957,7 @@ async def test_streamable_http_client_json_response(json_app: Starlette) -> None # Call a tool and verify JSON response handling result = await session.call_tool("test_tool", {}) + assert isinstance(result, CallToolResult) assert len(result.content) == 1 assert result.content[0].type == "text" assert result.content[0].text == "Called test_tool" @@ -1277,6 +1279,7 @@ async def sampling_callback( # Call the tool that triggers server-side sampling tool_result = await session.call_tool("test_sampling_tool", {}) + assert isinstance(tool_result, CallToolResult) # Verify the tool result contains the expected content assert len(tool_result.content) == 1 @@ -1418,6 +1421,7 @@ async def test_streamablehttp_request_context_propagation(context_app: Starlette # Call the tool that echoes headers back tool_result = await session.call_tool("echo_headers", {}) + assert isinstance(tool_result, CallToolResult) # Parse the JSON response assert len(tool_result.content) == 1 @@ -1453,6 +1457,7 @@ async def test_streamablehttp_request_context_isolation(context_app: Starlette) # Call the tool that echoes context tool_result = await session.call_tool("echo_context", {"request_id": f"request-{i}"}) + assert isinstance(tool_result, CallToolResult) assert len(tool_result.content) == 1 assert isinstance(tool_result.content[0], TextContent) @@ -1482,6 +1487,7 @@ async def test_client_includes_protocol_version_header_after_init(context_app: S # Call a tool that echoes headers to verify the header is present tool_result = await session.call_tool("echo_headers", {}) + assert isinstance(tool_result, CallToolResult) assert len(tool_result.content) == 1 assert isinstance(tool_result.content[0], TextContent) @@ -1782,7 +1788,7 @@ async def test_server_close_sse_stream_via_context( result = await session.call_tool("tool_with_stream_close", {}) # Client should still receive complete response (via auto-reconnect) - assert result is not None + assert isinstance(result, CallToolResult) assert len(result.content) > 0 assert result.content[0].type == "text" assert isinstance(result.content[0], TextContent) @@ -1819,6 +1825,7 @@ async def message_handler( # 3. Sends more notifications (stored in event_store) # 4. Returns response result = await session.call_tool("tool_with_stream_close", {}) + assert isinstance(result, CallToolResult) # Client should have auto-reconnected and received ALL notifications assert len(captured_notifications) >= 2, ( @@ -1848,6 +1855,7 @@ async def test_streamable_http_client_respects_retry_interval( elapsed = time.monotonic() - start_time # Verify result was received + assert isinstance(result, CallToolResult) assert result.content[0].type == "text" assert isinstance(result.content[0], TextContent) assert result.content[0].text == "Done" @@ -1889,6 +1897,7 @@ async def message_handler( # 5. Server sends "After close" notification # 6. Server sends final response result = await session.call_tool("tool_with_stream_close", {}) + assert isinstance(result, CallToolResult) # Verify all notifications received in order assert "Before close" in all_notifications, "Should receive notification sent before stream close" @@ -1927,6 +1936,7 @@ async def message_handler( # Tool sends: notification1, close_stream, notification2, notification3, response # Client should receive all notifications even though 2&3 were sent during disconnect result = await session.call_tool("tool_with_multiple_notifications_and_close", {}) + assert isinstance(result, CallToolResult) assert "notification1" in notification_data, "Should receive notification1 (sent before close)" assert "notification2" in notification_data, "Should receive notification2 (sent after close, replayed)" @@ -2063,6 +2073,7 @@ async def message_handler( # 3. Sends notification_2 (stored in event_store) # 4. Returns response result = await session.call_tool("tool_with_standalone_stream_close", {}) + assert isinstance(result, CallToolResult) # Verify the tool completed assert result.content[0].type == "text" @@ -2125,6 +2136,7 @@ async def test_streamable_http_client_mcp_headers_override_defaults(context_app: # Use echo_headers tool to see what headers the server actually received tool_result = await session.call_tool("echo_headers", {}) + assert isinstance(tool_result, CallToolResult) assert len(tool_result.content) == 1 assert isinstance(tool_result.content[0], TextContent) headers_data = json.loads(tool_result.content[0].text) @@ -2154,6 +2166,7 @@ async def test_streamable_http_client_preserves_custom_with_mcp_headers(context_ # Use echo_headers tool to verify both custom and MCP headers are present tool_result = await session.call_tool("echo_headers", {}) + assert isinstance(tool_result, CallToolResult) assert len(tool_result.content) == 1 assert isinstance(tool_result.content[0], TextContent) headers_data = json.loads(tool_result.content[0].text) From 473f09c05fdf9114c7d1b68228e8439d6030dacb Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 25 Jun 2026 13:39:21 +0000 Subject: [PATCH 2/6] docs/migration: note ClientSession.call_tool union return --- docs/migration.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/migration.md b/docs/migration.md index 46ec205ee9..cf99384c80 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -364,6 +364,12 @@ For an in-process `Client(server)` (where `server` is a `Server` or `MCPServer` `Client.send_ping()` is deprecated (ping is removed in 2026-07-28); pin `mode='legacy'` if you need it. +### `ClientSession.call_tool` returns `CallToolResult | InputRequiredResult` + +For protocol 2026-07-28, a `tools/call` request may return an `InputRequiredResult` asking the client to supply additional input and retry. `ClientSession.call_tool` now returns `CallToolResult | InputRequiredResult` to reflect this; narrow with `isinstance(result, CallToolResult)` before reading `.content` / `.is_error` / `.structured_content`. + +The high-level `Client.call_tool` still returns `CallToolResult` by default (and raises `RuntimeError` if the server requests input). Pass `allow_input_required=True` to receive the `InputRequiredResult` and retry with `input_responses=` / `request_state=`. + ### `McpError` renamed to `MCPError` The `McpError` exception class has been renamed to `MCPError` for consistent naming with the MCP acronym style used throughout the SDK. From 0b4d8a9a7088df190ea721dc18534ed7f67caa47 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 25 Jun 2026 13:40:56 +0000 Subject: [PATCH 3/6] Retag call_tool driver placeholder TODO to L84 --- src/mcp/client/client.py | 2 +- src/mcp/client/session_group.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index f926a46e9d..d89440761e 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -449,7 +449,7 @@ async def call_tool( meta=meta, ) if isinstance(result, InputRequiredResult) and not allow_input_required: - # TODO(L80): replace this raise with the MRTR auto-loop driver (S6). + # TODO(L84): replace this raise with the MRTR auto-loop driver (S6). raise RuntimeError( "Server returned InputRequiredResult; pass allow_input_required=True to receive it " "and retry call_tool(..., input_responses=..., request_state=result.request_state)." diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 17a227af48..f87317d3f7 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -248,7 +248,7 @@ async def call_tool( meta=meta, ) if isinstance(result, types.InputRequiredResult) and not allow_input_required: - # TODO(L80): replace this raise with the MRTR auto-loop driver (S6). + # TODO(L84): replace this raise with the MRTR auto-loop driver (S6). raise RuntimeError( "Server returned InputRequiredResult; pass allow_input_required=True to receive it " "and retry call_tool(..., input_responses=..., request_state=result.request_state)." From 18559fe6e7c822183316603e663582fc8a1df958 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 25 Jun 2026 14:39:29 +0000 Subject: [PATCH 4/6] Move allow_input_required overload+gate down to ClientSession.call_tool The gate now lives once on ClientSession (mechanics layer); Client and ClientSessionGroup are pure passthroughs that forward allow_input_required. Third 'bool' overload on ClientSession.call_tool lets the passthrough impls type-check. Reverts the isinstance narrowing in examples/ and tests/shared/; default return is CallToolResult everywhere, so the change is additive. --- README.v2.md | 8 +-- docs/migration.md | 6 +- .../mcp_simple_auth_client/main.py | 3 +- .../mcp_sse_polling_client/main.py | 2 - .../snippets/clients/parsing_tool_results.py | 7 +- examples/snippets/clients/stdio_client.py | 1 - .../clients/url_elicitation_client.py | 1 - src/mcp/client/client.py | 11 +--- src/mcp/client/session.py | 64 ++++++++++++++++++- src/mcp/client/session_group.py | 10 +-- tests/client/test_http_unicode.py | 1 - tests/client/test_session_group.py | 7 +- tests/issues/test_88_random_error.py | 2 - tests/shared/test_sse.py | 2 - tests/shared/test_streamable_http.py | 15 +---- 15 files changed, 76 insertions(+), 64 deletions(-) diff --git a/README.v2.md b/README.v2.md index 8d5eb44d20..6eb869a8a4 100644 --- a/README.v2.md +++ b/README.v2.md @@ -2166,7 +2166,6 @@ async def run(): # Call a tool (add tool from mcpserver_quickstart) result = await session.call_tool("add", arguments={"a": 5, "b": 3}) - assert isinstance(result, types.CallToolResult) result_unstructured = result.content[0] if isinstance(result_unstructured, types.TextContent): print(f"Tool result: {result_unstructured.text}") @@ -2432,22 +2431,19 @@ async def parse_tool_results(): # Example 1: Parsing text content result = await session.call_tool("get_data", {"format": "text"}) - assert isinstance(result, types.CallToolResult) for content in result.content: if isinstance(content, types.TextContent): print(f"Text: {content.text}") # Example 2: Parsing structured content from JSON tools result = await session.call_tool("get_user", {"id": "123"}) - assert isinstance(result, types.CallToolResult) - if result.structured_content: + if hasattr(result, "structured_content") and result.structured_content: # Access structured data directly user_data = result.structured_content print(f"User: {user_data.get('name')}, Age: {user_data.get('age')}") # Example 3: Parsing embedded resources result = await session.call_tool("read_config", {}) - assert isinstance(result, types.CallToolResult) for content in result.content: if isinstance(content, types.EmbeddedResource): resource = content.resource @@ -2458,14 +2454,12 @@ async def parse_tool_results(): # Example 4: Parsing image content result = await session.call_tool("generate_chart", {"data": [1, 2, 3]}) - assert isinstance(result, types.CallToolResult) for content in result.content: if isinstance(content, types.ImageContent): print(f"Image ({content.mime_type}): {len(content.data)} bytes") # Example 5: Handling errors result = await session.call_tool("failing_tool", {}) - assert isinstance(result, types.CallToolResult) if result.is_error: print("Tool execution failed!") for content in result.content: diff --git a/docs/migration.md b/docs/migration.md index cf99384c80..1c965562d1 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -364,11 +364,9 @@ For an in-process `Client(server)` (where `server` is a `Server` or `MCPServer` `Client.send_ping()` is deprecated (ping is removed in 2026-07-28); pin `mode='legacy'` if you need it. -### `ClientSession.call_tool` returns `CallToolResult | InputRequiredResult` +### `call_tool` can return `InputRequiredResult` (opt-in) -For protocol 2026-07-28, a `tools/call` request may return an `InputRequiredResult` asking the client to supply additional input and retry. `ClientSession.call_tool` now returns `CallToolResult | InputRequiredResult` to reflect this; narrow with `isinstance(result, CallToolResult)` before reading `.content` / `.is_error` / `.structured_content`. - -The high-level `Client.call_tool` still returns `CallToolResult` by default (and raises `RuntimeError` if the server requests input). Pass `allow_input_required=True` to receive the `InputRequiredResult` and retry with `input_responses=` / `request_state=`. +For protocol 2026-07-28, a `tools/call` request may return an `InputRequiredResult` asking the client to supply additional input and retry. By default `call_tool` (on `ClientSession`, `Client`, and `ClientSessionGroup`) still returns `CallToolResult` and raises `RuntimeError` if the server requests input. Pass `allow_input_required=True` to receive the `InputRequiredResult` instead, then retry with `input_responses=` / `request_state=`. ### `McpError` renamed to `MCPError` diff --git a/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py b/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py index fbb484ba12..0d461d5d11 100644 --- a/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py +++ b/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py @@ -25,7 +25,6 @@ from mcp.client.streamable_http import streamable_http_client from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken from mcp.shared.message import SessionMessage -from mcp.types import CallToolResult class InMemoryTokenStorage(TokenStorage): @@ -294,7 +293,7 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None = Non try: result = await self.session.call_tool(tool_name, arguments or {}) print(f"\n🔧 Tool '{tool_name}' result:") - if isinstance(result, CallToolResult): + if hasattr(result, "content"): for content in result.content: if content.type == "text": print(content.text) diff --git a/examples/clients/sse-polling-client/mcp_sse_polling_client/main.py b/examples/clients/sse-polling-client/mcp_sse_polling_client/main.py index d6c180296d..e91ed9d527 100644 --- a/examples/clients/sse-polling-client/mcp_sse_polling_client/main.py +++ b/examples/clients/sse-polling-client/mcp_sse_polling_client/main.py @@ -20,7 +20,6 @@ import click from mcp import ClientSession from mcp.client.streamable_http import streamable_http_client -from mcp.types import CallToolResult async def run_demo(url: str, items: int, checkpoint_every: int) -> None: @@ -56,7 +55,6 @@ async def run_demo(url: str, items: int, checkpoint_every: int) -> None: ) print("-" * 40) - assert isinstance(result, CallToolResult) if result.content: content = result.content[0] text = getattr(content, "text", str(content)) diff --git a/examples/snippets/clients/parsing_tool_results.py b/examples/snippets/clients/parsing_tool_results.py index be4dfd7e79..b166406774 100644 --- a/examples/snippets/clients/parsing_tool_results.py +++ b/examples/snippets/clients/parsing_tool_results.py @@ -16,22 +16,19 @@ async def parse_tool_results(): # Example 1: Parsing text content result = await session.call_tool("get_data", {"format": "text"}) - assert isinstance(result, types.CallToolResult) for content in result.content: if isinstance(content, types.TextContent): print(f"Text: {content.text}") # Example 2: Parsing structured content from JSON tools result = await session.call_tool("get_user", {"id": "123"}) - assert isinstance(result, types.CallToolResult) - if result.structured_content: + if hasattr(result, "structured_content") and result.structured_content: # Access structured data directly user_data = result.structured_content print(f"User: {user_data.get('name')}, Age: {user_data.get('age')}") # Example 3: Parsing embedded resources result = await session.call_tool("read_config", {}) - assert isinstance(result, types.CallToolResult) for content in result.content: if isinstance(content, types.EmbeddedResource): resource = content.resource @@ -42,14 +39,12 @@ async def parse_tool_results(): # Example 4: Parsing image content result = await session.call_tool("generate_chart", {"data": [1, 2, 3]}) - assert isinstance(result, types.CallToolResult) for content in result.content: if isinstance(content, types.ImageContent): print(f"Image ({content.mime_type}): {len(content.data)} bytes") # Example 5: Handling errors result = await session.call_tool("failing_tool", {}) - assert isinstance(result, types.CallToolResult) if result.is_error: print("Tool execution failed!") for content in result.content: diff --git a/examples/snippets/clients/stdio_client.py b/examples/snippets/clients/stdio_client.py index f2b0c0cb55..3f7c4b981b 100644 --- a/examples/snippets/clients/stdio_client.py +++ b/examples/snippets/clients/stdio_client.py @@ -64,7 +64,6 @@ async def run(): # Call a tool (add tool from mcpserver_quickstart) result = await session.call_tool("add", arguments={"a": 5, "b": 3}) - assert isinstance(result, types.CallToolResult) result_unstructured = result.content[0] if isinstance(result_unstructured, types.TextContent): print(f"Tool result: {result_unstructured.text}") diff --git a/examples/snippets/clients/url_elicitation_client.py b/examples/snippets/clients/url_elicitation_client.py index 6aeb56864c..2aecbeeee6 100644 --- a/examples/snippets/clients/url_elicitation_client.py +++ b/examples/snippets/clients/url_elicitation_client.py @@ -150,7 +150,6 @@ async def call_tool_with_error_handling( """ try: result = await session.call_tool(tool_name, arguments) - assert isinstance(result, types.CallToolResult) # Check if the tool returned an error in the result if result.is_error: diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index d89440761e..38f546f1c4 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -439,7 +439,8 @@ async def call_tool( RuntimeError: If the server returns an `InputRequiredResult` and ``allow_input_required`` is ``False``. """ - result = await self.session.call_tool( + # TODO(L84): stop forwarding allow_input_required; run the MRTR auto-loop driver here (S6). + return await self.session.call_tool( name=name, arguments=arguments, read_timeout_seconds=read_timeout_seconds, @@ -447,14 +448,8 @@ async def call_tool( input_responses=input_responses, request_state=request_state, meta=meta, + allow_input_required=allow_input_required, ) - if isinstance(result, InputRequiredResult) and not allow_input_required: - # TODO(L84): replace this raise with the MRTR auto-loop driver (S6). - raise RuntimeError( - "Server returned InputRequiredResult; pass allow_input_required=True to receive it " - "and retry call_tool(..., input_responses=..., request_state=result.request_state)." - ) - return result async def list_prompts( self, diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 887b2b4f55..90dffa6e3d 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -4,7 +4,7 @@ from collections.abc import Callable, Mapping from dataclasses import dataclass from types import TracebackType -from typing import Any, Protocol, cast +from typing import Any, Literal, Protocol, cast, overload import anyio import anyio.abc @@ -602,6 +602,7 @@ async def unsubscribe_resource(self, uri: str, *, meta: RequestParamsMeta | None types.EmptyResult, ) + @overload async def call_tool( self, name: str, @@ -612,8 +613,62 @@ async def call_tool( input_responses: types.InputResponses | None = None, request_state: str | None = None, meta: RequestParamsMeta | None = None, + allow_input_required: Literal[False] = False, + ) -> types.CallToolResult: ... + + @overload + async def call_tool( + self, + name: str, + arguments: dict[str, Any] | None = None, + read_timeout_seconds: float | None = None, + progress_callback: ProgressFnT | None = None, + *, + input_responses: types.InputResponses | None = None, + request_state: str | None = None, + meta: RequestParamsMeta | None = None, + allow_input_required: Literal[True], + ) -> types.CallToolResult | types.InputRequiredResult: ... + + @overload + async def call_tool( + self, + name: str, + arguments: dict[str, Any] | None = None, + read_timeout_seconds: float | None = None, + progress_callback: ProgressFnT | None = None, + *, + input_responses: types.InputResponses | None = None, + request_state: str | None = None, + meta: RequestParamsMeta | None = None, + allow_input_required: bool, + ) -> types.CallToolResult | types.InputRequiredResult: ... + + async def call_tool( + self, + name: str, + arguments: dict[str, Any] | None = None, + read_timeout_seconds: float | None = None, + progress_callback: ProgressFnT | None = None, + *, + input_responses: types.InputResponses | None = None, + request_state: str | None = None, + meta: RequestParamsMeta | None = None, + allow_input_required: bool = False, ) -> types.CallToolResult | types.InputRequiredResult: - """Send a tools/call request with optional progress callback support.""" + """Send a tools/call request with optional progress callback support. + + Args: + input_responses: Responses to a prior `InputRequiredResult.input_requests`. + request_state: Opaque state echoed from a prior `InputRequiredResult`. + allow_input_required: When ``False`` (default), an `InputRequiredResult` + from the server raises `RuntimeError`; when ``True``, it is returned + so the caller can resolve the requests and retry. + + Raises: + RuntimeError: If the server returns an `InputRequiredResult` and + ``allow_input_required`` is ``False``. + """ result = await self.send_request( types.CallToolRequest( @@ -633,6 +688,11 @@ async def call_tool( if isinstance(result, types.CallToolResult) and not result.is_error: await self._validate_tool_result(name, result) + if isinstance(result, types.InputRequiredResult) and not allow_input_required: + raise RuntimeError( + "Server returned InputRequiredResult; pass allow_input_required=True to receive it " + "and retry call_tool(..., input_responses=..., request_state=result.request_state)." + ) return result async def _validate_tool_result(self, name: str, result: types.CallToolResult) -> None: diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index f87317d3f7..4f97caaa10 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -238,7 +238,7 @@ async def call_tool( """ session = self._tool_to_session[name] session_tool_name = self.tools[name].name - result = await session.call_tool( + return await session.call_tool( session_tool_name, arguments=arguments, read_timeout_seconds=read_timeout_seconds, @@ -246,14 +246,8 @@ async def call_tool( input_responses=input_responses, request_state=request_state, meta=meta, + allow_input_required=allow_input_required, ) - if isinstance(result, types.InputRequiredResult) and not allow_input_required: - # TODO(L84): replace this raise with the MRTR auto-loop driver (S6). - raise RuntimeError( - "Server returned InputRequiredResult; pass allow_input_required=True to receive it " - "and retry call_tool(..., input_responses=..., request_state=result.request_state)." - ) - return result async def disconnect_from_server(self, session: mcp.ClientSession) -> None: """Disconnects from a single MCP server.""" diff --git a/tests/client/test_http_unicode.py b/tests/client/test_http_unicode.py index c82946df6f..585a142617 100644 --- a/tests/client/test_http_unicode.py +++ b/tests/client/test_http_unicode.py @@ -142,7 +142,6 @@ async def test_streamable_http_client_unicode_tool_call() -> None: # Test 2: Send Unicode text in tool call (client→server→client) for test_name, test_string in UNICODE_TEST_STRINGS.items(): result = await session.call_tool("echo_unicode", arguments={"text": test_string}) - assert isinstance(result, types.CallToolResult) # Verify server correctly received and echoed back Unicode assert len(result.content) == 1 diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py index 73bec00d88..faa4281e3c 100644 --- a/tests/client/test_session_group.py +++ b/tests/client/test_session_group.py @@ -85,23 +85,22 @@ def hook(name: str, server_info: types.Implementation) -> str: # pragma: no cov input_responses=None, request_state=None, meta=None, + allow_input_required=False, ) @pytest.mark.anyio -async def test_client_session_group_call_tool_input_required(): +async def test_client_session_group_call_tool_forwards_allow_input_required(): mock_session = mock.AsyncMock() mcp_session_group = ClientSessionGroup() mcp_session_group._tools = {"my_tool": types.Tool(name="my_tool", input_schema={})} mcp_session_group._tool_to_session = {"my_tool": mock_session} mock_session.call_tool.return_value = types.InputRequiredResult(request_state="s") - with pytest.raises(RuntimeError): - await mcp_session_group.call_tool(name="my_tool", arguments={}) - result = await mcp_session_group.call_tool(name="my_tool", arguments={}, allow_input_required=True) assert isinstance(result, types.InputRequiredResult) assert result.request_state == "s" + assert mock_session.call_tool.call_args.kwargs["allow_input_required"] is True @pytest.mark.anyio diff --git a/tests/issues/test_88_random_error.py b/tests/issues/test_88_random_error.py index cf5ff5c74f..b1c6a4f709 100644 --- a/tests/issues/test_88_random_error.py +++ b/tests/issues/test_88_random_error.py @@ -96,7 +96,6 @@ async def client( # First call should work (fast operation, no timeout) result = await session.call_tool("fast", read_timeout_seconds=None) - assert isinstance(result, CallToolResult) assert result.content == [TextContent(type="text", text="fast 1")] assert not slow_request_lock.is_set() @@ -112,7 +111,6 @@ async def client( # Third call should work (fast operation, no timeout), # proving server is still responsive result = await session.call_tool("fast", read_timeout_seconds=None) - assert isinstance(result, CallToolResult) assert result.content == [TextContent(type="text", text="fast 3")] scope.cancel() # pragma: lax no cover diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 71d41ddcb8..675a4acb16 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -306,7 +306,6 @@ async def test_request_context_propagation() -> None: assert isinstance(result, InitializeResult) tool_result = await session.call_tool("echo_headers", {}) - assert isinstance(tool_result, CallToolResult) assert len(tool_result.content) == 1 content = tool_result.content[0] @@ -333,7 +332,6 @@ async def test_request_context_isolation() -> None: await session.initialize() tool_result = await session.call_tool("echo_context", {"request_id": f"request-{i}"}) - assert isinstance(tool_result, CallToolResult) assert len(tool_result.content) == 1 content = tool_result.content[0] diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 513f5a4484..5360e56ff6 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -899,7 +899,6 @@ async def test_streamable_http_client_tool_invocation(initialized_client_session # Call the tool result = await initialized_client_session.call_tool("test_tool", {}) - assert isinstance(result, CallToolResult) assert len(result.content) == 1 assert result.content[0].type == "text" assert result.content[0].text == "Called test_tool" @@ -957,7 +956,6 @@ async def test_streamable_http_client_json_response(json_app: Starlette) -> None # Call a tool and verify JSON response handling result = await session.call_tool("test_tool", {}) - assert isinstance(result, CallToolResult) assert len(result.content) == 1 assert result.content[0].type == "text" assert result.content[0].text == "Called test_tool" @@ -1279,7 +1277,6 @@ async def sampling_callback( # Call the tool that triggers server-side sampling tool_result = await session.call_tool("test_sampling_tool", {}) - assert isinstance(tool_result, CallToolResult) # Verify the tool result contains the expected content assert len(tool_result.content) == 1 @@ -1421,7 +1418,6 @@ async def test_streamablehttp_request_context_propagation(context_app: Starlette # Call the tool that echoes headers back tool_result = await session.call_tool("echo_headers", {}) - assert isinstance(tool_result, CallToolResult) # Parse the JSON response assert len(tool_result.content) == 1 @@ -1457,7 +1453,6 @@ async def test_streamablehttp_request_context_isolation(context_app: Starlette) # Call the tool that echoes context tool_result = await session.call_tool("echo_context", {"request_id": f"request-{i}"}) - assert isinstance(tool_result, CallToolResult) assert len(tool_result.content) == 1 assert isinstance(tool_result.content[0], TextContent) @@ -1487,7 +1482,6 @@ async def test_client_includes_protocol_version_header_after_init(context_app: S # Call a tool that echoes headers to verify the header is present tool_result = await session.call_tool("echo_headers", {}) - assert isinstance(tool_result, CallToolResult) assert len(tool_result.content) == 1 assert isinstance(tool_result.content[0], TextContent) @@ -1788,7 +1782,7 @@ async def test_server_close_sse_stream_via_context( result = await session.call_tool("tool_with_stream_close", {}) # Client should still receive complete response (via auto-reconnect) - assert isinstance(result, CallToolResult) + assert result is not None assert len(result.content) > 0 assert result.content[0].type == "text" assert isinstance(result.content[0], TextContent) @@ -1825,7 +1819,6 @@ async def message_handler( # 3. Sends more notifications (stored in event_store) # 4. Returns response result = await session.call_tool("tool_with_stream_close", {}) - assert isinstance(result, CallToolResult) # Client should have auto-reconnected and received ALL notifications assert len(captured_notifications) >= 2, ( @@ -1855,7 +1848,6 @@ async def test_streamable_http_client_respects_retry_interval( elapsed = time.monotonic() - start_time # Verify result was received - assert isinstance(result, CallToolResult) assert result.content[0].type == "text" assert isinstance(result.content[0], TextContent) assert result.content[0].text == "Done" @@ -1897,7 +1889,6 @@ async def message_handler( # 5. Server sends "After close" notification # 6. Server sends final response result = await session.call_tool("tool_with_stream_close", {}) - assert isinstance(result, CallToolResult) # Verify all notifications received in order assert "Before close" in all_notifications, "Should receive notification sent before stream close" @@ -1936,7 +1927,6 @@ async def message_handler( # Tool sends: notification1, close_stream, notification2, notification3, response # Client should receive all notifications even though 2&3 were sent during disconnect result = await session.call_tool("tool_with_multiple_notifications_and_close", {}) - assert isinstance(result, CallToolResult) assert "notification1" in notification_data, "Should receive notification1 (sent before close)" assert "notification2" in notification_data, "Should receive notification2 (sent after close, replayed)" @@ -2073,7 +2063,6 @@ async def message_handler( # 3. Sends notification_2 (stored in event_store) # 4. Returns response result = await session.call_tool("tool_with_standalone_stream_close", {}) - assert isinstance(result, CallToolResult) # Verify the tool completed assert result.content[0].type == "text" @@ -2136,7 +2125,6 @@ async def test_streamable_http_client_mcp_headers_override_defaults(context_app: # Use echo_headers tool to see what headers the server actually received tool_result = await session.call_tool("echo_headers", {}) - assert isinstance(tool_result, CallToolResult) assert len(tool_result.content) == 1 assert isinstance(tool_result.content[0], TextContent) headers_data = json.loads(tool_result.content[0].text) @@ -2166,7 +2154,6 @@ async def test_streamable_http_client_preserves_custom_with_mcp_headers(context_ # Use echo_headers tool to verify both custom and MCP headers are present tool_result = await session.call_tool("echo_headers", {}) - assert isinstance(tool_result, CallToolResult) assert len(tool_result.content) == 1 assert isinstance(tool_result.content[0], TextContent) headers_data = json.loads(tool_result.content[0].text) From 78b1f9d30012e4b1003f0f2e72b6cb5b4bc1e1a1 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 25 Jun 2026 14:46:32 +0000 Subject: [PATCH 5/6] send_request: pass by_name=False on the TypeAdapter branch MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Aligns with the class branch and the existing validate_python(..., by_name=False) call sites in types/methods.py — wire data must use camelCase aliases only. --- src/mcp/client/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 90dffa6e3d..6fb4dc3d84 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -313,7 +313,7 @@ async def send_request( except KeyError: pass if isinstance(result_type, TypeAdapter): - return result_type.validate_python(raw) + return result_type.validate_python(raw, by_name=False) return result_type.model_validate(raw, by_name=False) async def send_notification(self, notification: types.ClientNotification) -> None: From 68fbef9e375deb6f6cfe04862eadf5d0c5cb735a Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 25 Jun 2026 15:05:07 +0000 Subject: [PATCH 6/6] call_tool overloads: collapse to Literal[False] + bool on all three layers A runtime bool variable now matches the second overload (union return); literal False and the default still match the first (CallToolResult). Drops the redundant Literal[True] arm. --- src/mcp/client/client.py | 2 +- src/mcp/client/session.py | 14 -------------- src/mcp/client/session_group.py | 2 +- 3 files changed, 2 insertions(+), 16 deletions(-) diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index 38f546f1c4..362042ba5e 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -401,7 +401,7 @@ async def call_tool( input_responses: InputResponses | None = None, request_state: str | None = None, meta: RequestParamsMeta | None = None, - allow_input_required: Literal[True], + allow_input_required: bool, ) -> CallToolResult | InputRequiredResult: ... async def call_tool( diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 6fb4dc3d84..902cec80bc 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -616,20 +616,6 @@ async def call_tool( allow_input_required: Literal[False] = False, ) -> types.CallToolResult: ... - @overload - async def call_tool( - self, - name: str, - arguments: dict[str, Any] | None = None, - read_timeout_seconds: float | None = None, - progress_callback: ProgressFnT | None = None, - *, - input_responses: types.InputResponses | None = None, - request_state: str | None = None, - meta: RequestParamsMeta | None = None, - allow_input_required: Literal[True], - ) -> types.CallToolResult | types.InputRequiredResult: ... - @overload async def call_tool( self, diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 4f97caaa10..211733d6a3 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -215,7 +215,7 @@ async def call_tool( input_responses: types.InputResponses | None = None, request_state: str | None = None, meta: types.RequestParamsMeta | None = None, - allow_input_required: Literal[True], + allow_input_required: bool, ) -> types.CallToolResult | types.InputRequiredResult: ... async def call_tool(