From dd4939d879cc287340c23c174cb38f0f2220e4c7 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Thu, 6 Nov 2025 10:28:37 +0530 Subject: [PATCH 01/53] add transport abstraction --- src/mcp/client/transport_session.py | 133 ++++++++++++++++++++++++++++ 1 file changed, 133 insertions(+) create mode 100644 src/mcp/client/transport_session.py diff --git a/src/mcp/client/transport_session.py b/src/mcp/client/transport_session.py new file mode 100644 index 0000000000..f1cd84f17a --- /dev/null +++ b/src/mcp/client/transport_session.py @@ -0,0 +1,133 @@ +from abc import ABC +from abc import abstractmethod +from datetime import timedelta + +from typing import Any + +from pydantic import AnyUrl + +from mcp import types +from mcp.shared.session import ProgressFnT + + +class TransportSession(ABC): + """Abstract base class for communication transports.""" + + @abstractmethod + async def initialize(self) -> types.InitializeResult: + """Send an initialize request.""" + ... + + @abstractmethod + async def send_ping(self): + ... + + @abstractmethod + async def send_progress_notification( + self, + progress_token: str | int, + progress: float, + total: float | None = None, + message: str | None = None, + ) -> None: + ... + + @abstractmethod + async def set_logging_level( + self, + level: types.LoggingLevel, + ) -> types.EmptyResult: + """Send a logging/setLevel request.""" + ... + + @abstractmethod + async def list_resources( + self, + cursor: str | None = None, + ) -> types.ListResourcesResult: + """Send a resources/list request.""" + ... + + @abstractmethod + async def list_resource_templates( + self, + cursor: str | None = None, + ) -> types.ListResourceTemplatesResult: + """Send a resources/templates/list request.""" + ... + + @abstractmethod + async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult: + """Send a resources/read request.""" + ... + + @abstractmethod + async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: + """Send a resources/subscribe request.""" + ... + + @abstractmethod + async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: + """Send a resources/unsubscribe request.""" + ... + + @abstractmethod + async def call_tool( + self, + name: str, + arguments: Any | None = None, + read_timeout_seconds: timedelta | None = None, + progress_callback: ProgressFnT | None = None, + ) -> types.CallToolResult: + """Send a tools/call request with optional progress callback support.""" + ... + + @abstractmethod + async def _validate_tool_result( + self, + name: str, + result: types.CallToolResult, + ) -> None: + """Validate the structured content of a tool result against its output + schema.""" + ... + + @abstractmethod + async def list_prompts( + self, + cursor: str | None = None, + ) -> types.ListPromptsResult: + """Send a prompts/list request.""" + ... + + @abstractmethod + async def get_prompt( + self, + name: str, + arguments: dict[str, str] | None = None, + ) -> types.GetPromptResult: + """Send a prompts/get request.""" + ... + + @abstractmethod + async def complete( + self, + ref: types.ResourceTemplateReference | types.PromptReference, + argument: dict[str, str], + context_arguments: dict[str, str] | None = None, + ) -> types.CompleteResult: + """Send a completion/complete request.""" + ... + + @abstractmethod + async def list_tools( + self, + cursor: str | None = None, + ) -> types.ListToolsResult: + """Send a tools/list request.""" + ... + + @abstractmethod + async def send_roots_list_changed(self) -> None: + """Send a roots/list_changed notification.""" + ... \ No newline at end of file From 11d12494d89edc6ca2a1cb5e14b560c6fa0d2143 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Thu, 6 Nov 2025 06:10:10 +0000 Subject: [PATCH 02/53] fix ruff --- src/mcp/client/transport_session.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/mcp/client/transport_session.py b/src/mcp/client/transport_session.py index f1cd84f17a..7575afa552 100644 --- a/src/mcp/client/transport_session.py +++ b/src/mcp/client/transport_session.py @@ -1,7 +1,5 @@ -from abc import ABC -from abc import abstractmethod +from abc import ABC, abstractmethod from datetime import timedelta - from typing import Any from pydantic import AnyUrl From c8f3a42bf6f38ef7a4c215492709b23ea846d285 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Thu, 6 Nov 2025 06:12:32 +0000 Subject: [PATCH 03/53] fix ruff format --- src/mcp/client/transport_session.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/mcp/client/transport_session.py b/src/mcp/client/transport_session.py index 7575afa552..5ce6fd34ea 100644 --- a/src/mcp/client/transport_session.py +++ b/src/mcp/client/transport_session.py @@ -17,8 +17,7 @@ async def initialize(self) -> types.InitializeResult: ... @abstractmethod - async def send_ping(self): - ... + async def send_ping(self): ... @abstractmethod async def send_progress_notification( @@ -27,8 +26,7 @@ async def send_progress_notification( progress: float, total: float | None = None, message: str | None = None, - ) -> None: - ... + ) -> None: ... @abstractmethod async def set_logging_level( @@ -128,4 +126,4 @@ async def list_tools( @abstractmethod async def send_roots_list_changed(self) -> None: """Send a roots/list_changed notification.""" - ... \ No newline at end of file + ... From 03cc6c525aa6d0d94c3f0284fc80d8ec937263b9 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Thu, 6 Nov 2025 12:00:03 +0530 Subject: [PATCH 04/53] add transport session for server --- src/mcp/server/transport_session.py | 113 ++++++++++++++++++++++++++++ 1 file changed, 113 insertions(+) create mode 100644 src/mcp/server/transport_session.py diff --git a/src/mcp/server/transport_session.py b/src/mcp/server/transport_session.py new file mode 100644 index 0000000000..becc5a8554 --- /dev/null +++ b/src/mcp/server/transport_session.py @@ -0,0 +1,113 @@ +"""Abstract base class for transport sessions.""" + +import abc +from typing import Any + +from anyio.streams.memory import MemoryObjectReceiveStream +from pydantic import AnyUrl + +import mcp_grpc.types as types +from mcp_grpc.server.session import ServerRequestResponder + + +class TransportSession(abc.ABC): + """Abstract base class for transport sessions.""" + + @property + @abc.abstractmethod + def client_params(self) -> types.InitializeRequestParams | None: + """Client initialization parameters.""" + raise NotImplementedError + + @abc.abstractmethod + def check_client_capability(self, capability: types.ClientCapabilities) -> bool: + """Check if the client supports a specific capability.""" + raise NotImplementedError + + @abc.abstractmethod + async def send_log_message( + self, + level: types.LoggingLevel, + data: Any, + logger: str | None = None, + related_request_id: types.RequestId | None = None, + ) -> None: + """Send a log message notification.""" + raise NotImplementedError + + @abc.abstractmethod + async def send_resource_updated(self, uri: AnyUrl) -> None: + """Send a resource updated notification.""" + raise NotImplementedError + + @abc.abstractmethod + async def create_message( + self, + messages: list[types.SamplingMessage], + *, + max_tokens: int, + system_prompt: str | None = None, + include_context: types.IncludeContext | None = None, + temperature: float | None = None, + stop_sequences: list[str] | None = None, + metadata: dict[str, Any] | None = None, + model_preferences: types.ModelPreferences | None = None, + related_request_id: types.RequestId | None = None, + ) -> types.CreateMessageResult: + """Send a sampling/create_message request.""" + raise NotImplementedError + + @abc.abstractmethod + async def list_roots(self) -> types.ListRootsResult: + """Send a roots/list request.""" + raise NotImplementedError + + @abc.abstractmethod + async def elicit( + self, + message: str, + requestedSchema: types.ElicitRequestedSchema, + related_request_id: types.RequestId | None = None, + ) -> types.ElicitResult: + """Send an elicitation/create request.""" + raise NotImplementedError + + @abc.abstractmethod + async def send_ping(self) -> types.EmptyResult: + """Send a ping request.""" + raise NotImplementedError + + @abc.abstractmethod + async def send_progress_notification( + self, + progress_token: str | int, + progress: float, + total: float | None = None, + message: str | None = None, + related_request_id: str | None = None, + ) -> None: + """Send a progress notification.""" + raise NotImplementedError + + @abc.abstractmethod + async def send_resource_list_changed(self) -> None: + """Send a resource list changed notification.""" + raise NotImplementedError + + @abc.abstractmethod + async def send_tool_list_changed(self) -> None: + """Send a tool list changed notification.""" + raise NotImplementedError + + @abc.abstractmethod + async def send_prompt_list_changed(self) -> None: + """Send a prompt list changed notification.""" + raise NotImplementedError + + @property + @abc.abstractmethod + def incoming_messages( + self, + ) -> MemoryObjectReceiveStream[ServerRequestResponder]: + """Incoming messages stream.""" + raise NotImplementedError From 1327a9cca39f41ec4d08d7f8672d54e28869fbad Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Thu, 6 Nov 2025 12:01:29 +0530 Subject: [PATCH 05/53] clientsession and server session to implement abstract classes --- src/mcp/client/session.py | 3 +++ src/mcp/server/session.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 3835a2a577..339c64abdc 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -14,6 +14,8 @@ from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS +from src.mcp.client.transport_session import TransportSession + DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0") logger = logging.getLogger("client") @@ -100,6 +102,7 @@ async def _default_logging_callback( class ClientSession( + TransportSession, BaseSession[ types.ClientRequest, types.ClientNotification, diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index a1bfadc9fc..3dc888843c 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -54,6 +54,8 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: ) from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS +from src.mcp.server.transport_session import TransportSession + class InitializationState(Enum): NotInitialized = 1 @@ -69,6 +71,7 @@ class InitializationState(Enum): class ServerSession( + TransportSession, BaseSession[ types.ServerRequest, types.ServerNotification, From 0018679c02095d981fa1c84c779633c3b9aa31e4 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Thu, 6 Nov 2025 12:04:00 +0530 Subject: [PATCH 06/53] add raise not implemented --- src/mcp/client/transport_session.py | 38 ++++++++++++++++------------- src/mcp/server/transport_session.py | 4 +-- 2 files changed, 23 insertions(+), 19 deletions(-) diff --git a/src/mcp/client/transport_session.py b/src/mcp/client/transport_session.py index 5ce6fd34ea..a85b397184 100644 --- a/src/mcp/client/transport_session.py +++ b/src/mcp/client/transport_session.py @@ -1,5 +1,7 @@ -from abc import ABC, abstractmethod +from abc import ABC +from abc import abstractmethod from datetime import timedelta + from typing import Any from pydantic import AnyUrl @@ -14,10 +16,11 @@ class TransportSession(ABC): @abstractmethod async def initialize(self) -> types.InitializeResult: """Send an initialize request.""" - ... + raise NotImplementedError @abstractmethod - async def send_ping(self): ... + async def send_ping(self): + raise NotImplementedError @abstractmethod async def send_progress_notification( @@ -26,7 +29,8 @@ async def send_progress_notification( progress: float, total: float | None = None, message: str | None = None, - ) -> None: ... + ) -> None: + raise NotImplementedError @abstractmethod async def set_logging_level( @@ -34,7 +38,7 @@ async def set_logging_level( level: types.LoggingLevel, ) -> types.EmptyResult: """Send a logging/setLevel request.""" - ... + raise NotImplementedError @abstractmethod async def list_resources( @@ -42,7 +46,7 @@ async def list_resources( cursor: str | None = None, ) -> types.ListResourcesResult: """Send a resources/list request.""" - ... + raise NotImplementedError @abstractmethod async def list_resource_templates( @@ -50,22 +54,22 @@ async def list_resource_templates( cursor: str | None = None, ) -> types.ListResourceTemplatesResult: """Send a resources/templates/list request.""" - ... + raise NotImplementedError @abstractmethod async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult: """Send a resources/read request.""" - ... + raise NotImplementedError @abstractmethod async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: """Send a resources/subscribe request.""" - ... + raise NotImplementedError @abstractmethod async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: """Send a resources/unsubscribe request.""" - ... + raise NotImplementedError @abstractmethod async def call_tool( @@ -76,7 +80,7 @@ async def call_tool( progress_callback: ProgressFnT | None = None, ) -> types.CallToolResult: """Send a tools/call request with optional progress callback support.""" - ... + raise NotImplementedError @abstractmethod async def _validate_tool_result( @@ -86,7 +90,7 @@ async def _validate_tool_result( ) -> None: """Validate the structured content of a tool result against its output schema.""" - ... + raise NotImplementedError @abstractmethod async def list_prompts( @@ -94,7 +98,7 @@ async def list_prompts( cursor: str | None = None, ) -> types.ListPromptsResult: """Send a prompts/list request.""" - ... + raise NotImplementedError @abstractmethod async def get_prompt( @@ -103,7 +107,7 @@ async def get_prompt( arguments: dict[str, str] | None = None, ) -> types.GetPromptResult: """Send a prompts/get request.""" - ... + raise NotImplementedError @abstractmethod async def complete( @@ -113,7 +117,7 @@ async def complete( context_arguments: dict[str, str] | None = None, ) -> types.CompleteResult: """Send a completion/complete request.""" - ... + raise NotImplementedError @abstractmethod async def list_tools( @@ -121,9 +125,9 @@ async def list_tools( cursor: str | None = None, ) -> types.ListToolsResult: """Send a tools/list request.""" - ... + raise NotImplementedError @abstractmethod async def send_roots_list_changed(self) -> None: """Send a roots/list_changed notification.""" - ... + raise NotImplementedError \ No newline at end of file diff --git a/src/mcp/server/transport_session.py b/src/mcp/server/transport_session.py index becc5a8554..bd0d592e57 100644 --- a/src/mcp/server/transport_session.py +++ b/src/mcp/server/transport_session.py @@ -6,8 +6,8 @@ from anyio.streams.memory import MemoryObjectReceiveStream from pydantic import AnyUrl -import mcp_grpc.types as types -from mcp_grpc.server.session import ServerRequestResponder +import mcp.types as types +from mcp.server.session import ServerRequestResponder class TransportSession(abc.ABC): From af7ff5a0e3bd088a0b0aba0608a774188872c24d Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Thu, 6 Nov 2025 12:25:45 +0530 Subject: [PATCH 07/53] fix abstract server transport session --- src/mcp/server/transport_session.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/mcp/server/transport_session.py b/src/mcp/server/transport_session.py index bd0d592e57..efc6aad682 100644 --- a/src/mcp/server/transport_session.py +++ b/src/mcp/server/transport_session.py @@ -7,8 +7,6 @@ from pydantic import AnyUrl import mcp.types as types -from mcp.server.session import ServerRequestResponder - class TransportSession(abc.ABC): """Abstract base class for transport sessions.""" @@ -103,11 +101,3 @@ async def send_tool_list_changed(self) -> None: async def send_prompt_list_changed(self) -> None: """Send a prompt list changed notification.""" raise NotImplementedError - - @property - @abc.abstractmethod - def incoming_messages( - self, - ) -> MemoryObjectReceiveStream[ServerRequestResponder]: - """Incoming messages stream.""" - raise NotImplementedError From 7f468d0210782afabb29c90d004df3981c215956 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Thu, 6 Nov 2025 12:26:36 +0530 Subject: [PATCH 08/53] removed unused import --- src/mcp/server/transport_session.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/mcp/server/transport_session.py b/src/mcp/server/transport_session.py index efc6aad682..f23a8361da 100644 --- a/src/mcp/server/transport_session.py +++ b/src/mcp/server/transport_session.py @@ -3,7 +3,6 @@ import abc from typing import Any -from anyio.streams.memory import MemoryObjectReceiveStream from pydantic import AnyUrl import mcp.types as types From e895d90b5101cc64fc611074c4bc77899a462230 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Thu, 6 Nov 2025 12:35:24 +0530 Subject: [PATCH 09/53] fix type hints --- src/mcp/server/elicitation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mcp/server/elicitation.py b/src/mcp/server/elicitation.py index bba988f496..47be94b138 100644 --- a/src/mcp/server/elicitation.py +++ b/src/mcp/server/elicitation.py @@ -8,7 +8,7 @@ from pydantic import BaseModel from pydantic.fields import FieldInfo -from mcp.server.session import ServerSession +from mcp.server.transport_session import TransportSession from mcp.types import RequestId ElicitSchemaModelT = TypeVar("ElicitSchemaModelT", bound=BaseModel) @@ -74,7 +74,7 @@ def _is_primitive_field(field_info: FieldInfo) -> bool: async def elicit_with_validation( - session: ServerSession, + session: TransportSession, message: str, schema: type[ElicitSchemaModelT], related_request_id: RequestId | None = None, From d01e477a3b7b77c94f2f08c78596f8deba9b48c4 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Thu, 6 Nov 2025 12:42:52 +0530 Subject: [PATCH 10/53] revert type hints --- src/mcp/server/elicitation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mcp/server/elicitation.py b/src/mcp/server/elicitation.py index 47be94b138..bba988f496 100644 --- a/src/mcp/server/elicitation.py +++ b/src/mcp/server/elicitation.py @@ -8,7 +8,7 @@ from pydantic import BaseModel from pydantic.fields import FieldInfo -from mcp.server.transport_session import TransportSession +from mcp.server.session import ServerSession from mcp.types import RequestId ElicitSchemaModelT = TypeVar("ElicitSchemaModelT", bound=BaseModel) @@ -74,7 +74,7 @@ def _is_primitive_field(field_info: FieldInfo) -> bool: async def elicit_with_validation( - session: TransportSession, + session: ServerSession, message: str, schema: type[ElicitSchemaModelT], related_request_id: RequestId | None = None, From 7bdafa384f796e16505104b089b2d5c5a636ec7b Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Thu, 6 Nov 2025 12:44:58 +0530 Subject: [PATCH 11/53] fix import --- src/mcp/server/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 3dc888843c..00355ae9ef 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -54,7 +54,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: ) from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS -from src.mcp.server.transport_session import TransportSession +from mcp.server.transport_session import TransportSession class InitializationState(Enum): From e9f63dd45f65624c72d000bbb91dc3bcd6790adb Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Thu, 6 Nov 2025 12:46:44 +0530 Subject: [PATCH 12/53] fix import --- src/mcp/client/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 339c64abdc..c058de1721 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -14,7 +14,7 @@ from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS -from src.mcp.client.transport_session import TransportSession +from mcp.client.transport_session import TransportSession DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0") From 5b156a16728cdccf22fa3991316fbd5127d636e3 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Thu, 6 Nov 2025 07:18:36 +0000 Subject: [PATCH 13/53] fix ruff format --- src/mcp/client/session.py | 2 +- src/mcp/client/transport_session.py | 2 +- src/mcp/server/session.py | 2 +- src/mcp/server/transport_session.py | 1 + 4 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index c058de1721..c07ca8c50d 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -109,7 +109,7 @@ class ClientSession( types.ClientResult, types.ServerRequest, types.ServerNotification, - ] + ], ): def __init__( self, diff --git a/src/mcp/client/transport_session.py b/src/mcp/client/transport_session.py index a85b397184..8dbe1a82df 100644 --- a/src/mcp/client/transport_session.py +++ b/src/mcp/client/transport_session.py @@ -130,4 +130,4 @@ async def list_tools( @abstractmethod async def send_roots_list_changed(self) -> None: """Send a roots/list_changed notification.""" - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 00355ae9ef..e50e7d0042 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -78,7 +78,7 @@ class ServerSession( types.ServerResult, types.ClientRequest, types.ClientNotification, - ] + ], ): _initialized: InitializationState = InitializationState.NotInitialized _client_params: types.InitializeRequestParams | None = None diff --git a/src/mcp/server/transport_session.py b/src/mcp/server/transport_session.py index f23a8361da..d0288a0f35 100644 --- a/src/mcp/server/transport_session.py +++ b/src/mcp/server/transport_session.py @@ -7,6 +7,7 @@ import mcp.types as types + class TransportSession(abc.ABC): """Abstract base class for transport sessions.""" From f26d861db283bfb4903df0cbf48a88ff8446576b Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Thu, 6 Nov 2025 12:58:12 +0530 Subject: [PATCH 14/53] request context as optional param --- src/mcp/server/fastmcp/server.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 865b8e7e72..9871063c3b 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -326,10 +326,18 @@ def get_context(self) -> Context[ServerSession, LifespanResultT, Request]: request_context = None return Context(request_context=request_context, fastmcp=self) - async def call_tool(self, name: str, arguments: dict[str, Any]) -> Sequence[ContentBlock] | dict[str, Any]: + async def call_tool( + self, name: str, arguments: dict[str, Any], + request_context: RequestContext | None = None + ) -> Sequence[ContentBlock] | dict[str, Any]: """Call a tool by name with arguments.""" - context = self.get_context() - return await self._tool_manager.call_tool(name, arguments, context=context, convert_result=True) + if request_context: + context = Context(request_context=request_context, fastmcp=self) + else: + context = self.get_context() + return await self._tool_manager.call_tool(name, arguments, + context=context, + convert_result=True) async def list_resources(self) -> list[MCPResource]: """List all available resources.""" From 3097cb3a3360ca5993c873b8b001099120868e28 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Thu, 6 Nov 2025 07:28:45 +0000 Subject: [PATCH 15/53] fix format --- src/mcp/server/fastmcp/server.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 9871063c3b..7da7ca43d4 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -327,17 +327,14 @@ def get_context(self) -> Context[ServerSession, LifespanResultT, Request]: return Context(request_context=request_context, fastmcp=self) async def call_tool( - self, name: str, arguments: dict[str, Any], - request_context: RequestContext | None = None + self, name: str, arguments: dict[str, Any], request_context: RequestContext | None = None ) -> Sequence[ContentBlock] | dict[str, Any]: """Call a tool by name with arguments.""" if request_context: context = Context(request_context=request_context, fastmcp=self) else: context = self.get_context() - return await self._tool_manager.call_tool(name, arguments, - context=context, - convert_result=True) + return await self._tool_manager.call_tool(name, arguments, context=context, convert_result=True) async def list_resources(self) -> list[MCPResource]: """List all available resources.""" From 9e8dca3a075e393c70af798c0fb13fcbdd493c30 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Thu, 6 Nov 2025 08:08:19 +0000 Subject: [PATCH 16/53] ruff check --fix --- src/mcp/client/session.py | 3 +-- src/mcp/client/transport_session.py | 4 +--- src/mcp/server/session.py | 3 +-- 3 files changed, 3 insertions(+), 7 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index c07ca8c50d..02646924bd 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -9,13 +9,12 @@ from typing_extensions import deprecated import mcp.types as types +from mcp.client.transport_session import TransportSession from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS -from mcp.client.transport_session import TransportSession - DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0") logger = logging.getLogger("client") diff --git a/src/mcp/client/transport_session.py b/src/mcp/client/transport_session.py index 8dbe1a82df..9f9f3f8c47 100644 --- a/src/mcp/client/transport_session.py +++ b/src/mcp/client/transport_session.py @@ -1,7 +1,5 @@ -from abc import ABC -from abc import abstractmethod +from abc import ABC, abstractmethod from datetime import timedelta - from typing import Any from pydantic import AnyUrl diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index e50e7d0042..99fdb8f3f7 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -47,6 +47,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: import mcp.types as types from mcp.server.models import InitializationOptions +from mcp.server.transport_session import TransportSession from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.session import ( BaseSession, @@ -54,8 +55,6 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: ) from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS -from mcp.server.transport_session import TransportSession - class InitializationState(Enum): NotInitialized = 1 From 5b7b458f963d608252efaf13b9d39fcd6bb1824e Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Thu, 6 Nov 2025 08:16:17 +0000 Subject: [PATCH 17/53] fix pyright --- src/mcp/client/transport_session.py | 2 +- src/mcp/server/fastmcp/server.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mcp/client/transport_session.py b/src/mcp/client/transport_session.py index 9f9f3f8c47..71e69ee3e4 100644 --- a/src/mcp/client/transport_session.py +++ b/src/mcp/client/transport_session.py @@ -17,7 +17,7 @@ async def initialize(self) -> types.InitializeResult: raise NotImplementedError @abstractmethod - async def send_ping(self): + async def send_ping(self) -> types.EmptyResult: raise NotImplementedError @abstractmethod diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 7da7ca43d4..05bf7f3b7d 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -327,7 +327,7 @@ def get_context(self) -> Context[ServerSession, LifespanResultT, Request]: return Context(request_context=request_context, fastmcp=self) async def call_tool( - self, name: str, arguments: dict[str, Any], request_context: RequestContext | None = None + self, name: str, arguments: dict[str, Any], request_context: RequestContext[ServerSession, LifespanResultT, Request] | None = None ) -> Sequence[ContentBlock] | dict[str, Any]: """Call a tool by name with arguments.""" if request_context: From 8ca511ef8b55302411b3d0ef356ebc08789f5661 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Thu, 6 Nov 2025 08:17:19 +0000 Subject: [PATCH 18/53] ruff fix --- src/mcp/server/fastmcp/server.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 05bf7f3b7d..c9883ca566 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -327,7 +327,10 @@ def get_context(self) -> Context[ServerSession, LifespanResultT, Request]: return Context(request_context=request_context, fastmcp=self) async def call_tool( - self, name: str, arguments: dict[str, Any], request_context: RequestContext[ServerSession, LifespanResultT, Request] | None = None + self, + name: str, + arguments: dict[str, Any], + request_context: RequestContext[ServerSession, LifespanResultT, Request] | None = None, ) -> Sequence[ContentBlock] | dict[str, Any]: """Call a tool by name with arguments.""" if request_context: From 53e02fe7403c169c4453d4dc1f74f33a660f187f Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Thu, 6 Nov 2025 16:18:15 +0530 Subject: [PATCH 19/53] removed fat abstract class --- src/mcp/server/transport_session.py | 28 ---------------------------- 1 file changed, 28 deletions(-) diff --git a/src/mcp/server/transport_session.py b/src/mcp/server/transport_session.py index d0288a0f35..fcb4a21e8a 100644 --- a/src/mcp/server/transport_session.py +++ b/src/mcp/server/transport_session.py @@ -11,17 +11,6 @@ class TransportSession(abc.ABC): """Abstract base class for transport sessions.""" - @property - @abc.abstractmethod - def client_params(self) -> types.InitializeRequestParams | None: - """Client initialization parameters.""" - raise NotImplementedError - - @abc.abstractmethod - def check_client_capability(self, capability: types.ClientCapabilities) -> bool: - """Check if the client supports a specific capability.""" - raise NotImplementedError - @abc.abstractmethod async def send_log_message( self, @@ -38,23 +27,6 @@ async def send_resource_updated(self, uri: AnyUrl) -> None: """Send a resource updated notification.""" raise NotImplementedError - @abc.abstractmethod - async def create_message( - self, - messages: list[types.SamplingMessage], - *, - max_tokens: int, - system_prompt: str | None = None, - include_context: types.IncludeContext | None = None, - temperature: float | None = None, - stop_sequences: list[str] | None = None, - metadata: dict[str, Any] | None = None, - model_preferences: types.ModelPreferences | None = None, - related_request_id: types.RequestId | None = None, - ) -> types.CreateMessageResult: - """Send a sampling/create_message request.""" - raise NotImplementedError - @abc.abstractmethod async def list_roots(self) -> types.ListRootsResult: """Send a roots/list request.""" From cf0f15243b785cc6c14e94e0c0c1af4634e1e7b8 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Thu, 6 Nov 2025 16:27:28 +0530 Subject: [PATCH 20/53] removed client a thin interface --- src/mcp/client/transport_session.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/mcp/client/transport_session.py b/src/mcp/client/transport_session.py index 71e69ee3e4..41150a039b 100644 --- a/src/mcp/client/transport_session.py +++ b/src/mcp/client/transport_session.py @@ -80,16 +80,6 @@ async def call_tool( """Send a tools/call request with optional progress callback support.""" raise NotImplementedError - @abstractmethod - async def _validate_tool_result( - self, - name: str, - result: types.CallToolResult, - ) -> None: - """Validate the structured content of a tool result against its output - schema.""" - raise NotImplementedError - @abstractmethod async def list_prompts( self, From ccbdde86fa9042514d42707042d02afac79ba9fb Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Thu, 6 Nov 2025 16:49:00 +0530 Subject: [PATCH 21/53] add description --- src/mcp/client/transport_session.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/mcp/client/transport_session.py b/src/mcp/client/transport_session.py index 41150a039b..6157749893 100644 --- a/src/mcp/client/transport_session.py +++ b/src/mcp/client/transport_session.py @@ -18,6 +18,7 @@ async def initialize(self) -> types.InitializeResult: @abstractmethod async def send_ping(self) -> types.EmptyResult: + """Send a ping request.""" raise NotImplementedError @abstractmethod @@ -28,6 +29,7 @@ async def send_progress_notification( total: float | None = None, message: str | None = None, ) -> None: + """Send a progress notification.""" raise NotImplementedError @abstractmethod From 380710e49e07e40895ffb1f0cbae9af54a7d2bf5 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Thu, 6 Nov 2025 17:10:29 +0530 Subject: [PATCH 22/53] revert context change in this pr --- src/mcp/server/fastmcp/server.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index c9883ca566..865b8e7e72 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -326,17 +326,9 @@ def get_context(self) -> Context[ServerSession, LifespanResultT, Request]: request_context = None return Context(request_context=request_context, fastmcp=self) - async def call_tool( - self, - name: str, - arguments: dict[str, Any], - request_context: RequestContext[ServerSession, LifespanResultT, Request] | None = None, - ) -> Sequence[ContentBlock] | dict[str, Any]: + async def call_tool(self, name: str, arguments: dict[str, Any]) -> Sequence[ContentBlock] | dict[str, Any]: """Call a tool by name with arguments.""" - if request_context: - context = Context(request_context=request_context, fastmcp=self) - else: - context = self.get_context() + context = self.get_context() return await self._tool_manager.call_tool(name, arguments, context=context, convert_result=True) async def list_resources(self) -> list[MCPResource]: From 3f977b380cdcfcc048f23981c0985142ff6741d2 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 7 Nov 2025 11:51:36 +0530 Subject: [PATCH 23/53] rename classes --- src/mcp/__init__.py | 4 ++++ src/mcp/client/session.py | 4 ++-- src/mcp/client/transport_session.py | 2 +- src/mcp/server/session.py | 4 ++-- src/mcp/server/transport_session.py | 22 +++++++++++----------- 5 files changed, 20 insertions(+), 16 deletions(-) diff --git a/src/mcp/__init__.py b/src/mcp/__init__.py index e93b95c902..ae74dfa326 100644 --- a/src/mcp/__init__.py +++ b/src/mcp/__init__.py @@ -1,4 +1,6 @@ from .client.session import ClientSession +from .client.transport_session import ClientTransportSession +from .server.transport_session import ServerTransportSession from .client.session_group import ClientSessionGroup from .client.stdio import StdioServerParameters, stdio_client from .server.session import ServerSession @@ -113,4 +115,6 @@ "stdio_server", "CompleteRequest", "JSONRPCResponse", + "ClientTransportSession", + "ServerTransportSession", ] diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 02646924bd..4243fa999e 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -9,7 +9,7 @@ from typing_extensions import deprecated import mcp.types as types -from mcp.client.transport_session import TransportSession +from mcp.client.transport_session import ClientTransportSession from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder @@ -101,7 +101,7 @@ async def _default_logging_callback( class ClientSession( - TransportSession, + ClientTransportSession, BaseSession[ types.ClientRequest, types.ClientNotification, diff --git a/src/mcp/client/transport_session.py b/src/mcp/client/transport_session.py index 6157749893..6f6f523226 100644 --- a/src/mcp/client/transport_session.py +++ b/src/mcp/client/transport_session.py @@ -8,7 +8,7 @@ from mcp.shared.session import ProgressFnT -class TransportSession(ABC): +class ClientTransportSession(ABC): """Abstract base class for communication transports.""" @abstractmethod diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 99fdb8f3f7..96f879034c 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -47,7 +47,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: import mcp.types as types from mcp.server.models import InitializationOptions -from mcp.server.transport_session import TransportSession +from mcp.server.transport_session import ServerTransportSession from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.session import ( BaseSession, @@ -70,7 +70,7 @@ class InitializationState(Enum): class ServerSession( - TransportSession, + ServerTransportSession, BaseSession[ types.ServerRequest, types.ServerNotification, diff --git a/src/mcp/server/transport_session.py b/src/mcp/server/transport_session.py index fcb4a21e8a..bf3f6a1d1c 100644 --- a/src/mcp/server/transport_session.py +++ b/src/mcp/server/transport_session.py @@ -1,6 +1,6 @@ """Abstract base class for transport sessions.""" -import abc +from abc import ABC, abstractmethod from typing import Any from pydantic import AnyUrl @@ -8,10 +8,10 @@ import mcp.types as types -class TransportSession(abc.ABC): +class ServerTransportSession(ABC): """Abstract base class for transport sessions.""" - @abc.abstractmethod + @abstractmethod async def send_log_message( self, level: types.LoggingLevel, @@ -22,17 +22,17 @@ async def send_log_message( """Send a log message notification.""" raise NotImplementedError - @abc.abstractmethod + @abstractmethod async def send_resource_updated(self, uri: AnyUrl) -> None: """Send a resource updated notification.""" raise NotImplementedError - @abc.abstractmethod + @abstractmethod async def list_roots(self) -> types.ListRootsResult: """Send a roots/list request.""" raise NotImplementedError - @abc.abstractmethod + @abstractmethod async def elicit( self, message: str, @@ -42,12 +42,12 @@ async def elicit( """Send an elicitation/create request.""" raise NotImplementedError - @abc.abstractmethod + @abstractmethod async def send_ping(self) -> types.EmptyResult: """Send a ping request.""" raise NotImplementedError - @abc.abstractmethod + @abstractmethod async def send_progress_notification( self, progress_token: str | int, @@ -59,17 +59,17 @@ async def send_progress_notification( """Send a progress notification.""" raise NotImplementedError - @abc.abstractmethod + @abstractmethod async def send_resource_list_changed(self) -> None: """Send a resource list changed notification.""" raise NotImplementedError - @abc.abstractmethod + @abstractmethod async def send_tool_list_changed(self) -> None: """Send a tool list changed notification.""" raise NotImplementedError - @abc.abstractmethod + @abstractmethod async def send_prompt_list_changed(self) -> None: """Send a prompt list changed notification.""" raise NotImplementedError From ec7b6d6a2592c243dff686d3bc27685f186759c8 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 7 Nov 2025 06:23:59 +0000 Subject: [PATCH 24/53] ruff fix --- src/mcp/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mcp/__init__.py b/src/mcp/__init__.py index ae74dfa326..93ef8acdf7 100644 --- a/src/mcp/__init__.py +++ b/src/mcp/__init__.py @@ -1,10 +1,10 @@ from .client.session import ClientSession -from .client.transport_session import ClientTransportSession -from .server.transport_session import ServerTransportSession from .client.session_group import ClientSessionGroup from .client.stdio import StdioServerParameters, stdio_client +from .client.transport_session import ClientTransportSession from .server.session import ServerSession from .server.stdio import stdio_server +from .server.transport_session import ServerTransportSession from .shared.exceptions import McpError from .types import ( CallToolRequest, From 0359aa899a2a89388e01816c4c7dd48c57ca196d Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Wed, 12 Nov 2025 14:18:15 +0530 Subject: [PATCH 25/53] merge main --- .../mcp_simple_auth_client/main.py | 4 ++-- .../simple-chatbot/mcp_simple_chatbot/main.py | 3 ++- .../snippets/clients/display_utilities.py | 5 ++-- examples/snippets/clients/stdio_client.py | 3 ++- src/mcp/client/session.py | 16 ++++++------- src/mcp/client/session_group.py | 24 ++++++++++--------- src/mcp/client/transport_session.py | 22 +++++++++++++++-- src/mcp/shared/context.py | 3 ++- tests/client/test_list_roots_callback.py | 4 ++-- tests/client/test_sampling_callback.py | 4 ++-- tests/client/test_session.py | 6 ++--- tests/server/fastmcp/test_elicitation.py | 18 +++++++------- tests/server/fastmcp/test_integration.py | 6 ++--- tests/shared/test_streamable_http.py | 4 ++-- 14 files changed, 74 insertions(+), 48 deletions(-) diff --git a/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py b/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py index 5987a878ef..6c7201e044 100644 --- a/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py +++ b/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py @@ -17,7 +17,7 @@ from urllib.parse import parse_qs, urlparse from mcp.client.auth import OAuthClientProvider, TokenStorage -from mcp.client.session import ClientSession +from mcp.client.session import ClientSession, ClientTransportSession from mcp.client.sse import sse_client from mcp.client.streamable_http import streamablehttp_client from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken @@ -153,7 +153,7 @@ class SimpleAuthClient: def __init__(self, server_url: str, transport_type: str = "streamable-http"): self.server_url = server_url self.transport_type = transport_type - self.session: ClientSession | None = None + self.session: ClientTransportSession | None = None async def connect(self): """Connect to the MCP server.""" diff --git a/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py b/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py index 78a81a4d9f..3a9d201b17 100644 --- a/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py +++ b/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py @@ -10,6 +10,7 @@ from dotenv import load_dotenv from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client +from mcp.client.transport_session import ClientTransportSession # Configure logging logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") @@ -67,7 +68,7 @@ def __init__(self, name: str, config: dict[str, Any]) -> None: self.name: str = name self.config: dict[str, Any] = config self.stdio_context: Any | None = None - self.session: ClientSession | None = None + self.session: ClientTransportSession | None = None self._cleanup_lock: asyncio.Lock = asyncio.Lock() self.exit_stack: AsyncExitStack = AsyncExitStack() diff --git a/examples/snippets/clients/display_utilities.py b/examples/snippets/clients/display_utilities.py index 5f1d50510d..5e1b203ee6 100644 --- a/examples/snippets/clients/display_utilities.py +++ b/examples/snippets/clients/display_utilities.py @@ -8,6 +8,7 @@ from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client +from mcp.client.transport_session import ClientTransportSession from mcp.shared.metadata_utils import get_display_name # Create server parameters for stdio connection @@ -18,7 +19,7 @@ ) -async def display_tools(session: ClientSession): +async def display_tools(session: ClientTransportSession): """Display available tools with human-readable names""" tools_response = await session.list_tools() @@ -30,7 +31,7 @@ async def display_tools(session: ClientSession): print(f" {tool.description}") -async def display_resources(session: ClientSession): +async def display_resources(session: ClientTransportSession): """Display available resources with human-readable names""" resources_response = await session.list_resources() diff --git a/examples/snippets/clients/stdio_client.py b/examples/snippets/clients/stdio_client.py index ac978035d4..62fb0f4c47 100644 --- a/examples/snippets/clients/stdio_client.py +++ b/examples/snippets/clients/stdio_client.py @@ -9,6 +9,7 @@ from pydantic import AnyUrl from mcp import ClientSession, StdioServerParameters, types +from mcp.client.session import ClientTransportSession from mcp.client.stdio import stdio_client from mcp.shared.context import RequestContext @@ -22,7 +23,7 @@ # Optional: create a sampling callback async def handle_sampling_message( - context: RequestContext[ClientSession, None], params: types.CreateMessageRequestParams + context: RequestContext[ClientTransportSession, None], params: types.CreateMessageRequestParams ) -> types.CreateMessageResult: print(f"Sampling request: {params.messages}") return types.CreateMessageResult( diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 4243fa999e..c3559b13a4 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -23,7 +23,7 @@ class SamplingFnT(Protocol): async def __call__( self, - context: RequestContext["ClientSession", Any], + context: RequestContext["ClientTransportSession", Any], params: types.CreateMessageRequestParams, ) -> types.CreateMessageResult | types.ErrorData: ... # pragma: no branch @@ -31,15 +31,15 @@ async def __call__( class ElicitationFnT(Protocol): async def __call__( self, - context: RequestContext["ClientSession", Any], + context: RequestContext["ClientTransportSession", Any], params: types.ElicitRequestParams, ) -> types.ElicitResult | types.ErrorData: ... # pragma: no branch class ListRootsFnT(Protocol): async def __call__( - self, context: RequestContext["ClientSession", Any] - ) -> types.ListRootsResult | types.ErrorData: ... # pragma: no branch + self, context: RequestContext["ClientTransportSession", Any] + ) -> types.ListRootsResult | types.ErrorData: ... # pragma: no branch class LoggingFnT(Protocol): @@ -63,7 +63,7 @@ async def _default_message_handler( async def _default_sampling_callback( - context: RequestContext["ClientSession", Any], + context: RequestContext["ClientTransportSession", Any], params: types.CreateMessageRequestParams, ) -> types.CreateMessageResult | types.ErrorData: return types.ErrorData( @@ -73,7 +73,7 @@ async def _default_sampling_callback( async def _default_elicitation_callback( - context: RequestContext["ClientSession", Any], + context: RequestContext["ClientTransportSession", Any], params: types.ElicitRequestParams, ) -> types.ElicitResult | types.ErrorData: return types.ErrorData( # pragma: no cover @@ -83,7 +83,7 @@ async def _default_elicitation_callback( async def _default_list_roots_callback( - context: RequestContext["ClientSession", Any], + context: RequestContext["ClientTransportSession", Any], ) -> types.ListRootsResult | types.ErrorData: return types.ErrorData( code=types.INVALID_REQUEST, @@ -510,7 +510,7 @@ async def send_roots_list_changed(self) -> None: # pragma: no cover await self.send_notification(types.ClientNotification(types.RootsListChangedNotification())) async def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None: - ctx = RequestContext[ClientSession, Any]( + ctx = RequestContext[ClientTransportSession, Any]( request_id=responder.request_id, meta=responder.request_meta, session=self, diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 2c55bb7752..9e95ed909a 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -96,10 +96,10 @@ class _ComponentNames(BaseModel): _tools: dict[str, types.Tool] # Client-server connection management. - _sessions: dict[mcp.ClientSession, _ComponentNames] - _tool_to_session: dict[str, mcp.ClientSession] + _sessions: dict[mcp.ClientTransportSession, _ComponentNames] + _tool_to_session: dict[str, mcp.ClientTransportSession] _exit_stack: contextlib.AsyncExitStack - _session_exit_stacks: dict[mcp.ClientSession, contextlib.AsyncExitStack] + _session_exit_stacks: dict[mcp.ClientTransportSession, contextlib.AsyncExitStack] # Optional fn consuming (component_name, serverInfo) for custom names. # This is provide a means to mitigate naming conflicts across servers. @@ -153,7 +153,7 @@ async def __aexit__( tg.start_soon(exit_stack.aclose) @property - def sessions(self) -> list[mcp.ClientSession]: + def sessions(self) -> list[mcp.ClientTransportSession]: """Returns the list of sessions being managed.""" return list(self._sessions.keys()) # pragma: no cover @@ -178,7 +178,7 @@ async def call_tool(self, name: str, args: dict[str, Any]) -> types.CallToolResu session_tool_name = self.tools[name].name return await session.call_tool(session_tool_name, args) - async def disconnect_from_server(self, session: mcp.ClientSession) -> None: + async def disconnect_from_server(self, session: mcp.ClientTransportSession) -> None: """Disconnects from a single MCP server.""" session_known_for_components = session in self._sessions @@ -216,8 +216,8 @@ async def disconnect_from_server(self, session: mcp.ClientSession) -> None: await session_stack_to_close.aclose() # pragma: no cover async def connect_with_session( - self, server_info: types.Implementation, session: mcp.ClientSession - ) -> mcp.ClientSession: + self, server_info: types.Implementation, session: mcp.ClientTransportSession + ) -> mcp.ClientTransportSession: """Connects to a single MCP server.""" await self._aggregate_components(server_info, session) return session @@ -225,14 +225,14 @@ async def connect_with_session( async def connect_to_server( self, server_params: ServerParameters, - ) -> mcp.ClientSession: + ) -> mcp.ClientTransportSession: """Connects to a single MCP server.""" server_info, session = await self._establish_session(server_params) return await self.connect_with_session(server_info, session) async def _establish_session( self, server_params: ServerParameters - ) -> tuple[types.Implementation, mcp.ClientSession]: + ) -> tuple[types.Implementation, mcp.ClientTransportSession]: """Establish a client session to an MCP server.""" session_stack = contextlib.AsyncExitStack() @@ -276,7 +276,9 @@ async def _establish_session( await session_stack.aclose() raise - async def _aggregate_components(self, server_info: types.Implementation, session: mcp.ClientSession) -> None: + async def _aggregate_components( + self, server_info: types.Implementation, session: mcp.ClientTransportSession + ) -> None: """Aggregates prompts, resources, and tools from a given session.""" # Create a reverse index so we can find all prompts, resources, and @@ -289,7 +291,7 @@ async def _aggregate_components(self, server_info: types.Implementation, session prompts_temp: dict[str, types.Prompt] = {} resources_temp: dict[str, types.Resource] = {} tools_temp: dict[str, types.Tool] = {} - tool_to_session_temp: dict[str, mcp.ClientSession] = {} + tool_to_session_temp: dict[str, mcp.ClientTransportSession] = {} # Query the server for its prompts and aggregate to list. try: diff --git a/src/mcp/client/transport_session.py b/src/mcp/client/transport_session.py index 6f6f523226..c51b059f66 100644 --- a/src/mcp/client/transport_session.py +++ b/src/mcp/client/transport_session.py @@ -1,8 +1,9 @@ from abc import ABC, abstractmethod from datetime import timedelta -from typing import Any +from typing import Any, overload from pydantic import AnyUrl +from typing_extensions import deprecated from mcp import types from mcp.shared.session import ProgressFnT @@ -109,12 +110,29 @@ async def complete( """Send a completion/complete request.""" raise NotImplementedError + @overload + @deprecated("Use list_tools(params=PaginatedRequestParams(...)) instead") + async def list_tools(self, cursor: str | None) -> types.ListToolsResult: ... + + @overload + async def list_tools(self, *, params: types.PaginatedRequestParams | None) -> types.ListToolsResult: ... + + @overload + async def list_tools(self) -> types.ListToolsResult: ... + @abstractmethod async def list_tools( self, cursor: str | None = None, + *, + params: types.PaginatedRequestParams | None = None, ) -> types.ListToolsResult: - """Send a tools/list request.""" + """Send a tools/list request. + + Args: + cursor: Simple cursor string for pagination (deprecated, use params instead) + params: Full pagination parameters including cursor and any future fields + """ raise NotImplementedError @abstractmethod diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index f3006e7d5f..0fb12c649c 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -3,10 +3,11 @@ from typing_extensions import TypeVar +from mcp.client.transport_session import ClientTransportSession from mcp.shared.session import BaseSession from mcp.types import RequestId, RequestParams -SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any]) +SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any] | ClientTransportSession) LifespanContextT = TypeVar("LifespanContextT") RequestT = TypeVar("RequestT", default=Any) diff --git a/tests/client/test_list_roots_callback.py b/tests/client/test_list_roots_callback.py index 0da0fff07a..dc53eddbcb 100644 --- a/tests/client/test_list_roots_callback.py +++ b/tests/client/test_list_roots_callback.py @@ -1,7 +1,7 @@ import pytest from pydantic import FileUrl -from mcp.client.session import ClientSession +from mcp.client.session import ClientTransportSession from mcp.server.fastmcp.server import Context from mcp.server.session import ServerSession from mcp.shared.context import RequestContext @@ -31,7 +31,7 @@ async def test_list_roots_callback(): ) async def list_roots_callback( - context: RequestContext[ClientSession, None], + context: RequestContext[ClientTransportSession, None], ) -> ListRootsResult: return callback_return diff --git a/tests/client/test_sampling_callback.py b/tests/client/test_sampling_callback.py index a3f6affda8..8cd2c71166 100644 --- a/tests/client/test_sampling_callback.py +++ b/tests/client/test_sampling_callback.py @@ -1,6 +1,6 @@ import pytest -from mcp.client.session import ClientSession +from mcp.client.session import ClientTransportSession from mcp.shared.context import RequestContext from mcp.shared.memory import ( create_connected_server_and_client_session as create_session, @@ -27,7 +27,7 @@ async def test_sampling_callback(): ) async def sampling_callback( - context: RequestContext[ClientSession, None], + context: RequestContext[ClientTransportSession, None], params: CreateMessageRequestParams, ) -> CreateMessageResult: return callback_return diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 8d0ef68a98..c327a806f6 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -4,7 +4,7 @@ import pytest import mcp.types as types -from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession +from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession, ClientTransportSession from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder @@ -427,7 +427,7 @@ async def test_client_capabilities_with_custom_callbacks(): received_capabilities = None async def custom_sampling_callback( # pragma: no cover - context: RequestContext["ClientSession", Any], + context: RequestContext["ClientTransportSession", Any], params: types.CreateMessageRequestParams, ) -> types.CreateMessageResult | types.ErrorData: return types.CreateMessageResult( @@ -437,7 +437,7 @@ async def custom_sampling_callback( # pragma: no cover ) async def custom_list_roots_callback( # pragma: no cover - context: RequestContext["ClientSession", Any], + context: RequestContext["ClientTransportSession", Any], ) -> types.ListRootsResult | types.ErrorData: return types.ListRootsResult(roots=[]) diff --git a/tests/server/fastmcp/test_elicitation.py b/tests/server/fastmcp/test_elicitation.py index 2c74d0e88b..dd1ae72dc7 100644 --- a/tests/server/fastmcp/test_elicitation.py +++ b/tests/server/fastmcp/test_elicitation.py @@ -7,7 +7,7 @@ import pytest from pydantic import BaseModel, Field -from mcp.client.session import ClientSession, ElicitationFnT +from mcp.client.session import ClientTransportSession, ElicitationFnT from mcp.server.fastmcp import Context, FastMCP from mcp.server.session import ServerSession from mcp.shared.context import RequestContext @@ -72,7 +72,7 @@ async def test_stdio_elicitation(): # Create a custom handler for elicitation requests async def elicitation_callback( - context: RequestContext[ClientSession, None], params: ElicitRequestParams + context: RequestContext[ClientTransportSession, None], params: ElicitRequestParams ): # pragma: no cover if params.message == "Tool wants to ask: What is your name?": return ElicitResult(action="accept", content={"answer": "Test User"}) @@ -90,7 +90,7 @@ async def test_stdio_elicitation_decline(): mcp = FastMCP(name="StdioElicitationDeclineServer") create_ask_user_tool(mcp) - async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + async def elicitation_callback(context: RequestContext[ClientTransportSession, None], params: ElicitRequestParams): return ElicitResult(action="decline") await call_tool_and_assert( @@ -129,7 +129,7 @@ class InvalidNestedSchema(BaseModel): # Dummy callback (won't be called due to validation failure) async def elicitation_callback( - context: RequestContext[ClientSession, None], params: ElicitRequestParams + context: RequestContext[ClientTransportSession, None], params: ElicitRequestParams ): # pragma: no cover return ElicitResult(action="accept", content={}) @@ -189,7 +189,7 @@ async def optional_tool(ctx: Context[ServerSession, None]) -> str: for content, expected in test_cases: - async def callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + async def callback(context: RequestContext[ClientTransportSession, None], params: ElicitRequestParams): return ElicitResult(action="accept", content=content) await call_tool_and_assert(mcp, callback, "optional_tool", {}, expected) @@ -208,7 +208,7 @@ async def invalid_optional_tool(ctx: Context[ServerSession, None]) -> str: # pr return f"Validation failed: {str(e)}" async def elicitation_callback( - context: RequestContext[ClientSession, None], params: ElicitRequestParams + context: RequestContext[ClientTransportSession, None], params: ElicitRequestParams ): # pragma: no cover return ElicitResult(action="accept", content={}) @@ -245,7 +245,9 @@ async def defaults_tool(ctx: Context[ServerSession, None]) -> str: return f"User {result.action}" # First verify that defaults are present in the JSON schema sent to clients - async def callback_schema_verify(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + async def callback_schema_verify( + context: RequestContext[ClientTransportSession, None], params: ElicitRequestParams + ): # Verify the schema includes defaults schema = params.requestedSchema props = schema["properties"] @@ -266,7 +268,7 @@ async def callback_schema_verify(context: RequestContext[ClientSession, None], p ) # Test overriding defaults - async def callback_override(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + async def callback_override(context: RequestContext[ClientTransportSession, None], params: ElicitRequestParams): return ElicitResult( action="accept", content={"email": "john@example.com", "name": "John", "age": 25, "subscribe": False} ) diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index b1cefca29c..778b0bfd7a 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -32,7 +32,7 @@ structured_output, tool_progress, ) -from mcp.client.session import ClientSession +from mcp.client.session import ClientSession, ClientTransportSession from mcp.client.sse import sse_client from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client from mcp.shared.context import RequestContext @@ -212,7 +212,7 @@ def unpack_streams( # Callback functions for testing async def sampling_callback( - context: RequestContext[ClientSession, None], params: CreateMessageRequestParams + context: RequestContext[ClientTransportSession, None], params: CreateMessageRequestParams ) -> CreateMessageResult: """Sampling callback for tests.""" return CreateMessageResult( @@ -225,7 +225,7 @@ async def sampling_callback( ) -async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): +async def elicitation_callback(context: RequestContext[ClientTransportSession, None], params: ElicitRequestParams): """Elicitation callback for tests.""" # For restaurant booking test if "No tables available" in params.message: diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 43b321d96e..736e261cd3 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -21,7 +21,7 @@ from starlette.routing import Mount import mcp.types as types -from mcp.client.session import ClientSession +from mcp.client.session import ClientSession, ClientTransportSession from mcp.client.streamable_http import streamablehttp_client from mcp.server import Server from mcp.server.streamable_http import ( @@ -1233,7 +1233,7 @@ async def test_streamablehttp_server_sampling(basic_server: None, basic_server_u # Define sampling callback that returns a mock response async def sampling_callback( - context: RequestContext[ClientSession, Any], + context: RequestContext[ClientTransportSession, Any], params: types.CreateMessageRequestParams, ) -> types.CreateMessageResult: nonlocal sampling_callback_invoked, captured_message_params From b733fcfc8ea0df816ff5f82b3b9f4f24301f6363 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 7 Nov 2025 08:16:07 +0000 Subject: [PATCH 26/53] fix type hints for serversession --- examples/snippets/servers/elicitation.py | 4 ++-- examples/snippets/servers/lifespan_example.py | 4 ++-- examples/snippets/servers/notifications.py | 4 ++-- examples/snippets/servers/tool_progress.py | 4 ++-- src/mcp/server/elicitation.py | 4 ++-- src/mcp/server/fastmcp/server.py | 4 ++-- src/mcp/server/lowlevel/server.py | 6 +++--- src/mcp/server/session.py | 2 +- src/mcp/shared/context.py | 6 +++++- tests/client/test_sampling_callback.py | 5 ++++- tests/shared/test_streamable_http.py | 6 ++++-- 11 files changed, 29 insertions(+), 20 deletions(-) diff --git a/examples/snippets/servers/elicitation.py b/examples/snippets/servers/elicitation.py index 2c8a3b35ac..049b42516f 100644 --- a/examples/snippets/servers/elicitation.py +++ b/examples/snippets/servers/elicitation.py @@ -1,7 +1,7 @@ from pydantic import BaseModel, Field from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerSession +from mcp.server.session import ServerTransportSession mcp = FastMCP(name="Elicitation Example") @@ -17,7 +17,7 @@ class BookingPreferences(BaseModel): @mcp.tool() -async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerSession, None]) -> str: +async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerTransportSession, None]) -> str: """Book a table with date availability check.""" # Check if date is available if date == "2024-12-25": diff --git a/examples/snippets/servers/lifespan_example.py b/examples/snippets/servers/lifespan_example.py index 62278b6aac..32b6997304 100644 --- a/examples/snippets/servers/lifespan_example.py +++ b/examples/snippets/servers/lifespan_example.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerSession +from mcp.server.session import ServerTransportSession # Mock database class for example @@ -51,7 +51,7 @@ async def app_lifespan(server: FastMCP) -> AsyncIterator[AppContext]: # Access type-safe lifespan context in tools @mcp.tool() -def query_db(ctx: Context[ServerSession, AppContext]) -> str: +def query_db(ctx: Context[ServerTransportSession, AppContext]) -> str: """Tool that uses initialized resources.""" db = ctx.request_context.lifespan_context.db return db.query() diff --git a/examples/snippets/servers/notifications.py b/examples/snippets/servers/notifications.py index 833bc89053..36d9712eb6 100644 --- a/examples/snippets/servers/notifications.py +++ b/examples/snippets/servers/notifications.py @@ -1,11 +1,11 @@ from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerSession +from mcp.server.session import ServerTransportSession mcp = FastMCP(name="Notifications Example") @mcp.tool() -async def process_data(data: str, ctx: Context[ServerSession, None]) -> str: +async def process_data(data: str, ctx: Context[ServerTransportSession, None]) -> str: """Process data with logging.""" # Different log levels await ctx.debug(f"Debug: Processing '{data}'") diff --git a/examples/snippets/servers/tool_progress.py b/examples/snippets/servers/tool_progress.py index 2ac458f6aa..dddd8c9eb2 100644 --- a/examples/snippets/servers/tool_progress.py +++ b/examples/snippets/servers/tool_progress.py @@ -1,11 +1,11 @@ from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerSession +from mcp.server.session import ServerTransportSession mcp = FastMCP(name="Progress Example") @mcp.tool() -async def long_running_task(task_name: str, ctx: Context[ServerSession, None], steps: int = 5) -> str: +async def long_running_task(task_name: str, ctx: Context[ServerTransportSession, None], steps: int = 5) -> str: """Execute a task with progress updates.""" await ctx.info(f"Starting: {task_name}") diff --git a/src/mcp/server/elicitation.py b/src/mcp/server/elicitation.py index bba988f496..65399e27c2 100644 --- a/src/mcp/server/elicitation.py +++ b/src/mcp/server/elicitation.py @@ -8,7 +8,7 @@ from pydantic import BaseModel from pydantic.fields import FieldInfo -from mcp.server.session import ServerSession +from mcp.server.session import ServerTransportSession from mcp.types import RequestId ElicitSchemaModelT = TypeVar("ElicitSchemaModelT", bound=BaseModel) @@ -74,7 +74,7 @@ def _is_primitive_field(field_info: FieldInfo) -> bool: async def elicit_with_validation( - session: ServerSession, + session: ServerTransportSession, message: str, schema: type[ElicitSchemaModelT], related_request_id: RequestId | None = None, diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 865b8e7e72..03e2233296 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -54,7 +54,7 @@ from mcp.server.lowlevel.server import LifespanResultT from mcp.server.lowlevel.server import Server as MCPServer from mcp.server.lowlevel.server import lifespan as default_lifespan -from mcp.server.session import ServerSession, ServerSessionT +from mcp.server.session import ServerSessionT, ServerTransportSession from mcp.server.sse import SseServerTransport from mcp.server.stdio import stdio_server from mcp.server.streamable_http import EventStore @@ -315,7 +315,7 @@ async def list_tools(self) -> list[MCPTool]: for info in tools ] - def get_context(self) -> Context[ServerSession, LifespanResultT, Request]: + def get_context(self) -> Context[ServerTransportSession, LifespanResultT, Request]: """ Returns a Context object. Note that the context will only be valid during a request; outside a request, most methods will error. diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 49d289fb75..329cd1dd23 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -85,7 +85,7 @@ async def main(): from mcp.server.lowlevel.func_inspection import create_call_wrapper from mcp.server.lowlevel.helper_types import ReadResourceContents from mcp.server.models import InitializationOptions -from mcp.server.session import ServerSession +from mcp.server.session import ServerSession, ServerTransportSession from mcp.shared.context import RequestContext from mcp.shared.exceptions import McpError from mcp.shared.message import ServerMessageMetadata, SessionMessage @@ -102,7 +102,7 @@ async def main(): CombinationContent: TypeAlias = tuple[UnstructuredContent, StructuredContent] # This will be properly typed in each Server instance's context -request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = contextvars.ContextVar("request_ctx") +request_ctx: contextvars.ContextVar[RequestContext[ServerTransportSession, Any, Any]] = contextvars.ContextVar("request_ctx") class NotificationOptions: @@ -231,7 +231,7 @@ def get_capabilities( @property def request_context( self, - ) -> RequestContext[ServerSession, LifespanResultT, RequestT]: + ) -> RequestContext[ServerTransportSession, LifespanResultT, RequestT]: """If called outside of a request context, this will raise a LookupError.""" return request_ctx.get() diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 96f879034c..9456ebf9fd 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -62,7 +62,7 @@ class InitializationState(Enum): Initialized = 3 -ServerSessionT = TypeVar("ServerSessionT", bound="ServerSession") +ServerSessionT = TypeVar("ServerSessionT", bound="ServerTransportSession") ServerRequestResponder = ( RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index 0fb12c649c..094fbddf40 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -4,10 +4,14 @@ from typing_extensions import TypeVar from mcp.client.transport_session import ClientTransportSession +from mcp.server.transport_session import ServerTransportSession from mcp.shared.session import BaseSession from mcp.types import RequestId, RequestParams -SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any] | ClientTransportSession) +SessionT = TypeVar("SessionT", + bound=BaseSession[Any, Any, Any, Any, Any] | + ClientTransportSession | + ServerTransportSession) LifespanContextT = TypeVar("LifespanContextT") RequestT = TypeVar("RequestT", default=Any) diff --git a/tests/client/test_sampling_callback.py b/tests/client/test_sampling_callback.py index 8cd2c71166..3fe50a132c 100644 --- a/tests/client/test_sampling_callback.py +++ b/tests/client/test_sampling_callback.py @@ -1,6 +1,8 @@ import pytest from mcp.client.session import ClientTransportSession +from mcp.server.session import ServerSession +from typing import cast from mcp.shared.context import RequestContext from mcp.shared.memory import ( create_connected_server_and_client_session as create_session, @@ -34,7 +36,8 @@ async def sampling_callback( @server.tool("test_sampling") async def test_sampling_tool(message: str): - value = await server.get_context().session.create_message( + session = cast(ServerSession, server.get_context().session) + value = await session.create_message( messages=[SamplingMessage(role="user", content=TextContent(type="text", text=message))], max_tokens=100, ) diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 736e261cd3..95bbd633ef 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -8,7 +8,7 @@ import multiprocessing import socket from collections.abc import Generator -from typing import Any +from typing import Any, cast import anyio import httpx @@ -22,6 +22,7 @@ import mcp.types as types from mcp.client.session import ClientSession, ClientTransportSession +from mcp.server.session import ServerSession from mcp.client.streamable_http import streamablehttp_client from mcp.server import Server from mcp.server.streamable_http import ( @@ -198,7 +199,8 @@ async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent] elif name == "test_sampling_tool": # Test sampling by requesting the client to sample a message - sampling_result = await ctx.session.create_message( + session = cast(ServerSession, ctx.session) + sampling_result = await session.create_message( messages=[ types.SamplingMessage( role="user", From cdc39f4edbe0af04ab84f338e8b807e87ac514dc Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 7 Nov 2025 08:16:44 +0000 Subject: [PATCH 27/53] fix ruff --- src/mcp/server/lowlevel/server.py | 4 +++- src/mcp/shared/context.py | 7 +++---- tests/client/test_sampling_callback.py | 3 ++- tests/shared/test_streamable_http.py | 2 +- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 329cd1dd23..b60e049749 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -102,7 +102,9 @@ async def main(): CombinationContent: TypeAlias = tuple[UnstructuredContent, StructuredContent] # This will be properly typed in each Server instance's context -request_ctx: contextvars.ContextVar[RequestContext[ServerTransportSession, Any, Any]] = contextvars.ContextVar("request_ctx") +request_ctx: contextvars.ContextVar[RequestContext[ServerTransportSession, Any, Any]] = contextvars.ContextVar( + "request_ctx" +) class NotificationOptions: diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index 094fbddf40..63fafa2418 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -8,10 +8,9 @@ from mcp.shared.session import BaseSession from mcp.types import RequestId, RequestParams -SessionT = TypeVar("SessionT", - bound=BaseSession[Any, Any, Any, Any, Any] | - ClientTransportSession | - ServerTransportSession) +SessionT = TypeVar( + "SessionT", bound=BaseSession[Any, Any, Any, Any, Any] | ClientTransportSession | ServerTransportSession +) LifespanContextT = TypeVar("LifespanContextT") RequestT = TypeVar("RequestT", default=Any) diff --git a/tests/client/test_sampling_callback.py b/tests/client/test_sampling_callback.py index 3fe50a132c..feed499afd 100644 --- a/tests/client/test_sampling_callback.py +++ b/tests/client/test_sampling_callback.py @@ -1,8 +1,9 @@ +from typing import cast + import pytest from mcp.client.session import ClientTransportSession from mcp.server.session import ServerSession -from typing import cast from mcp.shared.context import RequestContext from mcp.shared.memory import ( create_connected_server_and_client_session as create_session, diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 95bbd633ef..08968d6f7e 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -22,9 +22,9 @@ import mcp.types as types from mcp.client.session import ClientSession, ClientTransportSession -from mcp.server.session import ServerSession from mcp.client.streamable_http import streamablehttp_client from mcp.server import Server +from mcp.server.session import ServerSession from mcp.server.streamable_http import ( MCP_PROTOCOL_VERSION_HEADER, MCP_SESSION_ID_HEADER, From 65a3b0f8aee6257cb70ddbef97873846890100a9 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 7 Nov 2025 08:17:17 +0000 Subject: [PATCH 28/53] uv run scripts/update_readme_snippets.py --- README.md | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 5dbc4bd9dd..6450adbe98 100644 --- a/README.md +++ b/README.md @@ -208,7 +208,7 @@ from contextlib import asynccontextmanager from dataclasses import dataclass from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerSession +from mcp.server.session import ServerTransportSession # Mock database class for example @@ -254,7 +254,7 @@ mcp = FastMCP("My App", lifespan=app_lifespan) # Access type-safe lifespan context in tools @mcp.tool() -def query_db(ctx: Context[ServerSession, AppContext]) -> str: +def query_db(ctx: Context[ServerTransportSession, AppContext]) -> str: """Tool that uses initialized resources.""" db = ctx.request_context.lifespan_context.db return db.query() @@ -326,13 +326,13 @@ Tools can optionally receive a Context object by including a parameter with the ```python from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerSession +from mcp.server.session import ServerTransportSession mcp = FastMCP(name="Progress Example") @mcp.tool() -async def long_running_task(task_name: str, ctx: Context[ServerSession, None], steps: int = 5) -> str: +async def long_running_task(task_name: str, ctx: Context[ServerTransportSession, None], steps: int = 5) -> str: """Execute a task with progress updates.""" await ctx.info(f"Starting: {task_name}") @@ -674,13 +674,13 @@ The Context object provides the following capabilities: ```python from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerSession +from mcp.server.session import ServerTransportSession mcp = FastMCP(name="Progress Example") @mcp.tool() -async def long_running_task(task_name: str, ctx: Context[ServerSession, None], steps: int = 5) -> str: +async def long_running_task(task_name: str, ctx: Context[ServerTransportSession, None], steps: int = 5) -> str: """Execute a task with progress updates.""" await ctx.info(f"Starting: {task_name}") @@ -798,7 +798,7 @@ Request additional information from users. This example shows an Elicitation dur from pydantic import BaseModel, Field from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerSession +from mcp.server.session import ServerTransportSession mcp = FastMCP(name="Elicitation Example") @@ -814,7 +814,7 @@ class BookingPreferences(BaseModel): @mcp.tool() -async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerSession, None]) -> str: +async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerTransportSession, None]) -> str: """Book a table with date availability check.""" # Check if date is available if date == "2024-12-25": @@ -888,13 +888,13 @@ Tools can send logs and notifications through the context: ```python from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerSession +from mcp.server.session import ServerTransportSession mcp = FastMCP(name="Notifications Example") @mcp.tool() -async def process_data(data: str, ctx: Context[ServerSession, None]) -> str: +async def process_data(data: str, ctx: Context[ServerTransportSession, None]) -> str: """Process data with logging.""" # Different log levels await ctx.debug(f"Debug: Processing '{data}'") @@ -2038,6 +2038,7 @@ import os from pydantic import AnyUrl from mcp import ClientSession, StdioServerParameters, types +from mcp.client.session import ClientTransportSession from mcp.client.stdio import stdio_client from mcp.shared.context import RequestContext @@ -2051,7 +2052,7 @@ server_params = StdioServerParameters( # Optional: create a sampling callback async def handle_sampling_message( - context: RequestContext[ClientSession, None], params: types.CreateMessageRequestParams + context: RequestContext[ClientTransportSession, None], params: types.CreateMessageRequestParams ) -> types.CreateMessageResult: print(f"Sampling request: {params.messages}") return types.CreateMessageResult( @@ -2169,6 +2170,7 @@ import os from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client +from mcp.client.transport_session import ClientTransportSession from mcp.shared.metadata_utils import get_display_name # Create server parameters for stdio connection @@ -2179,7 +2181,7 @@ server_params = StdioServerParameters( ) -async def display_tools(session: ClientSession): +async def display_tools(session: ClientTransportSession): """Display available tools with human-readable names""" tools_response = await session.list_tools() @@ -2191,7 +2193,7 @@ async def display_tools(session: ClientSession): print(f" {tool.description}") -async def display_resources(session: ClientSession): +async def display_resources(session: ClientTransportSession): """Display available resources with human-readable names""" resources_response = await session.list_resources() From f34e8fe12c101de34ce67e316b9c9e490487c555 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 7 Nov 2025 08:23:41 +0000 Subject: [PATCH 29/53] some fixes --- src/mcp/client/session_group.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 9e95ed909a..233c532cea 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -96,10 +96,10 @@ class _ComponentNames(BaseModel): _tools: dict[str, types.Tool] # Client-server connection management. - _sessions: dict[mcp.ClientTransportSession, _ComponentNames] - _tool_to_session: dict[str, mcp.ClientTransportSession] + _sessions: dict["mcp.ClientTransportSession", _ComponentNames] + _tool_to_session: dict[str, "mcp.ClientTransportSession"] _exit_stack: contextlib.AsyncExitStack - _session_exit_stacks: dict[mcp.ClientTransportSession, contextlib.AsyncExitStack] + _session_exit_stacks: dict["mcp.ClientTransportSession", contextlib.AsyncExitStack] # Optional fn consuming (component_name, serverInfo) for custom names. # This is provide a means to mitigate naming conflicts across servers. @@ -153,7 +153,7 @@ async def __aexit__( tg.start_soon(exit_stack.aclose) @property - def sessions(self) -> list[mcp.ClientTransportSession]: + def sessions(self) -> list["mcp.ClientTransportSession"]: """Returns the list of sessions being managed.""" return list(self._sessions.keys()) # pragma: no cover @@ -178,7 +178,7 @@ async def call_tool(self, name: str, args: dict[str, Any]) -> types.CallToolResu session_tool_name = self.tools[name].name return await session.call_tool(session_tool_name, args) - async def disconnect_from_server(self, session: mcp.ClientTransportSession) -> None: + async def disconnect_from_server(self, session: "mcp.ClientTransportSession") -> None: """Disconnects from a single MCP server.""" session_known_for_components = session in self._sessions @@ -216,8 +216,8 @@ async def disconnect_from_server(self, session: mcp.ClientTransportSession) -> N await session_stack_to_close.aclose() # pragma: no cover async def connect_with_session( - self, server_info: types.Implementation, session: mcp.ClientTransportSession - ) -> mcp.ClientTransportSession: + self, server_info: types.Implementation, session: "mcp.ClientTransportSession" + ) -> "mcp.ClientTransportSession": """Connects to a single MCP server.""" await self._aggregate_components(server_info, session) return session @@ -225,14 +225,14 @@ async def connect_with_session( async def connect_to_server( self, server_params: ServerParameters, - ) -> mcp.ClientTransportSession: + ) -> "mcp.ClientTransportSession": """Connects to a single MCP server.""" server_info, session = await self._establish_session(server_params) return await self.connect_with_session(server_info, session) async def _establish_session( self, server_params: ServerParameters - ) -> tuple[types.Implementation, mcp.ClientTransportSession]: + ) -> tuple[types.Implementation, "mcp.ClientTransportSession"]: """Establish a client session to an MCP server.""" session_stack = contextlib.AsyncExitStack() @@ -277,7 +277,7 @@ async def _establish_session( raise async def _aggregate_components( - self, server_info: types.Implementation, session: mcp.ClientTransportSession + self, server_info: types.Implementation, session: "mcp.ClientTransportSession" ) -> None: """Aggregates prompts, resources, and tools from a given session.""" @@ -291,7 +291,7 @@ async def _aggregate_components( prompts_temp: dict[str, types.Prompt] = {} resources_temp: dict[str, types.Resource] = {} tools_temp: dict[str, types.Tool] = {} - tool_to_session_temp: dict[str, mcp.ClientTransportSession] = {} + tool_to_session_temp: dict[str, "mcp.ClientTransportSession"] = {} # Query the server for its prompts and aggregate to list. try: From 1bfc08696de26b14b83055eb855dab236fa211bd Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 7 Nov 2025 08:24:59 +0000 Subject: [PATCH 30/53] fix ruff --- src/mcp/client/session_group.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 233c532cea..f3d351d312 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -291,7 +291,7 @@ async def _aggregate_components( prompts_temp: dict[str, types.Prompt] = {} resources_temp: dict[str, types.Resource] = {} tools_temp: dict[str, types.Tool] = {} - tool_to_session_temp: dict[str, "mcp.ClientTransportSession"] = {} + tool_to_session_temp: dict[str, mcp.ClientTransportSession] = {} # Query the server for its prompts and aggregate to list. try: From 481f7eabe10ef4b6cbedbb5745b160fe111e6ae7 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 7 Nov 2025 08:49:55 +0000 Subject: [PATCH 31/53] fix type hints without cast --- src/mcp/client/session_group.py | 20 ++++++++++---------- src/mcp/shared/memory.py | 3 ++- tests/client/test_sampling_callback.py | 5 ++--- tests/server/test_cancel_handling.py | 2 ++ tests/shared/test_memory.py | 4 ++-- tests/shared/test_progress_notifications.py | 1 + tests/shared/test_session.py | 8 +++++--- tests/shared/test_sse.py | 4 ++-- tests/shared/test_streamable_http.py | 5 +++-- tests/shared/test_ws.py | 4 ++-- 10 files changed, 31 insertions(+), 25 deletions(-) diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index f3d351d312..9e95ed909a 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -96,10 +96,10 @@ class _ComponentNames(BaseModel): _tools: dict[str, types.Tool] # Client-server connection management. - _sessions: dict["mcp.ClientTransportSession", _ComponentNames] - _tool_to_session: dict[str, "mcp.ClientTransportSession"] + _sessions: dict[mcp.ClientTransportSession, _ComponentNames] + _tool_to_session: dict[str, mcp.ClientTransportSession] _exit_stack: contextlib.AsyncExitStack - _session_exit_stacks: dict["mcp.ClientTransportSession", contextlib.AsyncExitStack] + _session_exit_stacks: dict[mcp.ClientTransportSession, contextlib.AsyncExitStack] # Optional fn consuming (component_name, serverInfo) for custom names. # This is provide a means to mitigate naming conflicts across servers. @@ -153,7 +153,7 @@ async def __aexit__( tg.start_soon(exit_stack.aclose) @property - def sessions(self) -> list["mcp.ClientTransportSession"]: + def sessions(self) -> list[mcp.ClientTransportSession]: """Returns the list of sessions being managed.""" return list(self._sessions.keys()) # pragma: no cover @@ -178,7 +178,7 @@ async def call_tool(self, name: str, args: dict[str, Any]) -> types.CallToolResu session_tool_name = self.tools[name].name return await session.call_tool(session_tool_name, args) - async def disconnect_from_server(self, session: "mcp.ClientTransportSession") -> None: + async def disconnect_from_server(self, session: mcp.ClientTransportSession) -> None: """Disconnects from a single MCP server.""" session_known_for_components = session in self._sessions @@ -216,8 +216,8 @@ async def disconnect_from_server(self, session: "mcp.ClientTransportSession") -> await session_stack_to_close.aclose() # pragma: no cover async def connect_with_session( - self, server_info: types.Implementation, session: "mcp.ClientTransportSession" - ) -> "mcp.ClientTransportSession": + self, server_info: types.Implementation, session: mcp.ClientTransportSession + ) -> mcp.ClientTransportSession: """Connects to a single MCP server.""" await self._aggregate_components(server_info, session) return session @@ -225,14 +225,14 @@ async def connect_with_session( async def connect_to_server( self, server_params: ServerParameters, - ) -> "mcp.ClientTransportSession": + ) -> mcp.ClientTransportSession: """Connects to a single MCP server.""" server_info, session = await self._establish_session(server_params) return await self.connect_with_session(server_info, session) async def _establish_session( self, server_params: ServerParameters - ) -> tuple[types.Implementation, "mcp.ClientTransportSession"]: + ) -> tuple[types.Implementation, mcp.ClientTransportSession]: """Establish a client session to an MCP server.""" session_stack = contextlib.AsyncExitStack() @@ -277,7 +277,7 @@ async def _establish_session( raise async def _aggregate_components( - self, server_info: types.Implementation, session: "mcp.ClientTransportSession" + self, server_info: types.Implementation, session: mcp.ClientTransportSession ) -> None: """Aggregates prompts, resources, and tools from a given session.""" diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index 06d404e311..9f68a0c47e 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -14,6 +14,7 @@ import mcp.types as types from mcp.client.session import ClientSession, ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT +from mcp.client.session import ClientTransportSession from mcp.server import Server from mcp.server.fastmcp import FastMCP from mcp.shared.message import SessionMessage @@ -57,7 +58,7 @@ async def create_connected_server_and_client_session( client_info: types.Implementation | None = None, raise_exceptions: bool = False, elicitation_callback: ElicitationFnT | None = None, -) -> AsyncGenerator[ClientSession, None]: +) -> AsyncGenerator[ClientTransportSession, None]: """Creates a ClientSession that is connected to a running MCP server.""" # TODO(Marcelo): we should have a proper `Client` that can use this "in-memory transport", diff --git a/tests/client/test_sampling_callback.py b/tests/client/test_sampling_callback.py index feed499afd..49138398c0 100644 --- a/tests/client/test_sampling_callback.py +++ b/tests/client/test_sampling_callback.py @@ -1,5 +1,3 @@ -from typing import cast - import pytest from mcp.client.session import ClientTransportSession @@ -37,7 +35,8 @@ async def sampling_callback( @server.tool("test_sampling") async def test_sampling_tool(message: str): - session = cast(ServerSession, server.get_context().session) + session = server.get_context().session + assert isinstance(session, ServerSession) value = await session.create_message( messages=[SamplingMessage(role="user", content=TextContent(type="text", text=message))], max_tokens=100, diff --git a/tests/server/test_cancel_handling.py b/tests/server/test_cancel_handling.py index 47c49bb62b..3a0df20cc0 100644 --- a/tests/server/test_cancel_handling.py +++ b/tests/server/test_cancel_handling.py @@ -9,6 +9,7 @@ from mcp.server.lowlevel.server import Server from mcp.shared.exceptions import McpError from mcp.shared.memory import create_connected_server_and_client_session +from mcp.client.session import ClientSession from mcp.types import ( CallToolRequest, CallToolRequestParams, @@ -56,6 +57,7 @@ async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[ async with create_connected_server_and_client_session(server) as client: # First request (will be cancelled) + assert isinstance(client, ClientSession) async def first_request(): try: await client.send_request( diff --git a/tests/shared/test_memory.py b/tests/shared/test_memory.py index ca4368e9f8..4ebce3b15d 100644 --- a/tests/shared/test_memory.py +++ b/tests/shared/test_memory.py @@ -2,7 +2,7 @@ from pydantic import AnyUrl from typing_extensions import AsyncGenerator -from mcp.client.session import ClientSession +from mcp.client.session import ClientSession, ClientTransportSession from mcp.server import Server from mcp.shared.memory import create_connected_server_and_client_session from mcp.types import EmptyResult, Resource @@ -28,7 +28,7 @@ async def handle_list_resources(): # pragma: no cover @pytest.fixture async def client_connected_to_server( mcp_server: Server, -) -> AsyncGenerator[ClientSession, None]: +) -> AsyncGenerator[ClientTransportSession, None]: async with create_connected_server_and_client_session(mcp_server) as client_session: yield client_session diff --git a/tests/shared/test_progress_notifications.py b/tests/shared/test_progress_notifications.py index 1552711d2e..25afd7f328 100644 --- a/tests/shared/test_progress_notifications.py +++ b/tests/shared/test_progress_notifications.py @@ -370,6 +370,7 @@ async def handle_list_tools() -> list[types.Tool]: with patch("mcp.shared.session.logging.error", side_effect=mock_log_error): async with create_connected_server_and_client_session(server) as client_session: # Send a request with a failing progress callback + assert isinstance(client_session, ClientSession) result = await client_session.send_request( types.ClientRequest( types.CallToolRequest( diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index 313ec99265..47b5a02f62 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -5,7 +5,7 @@ import pytest import mcp.types as types -from mcp.client.session import ClientSession +from mcp.client.session import ClientSession, ClientTransportSession from mcp.server.lowlevel.server import Server from mcp.shared.exceptions import McpError from mcp.shared.memory import create_client_server_memory_streams, create_connected_server_and_client_session @@ -27,19 +27,20 @@ def mcp_server() -> Server: @pytest.fixture async def client_connected_to_server( mcp_server: Server, -) -> AsyncGenerator[ClientSession, None]: +) -> AsyncGenerator[ClientTransportSession, None]: async with create_connected_server_and_client_session(mcp_server) as client_session: yield client_session @pytest.mark.anyio async def test_in_flight_requests_cleared_after_completion( - client_connected_to_server: ClientSession, + client_connected_to_server: ClientTransportSession, ): """Verify that _in_flight is empty after all requests complete.""" # Send a request and wait for response response = await client_connected_to_server.send_ping() assert isinstance(response, EmptyResult) + assert isinstance(client_connected_to_server, ClientSession) # Verify _in_flight is empty assert len(client_connected_to_server._in_flight) == 0 @@ -101,6 +102,7 @@ async def make_request(client_session: ClientSession): async with create_connected_server_and_client_session(make_server()) as client_session: async with anyio.create_task_group() as tg: + assert isinstance(client_session, ClientSession) tg.start_soon(make_request, client_session) # Wait for the request to be in-flight diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 28ac07d092..ba823ab6ad 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -17,7 +17,7 @@ from starlette.routing import Mount, Route import mcp.types as types -from mcp.client.session import ClientSession +from mcp.client.session import ClientSession, ClientTransportSession from mcp.client.sse import sse_client from mcp.server import Server from mcp.server.sse import SseServerTransport @@ -185,7 +185,7 @@ async def test_sse_client_basic_connection(server: None, server_url: str) -> Non @pytest.fixture -async def initialized_sse_client_session(server: None, server_url: str) -> AsyncGenerator[ClientSession, None]: +async def initialized_sse_client_session(server: None, server_url: str) -> AsyncGenerator[ClientTransportSession, None]: async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams: async with ClientSession(*streams) as session: await session.initialize() diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 08968d6f7e..603a4270a6 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -8,7 +8,7 @@ import multiprocessing import socket from collections.abc import Generator -from typing import Any, cast +from typing import Any import anyio import httpx @@ -199,7 +199,8 @@ async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent] elif name == "test_sampling_tool": # Test sampling by requesting the client to sample a message - session = cast(ServerSession, ctx.session) + session = ctx.session + assert isinstance(session, ServerSession) sampling_result = await session.create_message( messages=[ types.SamplingMessage( diff --git a/tests/shared/test_ws.py b/tests/shared/test_ws.py index f093cb4927..1fac8696f7 100644 --- a/tests/shared/test_ws.py +++ b/tests/shared/test_ws.py @@ -12,7 +12,7 @@ from starlette.routing import WebSocketRoute from starlette.websockets import WebSocket -from mcp.client.session import ClientSession +from mcp.client.session import ClientSession, ClientTransportSession from mcp.client.websocket import websocket_client from mcp.server import Server from mcp.server.websocket import websocket_server @@ -125,7 +125,7 @@ def server(server_port: int) -> Generator[None, None, None]: @pytest.fixture() -async def initialized_ws_client_session(server: None, server_url: str) -> AsyncGenerator[ClientSession, None]: +async def initialized_ws_client_session(server: None, server_url: str) -> AsyncGenerator[ClientTransportSession, None]: """Create and initialize a WebSocket client session""" async with websocket_client(server_url + "/ws") as streams: async with ClientSession(*streams) as session: From 6b8f7374b71b037098864047885ec356a4e930d9 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 7 Nov 2025 08:51:45 +0000 Subject: [PATCH 32/53] fix ruff --- src/mcp/shared/memory.py | 11 +++++++++-- tests/server/test_cancel_handling.py | 3 ++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index 9f68a0c47e..b8466fe91c 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -13,8 +13,15 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream import mcp.types as types -from mcp.client.session import ClientSession, ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT -from mcp.client.session import ClientTransportSession +from mcp.client.session import ( + ClientSession, + ClientTransportSession, + ElicitationFnT, + ListRootsFnT, + LoggingFnT, + MessageHandlerFnT, + SamplingFnT, +) from mcp.server import Server from mcp.server.fastmcp import FastMCP from mcp.shared.message import SessionMessage diff --git a/tests/server/test_cancel_handling.py b/tests/server/test_cancel_handling.py index 3a0df20cc0..b1f825933a 100644 --- a/tests/server/test_cancel_handling.py +++ b/tests/server/test_cancel_handling.py @@ -6,10 +6,10 @@ import pytest import mcp.types as types +from mcp.client.session import ClientSession from mcp.server.lowlevel.server import Server from mcp.shared.exceptions import McpError from mcp.shared.memory import create_connected_server_and_client_session -from mcp.client.session import ClientSession from mcp.types import ( CallToolRequest, CallToolRequestParams, @@ -58,6 +58,7 @@ async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[ async with create_connected_server_and_client_session(server) as client: # First request (will be cancelled) assert isinstance(client, ClientSession) + async def first_request(): try: await client.send_request( From 99856e813f2c2272a034cbbcfddb6da7e6612500 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 7 Nov 2025 08:54:29 +0000 Subject: [PATCH 33/53] remove overload --- src/mcp/client/transport_session.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/src/mcp/client/transport_session.py b/src/mcp/client/transport_session.py index c51b059f66..37578f2110 100644 --- a/src/mcp/client/transport_session.py +++ b/src/mcp/client/transport_session.py @@ -1,11 +1,10 @@ from abc import ABC, abstractmethod from datetime import timedelta -from typing import Any, overload +from typing import Any from pydantic import AnyUrl -from typing_extensions import deprecated -from mcp import types +import mcp.types as types from mcp.shared.session import ProgressFnT @@ -110,15 +109,8 @@ async def complete( """Send a completion/complete request.""" raise NotImplementedError - @overload - @deprecated("Use list_tools(params=PaginatedRequestParams(...)) instead") - async def list_tools(self, cursor: str | None) -> types.ListToolsResult: ... - @overload - async def list_tools(self, *, params: types.PaginatedRequestParams | None) -> types.ListToolsResult: ... - - @overload - async def list_tools(self) -> types.ListToolsResult: ... + @abstractmethod async def list_tools( From ea8a33ca220397e88b8e6e11225c1c669c759fe2 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 7 Nov 2025 08:56:27 +0000 Subject: [PATCH 34/53] revert client session group --- src/mcp/client/session_group.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 9e95ed909a..ecab5aecf1 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -96,10 +96,10 @@ class _ComponentNames(BaseModel): _tools: dict[str, types.Tool] # Client-server connection management. - _sessions: dict[mcp.ClientTransportSession, _ComponentNames] - _tool_to_session: dict[str, mcp.ClientTransportSession] + _sessions: dict[mcp.ClientSession, _ComponentNames] + _tool_to_session: dict[str, mcp.ClientSession] _exit_stack: contextlib.AsyncExitStack - _session_exit_stacks: dict[mcp.ClientTransportSession, contextlib.AsyncExitStack] + _session_exit_stacks: dict[mcp.ClientSession, contextlib.AsyncExitStack] # Optional fn consuming (component_name, serverInfo) for custom names. # This is provide a means to mitigate naming conflicts across servers. @@ -153,7 +153,7 @@ async def __aexit__( tg.start_soon(exit_stack.aclose) @property - def sessions(self) -> list[mcp.ClientTransportSession]: + def sessions(self) -> list[mcp.ClientSession]: """Returns the list of sessions being managed.""" return list(self._sessions.keys()) # pragma: no cover @@ -178,7 +178,7 @@ async def call_tool(self, name: str, args: dict[str, Any]) -> types.CallToolResu session_tool_name = self.tools[name].name return await session.call_tool(session_tool_name, args) - async def disconnect_from_server(self, session: mcp.ClientTransportSession) -> None: + async def disconnect_from_server(self, session: mcp.ClientSession) -> None: """Disconnects from a single MCP server.""" session_known_for_components = session in self._sessions @@ -216,8 +216,8 @@ async def disconnect_from_server(self, session: mcp.ClientTransportSession) -> N await session_stack_to_close.aclose() # pragma: no cover async def connect_with_session( - self, server_info: types.Implementation, session: mcp.ClientTransportSession - ) -> mcp.ClientTransportSession: + self, server_info: types.Implementation, session: mcp.ClientSession + ) -> mcp.ClientSession: """Connects to a single MCP server.""" await self._aggregate_components(server_info, session) return session @@ -225,14 +225,14 @@ async def connect_with_session( async def connect_to_server( self, server_params: ServerParameters, - ) -> mcp.ClientTransportSession: + ) -> mcp.ClientSession: """Connects to a single MCP server.""" server_info, session = await self._establish_session(server_params) return await self.connect_with_session(server_info, session) async def _establish_session( self, server_params: ServerParameters - ) -> tuple[types.Implementation, mcp.ClientTransportSession]: + ) -> tuple[types.Implementation, mcp.ClientSession]: """Establish a client session to an MCP server.""" session_stack = contextlib.AsyncExitStack() @@ -277,7 +277,7 @@ async def _establish_session( raise async def _aggregate_components( - self, server_info: types.Implementation, session: mcp.ClientTransportSession + self, server_info: types.Implementation, session: mcp.ClientSession ) -> None: """Aggregates prompts, resources, and tools from a given session.""" @@ -291,7 +291,7 @@ async def _aggregate_components( prompts_temp: dict[str, types.Prompt] = {} resources_temp: dict[str, types.Resource] = {} tools_temp: dict[str, types.Tool] = {} - tool_to_session_temp: dict[str, mcp.ClientTransportSession] = {} + tool_to_session_temp: dict[str, mcp.ClientSession] = {} # Query the server for its prompts and aggregate to list. try: From 5bcfe6200e0314d9a8fcf7e1d7c8a50fcf9be5ce Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 7 Nov 2025 09:30:51 +0000 Subject: [PATCH 35/53] fix ruff pyright --- src/mcp/client/session_group.py | 4 +--- src/mcp/client/transport_session.py | 3 --- src/mcp/shared/context.py | 10 ++++++---- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index ecab5aecf1..2c55bb7752 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -276,9 +276,7 @@ async def _establish_session( await session_stack.aclose() raise - async def _aggregate_components( - self, server_info: types.Implementation, session: mcp.ClientSession - ) -> None: + async def _aggregate_components(self, server_info: types.Implementation, session: mcp.ClientSession) -> None: """Aggregates prompts, resources, and tools from a given session.""" # Create a reverse index so we can find all prompts, resources, and diff --git a/src/mcp/client/transport_session.py b/src/mcp/client/transport_session.py index 37578f2110..07389d59a0 100644 --- a/src/mcp/client/transport_session.py +++ b/src/mcp/client/transport_session.py @@ -109,9 +109,6 @@ async def complete( """Send a completion/complete request.""" raise NotImplementedError - - - @abstractmethod async def list_tools( self, diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index 63fafa2418..845cc50e20 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -1,15 +1,17 @@ from dataclasses import dataclass -from typing import Any, Generic +from typing import TYPE_CHECKING, Any, Generic from typing_extensions import TypeVar -from mcp.client.transport_session import ClientTransportSession -from mcp.server.transport_session import ServerTransportSession from mcp.shared.session import BaseSession from mcp.types import RequestId, RequestParams +if TYPE_CHECKING: + from mcp.client.session import ClientTransportSession + from mcp.server.session import ServerTransportSession + SessionT = TypeVar( - "SessionT", bound=BaseSession[Any, Any, Any, Any, Any] | ClientTransportSession | ServerTransportSession + "SessionT", bound=BaseSession[Any, Any, Any, Any, Any] | "ClientTransportSession" | "ServerTransportSession" ) LifespanContextT = TypeVar("LifespanContextT") RequestT = TypeVar("RequestT", default=Any) From af6be96916628e27d6e4661ca34e48749e51cfa6 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Wed, 12 Nov 2025 08:50:14 +0000 Subject: [PATCH 36/53] fix ruff --- src/mcp/client/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index c3559b13a4..0bd4e9608e 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -39,7 +39,7 @@ async def __call__( class ListRootsFnT(Protocol): async def __call__( self, context: RequestContext["ClientTransportSession", Any] - ) -> types.ListRootsResult | types.ErrorData: ... # pragma: no branch + ) -> types.ListRootsResult | types.ErrorData: ... # pragma: no branch class LoggingFnT(Protocol): From 4377d41eb524bb6165b7bf1d9a6f9d2eb392fcd6 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 14 Nov 2025 06:27:36 +0000 Subject: [PATCH 37/53] fix imports --- examples/snippets/clients/display_utilities.py | 3 +-- examples/snippets/clients/stdio_client.py | 3 +-- examples/snippets/servers/elicitation.py | 2 +- examples/snippets/servers/lifespan_example.py | 2 +- examples/snippets/servers/notifications.py | 2 +- examples/snippets/servers/tool_progress.py | 2 +- src/mcp/server/elicitation.py | 2 +- src/mcp/server/fastmcp/server.py | 3 ++- src/mcp/server/lowlevel/server.py | 3 ++- src/mcp/shared/context.py | 9 ++++----- src/mcp/shared/memory.py | 2 +- tests/client/test_list_roots_callback.py | 2 +- tests/client/test_sampling_callback.py | 2 +- tests/client/test_session.py | 3 ++- tests/server/fastmcp/test_elicitation.py | 3 ++- tests/server/fastmcp/test_integration.py | 3 ++- tests/shared/test_memory.py | 3 ++- tests/shared/test_session.py | 3 ++- tests/shared/test_sse.py | 3 ++- tests/shared/test_streamable_http.py | 3 ++- tests/shared/test_ws.py | 3 ++- 21 files changed, 34 insertions(+), 27 deletions(-) diff --git a/examples/snippets/clients/display_utilities.py b/examples/snippets/clients/display_utilities.py index 5e1b203ee6..b8ad7dffc2 100644 --- a/examples/snippets/clients/display_utilities.py +++ b/examples/snippets/clients/display_utilities.py @@ -6,9 +6,8 @@ import asyncio import os -from mcp import ClientSession, StdioServerParameters +from mcp import ClientSession, ClientTransportSession, StdioServerParameters from mcp.client.stdio import stdio_client -from mcp.client.transport_session import ClientTransportSession from mcp.shared.metadata_utils import get_display_name # Create server parameters for stdio connection diff --git a/examples/snippets/clients/stdio_client.py b/examples/snippets/clients/stdio_client.py index 62fb0f4c47..c72cc54f2e 100644 --- a/examples/snippets/clients/stdio_client.py +++ b/examples/snippets/clients/stdio_client.py @@ -8,8 +8,7 @@ from pydantic import AnyUrl -from mcp import ClientSession, StdioServerParameters, types -from mcp.client.session import ClientTransportSession +from mcp import ClientSession, StdioServerParameters, types, ClientTransportSession from mcp.client.stdio import stdio_client from mcp.shared.context import RequestContext diff --git a/examples/snippets/servers/elicitation.py b/examples/snippets/servers/elicitation.py index 049b42516f..45f2cb68b9 100644 --- a/examples/snippets/servers/elicitation.py +++ b/examples/snippets/servers/elicitation.py @@ -1,7 +1,7 @@ from pydantic import BaseModel, Field from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerTransportSession +from mcp.server.transport_session import ServerTransportSession mcp = FastMCP(name="Elicitation Example") diff --git a/examples/snippets/servers/lifespan_example.py b/examples/snippets/servers/lifespan_example.py index 32b6997304..46f01f427f 100644 --- a/examples/snippets/servers/lifespan_example.py +++ b/examples/snippets/servers/lifespan_example.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerTransportSession +from mcp.server.transport_session import ServerTransportSession # Mock database class for example diff --git a/examples/snippets/servers/notifications.py b/examples/snippets/servers/notifications.py index 36d9712eb6..995ecd8178 100644 --- a/examples/snippets/servers/notifications.py +++ b/examples/snippets/servers/notifications.py @@ -1,5 +1,5 @@ from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerTransportSession +from mcp.server.transport_session import ServerTransportSession mcp = FastMCP(name="Notifications Example") diff --git a/examples/snippets/servers/tool_progress.py b/examples/snippets/servers/tool_progress.py index dddd8c9eb2..a0f62fda61 100644 --- a/examples/snippets/servers/tool_progress.py +++ b/examples/snippets/servers/tool_progress.py @@ -1,5 +1,5 @@ from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerTransportSession +from mcp.server.transport_session import ServerTransportSession mcp = FastMCP(name="Progress Example") diff --git a/src/mcp/server/elicitation.py b/src/mcp/server/elicitation.py index 65399e27c2..b2f33ec7ce 100644 --- a/src/mcp/server/elicitation.py +++ b/src/mcp/server/elicitation.py @@ -8,7 +8,7 @@ from pydantic import BaseModel from pydantic.fields import FieldInfo -from mcp.server.session import ServerTransportSession +from mcp.server.transport_session import ServerTransportSession from mcp.types import RequestId ElicitSchemaModelT = TypeVar("ElicitSchemaModelT", bound=BaseModel) diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 03e2233296..840273e503 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -54,7 +54,8 @@ from mcp.server.lowlevel.server import LifespanResultT from mcp.server.lowlevel.server import Server as MCPServer from mcp.server.lowlevel.server import lifespan as default_lifespan -from mcp.server.session import ServerSessionT, ServerTransportSession +from mcp.server.session import ServerSessionT +from mcp.server.transport_session import ServerTransportSession from mcp.server.sse import SseServerTransport from mcp.server.stdio import stdio_server from mcp.server.streamable_http import EventStore diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index b60e049749..85846afc6e 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -85,7 +85,8 @@ async def main(): from mcp.server.lowlevel.func_inspection import create_call_wrapper from mcp.server.lowlevel.helper_types import ReadResourceContents from mcp.server.models import InitializationOptions -from mcp.server.session import ServerSession, ServerTransportSession +from mcp.server.session import ServerSession +from mcp.server.transport_session import ServerTransportSession from mcp.shared.context import RequestContext from mcp.shared.exceptions import McpError from mcp.shared.message import ServerMessageMetadata, SessionMessage diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index 845cc50e20..eaa3e27933 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -1,17 +1,16 @@ from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Generic +from typing import Any, Generic from typing_extensions import TypeVar from mcp.shared.session import BaseSession from mcp.types import RequestId, RequestParams -if TYPE_CHECKING: - from mcp.client.session import ClientTransportSession - from mcp.server.session import ServerTransportSession +from mcp.client.transport_session import ClientTransportSession +from mcp.server.transport_session import ServerTransportSession SessionT = TypeVar( - "SessionT", bound=BaseSession[Any, Any, Any, Any, Any] | "ClientTransportSession" | "ServerTransportSession" + "SessionT", bound=BaseSession[Any, Any, Any, Any, Any] | ClientTransportSession | ServerTransportSession ) LifespanContextT = TypeVar("LifespanContextT") RequestT = TypeVar("RequestT", default=Any) diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index b8466fe91c..2d203d7430 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -15,13 +15,13 @@ import mcp.types as types from mcp.client.session import ( ClientSession, - ClientTransportSession, ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT, ) +from mcp.client.transport_session import ClientTransportSession from mcp.server import Server from mcp.server.fastmcp import FastMCP from mcp.shared.message import SessionMessage diff --git a/tests/client/test_list_roots_callback.py b/tests/client/test_list_roots_callback.py index dc53eddbcb..5acb3b21aa 100644 --- a/tests/client/test_list_roots_callback.py +++ b/tests/client/test_list_roots_callback.py @@ -1,7 +1,7 @@ import pytest from pydantic import FileUrl -from mcp.client.session import ClientTransportSession +from mcp.client.transport_session import ClientTransportSession from mcp.server.fastmcp.server import Context from mcp.server.session import ServerSession from mcp.shared.context import RequestContext diff --git a/tests/client/test_sampling_callback.py b/tests/client/test_sampling_callback.py index 49138398c0..9fb6e29c75 100644 --- a/tests/client/test_sampling_callback.py +++ b/tests/client/test_sampling_callback.py @@ -1,6 +1,6 @@ import pytest -from mcp.client.session import ClientTransportSession +from mcp.client.transport_session import ClientTransportSession from mcp.server.session import ServerSession from mcp.shared.context import RequestContext from mcp.shared.memory import ( diff --git a/tests/client/test_session.py b/tests/client/test_session.py index c327a806f6..bd51e4e102 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -4,7 +4,8 @@ import pytest import mcp.types as types -from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession, ClientTransportSession +from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession +from mcp.client.transport_session import ClientTransportSession from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder diff --git a/tests/server/fastmcp/test_elicitation.py b/tests/server/fastmcp/test_elicitation.py index dd1ae72dc7..77f97e6772 100644 --- a/tests/server/fastmcp/test_elicitation.py +++ b/tests/server/fastmcp/test_elicitation.py @@ -7,7 +7,8 @@ import pytest from pydantic import BaseModel, Field -from mcp.client.session import ClientTransportSession, ElicitationFnT +from mcp.client.session import ElicitationFnT +from mcp.client.transport_session import ClientTransportSession from mcp.server.fastmcp import Context, FastMCP from mcp.server.session import ServerSession from mcp.shared.context import RequestContext diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index 778b0bfd7a..99e8972a95 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -32,7 +32,8 @@ structured_output, tool_progress, ) -from mcp.client.session import ClientSession, ClientTransportSession +from mcp.client.session import ClientSession +from mcp.client.transport_session import ClientTransportSession from mcp.client.sse import sse_client from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client from mcp.shared.context import RequestContext diff --git a/tests/shared/test_memory.py b/tests/shared/test_memory.py index 4ebce3b15d..56e0b98e70 100644 --- a/tests/shared/test_memory.py +++ b/tests/shared/test_memory.py @@ -2,7 +2,8 @@ from pydantic import AnyUrl from typing_extensions import AsyncGenerator -from mcp.client.session import ClientSession, ClientTransportSession +from mcp.client.session import ClientSession +from mcp.client.transport_session import ClientTransportSession from mcp.server import Server from mcp.shared.memory import create_connected_server_and_client_session from mcp.types import EmptyResult, Resource diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index 47b5a02f62..a056f705ba 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -5,7 +5,8 @@ import pytest import mcp.types as types -from mcp.client.session import ClientSession, ClientTransportSession +from mcp.client.session import ClientSession +from mcp.client.transport_session import ClientTransportSession from mcp.server.lowlevel.server import Server from mcp.shared.exceptions import McpError from mcp.shared.memory import create_client_server_memory_streams, create_connected_server_and_client_session diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index ba823ab6ad..967925a118 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -17,7 +17,8 @@ from starlette.routing import Mount, Route import mcp.types as types -from mcp.client.session import ClientSession, ClientTransportSession +from mcp.client.session import ClientSession +from mcp.client.transport_session import ClientTransportSession from mcp.client.sse import sse_client from mcp.server import Server from mcp.server.sse import SseServerTransport diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 603a4270a6..7aa768ae1b 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -21,7 +21,8 @@ from starlette.routing import Mount import mcp.types as types -from mcp.client.session import ClientSession, ClientTransportSession +from mcp.client.session import ClientSession +from mcp.client.transport_session import ClientTransportSession from mcp.client.streamable_http import streamablehttp_client from mcp.server import Server from mcp.server.session import ServerSession diff --git a/tests/shared/test_ws.py b/tests/shared/test_ws.py index 1fac8696f7..107cd5589e 100644 --- a/tests/shared/test_ws.py +++ b/tests/shared/test_ws.py @@ -12,7 +12,8 @@ from starlette.routing import WebSocketRoute from starlette.websockets import WebSocket -from mcp.client.session import ClientSession, ClientTransportSession +from mcp.client.session import ClientSession +from mcp.client.transport_session import ClientTransportSession from mcp.client.websocket import websocket_client from mcp.server import Server from mcp.server.websocket import websocket_server From f02873fa25e2dc96bfe44572e31a688167f202bc Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 14 Nov 2025 06:29:29 +0000 Subject: [PATCH 38/53] fix ruff --- examples/snippets/clients/stdio_client.py | 2 +- src/mcp/server/fastmcp/server.py | 2 +- src/mcp/shared/context.py | 5 ++--- tests/server/fastmcp/test_integration.py | 2 +- tests/shared/test_sse.py | 2 +- tests/shared/test_streamable_http.py | 2 +- 6 files changed, 7 insertions(+), 8 deletions(-) diff --git a/examples/snippets/clients/stdio_client.py b/examples/snippets/clients/stdio_client.py index c72cc54f2e..90f9fdff9b 100644 --- a/examples/snippets/clients/stdio_client.py +++ b/examples/snippets/clients/stdio_client.py @@ -8,7 +8,7 @@ from pydantic import AnyUrl -from mcp import ClientSession, StdioServerParameters, types, ClientTransportSession +from mcp import ClientSession, ClientTransportSession, StdioServerParameters, types from mcp.client.stdio import stdio_client from mcp.shared.context import RequestContext diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 840273e503..cc05403dd8 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -55,12 +55,12 @@ from mcp.server.lowlevel.server import Server as MCPServer from mcp.server.lowlevel.server import lifespan as default_lifespan from mcp.server.session import ServerSessionT -from mcp.server.transport_session import ServerTransportSession from mcp.server.sse import SseServerTransport from mcp.server.stdio import stdio_server from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings +from mcp.server.transport_session import ServerTransportSession from mcp.shared.context import LifespanContextT, RequestContext, RequestT from mcp.types import Annotations, AnyFunction, ContentBlock, GetPromptResult, Icon, ToolAnnotations from mcp.types import Prompt as MCPPrompt diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index eaa3e27933..63fafa2418 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -3,11 +3,10 @@ from typing_extensions import TypeVar -from mcp.shared.session import BaseSession -from mcp.types import RequestId, RequestParams - from mcp.client.transport_session import ClientTransportSession from mcp.server.transport_session import ServerTransportSession +from mcp.shared.session import BaseSession +from mcp.types import RequestId, RequestParams SessionT = TypeVar( "SessionT", bound=BaseSession[Any, Any, Any, Any, Any] | ClientTransportSession | ServerTransportSession diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index 99e8972a95..d95d3a380e 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -33,9 +33,9 @@ tool_progress, ) from mcp.client.session import ClientSession -from mcp.client.transport_session import ClientTransportSession from mcp.client.sse import sse_client from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client +from mcp.client.transport_session import ClientTransportSession from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 967925a118..0f850599a3 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -18,8 +18,8 @@ import mcp.types as types from mcp.client.session import ClientSession -from mcp.client.transport_session import ClientTransportSession from mcp.client.sse import sse_client +from mcp.client.transport_session import ClientTransportSession from mcp.server import Server from mcp.server.sse import SseServerTransport from mcp.server.transport_security import TransportSecuritySettings diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 7aa768ae1b..be80e38201 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -22,8 +22,8 @@ import mcp.types as types from mcp.client.session import ClientSession -from mcp.client.transport_session import ClientTransportSession from mcp.client.streamable_http import streamablehttp_client +from mcp.client.transport_session import ClientTransportSession from mcp.server import Server from mcp.server.session import ServerSession from mcp.server.streamable_http import ( From d4895a72e2d5515c3cb7b42b1c84fe16ba2fb4fc Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 14 Nov 2025 06:30:54 +0000 Subject: [PATCH 39/53] fix circle --- src/mcp/shared/context.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index 63fafa2418..e38fc3d5ce 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -3,8 +3,7 @@ from typing_extensions import TypeVar -from mcp.client.transport_session import ClientTransportSession -from mcp.server.transport_session import ServerTransportSession +from mcp import ServerTransportSession, ClientTransportSession from mcp.shared.session import BaseSession from mcp.types import RequestId, RequestParams From 3f0b620e1f3944d64b2785e194bab135193b48de Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 14 Nov 2025 06:31:48 +0000 Subject: [PATCH 40/53] fix readme --- README.md | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 6450adbe98..5cbb6510f3 100644 --- a/README.md +++ b/README.md @@ -208,7 +208,7 @@ from contextlib import asynccontextmanager from dataclasses import dataclass from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerTransportSession +from mcp.server.transport_session import ServerTransportSession # Mock database class for example @@ -326,7 +326,7 @@ Tools can optionally receive a Context object by including a parameter with the ```python from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerTransportSession +from mcp.server.transport_session import ServerTransportSession mcp = FastMCP(name="Progress Example") @@ -674,7 +674,7 @@ The Context object provides the following capabilities: ```python from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerTransportSession +from mcp.server.transport_session import ServerTransportSession mcp = FastMCP(name="Progress Example") @@ -798,7 +798,7 @@ Request additional information from users. This example shows an Elicitation dur from pydantic import BaseModel, Field from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerTransportSession +from mcp.server.transport_session import ServerTransportSession mcp = FastMCP(name="Elicitation Example") @@ -888,7 +888,7 @@ Tools can send logs and notifications through the context: ```python from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerTransportSession +from mcp.server.transport_session import ServerTransportSession mcp = FastMCP(name="Notifications Example") @@ -2037,8 +2037,7 @@ import os from pydantic import AnyUrl -from mcp import ClientSession, StdioServerParameters, types -from mcp.client.session import ClientTransportSession +from mcp import ClientSession, ClientTransportSession, StdioServerParameters, types from mcp.client.stdio import stdio_client from mcp.shared.context import RequestContext @@ -2168,9 +2167,8 @@ cd to the `examples/snippets` directory and run: import asyncio import os -from mcp import ClientSession, StdioServerParameters +from mcp import ClientSession, ClientTransportSession, StdioServerParameters from mcp.client.stdio import stdio_client -from mcp.client.transport_session import ClientTransportSession from mcp.shared.metadata_utils import get_display_name # Create server parameters for stdio connection From fc17b953ae64ee658af806f0261fa0ed6a48ab52 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 14 Nov 2025 06:32:07 +0000 Subject: [PATCH 41/53] fix ruff check --- src/mcp/shared/context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index e38fc3d5ce..9ec3a2f170 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -3,7 +3,7 @@ from typing_extensions import TypeVar -from mcp import ServerTransportSession, ClientTransportSession +from mcp import ClientTransportSession, ServerTransportSession from mcp.shared.session import BaseSession from mcp.types import RequestId, RequestParams From 1d2b6262ad061650ff83e132e9dab4ddb66c6faa Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 14 Nov 2025 06:33:23 +0000 Subject: [PATCH 42/53] fix circular import --- src/mcp/shared/context.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index 9ec3a2f170..7267f4954c 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -1,14 +1,16 @@ from dataclasses import dataclass -from typing import Any, Generic +from typing import TYPE_CHECKING, Any, Generic from typing_extensions import TypeVar -from mcp import ClientTransportSession, ServerTransportSession from mcp.shared.session import BaseSession from mcp.types import RequestId, RequestParams +if TYPE_CHECKING: + from mcp import ClientTransportSession, ServerTransportSession + SessionT = TypeVar( - "SessionT", bound=BaseSession[Any, Any, Any, Any, Any] | ClientTransportSession | ServerTransportSession + "SessionT", bound=BaseSession[Any, Any, Any, Any, Any] | "ClientTransportSession" | "ServerTransportSession" ) LifespanContextT = TypeVar("LifespanContextT") RequestT = TypeVar("RequestT", default=Any) From fd22fe2a62b9e320bff0e759a836be447b80af17 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 14 Nov 2025 06:27:36 +0000 Subject: [PATCH 43/53] fix imports --- README.md | 16 +++++++--------- examples/snippets/clients/display_utilities.py | 3 +-- examples/snippets/clients/stdio_client.py | 3 +-- examples/snippets/servers/elicitation.py | 2 +- examples/snippets/servers/lifespan_example.py | 2 +- examples/snippets/servers/notifications.py | 2 +- examples/snippets/servers/tool_progress.py | 2 +- src/mcp/server/elicitation.py | 2 +- src/mcp/server/fastmcp/server.py | 3 ++- src/mcp/server/lowlevel/server.py | 3 ++- src/mcp/shared/context.py | 3 +-- src/mcp/shared/memory.py | 2 +- tests/client/test_list_roots_callback.py | 2 +- tests/client/test_sampling_callback.py | 2 +- tests/client/test_session.py | 3 ++- tests/server/fastmcp/test_elicitation.py | 3 ++- tests/server/fastmcp/test_integration.py | 3 ++- tests/shared/test_memory.py | 3 ++- tests/shared/test_session.py | 3 ++- tests/shared/test_sse.py | 3 ++- tests/shared/test_streamable_http.py | 3 ++- tests/shared/test_ws.py | 3 ++- 22 files changed, 38 insertions(+), 33 deletions(-) diff --git a/README.md b/README.md index 6450adbe98..5cbb6510f3 100644 --- a/README.md +++ b/README.md @@ -208,7 +208,7 @@ from contextlib import asynccontextmanager from dataclasses import dataclass from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerTransportSession +from mcp.server.transport_session import ServerTransportSession # Mock database class for example @@ -326,7 +326,7 @@ Tools can optionally receive a Context object by including a parameter with the ```python from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerTransportSession +from mcp.server.transport_session import ServerTransportSession mcp = FastMCP(name="Progress Example") @@ -674,7 +674,7 @@ The Context object provides the following capabilities: ```python from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerTransportSession +from mcp.server.transport_session import ServerTransportSession mcp = FastMCP(name="Progress Example") @@ -798,7 +798,7 @@ Request additional information from users. This example shows an Elicitation dur from pydantic import BaseModel, Field from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerTransportSession +from mcp.server.transport_session import ServerTransportSession mcp = FastMCP(name="Elicitation Example") @@ -888,7 +888,7 @@ Tools can send logs and notifications through the context: ```python from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerTransportSession +from mcp.server.transport_session import ServerTransportSession mcp = FastMCP(name="Notifications Example") @@ -2037,8 +2037,7 @@ import os from pydantic import AnyUrl -from mcp import ClientSession, StdioServerParameters, types -from mcp.client.session import ClientTransportSession +from mcp import ClientSession, ClientTransportSession, StdioServerParameters, types from mcp.client.stdio import stdio_client from mcp.shared.context import RequestContext @@ -2168,9 +2167,8 @@ cd to the `examples/snippets` directory and run: import asyncio import os -from mcp import ClientSession, StdioServerParameters +from mcp import ClientSession, ClientTransportSession, StdioServerParameters from mcp.client.stdio import stdio_client -from mcp.client.transport_session import ClientTransportSession from mcp.shared.metadata_utils import get_display_name # Create server parameters for stdio connection diff --git a/examples/snippets/clients/display_utilities.py b/examples/snippets/clients/display_utilities.py index 5e1b203ee6..b8ad7dffc2 100644 --- a/examples/snippets/clients/display_utilities.py +++ b/examples/snippets/clients/display_utilities.py @@ -6,9 +6,8 @@ import asyncio import os -from mcp import ClientSession, StdioServerParameters +from mcp import ClientSession, ClientTransportSession, StdioServerParameters from mcp.client.stdio import stdio_client -from mcp.client.transport_session import ClientTransportSession from mcp.shared.metadata_utils import get_display_name # Create server parameters for stdio connection diff --git a/examples/snippets/clients/stdio_client.py b/examples/snippets/clients/stdio_client.py index 62fb0f4c47..90f9fdff9b 100644 --- a/examples/snippets/clients/stdio_client.py +++ b/examples/snippets/clients/stdio_client.py @@ -8,8 +8,7 @@ from pydantic import AnyUrl -from mcp import ClientSession, StdioServerParameters, types -from mcp.client.session import ClientTransportSession +from mcp import ClientSession, ClientTransportSession, StdioServerParameters, types from mcp.client.stdio import stdio_client from mcp.shared.context import RequestContext diff --git a/examples/snippets/servers/elicitation.py b/examples/snippets/servers/elicitation.py index 049b42516f..45f2cb68b9 100644 --- a/examples/snippets/servers/elicitation.py +++ b/examples/snippets/servers/elicitation.py @@ -1,7 +1,7 @@ from pydantic import BaseModel, Field from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerTransportSession +from mcp.server.transport_session import ServerTransportSession mcp = FastMCP(name="Elicitation Example") diff --git a/examples/snippets/servers/lifespan_example.py b/examples/snippets/servers/lifespan_example.py index 32b6997304..46f01f427f 100644 --- a/examples/snippets/servers/lifespan_example.py +++ b/examples/snippets/servers/lifespan_example.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerTransportSession +from mcp.server.transport_session import ServerTransportSession # Mock database class for example diff --git a/examples/snippets/servers/notifications.py b/examples/snippets/servers/notifications.py index 36d9712eb6..995ecd8178 100644 --- a/examples/snippets/servers/notifications.py +++ b/examples/snippets/servers/notifications.py @@ -1,5 +1,5 @@ from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerTransportSession +from mcp.server.transport_session import ServerTransportSession mcp = FastMCP(name="Notifications Example") diff --git a/examples/snippets/servers/tool_progress.py b/examples/snippets/servers/tool_progress.py index dddd8c9eb2..a0f62fda61 100644 --- a/examples/snippets/servers/tool_progress.py +++ b/examples/snippets/servers/tool_progress.py @@ -1,5 +1,5 @@ from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerTransportSession +from mcp.server.transport_session import ServerTransportSession mcp = FastMCP(name="Progress Example") diff --git a/src/mcp/server/elicitation.py b/src/mcp/server/elicitation.py index 65399e27c2..b2f33ec7ce 100644 --- a/src/mcp/server/elicitation.py +++ b/src/mcp/server/elicitation.py @@ -8,7 +8,7 @@ from pydantic import BaseModel from pydantic.fields import FieldInfo -from mcp.server.session import ServerTransportSession +from mcp.server.transport_session import ServerTransportSession from mcp.types import RequestId ElicitSchemaModelT = TypeVar("ElicitSchemaModelT", bound=BaseModel) diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 03e2233296..cc05403dd8 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -54,12 +54,13 @@ from mcp.server.lowlevel.server import LifespanResultT from mcp.server.lowlevel.server import Server as MCPServer from mcp.server.lowlevel.server import lifespan as default_lifespan -from mcp.server.session import ServerSessionT, ServerTransportSession +from mcp.server.session import ServerSessionT from mcp.server.sse import SseServerTransport from mcp.server.stdio import stdio_server from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings +from mcp.server.transport_session import ServerTransportSession from mcp.shared.context import LifespanContextT, RequestContext, RequestT from mcp.types import Annotations, AnyFunction, ContentBlock, GetPromptResult, Icon, ToolAnnotations from mcp.types import Prompt as MCPPrompt diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index b60e049749..85846afc6e 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -85,7 +85,8 @@ async def main(): from mcp.server.lowlevel.func_inspection import create_call_wrapper from mcp.server.lowlevel.helper_types import ReadResourceContents from mcp.server.models import InitializationOptions -from mcp.server.session import ServerSession, ServerTransportSession +from mcp.server.session import ServerSession +from mcp.server.transport_session import ServerTransportSession from mcp.shared.context import RequestContext from mcp.shared.exceptions import McpError from mcp.shared.message import ServerMessageMetadata, SessionMessage diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index 845cc50e20..7267f4954c 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -7,8 +7,7 @@ from mcp.types import RequestId, RequestParams if TYPE_CHECKING: - from mcp.client.session import ClientTransportSession - from mcp.server.session import ServerTransportSession + from mcp import ClientTransportSession, ServerTransportSession SessionT = TypeVar( "SessionT", bound=BaseSession[Any, Any, Any, Any, Any] | "ClientTransportSession" | "ServerTransportSession" diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index b8466fe91c..2d203d7430 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -15,13 +15,13 @@ import mcp.types as types from mcp.client.session import ( ClientSession, - ClientTransportSession, ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT, ) +from mcp.client.transport_session import ClientTransportSession from mcp.server import Server from mcp.server.fastmcp import FastMCP from mcp.shared.message import SessionMessage diff --git a/tests/client/test_list_roots_callback.py b/tests/client/test_list_roots_callback.py index dc53eddbcb..5acb3b21aa 100644 --- a/tests/client/test_list_roots_callback.py +++ b/tests/client/test_list_roots_callback.py @@ -1,7 +1,7 @@ import pytest from pydantic import FileUrl -from mcp.client.session import ClientTransportSession +from mcp.client.transport_session import ClientTransportSession from mcp.server.fastmcp.server import Context from mcp.server.session import ServerSession from mcp.shared.context import RequestContext diff --git a/tests/client/test_sampling_callback.py b/tests/client/test_sampling_callback.py index 49138398c0..9fb6e29c75 100644 --- a/tests/client/test_sampling_callback.py +++ b/tests/client/test_sampling_callback.py @@ -1,6 +1,6 @@ import pytest -from mcp.client.session import ClientTransportSession +from mcp.client.transport_session import ClientTransportSession from mcp.server.session import ServerSession from mcp.shared.context import RequestContext from mcp.shared.memory import ( diff --git a/tests/client/test_session.py b/tests/client/test_session.py index c327a806f6..bd51e4e102 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -4,7 +4,8 @@ import pytest import mcp.types as types -from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession, ClientTransportSession +from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession +from mcp.client.transport_session import ClientTransportSession from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder diff --git a/tests/server/fastmcp/test_elicitation.py b/tests/server/fastmcp/test_elicitation.py index dd1ae72dc7..77f97e6772 100644 --- a/tests/server/fastmcp/test_elicitation.py +++ b/tests/server/fastmcp/test_elicitation.py @@ -7,7 +7,8 @@ import pytest from pydantic import BaseModel, Field -from mcp.client.session import ClientTransportSession, ElicitationFnT +from mcp.client.session import ElicitationFnT +from mcp.client.transport_session import ClientTransportSession from mcp.server.fastmcp import Context, FastMCP from mcp.server.session import ServerSession from mcp.shared.context import RequestContext diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index 778b0bfd7a..d95d3a380e 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -32,9 +32,10 @@ structured_output, tool_progress, ) -from mcp.client.session import ClientSession, ClientTransportSession +from mcp.client.session import ClientSession from mcp.client.sse import sse_client from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client +from mcp.client.transport_session import ClientTransportSession from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder diff --git a/tests/shared/test_memory.py b/tests/shared/test_memory.py index 4ebce3b15d..56e0b98e70 100644 --- a/tests/shared/test_memory.py +++ b/tests/shared/test_memory.py @@ -2,7 +2,8 @@ from pydantic import AnyUrl from typing_extensions import AsyncGenerator -from mcp.client.session import ClientSession, ClientTransportSession +from mcp.client.session import ClientSession +from mcp.client.transport_session import ClientTransportSession from mcp.server import Server from mcp.shared.memory import create_connected_server_and_client_session from mcp.types import EmptyResult, Resource diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index 47b5a02f62..a056f705ba 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -5,7 +5,8 @@ import pytest import mcp.types as types -from mcp.client.session import ClientSession, ClientTransportSession +from mcp.client.session import ClientSession +from mcp.client.transport_session import ClientTransportSession from mcp.server.lowlevel.server import Server from mcp.shared.exceptions import McpError from mcp.shared.memory import create_client_server_memory_streams, create_connected_server_and_client_session diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index ba823ab6ad..0f850599a3 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -17,8 +17,9 @@ from starlette.routing import Mount, Route import mcp.types as types -from mcp.client.session import ClientSession, ClientTransportSession +from mcp.client.session import ClientSession from mcp.client.sse import sse_client +from mcp.client.transport_session import ClientTransportSession from mcp.server import Server from mcp.server.sse import SseServerTransport from mcp.server.transport_security import TransportSecuritySettings diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 603a4270a6..be80e38201 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -21,8 +21,9 @@ from starlette.routing import Mount import mcp.types as types -from mcp.client.session import ClientSession, ClientTransportSession +from mcp.client.session import ClientSession from mcp.client.streamable_http import streamablehttp_client +from mcp.client.transport_session import ClientTransportSession from mcp.server import Server from mcp.server.session import ServerSession from mcp.server.streamable_http import ( diff --git a/tests/shared/test_ws.py b/tests/shared/test_ws.py index 1fac8696f7..107cd5589e 100644 --- a/tests/shared/test_ws.py +++ b/tests/shared/test_ws.py @@ -12,7 +12,8 @@ from starlette.routing import WebSocketRoute from starlette.websockets import WebSocket -from mcp.client.session import ClientSession, ClientTransportSession +from mcp.client.session import ClientSession +from mcp.client.transport_session import ClientTransportSession from mcp.client.websocket import websocket_client from mcp.server import Server from mcp.server.websocket import websocket_server From f36e9399d45a431330a79d2b3b18c2dd6820e167 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 14 Nov 2025 07:00:48 +0000 Subject: [PATCH 44/53] fix some more type hints --- tests/server/fastmcp/test_elicitation.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/server/fastmcp/test_elicitation.py b/tests/server/fastmcp/test_elicitation.py index 77f97e6772..52e6799b7b 100644 --- a/tests/server/fastmcp/test_elicitation.py +++ b/tests/server/fastmcp/test_elicitation.py @@ -10,7 +10,7 @@ from mcp.client.session import ElicitationFnT from mcp.client.transport_session import ClientTransportSession from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerSession +from mcp.server.transport_session import ServerTransportSession from mcp.shared.context import RequestContext from mcp.shared.memory import create_connected_server_and_client_session from mcp.types import ElicitRequestParams, ElicitResult, TextContent @@ -25,7 +25,7 @@ def create_ask_user_tool(mcp: FastMCP): """Create a standard ask_user tool that handles all elicitation responses.""" @mcp.tool(description="A tool that uses elicitation") - async def ask_user(prompt: str, ctx: Context[ServerSession, None]) -> str: + async def ask_user(prompt: str, ctx: Context[ServerTransportSession, None]) -> str: result = await ctx.elicit(message=f"Tool wants to ask: {prompt}", schema=AnswerSchema) if result.action == "accept" and result.data: @@ -106,7 +106,7 @@ async def test_elicitation_schema_validation(): def create_validation_tool(name: str, schema_class: type[BaseModel]): @mcp.tool(name=name, description=f"Tool testing {name}") - async def tool(ctx: Context[ServerSession, None]) -> str: # pragma: no cover + async def tool(ctx: Context[ServerTransportSession, None]) -> str: # pragma: no cover try: await ctx.elicit(message="This should fail validation", schema=schema_class) return "Should not reach here" @@ -160,7 +160,7 @@ class OptionalSchema(BaseModel): subscribe: bool | None = Field(default=False, description="Subscribe to newsletter?") @mcp.tool(description="Tool with optional fields") - async def optional_tool(ctx: Context[ServerSession, None]) -> str: + async def optional_tool(ctx: Context[ServerTransportSession, None]) -> str: result = await ctx.elicit(message="Please provide your information", schema=OptionalSchema) if result.action == "accept" and result.data: @@ -201,7 +201,7 @@ class InvalidOptionalSchema(BaseModel): optional_list: list[str] | None = Field(default=None, description="Invalid optional list") @mcp.tool(description="Tool with invalid optional field") - async def invalid_optional_tool(ctx: Context[ServerSession, None]) -> str: # pragma: no cover + async def invalid_optional_tool(ctx: Context[ServerTransportSession, None]) -> str: # pragma: no cover try: await ctx.elicit(message="This should fail", schema=InvalidOptionalSchema) return "Should not reach here" @@ -234,7 +234,7 @@ class DefaultsSchema(BaseModel): email: str = Field(description="Email address (required)") @mcp.tool(description="Tool with default values") - async def defaults_tool(ctx: Context[ServerSession, None]) -> str: + async def defaults_tool(ctx: Context[ServerTransportSession, None]) -> str: result = await ctx.elicit(message="Please provide your information", schema=DefaultsSchema) if result.action == "accept" and result.data: From 48324fe6f8e148ee4a205f1bd24a20bba2d9fefc Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 23 Jan 2026 12:46:43 +0000 Subject: [PATCH 45/53] fixes ci --- tests/server/fastmcp/test_elicitation.py | 1 + tests/server/test_cancel_handling.py | 2 -- tests/shared/test_progress_notifications.py | 1 + tests/shared/test_session.py | 14 -------------- 4 files changed, 2 insertions(+), 16 deletions(-) diff --git a/tests/server/fastmcp/test_elicitation.py b/tests/server/fastmcp/test_elicitation.py index b4c6e6c447..4679a38920 100644 --- a/tests/server/fastmcp/test_elicitation.py +++ b/tests/server/fastmcp/test_elicitation.py @@ -10,6 +10,7 @@ from mcp.client.session import ElicitationFnT from mcp.client.transport_session import ClientTransportSession from mcp.server.fastmcp import Context, FastMCP +from mcp.server.session import ServerSession from mcp.server.transport_session import ServerTransportSession from mcp.shared.context import RequestContext from mcp.types import ElicitRequestParams, ElicitResult, TextContent diff --git a/tests/server/test_cancel_handling.py b/tests/server/test_cancel_handling.py index 152d39ac07..49030cbcc5 100644 --- a/tests/server/test_cancel_handling.py +++ b/tests/server/test_cancel_handling.py @@ -55,8 +55,6 @@ async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[ async with Client(server) as client: # First request (will be cancelled) - assert isinstance(client, ClientSession) - async def first_request(): try: await client.session.send_request( diff --git a/tests/shared/test_progress_notifications.py b/tests/shared/test_progress_notifications.py index f9e61bd028..13edcec01d 100644 --- a/tests/shared/test_progress_notifications.py +++ b/tests/shared/test_progress_notifications.py @@ -369,6 +369,7 @@ async def handle_list_tools() -> list[types.Tool]: # Test with mocked logging with patch("mcp.shared.session.logging.exception", side_effect=mock_log_exception): async with Client(server) as client: + # Call tool with a failing progress callback result = await client.call_tool( "progress_tool", arguments={}, diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index 41d3cec51c..c3b27f2049 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -23,26 +23,12 @@ ) -@pytest.fixture -def mcp_server() -> Server: - return Server(name="test server") - - -@pytest.fixture -async def client_connected_to_server( - mcp_server: Server, -) -> AsyncGenerator[ClientTransportSession, None]: - async with create_connected_server_and_client_session(mcp_server) as client_session: - yield client_session - - @pytest.mark.anyio async def test_in_flight_requests_cleared_after_completion(): """Verify that _in_flight is empty after all requests complete.""" # Send a request and wait for response server = Server(name="test server") async with Client(server) as client: - # Send a request and wait for response response = await client.send_ping() assert isinstance(response, EmptyResult) From 4e6800eea8852245a99d7e5d05df36a9b8b728d0 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 23 Jan 2026 12:47:11 +0000 Subject: [PATCH 46/53] fixes tests --- tests/shared/test_streamable_http.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 7c639d102f..3773caa4f1 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -27,7 +27,7 @@ import mcp.types as types from mcp.client.session import ClientSession -from mcp.client.streamable_http import streamablehttp_client +from mcp.client.streamable_http import streamable_http_client from mcp.client.transport_session import ClientTransportSession from mcp.client.streamable_http import StreamableHTTPTransport, streamable_http_client from mcp.server import Server From a85068e6ec24bf75a55c54e5b131f5761887b19a Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 23 Jan 2026 13:03:15 +0000 Subject: [PATCH 47/53] some pyright fixes --- .../main.py | 6 ++--- .../mcp_simple_streamablehttp/server.py | 4 +++- examples/snippets/clients/stdio_client.py | 1 - .../clients/url_elicitation_client.py | 4 ++-- src/mcp/client/transport_session.py | 24 ++++++++++--------- tests/server/test_cancel_handling.py | 1 - tests/shared/test_memory.py | 2 -- 7 files changed, 21 insertions(+), 21 deletions(-) diff --git a/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/main.py b/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/main.py index 5f34eb9491..a5204ca279 100644 --- a/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/main.py +++ b/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/main.py @@ -10,7 +10,7 @@ from typing import Any import click -from mcp import ClientSession +from mcp import ClientSession, ClientTransportSession from mcp.client.streamable_http import streamable_http_client from mcp.shared.context import RequestContext from mcp.types import ( @@ -24,7 +24,7 @@ async def elicitation_callback( - context: RequestContext[ClientSession, Any], + context: RequestContext[ClientTransportSession, Any], params: ElicitRequestParams, ) -> ElicitResult: """Handle elicitation requests from the server.""" @@ -39,7 +39,7 @@ async def elicitation_callback( async def sampling_callback( - context: RequestContext[ClientSession, Any], + context: RequestContext[ClientTransportSession, Any], params: CreateMessageRequestParams, ) -> CreateMessageResult: """Handle sampling requests from the server.""" diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py index bb09c119f0..d7a83c37a0 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py @@ -13,6 +13,8 @@ from starlette.routing import Mount from starlette.types import Receive, Scope, Send +from pydantic import AnyUrl + from .event_store import InMemoryEventStore # Configure logging @@ -73,7 +75,7 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentB # This will send a resource notificaiton though standalone SSE # established by GET request - await ctx.session.send_resource_updated(uri="http:///test_resource") + await ctx.session.send_resource_updated(uri=AnyUrl("http:///test_resource")) return [ types.TextContent( type="text", diff --git a/examples/snippets/clients/stdio_client.py b/examples/snippets/clients/stdio_client.py index bf3cc9558a..bcb8b5a5ff 100644 --- a/examples/snippets/clients/stdio_client.py +++ b/examples/snippets/clients/stdio_client.py @@ -6,7 +6,6 @@ import os from mcp import ClientSession, StdioServerParameters, types -from pydantic import AnyUrl from mcp import ClientSession, ClientTransportSession, StdioServerParameters, types from mcp.client.stdio import stdio_client diff --git a/examples/snippets/clients/url_elicitation_client.py b/examples/snippets/clients/url_elicitation_client.py index 300c38fa0c..1b8ec5431e 100644 --- a/examples/snippets/clients/url_elicitation_client.py +++ b/examples/snippets/clients/url_elicitation_client.py @@ -30,7 +30,7 @@ from typing import Any from urllib.parse import urlparse -from mcp import ClientSession, types +from mcp import ClientSession, ClientTransportSession, types from mcp.client.sse import sse_client from mcp.shared.context import RequestContext from mcp.shared.exceptions import McpError, UrlElicitationRequiredError @@ -38,7 +38,7 @@ async def handle_elicitation( - context: RequestContext[ClientSession, Any], + context: RequestContext[ClientTransportSession, Any], params: types.ElicitRequestParams, ) -> types.ElicitResult | types.ErrorData: """Handle elicitation requests from the server. diff --git a/src/mcp/client/transport_session.py b/src/mcp/client/transport_session.py index 07389d59a0..3f221efdb7 100644 --- a/src/mcp/client/transport_session.py +++ b/src/mcp/client/transport_session.py @@ -41,19 +41,19 @@ async def set_logging_level( raise NotImplementedError @abstractmethod - async def list_resources( - self, - cursor: str | None = None, - ) -> types.ListResourcesResult: + async def list_resources(self, *, params: types.PaginatedRequestParams | None = None) -> types.ListResourcesResult: """Send a resources/list request.""" raise NotImplementedError @abstractmethod async def list_resource_templates( - self, - cursor: str | None = None, + self, *, params: types.PaginatedRequestParams | None = None ) -> types.ListResourceTemplatesResult: - """Send a resources/templates/list request.""" + """Send a resources/templates/list request. + + Args: + params: Full pagination parameters including cursor and any future fields + """ raise NotImplementedError @abstractmethod @@ -62,12 +62,12 @@ async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult: raise NotImplementedError @abstractmethod - async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: + async def subscribe_resource(self, uri: str, *, meta: RequestParamsMeta | None = None) -> types.EmptyResult: """Send a resources/subscribe request.""" raise NotImplementedError @abstractmethod - async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: + async def unsubscribe_resource(self, uri: str, *, meta: RequestParamsMeta | None = None) -> types.EmptyResult: """Send a resources/unsubscribe request.""" raise NotImplementedError @@ -75,9 +75,11 @@ async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: async def call_tool( self, name: str, - arguments: Any | None = None, - read_timeout_seconds: timedelta | None = None, + arguments: dict[str, Any] | None = None, + read_timeout_seconds: float | None = None, progress_callback: ProgressFnT | None = None, + *, + meta: RequestParamsMeta | None = None, ) -> types.CallToolResult: """Send a tools/call request with optional progress callback support.""" raise NotImplementedError diff --git a/tests/server/test_cancel_handling.py b/tests/server/test_cancel_handling.py index 49030cbcc5..98f34df465 100644 --- a/tests/server/test_cancel_handling.py +++ b/tests/server/test_cancel_handling.py @@ -6,7 +6,6 @@ import pytest import mcp.types as types -from mcp.client.session import ClientSession from mcp import Client from mcp.server.lowlevel.server import Server from mcp.shared.exceptions import McpError diff --git a/tests/shared/test_memory.py b/tests/shared/test_memory.py index c199441bb8..31238b9ffd 100644 --- a/tests/shared/test_memory.py +++ b/tests/shared/test_memory.py @@ -1,8 +1,6 @@ import pytest from mcp import Client -from mcp.client.session import ClientSession -from mcp.client.transport_session import ClientTransportSession from mcp.server import Server from mcp.types import EmptyResult, Resource From f0fac5c92ee00c73ecebfb8e7f43ac6f506859be Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 23 Jan 2026 13:48:49 +0000 Subject: [PATCH 48/53] fix some pyright --- .../mcp_simple_streamablehttp/server.py | 3 +- src/mcp/client/session.py | 8 ++--- src/mcp/client/transport_session.py | 14 ++++----- src/mcp/server/elicitation.py | 9 ++++-- .../experimental/task_result_handler.py | 8 ++--- src/mcp/server/session.py | 4 +-- src/mcp/server/transport_session.py | 31 +++++++++++++++++-- src/mcp/shared/context.py | 4 ++- src/mcp/shared/session.py | 20 ++++++------ tests/client/test_list_roots_callback.py | 5 ++- tests/client/test_sampling_callback.py | 7 +++-- tests/shared/test_session.py | 1 - 12 files changed, 71 insertions(+), 43 deletions(-) diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py index d7a83c37a0..4c3dde4e2c 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py @@ -13,7 +13,6 @@ from starlette.routing import Mount from starlette.types import Receive, Scope, Send -from pydantic import AnyUrl from .event_store import InMemoryEventStore @@ -75,7 +74,7 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentB # This will send a resource notificaiton though standalone SSE # established by GET request - await ctx.session.send_resource_updated(uri=AnyUrl("http:///test_resource")) + await ctx.session.send_resource_updated(uri="http:///test_resource") return [ types.TextContent( type="text", diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index cddc7e38be..c28fb05e76 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -23,7 +23,7 @@ class SamplingFnT(Protocol): async def __call__( self, - context: RequestContext["ClientTransportSession", Any], + context: RequestContext["ClientSession", Any], params: types.CreateMessageRequestParams, ) -> types.CreateMessageResult | types.CreateMessageResultWithTools | types.ErrorData: ... # pragma: no branch @@ -31,14 +31,14 @@ async def __call__( class ElicitationFnT(Protocol): async def __call__( self, - context: RequestContext["ClientTransportSession", Any], + context: RequestContext["ClientSession", Any], params: types.ElicitRequestParams, ) -> types.ElicitResult | types.ErrorData: ... # pragma: no branch class ListRootsFnT(Protocol): async def __call__( - self, context: RequestContext["ClientTransportSession", Any] + self, context: RequestContext["ClientSession", Any] ) -> types.ListRootsResult | types.ErrorData: ... # pragma: no branch @@ -418,7 +418,7 @@ async def send_roots_list_changed(self) -> None: # pragma: no cover await self.send_notification(types.RootsListChangedNotification()) async def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None: - ctx = RequestContext[ClientTransportSession, Any]( + ctx = RequestContext[ClientSession, Any]( request_id=responder.request_id, meta=responder.request_meta, session=self, diff --git a/src/mcp/client/transport_session.py b/src/mcp/client/transport_session.py index 3f221efdb7..d6b03dc9d3 100644 --- a/src/mcp/client/transport_session.py +++ b/src/mcp/client/transport_session.py @@ -1,11 +1,9 @@ from abc import ABC, abstractmethod -from datetime import timedelta from typing import Any -from pydantic import AnyUrl - import mcp.types as types from mcp.shared.session import ProgressFnT +from mcp.types import RequestParamsMeta class ClientTransportSession(ABC): @@ -57,17 +55,17 @@ async def list_resource_templates( raise NotImplementedError @abstractmethod - async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult: + async def read_resource(self, uri: str) -> types.ReadResourceResult: """Send a resources/read request.""" raise NotImplementedError @abstractmethod - async def subscribe_resource(self, uri: str, *, meta: RequestParamsMeta | None = None) -> types.EmptyResult: + async def subscribe_resource(self, uri: str) -> types.EmptyResult: """Send a resources/subscribe request.""" raise NotImplementedError @abstractmethod - async def unsubscribe_resource(self, uri: str, *, meta: RequestParamsMeta | None = None) -> types.EmptyResult: + async def unsubscribe_resource(self, uri: str) -> types.EmptyResult: """Send a resources/unsubscribe request.""" raise NotImplementedError @@ -87,7 +85,8 @@ async def call_tool( @abstractmethod async def list_prompts( self, - cursor: str | None = None, + *, + params: types.PaginatedRequestParams | None = None, ) -> types.ListPromptsResult: """Send a prompts/list request.""" raise NotImplementedError @@ -114,7 +113,6 @@ async def complete( @abstractmethod async def list_tools( self, - cursor: str | None = None, *, params: types.PaginatedRequestParams | None = None, ) -> types.ListToolsResult: diff --git a/src/mcp/server/elicitation.py b/src/mcp/server/elicitation.py index e3c7f79081..b07cb1a3ef 100644 --- a/src/mcp/server/elicitation.py +++ b/src/mcp/server/elicitation.py @@ -9,7 +9,10 @@ from pydantic import BaseModel from mcp.server.transport_session import ServerTransportSession -from mcp.types import RequestId +from mcp.types import ( + ElicitResult, + RequestId, +) ElicitSchemaModelT = TypeVar("ElicitSchemaModelT", bound=BaseModel) @@ -123,7 +126,7 @@ async def elicit_with_validation( json_schema = schema.model_json_schema() - result = await session.elicit_form( + result: ElicitResult = await session.elicit( message=message, requested_schema=json_schema, related_request_id=related_request_id, @@ -143,7 +146,7 @@ async def elicit_with_validation( async def elicit_url( - session: ServerSession, + session: ServerTransportSession, message: str, url: str, elicitation_id: str, diff --git a/src/mcp/server/experimental/task_result_handler.py b/src/mcp/server/experimental/task_result_handler.py index 4d763ef0e6..7390b49cb3 100644 --- a/src/mcp/server/experimental/task_result_handler.py +++ b/src/mcp/server/experimental/task_result_handler.py @@ -14,7 +14,7 @@ import anyio -from mcp.server.session import ServerSession +from mcp.server.session import ServerTransportSession from mcp.shared.exceptions import McpError from mcp.shared.experimental.tasks.helpers import RELATED_TASK_METADATA_KEY, is_terminal from mcp.shared.experimental.tasks.message_queue import TaskMessageQueue @@ -69,7 +69,7 @@ def __init__( async def send_message( self, - session: ServerSession, + session: ServerTransportSession, message: SessionMessage, ) -> None: """Send a message via the session. @@ -81,7 +81,7 @@ async def send_message( async def handle( self, request: GetTaskPayloadRequest, - session: ServerSession, + session: ServerTransportSession, request_id: RequestId, ) -> GetTaskPayloadResult: """Handle a tasks/result request. @@ -131,7 +131,7 @@ async def handle( async def _deliver_queued_messages( self, task_id: str, - session: ServerSession, + session: ServerTransportSession, request_id: RequestId, ) -> None: """Dequeue and send all pending messages for a task. diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index e5cd0e9819..08c7bb711b 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -42,7 +42,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: import anyio import anyio.lowlevel from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from pydantic import AnyUrl, TypeAdapter +from pydantic import TypeAdapter import mcp.types as types from mcp.server.experimental.session_features import ExperimentalServerSessionFeatures @@ -232,7 +232,7 @@ async def send_log_message( related_request_id, ) - async def send_resource_updated(self, uri: str | AnyUrl) -> None: # pragma: no cover + async def send_resource_updated(self, uri: str) -> None: # pragma: no cover """Send a resource updated notification.""" await self.send_notification( types.ResourceUpdatedNotification( diff --git a/src/mcp/server/transport_session.py b/src/mcp/server/transport_session.py index bf3f6a1d1c..4b3cb4546f 100644 --- a/src/mcp/server/transport_session.py +++ b/src/mcp/server/transport_session.py @@ -3,13 +3,17 @@ from abc import ABC, abstractmethod from typing import Any -from pydantic import AnyUrl import mcp.types as types +from mcp.shared.message import SessionMessage class ServerTransportSession(ABC): """Abstract base class for transport sessions.""" + @abstractmethod + async def send_message(self, message: SessionMessage) -> None: + """Send a raw session message.""" + raise NotImplementedError @abstractmethod async def send_log_message( @@ -23,7 +27,7 @@ async def send_log_message( raise NotImplementedError @abstractmethod - async def send_resource_updated(self, uri: AnyUrl) -> None: + async def send_resource_updated(self, uri: str) -> None: """Send a resource updated notification.""" raise NotImplementedError @@ -36,12 +40,33 @@ async def list_roots(self) -> types.ListRootsResult: async def elicit( self, message: str, - requestedSchema: types.ElicitRequestedSchema, + requested_schema: types.ElicitRequestedSchema, related_request_id: types.RequestId | None = None, ) -> types.ElicitResult: """Send an elicitation/create request.""" raise NotImplementedError + @abstractmethod + async def elicit_form( + self, + message: str, + requested_schema: types.ElicitRequestedSchema, + related_request_id: types.RequestId | None = None, + ) -> types.ElicitResult: + """Send a form mode elicitation/create request.""" + raise NotImplementedError + + @abstractmethod + async def elicit_url( + self, + message: str, + url: str, + elicitation_id: str, + related_request_id: types.RequestId | None = None, + ) -> types.ElicitResult: + """Send a URL mode elicitation/create request.""" + raise NotImplementedError + @abstractmethod async def send_ping(self) -> types.EmptyResult: """Send a ping request.""" diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index 4ef3626155..0044396439 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -13,7 +13,9 @@ from mcp import ClientTransportSession, ServerTransportSession SessionT = TypeVar( - "SessionT", bound=BaseSession[Any, Any, Any, Any, Any] | "ClientTransportSession" | "ServerTransportSession" + "SessionT", + bound=BaseSession[Any, Any, Any, Any, Any] | "ClientTransportSession" | "ServerTransportSession", + covariant=True, ) LifespanContextT = TypeVar("LifespanContextT") RequestT = TypeVar("RequestT", default=Any) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 341e1fac03..afa6598eb0 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -153,8 +153,19 @@ def in_flight(self) -> bool: # pragma: no cover def cancelled(self) -> bool: # pragma: no cover return self._cancel_scope.cancel_called +class Session: + """Base class for a session that can send progress notifications.""" + async def send_progress_notification( + self, + progress_token: ProgressToken, + progress: float, + total: float | None = None, + message: str | None = None, + ) -> None: + """Sends a progress notification for a request that is currently being processed.""" class BaseSession( + Session, Generic[ SendRequestT, SendNotificationT, @@ -500,15 +511,6 @@ async def _received_notification(self, notification: ReceiveNotificationT) -> No to listen on the message stream. """ - async def send_progress_notification( - self, - progress_token: ProgressToken, - progress: float, - total: float | None = None, - message: str | None = None, - ) -> None: - """Sends a progress notification for a request that is currently being processed.""" - async def _handle_incoming( self, req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception ) -> None: diff --git a/tests/client/test_list_roots_callback.py b/tests/client/test_list_roots_callback.py index bc1f8eca2d..1b416d1ca7 100644 --- a/tests/client/test_list_roots_callback.py +++ b/tests/client/test_list_roots_callback.py @@ -1,9 +1,8 @@ import pytest from pydantic import FileUrl -from mcp.client.transport_session import ClientTransportSession -from mcp import Client from mcp.client.session import ClientSession +from mcp import Client from mcp.server.fastmcp import FastMCP from mcp.server.fastmcp.server import Context from mcp.server.session import ServerSession @@ -29,7 +28,7 @@ async def test_list_roots_callback(): ) async def list_roots_callback( - context: RequestContext[ClientTransportSession, None], + context: RequestContext[ClientSession, None], ) -> ListRootsResult: return callback_return diff --git a/tests/client/test_sampling_callback.py b/tests/client/test_sampling_callback.py index f45878fae4..ebd24a857b 100644 --- a/tests/client/test_sampling_callback.py +++ b/tests/client/test_sampling_callback.py @@ -1,6 +1,5 @@ import pytest -from mcp.client.transport_session import ClientTransportSession from mcp.server.session import ServerSession from mcp import Client from mcp.client.session import ClientSession @@ -28,7 +27,7 @@ async def test_sampling_callback(): ) async def sampling_callback( - context: RequestContext[ClientTransportSession, None], + context: RequestContext[ClientSession, None], params: CreateMessageRequestParams, ) -> CreateMessageResult: return callback_return @@ -83,7 +82,9 @@ async def sampling_callback( @server.tool("test_backwards_compat") async def test_tool(message: str): # Call create_message WITHOUT tools - result = await server.get_context().session.create_message( + session = server.get_context().session + assert isinstance(session, ServerSession) + result = await session.create_message( messages=[SamplingMessage(role="user", content=TextContent(type="text", text=message))], max_tokens=100, ) diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index c3b27f2049..68df1e58c0 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -6,7 +6,6 @@ import mcp.types as types from mcp import Client from mcp.client.session import ClientSession -from mcp.client.transport_session import ClientTransportSession from mcp.server.lowlevel.server import Server from mcp.shared.exceptions import McpError from mcp.shared.memory import create_client_server_memory_streams From f0a8a41408f058ee3f9d263661f3630fbfbd9d71 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 23 Jan 2026 13:50:17 +0000 Subject: [PATCH 49/53] fix pyright --- .../experimental/tasks/test_elicitation_scenarios.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/experimental/tasks/test_elicitation_scenarios.py b/tests/experimental/tasks/test_elicitation_scenarios.py index 1cefe847da..893979fb45 100644 --- a/tests/experimental/tasks/test_elicitation_scenarios.py +++ b/tests/experimental/tasks/test_elicitation_scenarios.py @@ -19,6 +19,7 @@ from mcp.client.session import ClientSession from mcp.server import Server from mcp.server.experimental.task_context import ServerTaskContext +from mcp.server.session import ServerSession from mcp.server.lowlevel import NotificationOptions from mcp.shared.context import RequestContext from mcp.shared.experimental.tasks.helpers import is_terminal @@ -283,8 +284,11 @@ async def list_tools() -> list[Tool]: async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: ctx = server.request_context + session = ctx.session + assert isinstance(session, ServerSession) + # Task-augmented elicitation - server polls client - result = await ctx.session.experimental.elicit_as_task( + result = await session.experimental.elicit_as_task( message="Please confirm the action", requested_schema={"type": "object", "properties": {"confirm": {"type": "boolean"}}}, ttl=60000, @@ -574,8 +578,11 @@ async def list_tools() -> list[Tool]: async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: ctx = server.request_context + session = ctx.session + assert isinstance(session, ServerSession) + # Task-augmented sampling - server polls client - result = await ctx.session.experimental.create_message_as_task( + result = await session.experimental.create_message_as_task( messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Hello"))], max_tokens=100, ttl=60000, From e522882d0428afb520f5edb3525148b7279b723f Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 23 Jan 2026 13:53:32 +0000 Subject: [PATCH 50/53] ruff fixes --- .../mcp_simple_streamablehttp/server.py | 1 - examples/snippets/clients/stdio_client.py | 2 -- examples/snippets/servers/elicitation.py | 2 +- src/mcp/__init__.py | 3 +-- src/mcp/server/lowlevel/server.py | 2 +- src/mcp/server/session.py | 2 +- src/mcp/server/transport_session.py | 2 +- src/mcp/shared/context.py | 8 ++++---- src/mcp/shared/session.py | 3 +++ tests/client/test_list_roots_callback.py | 2 +- tests/client/test_sampling_callback.py | 2 +- tests/experimental/tasks/test_elicitation_scenarios.py | 2 +- tests/server/fastmcp/test_elicitation.py | 1 - tests/server/fastmcp/test_integration.py | 2 +- tests/shared/test_sse.py | 1 - tests/shared/test_streamable_http.py | 3 +-- 16 files changed, 17 insertions(+), 21 deletions(-) diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py index 4c3dde4e2c..bb09c119f0 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py @@ -13,7 +13,6 @@ from starlette.routing import Mount from starlette.types import Receive, Scope, Send - from .event_store import InMemoryEventStore # Configure logging diff --git a/examples/snippets/clients/stdio_client.py b/examples/snippets/clients/stdio_client.py index bcb8b5a5ff..fe07f24f0a 100644 --- a/examples/snippets/clients/stdio_client.py +++ b/examples/snippets/clients/stdio_client.py @@ -5,8 +5,6 @@ import asyncio import os -from mcp import ClientSession, StdioServerParameters, types - from mcp import ClientSession, ClientTransportSession, StdioServerParameters, types from mcp.client.stdio import stdio_client from mcp.shared.context import RequestContext diff --git a/examples/snippets/servers/elicitation.py b/examples/snippets/servers/elicitation.py index 6d73f0c711..bc7c17ce69 100644 --- a/examples/snippets/servers/elicitation.py +++ b/examples/snippets/servers/elicitation.py @@ -11,9 +11,9 @@ from mcp.server.fastmcp import Context, FastMCP from mcp.server.session import ServerSession +from mcp.server.transport_session import ServerTransportSession from mcp.shared.exceptions import UrlElicitationRequiredError from mcp.types import ElicitRequestURLParams -from mcp.server.transport_session import ServerTransportSession mcp = FastMCP(name="Elicitation Example") diff --git a/src/mcp/__init__.py b/src/mcp/__init__.py index 088bcb013f..7cc1fec7a7 100644 --- a/src/mcp/__init__.py +++ b/src/mcp/__init__.py @@ -5,9 +5,8 @@ from .client.transport_session import ClientTransportSession from .server.session import ServerSession from .server.stdio import stdio_server -from .shared.exceptions import McpError, UrlElicitationRequiredError from .server.transport_session import ServerTransportSession -from .shared.exceptions import McpError +from .shared.exceptions import McpError, UrlElicitationRequiredError from .types import ( CallToolRequest, ClientCapabilities, diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index a06d13326a..667ca5e897 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -97,10 +97,10 @@ async def main(): from mcp.server.lowlevel.helper_types import ReadResourceContents from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession -from mcp.server.transport_session import ServerTransportSession from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPASGIApp, StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings +from mcp.server.transport_session import ServerTransportSession from mcp.shared.context import RequestContext from mcp.shared.exceptions import McpError, UrlElicitationRequiredError from mcp.shared.message import ServerMessageMetadata, SessionMessage diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 08c7bb711b..68fc74dbe9 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -47,11 +47,11 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: import mcp.types as types from mcp.server.experimental.session_features import ExperimentalServerSessionFeatures from mcp.server.models import InitializationOptions +from mcp.server.transport_session import ServerTransportSession from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages from mcp.shared.exceptions import StatelessModeNotSupported from mcp.shared.experimental.tasks.capabilities import check_tasks_capability from mcp.shared.experimental.tasks.helpers import RELATED_TASK_METADATA_KEY -from mcp.server.transport_session import ServerTransportSession from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.session import ( BaseSession, diff --git a/src/mcp/server/transport_session.py b/src/mcp/server/transport_session.py index 4b3cb4546f..c6077437ab 100644 --- a/src/mcp/server/transport_session.py +++ b/src/mcp/server/transport_session.py @@ -3,13 +3,13 @@ from abc import ABC, abstractmethod from typing import Any - import mcp.types as types from mcp.shared.message import SessionMessage class ServerTransportSession(ABC): """Abstract base class for transport sessions.""" + @abstractmethod async def send_message(self, message: SessionMessage) -> None: """Send a raw session message.""" diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index 0044396439..67d963bd59 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -12,8 +12,8 @@ if TYPE_CHECKING: from mcp import ClientTransportSession, ServerTransportSession -SessionT = TypeVar( - "SessionT", +SessionT_co = TypeVar( + "SessionT_co", bound=BaseSession[Any, Any, Any, Any, Any] | "ClientTransportSession" | "ServerTransportSession", covariant=True, ) @@ -22,10 +22,10 @@ @dataclass -class RequestContext(Generic[SessionT, LifespanContextT, RequestT]): +class RequestContext(Generic[SessionT_co, LifespanContextT, RequestT]): request_id: RequestId meta: RequestParamsMeta | None - session: SessionT + session: SessionT_co lifespan_context: LifespanContextT # NOTE: This is typed as Any to avoid circular imports. The actual type is # mcp.server.experimental.request_context.Experimental, but importing it here diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index afa6598eb0..e99dd6c9fa 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -153,8 +153,10 @@ def in_flight(self) -> bool: # pragma: no cover def cancelled(self) -> bool: # pragma: no cover return self._cancel_scope.cancel_called + class Session: """Base class for a session that can send progress notifications.""" + async def send_progress_notification( self, progress_token: ProgressToken, @@ -164,6 +166,7 @@ async def send_progress_notification( ) -> None: """Sends a progress notification for a request that is currently being processed.""" + class BaseSession( Session, Generic[ diff --git a/tests/client/test_list_roots_callback.py b/tests/client/test_list_roots_callback.py index 1b416d1ca7..a8f8823fe5 100644 --- a/tests/client/test_list_roots_callback.py +++ b/tests/client/test_list_roots_callback.py @@ -1,8 +1,8 @@ import pytest from pydantic import FileUrl -from mcp.client.session import ClientSession from mcp import Client +from mcp.client.session import ClientSession from mcp.server.fastmcp import FastMCP from mcp.server.fastmcp.server import Context from mcp.server.session import ServerSession diff --git a/tests/client/test_sampling_callback.py b/tests/client/test_sampling_callback.py index ebd24a857b..d7fbe5db50 100644 --- a/tests/client/test_sampling_callback.py +++ b/tests/client/test_sampling_callback.py @@ -1,9 +1,9 @@ import pytest -from mcp.server.session import ServerSession from mcp import Client from mcp.client.session import ClientSession from mcp.server.fastmcp import FastMCP +from mcp.server.session import ServerSession from mcp.shared.context import RequestContext from mcp.types import ( CreateMessageRequestParams, diff --git a/tests/experimental/tasks/test_elicitation_scenarios.py b/tests/experimental/tasks/test_elicitation_scenarios.py index 893979fb45..058a5970c5 100644 --- a/tests/experimental/tasks/test_elicitation_scenarios.py +++ b/tests/experimental/tasks/test_elicitation_scenarios.py @@ -19,8 +19,8 @@ from mcp.client.session import ClientSession from mcp.server import Server from mcp.server.experimental.task_context import ServerTaskContext -from mcp.server.session import ServerSession from mcp.server.lowlevel import NotificationOptions +from mcp.server.session import ServerSession from mcp.shared.context import RequestContext from mcp.shared.experimental.tasks.helpers import is_terminal from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore diff --git a/tests/server/fastmcp/test_elicitation.py b/tests/server/fastmcp/test_elicitation.py index 4679a38920..ac6be8c66d 100644 --- a/tests/server/fastmcp/test_elicitation.py +++ b/tests/server/fastmcp/test_elicitation.py @@ -7,7 +7,6 @@ from mcp import Client, types from mcp.client.session import ClientSession, ElicitationFnT -from mcp.client.session import ElicitationFnT from mcp.client.transport_session import ClientTransportSession from mcp.server.fastmcp import Context, FastMCP from mcp.server.session import ServerSession diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index cbbe1e8770..3254ecd256 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -33,8 +33,8 @@ ) from mcp.client.session import ClientSession from mcp.client.sse import sse_client -from mcp.client.transport_session import ClientTransportSession from mcp.client.streamable_http import GetSessionIdCallback, streamable_http_client +from mcp.client.transport_session import ClientTransportSession from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 13d7ca5051..fda3d8ddba 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -22,7 +22,6 @@ import mcp.types as types from mcp.client.session import ClientSession from mcp.client.sse import _extract_session_id_from_endpoint, sse_client -from mcp.client.sse import sse_client from mcp.client.transport_session import ClientTransportSession from mcp.server import Server from mcp.server.sse import SseServerTransport diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 3773caa4f1..7c7eefb52f 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -27,9 +27,8 @@ import mcp.types as types from mcp.client.session import ClientSession -from mcp.client.streamable_http import streamable_http_client -from mcp.client.transport_session import ClientTransportSession from mcp.client.streamable_http import StreamableHTTPTransport, streamable_http_client +from mcp.client.transport_session import ClientTransportSession from mcp.server import Server from mcp.server.session import ServerSession from mcp.server.streamable_http import ( From 7b25f8e74a6bb1bd72757f65a60b5afab47d2d53 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 23 Jan 2026 13:56:02 +0000 Subject: [PATCH 51/53] fix readme --- README.md | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index b6f8087ab8..634da7c29b 100644 --- a/README.md +++ b/README.md @@ -227,7 +227,7 @@ from contextlib import asynccontextmanager from dataclasses import dataclass from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerSession +from mcp.server.transport_session import ServerTransportSession # Mock database class for example @@ -273,7 +273,7 @@ mcp = FastMCP("My App", lifespan=app_lifespan) # Access type-safe lifespan context in tools @mcp.tool() -def query_db(ctx: Context[ServerSession, AppContext]) -> str: +def query_db(ctx: Context[ServerTransportSession, AppContext]) -> str: """Tool that uses initialized resources.""" db = ctx.request_context.lifespan_context.db return db.query() @@ -345,13 +345,13 @@ Tools can optionally receive a Context object by including a parameter with the ```python from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerSession +from mcp.server.transport_session import ServerTransportSession mcp = FastMCP(name="Progress Example") @mcp.tool() -async def long_running_task(task_name: str, ctx: Context[ServerSession, None], steps: int = 5) -> str: +async def long_running_task(task_name: str, ctx: Context[ServerTransportSession, None], steps: int = 5) -> str: """Execute a task with progress updates.""" await ctx.info(f"Starting: {task_name}") @@ -693,13 +693,13 @@ The Context object provides the following capabilities: ```python from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerSession +from mcp.server.transport_session import ServerTransportSession mcp = FastMCP(name="Progress Example") @mcp.tool() -async def long_running_task(task_name: str, ctx: Context[ServerSession, None], steps: int = 5) -> str: +async def long_running_task(task_name: str, ctx: Context[ServerTransportSession, None], steps: int = 5) -> str: """Execute a task with progress updates.""" await ctx.info(f"Starting: {task_name}") @@ -826,6 +826,7 @@ from pydantic import BaseModel, Field from mcp.server.fastmcp import Context, FastMCP from mcp.server.session import ServerSession +from mcp.server.transport_session import ServerTransportSession from mcp.shared.exceptions import UrlElicitationRequiredError from mcp.types import ElicitRequestURLParams @@ -843,7 +844,7 @@ class BookingPreferences(BaseModel): @mcp.tool() -async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerSession, None]) -> str: +async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerTransportSession, None]) -> str: """Book a table with date availability check. This demonstrates form mode elicitation for collecting non-sensitive user input. @@ -969,13 +970,13 @@ Tools can send logs and notifications through the context: ```python from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerSession +from mcp.server.transport_session import ServerTransportSession mcp = FastMCP(name="Notifications Example") @mcp.tool() -async def process_data(data: str, ctx: Context[ServerSession, None]) -> str: +async def process_data(data: str, ctx: Context[ServerTransportSession, None]) -> str: """Process data with logging.""" # Different log levels await ctx.debug(f"Debug: Processing '{data}'") @@ -2119,7 +2120,7 @@ uv run client import asyncio import os -from mcp import ClientSession, StdioServerParameters, types +from mcp import ClientSession, ClientTransportSession, StdioServerParameters, types from mcp.client.stdio import stdio_client from mcp.shared.context import RequestContext @@ -2133,7 +2134,7 @@ server_params = StdioServerParameters( # Optional: create a sampling callback async def handle_sampling_message( - context: RequestContext[ClientSession, None], params: types.CreateMessageRequestParams + context: RequestContext[ClientTransportSession, None], params: types.CreateMessageRequestParams ) -> types.CreateMessageResult: print(f"Sampling request: {params.messages}") return types.CreateMessageResult( @@ -2247,7 +2248,7 @@ uv run display-utilities-client import asyncio import os -from mcp import ClientSession, StdioServerParameters +from mcp import ClientSession, ClientTransportSession, StdioServerParameters from mcp.client.stdio import stdio_client from mcp.shared.metadata_utils import get_display_name @@ -2259,7 +2260,7 @@ server_params = StdioServerParameters( ) -async def display_tools(session: ClientSession): +async def display_tools(session: ClientTransportSession): """Display available tools with human-readable names""" tools_response = await session.list_tools() @@ -2271,7 +2272,7 @@ async def display_tools(session: ClientSession): print(f" {tool.description}") -async def display_resources(session: ClientSession): +async def display_resources(session: ClientTransportSession): """Display available resources with human-readable names""" resources_response = await session.list_resources() From 99f194f8971e9d8a514e105eabdc2b8da60de8bc Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 23 Jan 2026 13:57:46 +0000 Subject: [PATCH 52/53] fix tests and readme --- src/mcp/shared/context.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index 67d963bd59..7e37909542 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -1,7 +1,7 @@ """Request context for MCP handlers.""" from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Generic +from typing import TYPE_CHECKING, Any, Generic, Union from typing_extensions import TypeVar @@ -14,7 +14,7 @@ SessionT_co = TypeVar( "SessionT_co", - bound=BaseSession[Any, Any, Any, Any, Any] | "ClientTransportSession" | "ServerTransportSession", + bound=Union[BaseSession[Any, Any, Any, Any, Any], "ClientTransportSession", "ServerTransportSession"], covariant=True, ) LifespanContextT = TypeVar("LifespanContextT") From 1a4ad49f734b3f90efed3e0fa73557b651832be4 Mon Sep 17 00:00:00 2001 From: Ashesh Vidyut Date: Fri, 23 Jan 2026 14:18:44 +0000 Subject: [PATCH 53/53] revert anyurl change --- src/mcp/server/session.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 68fc74dbe9..fa1a90e1b1 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -42,7 +42,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: import anyio import anyio.lowlevel from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from pydantic import TypeAdapter +from pydantic import AnyUrl, TypeAdapter import mcp.types as types from mcp.server.experimental.session_features import ExperimentalServerSessionFeatures @@ -232,7 +232,7 @@ async def send_log_message( related_request_id, ) - async def send_resource_updated(self, uri: str) -> None: # pragma: no cover + async def send_resource_updated(self, uri: str | AnyUrl) -> None: # pragma: no cover """Send a resource updated notification.""" await self.send_notification( types.ResourceUpdatedNotification(