From 5d2011861b3d2f1c406ec5a72f46b82e014c8b88 Mon Sep 17 00:00:00 2001 From: Joey Hines Date: Sat, 10 Feb 2024 11:02:07 -0700 Subject: [PATCH] Added auth handling for protected routes --- src/main.rs | 99 ++++++++++++++++++++++++++++---------- src/model/api_key.rs | 26 ++++++++-- src/model/mod.rs | 2 +- src/storage_manager/mod.rs | 2 +- 4 files changed, 97 insertions(+), 32 deletions(-) diff --git a/src/main.rs b/src/main.rs index 87ebf79..c081bd8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,16 +6,19 @@ mod storage_manager; use crate::config::PicOxConfig; use crate::model::album::Album; +use crate::model::api_key::{ApiKey, ApiPermissions}; use crate::model::image::{Image, ImageData}; use crate::state::Context; use crate::storage_manager::{StorageManager, StoreError}; use axum::body::Bytes; -use axum::extract::{DefaultBodyLimit, Multipart, Path, Query, State}; -use axum::http::StatusCode; +use axum::extract::{DefaultBodyLimit, Multipart, Path, Query, Request, State}; +use axum::http::{HeaderMap, StatusCode}; +use axum::middleware::Next; use axum::response::IntoResponse; use axum::routing::{get, post}; -use axum::{Json, Router}; +use axum::{middleware, Json, Router}; use axum_macros::FromRequest; +use base64::Engine; use j_db::database::Database; use j_db::model::JdbModel; use log::info; @@ -24,14 +27,11 @@ use rand::thread_rng; use serde::{Deserialize, Serialize}; use std::path::PathBuf; use std::sync::Arc; -use base64::Engine; use structopt::StructOpt; use tokio::sync::RwLock; -use crate::model::api_key::{ApiKey, ApiPermissions}; type PicContext = Arc; - #[derive(StructOpt, Debug, Clone)] #[structopt(about = "PicOx Commands")] enum SubCommands { @@ -42,7 +42,7 @@ enum SubCommands { /// Key description description: String, /// API Key permissions (WRITE, DELETE) - permissions: ApiPermissions + permissions: ApiPermissions, }, /// Dump the database state to json Dump { @@ -52,18 +52,17 @@ enum SubCommands { /// Import the database state from json Import { /// Path to json containing DB state - db_file: PathBuf - } + db_file: PathBuf, + }, } - #[derive(Debug, Clone, StructOpt)] struct Args { /// Path to the config file #[structopt(short, long, env = "PICOX_CONFIG")] pub config: PathBuf, #[structopt(subcommand)] - pub command: SubCommands + pub command: SubCommands, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -116,11 +115,13 @@ where } #[allow(dead_code)] -enum PicOxError { +pub enum PicOxError { StoreError(StoreError), DbError(j_db::error::JDbError), AlbumNotFound, ImageNotFound, + TokenInvalid, + NoUserInHeader, } impl From for PicOxError { @@ -161,6 +162,11 @@ impl IntoResponse for PicOxError { StatusCode::INTERNAL_SERVER_ERROR, "Image not found".to_string(), ), + PicOxError::TokenInvalid => (StatusCode::UNAUTHORIZED, "Token is invalid".to_string()), + PicOxError::NoUserInHeader => ( + StatusCode::INTERNAL_SERVER_ERROR, + "User not found in header".to_string(), + ), }; (status, Response(ErrorResponse { message })).into_response() @@ -169,12 +175,16 @@ impl IntoResponse for PicOxError { async fn create_album( State(context): State, + headers: HeaderMap, Json(album): Json, ) -> Result, PicOxError> { - info!("Creating new album '{}'", album.album_name); - - let new_album = Album::new(&album.album_name, Vec::new(), 0); + let user_id = get_user_id_from_headers(&headers)?; + let new_album = Album::new(&album.album_name, Vec::new(), user_id); + info!( + "Creating new album '{}' for user {}", + album.album_name, user_id + ); let new_album = context.db.insert(new_album)?; Ok(Response(new_album)) @@ -256,8 +266,17 @@ async fn query_images( Ok(Response(images)) } +fn get_user_id_from_headers(headers: &HeaderMap) -> Result { + let user = headers.get("user").ok_or(PicOxError::NoUserInHeader)?; + let user_str = user.to_str().unwrap(); + let user: u64 = user_str.parse().unwrap(); + + Ok(user) +} + async fn add_image( State(context): State, + headers: HeaderMap, mut img_data: Multipart, ) -> Result, PicOxError> { let mut data: Vec = Vec::new(); @@ -277,6 +296,9 @@ async fn add_image( } } + let user = headers.get("user").unwrap().to_str().unwrap(); + let user: u64 = user.parse().unwrap(); + let mut album = Album::find_album_by_query(&context.db, metadata.clone().unwrap().album).unwrap(); @@ -287,7 +309,7 @@ async fn add_image( None, ImageData::Bytes(data), &file_name.unwrap(), - 0, + user, album.id().unwrap(), ) .await?; @@ -307,6 +329,26 @@ async fn query_album( Ok(Response(resp)) } +async fn check_token_header( + State(context): State, + mut request: Request, + next: Next, +) -> Result { + let headers = request.headers(); + + if let Some(token) = headers.get("token") { + if let Some(api_key) = ApiKey::find_api_key_by_token(&context.db, token.to_str().unwrap())? + { + request + .headers_mut() + .insert("user", api_key.id().unwrap().into()); + return Ok(next.run(request).await); + } + } + + Err(PicOxError::TokenInvalid) +} + async fn run_picox(db: Database, config: PicOxConfig) { let store_manager = StorageManager::new(config.storage_config.clone()); @@ -319,12 +361,15 @@ async fn run_picox(db: Database, config: PicOxConfig) { let context = Arc::new(context); let app = Router::new() - // `GET /` goes to `root` - .route("/api/album/create", post(create_album)) - .route("/api/album/:id", get(get_album)) - .route("/api/album/", get(query_album)) .route("/api/image/", post(add_image)) .layer(DefaultBodyLimit::max(1024 * 1024 * 1024)) + .route("/api/album/create", post(create_album)) + .layer(middleware::from_fn_with_state( + context.clone(), + check_token_header, + )) + .route("/api/album/:id", get(get_album)) + .route("/api/album/", get(query_album)) .route("/api/image/", get(query_images)) .with_state(context); @@ -348,17 +393,22 @@ async fn main() { SubCommands::Start => { run_picox(db, config).await; } - SubCommands::CreateKey { description, permissions } => { - let (token, api_key) = ApiKey::create_new_key(&db, description, permissions).unwrap(); + SubCommands::CreateKey { + description, + permissions, + } => { + let (token, api_key) = ApiKey::create_new_key(&db, description, permissions).unwrap(); info!("New Key info: {:?}", api_key); info!("Token: {}", base64::prelude::BASE64_STANDARD.encode(token)); db.db.flush().unwrap(); } - SubCommands::Dump {out_path} => { + SubCommands::Dump { out_path } => { info!("Dumping database state to {:?}", out_path); - tokio::fs::write(out_path, db.dump_db().unwrap().pretty(4)).await.unwrap(); + tokio::fs::write(out_path, db.dump_db().unwrap().pretty(4)) + .await + .unwrap(); } SubCommands::Import { db_file } => { info!("Importing database state from {:?}", db_file); @@ -370,5 +420,4 @@ async fn main() { db.db.flush().unwrap(); } } - } diff --git a/src/model/api_key.rs b/src/model/api_key.rs index 87355fd..d7621f9 100644 --- a/src/model/api_key.rs +++ b/src/model/api_key.rs @@ -1,7 +1,7 @@ -use std::str::FromStr; use base64::Engine; use j_db::database::Database; use serde::{Deserialize, Serialize}; +use std::str::FromStr; use bitflags::bitflags; use rand::RngCore; @@ -29,18 +29,22 @@ pub struct ApiKey { pub description: String, pub permissions: ApiPermissions, - id: Option + id: Option, } impl ApiKey { - pub fn create_new_key(db: &Database, description: String, permissions: ApiPermissions) -> Result<(Vec, ApiKey), j_db::error::JDbError>{ + pub fn create_new_key( + db: &Database, + description: String, + permissions: ApiPermissions, + ) -> Result<(Vec, ApiKey), j_db::error::JDbError> { let mut key = [0u8; 32]; rand::thread_rng().fill_bytes(&mut key); let hash = sha2::Sha256::digest(key); let api_key = ApiKey { - token_hash: base64::prelude::BASE64_STANDARD.encode(&hash), + token_hash: base64::prelude::BASE64_STANDARD.encode(hash), description, permissions, id: None, @@ -50,6 +54,18 @@ impl ApiKey { Ok((key.to_vec(), api_key)) } + + pub fn find_api_key_by_token( + db: &Database, + token: &str, + ) -> Result, j_db::error::JDbError> { + let token = base64::prelude::BASE64_STANDARD.decode(token).unwrap(); + let hash = sha2::Sha256::digest(token); + let hash = base64::prelude::BASE64_STANDARD.encode(hash); + Ok(db + .filter(move |_, api_key: &ApiKey| api_key.token_hash == hash)? + .next()) + } } impl j_db::model::JdbModel for ApiKey { @@ -64,4 +80,4 @@ impl j_db::model::JdbModel for ApiKey { fn tree() -> String { "ApiKey".to_string() } -} \ No newline at end of file +} diff --git a/src/model/mod.rs b/src/model/mod.rs index 8ada03b..dba073a 100644 --- a/src/model/mod.rs +++ b/src/model/mod.rs @@ -1,3 +1,3 @@ pub mod album; -pub mod image; pub mod api_key; +pub mod image; diff --git a/src/storage_manager/mod.rs b/src/storage_manager/mod.rs index d3c8fa4..c8efd8d 100644 --- a/src/storage_manager/mod.rs +++ b/src/storage_manager/mod.rs @@ -48,8 +48,8 @@ pub trait Store: Send { &mut self, img_data: ImageData, file_name: &str, - album: u64, created_by: u64, + album: u64, ) -> Result { let (url, storage_location) = self.store_img(img_data, file_name).await?;