Published on

Verifying a JSON Web Token from Amazon Cognito in Python and FastAPI

Authors

NOTE: This post is intentionally structured similarly to Verifying a JSON Web Token from Amazon Cognito in Go and Gin, but showcasing the same methodology using Python related technologies.

Intro

One popular option for integrating Amazon Cognito authentication/authorization with a backend, requires the decoding and verifying of JSON Web Tokens (JWT) for server-side processing. AWS details the steps required to validate an incoming JWT produced by Amazon Cognito in its official documentation, and offers an example using a dedicated JavaScript library.

We can implement the above steps in Python and FastAPI, using PyJWT.

Retrieving a public JWKS

The JWKS URI contains public information about the private key that signed your user's token. As soon as a Cognito User Pool is created, it will publish its JWKS in a public URI. It can be composed as followed https://cognito-idp.<region>.amazonaws.com/<userPoolId>/.well-known/jwks.json where region is the AWS region of your User Pool and userPoolId the ID of the User Pool.

The PyJWT Python library, has a convenient method jwt.PyJWKClient for retrieving that key using an HTTP request and parsing it, all in a single statement:

import jwt
from fastapi_cognito_jwt.config import get_settings

jwks_client = jwt.PyJWKClient(
    f"https://cognito-idp.{get_settings().aws_default_region}.amazonaws.com/{get_settings().cognito_user_pool_id}/.well-known/jwks.json"
)

Verifying an incoming JWT in a FastAPI Dependency

1. Defining a FastAPI Dependency that performs Cognito authentication

FastAPI offers an elegant dependency injection mechanism in the form of Dependencies. It can also be leveraged to enforce authentication/authorization, and it is ideal for protecting endpoints using JWTs.

Even better, a FastAPI Dependency can take the form of a callable instance of a class, which can be instantiated with values of our choosing in its __init__ method and where our JWT verification logic can be implemented inside its __call__ method.

Following this convention, we can start implementing a dependency called CognitoJWTAuthorizer in a file dependencies.py, with the following function signature, passing the token use (required_token_use which can be access or id depending on the type of Cognito token we accept), AWS Region (aws_default_region), the Cognito User Pool ID (cognito_user_pool_id), the Cognito App Client ID that we've created (cognito_app_client_id) and the instantiated JWKS client containing the JWKS key (jwks_client).

The __call__ method can take as an argument a FastAPI header authorization: Annotated[str | None, Header()], where its variable name authorization is automatically mapped to the actual Authorization header of an incoming HTTP request.

from enum import Enum
from typing import Annotated

import jwt
from fastapi import Header, HTTPException
from starlette.status import HTTP_401_UNAUTHORIZED

from python_fastapi_cognito_jwt_verification.config import get_settings


jwks_client = jwt.PyJWKClient(
    f"https://cognito-idp.{get_settings().aws_default_region}.amazonaws.com/{get_settings().cognito_user_pool_id}/.well-known/jwks.json"
)


class CognitoTokenUse(Enum):
    ID = "id"
    ACCESS = "access"


class CognitoJWTAuthorizer:
    def __init__(
            self,
            required_token_use: CognitoTokenUse,
            aws_default_region: str,
            cognito_user_pool_id: str,
            cognito_app_client_id: str,
            jwks_client: jwt.PyJWKClient,
    ) -> None:
        self.required_token_use = required_token_use
        self.aws_default_region = aws_default_region
        self.cognito_user_pool_id = cognito_user_pool_id
        self.cognito_app_client_id = cognito_app_client_id
        self.jwks_client = jwks_client

    def __call__(self, authorization: Annotated[str | None, Header()] = None):

2. Retrieving the JWT from an incoming request

By convention, an incoming request contains a JWT in its Authorization header using a Bearer token.

The Authorization header including the Bearer token has the format:

Authorization: Bearer abcdefg

Using the variable authorization that we passed to the __call__ method we can retrieve the raw value of the JWT by splitting the string containing it.

if not authorization:
    raise HTTPException(
        status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized"
    )
split_authorization_tokens: list[str] = authorization.split("Bearer ")

if len(split_authorization_tokens) < 2:
    raise HTTPException(
        status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized"
    )

token: str = split_authorization_tokens[1]

3. Verify the JWT signature, signing algorithm, issuer (iss) and existence of expiry time (exp)

Now, using PyJWT we can perform some first rudimentary checks against the JWT using some convenient methods offered by the library, also shown in its official documentation.

First, we can retrieve the signing key of the JWT:

try:
    signing_key: jwt.PyJWK = self.jwks_client.get_signing_key_from_jwt(token)
except jwt.exceptions.InvalidTokenError as e:
    raise HTTPException(
            status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized"
        ) from e

