Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion src/mcp/shared/auth.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Literal

from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, field_validator
from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, field_serializer, field_validator


class OAuthToken(BaseModel):
Expand Down Expand Up @@ -129,6 +129,11 @@ class OAuthMetadata(BaseModel):
code_challenge_methods_supported: list[str] | None = None
client_id_metadata_document_supported: bool | None = None

@field_serializer("issuer")
def serialize_issuer(self, v: AnyHttpUrl) -> str:
"""Strip trailing slash from issuer URL for RFC 8414 §3.3 compliance."""
return str(v).rstrip("/")


class ProtectedResourceMetadata(BaseModel):
"""RFC 9728 OAuth 2.0 Protected Resource Metadata.
Expand All @@ -151,3 +156,13 @@ class ProtectedResourceMetadata(BaseModel):
dpop_signing_alg_values_supported: list[str] | None = None
# dpop_bound_access_tokens_required default is False, but ommited here for clarity
dpop_bound_access_tokens_required: bool | None = None

@field_serializer("resource")
def serialize_resource(self, v: AnyHttpUrl) -> str:
"""Strip trailing slash from resource URL for RFC 9728 §3 compliance."""
return str(v).rstrip("/")

@field_serializer("authorization_servers")
def serialize_authorization_servers(self, v: list[AnyHttpUrl]) -> list[str]:
"""Strip trailing slashes from authorization server URLs for RFC 9728 §3 compliance."""
return [str(s).rstrip("/") for s in v]
4 changes: 2 additions & 2 deletions tests/server/auth/test_protected_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ async def test_metadata_endpoint_without_path(root_resource_client: httpx.AsyncC
assert response.status_code == 200
assert response.json() == snapshot(
{
"resource": "https://example.com/",
"authorization_servers": ["https://auth.example.com/"],
"resource": "https://example.com",
"authorization_servers": ["https://auth.example.com"],
"scopes_supported": ["read"],
"resource_name": "Root Resource",
"bearer_methods_supported": ["header"],
Expand Down
240 changes: 240 additions & 0 deletions tests/server/auth/test_trailing_slash_fix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
"""Tests for OAuth metadata URL trailing slash handling.

These tests verify that trailing slashes are properly stripped from OAuth metadata URLs
to ensure compliance with RFC 8414 §3.3 and RFC 9728 §3, which require that the issuer/
resource URL in the metadata response must be identical to the URL used for discovery.

These tests would fail on main (before the fix) but pass on this branch.
"""

import httpx
import pytest
from pydantic import AnyHttpUrl
from starlette.applications import Starlette

from mcp.server.auth.routes import build_metadata, create_auth_routes, create_protected_resource_routes
from mcp.server.auth.settings import ClientRegistrationOptions, RevocationOptions
from tests.server.fastmcp.auth.test_auth_integration import MockOAuthProvider


def test_build_metadata_strips_trailing_slash_from_issuer():
"""Test that build_metadata strips trailing slash from issuer URL when serialized.

Pydantic's AnyHttpUrl automatically adds trailing slashes to bare hostnames.
This test verifies that we strip them during serialization to comply with RFC 8414 §3.3.
"""
# Use a bare hostname URL which Pydantic will add a trailing slash to
issuer_url = AnyHttpUrl("http://localhost:8000")

metadata = build_metadata(
issuer_url=issuer_url,
service_documentation_url=None,
client_registration_options=ClientRegistrationOptions(enabled=False),
revocation_options=RevocationOptions(enabled=False),
)

# The serialized issuer should NOT have a trailing slash
serialized = metadata.model_dump(mode="json")
assert serialized["issuer"] == "http://localhost:8000"
assert not serialized["issuer"].endswith("/")


def test_build_metadata_strips_trailing_slash_from_issuer_with_path():
"""Test that build_metadata strips trailing slash from issuer URL with path when serialized."""
# URL with path that has trailing slash
issuer_url = AnyHttpUrl("http://localhost:8000/auth/")

metadata = build_metadata(
issuer_url=issuer_url,
service_documentation_url=None,
client_registration_options=ClientRegistrationOptions(enabled=False),
revocation_options=RevocationOptions(enabled=False),
)

# The serialized issuer should NOT have a trailing slash
serialized = metadata.model_dump(mode="json")
assert serialized["issuer"] == "http://localhost:8000/auth"
assert not serialized["issuer"].endswith("/")


def test_build_metadata_endpoints_have_no_double_slashes():
"""Test that endpoint URLs don't have double slashes when issuer has trailing slash."""
# Use a URL that Pydantic will add trailing slash to
issuer_url = AnyHttpUrl("http://localhost:8000")

