Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any

from mcp.server.auth.provider import AccessToken, TokenVerifier
from mcp.shared.auth_utils import check_resource_allowed, resource_url_from_server_url
from mcp.shared._auth_utils import check_resource_allowed, resource_url_from_server_url

logger = logging.getLogger(__name__)

Expand Down
102 changes: 87 additions & 15 deletions src/mcp/client/auth/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
Implements authorization code flow with PKCE and automatic token refresh.
"""

from __future__ import annotations as _annotations

import base64
import hashlib
import logging
Expand All @@ -13,11 +15,11 @@
from collections.abc import AsyncGenerator, Awaitable, Callable
from dataclasses import dataclass, field
from typing import Any, Protocol
from urllib.parse import quote, urlencode, urljoin, urlparse
from urllib.parse import quote, urlencode, urljoin, urlparse, urlsplit, urlunsplit

import anyio
import httpx
from pydantic import BaseModel, Field, ValidationError
from pydantic import AnyUrl, BaseModel, Field, HttpUrl, ValidationError

from mcp.client.auth.exceptions import OAuthFlowError, OAuthTokenError
from mcp.client.auth.utils import (
Expand Down Expand Up @@ -45,11 +47,6 @@
OAuthToken,
ProtectedResourceMetadata,
)
from mcp.shared.auth_utils import (
calculate_token_expiry,
check_resource_allowed,
resource_url_from_server_url,
)

logger = logging.getLogger(__name__)

Expand All @@ -61,7 +58,7 @@ class PKCEParameters(BaseModel):
code_challenge: str = Field(..., min_length=43, max_length=128)

@classmethod
def generate(cls) -> "PKCEParameters":
def generate(cls) -> PKCEParameters:
"""Generate new PKCE parameters."""
code_verifier = "".join(secrets.choice(string.ascii_letters + string.digits + "-._~") for _ in range(128))
digest = hashlib.sha256(code_verifier.encode()).digest()
Expand All @@ -74,19 +71,15 @@ class TokenStorage(Protocol):

async def get_tokens(self) -> OAuthToken | None:
"""Get stored tokens."""
...

async def set_tokens(self, tokens: OAuthToken) -> None:
"""Store tokens."""
...

async def get_client_info(self) -> OAuthClientInformationFull | None:
"""Get stored client information."""
...

async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
"""Store client information."""
...


@dataclass
Expand Down Expand Up @@ -124,7 +117,7 @@ def get_authorization_base_url(self, server_url: str) -> str:

def update_token_expiry(self, token: OAuthToken) -> None:
"""Update token expiry time using shared util function."""
self.token_expiry_time = calculate_token_expiry(token.expires_in)
self.token_expiry_time = _calculate_token_expiry(token.expires_in)

def is_token_valid(self) -> bool:
"""Check if current token is valid."""
Expand All @@ -148,12 +141,12 @@ def get_resource_url(self) -> str:

Uses PRM resource if it's a valid parent, otherwise uses canonical server URL.
"""
resource = resource_url_from_server_url(self.server_url)
resource = _resource_url_from_server_url(self.server_url)

# If PRM provides a resource that's a valid parent, use it
if self.protected_resource_metadata and self.protected_resource_metadata.resource:
prm_resource = str(self.protected_resource_metadata.resource)
if check_resource_allowed(requested_resource=resource, configured_resource=prm_resource):
if _check_resource_allowed(requested_resource=resource, configured_resource=prm_resource):
resource = prm_resource

return resource
Expand Down Expand Up @@ -614,3 +607,82 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
# Retry with new tokens
self._add_auth_header(request)
yield request


def _resource_url_from_server_url(url: str | HttpUrl | AnyUrl) -> str:
"""Convert server URL to canonical resource URL per RFC 8707.

RFC 8707 section 2 states that resource URIs "MUST NOT include a fragment component".
Returns absolute URI with lowercase scheme/host for canonical form.

Args:
url: Server URL to convert

Returns:
Canonical resource URL string
"""
# Convert to string if needed
url_str = str(url)

# Parse the URL and remove fragment, create canonical form
parsed = urlsplit(url_str)
canonical = urlunsplit(parsed._replace(scheme=parsed.scheme.lower(), netloc=parsed.netloc.lower(), fragment=""))

return canonical


def _check_resource_allowed(requested_resource: str, configured_resource: str) -> bool:
"""Check if a requested resource URL matches a configured resource URL.

A requested resource matches if it has the same scheme, domain, port,
and its path starts with the configured resource's path. This allows
hierarchical matching where a token for a parent resource can be used
for child resources.

Args:
requested_resource: The resource URL being requested
configured_resource: The resource URL that has been configured

Returns:
True if the requested resource matches the configured resource
"""
# Parse both URLs
requested = urlparse(requested_resource)
configured = urlparse(configured_resource)

# Compare scheme, host, and port (origin)
if requested.scheme.lower() != configured.scheme.lower() or requested.netloc.lower() != configured.netloc.lower():
return False

# Handle cases like requested=/foo and configured=/foo/
requested_path = requested.path
configured_path = configured.path

# If requested path is shorter, it cannot be a child
if len(requested_path) < len(configured_path):
return False

# Check if the requested path starts with the configured path
# Ensure both paths end with / for proper comparison
# This ensures that paths like "/api123" don't incorrectly match "/api"
if not requested_path.endswith("/"):
requested_path += "/"
if not configured_path.endswith("/"):
configured_path += "/"