Then we can attempt to perform the majority of the required checks against the JWT:

  • The signature of the JWT is verified using the jwt.decode method by providing the retrieved signing key, in the form of a siging_key.key argument, along with some extra checks we define below.

  • One of the most important verifications is defining the specific valid algorithm methods that the parser will use, and confirming that the incoming token is using those exclusively. In the case of Amazon Cognito the asymmetric algorithm RS256 is used to sign the key. This can be enforced using the algorithms argument and passing a list of algorithms, in our case just ["RS256"]. Leaving this option without a value leaves us open to algorithm confusion attacks.

  • The token should not be expired. The jwt.decode method can automate this check for us, by passing verify_exp: True as a key-pair in the options argument.

  • The issuer claim (iss) should match our User Pool. For example, a User Pool created in the us-east-1 region will have the following iss value: https://cognito-idp.us-east-1.amazonaws.com/<userpoolID>. We can verify this using the issuer argument and also pass verify_iss: True as a key-pair in the options argument, just to be extra explicit about this check.

  • We can also require the existence of specific claims in the token, by passing a list of them in the require attribute of the options argument. Our verification process needs the claims ["token_use", "exp", "iss", "sub"].

  • Verification of the audience claim (aud) is taken care later when we examine if the token is a Cognito 'id' or 'access' token. That is why for now we pass verify_aud: False as a key-pair in the options argument.

The above checks can be succinctly defined in one statement, facilitated by PyJWT, as followed:

try:
    claims = jwt.decode(
        token,
        signing_key.key,
        algorithms=["RS256"],
        issuer=f"https://cognito-idp.{self.aws_default_region}.amazonaws.com/{self.cognito_user_pool_id}",
        options={
            "verify_aud": False,
            "verify_signature": True,
            "verify_exp": True,
            "verify_iss": True,
            "require": ["token_use", "exp", "iss", "sub"],
        },
    )
except jwt.exceptions.ExpiredSignatureError as e:
    raise HTTPException(
        status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized"
    ) from e
except jwt.exceptions.InvalidTokenError as e:
    raise HTTPException(
        status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized"
    ) from e

4. Check the token_use claim

Cognito can send ID or access tokens, each with a different set of attributes. Depending on the nature of the endpoint we want to protect we can choose to accept specific types.

ID tokens contain claims about the identity of the authenticated user, such as name, email, and phone_number.

Access tokens contain claims about the authorization of the authenticated user such as a list of the user's groups, and a list of scopes.

if self.required_token_use.value != claims["token_use"]:
    raise HTTPException(
        status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized"
    )

5. Verify the audience (aud)/client ID (client_id) claim

Depending on the type of token (access or ID), we can check respectively the aud or the client_id claims and that they should match the Cognito App Client ID created in the Cognito User Pool.

For each case, we can check the existence of aud the client_id custom claims in claims, the same way we would check for values in a Python dictionary, and compare their value against the Cognito App Client ID.

if self.required_token_use == CognitoTokenUse.ID:
    if "aud" not in claims:
        raise HTTPException(
            status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized"
        )
    if claims["aud"] != self.cognito_app_client_id:
        raise HTTPException(
            status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized"
        )
elif self.required_token_use == CognitoTokenUse.ACCESS:
    if "client_id" not in claims:
        raise HTTPException(
            status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized"
        )
    if claims["client_id"] != self.cognito_app_client_id:
        raise HTTPException(
            status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized"
        )
else:
    raise HTTPException(
        status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized"
    )

Complete middleware snippet

The complete snippet using all the above statements is as followed:

from enum import Enum
from typing import Annotated

import jwt
from fastapi import Header, HTTPException
from starlette.status import HTTP_401_UNAUTHORIZED

from python_fastapi_cognito_jwt_verification.config import get_settings

jwks_client = jwt.PyJWKClient(
    f"https://cognito-idp.{get_settings().aws_default_region}.amazonaws.com/{get_settings().cognito_user_pool_id}/.well-known/jwks.json"
)


class CognitoTokenUse(Enum):
    ID = "id"
    ACCESS = "access"


