Added auth handling for protected routes

main
Joey Hines 2024-02-10 11:02:07 -07:00
parent b9b14a7118
commit 5d2011861b
Signed by: joeyahines
GPG Key ID: 995E531F7A569DDB
4 changed files with 97 additions and 32 deletions

View File

@ -6,16 +6,19 @@ mod storage_manager;
use crate::config::PicOxConfig; use crate::config::PicOxConfig;
use crate::model::album::Album; use crate::model::album::Album;
use crate::model::api_key::{ApiKey, ApiPermissions};
use crate::model::image::{Image, ImageData}; use crate::model::image::{Image, ImageData};
use crate::state::Context; use crate::state::Context;
use crate::storage_manager::{StorageManager, StoreError}; use crate::storage_manager::{StorageManager, StoreError};
use axum::body::Bytes; use axum::body::Bytes;
use axum::extract::{DefaultBodyLimit, Multipart, Path, Query, State}; use axum::extract::{DefaultBodyLimit, Multipart, Path, Query, Request, State};
use axum::http::StatusCode; use axum::http::{HeaderMap, StatusCode};
use axum::middleware::Next;
use axum::response::IntoResponse; use axum::response::IntoResponse;
use axum::routing::{get, post}; use axum::routing::{get, post};
use axum::{Json, Router}; use axum::{middleware, Json, Router};
use axum_macros::FromRequest; use axum_macros::FromRequest;
use base64::Engine;
use j_db::database::Database; use j_db::database::Database;
use j_db::model::JdbModel; use j_db::model::JdbModel;
use log::info; use log::info;
@ -24,14 +27,11 @@ use rand::thread_rng;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
use base64::Engine;
use structopt::StructOpt; use structopt::StructOpt;
use tokio::sync::RwLock; use tokio::sync::RwLock;
use crate::model::api_key::{ApiKey, ApiPermissions};
type PicContext = Arc<Context>; type PicContext = Arc<Context>;
#[derive(StructOpt, Debug, Clone)] #[derive(StructOpt, Debug, Clone)]
#[structopt(about = "PicOx Commands")] #[structopt(about = "PicOx Commands")]
enum SubCommands { enum SubCommands {
@ -42,7 +42,7 @@ enum SubCommands {
/// Key description /// Key description
description: String, description: String,
/// API Key permissions (WRITE, DELETE) /// API Key permissions (WRITE, DELETE)
permissions: ApiPermissions permissions: ApiPermissions,
}, },
/// Dump the database state to json /// Dump the database state to json
Dump { Dump {
@ -52,18 +52,17 @@ enum SubCommands {
/// Import the database state from json /// Import the database state from json
Import { Import {
/// Path to json containing DB state /// Path to json containing DB state
db_file: PathBuf db_file: PathBuf,
} },
} }
#[derive(Debug, Clone, StructOpt)] #[derive(Debug, Clone, StructOpt)]
struct Args { struct Args {
/// Path to the config file /// Path to the config file
#[structopt(short, long, env = "PICOX_CONFIG")] #[structopt(short, long, env = "PICOX_CONFIG")]
pub config: PathBuf, pub config: PathBuf,
#[structopt(subcommand)] #[structopt(subcommand)]
pub command: SubCommands pub command: SubCommands,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@ -116,11 +115,13 @@ where
} }
#[allow(dead_code)] #[allow(dead_code)]
enum PicOxError { pub enum PicOxError {
StoreError(StoreError), StoreError(StoreError),
DbError(j_db::error::JDbError), DbError(j_db::error::JDbError),
AlbumNotFound, AlbumNotFound,
ImageNotFound, ImageNotFound,
TokenInvalid,
NoUserInHeader,
} }
impl From<StoreError> for PicOxError { impl From<StoreError> for PicOxError {
@ -161,6 +162,11 @@ impl IntoResponse for PicOxError {
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
"Image not found".to_string(), "Image not found".to_string(),
), ),
PicOxError::TokenInvalid => (StatusCode::UNAUTHORIZED, "Token is invalid".to_string()),
PicOxError::NoUserInHeader => (
StatusCode::INTERNAL_SERVER_ERROR,
"User not found in header".to_string(),
),
}; };
(status, Response(ErrorResponse { message })).into_response() (status, Response(ErrorResponse { message })).into_response()
@ -169,12 +175,16 @@ impl IntoResponse for PicOxError {
async fn create_album( async fn create_album(
State(context): State<PicContext>, State(context): State<PicContext>,
headers: HeaderMap,
Json(album): Json<CreateAlbum>, Json(album): Json<CreateAlbum>,
) -> Result<Response<Album>, PicOxError> { ) -> Result<Response<Album>, PicOxError> {
info!("Creating new album '{}'", album.album_name); let user_id = get_user_id_from_headers(&headers)?;
let new_album = Album::new(&album.album_name, Vec::new(), user_id);
let new_album = Album::new(&album.album_name, Vec::new(), 0);
info!(
"Creating new album '{}' for user {}",
album.album_name, user_id
);
let new_album = context.db.insert(new_album)?; let new_album = context.db.insert(new_album)?;
Ok(Response(new_album)) Ok(Response(new_album))
@ -256,8 +266,17 @@ async fn query_images(
Ok(Response(images)) Ok(Response(images))
} }
fn get_user_id_from_headers(headers: &HeaderMap) -> Result<u64, PicOxError> {
let user = headers.get("user").ok_or(PicOxError::NoUserInHeader)?;
let user_str = user.to_str().unwrap();
let user: u64 = user_str.parse().unwrap();
Ok(user)
}
async fn add_image( async fn add_image(
State(context): State<PicContext>, State(context): State<PicContext>,
headers: HeaderMap,
mut img_data: Multipart, mut img_data: Multipart,
) -> Result<Response<Image>, PicOxError> { ) -> Result<Response<Image>, PicOxError> {
let mut data: Vec<u8> = Vec::new(); let mut data: Vec<u8> = Vec::new();
@ -277,6 +296,9 @@ async fn add_image(
} }
} }
let user = headers.get("user").unwrap().to_str().unwrap();
let user: u64 = user.parse().unwrap();
let mut album = let mut album =
Album::find_album_by_query(&context.db, metadata.clone().unwrap().album).unwrap(); Album::find_album_by_query(&context.db, metadata.clone().unwrap().album).unwrap();
@ -287,7 +309,7 @@ async fn add_image(
None, None,
ImageData::Bytes(data), ImageData::Bytes(data),
&file_name.unwrap(), &file_name.unwrap(),
0, user,
album.id().unwrap(), album.id().unwrap(),
) )
.await?; .await?;
@ -307,6 +329,26 @@ async fn query_album(
Ok(Response(resp)) Ok(Response(resp))
} }
async fn check_token_header(
State(context): State<PicContext>,
mut request: Request,
next: Next,
) -> Result<impl IntoResponse, PicOxError> {
let headers = request.headers();
if let Some(token) = headers.get("token") {
if let Some(api_key) = ApiKey::find_api_key_by_token(&context.db, token.to_str().unwrap())?
{
request
.headers_mut()
.insert("user", api_key.id().unwrap().into());
return Ok(next.run(request).await);
}
}
Err(PicOxError::TokenInvalid)
}
async fn run_picox(db: Database, config: PicOxConfig) { async fn run_picox(db: Database, config: PicOxConfig) {
let store_manager = StorageManager::new(config.storage_config.clone()); let store_manager = StorageManager::new(config.storage_config.clone());
@ -319,12 +361,15 @@ async fn run_picox(db: Database, config: PicOxConfig) {
let context = Arc::new(context); let context = Arc::new(context);
let app = Router::new() let app = Router::new()
// `GET /` goes to `root`
.route("/api/album/create", post(create_album))
.route("/api/album/:id", get(get_album))
.route("/api/album/", get(query_album))
.route("/api/image/", post(add_image)) .route("/api/image/", post(add_image))
.layer(DefaultBodyLimit::max(1024 * 1024 * 1024)) .layer(DefaultBodyLimit::max(1024 * 1024 * 1024))
.route("/api/album/create", post(create_album))
.layer(middleware::from_fn_with_state(
context.clone(),
check_token_header,
))
.route("/api/album/:id", get(get_album))
.route("/api/album/", get(query_album))
.route("/api/image/", get(query_images)) .route("/api/image/", get(query_images))
.with_state(context); .with_state(context);
@ -348,17 +393,22 @@ async fn main() {
SubCommands::Start => { SubCommands::Start => {
run_picox(db, config).await; run_picox(db, config).await;
} }
SubCommands::CreateKey { description, permissions } => { SubCommands::CreateKey {
description,
permissions,
} => {
let (token, api_key) = ApiKey::create_new_key(&db, description, permissions).unwrap(); let (token, api_key) = ApiKey::create_new_key(&db, description, permissions).unwrap();
info!("New Key info: {:?}", api_key); info!("New Key info: {:?}", api_key);
info!("Token: {}", base64::prelude::BASE64_STANDARD.encode(token)); info!("Token: {}", base64::prelude::BASE64_STANDARD.encode(token));
db.db.flush().unwrap(); db.db.flush().unwrap();
} }
SubCommands::Dump {out_path} => { SubCommands::Dump { out_path } => {
info!("Dumping database state to {:?}", out_path); info!("Dumping database state to {:?}", out_path);
tokio::fs::write(out_path, db.dump_db().unwrap().pretty(4)).await.unwrap(); tokio::fs::write(out_path, db.dump_db().unwrap().pretty(4))
.await
.unwrap();
} }
SubCommands::Import { db_file } => { SubCommands::Import { db_file } => {
info!("Importing database state from {:?}", db_file); info!("Importing database state from {:?}", db_file);
@ -370,5 +420,4 @@ async fn main() {
db.db.flush().unwrap(); db.db.flush().unwrap();
} }
} }
} }

View File

@ -1,7 +1,7 @@
use std::str::FromStr;
use base64::Engine; use base64::Engine;
use j_db::database::Database; use j_db::database::Database;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::str::FromStr;
use bitflags::bitflags; use bitflags::bitflags;
use rand::RngCore; use rand::RngCore;
@ -29,18 +29,22 @@ pub struct ApiKey {
pub description: String, pub description: String,
pub permissions: ApiPermissions, pub permissions: ApiPermissions,
id: Option<u64> id: Option<u64>,
} }
impl ApiKey { impl ApiKey {
pub fn create_new_key(db: &Database, description: String, permissions: ApiPermissions) -> Result<(Vec<u8>, ApiKey), j_db::error::JDbError>{ pub fn create_new_key(
db: &Database,
description: String,
permissions: ApiPermissions,
) -> Result<(Vec<u8>, ApiKey), j_db::error::JDbError> {
let mut key = [0u8; 32]; let mut key = [0u8; 32];
rand::thread_rng().fill_bytes(&mut key); rand::thread_rng().fill_bytes(&mut key);
let hash = sha2::Sha256::digest(key); let hash = sha2::Sha256::digest(key);
let api_key = ApiKey { let api_key = ApiKey {
token_hash: base64::prelude::BASE64_STANDARD.encode(&hash), token_hash: base64::prelude::BASE64_STANDARD.encode(hash),
description, description,
permissions, permissions,
id: None, id: None,
@ -50,6 +54,18 @@ impl ApiKey {
Ok((key.to_vec(), api_key)) Ok((key.to_vec(), api_key))
} }
pub fn find_api_key_by_token(
db: &Database,
token: &str,
) -> Result<Option<ApiKey>, j_db::error::JDbError> {
let token = base64::prelude::BASE64_STANDARD.decode(token).unwrap();
let hash = sha2::Sha256::digest(token);
let hash = base64::prelude::BASE64_STANDARD.encode(hash);
Ok(db
.filter(move |_, api_key: &ApiKey| api_key.token_hash == hash)?
.next())
}
} }
impl j_db::model::JdbModel for ApiKey { impl j_db::model::JdbModel for ApiKey {

View File

@ -1,3 +1,3 @@
pub mod album; pub mod album;
pub mod image;
pub mod api_key; pub mod api_key;
pub mod image;

View File

@ -48,8 +48,8 @@ pub trait Store: Send {
&mut self, &mut self,
img_data: ImageData, img_data: ImageData,
file_name: &str, file_name: &str,
album: u64,
created_by: u64, created_by: u64,
album: u64,
) -> Result<Image, StoreError> { ) -> Result<Image, StoreError> {
let (url, storage_location) = self.store_img(img_data, file_name).await?; let (url, storage_location) = self.store_img(img_data, file_name).await?;