return requested_path.startswith(configured_path)


def _calculate_token_expiry(expires_in: int | str | None) -> float | None:
"""Calculate token expiry timestamp from expires_in seconds.

Args:
expires_in: Seconds until token expiration (may be string from some servers)

Returns:
Unix timestamp when token expires, or None if no expiry specified
"""
if expires_in is None:
return None # pragma: no cover
# Defensive: handle servers that return expires_in as string
return time.time() + int(expires_in)
14 changes: 8 additions & 6 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations as _annotations

import logging
from typing import Any, Protocol, overload

Expand All @@ -22,22 +24,22 @@
class SamplingFnT(Protocol):
async def __call__(
self,
context: RequestContext["ClientSession", Any],
context: RequestContext[ClientSession, Any],
params: types.CreateMessageRequestParams,
) -> types.CreateMessageResult | types.CreateMessageResultWithTools | types.ErrorData: ... # pragma: no branch


class ElicitationFnT(Protocol):
async def __call__(
self,
context: RequestContext["ClientSession", Any],
context: RequestContext[ClientSession, Any],
params: types.ElicitRequestParams,
) -> types.ElicitResult | types.ErrorData: ... # pragma: no branch


class ListRootsFnT(Protocol):
async def __call__(
self, context: RequestContext["ClientSession", Any]
self, context: RequestContext[ClientSession, Any]
) -> types.ListRootsResult | types.ErrorData: ... # pragma: no branch


Expand All @@ -62,7 +64,7 @@ async def _default_message_handler(


async def _default_sampling_callback(
context: RequestContext["ClientSession", Any],
context: RequestContext[ClientSession, Any],
params: types.CreateMessageRequestParams,
) -> types.CreateMessageResult | types.CreateMessageResultWithTools | types.ErrorData:
return types.ErrorData(
Expand All @@ -72,7 +74,7 @@ async def _default_sampling_callback(


async def _default_elicitation_callback(
context: RequestContext["ClientSession", Any],
context: RequestContext[ClientSession, Any],
params: types.ElicitRequestParams,
) -> types.ElicitResult | types.ErrorData:
return types.ErrorData( # pragma: no cover
Expand All @@ -82,7 +84,7 @@ async def _default_elicitation_callback(


async def _default_list_roots_callback(
context: RequestContext["ClientSession", Any],
context: RequestContext[ClientSession, Any],
) -> types.ListRootsResult | types.ErrorData:
return types.ErrorData(
code=types.INVALID_REQUEST,
Expand Down
2 changes: 1 addition & 1 deletion src/mcp/client/session_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@


class SseServerParameters(BaseModel):
"""Parameters for intializing a sse_client."""
"""Parameters for initializing a sse_client."""

# The endpoint URL.
url: str
Expand Down
10 changes: 3 additions & 7 deletions src/mcp/client/websocket.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations as _annotations

import json
import logging
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager

Expand All @@ -12,8 +13,6 @@
import mcp.types as types
from mcp.shared.message import SessionMessage

logger = logging.getLogger(__name__)


@asynccontextmanager
async def websocket_client(
Expand Down Expand Up @@ -64,10 +63,7 @@ async def ws_reader():
await read_stream_writer.send(exc)

async def ws_writer():
"""
Reads JSON-RPC messages from write_stream_reader and
sends them to the server.
"""
"""Reads JSON-RPC messages from write_stream_reader and sends them to the server."""
async with write_stream_reader:
async for session_message in write_stream_reader:
# Convert to a dict, then to JSON
Expand Down
2 changes: 1 addition & 1 deletion src/mcp/server/fastmcp/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from mcp.server.fastmcp.exceptions import ToolError
from mcp.server.fastmcp.utilities.context_injection import find_context_parameter
from mcp.server.fastmcp.utilities.func_metadata import FuncMetadata, func_metadata
from mcp.shared._tool_name_validation import validate_and_warn_tool_name
from mcp.shared.exceptions import UrlElicitationRequiredError
from mcp.shared.tool_name_validation import validate_and_warn_tool_name
from mcp.types import Icon, ToolAnnotations

if TYPE_CHECKING:
Expand Down
2 changes: 1 addition & 1 deletion src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,11 @@ async def main():
from mcp.server.lowlevel.helper_types import ReadResourceContents
from mcp.server.models import InitializationOptions
from mcp.server.session import ServerSession
from mcp.shared._tool_name_validation import validate_and_warn_tool_name
from mcp.shared.context import RequestContext
from mcp.shared.exceptions import McpError, UrlElicitationRequiredError
from mcp.shared.message import ServerMessageMetadata, SessionMessage
from mcp.shared.session import RequestResponder
from mcp.shared.tool_name_validation import validate_and_warn_tool_name

logger = logging.getLogger(__name__)

Expand Down
5 changes: 1 addition & 4 deletions src/mcp/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@

from pydantic import BaseModel

from mcp.types import (
Icon,
ServerCapabilities,
)
from mcp.types import Icon, ServerCapabilities


class InitializationOptions(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
See: https://modelcontextprotocol.io/specification/2025-11-25/server/tools#tool-names
"""

from __future__ import annotations
from __future__ import annotations as _annotations

import logging
import re
Expand Down
Loading
Loading