- Published on
Verifying a JSON Web Token from Amazon Cognito in Python and FastAPI
- Authors
- Name
- Angelos Panagiotopoulos
- @angelospanag
NOTE: This post is intentionally structured similarly to Verifying a JSON Web Token from Amazon Cognito in Go and Gin, but showcasing the same methodology using Python related technologies.
- Intro
- Retrieving a public JWKS
- Verifying an incoming JWT in a FastAPI Dependency
- Complete middleware snippet
- Using the Dependencies in a FastAPI endpoint
- Full implementation
Intro
One popular option for integrating Amazon Cognito authentication/authorization with a backend, requires the decoding and verifying of JSON Web Tokens (JWT) for server-side processing. AWS details the steps required to validate an incoming JWT produced by Amazon Cognito in its official documentation, and offers an example using a dedicated JavaScript library.
We can implement the above steps in Python and FastAPI, using PyJWT.
Retrieving a public JWKS
The JWKS URI contains public information about the private key that signed your user's token. As soon as a Cognito User Pool is created, it will publish its JWKS in a public URI. It can be composed as followed https://cognito-idp.<region>.amazonaws.com/<userPoolId>/.well-known/jwks.json
where region
is the AWS region of your User Pool and userPoolId
the ID of the User Pool.
The PyJWT Python library, has a convenient method jwt.PyJWKClient for retrieving that key using an HTTP request and parsing it, all in a single statement:
import jwt
from fastapi_cognito_jwt.config import get_settings
jwks_client = jwt.PyJWKClient(
f"https://cognito-idp.{get_settings().aws_default_region}.amazonaws.com/{get_settings().cognito_user_pool_id}/.well-known/jwks.json"
)
Verifying an incoming JWT in a FastAPI Dependency
1. Defining a FastAPI Dependency that performs Cognito authentication
FastAPI offers an elegant dependency injection mechanism in the form of Dependencies. It can also be leveraged to enforce authentication/authorization, and it is ideal for protecting endpoints using JWTs.
Even better, a FastAPI Dependency can take the form of a callable instance of a class, which can be instantiated with values of our choosing in its __init__
method and where our JWT verification logic can be implemented inside its __call__
method.
Following this convention, we can start implementing a dependency called CognitoJWTAuthorizer
in a file dependencies.py
, with the following function signature, passing the token use (required_token_use
which can be access
or id
depending on the type of Cognito token we accept), AWS Region (aws_default_region
), the Cognito User Pool ID (cognito_user_pool_id
), the Cognito App Client ID that we've created (cognito_app_client_id
) and the instantiated JWKS client containing the JWKS key (jwks_client
).
The __call__
method can take as an argument a FastAPI header authorization: Annotated[str | None, Header()]
, where its variable name authorization
is automatically mapped to the actual Authorization
header of an incoming HTTP request.
from enum import Enum
from typing import Annotated
import jwt
from fastapi import Header, HTTPException
from starlette.status import HTTP_401_UNAUTHORIZED
from python_fastapi_cognito_jwt_verification.config import get_settings
jwks_client = jwt.PyJWKClient(
f"https://cognito-idp.{get_settings().aws_default_region}.amazonaws.com/{get_settings().cognito_user_pool_id}/.well-known/jwks.json"
)
class CognitoTokenUse(Enum):
ID = "id"
ACCESS = "access"
class CognitoJWTAuthorizer:
def __init__(
self,
required_token_use: CognitoTokenUse,
aws_default_region: str,
cognito_user_pool_id: str,
cognito_app_client_id: str,
jwks_client: jwt.PyJWKClient,
) -> None:
self.required_token_use = required_token_use
self.aws_default_region = aws_default_region
self.cognito_user_pool_id = cognito_user_pool_id
self.cognito_app_client_id = cognito_app_client_id
self.jwks_client = jwks_client
def __call__(self, authorization: Annotated[str | None, Header()] = None):
2. Retrieving the JWT from an incoming request
By convention, an incoming request contains a JWT in its Authorization
header using a Bearer token.
The Authorization
header including the Bearer token has the format:
Authorization: Bearer abcdefg
Using the variable authorization
that we passed to the __call__
method we can retrieve the raw value of the JWT by splitting the string containing it.
if not authorization:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized"
)
split_authorization_tokens: list[str] = authorization.split("Bearer ")
if len(split_authorization_tokens) < 2:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized"
)
token: str = split_authorization_tokens[1]
iss
) and existence of expiry time (exp
)
3. Verify the JWT signature, signing algorithm, issuer (Now, using PyJWT we can perform some first rudimentary checks against the JWT using some convenient methods offered by the library, also shown in its official documentation.
First, we can retrieve the signing key of the JWT:
try:
signing_key: jwt.PyJWK = self.jwks_client.get_signing_key_from_jwt(token)
except jwt.exceptions.InvalidTokenError as e:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized"
) from e
Then we can attempt to perform the majority of the required checks against the JWT:
The signature of the JWT is verified using the
jwt.decode
method by providing the retrieved signing key, in the form of asiging_key.key
argument, along with some extra checks we define below.One of the most important verifications is defining the specific valid algorithm methods that the parser will use, and confirming that the incoming token is using those exclusively. In the case of Amazon Cognito the asymmetric algorithm
RS256
is used to sign the key. This can be enforced using thealgorithms
argument and passing a list of algorithms, in our case just["RS256"]
. Leaving this option without a value leaves us open to algorithm confusion attacks.The token should not be expired. The
jwt.decode
method can automate this check for us, by passingverify_exp: True
as a key-pair in theoptions
argument.The issuer claim (
iss
) should match our User Pool. For example, a User Pool created in theus-east-1
region will have the followingiss
value:https://cognito-idp.us-east-1.amazonaws.com/<userpoolID>
. We can verify this using theissuer
argument and also passverify_iss: True
as a key-pair in theoptions
argument, just to be extra explicit about this check.We can also require the existence of specific claims in the token, by passing a list of them in the
require
attribute of theoptions
argument. Our verification process needs the claims["token_use", "exp", "iss", "sub"]
.Verification of the audience claim (
aud
) is taken care later when we examine if the token is a Cognito 'id' or 'access' token. That is why for now we passverify_aud: False
as a key-pair in theoptions
argument.
The above checks can be succinctly defined in one statement, facilitated by PyJWT, as followed:
try:
claims = jwt.decode(
token,
signing_key.key,
algorithms=["RS256"],
issuer=f"https://cognito-idp.{self.aws_default_region}.amazonaws.com/{self.cognito_user_pool_id}",
options={
"verify_aud": False,
"verify_signature": True,
"verify_exp": True,
"verify_iss": True,
"require": ["token_use", "exp", "iss", "sub"],
},
)
except jwt.exceptions.ExpiredSignatureError as e:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized"
) from e
except jwt.exceptions.InvalidTokenError as e:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized"
) from e
token_use
claim
4. Check the Cognito can send ID or access tokens, each with a different set of attributes. Depending on the nature of the endpoint we want to protect we can choose to accept specific types.
ID tokens contain claims about the identity of the authenticated user, such as name
, email
, and phone_number
.
Access tokens contain claims about the authorization of the authenticated user such as a list of the user's groups, and a list of scopes.
if self.required_token_use.value != claims["token_use"]:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized"
)
aud
)/client ID (client_id
) claim
5. Verify the audience (Depending on the type of token (access or ID), we can check respectively the aud
or the client_id
claims and that they should match the Cognito App Client ID created in the Cognito User Pool.
For each case, we can check the existence of aud
the client_id
custom claims in claims
, the same way we would check for values in a Python dictionary, and compare their value against the Cognito App Client ID.
if self.required_token_use == CognitoTokenUse.ID:
if "aud" not in claims:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized"
)
if claims["aud"] != self.cognito_app_client_id:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized"
)
elif self.required_token_use == CognitoTokenUse.ACCESS:
if "client_id" not in claims:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized"
)
if claims["client_id"] != self.cognito_app_client_id:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized"
)
else:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized"
)
Complete middleware snippet
The complete snippet using all the above statements is as followed:
from enum import Enum
from typing import Annotated
import jwt
from fastapi import Header, HTTPException
from starlette.status import HTTP_401_UNAUTHORIZED
from python_fastapi_cognito_jwt_verification.config import get_settings
jwks_client = jwt.PyJWKClient(
f"https://cognito-idp.{get_settings().aws_default_region}.amazonaws.com/{get_settings().cognito_user_pool_id}/.well-known/jwks.json"
)
class CognitoTokenUse(Enum):
ID = "id"
ACCESS = "access"
class CognitoJWTAuthorizer:
def __init__(
self,
required_token_use: CognitoTokenUse,
aws_default_region: str,
cognito_user_pool_id: str,
cognito_app_client_id: str,
jwks_client: jwt.PyJWKClient,
) -> None:
"""Verify an incoming JWT using official AWS guidelines.
https://docs.aws.amazon.com/cognito/latest/developerguide/amazon-cognito-user-pools-using-tokens-verifying-a-jwt.html#amazon-cognito-user-pools-using-tokens-manually-inspect
Args:
required_token_use (CognitoTokenUse): Required Cognito token type
aws_default_region (str): AWS region
cognito_user_pool_id (str): Cognito User Pool ID
cognito_app_client_id (str): Cognito App Pool ID
jwks_client (jwt.PyJWKClient): An instance of a JWKS client that has
retrieved the public Cognito JWKS
Raises:
fastapi.HTTPException: Raised if a verification check of the incoming
JWT fails
"""
self.required_token_use = required_token_use
self.aws_default_region = aws_default_region
self.cognito_user_pool_id = cognito_user_pool_id
self.cognito_app_client_id = cognito_app_client_id
self.jwks_client = jwks_client
def __call__(self, authorization: Annotated[str | None, Header()] = None):
"""Verify an embedded Cognito JWT token in the 'Authorization' header.
Args:
authorization (str | None): 'Authorization' header of a FastAPI endpoint
"""
if not authorization:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized"
)
split_authorization_tokens: list[str] = authorization.split("Bearer ")
if len(split_authorization_tokens) < 2:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized"
)
token: str = split_authorization_tokens[1]
# Attempt to retrieve the signature of the incoming JWT
try:
signing_key: jwt.PyJWK = self.jwks_client.get_signing_key_from_jwt(token)
except jwt.exceptions.InvalidTokenError as e:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized"
) from e
try:
"""
* Verify the signature of the JWT
* Verify that the algorithm used is RS256
* Verification of audience 'aud' is taken care later when we examine if the
token is 'id' or 'access'
* Verify that the token hasn't expired. Decode the token and compare the
'exp' claim to the current time.
* The issuer (iss) claim should match your user pool. For example, a user
pool created in the eu-west-2 region
will have the following iss value: https://cognito-idp.us-east-1.amazonaws.com/<userpoolID>.
* Require the existence of claims: 'token_use', 'exp', 'iss', 'sub'
"""
claims = jwt.decode(
token,
signing_key.key,
algorithms=["RS256"],
issuer=f"https://cognito-idp.{self.aws_default_region}.amazonaws.com/{self.cognito_user_pool_id}",
options={
"verify_aud": False,
"verify_signature": True,
"verify_exp": True,
"verify_iss": True,
"require": ["token_use", "exp", "iss", "sub"],
},
)
except jwt.exceptions.ExpiredSignatureError as e:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized"
) from e
except jwt.exceptions.InvalidTokenError as e:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized"
) from e
"""
Check the token_use claim
* If you are only accepting the access token in your web API operations,
its value must be access.
* If you are only using the ID token, its value must be id.
* If you are using both ID and access tokens,
the token_use claim must be either id or access.
"""
if self.required_token_use.value != claims["token_use"]:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized"
)
"""
The "aud" claim in an ID token and the "client_id" claim in an access token
should match the app client ID that was created in the Amazon Cognito user pool.
"""
if self.required_token_use == CognitoTokenUse.ID:
if "aud" not in claims:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized"
)
if claims["aud"] != self.cognito_app_client_id:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized"
)
elif self.required_token_use == CognitoTokenUse.ACCESS:
if "client_id" not in claims:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized"
)
if claims["client_id"] != self.cognito_app_client_id:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized"
)
else:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail="Unauthorized"
)
cognito_jwt_authorizer_access_token = CognitoJWTAuthorizer(
CognitoTokenUse.ACCESS,
get_settings().aws_default_region,
get_settings().cognito_user_pool_id,
get_settings().cognito_app_client_id,
jwks_client,
)
cognito_jwt_authorizer_id_token = CognitoJWTAuthorizer(
CognitoTokenUse.ID,
get_settings().aws_default_region,
get_settings().cognito_user_pool_id,
get_settings().cognito_app_client_id,
jwks_client,
)
Using the Dependencies in a FastAPI endpoint
Putting it all together, we can create FastAPI endpoints that are protected using the above Dependencies.
The below code can be used in a main.py
file and shows how to define two protected FastAPI endpoints, one requiring a Cognito ID token (/protected-with-id-token
) and one requiring a Cognito Access Token (/protected-with-access-token
).
from fastapi import Depends, FastAPI
from python_fastapi_cognito_jwt_verification.dependencies import (
cognito_jwt_authorizer_access_token,
cognito_jwt_authorizer_id_token,
)
app = FastAPI()
@app.get("/protected-with-access-token", dependencies=[Depends(cognito_jwt_authorizer_access_token)])
def protected_with_access_token():
return {"Hello": "World"}
@app.get("/protected-with-id-token", dependencies=[Depends(cognito_jwt_authorizer_id_token)])
def protected_with_id_token():
return {"Hello": "World"}
Full implementation
You can find the fully working implementation in a GitHub repository.