diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000..12301490 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,6 @@ +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "daily" diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index a110e2ca..70cf628d 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -18,11 +18,13 @@ jobs: strategy: matrix: platform: ["ubuntu-latest", "windows-latest"] - python-version: ["3.6", "3.7", "3.8", "3.9", "3.10"] + python-version: ["3.7", "3.8", "3.9", "3.10", + "3.11.0-beta - 3.11", "pypy-3.8", "pypy-3.9"] steps: - - uses: "actions/checkout@v2" - - uses: "actions/setup-python@v2" + - uses: "actions/checkout@v3" + - uses: "actions/setup-python@v4" + with: python-version: "${{ matrix.python-version }}" @@ -38,7 +40,7 @@ jobs: # We always use a modern Python version for combining coverage to prevent # parsing errors in older versions for modern code. - - uses: "actions/setup-python@v2" + - uses: "actions/setup-python@v4" with: python-version: "3.8" @@ -52,7 +54,7 @@ jobs: - name: "Upload coverage to Codecov" if: "contains(env.USING_COVERAGE, matrix.python-version) && matrix.platform == 'ubuntu-latest'" - uses: "codecov/codecov-action@v1" + uses: "codecov/codecov-action@v3" with: fail_ci_if_error: true @@ -61,8 +63,9 @@ jobs: runs-on: "ubuntu-latest" steps: - - uses: "actions/checkout@v2" - - uses: "actions/setup-python@v2" + - uses: "actions/checkout@v3" + - uses: "actions/setup-python@v4" + with: python-version: "3.8" @@ -87,8 +90,9 @@ jobs: runs-on: "${{ matrix.os }}" steps: - - uses: "actions/checkout@v2" - - uses: "actions/setup-python@v2" + - uses: "actions/checkout@v3" + - uses: "actions/setup-python@v4" + with: python-version: "3.8" diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml new file mode 100644 index 00000000..82b7549a --- /dev/null +++ b/.github/workflows/stale.yml @@ -0,0 +1,25 @@ +name: 'Stale issue handler' +on: + workflow_dispatch: + schedule: + - cron: '30 1 * * *' + +permissions: + issues: write + pull-requests: write + +jobs: + stale: + runs-on: ubuntu-latest + steps: + - uses: actions/stale@v5 + id: stale + with: + stale-issue-message: 'This issue is stale because it has been open 60 days with no activity. Remove stale label or comment or this will be closed in 7 days' + days-before-stale: 60 + days-before-close: 7 + stale-issue-label: stale + stale-pr-label: stale + exempt-issue-labels: 'blocked,must,should,keep' + - name: Print outputs + run: echo ${{ join(steps.stale.outputs.*, ',') }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 890329d4..506dc820 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,18 +1,18 @@ repos: - repo: https://github.com/psf/black - rev: 22.3.0 + rev: 22.6.0 hooks: - id: black - args: ["--target-version=py36"] + args: ["--target-version=py37"] - repo: https://github.com/asottile/blacken-docs rev: v1.12.1 hooks: - id: blacken-docs - args: ["--target-version=py36"] + args: ["--target-version=py37"] - repo: https://github.com/PyCQA/flake8 - rev: 4.0.1 + rev: 5.0.4 hooks: - id: flake8 language_version: python3.8 @@ -23,7 +23,7 @@ repos: - id: isort - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.2.0 + rev: v4.3.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer @@ -34,3 +34,8 @@ repos: hooks: - id: check-manifest args: [--no-build-isolation] + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: "v0.971" + hooks: + - id: mypy diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 8b3b1467..b399dc42 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -4,7 +4,7 @@ Changelog All notable changes to this project will be documented in this file. This project adheres to `Semantic Versioning `__. -`Unreleased `__ +`Unreleased `__ ----------------------------------------------------------------------- Changed @@ -16,6 +16,38 @@ Fixed Added ~~~~~ +`v2.4.0 `__ +----------------------------------------------------------------------- + +Changed +~~~~~~~ + +- Skip keys with incompatible alg when loading JWKSet by @DaGuich in `#762 `__ +- Remove support for python3.6 by @sirosen in `#777 `__ +- Emit a deprecation warning for unsupported kwargs by @sirosen in `#776 `__ +- Remove redundant wheel dep from pyproject.toml by @mgorny in `#765 `__ +- Do not fail when an unusable key occurs by @DaGuich in `#762 `__ +- Update audience typing by @JulianMaurin in `#782 `__ +- Improve PyJWKSet error accuracy by @JulianMaurin in `#786 `__ +- Mypy as pre-commit check + api_jws typing by @JulianMaurin in `#787 `__ + +Fixed +~~~~~ + +- Adjust expected exceptions in option merging tests for PyPy3 by @mgorny in `#763 `__ +- Fixes for pyright on strict mode by @brandon-leapyear in `#747 `__ +- docs: fix simple typo, iinstance -> isinstance by @timgates42 in `#774 `__ +- Fix typo: priot -> prior by @jdufresne in `#780 `__ +- Fix for headers disorder issue by @kadabusha in `#721 `__ + +Added +~~~~~ + +- Add to_jwk static method to ECAlgorithm by @leonsmith in `#732 `__ +- Expose get_algorithm_by_name as new method by @sirosen in `#773 `__ +- Add type hints to jwt/help.py and add missing types dependency by @kkirsche in `#784 `__ +- Add cacheing functionality for JWK set by @wuhaoyujerry in `#781 `__ + `v2.4.0 `__ ----------------------------------------------------------------------- @@ -315,7 +347,7 @@ Pull Requests by @jdufresne - Remove unnecessary Unicode decoding before json.loads() (#542) by @jdufresne -- Remove unnecessary force\_bytes() calls priot to base64url\_decode() +- Remove unnecessary force\_bytes() calls prior to base64url\_decode() (#543) by @jdufresne - Remove deprecated arguments from docs (#544) by @jdufresne - Update code blocks in docs (#545) by @jdufresne diff --git a/README.rst b/README.rst index 49aa77a8..432631e7 100644 --- a/README.rst +++ b/README.rst @@ -42,7 +42,7 @@ Usage >>> import jwt >>> encoded = jwt.encode({"some": "payload"}, "secret", algorithm="HS256") >>> print(encoded) - eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzb21lIjoicGF5bG9hZCJ9.Joh1R2dYzkRvDkqv3sygm5YyK8Gi4ShZqbhK2gxcs2U + eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzb21lIjoicGF5bG9hZCJ9.4twFt5NiznN84AWoo1d7KO1T_yoc0Z6XOpOVswacPZg >>> jwt.decode(encoded, "secret", algorithms=["HS256"]) {'some': 'payload'} diff --git a/docs/api.rst b/docs/api.rst index 2f81b1f7..919b6af9 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -62,7 +62,7 @@ API Reference if ``verify_exp``, ``verify_iat``, and ``verify_nbf`` respectively is set to ``True``). - :param Iterable audience: optional, the value for ``verify_aud`` check + :param Union[str, Iterable] audience: optional, the value for ``verify_aud`` check :param str issuer: optional, the value for ``verify_iss`` check :param float leeway: a time margin in seconds for the expiration check :rtype: dict diff --git a/docs/index.rst b/docs/index.rst index 63e67945..5cdf5654 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -32,7 +32,7 @@ Example Usage >>> import jwt >>> encoded_jwt = jwt.encode({"some": "payload"}, "secret", algorithm="HS256") >>> print(encoded_jwt) - eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzb21lIjoicGF5bG9hZCJ9.Joh1R2dYzkRvDkqv3sygm5YyK8Gi4ShZqbhK2gxcs2U + eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzb21lIjoicGF5bG9hZCJ9.4twFt5NiznN84AWoo1d7KO1T_yoc0Z6XOpOVswacPZg >>> jwt.decode(encoded_jwt, "secret", algorithms=["HS256"]) {'some': 'payload'} diff --git a/jwt/__init__.py b/jwt/__init__.py index 6b3f8ab1..9f9eda9a 100644 --- a/jwt/__init__.py +++ b/jwt/__init__.py @@ -1,6 +1,7 @@ from .api_jwk import PyJWK, PyJWKSet from .api_jws import ( PyJWS, + get_algorithm_by_name, get_unverified_header, register_algorithm, unregister_algorithm, @@ -25,7 +26,7 @@ ) from .jwks_client import PyJWKClient -__version__ = "2.4.0" +__version__ = "2.5.0" __title__ = "PyJWT" __description__ = "JSON Web Token implementation in Python" @@ -51,6 +52,7 @@ "get_unverified_header", "register_algorithm", "unregister_algorithm", + "get_algorithm_by_name", # Exceptions "DecodeError", "ExpiredSignatureError", diff --git a/jwt/algorithms.py b/jwt/algorithms.py index 46a1a532..93fadf4c 100644 --- a/jwt/algorithms.py +++ b/jwt/algorithms.py @@ -439,6 +439,41 @@ def verify(self, msg, key, sig): except InvalidSignature: return False + @staticmethod + def to_jwk(key_obj): + + if isinstance(key_obj, EllipticCurvePrivateKey): + public_numbers = key_obj.public_key().public_numbers() + elif isinstance(key_obj, EllipticCurvePublicKey): + public_numbers = key_obj.public_numbers() + else: + raise InvalidKeyError("Not a public or private key") + + if isinstance(key_obj.curve, ec.SECP256R1): + crv = "P-256" + elif isinstance(key_obj.curve, ec.SECP384R1): + crv = "P-384" + elif isinstance(key_obj.curve, ec.SECP521R1): + crv = "P-521" + elif isinstance(key_obj.curve, ec.SECP256K1): + crv = "secp256k1" + else: + raise InvalidKeyError(f"Invalid curve: {key_obj.curve}") + + obj = { + "kty": "EC", + "crv": crv, + "x": to_base64url_uint(public_numbers.x).decode(), + "y": to_base64url_uint(public_numbers.y).decode(), + } + + if isinstance(key_obj, EllipticCurvePrivateKey): + obj["d"] = to_base64url_uint( + key_obj.private_numbers().private_value + ).decode() + + return json.dumps(obj) + @staticmethod def from_jwk(jwk): try: @@ -574,7 +609,7 @@ def sign(self, msg, key): Sign a message ``msg`` using the EdDSA private key ``key`` :param str|bytes msg: Message to sign :param Ed25519PrivateKey}Ed448PrivateKey key: A :class:`.Ed25519PrivateKey` - or :class:`.Ed448PrivateKey` iinstance + or :class:`.Ed448PrivateKey` isinstance :return bytes signature: The signature, as bytes """ msg = bytes(msg, "utf-8") if type(msg) is not bytes else msg diff --git a/jwt/api_jwk.py b/jwt/api_jwk.py index 31250d57..aa3dd321 100644 --- a/jwt/api_jwk.py +++ b/jwt/api_jwk.py @@ -1,4 +1,7 @@ +from __future__ import annotations + import json +import time from .algorithms import get_default_algorithms from .exceptions import InvalidKeyError, PyJWKError, PyJWKSetError @@ -74,17 +77,24 @@ def public_key_use(self): class PyJWKSet: - def __init__(self, keys): + def __init__(self, keys: list[dict]) -> None: self.keys = [] - if not keys or not isinstance(keys, list): - raise PyJWKSetError("Invalid JWK Set value") - - if len(keys) == 0: + if not keys: raise PyJWKSetError("The JWK Set did not contain any keys") + if not isinstance(keys, list): + raise PyJWKSetError("Invalid JWK Set value") + for key in keys: - self.keys.append(PyJWK(key)) + try: + self.keys.append(PyJWK(key)) + except PyJWKError: + # skip unusable keys + continue + + if len(self.keys) == 0: + raise PyJWKSetError("The JWK Set did not contain any usable keys") @staticmethod def from_dict(obj): @@ -101,3 +111,15 @@ def __getitem__(self, kid): if key.key_id == kid: return key raise KeyError(f"keyset has no key for kid: {kid}") + + +class PyJWTSetWithTimestamp: + def __init__(self, jwk_set: PyJWKSet): + self.jwk_set = jwk_set + self.timestamp = time.monotonic() + + def get_jwk_set(self): + return self.jwk_set + + def get_timestamp(self): + return self.timestamp diff --git a/jwt/api_jws.py b/jwt/api_jws.py index cbf4f6f5..ab8490f9 100644 --- a/jwt/api_jws.py +++ b/jwt/api_jws.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import binascii import json -from collections.abc import Mapping -from typing import Any, Dict, List, Optional, Type +import warnings +from typing import Any, Type from .algorithms import ( Algorithm, @@ -16,12 +18,13 @@ InvalidTokenError, ) from .utils import base64url_decode, base64url_encode +from .warnings import RemovedInPyjwt3Warning class PyJWS: header_typ = "JWT" - def __init__(self, algorithms=None, options=None): + def __init__(self, algorithms=None, options=None) -> None: self._algorithms = get_default_algorithms() self._valid_algs = ( set(algorithms) if algorithms is not None else set(self._algorithms) @@ -37,10 +40,10 @@ def __init__(self, algorithms=None, options=None): self.options = {**self._get_default_options(), **options} @staticmethod - def _get_default_options(): + def _get_default_options() -> dict[str, bool]: return {"verify_signature": True} - def register_algorithm(self, alg_id, alg_obj): + def register_algorithm(self, alg_id: str, alg_obj: Algorithm) -> None: """ Registers a new Algorithm for use when creating and verifying tokens. """ @@ -53,7 +56,7 @@ def register_algorithm(self, alg_id, alg_obj): self._algorithms[alg_id] = alg_obj self._valid_algs.add(alg_id) - def unregister_algorithm(self, alg_id): + def unregister_algorithm(self, alg_id: str) -> None: """ Unregisters an Algorithm for use when creating and verifying tokens Throws KeyError if algorithm is not registered. @@ -67,38 +70,55 @@ def unregister_algorithm(self, alg_id): del self._algorithms[alg_id] self._valid_algs.remove(alg_id) - def get_algorithms(self): + def get_algorithms(self) -> list[str]: """ Returns a list of supported values for the 'alg' parameter. """ return list(self._valid_algs) + def get_algorithm_by_name(self, alg_name: str) -> Algorithm: + """ + For a given string name, return the matching Algorithm object. + + Example usage: + + >>> jws_obj.get_algorithm_by_name("RS256") + """ + try: + return self._algorithms[alg_name] + except KeyError as e: + if not has_crypto and alg_name in requires_cryptography: + raise NotImplementedError( + f"Algorithm '{alg_name}' could not be found. Do you have cryptography installed?" + ) from e + raise NotImplementedError("Algorithm not supported") from e + def encode( self, payload: bytes, key: str, - algorithm: Optional[str] = "HS256", - headers: Optional[Dict] = None, - json_encoder: Optional[Type[json.JSONEncoder]] = None, + algorithm: str | None = "HS256", + headers: dict[str, Any] | None = None, + json_encoder: Type[json.JSONEncoder] | None = None, is_payload_detached: bool = False, ) -> str: segments = [] - if algorithm is None: - algorithm = "none" + # declare a new var to narrow the type for type checkers + algorithm_: str = algorithm if algorithm is not None else "none" # Prefer headers values if present to function parameters. if headers: headers_alg = headers.get("alg") if headers_alg: - algorithm = headers["alg"] + algorithm_ = headers["alg"] headers_b64 = headers.get("b64") if headers_b64 is False: is_payload_detached = True # Header - header = {"typ": self.header_typ, "alg": algorithm} # type: Dict[str, Any] + header: dict[str, Any] = {"typ": self.header_typ, "alg": algorithm_} if headers: self._validate_headers(headers) @@ -113,8 +133,9 @@ def encode( # True is the standard value for b64, so no need for it del header["b64"] + # Fix for headers misorder - issue #715 json_header = json.dumps( - header, separators=(",", ":"), cls=json_encoder + header, separators=(",", ":"), cls=json_encoder, sort_keys=True ).encode() segments.append(base64url_encode(json_header)) @@ -128,17 +149,9 @@ def encode( # Segments signing_input = b".".join(segments) - try: - alg_obj = self._algorithms[algorithm] - key = alg_obj.prepare_key(key) - signature = alg_obj.sign(signing_input, key) - - except KeyError as e: - if not has_crypto and algorithm in requires_cryptography: - raise NotImplementedError( - f"Algorithm '{algorithm}' could not be found. Do you have cryptography installed?" - ) from e - raise NotImplementedError("Algorithm not supported") from e + alg_obj = self.get_algorithm_by_name(algorithm_) + key = alg_obj.prepare_key(key) + signature = alg_obj.sign(signing_input, key) segments.append(base64url_encode(signature)) @@ -153,11 +166,18 @@ def decode_complete( self, jwt: str, key: str = "", - algorithms: Optional[List[str]] = None, - options: Optional[Dict] = None, - detached_payload: Optional[bytes] = None, + algorithms: list[str] | None = None, + options: dict[str, Any] | None = None, + detached_payload: bytes | None = None, **kwargs, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: + if kwargs: + warnings.warn( + "passing additional kwargs to decode_complete() is deprecated " + "and will be removed in pyjwt version 3. " + f"Unsupported kwargs: {tuple(kwargs.keys())}", + RemovedInPyjwt3Warning, + ) if options is None: options = {} merged_options = {**self.options, **options} @@ -191,14 +211,24 @@ def decode( self, jwt: str, key: str = "", - algorithms: Optional[List[str]] = None, - options: Optional[Dict] = None, + algorithms: list[str] | None = None, + options: dict[str, Any] | None = None, + detached_payload: bytes | None = None, **kwargs, ) -> str: - decoded = self.decode_complete(jwt, key, algorithms, options, **kwargs) + if kwargs: + warnings.warn( + "passing additional kwargs to decode() is deprecated " + "and will be removed in pyjwt version 3. " + f"Unsupported kwargs: {tuple(kwargs.keys())}", + RemovedInPyjwt3Warning, + ) + decoded = self.decode_complete( + jwt, key, algorithms, options, detached_payload=detached_payload + ) return decoded["payload"] - def get_unverified_header(self, jwt): + def get_unverified_header(self, jwt: str | bytes) -> dict: """Returns back the JWT header parameters as a dict() Note: The signature is not verified so the header parameters @@ -209,7 +239,7 @@ def get_unverified_header(self, jwt): return headers - def _load(self, jwt): + def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict, bytes]: if isinstance(jwt, str): jwt = jwt.encode("utf-8") @@ -232,7 +262,7 @@ def _load(self, jwt): except ValueError as e: raise DecodeError(f"Invalid header string: {e}") from e - if not isinstance(header, Mapping): + if not isinstance(header, dict): raise DecodeError("Invalid header string: must be a json object") try: @@ -249,33 +279,32 @@ def _load(self, jwt): def _verify_signature( self, - signing_input, - header, - signature, - key="", - algorithms=None, - ): + signing_input: bytes, + header: dict, + signature: bytes, + key: str = "", + algorithms: list[str] | None = None, + ) -> None: alg = header.get("alg") - if algorithms is not None and alg not in algorithms: + if not alg or (algorithms is not None and alg not in algorithms): raise InvalidAlgorithmError("The specified alg value is not allowed") try: - alg_obj = self._algorithms[alg] - key = alg_obj.prepare_key(key) - - if not alg_obj.verify(signing_input, key, signature): - raise InvalidSignatureError("Signature verification failed") - - except KeyError as e: + alg_obj = self.get_algorithm_by_name(alg) + except NotImplementedError as e: raise InvalidAlgorithmError("Algorithm not supported") from e + key = alg_obj.prepare_key(key) + + if not alg_obj.verify(signing_input, key, signature): + raise InvalidSignatureError("Signature verification failed") - def _validate_headers(self, headers): + def _validate_headers(self, headers: dict[str, Any]) -> None: if "kid" in headers: self._validate_kid(headers["kid"]) - def _validate_kid(self, kid): + def _validate_kid(self, kid: str) -> None: if not isinstance(kid, str): raise InvalidTokenError("Key ID header parameter must be a string") @@ -286,4 +315,5 @@ def _validate_kid(self, kid): decode = _jws_global_obj.decode register_algorithm = _jws_global_obj.register_algorithm unregister_algorithm = _jws_global_obj.unregister_algorithm +get_algorithm_by_name = _jws_global_obj.get_algorithm_by_name get_unverified_header = _jws_global_obj.get_unverified_header diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py index 7d2177bf..91a6d2e8 100644 --- a/jwt/api_jwt.py +++ b/jwt/api_jwt.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import warnings from calendar import timegm @@ -15,6 +17,7 @@ InvalidIssuerError, MissingRequiredClaimError, ) +from .warnings import RemovedInPyjwt3Warning class PyJWT: @@ -40,7 +43,7 @@ def encode( payload: Dict[str, Any], key: str, algorithm: Optional[str] = "HS256", - headers: Optional[Dict] = None, + headers: Optional[Dict[str, Any]] = None, json_encoder: Optional[Type[json.JSONEncoder]] = None, ) -> str: # Check that we get a mapping @@ -68,16 +71,33 @@ def decode_complete( jwt: str, key: str = "", algorithms: Optional[List[str]] = None, - options: Optional[Dict] = None, + options: Optional[Dict[str, Any]] = None, + # deprecated arg, remove in pyjwt3 + verify: Optional[bool] = None, + # could be used as passthrough to api_jws, consider removal in pyjwt3 + detached_payload: Optional[bytes] = None, + # passthrough arguments to _validate_claims + # consider putting in options + audience: Optional[Union[str, Iterable[str]]] = None, + issuer: Optional[str] = None, + leeway: Union[int, float, timedelta] = 0, + # kwargs **kwargs, ) -> Dict[str, Any]: + if kwargs: + warnings.warn( + "passing additional kwargs to decode_complete() is deprecated " + "and will be removed in pyjwt version 3. " + f"Unsupported kwargs: {tuple(kwargs.keys())}", + RemovedInPyjwt3Warning, + ) options = dict(options or {}) # shallow-copy or initialize an empty dict options.setdefault("verify_signature", True) # If the user has set the legacy `verify` argument, and it doesn't match # what the relevant `options` entry for the argument is, inform the user # that they're likely making a mistake. - if "verify" in kwargs and kwargs["verify"] != options["verify_signature"]: + if verify is not None and verify != options["verify_signature"]: warnings.warn( "The `verify` argument to `decode` does nothing in PyJWT 2.0 and newer. " "The equivalent is setting `verify_signature` to False in the `options` dictionary. " @@ -102,7 +122,7 @@ def decode_complete( key=key, algorithms=algorithms, options=options, - **kwargs, + detached_payload=detached_payload, ) try: @@ -113,7 +133,9 @@ def decode_complete( raise DecodeError("Invalid payload string: must be a json object") merged_options = {**self.options, **options} - self._validate_claims(payload, merged_options, **kwargs) + self._validate_claims( + payload, merged_options, audience=audience, issuer=issuer, leeway=leeway + ) decoded["payload"] = payload return decoded @@ -123,20 +145,45 @@ def decode( jwt: str, key: str = "", algorithms: Optional[List[str]] = None, - options: Optional[Dict] = None, + options: Optional[Dict[str, Any]] = None, + # deprecated arg, remove in pyjwt3 + verify: Optional[bool] = None, + # could be used as passthrough to api_jws, consider removal in pyjwt3 + detached_payload: Optional[bytes] = None, + # passthrough arguments to _validate_claims + # consider putting in options + audience: Optional[Union[str, Iterable[str]]] = None, + issuer: Optional[str] = None, + leeway: Union[int, float, timedelta] = 0, + # kwargs **kwargs, ) -> Dict[str, Any]: - decoded = self.decode_complete(jwt, key, algorithms, options, **kwargs) + if kwargs: + warnings.warn( + "passing additional kwargs to decode() is deprecated " + "and will be removed in pyjwt version 3. " + f"Unsupported kwargs: {tuple(kwargs.keys())}", + RemovedInPyjwt3Warning, + ) + decoded = self.decode_complete( + jwt, + key, + algorithms, + options, + verify=verify, + detached_payload=detached_payload, + audience=audience, + issuer=issuer, + leeway=leeway, + ) return decoded["payload"] - def _validate_claims( - self, payload, options, audience=None, issuer=None, leeway=0, **kwargs - ): + def _validate_claims(self, payload, options, audience=None, issuer=None, leeway=0): if isinstance(leeway, timedelta): leeway = leeway.total_seconds() - if not isinstance(audience, (bytes, str, type(None), Iterable)): - raise TypeError("audience must be a string, iterable, or None") + if audience is not None and not isinstance(audience, (str, Iterable)): + raise TypeError("audience must be a string, iterable or None") self._validate_required_claims(payload, options) diff --git a/jwt/help.py b/jwt/help.py index d5c3ebbf..0c02eb92 100644 --- a/jwt/help.py +++ b/jwt/help.py @@ -1,16 +1,17 @@ import json import platform import sys +from typing import Dict from . import __version__ as pyjwt_version try: import cryptography except ModuleNotFoundError: - cryptography = None # type: ignore + cryptography = None -def info(): +def info() -> Dict[str, Dict[str, str]]: """ Generate information for a bug report. Based on the requests package help utility module. @@ -28,14 +29,15 @@ def info(): if implementation == "CPython": implementation_version = platform.python_version() elif implementation == "PyPy": + pypy_version_info = getattr(sys, "pypy_version_info") implementation_version = ( - f"{sys.pypy_version_info.major}." - f"{sys.pypy_version_info.minor}." - f"{sys.pypy_version_info.micro}" + f"{pypy_version_info.major}." + f"{pypy_version_info.minor}." + f"{pypy_version_info.micro}" ) - if sys.pypy_version_info.releaselevel != "final": + if pypy_version_info.releaselevel != "final": implementation_version = "".join( - [implementation_version, sys.pypy_version_info.releaselevel] + [implementation_version, pypy_version_info.releaselevel] ) else: implementation_version = "Unknown" @@ -51,7 +53,7 @@ def info(): } -def main(): +def main() -> None: """Pretty-print the bug information as JSON.""" print(json.dumps(info(), sort_keys=True, indent=2)) diff --git a/jwt/jwk_set_cache.py b/jwt/jwk_set_cache.py new file mode 100644 index 00000000..e8c2a7e0 --- /dev/null +++ b/jwt/jwk_set_cache.py @@ -0,0 +1,32 @@ +import time +from typing import Optional + +from .api_jwk import PyJWKSet, PyJWTSetWithTimestamp + + +class JWKSetCache: + def __init__(self, lifespan: int): + self.jwk_set_with_timestamp: Optional[PyJWTSetWithTimestamp] = None + self.lifespan = lifespan + + def put(self, jwk_set: PyJWKSet): + if jwk_set is not None: + self.jwk_set_with_timestamp = PyJWTSetWithTimestamp(jwk_set) + else: + # clear cache + self.jwk_set_with_timestamp = None + + def get(self) -> Optional[PyJWKSet]: + if self.jwk_set_with_timestamp is None or self.is_expired(): + return None + + return self.jwk_set_with_timestamp.get_jwk_set() + + def is_expired(self) -> bool: + + return ( + self.jwk_set_with_timestamp is not None + and self.lifespan > -1 + and time.monotonic() + > self.jwk_set_with_timestamp.get_timestamp() + self.lifespan + ) diff --git a/jwt/jwks_client.py b/jwt/jwks_client.py index 767b7179..b4e98007 100644 --- a/jwt/jwks_client.py +++ b/jwt/jwks_client.py @@ -1,31 +1,68 @@ import json import urllib.request from functools import lru_cache -from typing import Any, List +from typing import Any, List, Optional +from urllib.error import URLError from .api_jwk import PyJWK, PyJWKSet from .api_jwt import decode_complete as decode_token from .exceptions import PyJWKClientError +from .jwk_set_cache import JWKSetCache class PyJWKClient: - def __init__(self, uri: str, cache_keys: bool = True, max_cached_keys: int = 16): + def __init__( + self, + uri: str, + cache_keys: bool = False, + max_cached_keys: int = 16, + cache_jwk_set: bool = True, + lifespan: int = 300, + ): self.uri = uri + self.jwk_set_cache: Optional[JWKSetCache] = None + + if cache_jwk_set: + # Init jwt set cache with default or given lifespan. + # Default lifespan is 300 seconds (5 minutes). + if lifespan <= 0: + raise PyJWKClientError( + f'Lifespan must be greater than 0, the input is "{lifespan}"' + ) + self.jwk_set_cache = JWKSetCache(lifespan) + else: + self.jwk_set_cache = None + if cache_keys: # Cache signing keys # Ignore mypy (https://github.com/python/mypy/issues/2427) self.get_signing_key = lru_cache(maxsize=max_cached_keys)(self.get_signing_key) # type: ignore def fetch_data(self) -> Any: - with urllib.request.urlopen(self.uri) as response: - return json.load(response) + jwk_set: Any = None + try: + with urllib.request.urlopen(self.uri) as response: + jwk_set = json.load(response) + except URLError as e: + raise PyJWKClientError(f'Fail to fetch data from the url, err: "{e}"') + else: + return jwk_set + finally: + if self.jwk_set_cache is not None: + self.jwk_set_cache.put(jwk_set) + + def get_jwk_set(self, refresh: bool = False) -> PyJWKSet: + data = None + if self.jwk_set_cache is not None and not refresh: + data = self.jwk_set_cache.get() + + if data is None: + data = self.fetch_data() - def get_jwk_set(self) -> PyJWKSet: - data = self.fetch_data() return PyJWKSet.from_dict(data) - def get_signing_keys(self) -> List[PyJWK]: - jwk_set = self.get_jwk_set() + def get_signing_keys(self, refresh: bool = False) -> List[PyJWK]: + jwk_set = self.get_jwk_set(refresh) signing_keys = [ jwk_set_key for jwk_set_key in jwk_set.keys @@ -39,17 +76,17 @@ def get_signing_keys(self) -> List[PyJWK]: def get_signing_key(self, kid: str) -> PyJWK: signing_keys = self.get_signing_keys() - signing_key = None - - for key in signing_keys: - if key.key_id == kid: - signing_key = key - break + signing_key = self.match_kid(signing_keys, kid) if not signing_key: - raise PyJWKClientError( - f'Unable to find a signing key that matches: "{kid}"' - ) + # If no matching signing key from the jwk set, refresh the jwk set and try again. + signing_keys = self.get_signing_keys(refresh=True) + signing_key = self.match_kid(signing_keys, kid) + + if not signing_key: + raise PyJWKClientError( + f'Unable to find a signing key that matches: "{kid}"' + ) return signing_key @@ -57,3 +94,14 @@ def get_signing_key_from_jwt(self, token: str) -> PyJWK: unverified = decode_token(token, options={"verify_signature": False}) header = unverified["header"] return self.get_signing_key(header.get("kid")) + + @staticmethod + def match_kid(signing_keys: List[PyJWK], kid: str) -> Optional[PyJWK]: + signing_key = None + + for key in signing_keys: + if key.key_id == kid: + signing_key = key + break + + return signing_key diff --git a/jwt/utils.py b/jwt/utils.py index 8ab73b42..16cae066 100644 --- a/jwt/utils.py +++ b/jwt/utils.py @@ -1,7 +1,7 @@ import base64 import binascii import re -from typing import Any, Union +from typing import Union try: from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurve @@ -10,7 +10,7 @@ encode_dss_signature, ) except ModuleNotFoundError: - EllipticCurve = Any # type: ignore + EllipticCurve = None def force_bytes(value: Union[str, bytes]) -> bytes: @@ -136,7 +136,7 @@ def is_pem_format(key: bytes) -> bool: # Based on https://github.com/pyca/cryptography/blob/bcb70852d577b3f490f015378c75cba74986297b/src/cryptography/hazmat/primitives/serialization/ssh.py#L40-L46 _CERT_SUFFIX = b"-cert-v01@openssh.com" -_SSH_PUBKEY_RC = re.compile(br"\A(\S+)[ \t]+(\S+)") +_SSH_PUBKEY_RC = re.compile(rb"\A(\S+)[ \t]+(\S+)") _SSH_KEY_FORMATS = [ b"ssh-ed25519", b"ssh-rsa", diff --git a/jwt/warnings.py b/jwt/warnings.py new file mode 100644 index 00000000..8762a8cb --- /dev/null +++ b/jwt/warnings.py @@ -0,0 +1,2 @@ +class RemovedInPyjwt3Warning(DeprecationWarning): + pass diff --git a/pyproject.toml b/pyproject.toml index 85e9eca2..f4065051 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,8 +1,7 @@ [build-system] -requires = ["setuptools", "wheel"] +requires = ["setuptools"] build-backend = "setuptools.build_meta" - [tool.coverage.run] parallel = true branch = true @@ -14,8 +13,13 @@ source = ["jwt", ".tox/*/site-packages"] [tool.coverage.report] show_missing = true - [tool.isort] profile = "black" atomic = true combine_as_imports = true + +[tool.mypy] +python_version = 3.7 +ignore_missing_imports = true +warn_unused_ignores = true +no_implicit_optional = true diff --git a/setup.cfg b/setup.cfg index 5e0b244c..434e22c2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -23,7 +23,6 @@ classifiers = Programming Language :: Python Programming Language :: Python :: 3 Programming Language :: Python :: 3 :: Only - Programming Language :: Python :: 3.6 Programming Language :: Python :: 3.7 Programming Language :: Python :: 3.8 Programming Language :: Python :: 3.9 @@ -33,7 +32,7 @@ classifiers = [options] zip_safe = false include_package_data = true -python_requires = >=3.6 +python_requires = >=3.7 packages = find: [options.package_data] @@ -41,22 +40,23 @@ packages = find: [options.extras_require] docs = - sphinx + sphinx>=4.5.0,<5.0.0 sphinx-rtd-theme zope.interface crypto = cryptography>=3.3.1 + types-cryptography>=3.3.21 tests = pytest>=6.0.0,<7.0.0 coverage[toml]==5.0.4 dev = - sphinx + sphinx>=4.5.0,<5.0.0 sphinx-rtd-theme zope.interface cryptography>=3.3.1 + types-cryptography>=3.3.21 pytest>=6.0.0,<7.0.0 coverage[toml]==5.0.4 - mypy pre-commit [options.packages.find] @@ -66,9 +66,3 @@ exclude = [flake8] extend-ignore = E203, E501 - -[mypy] -python_version = 3.6 -ignore_missing_imports = true -warn_unused_ignores = true -no_implicit_optional = true diff --git a/tests/keys/jwk_keyset_only_unknown_alg.json b/tests/keys/jwk_keyset_only_unknown_alg.json new file mode 100644 index 00000000..963e62c8 --- /dev/null +++ b/tests/keys/jwk_keyset_only_unknown_alg.json @@ -0,0 +1 @@ +{"keys":[{"kid":"lYXxnemSzWNBUoPug_h0hZnjPi5oKCmQ9awQJaZCWWM","kty":"RSA","alg":"RSA-OAEP","use":"enc","n":"k75Ghd4r8h_fdydTAXyMjrGYNnuiG7yevoW1ZIIuegEUK3LLGY0Z3Q8PhCrkmi6LpkPwwR1C8ck9plvSs4vZ9GqmUoi5YcQEile6HjPG3NBwQ-cHWY4ZH_D-ItdzcZUKDxjHYaY-GW1yLeJ1RAh8wMPM7cenA2v0eNIq4HaIXzZJ2Hgxh4Ei-CSYcD0f_TYEySqUEb8jd0dC8frpkYDkOUCVizRBDUEg_hkPSpVqfLP8ekxIHxkC9wcfL-d2FhptxBQYN8NFnIuG9NFXbZ5mdzdmIuN6WPr_CECcgL9qXsph9U-L829dU67ufeBvzEejJ8qwiswslRdx4ZcYjtaBdQ","e":"AQAB","x5c":["MIICnTCCAYUCBgGAUN05KzANBgkqhkiG9w0BAQsFADASMRAwDgYDVQQDDAdUZXN0aW5nMB4XDTIyMDQyMjEwNDAxN1oXDTMyMDQyMjEwNDE1N1owEjEQMA4GA1UEAwwHVGVzdGluZzCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAJO+RoXeK/If33cnUwF8jI6xmDZ7ohu8nr6FtWSCLnoBFCtyyxmNGd0PD4Qq5Joui6ZD8MEdQvHJPaZb0rOL2fRqplKIuWHEBIpXuh4zxtzQcEPnB1mOGR/w/iLXc3GVCg8Yx2GmPhltci3idUQIfMDDzO3HpwNr9HjSKuB2iF82Sdh4MYeBIvgkmHA9H/02BMkqlBG/I3dHQvH66ZGA5DlAlYs0QQ1BIP4ZD0qVanyz/HpMSB8ZAvcHHy/ndhYabcQUGDfDRZyLhvTRV22eZnc3ZiLjelj6/whAnIC/al7KYfVPi/NvXVOu7n3gb8xHoyfKsIrMLJUXceGXGI7WgXUCAwEAATANBgkqhkiG9w0BAQsFAAOCAQEAeMUFrCX4eAfF8i6wILOP5dDJOBN10nPP63VNliQ7+YHu1ZI0VGB7TNrImRE9riH2IWenSXD21DxK31qBlZKNEgaH7rVwwvOZ22qCyWacv1+QdanxAiljD03rU7HOR/tyqcvjl6U2Yadxcq6OWlKKVaa0fNtbPigqAwQ3iVpg9N+OthANYyKHxlmzJKGeEaDA69/uJ6UwektHlv/9BnNFh8We6EwJxYG7/rejI02EgbJFxGO1RlcmigTxRc5l3Dw4WldBIRxWiJgSEkKSfUy5S7sQdFQokZjTyqy6h1ldb/tgrWLIE0srGQ2u/fQeSgPTbAzihaeOf+WKq5RDXoq5bw=="],"x5t":"FaWinuPZQiDMljn3x9DMAuepBYQ","x5t#S256":"_0B--Hh1KgNtdyZqAp1NWUAikRPvlt2HGm__xXpjTi0"}]} diff --git a/tests/keys/jwk_keyset_with_unknown_alg.json b/tests/keys/jwk_keyset_with_unknown_alg.json new file mode 100644 index 00000000..0a0a7540 --- /dev/null +++ b/tests/keys/jwk_keyset_with_unknown_alg.json @@ -0,0 +1 @@ +{"keys":[{"kid":"U1MayerhVuRj8xtFR8hyMH9lCfVMKlb3TG7mbQAS19M","kty":"RSA","alg":"RS256","use":"sig","n":"omef3NkXf4--6BtUPKjhlV7pf6Vv7HMg-VL-ITX8KQZTD4LTzWO3x9RPwVepKjgfvJe_IiZFaJX78-a7zpcG9mpZG8czp3C8nZSvAJKphvYLd9s9qYrGMFW9t1eHyGwmIQN02VXwHeZ0JDd5X4i7sO4XPkNycfzSoxaQbv7wANYBTcvcWcjYVxIj4ZpYkSsQqrrOTm69G7FyurtfExGc7jlSRcv-Gubq_K3IQLHGHTlil20wqZmis1dLJwpAjgTxY7uQSwEdqJHCJR3q76bsDelIBZpbR07kqIOXqYu52w0wkC_1W7_HcVPLNp6T_ML09P8jGsOWfMO95_zchkseQw","e":"AQAB","x5c":["MIICnTCCAYUCBgGAUN03JTANBgkqhkiG9w0BAQsFADASMRAwDgYDVQQDDAdUZXN0aW5nMB4XDTIyMDQyMjEwNDAxNloXDTMyMDQyMjEwNDE1NlowEjEQMA4GA1UEAwwHVGVzdGluZzCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAKJnn9zZF3+PvugbVDyo4ZVe6X+lb+xzIPlS/iE1/CkGUw+C081jt8fUT8FXqSo4H7yXvyImRWiV+/Pmu86XBvZqWRvHM6dwvJ2UrwCSqYb2C3fbPamKxjBVvbdXh8hsJiEDdNlV8B3mdCQ3eV+Iu7DuFz5DcnH80qMWkG7+8ADWAU3L3FnI2FcSI+GaWJErEKq6zk5uvRuxcrq7XxMRnO45UkXL/hrm6vytyECxxh05YpdtMKmZorNXSycKQI4E8WO7kEsBHaiRwiUd6u+m7A3pSAWaW0dO5KiDl6mLudsNMJAv9Vu/x3FTyzaek/zC9PT/IxrDlnzDvef83IZLHkMCAwEAATANBgkqhkiG9w0BAQsFAAOCAQEAi7ZppYbkpt0ALn5NXIIPgA04svRwAmsUJWKLBS5iKVXq6HOJPsz0GAB9oKpjar83rUomwK2UE0XFJLMDvrB0nTZJBjm2DCANLL1GtTKUd+mdvhyHCIMrUApkhAYzv2Rk1c4+Jt7f5/h8FnM8jdl9FGc5TBy5ixS0OxnyW1JOakClYQz8vNS7LrC4hmLWwy7GAmUdemNLEefQcECaNzaLN5gGk1ht5lJyNCsHu9STZeYM2UXdDAtMtu9HAepfzh2CAOscSDtZr89SmFSwxKaOfbJyXH4PivMgWK4zO0P6ofuv8d8gRbUAUgnysKHQc0isTVWOxgmzI69EUe/iVXJHig=="],"x5t":"0C94xr3ayzaC9OUcSSLyrwDGdmI","x5t#S256":"O6ntIrYkVK0hX-_AwnrwJW1CO97lP3D2_aKnELuNLSo"},{"kid":"lYXxnemSzWNBUoPug_h0hZnjPi5oKCmQ9awQJaZCWWM","kty":"RSA","alg":"RSA-OAEP","use":"enc","n":"k75Ghd4r8h_fdydTAXyMjrGYNnuiG7yevoW1ZIIuegEUK3LLGY0Z3Q8PhCrkmi6LpkPwwR1C8ck9plvSs4vZ9GqmUoi5YcQEile6HjPG3NBwQ-cHWY4ZH_D-ItdzcZUKDxjHYaY-GW1yLeJ1RAh8wMPM7cenA2v0eNIq4HaIXzZJ2Hgxh4Ei-CSYcD0f_TYEySqUEb8jd0dC8frpkYDkOUCVizRBDUEg_hkPSpVqfLP8ekxIHxkC9wcfL-d2FhptxBQYN8NFnIuG9NFXbZ5mdzdmIuN6WPr_CECcgL9qXsph9U-L829dU67ufeBvzEejJ8qwiswslRdx4ZcYjtaBdQ","e":"AQAB","x5c":["MIICnTCCAYUCBgGAUN05KzANBgkqhkiG9w0BAQsFADASMRAwDgYDVQQDDAdUZXN0aW5nMB4XDTIyMDQyMjEwNDAxN1oXDTMyMDQyMjEwNDE1N1owEjEQMA4GA1UEAwwHVGVzdGluZzCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAJO+RoXeK/If33cnUwF8jI6xmDZ7ohu8nr6FtWSCLnoBFCtyyxmNGd0PD4Qq5Joui6ZD8MEdQvHJPaZb0rOL2fRqplKIuWHEBIpXuh4zxtzQcEPnB1mOGR/w/iLXc3GVCg8Yx2GmPhltci3idUQIfMDDzO3HpwNr9HjSKuB2iF82Sdh4MYeBIvgkmHA9H/02BMkqlBG/I3dHQvH66ZGA5DlAlYs0QQ1BIP4ZD0qVanyz/HpMSB8ZAvcHHy/ndhYabcQUGDfDRZyLhvTRV22eZnc3ZiLjelj6/whAnIC/al7KYfVPi/NvXVOu7n3gb8xHoyfKsIrMLJUXceGXGI7WgXUCAwEAATANBgkqhkiG9w0BAQsFAAOCAQEAeMUFrCX4eAfF8i6wILOP5dDJOBN10nPP63VNliQ7+YHu1ZI0VGB7TNrImRE9riH2IWenSXD21DxK31qBlZKNEgaH7rVwwvOZ22qCyWacv1+QdanxAiljD03rU7HOR/tyqcvjl6U2Yadxcq6OWlKKVaa0fNtbPigqAwQ3iVpg9N+OthANYyKHxlmzJKGeEaDA69/uJ6UwektHlv/9BnNFh8We6EwJxYG7/rejI02EgbJFxGO1RlcmigTxRc5l3Dw4WldBIRxWiJgSEkKSfUy5S7sQdFQokZjTyqy6h1ldb/tgrWLIE0srGQ2u/fQeSgPTbAzihaeOf+WKq5RDXoq5bw=="],"x5t":"FaWinuPZQiDMljn3x9DMAuepBYQ","x5t#S256":"_0B--Hh1KgNtdyZqAp1NWUAikRPvlt2HGm__xXpjTi0"}]} diff --git a/tests/keys/testkey_ec_secp192r1.priv b/tests/keys/testkey_ec_secp192r1.priv new file mode 100644 index 00000000..0f4d1c71 --- /dev/null +++ b/tests/keys/testkey_ec_secp192r1.priv @@ -0,0 +1,5 @@ +-----BEGIN PRIVATE KEY----- +MG8CAQAwEwYHKoZIzj0CAQYIKoZIzj0DAQEEVTBTAgEBBBiON6kYcPu8ZUDRTu8W +eXJ2FmX7e9yq0hahNAMyAARHecLjkXWDUJfZ4wiFH61JpmonCYH1GpinVlqw68Sf +wtDHg2F6SifQEFC6VKj1ZXw= +-----END PRIVATE KEY----- diff --git a/tests/test_advisory.py b/tests/test_advisory.py index a4a7d237..ed768d4b 100644 --- a/tests/test_advisory.py +++ b/tests/test_advisory.py @@ -1,14 +1,17 @@ -import jwt import pytest + +import jwt from jwt.exceptions import InvalidKeyError from .utils import crypto_required -priv_key_bytes = b'''-----BEGIN PRIVATE KEY----- +priv_key_bytes = b"""-----BEGIN PRIVATE KEY----- MC4CAQAwBQYDK2VwBCIEIIbBhdo2ah7X32i50GOzrCr4acZTe6BezUdRIixjTAdL ------END PRIVATE KEY-----''' +-----END PRIVATE KEY-----""" -pub_key_bytes = b'ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIPL1I9oiq+B8crkmuV4YViiUnhdLjCp3hvy1bNGuGfNL' +pub_key_bytes = ( + b"ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIPL1I9oiq+B8crkmuV4YViiUnhdLjCp3hvy1bNGuGfNL" +) ssh_priv_key_bytes = b"""-----BEGIN EC PRIVATE KEY----- MHcCAQEEIOWc7RbaNswMtNtc+n6WZDlUblMr2FBPo79fcGXsJlGQoAoGCCqGSM49 @@ -41,11 +44,11 @@ def test_ghsa_ffqj_6fqr_9h24(self): # Making a good jwt token that should work by signing it # with the private key # encoded_good = jwt.encode({"test": 1234}, priv_key_bytes, algorithm="EdDSA") - encoded_good = 'eyJ0eXAiOiJKV1QiLCJhbGciOiJFZERTQSJ9.eyJ0ZXN0IjoxMjM0fQ.M5y1EEavZkHSlj9i8yi9nXKKyPBSAUhDRTOYZi3zZY11tZItDaR3qwAye8pc74_lZY3Ogt9KPNFbVOSGnUBHDg' + encoded_good = "eyJ0eXAiOiJKV1QiLCJhbGciOiJFZERTQSJ9.eyJ0ZXN0IjoxMjM0fQ.M5y1EEavZkHSlj9i8yi9nXKKyPBSAUhDRTOYZi3zZY11tZItDaR3qwAye8pc74_lZY3Ogt9KPNFbVOSGnUBHDg" # Using HMAC with the public key to trick the receiver to think that the # public key is a HMAC secret - encoded_bad = 'eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJ0ZXN0IjoxMjM0fQ.6ulDpqSlbHmQ8bZXhZRLFko9SwcHrghCwh8d-exJEE4' + encoded_bad = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJ0ZXN0IjoxMjM0fQ.6ulDpqSlbHmQ8bZXhZRLFko9SwcHrghCwh8d-exJEE4" # Both of the jwt tokens are validated as valid jwt.decode( @@ -101,12 +104,12 @@ def test_ghsa_ffqj_6fqr_9h24(self): jwt.decode( encoded_good, ssh_key_bytes, - algorithms=jwt.algorithms.get_default_algorithms() + algorithms=jwt.algorithms.get_default_algorithms(), ) with pytest.raises(InvalidKeyError): jwt.decode( encoded_bad, ssh_key_bytes, - algorithms=jwt.algorithms.get_default_algorithms() + algorithms=jwt.algorithms.get_default_algorithms(), ) diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py index ac26600d..538078af 100644 --- a/tests/test_algorithms.py +++ b/tests/test_algorithms.py @@ -236,6 +236,103 @@ def test_ec_jwk_fails_on_invalid_json(self): f'{{"kty": "EC", "crv": "{curve}", "x": "{point["x"]}", "y": "{point["y"]}", "d": "dGVzdA=="}}' ) + @crypto_required + def test_ec_private_key_to_jwk_works_with_from_jwk(self): + algo = ECAlgorithm(ECAlgorithm.SHA256) + + with open(key_path("testkey_ec.priv")) as ec_key: + orig_key = algo.prepare_key(ec_key.read()) + + parsed_key = algo.from_jwk(algo.to_jwk(orig_key)) + assert parsed_key.private_numbers() == orig_key.private_numbers() + assert ( + parsed_key.private_numbers().public_numbers + == orig_key.private_numbers().public_numbers + ) + + @crypto_required + def test_ec_public_key_to_jwk_works_with_from_jwk(self): + algo = ECAlgorithm(ECAlgorithm.SHA256) + + with open(key_path("testkey_ec.pub")) as ec_key: + orig_key = algo.prepare_key(ec_key.read()) + + parsed_key = algo.from_jwk(algo.to_jwk(orig_key)) + assert parsed_key.public_numbers() == orig_key.public_numbers() + + @crypto_required + def test_ec_to_jwk_returns_correct_values_for_public_key(self): + algo = ECAlgorithm(ECAlgorithm.SHA256) + + with open(key_path("testkey_ec.pub")) as keyfile: + pub_key = algo.prepare_key(keyfile.read()) + + key = algo.to_jwk(pub_key) + + expected = { + "kty": "EC", + "crv": "P-256", + "x": "HzAcUWSlGBHcuf3y3RiNrWI-pE6-dD2T7fIzg9t6wEc", + "y": "t2G02kbWiOqimYfQAfnARdp2CTycsJPhwA8rn1Cn0SQ", + } + + assert json.loads(key) == expected + + @crypto_required + def test_ec_to_jwk_returns_correct_values_for_private_key(self): + algo = ECAlgorithm(ECAlgorithm.SHA256) + + with open(key_path("testkey_ec.priv")) as keyfile: + priv_key = algo.prepare_key(keyfile.read()) + + key = algo.to_jwk(priv_key) + + expected = { + "kty": "EC", + "crv": "P-256", + "x": "HzAcUWSlGBHcuf3y3RiNrWI-pE6-dD2T7fIzg9t6wEc", + "y": "t2G02kbWiOqimYfQAfnARdp2CTycsJPhwA8rn1Cn0SQ", + "d": "2nninfu2jMHDwAbn9oERUhRADS6duQaJEadybLaa0YQ", + } + + assert json.loads(key) == expected + + @crypto_required + def test_ec_to_jwk_raises_exception_on_invalid_key(self): + algo = ECAlgorithm(ECAlgorithm.SHA256) + + with pytest.raises(InvalidKeyError): + algo.to_jwk({"not": "a valid key"}) + + @crypto_required + def test_ec_to_jwk_with_valid_curves(self): + tests = { + "P-256": ECAlgorithm.SHA256, + "P-384": ECAlgorithm.SHA384, + "P-521": ECAlgorithm.SHA512, + "secp256k1": ECAlgorithm.SHA256, + } + for (curve, hash) in tests.items(): + algo = ECAlgorithm(hash) + + with open(key_path(f"jwk_ec_pub_{curve}.json")) as keyfile: + pub_key = algo.from_jwk(keyfile.read()) + assert json.loads(algo.to_jwk(pub_key))["crv"] == curve + + with open(key_path(f"jwk_ec_key_{curve}.json")) as keyfile: + priv_key = algo.from_jwk(keyfile.read()) + assert json.loads(algo.to_jwk(priv_key))["crv"] == curve + + @crypto_required + def test_ec_to_jwk_with_invalid_curve(self): + algo = ECAlgorithm(ECAlgorithm.SHA256) + + with open(key_path("testkey_ec_secp192r1.priv")) as keyfile: + priv_key = algo.prepare_key(keyfile.read()) + + with pytest.raises(InvalidKeyError): + algo.to_jwk(priv_key) + @crypto_required def test_rsa_jwk_public_and_private_keys_should_parse_and_verify(self): algo = RSAAlgorithm(RSAAlgorithm.SHA256) diff --git a/tests/test_api_jwk.py b/tests/test_api_jwk.py index 6f6cb899..040e81b9 100644 --- a/tests/test_api_jwk.py +++ b/tests/test_api_jwk.py @@ -4,7 +4,7 @@ from jwt.algorithms import has_crypto from jwt.api_jwk import PyJWK, PyJWKSet -from jwt.exceptions import InvalidKeyError, PyJWKError +from jwt.exceptions import InvalidKeyError, PyJWKError, PyJWKSetError from .utils import crypto_required, key_path @@ -208,8 +208,8 @@ def test_from_dict_should_throw_exception_if_arg_is_invalid(self): PyJWK.from_dict(v) +@crypto_required class TestPyJWKSet: - @crypto_required def test_should_load_keys_from_jwk_data_dict(self): algo = RSAAlgorithm(RSAAlgorithm.SHA256) @@ -231,7 +231,6 @@ def test_should_load_keys_from_jwk_data_dict(self): assert jwk.key_id == "keyid-abc123" assert jwk.public_key_use == "sig" - @crypto_required def test_should_load_keys_from_jwk_data_json_string(self): algo = RSAAlgorithm(RSAAlgorithm.SHA256) @@ -253,7 +252,6 @@ def test_should_load_keys_from_jwk_data_json_string(self): assert jwk.key_id == "keyid-abc123" assert jwk.public_key_use == "sig" - @crypto_required def test_keyset_should_index_by_kid(self): algo = RSAAlgorithm(RSAAlgorithm.SHA256) @@ -274,4 +272,31 @@ def test_keyset_should_index_by_kid(self): assert jwk == jwk_set["keyid-abc123"] with pytest.raises(KeyError): - jwk_set["this-kid-does-not-exist"] + _ = jwk_set["this-kid-does-not-exist"] + + def test_keyset_with_unknown_alg(self): + # first keyset with unusable key and usable key + with open(key_path("jwk_keyset_with_unknown_alg.json")) as keyfile: + jwks_text = keyfile.read() + jwks = json.loads(jwks_text) + assert len(jwks.get("keys")) == 2 + keyset = PyJWKSet.from_json(jwks_text) + assert len(keyset.keys) == 1 + + # second keyset with only unusable key -> catch exception + with open(key_path("jwk_keyset_only_unknown_alg.json")) as keyfile: + jwks_text = keyfile.read() + jwks = json.loads(jwks_text) + assert len(jwks.get("keys")) == 1 + with pytest.raises(PyJWKSetError): + _ = PyJWKSet.from_json(jwks_text) + + def test_invalid_keys_list(self): + with pytest.raises(PyJWKSetError) as err: + PyJWKSet(keys="string") + assert str(err.value) == "Invalid JWK Set value" + + def test_empty_keys_list(self): + with pytest.raises(PyJWKSetError) as err: + PyJWKSet(keys=[]) + assert str(err.value) == "The JWK Set did not contain any keys" diff --git a/tests/test_api_jws.py b/tests/test_api_jws.py index 0a0e2954..cfbbe212 100644 --- a/tests/test_api_jws.py +++ b/tests/test_api_jws.py @@ -12,6 +12,7 @@ InvalidTokenError, ) from jwt.utils import base64url_decode +from jwt.warnings import RemovedInPyjwt3Warning from .utils import crypto_required, key_path, no_crypto_required @@ -83,7 +84,7 @@ def test_non_object_options_dont_persist(self, jws, payload): def test_options_must_be_dict(self, jws): pytest.raises(TypeError, PyJWS, options=object()) - pytest.raises(TypeError, PyJWS, options=("something")) + pytest.raises((TypeError, ValueError), PyJWS, options=("something")) def test_encode_decode(self, jws, payload): secret = "secret" @@ -607,7 +608,7 @@ def test_decode_options_must_be_dict(self, jws, payload): with pytest.raises(TypeError): jws.decode(token, "secret", options=object()) - with pytest.raises(TypeError): + with pytest.raises((TypeError, ValueError)): jws.decode(token, "secret", options="something") def test_custom_json_encoder(self, jws, payload): @@ -770,3 +771,37 @@ def test_decode_detached_content_without_proper_argument(self, jws): 'It is required that you pass in a value for the "detached_payload" argument to decode a message having the b64 header set to false.' in str(exc.value) ) + + def test_decode_warns_on_unsupported_kwarg(self, jws, payload): + secret = "secret" + jws_message = jws.encode( + payload, secret, algorithm="HS256", is_payload_detached=True + ) + + with pytest.warns(RemovedInPyjwt3Warning) as record: + jws.decode( + jws_message, + secret, + algorithms=["HS256"], + detached_payload=payload, + foo="bar", + ) + assert len(record) == 1 + assert "foo" in str(record[0].message) + + def test_decode_complete_warns_on_unuspported_kwarg(self, jws, payload): + secret = "secret" + jws_message = jws.encode( + payload, secret, algorithm="HS256", is_payload_detached=True + ) + + with pytest.warns(RemovedInPyjwt3Warning) as record: + jws.decode_complete( + jws_message, + secret, + algorithms=["HS256"], + detached_payload=payload, + foo="bar", + ) + assert len(record) == 1 + assert "foo" in str(record[0].message) diff --git a/tests/test_api_jwt.py b/tests/test_api_jwt.py index 84e41e0e..bebe7d28 100644 --- a/tests/test_api_jwt.py +++ b/tests/test_api_jwt.py @@ -17,6 +17,7 @@ MissingRequiredClaimError, ) from jwt.utils import base64url_decode +from jwt.warnings import RemovedInPyjwt3Warning from .utils import crypto_required, key_path, utc_timestamp @@ -118,7 +119,7 @@ def test_decode_with_invalid_audience_param_throws_exception(self, jwt): jwt.decode(example_jwt, secret, audience=1, algorithms=["HS256"]) exception = context.value - assert str(exception) == "audience must be a string, iterable, or None" + assert str(exception) == "audience must be a string, iterable or None" def test_decode_with_nonlist_aud_claim_throws_exception(self, jwt): secret = "secret" @@ -418,6 +419,14 @@ def test_raise_exception_invalid_audience(self, jwt): with pytest.raises(InvalidAudienceError): jwt.decode(token, "secret", audience="urn-me", algorithms=["HS256"]) + def test_raise_exception_audience_as_bytes(self, jwt): + payload = {"some": "payload", "aud": ["urn:me", "urn:someone-else"]} + token = jwt.encode(payload, "secret") + with pytest.raises(InvalidAudienceError): + jwt.decode( + token, "secret", audience="urn:me".encode(), algorithms=["HS256"] + ) + def test_raise_exception_invalid_audience_in_array(self, jwt): payload = { "some": "payload", @@ -682,3 +691,21 @@ def test_decode_no_options_mutation(self, jwt, payload): jwt_message = jwt.encode(payload, secret) jwt.decode(jwt_message, secret, options=options, algorithms=["HS256"]) assert options == orig_options + + def test_decode_warns_on_unsupported_kwarg(self, jwt, payload): + secret = "secret" + jwt_message = jwt.encode(payload, secret) + + with pytest.warns(RemovedInPyjwt3Warning) as record: + jwt.decode(jwt_message, secret, algorithms=["HS256"], foo="bar") + assert len(record) == 1 + assert "foo" in str(record[0].message) + + def test_decode_complete_warns_on_unsupported_kwarg(self, jwt, payload): + secret = "secret" + jwt_message = jwt.encode(payload, secret) + + with pytest.warns(RemovedInPyjwt3Warning) as record: + jwt.decode_complete(jwt_message, secret, algorithms=["HS256"], foo="bar") + assert len(record) == 1 + assert "foo" in str(record[0].message) diff --git a/tests/test_jwks_client.py b/tests/test_jwks_client.py index 3e42da17..c95dfcc0 100644 --- a/tests/test_jwks_client.py +++ b/tests/test_jwks_client.py @@ -1,6 +1,8 @@ import contextlib import json +import time from unittest import mock +from urllib.error import URLError import pytest @@ -11,7 +13,7 @@ from .utils import crypto_required -RESPONSE_DATA = { +RESPONSE_DATA_WITH_MATCHING_KID = { "keys": [ { "alg": "RS256", @@ -28,9 +30,22 @@ ] } +RESPONSE_DATA_NO_MATCHING_KID = { + "keys": [ + { + "alg": "RS256", + "kty": "RSA", + "use": "sig", + "n": "39SJ39VgrQ0qMNK74CaueUBlyYsUyuA7yWlHYZ-jAj6tlFKugEVUTBUVbhGF44uOr99iL_cwmr-srqQDEi-jFHdkS6WFkYyZ03oyyx5dtBMtzrXPieFipSGfQ5EGUGloaKDjL-Ry9tiLnysH2VVWZ5WDDN-DGHxuCOWWjiBNcTmGfnj5_NvRHNUh2iTLuiJpHbGcPzWc5-lc4r-_ehw9EFfp2XsxE9xvtbMZ4SouJCiv9xnrnhe2bdpWuu34hXZCrQwE8DjRY3UR8LjyMxHHPLzX2LWNMHjfN3nAZMteS-Ok11VYDFI-4qCCVGo_WesBCAeqCjPLRyZoV27x1YGsUQ", + "e": "AQAB", + "kid": "MLYHNMMhwCNXw9roHIILFsK4nLs=", + } + ] +} + @contextlib.contextmanager -def mocked_response(data): +def mocked_success_response(data): with mock.patch("urllib.request.urlopen") as urlopen_mock: response = mock.Mock() response.__enter__ = mock.Mock(return_value=response) @@ -40,12 +55,35 @@ def mocked_response(data): yield urlopen_mock +@contextlib.contextmanager +def mocked_failed_response(): + with mock.patch("urllib.request.urlopen") as urlopen_mock: + urlopen_mock.side_effect = URLError("Fail to process the request.") + yield urlopen_mock + + +@contextlib.contextmanager +def mocked_first_call_wrong_kid_second_call_correct_kid( + response_data_one, response_data_two +): + with mock.patch("urllib.request.urlopen") as urlopen_mock: + response = mock.Mock() + response.__enter__ = mock.Mock(return_value=response) + response.__exit__ = mock.Mock() + response.read.side_effect = [ + json.dumps(response_data_one), + json.dumps(response_data_two), + ] + urlopen_mock.return_value = response + yield urlopen_mock + + @crypto_required class TestPyJWKClient: def test_get_jwk_set(self): url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" - with mocked_response(RESPONSE_DATA): + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID): jwks_client = PyJWKClient(url) jwk_set = jwks_client.get_jwk_set() @@ -54,7 +92,7 @@ def test_get_jwk_set(self): def test_get_signing_keys(self): url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" - with mocked_response(RESPONSE_DATA): + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID): jwks_client = PyJWKClient(url) signing_keys = jwks_client.get_signing_keys() @@ -64,11 +102,11 @@ def test_get_signing_keys(self): def test_get_signing_keys_if_no_use_provided(self): url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" - mocked_key = RESPONSE_DATA["keys"][0].copy() + mocked_key = RESPONSE_DATA_WITH_MATCHING_KID["keys"][0].copy() del mocked_key["use"] response = {"keys": [mocked_key]} - with mocked_response(response): + with mocked_success_response(response): jwks_client = PyJWKClient(url) signing_keys = jwks_client.get_signing_keys() @@ -78,10 +116,10 @@ def test_get_signing_keys_if_no_use_provided(self): def test_get_signing_keys_raises_if_none_found(self): url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" - mocked_key = RESPONSE_DATA["keys"][0].copy() + mocked_key = RESPONSE_DATA_WITH_MATCHING_KID["keys"][0].copy() mocked_key["use"] = "enc" response = {"keys": [mocked_key]} - with mocked_response(response): + with mocked_success_response(response): jwks_client = PyJWKClient(url) with pytest.raises(PyJWKClientError) as exc: @@ -93,7 +131,7 @@ def test_get_signing_key(self): url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" kid = "NEE1QURBOTM4MzI5RkFDNTYxOTU1MDg2ODgwQ0UzMTk1QjYyRkRFQw" - with mocked_response(RESPONSE_DATA): + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID): jwks_client = PyJWKClient(url) signing_key = jwks_client.get_signing_key(kid) @@ -106,14 +144,14 @@ def test_get_signing_key_caches_result(self): url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" kid = "NEE1QURBOTM4MzI5RkFDNTYxOTU1MDg2ODgwQ0UzMTk1QjYyRkRFQw" - jwks_client = PyJWKClient(url) + jwks_client = PyJWKClient(url, cache_keys=True) - with mocked_response(RESPONSE_DATA): + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID): jwks_client.get_signing_key(kid) # mocked_response does not allow urllib.request.urlopen to be called twice # so a second mock is needed - with mocked_response(RESPONSE_DATA) as repeated_call: + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID) as repeated_call: jwks_client.get_signing_key(kid) assert repeated_call.call_count == 0 @@ -122,14 +160,14 @@ def test_get_signing_key_does_not_cache_opt_out(self): url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" kid = "NEE1QURBOTM4MzI5RkFDNTYxOTU1MDg2ODgwQ0UzMTk1QjYyRkRFQw" - jwks_client = PyJWKClient(url, cache_keys=False) + jwks_client = PyJWKClient(url, cache_jwk_set=False) - with mocked_response(RESPONSE_DATA): + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID): jwks_client.get_signing_key(kid) # mocked_response does not allow urllib.request.urlopen to be called twice # so a second mock is needed - with mocked_response(RESPONSE_DATA) as repeated_call: + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID) as repeated_call: jwks_client.get_signing_key(kid) assert repeated_call.call_count == 1 @@ -138,7 +176,7 @@ def test_get_signing_key_from_jwt(self): token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6Ik5FRTFRVVJCT1RNNE16STVSa0ZETlRZeE9UVTFNRGcyT0Rnd1EwVXpNVGsxUWpZeVJrUkZRdyJ9.eyJpc3MiOiJodHRwczovL2Rldi04N2V2eDlydS5hdXRoMC5jb20vIiwic3ViIjoiYVc0Q2NhNzl4UmVMV1V6MGFFMkg2a0QwTzNjWEJWdENAY2xpZW50cyIsImF1ZCI6Imh0dHBzOi8vZXhwZW5zZXMtYXBpIiwiaWF0IjoxNTcyMDA2OTU0LCJleHAiOjE1NzIwMDY5NjQsImF6cCI6ImFXNENjYTc5eFJlTFdVejBhRTJINmtEME8zY1hCVnRDIiwiZ3R5IjoiY2xpZW50LWNyZWRlbnRpYWxzIn0.PUxE7xn52aTCohGiWoSdMBZGiYAHwE5FYie0Y1qUT68IHSTXwXVd6hn02HTah6epvHHVKA2FqcFZ4GGv5VTHEvYpeggiiZMgbxFrmTEY0csL6VNkX1eaJGcuehwQCRBKRLL3zKmA5IKGy5GeUnIbpPHLHDxr-GXvgFzsdsyWlVQvPX2xjeaQ217r2PtxDeqjlf66UYl6oY6AqNS8DH3iryCvIfCcybRZkc_hdy-6ZMoKT6Piijvk_aXdm7-QQqKJFHLuEqrVSOuBqqiNfVrG27QzAPuPOxvfXTVLXL2jek5meH6n-VWgrBdoMFH93QEszEDowDAEhQPHVs0xj7SIzA" url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" - with mocked_response(RESPONSE_DATA): + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID): jwks_client = PyJWKClient(url) signing_key = jwks_client.get_signing_key_from_jwt(token) @@ -159,3 +197,102 @@ def test_get_signing_key_from_jwt(self): "azp": "aW4Cca79xReLWUz0aE2H6kD0O3cXBVtC", "gty": "client-credentials", } + + def test_get_jwk_set_caches_result(self): + url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" + + jwks_client = PyJWKClient(url) + assert jwks_client.jwk_set_cache is not None + + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID): + jwks_client.get_jwk_set() + + # mocked_response does not allow urllib.request.urlopen to be called twice + # so a second mock is needed + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID) as repeated_call: + jwks_client.get_jwk_set() + + assert repeated_call.call_count == 0 + + def test_get_jwt_set_cache_expired_result(self): + url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" + + jwks_client = PyJWKClient(url, lifespan=1) + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID): + jwks_client.get_jwk_set() + + time.sleep(2) + + # mocked_response does not allow urllib.request.urlopen to be called twice + # so a second mock is needed + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID) as repeated_call: + jwks_client.get_jwk_set() + + assert repeated_call.call_count == 1 + + def test_get_jwt_set_cache_disabled(self): + url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" + + jwks_client = PyJWKClient(url, cache_jwk_set=False) + assert jwks_client.jwk_set_cache is None + + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID): + jwks_client.get_jwk_set() + + assert jwks_client.jwk_set_cache is None + + time.sleep(2) + + # mocked_response does not allow urllib.request.urlopen to be called twice + # so a second mock is needed + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID) as repeated_call: + jwks_client.get_jwk_set() + + assert repeated_call.call_count == 1 + + def test_get_jwt_set_failed_request_should_clear_cache(self): + url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" + + jwks_client = PyJWKClient(url) + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID): + jwks_client.get_jwk_set() + + with pytest.raises(PyJWKClientError): + with mocked_failed_response(): + jwks_client.get_jwk_set(refresh=True) + + assert jwks_client.jwk_set_cache is None + + def test_get_jwt_set_refresh_cache(self): + url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" + jwks_client = PyJWKClient(url) + + kid = "NEE1QURBOTM4MzI5RkFDNTYxOTU1MDg2ODgwQ0UzMTk1QjYyRkRFQw" + + # The first call will return response with no matching kid, + # the function should make another call to try to refresh the cache. + with mocked_first_call_wrong_kid_second_call_correct_kid( + RESPONSE_DATA_NO_MATCHING_KID, RESPONSE_DATA_WITH_MATCHING_KID + ) as call_data: + jwks_client.get_signing_key(kid) + + assert call_data.call_count == 2 + + def test_get_jwt_set_no_matching_kid_after_second_attempt(self): + url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" + jwks_client = PyJWKClient(url) + + kid = "NEE1QURBOTM4MzI5RkFDNTYxOTU1MDg2ODgwQ0UzMTk1QjYyRkRFQw" + + with pytest.raises(PyJWKClientError): + with mocked_first_call_wrong_kid_second_call_correct_kid( + RESPONSE_DATA_NO_MATCHING_KID, RESPONSE_DATA_NO_MATCHING_KID + ): + jwks_client.get_signing_key(kid) + + def test_get_jwt_set_invalid_lifespan(self): + url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" + + with pytest.raises(PyJWKClientError): + jwks_client = PyJWKClient(url, lifespan=-1) + assert jwks_client is None diff --git a/tox.ini b/tox.ini index d3664617..07ace117 100644 --- a/tox.ini +++ b/tox.ini @@ -8,18 +8,20 @@ filterwarnings = [gh-actions] python = - 3.6: py36 3.7: py37, docs 3.8: py38, typing 3.9: py39 3.10: py310 + 3.11: py311 + pypy-3.8: pypy3 + pypy-3.9: pypy3 [tox] envlist = lint typing - py{36,37,38,39}-{crypto,nocrypto} + py{37,38,39,310,311,py3}-{crypto,nocrypto} docs pypi-description coverage-report @@ -46,12 +48,6 @@ commands = python -m doctest README.rst -[testenv:typing] -basepython = python3.8 -extras = dev -commands = mypy jwt - - [testenv:lint] basepython = python3.8 extras = dev