JWT Authentication

This commit is contained in:
Michael Mikovsky
2025-11-29 13:15:09 -07:00
parent fcb8c6f6f5
commit a10bdce38f
18 changed files with 1198 additions and 583 deletions
+29
View File
@@ -0,0 +1,29 @@
use crate::{auth, structs::CurrentUser};
use axum::{
Extension, Router,
extract::Path,
middleware,
response::IntoResponse,
routing::{get, post},
};
use unshell_lib::info;
pub async fn app() -> Router {
Router::new().route("/auth", post(auth::sign_in)).route(
"/api/{*path}",
get(protected).layer(middleware::from_fn(auth::authorize)),
)
}
pub async fn protected(
Path(path): Path<String>,
Extension(currentUser): Extension<CurrentUser>,
) -> impl IntoResponse {
info!("{}", path);
// Json(UserResponse {
// email: currentUser.email,
// first_name: currentUser.first_name,
// last_name: currentUser.last_name,
// })
"Test"
}
+124
View File
@@ -0,0 +1,124 @@
use axum::{
body::Body,
extract::{Json, Request},
http::{self, Response, StatusCode},
middleware::Next,
};
use bcrypt::{DEFAULT_COST, hash, verify};
use chrono::Utc;
use jsonwebtoken::{Header, TokenData, Validation, decode, encode};
use serde_json::{Value, json};
use unshell_lib::info;
use crate::{
EXPIRE_DURATION, JWT_DECODING_KEY, JWT_ENCODING_KEY,
structs::{AuthError, Cliams, CurrentUser, SignInData},
};
pub fn hash_password(password: &str) -> Result<String, bcrypt::BcryptError> {
let hash = hash(password, DEFAULT_COST)?;
Ok(hash)
}
pub fn encode_jwt(email: String) -> Result<(String, usize), StatusCode> {
let now = Utc::now();
let exp = (now + EXPIRE_DURATION).timestamp() as usize;
let iat = now.timestamp() as usize;
let claim = Cliams { iat, exp, email };
let token = encode(&Header::default(), &claim, &JWT_ENCODING_KEY)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok((token, exp))
}
pub fn decode_jwt(jwt: String) -> Result<TokenData<Cliams>, StatusCode> {
decode(&jwt, &JWT_DECODING_KEY, &Validation::default())
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
}
pub async fn authorize(mut req: Request, next: Next) -> Result<Response<Body>, AuthError> {
let auth_header = req.headers_mut().get(http::header::AUTHORIZATION);
let auth_header = match auth_header {
Some(header) => header.to_str().map_err(|_| AuthError {
message: "Empty header is not allowed".to_string(),
status_code: StatusCode::FORBIDDEN,
})?,
None => {
return Err(AuthError {
message: "Please add the JWT token to the header".to_string(),
status_code: StatusCode::FORBIDDEN,
});
}
};
let mut header = auth_header.split_whitespace();
let (_, token) = (header.next(), header.next());
let token_data = match decode_jwt(token.unwrap().to_string()) {
Ok(data) => data,
Err(_) => {
return Err(AuthError {
message: "Invalid Session".to_string(),
status_code: StatusCode::UNAUTHORIZED,
});
}
};
// Fetch the user details from the database
let current_user = match retrieve_user_by_email(&token_data.claims.email) {
Some(user) => user,
None => {
return Err(AuthError {
message: "Unauthorized".to_string(),
status_code: StatusCode::UNAUTHORIZED,
});
}
};
req.extensions_mut().insert(current_user);
Ok(next.run(req).await)
}
pub async fn sign_in(Json(user_data): Json<SignInData>) -> Result<Json<Value>, StatusCode> {
// 1. Retrieve user from the database
let user = match retrieve_user_by_email(&user_data.username) {
Some(user) => user,
None => return Err(StatusCode::UNAUTHORIZED), // User not found
};
// 2. Compare the password
if !verify(&user_data.password, &user.password_hash)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
// Handle bcrypt errors
{
return Err(StatusCode::UNAUTHORIZED); // Wrong password
}
info!(
"Authenticated user {} for {}",
user_data.username, EXPIRE_DURATION
);
// 3. Generate JWT
let (token, experation) =
encode_jwt(user.email).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
// 4. Return the token
Ok(Json(json!({
"token": token,
"expiration": experation,
})))
}
fn retrieve_user_by_email(_email: &str) -> Option<CurrentUser> {
let current_user: CurrentUser = CurrentUser {
email: "foo".to_string(),
first_name: "Eze".to_string(),
last_name: "Sunday".to_string(),
password_hash: hash_password("bar").unwrap(),
};
Some(current_user)
}
+21
View File
@@ -0,0 +1,21 @@
#![macro_use]
use chrono::Duration;
use jsonwebtoken::{DecodingKey, EncodingKey};
use static_init::dynamic;
extern crate unshell_lib;
pub mod app;
mod auth;
mod structs;
mod userdata;
static EXPIRE_DURATION: Duration = Duration::seconds(10);
#[dynamic]
static JWT_SECRET: String = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set");
#[dynamic]
static JWT_ENCODING_KEY: EncodingKey = EncodingKey::from_secret(JWT_SECRET.as_bytes());
#[dynamic]
static JWT_DECODING_KEY: DecodingKey = DecodingKey::from_secret(JWT_SECRET.as_bytes());
+15 -193
View File
@@ -1,200 +1,22 @@
//! Example JWT authorization/authentication.
//!
//! Run with
//!
//! ```not_rust
//! JWT_SECRET=secret cargo run -p example-jwt
//! ```
use axum;
use tokio::net::TcpListener;
use unshell_lib::info;
use axum::{
Json, RequestPartsExt, Router,
extract::FromRequestParts,
http::{StatusCode, request::Parts},
response::{IntoResponse, Response},
routing::{get, post},
};
use axum_extra::{
TypedHeader,
headers::{Authorization, authorization::Bearer},
};
use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation, decode, encode};
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::fmt::Display;
use std::sync::LazyLock;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
// Quick instructions
//
// - get an authorization token:
//
// curl -s \
// -w '\n' \
// -H 'Content-Type: application/json' \
// -d '{"client_id":"foo","client_secret":"bar"}' \
// http://localhost:3000/authorize
//
// - visit the protected area using the authorized token
//
// curl -s \
// -w '\n' \
// -H 'Content-Type: application/json' \
// -H 'Authorization: Bearer eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJiQGIuY29tIiwiY29tcGFueSI6IkFDTUUiLCJleHAiOjEwMDAwMDAwMDAwfQ.M3LAZmrzUkXDC1q5mSzFAs_kJrwuKz3jOoDmjJ0G4gM' \
// http://localhost:3000/protected
//
// - try to visit the protected area using an invalid token
//
// curl -s \
// -w '\n' \
// -H 'Content-Type: application/json' \
// -H 'Authorization: Bearer blahblahblah' \
// http://localhost:3000/protected
static KEYS: LazyLock<Keys> = LazyLock::new(|| {
let secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set");
Keys::new(secret.as_bytes())
});
use unshell_server::app;
#[tokio::main]
async fn main() {
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()),
)
.with(tracing_subscriber::fmt::layer())
.init();
unshell_lib::logger::PrettyLogger::init();
let app = Router::new()
.route("/auth", post(authorize))
.route("/protected", get(protected));
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
let listener = TcpListener::bind("127.0.0.1:3000")
.await
.unwrap();
tracing::debug!("listening on {}", listener.local_addr().unwrap());
axum::serve(listener, app).await.unwrap();
}
async fn protected(claims: Claims) -> Result<String, AuthError> {
// Send the protected data to the user
Ok(format!(
"Welcome to the protected area :)\nYour data:\n{claims}",
))
}
async fn authorize(Json(payload): Json<AuthPayload>) -> Result<Json<AuthBody>, AuthError> {
println!("EEE");
// Check if the user sent the credentials
if payload.client_id.is_empty() || payload.client_secret.is_empty() {
return Err(AuthError::MissingCredentials);
}
// Here you can check the user credentials from a database
if payload.client_id != "foo" || payload.client_secret != "bar" {
return Err(AuthError::WrongCredentials);
}
let claims = Claims {
sub: "b@b.com".to_owned(),
company: "ACME".to_owned(),
// Mandatory expiry time as UTC timestamp
exp: 2000000000, // May 2033
};
// Create the authorization token
let token = encode(&Header::default(), &claims, &KEYS.encoding)
.map_err(|_| AuthError::TokenCreation)?;
// Send the authorized token
Ok(Json(AuthBody::new(token)))
}
impl Display for Claims {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Email: {}\nCompany: {}", self.sub, self.company)
}
}
impl AuthBody {
fn new(access_token: String) -> Self {
Self {
access_token,
token_type: "Bearer".to_string(),
}
}
}
impl<S> FromRequestParts<S> for Claims
where
S: Send + Sync,
{
type Rejection = AuthError;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
// Extract the token from the authorization header
let TypedHeader(Authorization(bearer)) = parts
.extract::<TypedHeader<Authorization<Bearer>>>()
.await
.map_err(|_| AuthError::InvalidToken)?;
// Decode the user data
let token_data = decode::<Claims>(bearer.token(), &KEYS.decoding, &Validation::default())
.map_err(|_| AuthError::InvalidToken)?;
Ok(token_data.claims)
}
}
impl IntoResponse for AuthError {
fn into_response(self) -> Response {
let (status, error_message) = match self {
AuthError::WrongCredentials => (StatusCode::UNAUTHORIZED, "Wrong credentials"),
AuthError::MissingCredentials => (StatusCode::BAD_REQUEST, "Missing credentials"),
AuthError::TokenCreation => (StatusCode::INTERNAL_SERVER_ERROR, "Token creation error"),
AuthError::InvalidToken => (StatusCode::BAD_REQUEST, "Invalid token"),
};
let body = Json(json!({
"error": error_message,
}));
(status, body).into_response()
}
}
struct Keys {
encoding: EncodingKey,
decoding: DecodingKey,
}
impl Keys {
fn new(secret: &[u8]) -> Self {
Self {
encoding: EncodingKey::from_secret(secret),
decoding: DecodingKey::from_secret(secret),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct Claims {
sub: String,
company: String,
exp: usize,
}
#[derive(Debug, Serialize)]
struct AuthBody {
access_token: String,
token_type: String,
}
#[derive(Debug, Deserialize)]
struct AuthPayload {
client_id: String,
client_secret: String,
}
#[derive(Debug)]
enum AuthError {
WrongCredentials,
MissingCredentials,
TokenCreation,
InvalidToken,
.expect("Unable to start listener");
info!("Listening on {}", listener.local_addr().unwrap());
let app = app::app().await;
axum::serve(listener, app)
.await
.expect("Error serving application");
}
+47
View File
@@ -0,0 +1,47 @@
use axum::{
Json,
body::Body,
http::{Response, StatusCode},
response::IntoResponse,
};
use serde::{Deserialize, Serialize};
use serde_json::json;
#[derive(Deserialize)]
pub struct SignInData {
pub username: String,
pub password: String,
}
#[derive(Debug, Clone)]
pub struct CurrentUser {
pub email: String,
pub first_name: String,
pub last_name: String,
pub password_hash: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Cliams {
// pub exp: u128,
// pub iat: u128,
//
pub exp: usize,
pub iat: usize,
pub email: String,
}
pub struct AuthError {
pub message: String,
pub status_code: StatusCode,
}
impl IntoResponse for AuthError {
fn into_response(self) -> Response<Body> {
let body = Json(json!({
"error": self.message,
}));
(self.status_code, body).into_response()
}
}
+4
View File
@@ -0,0 +1,4 @@
pub struct UserData {
username: String,
hash: Vec<u8>,
}