Skip to content

fix: attempt to query resource-specific protected resource metadata before root PRM #1142

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions src/mcp/client/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,13 +203,24 @@ def __init__(
)
self._initialized = False

async def _discover_protected_resource(self) -> httpx.Request:
def _build_well_known_path_protected_resource(self, pathname: str) -> str:
"""Construct well-known path for OAuth protected resource metadata discovery."""
well_known_path = f"/.well-known/oauth-protected-resource{pathname}"
if pathname.endswith("/"):
# Strip trailing slash from pathname to avoid double slashes
well_known_path = well_known_path[:-1]
return well_known_path

async def _discover_protected_resource(self, is_fallback: bool = False) -> httpx.Request:
"""Build discovery request for protected resource metadata."""
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
url = urljoin(auth_base_url, "/.well-known/oauth-protected-resource")
auth_url_parsed = urlparse(self.context.server_url)
pathname = auth_url_parsed.path if not is_fallback else "/"
well_known_path = self._build_well_known_path_protected_resource(pathname)
url = urljoin(auth_base_url, well_known_path)
return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION})

async def _handle_protected_resource_response(self, response: httpx.Response) -> None:
async def _handle_protected_resource_response(self, response: httpx.Response) -> bool:
"""Handle discovery response."""
if response.status_code == 200:
try:
Expand All @@ -218,8 +229,10 @@ async def _handle_protected_resource_response(self, response: httpx.Response) ->
self.context.protected_resource_metadata = metadata
if metadata.authorization_servers:
self.context.auth_server_url = str(metadata.authorization_servers[0])
return True
except ValidationError:
pass
return False

def _build_well_known_path(self, pathname: str) -> str:
"""Construct well-known path for OAuth metadata discovery."""
Expand Down Expand Up @@ -497,7 +510,13 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
# Step 1: Discover protected resource metadata (spec revision 2025-06-18)
discovery_request = await self._discover_protected_resource()
discovery_response = yield discovery_request
await self._handle_protected_resource_response(discovery_response)
discovery_handled = await self._handle_protected_resource_response(discovery_response)

# If path-aware discovery failed, try fallback to root
if not discovery_handled:
discovery_request = await self._discover_protected_resource(is_fallback=True)
discovery_response = yield discovery_request
await self._handle_protected_resource_response(discovery_response)

# Step 2: Discover OAuth metadata (with fallback for legacy servers)
oauth_request = await self._discover_oauth_metadata()
Expand Down Expand Up @@ -549,7 +568,13 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
# Step 1: Discover protected resource metadata (spec revision 2025-06-18)
discovery_request = await self._discover_protected_resource()
discovery_response = yield discovery_request
await self._handle_protected_resource_response(discovery_response)
discovery_handled = await self._handle_protected_resource_response(discovery_response)

# If path-aware discovery failed, try fallback to root
if not discovery_handled:
discovery_request = await self._discover_protected_resource(is_fallback=True)
discovery_response = yield discovery_request
await self._handle_protected_resource_response(discovery_response)

# Step 2: Discover OAuth metadata (with fallback for legacy servers)
oauth_request = await self._discover_oauth_metadata()
Expand Down
11 changes: 11 additions & 0 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,17 @@ async def test_discover_protected_resource_request(self, oauth_provider):
request = await oauth_provider._discover_protected_resource()

assert request.method == "GET"
assert str(request.url) == "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp"
assert "mcp-protocol-version" in request.headers

@pytest.mark.anyio
async def test_discover_protected_resource_request_fallback(self, oauth_provider):
"""Test protected resource discovery request building after a failure to discover metadata at the
standard endpoint."""
request = await oauth_provider._discover_protected_resource(is_fallback=True)

assert request.method == "GET"
# Falls back to the root
assert str(request.url) == "https://api.example.com/.well-known/oauth-protected-resource"
assert "mcp-protocol-version" in request.headers

Expand Down
Loading