//! Example JWT authorization/authentication. //! //! Run with //! //! ```not_rust //! JWT_SECRET=secret cargo run -p example-jwt //! ``` use axum::{ Json, RequestPartsExt, Router, extract::FromRequestParts, http::{StatusCode, request::Parts}, response::{IntoResponse, Response}, routing::{get, post}, }; use axum_extra::{ TypedHeader, headers::{Authorization, authorization::Bearer}, }; use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation, decode, encode}; use serde::{Deserialize, Serialize}; use serde_json::json; use std::fmt::Display; use std::sync::LazyLock; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; // Quick instructions // // - get an authorization token: // // curl -s \ // -w '\n' \ // -H 'Content-Type: application/json' \ // -d '{"client_id":"foo","client_secret":"bar"}' \ // http://localhost:3000/authorize // // - visit the protected area using the authorized token // // curl -s \ // -w '\n' \ // -H 'Content-Type: application/json' \ // -H 'Authorization: Bearer eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJiQGIuY29tIiwiY29tcGFueSI6IkFDTUUiLCJleHAiOjEwMDAwMDAwMDAwfQ.M3LAZmrzUkXDC1q5mSzFAs_kJrwuKz3jOoDmjJ0G4gM' \ // http://localhost:3000/protected // // - try to visit the protected area using an invalid token // // curl -s \ // -w '\n' \ // -H 'Content-Type: application/json' \ // -H 'Authorization: Bearer blahblahblah' \ // http://localhost:3000/protected static KEYS: LazyLock = LazyLock::new(|| { let secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set"); Keys::new(secret.as_bytes()) }); #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); let app = Router::new() .route("/auth", post(authorize)) .route("/protected", get(protected)); let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); tracing::debug!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app).await.unwrap(); } async fn protected(claims: Claims) -> Result { // Send the protected data to the user Ok(format!( "Welcome to the protected area :)\nYour data:\n{claims}", )) } async fn authorize(Json(payload): Json) -> Result, AuthError> { println!("EEE"); // Check if the user sent the credentials if payload.client_id.is_empty() || payload.client_secret.is_empty() { return Err(AuthError::MissingCredentials); } // Here you can check the user credentials from a database if payload.client_id != "foo" || payload.client_secret != "bar" { return Err(AuthError::WrongCredentials); } let claims = Claims { sub: "b@b.com".to_owned(), company: "ACME".to_owned(), // Mandatory expiry time as UTC timestamp exp: 2000000000, // May 2033 }; // Create the authorization token let token = encode(&Header::default(), &claims, &KEYS.encoding) .map_err(|_| AuthError::TokenCreation)?; // Send the authorized token Ok(Json(AuthBody::new(token))) } impl Display for Claims { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "Email: {}\nCompany: {}", self.sub, self.company) } } impl AuthBody { fn new(access_token: String) -> Self { Self { access_token, token_type: "Bearer".to_string(), } } } impl FromRequestParts for Claims where S: Send + Sync, { type Rejection = AuthError; async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { // Extract the token from the authorization header let TypedHeader(Authorization(bearer)) = parts .extract::>>() .await .map_err(|_| AuthError::InvalidToken)?; // Decode the user data let token_data = decode::(bearer.token(), &KEYS.decoding, &Validation::default()) .map_err(|_| AuthError::InvalidToken)?; Ok(token_data.claims) } } impl IntoResponse for AuthError { fn into_response(self) -> Response { let (status, error_message) = match self { AuthError::WrongCredentials => (StatusCode::UNAUTHORIZED, "Wrong credentials"), AuthError::MissingCredentials => (StatusCode::BAD_REQUEST, "Missing credentials"), AuthError::TokenCreation => (StatusCode::INTERNAL_SERVER_ERROR, "Token creation error"), AuthError::InvalidToken => (StatusCode::BAD_REQUEST, "Invalid token"), }; let body = Json(json!({ "error": error_message, })); (status, body).into_response() } } struct Keys { encoding: EncodingKey, decoding: DecodingKey, } impl Keys { fn new(secret: &[u8]) -> Self { Self { encoding: EncodingKey::from_secret(secret), decoding: DecodingKey::from_secret(secret), } } } #[derive(Debug, Clone, Serialize, Deserialize)] struct Claims { sub: String, company: String, exp: usize, } #[derive(Debug, Serialize)] struct AuthBody { access_token: String, token_type: String, } #[derive(Debug, Deserialize)] struct AuthPayload { client_id: String, client_secret: String, } #[derive(Debug)] enum AuthError { WrongCredentials, MissingCredentials, TokenCreation, InvalidToken, }