diff --git a/README.md b/README.md index b6f8087ab..634da7c29 100644 --- a/README.md +++ b/README.md @@ -227,7 +227,7 @@ from contextlib import asynccontextmanager from dataclasses import dataclass from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerSession +from mcp.server.transport_session import ServerTransportSession # Mock database class for example @@ -273,7 +273,7 @@ mcp = FastMCP("My App", lifespan=app_lifespan) # Access type-safe lifespan context in tools @mcp.tool() -def query_db(ctx: Context[ServerSession, AppContext]) -> str: +def query_db(ctx: Context[ServerTransportSession, AppContext]) -> str: """Tool that uses initialized resources.""" db = ctx.request_context.lifespan_context.db return db.query() @@ -345,13 +345,13 @@ Tools can optionally receive a Context object by including a parameter with the ```python from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerSession +from mcp.server.transport_session import ServerTransportSession mcp = FastMCP(name="Progress Example") @mcp.tool() -async def long_running_task(task_name: str, ctx: Context[ServerSession, None], steps: int = 5) -> str: +async def long_running_task(task_name: str, ctx: Context[ServerTransportSession, None], steps: int = 5) -> str: """Execute a task with progress updates.""" await ctx.info(f"Starting: {task_name}") @@ -693,13 +693,13 @@ The Context object provides the following capabilities: ```python from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerSession +from mcp.server.transport_session import ServerTransportSession mcp = FastMCP(name="Progress Example") @mcp.tool() -async def long_running_task(task_name: str, ctx: Context[ServerSession, None], steps: int = 5) -> str: +async def long_running_task(task_name: str, ctx: Context[ServerTransportSession, None], steps: int = 5) -> str: """Execute a task with progress updates.""" await ctx.info(f"Starting: {task_name}") @@ -826,6 +826,7 @@ from pydantic import BaseModel, Field from mcp.server.fastmcp import Context, FastMCP from mcp.server.session import ServerSession +from mcp.server.transport_session import ServerTransportSession from mcp.shared.exceptions import UrlElicitationRequiredError from mcp.types import ElicitRequestURLParams @@ -843,7 +844,7 @@ class BookingPreferences(BaseModel): @mcp.tool() -async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerSession, None]) -> str: +async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerTransportSession, None]) -> str: """Book a table with date availability check. This demonstrates form mode elicitation for collecting non-sensitive user input. @@ -969,13 +970,13 @@ Tools can send logs and notifications through the context: ```python from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerSession +from mcp.server.transport_session import ServerTransportSession mcp = FastMCP(name="Notifications Example") @mcp.tool() -async def process_data(data: str, ctx: Context[ServerSession, None]) -> str: +async def process_data(data: str, ctx: Context[ServerTransportSession, None]) -> str: """Process data with logging.""" # Different log levels await ctx.debug(f"Debug: Processing '{data}'") @@ -2119,7 +2120,7 @@ uv run client import asyncio import os -from mcp import ClientSession, StdioServerParameters, types +from mcp import ClientSession, ClientTransportSession, StdioServerParameters, types from mcp.client.stdio import stdio_client from mcp.shared.context import RequestContext @@ -2133,7 +2134,7 @@ server_params = StdioServerParameters( # Optional: create a sampling callback async def handle_sampling_message( - context: RequestContext[ClientSession, None], params: types.CreateMessageRequestParams + context: RequestContext[ClientTransportSession, None], params: types.CreateMessageRequestParams ) -> types.CreateMessageResult: print(f"Sampling request: {params.messages}") return types.CreateMessageResult( @@ -2247,7 +2248,7 @@ uv run display-utilities-client import asyncio import os -from mcp import ClientSession, StdioServerParameters +from mcp import ClientSession, ClientTransportSession, StdioServerParameters from mcp.client.stdio import stdio_client from mcp.shared.metadata_utils import get_display_name @@ -2259,7 +2260,7 @@ server_params = StdioServerParameters( ) -async def display_tools(session: ClientSession): +async def display_tools(session: ClientTransportSession): """Display available tools with human-readable names""" tools_response = await session.list_tools() @@ -2271,7 +2272,7 @@ async def display_tools(session: ClientSession): print(f" {tool.description}") -async def display_resources(session: ClientSession): +async def display_resources(session: ClientTransportSession): """Display available resources with human-readable names""" resources_response = await session.list_resources() diff --git a/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py b/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py index 684222dec..3093a7144 100644 --- a/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py +++ b/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py @@ -20,7 +20,7 @@ import httpx from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp.client.auth import OAuthClientProvider, TokenStorage -from mcp.client.session import ClientSession +from mcp.client.session import ClientSession, ClientTransportSession from mcp.client.sse import sse_client from mcp.client.streamable_http import streamable_http_client from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken @@ -173,7 +173,7 @@ def __init__( self.server_url = server_url self.transport_type = transport_type self.client_metadata_url = client_metadata_url - self.session: ClientSession | None = None + self.session: ClientTransportSession | None = None async def connect(self): """Connect to the MCP server.""" diff --git a/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py b/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py index 72b1a6f20..4e3d83da3 100644 --- a/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py +++ b/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py @@ -12,6 +12,7 @@ from dotenv import load_dotenv from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client +from mcp.client.transport_session import ClientTransportSession # Configure logging logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") @@ -69,7 +70,7 @@ def __init__(self, name: str, config: dict[str, Any]) -> None: self.name: str = name self.config: dict[str, Any] = config self.stdio_context: Any | None = None - self.session: ClientSession | None = None + self.session: ClientTransportSession | None = None self._cleanup_lock: asyncio.Lock = asyncio.Lock() self.exit_stack: AsyncExitStack = AsyncExitStack() diff --git a/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/main.py b/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/main.py index 5f34eb949..a5204ca27 100644 --- a/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/main.py +++ b/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/main.py @@ -10,7 +10,7 @@ from typing import Any import click -from mcp import ClientSession +from mcp import ClientSession, ClientTransportSession from mcp.client.streamable_http import streamable_http_client from mcp.shared.context import RequestContext from mcp.types import ( @@ -24,7 +24,7 @@ async def elicitation_callback( - context: RequestContext[ClientSession, Any], + context: RequestContext[ClientTransportSession, Any], params: ElicitRequestParams, ) -> ElicitResult: """Handle elicitation requests from the server.""" @@ -39,7 +39,7 @@ async def elicitation_callback( async def sampling_callback( - context: RequestContext[ClientSession, Any], + context: RequestContext[ClientTransportSession, Any], params: CreateMessageRequestParams, ) -> CreateMessageResult: """Handle sampling requests from the server.""" diff --git a/examples/snippets/clients/display_utilities.py b/examples/snippets/clients/display_utilities.py index 40e31cf2b..b90a0c739 100644 --- a/examples/snippets/clients/display_utilities.py +++ b/examples/snippets/clients/display_utilities.py @@ -5,7 +5,7 @@ import asyncio import os -from mcp import ClientSession, StdioServerParameters +from mcp import ClientSession, ClientTransportSession, StdioServerParameters from mcp.client.stdio import stdio_client from mcp.shared.metadata_utils import get_display_name @@ -17,7 +17,7 @@ ) -async def display_tools(session: ClientSession): +async def display_tools(session: ClientTransportSession): """Display available tools with human-readable names""" tools_response = await session.list_tools() @@ -29,7 +29,7 @@ async def display_tools(session: ClientSession): print(f" {tool.description}") -async def display_resources(session: ClientSession): +async def display_resources(session: ClientTransportSession): """Display available resources with human-readable names""" resources_response = await session.list_resources() diff --git a/examples/snippets/clients/stdio_client.py b/examples/snippets/clients/stdio_client.py index e4d430397..fe07f24f0 100644 --- a/examples/snippets/clients/stdio_client.py +++ b/examples/snippets/clients/stdio_client.py @@ -5,7 +5,7 @@ import asyncio import os -from mcp import ClientSession, StdioServerParameters, types +from mcp import ClientSession, ClientTransportSession, StdioServerParameters, types from mcp.client.stdio import stdio_client from mcp.shared.context import RequestContext @@ -19,7 +19,7 @@ # Optional: create a sampling callback async def handle_sampling_message( - context: RequestContext[ClientSession, None], params: types.CreateMessageRequestParams + context: RequestContext[ClientTransportSession, None], params: types.CreateMessageRequestParams ) -> types.CreateMessageResult: print(f"Sampling request: {params.messages}") return types.CreateMessageResult( diff --git a/examples/snippets/clients/url_elicitation_client.py b/examples/snippets/clients/url_elicitation_client.py index 300c38fa0..1b8ec5431 100644 --- a/examples/snippets/clients/url_elicitation_client.py +++ b/examples/snippets/clients/url_elicitation_client.py @@ -30,7 +30,7 @@ from typing import Any from urllib.parse import urlparse -from mcp import ClientSession, types +from mcp import ClientSession, ClientTransportSession, types from mcp.client.sse import sse_client from mcp.shared.context import RequestContext from mcp.shared.exceptions import McpError, UrlElicitationRequiredError @@ -38,7 +38,7 @@ async def handle_elicitation( - context: RequestContext[ClientSession, Any], + context: RequestContext[ClientTransportSession, Any], params: types.ElicitRequestParams, ) -> types.ElicitResult | types.ErrorData: """Handle elicitation requests from the server. diff --git a/examples/snippets/servers/elicitation.py b/examples/snippets/servers/elicitation.py index 34921aa4b..bc7c17ce6 100644 --- a/examples/snippets/servers/elicitation.py +++ b/examples/snippets/servers/elicitation.py @@ -11,6 +11,7 @@ from mcp.server.fastmcp import Context, FastMCP from mcp.server.session import ServerSession +from mcp.server.transport_session import ServerTransportSession from mcp.shared.exceptions import UrlElicitationRequiredError from mcp.types import ElicitRequestURLParams @@ -28,7 +29,7 @@ class BookingPreferences(BaseModel): @mcp.tool() -async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerSession, None]) -> str: +async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerTransportSession, None]) -> str: """Book a table with date availability check. This demonstrates form mode elicitation for collecting non-sensitive user input. diff --git a/examples/snippets/servers/lifespan_example.py b/examples/snippets/servers/lifespan_example.py index 62278b6aa..46f01f427 100644 --- a/examples/snippets/servers/lifespan_example.py +++ b/examples/snippets/servers/lifespan_example.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerSession +from mcp.server.transport_session import ServerTransportSession # Mock database class for example @@ -51,7 +51,7 @@ async def app_lifespan(server: FastMCP) -> AsyncIterator[AppContext]: # Access type-safe lifespan context in tools @mcp.tool() -def query_db(ctx: Context[ServerSession, AppContext]) -> str: +def query_db(ctx: Context[ServerTransportSession, AppContext]) -> str: """Tool that uses initialized resources.""" db = ctx.request_context.lifespan_context.db return db.query() diff --git a/examples/snippets/servers/notifications.py b/examples/snippets/servers/notifications.py index 833bc8905..995ecd817 100644 --- a/examples/snippets/servers/notifications.py +++ b/examples/snippets/servers/notifications.py @@ -1,11 +1,11 @@ from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerSession +from mcp.server.transport_session import ServerTransportSession mcp = FastMCP(name="Notifications Example") @mcp.tool() -async def process_data(data: str, ctx: Context[ServerSession, None]) -> str: +async def process_data(data: str, ctx: Context[ServerTransportSession, None]) -> str: """Process data with logging.""" # Different log levels await ctx.debug(f"Debug: Processing '{data}'") diff --git a/examples/snippets/servers/tool_progress.py b/examples/snippets/servers/tool_progress.py index 2ac458f6a..a0f62fda6 100644 --- a/examples/snippets/servers/tool_progress.py +++ b/examples/snippets/servers/tool_progress.py @@ -1,11 +1,11 @@ from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerSession +from mcp.server.transport_session import ServerTransportSession mcp = FastMCP(name="Progress Example") @mcp.tool() -async def long_running_task(task_name: str, ctx: Context[ServerSession, None], steps: int = 5) -> str: +async def long_running_task(task_name: str, ctx: Context[ServerTransportSession, None], steps: int = 5) -> str: """Execute a task with progress updates.""" await ctx.info(f"Starting: {task_name}") diff --git a/src/mcp/__init__.py b/src/mcp/__init__.py index 982352314..7cc1fec7a 100644 --- a/src/mcp/__init__.py +++ b/src/mcp/__init__.py @@ -2,8 +2,10 @@ from .client.session import ClientSession from .client.session_group import ClientSessionGroup from .client.stdio import StdioServerParameters, stdio_client +from .client.transport_session import ClientTransportSession from .server.session import ServerSession from .server.stdio import stdio_server +from .server.transport_session import ServerTransportSession from .shared.exceptions import McpError, UrlElicitationRequiredError from .types import ( CallToolRequest, @@ -132,4 +134,8 @@ "UrlElicitationRequiredError", "stdio_client", "stdio_server", + "CompleteRequest", + "JSONRPCResponse", + "ClientTransportSession", + "ServerTransportSession", ] diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index d5d4c8607..c28fb05e7 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -8,6 +8,7 @@ import mcp.types as types from mcp.client.experimental import ExperimentalClientFeatures from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers +from mcp.client.transport_session import ClientTransportSession from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder @@ -62,7 +63,7 @@ async def _default_message_handler( async def _default_sampling_callback( - context: RequestContext["ClientSession", Any], + context: RequestContext["ClientTransportSession", Any], params: types.CreateMessageRequestParams, ) -> types.CreateMessageResult | types.CreateMessageResultWithTools | types.ErrorData: return types.ErrorData( @@ -72,7 +73,7 @@ async def _default_sampling_callback( async def _default_elicitation_callback( - context: RequestContext["ClientSession", Any], + context: RequestContext["ClientTransportSession", Any], params: types.ElicitRequestParams, ) -> types.ElicitResult | types.ErrorData: return types.ErrorData( # pragma: no cover @@ -82,7 +83,7 @@ async def _default_elicitation_callback( async def _default_list_roots_callback( - context: RequestContext["ClientSession", Any], + context: RequestContext["ClientTransportSession", Any], ) -> types.ListRootsResult | types.ErrorData: return types.ErrorData( code=types.INVALID_REQUEST, @@ -100,13 +101,14 @@ async def _default_logging_callback( class ClientSession( + ClientTransportSession, BaseSession[ types.ClientRequest, types.ClientNotification, types.ClientResult, types.ServerRequest, types.ServerNotification, - ] + ], ): def __init__( self, diff --git a/src/mcp/client/transport_session.py b/src/mcp/client/transport_session.py new file mode 100644 index 000000000..d6b03dc9d --- /dev/null +++ b/src/mcp/client/transport_session.py @@ -0,0 +1,130 @@ +from abc import ABC, abstractmethod +from typing import Any + +import mcp.types as types +from mcp.shared.session import ProgressFnT +from mcp.types import RequestParamsMeta + + +class ClientTransportSession(ABC): + """Abstract base class for communication transports.""" + + @abstractmethod + async def initialize(self) -> types.InitializeResult: + """Send an initialize request.""" + raise NotImplementedError + + @abstractmethod + async def send_ping(self) -> types.EmptyResult: + """Send a ping request.""" + raise NotImplementedError + + @abstractmethod + async def send_progress_notification( + self, + progress_token: str | int, + progress: float, + total: float | None = None, + message: str | None = None, + ) -> None: + """Send a progress notification.""" + raise NotImplementedError + + @abstractmethod + async def set_logging_level( + self, + level: types.LoggingLevel, + ) -> types.EmptyResult: + """Send a logging/setLevel request.""" + raise NotImplementedError + + @abstractmethod + async def list_resources(self, *, params: types.PaginatedRequestParams | None = None) -> types.ListResourcesResult: + """Send a resources/list request.""" + raise NotImplementedError + + @abstractmethod + async def list_resource_templates( + self, *, params: types.PaginatedRequestParams | None = None + ) -> types.ListResourceTemplatesResult: + """Send a resources/templates/list request. + + Args: + params: Full pagination parameters including cursor and any future fields + """ + raise NotImplementedError + + @abstractmethod + async def read_resource(self, uri: str) -> types.ReadResourceResult: + """Send a resources/read request.""" + raise NotImplementedError + + @abstractmethod + async def subscribe_resource(self, uri: str) -> types.EmptyResult: + """Send a resources/subscribe request.""" + raise NotImplementedError + + @abstractmethod + async def unsubscribe_resource(self, uri: str) -> types.EmptyResult: + """Send a resources/unsubscribe request.""" + raise NotImplementedError + + @abstractmethod + async def call_tool( + self, + name: str, + arguments: dict[str, Any] | None = None, + read_timeout_seconds: float | None = None, + progress_callback: ProgressFnT | None = None, + *, + meta: RequestParamsMeta | None = None, + ) -> types.CallToolResult: + """Send a tools/call request with optional progress callback support.""" + raise NotImplementedError + + @abstractmethod + async def list_prompts( + self, + *, + params: types.PaginatedRequestParams | None = None, + ) -> types.ListPromptsResult: + """Send a prompts/list request.""" + raise NotImplementedError + + @abstractmethod + async def get_prompt( + self, + name: str, + arguments: dict[str, str] | None = None, + ) -> types.GetPromptResult: + """Send a prompts/get request.""" + raise NotImplementedError + + @abstractmethod + async def complete( + self, + ref: types.ResourceTemplateReference | types.PromptReference, + argument: dict[str, str], + context_arguments: dict[str, str] | None = None, + ) -> types.CompleteResult: + """Send a completion/complete request.""" + raise NotImplementedError + + @abstractmethod + async def list_tools( + self, + *, + params: types.PaginatedRequestParams | None = None, + ) -> types.ListToolsResult: + """Send a tools/list request. + + Args: + cursor: Simple cursor string for pagination (deprecated, use params instead) + params: Full pagination parameters including cursor and any future fields + """ + raise NotImplementedError + + @abstractmethod + async def send_roots_list_changed(self) -> None: + """Send a roots/list_changed notification.""" + raise NotImplementedError diff --git a/src/mcp/server/elicitation.py b/src/mcp/server/elicitation.py index 58e9fe448..b07cb1a3e 100644 --- a/src/mcp/server/elicitation.py +++ b/src/mcp/server/elicitation.py @@ -8,8 +8,11 @@ from pydantic import BaseModel -from mcp.server.session import ServerSession -from mcp.types import RequestId +from mcp.server.transport_session import ServerTransportSession +from mcp.types import ( + ElicitResult, + RequestId, +) ElicitSchemaModelT = TypeVar("ElicitSchemaModelT", bound=BaseModel) @@ -103,7 +106,7 @@ def _is_primitive_field(annotation: type) -> bool: async def elicit_with_validation( - session: ServerSession, + session: ServerTransportSession, message: str, schema: type[ElicitSchemaModelT], related_request_id: RequestId | None = None, @@ -123,7 +126,7 @@ async def elicit_with_validation( json_schema = schema.model_json_schema() - result = await session.elicit_form( + result: ElicitResult = await session.elicit( message=message, requested_schema=json_schema, related_request_id=related_request_id, @@ -143,7 +146,7 @@ async def elicit_with_validation( async def elicit_url( - session: ServerSession, + session: ServerTransportSession, message: str, url: str, elicitation_id: str, diff --git a/src/mcp/server/experimental/task_result_handler.py b/src/mcp/server/experimental/task_result_handler.py index 4d763ef0e..7390b49cb 100644 --- a/src/mcp/server/experimental/task_result_handler.py +++ b/src/mcp/server/experimental/task_result_handler.py @@ -14,7 +14,7 @@ import anyio -from mcp.server.session import ServerSession +from mcp.server.session import ServerTransportSession from mcp.shared.exceptions import McpError from mcp.shared.experimental.tasks.helpers import RELATED_TASK_METADATA_KEY, is_terminal from mcp.shared.experimental.tasks.message_queue import TaskMessageQueue @@ -69,7 +69,7 @@ def __init__( async def send_message( self, - session: ServerSession, + session: ServerTransportSession, message: SessionMessage, ) -> None: """Send a message via the session. @@ -81,7 +81,7 @@ async def send_message( async def handle( self, request: GetTaskPayloadRequest, - session: ServerSession, + session: ServerTransportSession, request_id: RequestId, ) -> GetTaskPayloadResult: """Handle a tasks/result request. @@ -131,7 +131,7 @@ async def handle( async def _deliver_queued_messages( self, task_id: str, - session: ServerSession, + session: ServerTransportSession, request_id: RequestId, ) -> None: """Dequeue and send all pending messages for a task. diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index d0a550280..ac67e195f 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -37,12 +37,13 @@ from mcp.server.lowlevel.server import LifespanResultT from mcp.server.lowlevel.server import Server as MCPServer from mcp.server.lowlevel.server import lifespan as default_lifespan -from mcp.server.session import ServerSession, ServerSessionT +from mcp.server.session import ServerSessionT from mcp.server.sse import SseServerTransport from mcp.server.stdio import stdio_server from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings +from mcp.server.transport_session import ServerTransportSession from mcp.shared.context import LifespanContextT, RequestContext, RequestT from mcp.types import Annotations, ContentBlock, GetPromptResult, Icon, ToolAnnotations from mcp.types import Prompt as MCPPrompt @@ -297,7 +298,7 @@ async def list_tools(self) -> list[MCPTool]: for info in tools ] - def get_context(self) -> Context[ServerSession, LifespanResultT, Request]: + def get_context(self) -> Context[ServerTransportSession, LifespanResultT, Request]: """Returns a Context object. Note that the context will only be valid during a request; outside a request, most methods will error. """ diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 6bea4126f..667ca5e89 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -100,6 +100,7 @@ async def main(): from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPASGIApp, StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings +from mcp.server.transport_session import ServerTransportSession from mcp.shared.context import RequestContext from mcp.shared.exceptions import McpError, UrlElicitationRequiredError from mcp.shared.message import ServerMessageMetadata, SessionMessage @@ -117,7 +118,9 @@ async def main(): CombinationContent: TypeAlias = tuple[UnstructuredContent, StructuredContent] # This will be properly typed in each Server instance's context -request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = contextvars.ContextVar("request_ctx") +request_ctx: contextvars.ContextVar[RequestContext[ServerTransportSession, Any, Any]] = contextvars.ContextVar( + "request_ctx" +) class NotificationOptions: @@ -255,7 +258,7 @@ def get_capabilities( @property def request_context( self, - ) -> RequestContext[ServerSession, LifespanResultT, RequestT]: + ) -> RequestContext[ServerTransportSession, LifespanResultT, RequestT]: """If called outside of a request context, this will raise a LookupError.""" return request_ctx.get() diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 5a70ee02e..fa1a90e1b 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -47,6 +47,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: import mcp.types as types from mcp.server.experimental.session_features import ExperimentalServerSessionFeatures from mcp.server.models import InitializationOptions +from mcp.server.transport_session import ServerTransportSession from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages from mcp.shared.exceptions import StatelessModeNotSupported from mcp.shared.experimental.tasks.capabilities import check_tasks_capability @@ -65,7 +66,7 @@ class InitializationState(Enum): Initialized = 3 -ServerSessionT = TypeVar("ServerSessionT", bound="ServerSession") +ServerSessionT = TypeVar("ServerSessionT", bound="ServerTransportSession") ServerRequestResponder = ( RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception @@ -73,13 +74,14 @@ class InitializationState(Enum): class ServerSession( + ServerTransportSession, BaseSession[ types.ServerRequest, types.ServerNotification, types.ServerResult, types.ClientRequest, types.ClientNotification, - ] + ], ): _initialized: InitializationState = InitializationState.NotInitialized _client_params: types.InitializeRequestParams | None = None diff --git a/src/mcp/server/transport_session.py b/src/mcp/server/transport_session.py new file mode 100644 index 000000000..c6077437a --- /dev/null +++ b/src/mcp/server/transport_session.py @@ -0,0 +1,100 @@ +"""Abstract base class for transport sessions.""" + +from abc import ABC, abstractmethod +from typing import Any + +import mcp.types as types +from mcp.shared.message import SessionMessage + + +class ServerTransportSession(ABC): + """Abstract base class for transport sessions.""" + + @abstractmethod + async def send_message(self, message: SessionMessage) -> None: + """Send a raw session message.""" + raise NotImplementedError + + @abstractmethod + async def send_log_message( + self, + level: types.LoggingLevel, + data: Any, + logger: str | None = None, + related_request_id: types.RequestId | None = None, + ) -> None: + """Send a log message notification.""" + raise NotImplementedError + + @abstractmethod + async def send_resource_updated(self, uri: str) -> None: + """Send a resource updated notification.""" + raise NotImplementedError + + @abstractmethod + async def list_roots(self) -> types.ListRootsResult: + """Send a roots/list request.""" + raise NotImplementedError + + @abstractmethod + async def elicit( + self, + message: str, + requested_schema: types.ElicitRequestedSchema, + related_request_id: types.RequestId | None = None, + ) -> types.ElicitResult: + """Send an elicitation/create request.""" + raise NotImplementedError + + @abstractmethod + async def elicit_form( + self, + message: str, + requested_schema: types.ElicitRequestedSchema, + related_request_id: types.RequestId | None = None, + ) -> types.ElicitResult: + """Send a form mode elicitation/create request.""" + raise NotImplementedError + + @abstractmethod + async def elicit_url( + self, + message: str, + url: str, + elicitation_id: str, + related_request_id: types.RequestId | None = None, + ) -> types.ElicitResult: + """Send a URL mode elicitation/create request.""" + raise NotImplementedError + + @abstractmethod + async def send_ping(self) -> types.EmptyResult: + """Send a ping request.""" + raise NotImplementedError + + @abstractmethod + async def send_progress_notification( + self, + progress_token: str | int, + progress: float, + total: float | None = None, + message: str | None = None, + related_request_id: str | None = None, + ) -> None: + """Send a progress notification.""" + raise NotImplementedError + + @abstractmethod + async def send_resource_list_changed(self) -> None: + """Send a resource list changed notification.""" + raise NotImplementedError + + @abstractmethod + async def send_tool_list_changed(self) -> None: + """Send a tool list changed notification.""" + raise NotImplementedError + + @abstractmethod + async def send_prompt_list_changed(self) -> None: + """Send a prompt list changed notification.""" + raise NotImplementedError diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index b140f9a77..7e3790954 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -1,7 +1,7 @@ """Request context for MCP handlers.""" from dataclasses import dataclass, field -from typing import Any, Generic +from typing import TYPE_CHECKING, Any, Generic, Union from typing_extensions import TypeVar @@ -9,16 +9,23 @@ from mcp.shared.session import BaseSession from mcp.types import RequestId, RequestParamsMeta -SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any]) +if TYPE_CHECKING: + from mcp import ClientTransportSession, ServerTransportSession + +SessionT_co = TypeVar( + "SessionT_co", + bound=Union[BaseSession[Any, Any, Any, Any, Any], "ClientTransportSession", "ServerTransportSession"], + covariant=True, +) LifespanContextT = TypeVar("LifespanContextT") RequestT = TypeVar("RequestT", default=Any) @dataclass -class RequestContext(Generic[SessionT, LifespanContextT, RequestT]): +class RequestContext(Generic[SessionT_co, LifespanContextT, RequestT]): request_id: RequestId meta: RequestParamsMeta | None - session: SessionT + session: SessionT_co lifespan_context: LifespanContextT # NOTE: This is typed as Any to avoid circular imports. The actual type is # mcp.server.experimental.request_context.Experimental, but importing it here diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 341e1fac0..e99dd6c9f 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -154,7 +154,21 @@ def cancelled(self) -> bool: # pragma: no cover return self._cancel_scope.cancel_called +class Session: + """Base class for a session that can send progress notifications.""" + + async def send_progress_notification( + self, + progress_token: ProgressToken, + progress: float, + total: float | None = None, + message: str | None = None, + ) -> None: + """Sends a progress notification for a request that is currently being processed.""" + + class BaseSession( + Session, Generic[ SendRequestT, SendNotificationT, @@ -500,15 +514,6 @@ async def _received_notification(self, notification: ReceiveNotificationT) -> No to listen on the message stream. """ - async def send_progress_notification( - self, - progress_token: ProgressToken, - progress: float, - total: float | None = None, - message: str | None = None, - ) -> None: - """Sends a progress notification for a request that is currently being processed.""" - async def _handle_incoming( self, req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception ) -> None: diff --git a/tests/client/test_sampling_callback.py b/tests/client/test_sampling_callback.py index 1394e665c..d7fbe5db5 100644 --- a/tests/client/test_sampling_callback.py +++ b/tests/client/test_sampling_callback.py @@ -3,6 +3,7 @@ from mcp import Client from mcp.client.session import ClientSession from mcp.server.fastmcp import FastMCP +from mcp.server.session import ServerSession from mcp.shared.context import RequestContext from mcp.types import ( CreateMessageRequestParams, @@ -33,7 +34,9 @@ async def sampling_callback( @server.tool("test_sampling") async def test_sampling_tool(message: str): - value = await server.get_context().session.create_message( + session = server.get_context().session + assert isinstance(session, ServerSession) + value = await session.create_message( messages=[SamplingMessage(role="user", content=TextContent(type="text", text=message))], max_tokens=100, ) @@ -79,7 +82,9 @@ async def sampling_callback( @server.tool("test_backwards_compat") async def test_tool(message: str): # Call create_message WITHOUT tools - result = await server.get_context().session.create_message( + session = server.get_context().session + assert isinstance(session, ServerSession) + result = await session.create_message( messages=[SamplingMessage(role="user", content=TextContent(type="text", text=message))], max_tokens=100, ) diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 220c571a5..de684605c 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -5,6 +5,7 @@ import mcp.types as types from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession +from mcp.client.transport_session import ClientTransportSession from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder @@ -390,7 +391,7 @@ async def test_client_capabilities_with_custom_callbacks(): received_capabilities = None async def custom_sampling_callback( # pragma: no cover - context: RequestContext["ClientSession", Any], + context: RequestContext["ClientTransportSession", Any], params: types.CreateMessageRequestParams, ) -> types.CreateMessageResult | types.ErrorData: return types.CreateMessageResult( @@ -400,7 +401,7 @@ async def custom_sampling_callback( # pragma: no cover ) async def custom_list_roots_callback( # pragma: no cover - context: RequestContext["ClientSession", Any], + context: RequestContext["ClientTransportSession", Any], ) -> types.ListRootsResult | types.ErrorData: return types.ListRootsResult(roots=[]) diff --git a/tests/experimental/tasks/test_elicitation_scenarios.py b/tests/experimental/tasks/test_elicitation_scenarios.py index 1cefe847d..058a5970c 100644 --- a/tests/experimental/tasks/test_elicitation_scenarios.py +++ b/tests/experimental/tasks/test_elicitation_scenarios.py @@ -20,6 +20,7 @@ from mcp.server import Server from mcp.server.experimental.task_context import ServerTaskContext from mcp.server.lowlevel import NotificationOptions +from mcp.server.session import ServerSession from mcp.shared.context import RequestContext from mcp.shared.experimental.tasks.helpers import is_terminal from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore @@ -283,8 +284,11 @@ async def list_tools() -> list[Tool]: async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: ctx = server.request_context + session = ctx.session + assert isinstance(session, ServerSession) + # Task-augmented elicitation - server polls client - result = await ctx.session.experimental.elicit_as_task( + result = await session.experimental.elicit_as_task( message="Please confirm the action", requested_schema={"type": "object", "properties": {"confirm": {"type": "boolean"}}}, ttl=60000, @@ -574,8 +578,11 @@ async def list_tools() -> list[Tool]: async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: ctx = server.request_context + session = ctx.session + assert isinstance(session, ServerSession) + # Task-augmented sampling - server polls client - result = await ctx.session.experimental.create_message_as_task( + result = await session.experimental.create_message_as_task( messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Hello"))], max_tokens=100, ttl=60000, diff --git a/tests/server/fastmcp/test_elicitation.py b/tests/server/fastmcp/test_elicitation.py index efed572e4..ac6be8c66 100644 --- a/tests/server/fastmcp/test_elicitation.py +++ b/tests/server/fastmcp/test_elicitation.py @@ -7,8 +7,10 @@ from mcp import Client, types from mcp.client.session import ClientSession, ElicitationFnT +from mcp.client.transport_session import ClientTransportSession from mcp.server.fastmcp import Context, FastMCP from mcp.server.session import ServerSession +from mcp.server.transport_session import ServerTransportSession from mcp.shared.context import RequestContext from mcp.types import ElicitRequestParams, ElicitResult, TextContent @@ -22,7 +24,7 @@ def create_ask_user_tool(mcp: FastMCP): """Create a standard ask_user tool that handles all elicitation responses.""" @mcp.tool(description="A tool that uses elicitation") - async def ask_user(prompt: str, ctx: Context[ServerSession, None]) -> str: + async def ask_user(prompt: str, ctx: Context[ServerTransportSession, None]) -> str: result = await ctx.elicit(message=f"Tool wants to ask: {prompt}", schema=AnswerSchema) if result.action == "accept" and result.data: @@ -66,7 +68,7 @@ async def test_stdio_elicitation(): # Create a custom handler for elicitation requests async def elicitation_callback( - context: RequestContext[ClientSession, None], params: ElicitRequestParams + context: RequestContext[ClientTransportSession, None], params: ElicitRequestParams ): # pragma: no cover if params.message == "Tool wants to ask: What is your name?": return ElicitResult(action="accept", content={"answer": "Test User"}) @@ -84,7 +86,7 @@ async def test_stdio_elicitation_decline(): mcp = FastMCP(name="StdioElicitationDeclineServer") create_ask_user_tool(mcp) - async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + async def elicitation_callback(context: RequestContext[ClientTransportSession, None], params: ElicitRequestParams): return ElicitResult(action="decline") await call_tool_and_assert( @@ -99,7 +101,7 @@ async def test_elicitation_schema_validation(): def create_validation_tool(name: str, schema_class: type[BaseModel]): @mcp.tool(name=name, description=f"Tool testing {name}") - async def tool(ctx: Context[ServerSession, None]) -> str: # pragma: no cover + async def tool(ctx: Context[ServerTransportSession, None]) -> str: # pragma: no cover try: await ctx.elicit(message="This should fail validation", schema=schema_class) return "Should not reach here" @@ -123,7 +125,7 @@ class InvalidNestedSchema(BaseModel): # Dummy callback (won't be called due to validation failure) async def elicitation_callback( - context: RequestContext[ClientSession, None], params: ElicitRequestParams + context: RequestContext[ClientTransportSession, None], params: ElicitRequestParams ): # pragma: no cover return ElicitResult(action="accept", content={}) @@ -149,7 +151,7 @@ class OptionalSchema(BaseModel): subscribe: bool | None = Field(default=False, description="Subscribe to newsletter?") @mcp.tool(description="Tool with optional fields") - async def optional_tool(ctx: Context[ServerSession, None]) -> str: + async def optional_tool(ctx: Context[ServerTransportSession, None]) -> str: result = await ctx.elicit(message="Please provide your information", schema=OptionalSchema) if result.action == "accept" and result.data: @@ -179,7 +181,7 @@ async def optional_tool(ctx: Context[ServerSession, None]) -> str: for content, expected in test_cases: - async def callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + async def callback(context: RequestContext[ClientTransportSession, None], params: ElicitRequestParams): return ElicitResult(action="accept", content=content) await call_tool_and_assert(mcp, callback, "optional_tool", {}, expected) @@ -190,7 +192,7 @@ class InvalidOptionalSchema(BaseModel): optional_list: list[int] | None = Field(default=None, description="Invalid optional list") @mcp.tool(description="Tool with invalid optional field") - async def invalid_optional_tool(ctx: Context[ServerSession, None]) -> str: # pragma: no cover + async def invalid_optional_tool(ctx: Context[ServerTransportSession, None]) -> str: # pragma: no cover try: await ctx.elicit(message="This should fail", schema=InvalidOptionalSchema) return "Should not reach here" @@ -198,7 +200,7 @@ async def invalid_optional_tool(ctx: Context[ServerSession, None]) -> str: # pr return f"Validation failed: {str(e)}" async def elicitation_callback( - context: RequestContext[ClientSession, None], params: ElicitRequestParams + context: RequestContext[ClientTransportSession, None], params: ElicitRequestParams ): # pragma: no cover return ElicitResult(action="accept", content={}) @@ -264,7 +266,7 @@ class DefaultsSchema(BaseModel): email: str = Field(description="Email address (required)") @mcp.tool(description="Tool with default values") - async def defaults_tool(ctx: Context[ServerSession, None]) -> str: + async def defaults_tool(ctx: Context[ServerTransportSession, None]) -> str: result = await ctx.elicit(message="Please provide your information", schema=DefaultsSchema) if result.action == "accept" and result.data: @@ -276,7 +278,9 @@ async def defaults_tool(ctx: Context[ServerSession, None]) -> str: return f"User {result.action}" # First verify that defaults are present in the JSON schema sent to clients - async def callback_schema_verify(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + async def callback_schema_verify( + context: RequestContext[ClientTransportSession, None], params: ElicitRequestParams + ): # Verify the schema includes defaults assert isinstance(params, types.ElicitRequestFormParams), "Expected form mode elicitation" schema = params.requested_schema @@ -298,7 +302,7 @@ async def callback_schema_verify(context: RequestContext[ClientSession, None], p ) # Test overriding defaults - async def callback_override(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + async def callback_override(context: RequestContext[ClientTransportSession, None], params: ElicitRequestParams): return ElicitResult( action="accept", content={"email": "john@example.com", "name": "John", "age": 25, "subscribe": False} ) diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index a7f945f78..3254ecd25 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -34,6 +34,7 @@ from mcp.client.session import ClientSession from mcp.client.sse import sse_client from mcp.client.streamable_http import GetSessionIdCallback, streamable_http_client +from mcp.client.transport_session import ClientTransportSession from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder @@ -213,7 +214,7 @@ def unpack_streams( # Callback functions for testing async def sampling_callback( - context: RequestContext[ClientSession, None], params: CreateMessageRequestParams + context: RequestContext[ClientTransportSession, None], params: CreateMessageRequestParams ) -> CreateMessageResult: """Sampling callback for tests.""" return CreateMessageResult( @@ -226,7 +227,7 @@ async def sampling_callback( ) -async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): +async def elicitation_callback(context: RequestContext[ClientTransportSession, None], params: ElicitRequestParams): """Elicitation callback for tests.""" # For restaurant booking test if "No tables available" in params.message: diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index 89fe18ebb..68df1e58c 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -25,9 +25,9 @@ @pytest.mark.anyio async def test_in_flight_requests_cleared_after_completion(): """Verify that _in_flight is empty after all requests complete.""" + # Send a request and wait for response server = Server(name="test server") async with Client(server) as client: - # Send a request and wait for response response = await client.send_ping() assert isinstance(response, EmptyResult) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index fb006424c..fda3d8ddb 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -22,6 +22,7 @@ import mcp.types as types from mcp.client.session import ClientSession from mcp.client.sse import _extract_session_id_from_endpoint, sse_client +from mcp.client.transport_session import ClientTransportSession from mcp.server import Server from mcp.server.sse import SseServerTransport from mcp.server.transport_security import TransportSecuritySettings @@ -243,7 +244,7 @@ def mock_extract(url: str) -> None: @pytest.fixture -async def initialized_sse_client_session(server: None, server_url: str) -> AsyncGenerator[ClientSession, None]: +async def initialized_sse_client_session(server: None, server_url: str) -> AsyncGenerator[ClientTransportSession, None]: async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams: async with ClientSession(*streams) as session: await session.initialize() diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index ed86f9860..7c7eefb52 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -28,7 +28,9 @@ import mcp.types as types from mcp.client.session import ClientSession from mcp.client.streamable_http import StreamableHTTPTransport, streamable_http_client +from mcp.client.transport_session import ClientTransportSession from mcp.server import Server +from mcp.server.session import ServerSession from mcp.server.streamable_http import ( MCP_PROTOCOL_VERSION_HEADER, MCP_SESSION_ID_HEADER, @@ -241,7 +243,9 @@ async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent] elif name == "test_sampling_tool": # Test sampling by requesting the client to sample a message - sampling_result = await ctx.session.create_message( + session = ctx.session + assert isinstance(session, ServerSession) + sampling_result = await session.create_message( messages=[ types.SamplingMessage( role="user", @@ -1361,7 +1365,7 @@ async def test_streamablehttp_server_sampling(basic_server: None, basic_server_u # Define sampling callback that returns a mock response async def sampling_callback( - context: RequestContext[ClientSession, Any], + context: RequestContext[ClientTransportSession, Any], params: types.CreateMessageRequestParams, ) -> types.CreateMessageResult: nonlocal sampling_callback_invoked, captured_message_params diff --git a/tests/shared/test_ws.py b/tests/shared/test_ws.py index 8fb7aeec3..8415ab51b 100644 --- a/tests/shared/test_ws.py +++ b/tests/shared/test_ws.py @@ -13,6 +13,7 @@ from starlette.websockets import WebSocket from mcp.client.session import ClientSession +from mcp.client.transport_session import ClientTransportSession from mcp.client.websocket import websocket_client from mcp.server import Server from mcp.server.websocket import websocket_server @@ -126,7 +127,7 @@ def server(server_port: int) -> Generator[None, None, None]: @pytest.fixture() -async def initialized_ws_client_session(server: None, server_url: str) -> AsyncGenerator[ClientSession, None]: +async def initialized_ws_client_session(server: None, server_url: str) -> AsyncGenerator[ClientTransportSession, None]: """Create and initialize a WebSocket client session""" async with websocket_client(server_url + "/ws") as streams: async with ClientSession(*streams) as session: