Validating Okta Access Tokens in Python with PyJWT
Every week, almost without fail, I come across one thing that confuses, entertains, or most commonly infuriates me. I’ve decided to keep a record of my adventures.
There is a great blog post by Renzo Lucioni from a few years ago that talks about validating JWTs with JWKs using pyJWT. While this blog post doesn't directly speak to Okta or access tokens, the principles are generally the same. Since 2019 when Renzo wrote this, a few things have improved in pyJWT. Today we'll take a look at the current flow for validating Okta Access Tokens using pyJWT
Please Note: Okta has their own Python library for validating JWTs. There are a couple reasons why you might not want to use it but most of them center around the fact that it will only work for Okta. For instance it ONLY supports 'RS256' (which is the current RFC recommendation, but is expected to change) and it only will look up Okta's jwks_uris.
Background
The use of Access Tokens versus ID Tokens can be highly idiomatic and has been covered in many other places. However, when a Resource Server (app) gets an Access Token it is required to validate that token. Exactly how that is done is not discussed in the RFC 6749.
There are two common forms of access tokens: "opaque" (honestly, there is no defined term here, this has just become common lingo) and JWT. The manners in which these can be validated differ. One of the core benefits of using a JWT is that they can be validated off-line because the they contain information on the validity of the token signed by the Authorization Server (issuing server). Alternatively, with opaque Access Tokens, a Resource Server is required to speak directly to the Authorization Server to determine validity, typically via an Introspection Endpoint (RFC 7662).
Okta uses JWT's as its OAuth2 Access Tokens. As mentioned, this is a smart choice because it allows offline validation, although introspection is also supported. Due to its common implementation, RFC 9068 exists to define usage of JWTs as Access Tokens. Because Okta has on occasion violated RFC, lets take a gander at some of the requirements:
? iss, exp, aud, sub (included by default), iat, jti are mandatory and present
? Optional auth_time header is present
? Additional non-standard claims are present: ver, uid,
? The optional `scope` claim has been renamed scp (interestingly. 'scope' is not reserved in Okta's claim manager - since this claim is technically optional this isn't an issue but as it could have special meaning to Resource Servers, this can lead to unexpected and possibly insecure configuration)
? client_id is present but has been renamed to cid (interestingly, 'client_id' is not reserved in Okta's claim manager- this can lead to unexpected and possibly insecure configuration if a rouge administrator were to add the value - this is also against the RFC)
? "typ" header must be `at+jwt` and is missing
The Problem
In spite of some minor deviations from the RFC, Okta's Access Tokens are pretty standard. I've talked in the past about the problems that can arise with usage of Okta's Introspection Endpoint and why validating an access token on the resource server side could be beneficial. In the past when I discussed this I didn't describe how to do this. That is what I'll discuss today.
The Solution
When Using JWTs as an access token, the RFC indicates that the Authorization Server should provide metadata about the JWKs via a "jwks_uri". This information, along with the issuer, which we'll need to validate against the claim, is available at https://{okta-tenant}/oauth2/{authorization-server-id}/.well-known/oauth-authorization-server (as required by RFC). While the RFC for OAuth Authorization Server Metadata indicates a preferred usage of https://{authorization-server}/jwks.json as the jwks_uri, Okta uses https://{okta-tenant}/oauth2/{authorization-server-id}/v1/keys instead to serve public key information.
Let's talk about PyJWT, if can be used for a several different use cases but has one distinct configuration not common within many other Python libraries, it has an optional dependency - the cryptography library. In this case, we're going to be validating signatures, so we'll need that installed. To ensure that crypto support is present users must first check the value of jwt.algorithms.has_crypto:
领英推荐
if not jwt.algorithms.has_crypto:
print("No crypto support for JWT, please install the cryptography dependency")
return False
Once we've ensured that crypto support is present, the next step depends on your goal. PyJWT can validate a JWT purely locally or it can ingest and cache the contents of a jwks_uri to validate a JWT.
There are trade-offs to each approach. Downloading the jwks_uri will provide better security if the keys are revoked. Unfortunately, if the Resource Server contacts the jwks_uri each time, if the Authorization Server is ever down then validation will fail (by throwing a PyJWKClientError). The purely local alternative requires the Resource Server to be updated manually each time a key change is made. For Okta this is defaults to once every 3 months (this can be manually overwritten, or queried for expiration via the API). The down the middle path is to use a cache. PyJWT provides functools lru_cache to support this option as well.
A conversation about cache lifetimes is important here. Tl;dr, it's important to weigh Okta being down for a bit with key rotation timelines when it comes to caching. PyJWT will default to a 5 minute cache. This means that every 3 months, there is likely 5-minutes of downtime on your Resource Server (Note: It is unclear if Okta automatically switches issued keys before the expiration of the previous key, this is best left for a future blog, we've also seen them do funky stuff like validate expired keys in a previous blog). Obviously, lowering the cache, waiting for the cache to timeout, or restarting the service will resolve the issue. This also speaks to why short-lived access tokens are logical.
Assuming the jwks_url is up, and the cache isn't expired -- PyJWT will download the jwks_url content, optionally cache its content, and then inspect the `kid` (note: alg is not matched) from the JWT Access Token header with the JWKs specified in the jwks_uri. If it finds a match, it will output the public key, otherwise it will throw a PyJWKClientError. Keep in mind get_signing_key_fromt_jwt() can also throw a DecodeError exception (as well as others) and as such users should catch the more general PyJWTError
jwks_url = "https://{okta_tenant}/oauth2/{auth_server_id}/v1/keys"
try:
jwks_client = jwt.PyJWKClient(jwks_url, cache_jwk_set=True, lifespan=360)
signing_key = jwks_client.get_signing_key_from_jwt(access_token)
except jwt.exception.PyJWTError as err:
print(f"Error: {err}")
return False
Let's say you are fine doing this manually. This could be because you've manually increased your key lifetimes, you're willing to trade stability to upkeep, or because you're failing back in the event the server is unavailable. Whatever the reason, this is supported by PyJWT as well. In a demonstration of simple design, PyJWT's PyJWK() class is instantiated with the dictionary value of a key (which can be found on the jwks_uri, or converted from a pem file). It would look something like the following (Note: please keep in mind that you should also validate the kid from the JWT header with the stored key):
try:
# Example public key provided by Okta
signing_key = {
"kty": "RSA",
"alg": "RS256",
"kid": "4STRLzk9N8JJbdz7kRk8mIqn27kotgV0oEv17Ient0E",
"use": "sig",
"e": "AQAB",
"n": "6eUqFyymfSQVCV2xRhrlxs3O-NxrZQ84bpIowajLBfREGadUru0ItIvcoUVw0E3TFB8-udAlOeXcKGgr33dOZ_ZuQtlTxUdU09_oHmuF2t8CdIUV6hxJL2trF9uedKYjoRX_EhPSBZ5V1JYH8oFhcoD0yMzn_Z5yZz1hIYr5uz2tt4v6wmL2_Yw2z7cXfC0DIn6XPxTVDG1uk10kZ57Q6VrnPhkTXQYNC0BrflvumzPM-t6VJun5MJaNPb0JZwnv25fB0b-3JHFXDcPunSuVlr3n7ao3mD_Xo2bkG8Ak2pMtfPvigu-Z2-X9Uln47_smczkXkOmCkrUce-JBFs7TLw",
}
except jwt.exceptions.PyJWTError as err:
print(f"Error: {err}")
Now that we have the signing key we can decode and validate the JWT. the jwt.decode() function takes a number of arguments. The most important are - the jwt, the public key, the algorithms, the audience, the issuer, and an options dictionary.
By default, if no options dictionary is passed (or specific values overridden) PyJWT will validate the signature (requires the algo), the expiration time (exp) the not before time (nbf), the issued at time (iat), the audience (aud), and the issuer (iss). This can throw any number of different exceptions including DecodeError, ExpiredSignatureError, InvalidAudienceError, and InvalidIssuerError just to name a few. Suffice it to say, any exceptions means that the access token should be treated as invalid.
Because PyJWT only accepts the jwks_url and not the full Authorization Server metadata it doesn't know anything about the issuer claim. As a result, developers MUST ensure that the value of issuer from the metadata is provided as the `issuer` parameter to `jwt.decode()`.
Likewise, `audience` is configured per Authorization Server. This value is NOT included in the Authorization Server metadata but must be set in the `audience` parameter within `jwt.decode()`.
The last area of interest is the algorithm validation. There are situations where allowing the attacker to control the algorithm being used can lead to algorithm confusion attacks. In Okta's case the only supported key alg value is RS256. This results in a call similar to the following:
data = jwt.decode(
access_token,
signing_key.key,
algorithms=["RS256"],
issuer="https://{okta_tenant}/oauth2/{auth_server_id}",
audience="{audience_value_from_okta_auth_server}",
# This should be default, but just to be doubly clear
options={"verify_signature": True "verify_exp": True, "verify_nbf": True, "verify_iat": True, "verify_aud": True, "verify_iss": True},
)
Bringing it all together, you'd get something like the following:
import jwt # requires cryptography
def validate_access_token(access_token):
if not jwt.algorithms.has_crypto:
print("No crypto support for JWT, please install the cryptography dependency")
return False
okta_auth_server = "https://{okta_tenant}/oauth2/{auth_server_id}"
jwks_url = f"{okta_auth_server}/v1/keys"
try:
jwks_client = jwt.PyJWKClient(jwks_url, cache_jwk_set=True, lifespan=360)
signing_key = jwks_client.get_signing_key_from_jwt(access_token)
data = jwt.decode(
access_token,
signing_key.key,
algorithms=["RS256"],
issuer=okta_auth_server,
audience=f"vpn_auth_server",
options={
"verify_signature": True,
"verify_exp": True,
"verify_nbf": True,
"verify_iat": True,
"verify_aud": True,
"verify_iss": True,
},
)
return data
except jwt.exceptions.PyJWTError as err:
print(f"Error: {err}")
return False
“I trust that none will stretch the seams in putting on the coat, for it may do good service to him whom it fits”. Have a great rest of 2022!