class CognitoJWTAuthorizer:
    def __init__(
            self,
            required_token_use: CognitoTokenUse,
            aws_default_region: str,
            cognito_user_pool_id: str,
            cognito_app_client_id: str,
            jwks_client: jwt.PyJWKClient,
    ) -> None:
        """Verify an incoming JWT using official AWS guidelines.

        https://docs.aws.amazon.com/cognito/latest/developerguide/amazon-cognito-user-pools-using-tokens-verifying-a-jwt.html#amazon-cognito-user-pools-using-tokens-manually-inspect

        Args:
            required_token_use (CognitoTokenUse): Required Cognito token type
            aws_default_region (str): AWS region
            cognito_user_pool_id (str): Cognito User Pool ID
            cognito_app_client_id (str): Cognito App Pool ID
            jwks_client (jwt.PyJWKClient): An instance of a JWKS client that has
                                           retrieved the public Cognito JWKS

        Raises:
            fastapi.HTTPException: Raised if a verification check of the incoming
                                     JWT fails
        """
        self.required_token_use = required_token_use
        self.aws_default_region = aws_default_region
        self.cognito_user_pool_id = cognito_user_pool_id
        self.cognito_app_client_id = cognito_app_client_id
        self.jwks_client = jwks_client

    def __call__(self, authorization: Annotated[str | None, Header()] = None):
        """Verify an embedded Cognito JWT token in the 'Authorization' header.

        Args:
            authorization (str | None): 'Authorization' header of a FastAPI endpoint
        """
        if not authorization:
            raise HTTPException(
                status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized"
            )
        split_authorization_tokens: list[str] = authorization.split("Bearer ")

        if len(split_authorization_tokens) < 2:
            raise HTTPException(
                status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized"
            )

        token: str = split_authorization_tokens[1]

        # Attempt to retrieve the signature of the incoming JWT
        try:
            signing_key: jwt.PyJWK = self.jwks_client.get_signing_key_from_jwt(token)
        except jwt.exceptions.InvalidTokenError as e:
            raise HTTPException(
                    status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized"
                ) from e

        try:
            """
            * Verify the signature of the JWT
            * Verify that the algorithm used is RS256
            * Verification of audience 'aud' is taken care later when we examine if the
              token is 'id' or 'access'
            * Verify that the token hasn't expired. Decode the token and compare the
              'exp' claim to the current time.
            * The issuer (iss) claim should match your user pool. For example, a user
              pool created in the eu-west-2 region
              will have the following iss value: https://cognito-idp.us-east-1.amazonaws.com/<userpoolID>.
            * Require the existence of claims: 'token_use', 'exp', 'iss', 'sub'
            """
            claims = jwt.decode(
                token,
                signing_key.key,
                algorithms=["RS256"],
                issuer=f"https://cognito-idp.{self.aws_default_region}.amazonaws.com/{self.cognito_user_pool_id}",
                options={
                    "verify_aud": False,
                    "verify_signature": True,
                    "verify_exp": True,
                    "verify_iss": True,
                    "require": ["token_use", "exp", "iss", "sub"],
                },
            )
        except jwt.exceptions.ExpiredSignatureError as e:
            raise HTTPException(
                status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized"
            ) from e
        except jwt.exceptions.InvalidTokenError as e:
            raise HTTPException(
                status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized"
            ) from e

        """
        Check the token_use claim
        * If you are only accepting the access token in your web API operations,
          its value must be access.
        * If you are only using the ID token, its value must be id.
        * If you are using both ID and access tokens,
          the token_use claim must be either id or access.
        """
        if self.required_token_use.value != claims["token_use"]:
            raise HTTPException(
                status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized"
            )

        """
        The "aud" claim in an ID token and the "client_id" claim in an access token
        should match the app client ID that was created in the Amazon Cognito user pool.
        """
        if self.required_token_use == CognitoTokenUse.ID:
            if "aud" not in claims:
                raise HTTPException(
                    status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized"
                )
            if claims["aud"] != self.cognito_app_client_id:
                raise HTTPException(
                    status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized"
                )
        elif self.required_token_use == CognitoTokenUse.ACCESS:
            if "client_id" not in claims:
                raise HTTPException(
                    status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized"
                )
            if claims["client_id"] != self.cognito_app_client_id:
                raise HTTPException(
                    status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized"
                )
        else:
            raise HTTPException(
                status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized"
            )


cognito_jwt_authorizer_access_token = CognitoJWTAuthorizer(
    CognitoTokenUse.ACCESS,
    get_settings().aws_default_region,
    get_settings().cognito_user_pool_id,
    get_settings().cognito_app_client_id,
    jwks_client,
)

cognito_jwt_authorizer_id_token = CognitoJWTAuthorizer(
    CognitoTokenUse.ID,
    get_settings().aws_default_region,
    get_settings().cognito_user_pool_id,
    get_settings().cognito_app_client_id,
    jwks_client,
)

Using the Dependencies in a FastAPI endpoint

Putting it all together, we can create FastAPI endpoints that are protected using the above Dependencies.

The below code can be used in a main.py file and shows how to define two protected FastAPI endpoints, one requiring a Cognito ID token (/protected-with-id-token) and one requiring a Cognito Access Token (/protected-with-access-token).

from fastapi import Depends, FastAPI

from python_fastapi_cognito_jwt_verification.dependencies import (
    cognito_jwt_authorizer_access_token,
    cognito_jwt_authorizer_id_token,
)

app = FastAPI()


@app.get("/protected-with-access-token", dependencies=[Depends(cognito_jwt_authorizer_access_token)])
def protected_with_access_token():
    return {"Hello": "World"}


@app.get("/protected-with-id-token", dependencies=[Depends(cognito_jwt_authorizer_id_token)])
def protected_with_id_token():
    return {"Hello": "World"}

Full implementation

You can find the fully working implementation in a GitHub repository.