metadata = build_metadata(
issuer_url=issuer_url,
service_documentation_url=None,
client_registration_options=ClientRegistrationOptions(enabled=True),
revocation_options=RevocationOptions(enabled=True),
)

# All endpoints should be correctly formed without double slashes
assert str(metadata.authorization_endpoint) == "http://localhost:8000/authorize"
assert str(metadata.token_endpoint) == "http://localhost:8000/token"
assert str(metadata.registration_endpoint) == "http://localhost:8000/register"
assert str(metadata.revocation_endpoint) == "http://localhost:8000/revoke"

# None should have double slashes
assert "//" not in str(metadata.authorization_endpoint).replace("http://", "")
assert "//" not in str(metadata.token_endpoint).replace("http://", "")
assert "//" not in str(metadata.registration_endpoint).replace("http://", "")
assert "//" not in str(metadata.revocation_endpoint).replace("http://", "")


def test_protected_resource_metadata_strips_trailing_slash_from_resource():
"""Test that protected resource metadata strips trailing slash from resource URL when serialized.

RFC 9728 §3 requires that the resource URL in the metadata response must be
identical to the URL used for discovery.
"""
# Use a bare hostname URL which Pydantic will add a trailing slash to
resource_url = AnyHttpUrl("http://localhost:8000")
auth_server_url = AnyHttpUrl("http://auth.example.com")

routes = create_protected_resource_routes(
resource_url=resource_url,
authorization_servers=[auth_server_url],
)

# Extract metadata from the handler
# The handler is wrapped in CORS middleware, so we need to unwrap it
route = routes[0]
# Access the app inside the middleware
cors_app = route.endpoint
handler = cors_app.app.func # type: ignore

metadata = handler.__self__.metadata # type: ignore

# The serialized resource URL should NOT have a trailing slash
serialized = metadata.model_dump(mode="json")
assert serialized["resource"] == "http://localhost:8000"
assert not serialized["resource"].endswith("/")


def test_protected_resource_metadata_strips_trailing_slash_from_authorization_servers():
"""Test that protected resource metadata strips trailing slashes from auth server URLs when serialized."""
resource_url = AnyHttpUrl("http://localhost:8000/resource")
# Use bare hostname URLs which Pydantic will add trailing slashes to
auth_servers = [
AnyHttpUrl("http://auth1.example.com"),
AnyHttpUrl("http://auth2.example.com"),
]

routes = create_protected_resource_routes(
resource_url=resource_url,
authorization_servers=auth_servers,
)

# Extract metadata from the handler
route = routes[0]
cors_app = route.endpoint
handler = cors_app.app.func # type: ignore
metadata = handler.__self__.metadata # type: ignore

# All serialized authorization server URLs should NOT have trailing slashes
serialized = metadata.model_dump(mode="json")
assert serialized["authorization_servers"][0] == "http://auth1.example.com"
assert serialized["authorization_servers"][1] == "http://auth2.example.com"
assert not serialized["authorization_servers"][0].endswith("/")
assert not serialized["authorization_servers"][1].endswith("/")


@pytest.fixture
def oauth_provider():
"""Return a MockOAuthProvider instance for testing."""
return MockOAuthProvider()


