mirror of
https://github.com/Astatin3/unshell.git
synced 2026-06-09 06:47:59 -06:00
JWT Authentication
This commit is contained in:
@@ -0,0 +1,29 @@
|
||||
use crate::{auth, structs::CurrentUser};
|
||||
use axum::{
|
||||
Extension, Router,
|
||||
extract::Path,
|
||||
middleware,
|
||||
response::IntoResponse,
|
||||
routing::{get, post},
|
||||
};
|
||||
use unshell_lib::info;
|
||||
|
||||
pub async fn app() -> Router {
|
||||
Router::new().route("/auth", post(auth::sign_in)).route(
|
||||
"/api/{*path}",
|
||||
get(protected).layer(middleware::from_fn(auth::authorize)),
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn protected(
|
||||
Path(path): Path<String>,
|
||||
Extension(currentUser): Extension<CurrentUser>,
|
||||
) -> impl IntoResponse {
|
||||
info!("{}", path);
|
||||
// Json(UserResponse {
|
||||
// email: currentUser.email,
|
||||
// first_name: currentUser.first_name,
|
||||
// last_name: currentUser.last_name,
|
||||
// })
|
||||
"Test"
|
||||
}
|
||||
@@ -0,0 +1,124 @@
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::{Json, Request},
|
||||
http::{self, Response, StatusCode},
|
||||
middleware::Next,
|
||||
};
|
||||
use bcrypt::{DEFAULT_COST, hash, verify};
|
||||
use chrono::Utc;
|
||||
use jsonwebtoken::{Header, TokenData, Validation, decode, encode};
|
||||
use serde_json::{Value, json};
|
||||
use unshell_lib::info;
|
||||
|
||||
use crate::{
|
||||
EXPIRE_DURATION, JWT_DECODING_KEY, JWT_ENCODING_KEY,
|
||||
structs::{AuthError, Cliams, CurrentUser, SignInData},
|
||||
};
|
||||
|
||||
pub fn hash_password(password: &str) -> Result<String, bcrypt::BcryptError> {
|
||||
let hash = hash(password, DEFAULT_COST)?;
|
||||
Ok(hash)
|
||||
}
|
||||
|
||||
pub fn encode_jwt(email: String) -> Result<(String, usize), StatusCode> {
|
||||
let now = Utc::now();
|
||||
let exp = (now + EXPIRE_DURATION).timestamp() as usize;
|
||||
let iat = now.timestamp() as usize;
|
||||
|
||||
let claim = Cliams { iat, exp, email };
|
||||
|
||||
let token = encode(&Header::default(), &claim, &JWT_ENCODING_KEY)
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
Ok((token, exp))
|
||||
}
|
||||
|
||||
pub fn decode_jwt(jwt: String) -> Result<TokenData<Cliams>, StatusCode> {
|
||||
decode(&jwt, &JWT_DECODING_KEY, &Validation::default())
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
|
||||
pub async fn authorize(mut req: Request, next: Next) -> Result<Response<Body>, AuthError> {
|
||||
let auth_header = req.headers_mut().get(http::header::AUTHORIZATION);
|
||||
|
||||
let auth_header = match auth_header {
|
||||
Some(header) => header.to_str().map_err(|_| AuthError {
|
||||
message: "Empty header is not allowed".to_string(),
|
||||
status_code: StatusCode::FORBIDDEN,
|
||||
})?,
|
||||
None => {
|
||||
return Err(AuthError {
|
||||
message: "Please add the JWT token to the header".to_string(),
|
||||
status_code: StatusCode::FORBIDDEN,
|
||||
});
|
||||
}
|
||||
};
|
||||
let mut header = auth_header.split_whitespace();
|
||||
|
||||
let (_, token) = (header.next(), header.next());
|
||||
|
||||
let token_data = match decode_jwt(token.unwrap().to_string()) {
|
||||
Ok(data) => data,
|
||||
Err(_) => {
|
||||
return Err(AuthError {
|
||||
message: "Invalid Session".to_string(),
|
||||
status_code: StatusCode::UNAUTHORIZED,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// Fetch the user details from the database
|
||||
let current_user = match retrieve_user_by_email(&token_data.claims.email) {
|
||||
Some(user) => user,
|
||||
None => {
|
||||
return Err(AuthError {
|
||||
message: "Unauthorized".to_string(),
|
||||
status_code: StatusCode::UNAUTHORIZED,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
req.extensions_mut().insert(current_user);
|
||||
Ok(next.run(req).await)
|
||||
}
|
||||
|
||||
pub async fn sign_in(Json(user_data): Json<SignInData>) -> Result<Json<Value>, StatusCode> {
|
||||
// 1. Retrieve user from the database
|
||||
let user = match retrieve_user_by_email(&user_data.username) {
|
||||
Some(user) => user,
|
||||
None => return Err(StatusCode::UNAUTHORIZED), // User not found
|
||||
};
|
||||
|
||||
// 2. Compare the password
|
||||
if !verify(&user_data.password, &user.password_hash)
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
// Handle bcrypt errors
|
||||
{
|
||||
return Err(StatusCode::UNAUTHORIZED); // Wrong password
|
||||
}
|
||||
|
||||
info!(
|
||||
"Authenticated user {} for {}",
|
||||
user_data.username, EXPIRE_DURATION
|
||||
);
|
||||
|
||||
// 3. Generate JWT
|
||||
let (token, experation) =
|
||||
encode_jwt(user.email).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
// 4. Return the token
|
||||
Ok(Json(json!({
|
||||
"token": token,
|
||||
"expiration": experation,
|
||||
})))
|
||||
}
|
||||
|
||||
fn retrieve_user_by_email(_email: &str) -> Option<CurrentUser> {
|
||||
let current_user: CurrentUser = CurrentUser {
|
||||
email: "foo".to_string(),
|
||||
first_name: "Eze".to_string(),
|
||||
last_name: "Sunday".to_string(),
|
||||
password_hash: hash_password("bar").unwrap(),
|
||||
};
|
||||
Some(current_user)
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
#![macro_use]
|
||||
|
||||
use chrono::Duration;
|
||||
use jsonwebtoken::{DecodingKey, EncodingKey};
|
||||
use static_init::dynamic;
|
||||
extern crate unshell_lib;
|
||||
|
||||
pub mod app;
|
||||
mod auth;
|
||||
mod structs;
|
||||
mod userdata;
|
||||
|
||||
static EXPIRE_DURATION: Duration = Duration::seconds(10);
|
||||
|
||||
#[dynamic]
|
||||
static JWT_SECRET: String = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set");
|
||||
|
||||
#[dynamic]
|
||||
static JWT_ENCODING_KEY: EncodingKey = EncodingKey::from_secret(JWT_SECRET.as_bytes());
|
||||
#[dynamic]
|
||||
static JWT_DECODING_KEY: DecodingKey = DecodingKey::from_secret(JWT_SECRET.as_bytes());
|
||||
+15
-193
@@ -1,200 +1,22 @@
|
||||
//! Example JWT authorization/authentication.
|
||||
//!
|
||||
//! Run with
|
||||
//!
|
||||
//! ```not_rust
|
||||
//! JWT_SECRET=secret cargo run -p example-jwt
|
||||
//! ```
|
||||
use axum;
|
||||
use tokio::net::TcpListener;
|
||||
use unshell_lib::info;
|
||||
|
||||
use axum::{
|
||||
Json, RequestPartsExt, Router,
|
||||
extract::FromRequestParts,
|
||||
http::{StatusCode, request::Parts},
|
||||
response::{IntoResponse, Response},
|
||||
routing::{get, post},
|
||||
};
|
||||
use axum_extra::{
|
||||
TypedHeader,
|
||||
headers::{Authorization, authorization::Bearer},
|
||||
};
|
||||
use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation, decode, encode};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use std::fmt::Display;
|
||||
use std::sync::LazyLock;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
// Quick instructions
|
||||
//
|
||||
// - get an authorization token:
|
||||
//
|
||||
// curl -s \
|
||||
// -w '\n' \
|
||||
// -H 'Content-Type: application/json' \
|
||||
// -d '{"client_id":"foo","client_secret":"bar"}' \
|
||||
// http://localhost:3000/authorize
|
||||
//
|
||||
// - visit the protected area using the authorized token
|
||||
//
|
||||
// curl -s \
|
||||
// -w '\n' \
|
||||
// -H 'Content-Type: application/json' \
|
||||
// -H 'Authorization: Bearer eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJiQGIuY29tIiwiY29tcGFueSI6IkFDTUUiLCJleHAiOjEwMDAwMDAwMDAwfQ.M3LAZmrzUkXDC1q5mSzFAs_kJrwuKz3jOoDmjJ0G4gM' \
|
||||
// http://localhost:3000/protected
|
||||
//
|
||||
// - try to visit the protected area using an invalid token
|
||||
//
|
||||
// curl -s \
|
||||
// -w '\n' \
|
||||
// -H 'Content-Type: application/json' \
|
||||
// -H 'Authorization: Bearer blahblahblah' \
|
||||
// http://localhost:3000/protected
|
||||
|
||||
static KEYS: LazyLock<Keys> = LazyLock::new(|| {
|
||||
let secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set");
|
||||
Keys::new(secret.as_bytes())
|
||||
});
|
||||
use unshell_server::app;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
tracing_subscriber::registry()
|
||||
.with(
|
||||
tracing_subscriber::EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()),
|
||||
)
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.init();
|
||||
unshell_lib::logger::PrettyLogger::init();
|
||||
|
||||
let app = Router::new()
|
||||
.route("/auth", post(authorize))
|
||||
.route("/protected", get(protected));
|
||||
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
|
||||
let listener = TcpListener::bind("127.0.0.1:3000")
|
||||
.await
|
||||
.unwrap();
|
||||
tracing::debug!("listening on {}", listener.local_addr().unwrap());
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
}
|
||||
|
||||
async fn protected(claims: Claims) -> Result<String, AuthError> {
|
||||
// Send the protected data to the user
|
||||
Ok(format!(
|
||||
"Welcome to the protected area :)\nYour data:\n{claims}",
|
||||
))
|
||||
}
|
||||
|
||||
async fn authorize(Json(payload): Json<AuthPayload>) -> Result<Json<AuthBody>, AuthError> {
|
||||
println!("EEE");
|
||||
|
||||
// Check if the user sent the credentials
|
||||
if payload.client_id.is_empty() || payload.client_secret.is_empty() {
|
||||
return Err(AuthError::MissingCredentials);
|
||||
}
|
||||
// Here you can check the user credentials from a database
|
||||
if payload.client_id != "foo" || payload.client_secret != "bar" {
|
||||
return Err(AuthError::WrongCredentials);
|
||||
}
|
||||
let claims = Claims {
|
||||
sub: "b@b.com".to_owned(),
|
||||
company: "ACME".to_owned(),
|
||||
// Mandatory expiry time as UTC timestamp
|
||||
exp: 2000000000, // May 2033
|
||||
};
|
||||
// Create the authorization token
|
||||
let token = encode(&Header::default(), &claims, &KEYS.encoding)
|
||||
.map_err(|_| AuthError::TokenCreation)?;
|
||||
|
||||
// Send the authorized token
|
||||
Ok(Json(AuthBody::new(token)))
|
||||
}
|
||||
|
||||
impl Display for Claims {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "Email: {}\nCompany: {}", self.sub, self.company)
|
||||
}
|
||||
}
|
||||
|
||||
impl AuthBody {
|
||||
fn new(access_token: String) -> Self {
|
||||
Self {
|
||||
access_token,
|
||||
token_type: "Bearer".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> FromRequestParts<S> for Claims
|
||||
where
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = AuthError;
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
|
||||
// Extract the token from the authorization header
|
||||
let TypedHeader(Authorization(bearer)) = parts
|
||||
.extract::<TypedHeader<Authorization<Bearer>>>()
|
||||
.await
|
||||
.map_err(|_| AuthError::InvalidToken)?;
|
||||
// Decode the user data
|
||||
let token_data = decode::<Claims>(bearer.token(), &KEYS.decoding, &Validation::default())
|
||||
.map_err(|_| AuthError::InvalidToken)?;
|
||||
|
||||
Ok(token_data.claims)
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for AuthError {
|
||||
fn into_response(self) -> Response {
|
||||
let (status, error_message) = match self {
|
||||
AuthError::WrongCredentials => (StatusCode::UNAUTHORIZED, "Wrong credentials"),
|
||||
AuthError::MissingCredentials => (StatusCode::BAD_REQUEST, "Missing credentials"),
|
||||
AuthError::TokenCreation => (StatusCode::INTERNAL_SERVER_ERROR, "Token creation error"),
|
||||
AuthError::InvalidToken => (StatusCode::BAD_REQUEST, "Invalid token"),
|
||||
};
|
||||
let body = Json(json!({
|
||||
"error": error_message,
|
||||
}));
|
||||
(status, body).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
struct Keys {
|
||||
encoding: EncodingKey,
|
||||
decoding: DecodingKey,
|
||||
}
|
||||
|
||||
impl Keys {
|
||||
fn new(secret: &[u8]) -> Self {
|
||||
Self {
|
||||
encoding: EncodingKey::from_secret(secret),
|
||||
decoding: DecodingKey::from_secret(secret),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct Claims {
|
||||
sub: String,
|
||||
company: String,
|
||||
exp: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct AuthBody {
|
||||
access_token: String,
|
||||
token_type: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AuthPayload {
|
||||
client_id: String,
|
||||
client_secret: String,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum AuthError {
|
||||
WrongCredentials,
|
||||
MissingCredentials,
|
||||
TokenCreation,
|
||||
InvalidToken,
|
||||
.expect("Unable to start listener");
|
||||
|
||||
info!("Listening on {}", listener.local_addr().unwrap());
|
||||
|
||||
let app = app::app().await;
|
||||
|
||||
axum::serve(listener, app)
|
||||
.await
|
||||
.expect("Error serving application");
|
||||
}
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
use axum::{
|
||||
Json,
|
||||
body::Body,
|
||||
http::{Response, StatusCode},
|
||||
response::IntoResponse,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct SignInData {
|
||||
pub username: String,
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CurrentUser {
|
||||
pub email: String,
|
||||
pub first_name: String,
|
||||
pub last_name: String,
|
||||
pub password_hash: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Cliams {
|
||||
// pub exp: u128,
|
||||
// pub iat: u128,
|
||||
//
|
||||
pub exp: usize,
|
||||
pub iat: usize,
|
||||
pub email: String,
|
||||
}
|
||||
|
||||
pub struct AuthError {
|
||||
pub message: String,
|
||||
pub status_code: StatusCode,
|
||||
}
|
||||
|
||||
impl IntoResponse for AuthError {
|
||||
fn into_response(self) -> Response<Body> {
|
||||
let body = Json(json!({
|
||||
"error": self.message,
|
||||
}));
|
||||
|
||||
(self.status_code, body).into_response()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
pub struct UserData {
|
||||
username: String,
|
||||
hash: Vec<u8>,
|
||||
}
|
||||
Reference in New Issue
Block a user