@pytest.fixture
def app(oauth_provider: MockOAuthProvider):
"""Create a Starlette app with OAuth routes using a bare hostname issuer URL."""
# Use a bare hostname which Pydantic will add a trailing slash to
# This simulates the real-world scenario that was failing
issuer_url = AnyHttpUrl("http://localhost:8000")

auth_routes = create_auth_routes(
oauth_provider,
issuer_url=issuer_url,
client_registration_options=ClientRegistrationOptions(enabled=True),
revocation_options=RevocationOptions(enabled=True),
)

return Starlette(routes=auth_routes)


@pytest.fixture
def client(app: Starlette):
"""Create an HTTP client for the OAuth app."""
transport = httpx.ASGITransport(app=app)
return httpx.AsyncClient(transport=transport, base_url="http://localhost:8000")


@pytest.mark.anyio
async def test_oauth_metadata_endpoint_has_no_trailing_slash_in_issuer(client: httpx.AsyncClient):
"""Test that the OAuth metadata endpoint returns issuer without trailing slash.

This is the integration test that verifies the fix works end-to-end.
This test would FAIL on main because the issuer would have a trailing slash.
"""
response = await client.get("/.well-known/oauth-authorization-server")

assert response.status_code == 200
metadata = response.json()

# The issuer should NOT have a trailing slash
assert metadata["issuer"] == "http://localhost:8000"
assert not metadata["issuer"].endswith("/")

# Endpoints should be correctly formed
assert metadata["authorization_endpoint"] == "http://localhost:8000/authorize"
assert metadata["token_endpoint"] == "http://localhost:8000/token"
assert metadata["registration_endpoint"] == "http://localhost:8000/register"
assert metadata["revocation_endpoint"] == "http://localhost:8000/revoke"


@pytest.fixture
def protected_resource_app():
"""Create a Starlette app with protected resource routes using bare hostname URLs."""
# Use bare hostname URLs which Pydantic will add trailing slashes to
resource_url = AnyHttpUrl("http://localhost:9000")
auth_servers = [AnyHttpUrl("http://auth.example.com")]

routes = create_protected_resource_routes(
resource_url=resource_url,
authorization_servers=auth_servers,
scopes_supported=["read", "write"],
)

return Starlette(routes=routes)


@pytest.fixture
def protected_resource_client(protected_resource_app: Starlette):
"""Create an HTTP client for the protected resource app."""
transport = httpx.ASGITransport(app=protected_resource_app)
return httpx.AsyncClient(transport=transport, base_url="http://localhost:9000")


@pytest.mark.anyio
async def test_protected_resource_metadata_endpoint_has_no_trailing_slashes(
protected_resource_client: httpx.AsyncClient,
):
"""Test that protected resource metadata endpoint returns URLs without trailing slashes.

This integration test verifies the fix for protected resource metadata.
This test would FAIL on main because resource and authorization_servers would have trailing slashes.
"""
response = await protected_resource_client.get("/.well-known/oauth-protected-resource")

assert response.status_code == 200
metadata = response.json()

# The resource URL should NOT have a trailing slash
assert metadata["resource"] == "http://localhost:9000"
assert not metadata["resource"].endswith("/")

# Authorization server URLs should NOT have trailing slashes
assert metadata["authorization_servers"] == ["http://auth.example.com"]
for auth_server in metadata["authorization_servers"]:
assert not auth_server.endswith("/")
2 changes: 1 addition & 1 deletion tests/server/fastmcp/auth/test_auth_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ async def test_metadata_endpoint(self, test_client: httpx.AsyncClient):
assert response.status_code == 200

metadata = response.json()
assert metadata["issuer"] == "https://auth.example.com/"
assert metadata["issuer"] == "https://auth.example.com"
assert metadata["authorization_endpoint"] == "https://auth.example.com/authorize"
assert metadata["token_endpoint"] == "https://auth.example.com/token"
assert metadata["registration_endpoint"] == "https://auth.example.com/register"
Expand Down
Loading