124 lines
4.2 KiB
Rust
124 lines
4.2 KiB
Rust
mod settings;
|
|
|
|
use crate::api_endpoint::{ApiEndpoint, RequestType};
|
|
use crate::context::Context;
|
|
use crate::helper::get_token_from_req;
|
|
use crate::model::settings::Settings;
|
|
use crate::Result;
|
|
use geoffrey_models::models::parameters::{GeoffreyParam, ModelRequest};
|
|
use geoffrey_models::models::response::api_error::GeoffreyAPIError;
|
|
use geoffrey_models::models::response::APIResponse;
|
|
use geoffrey_models::models::token::{Permissions, Token};
|
|
use serde::de::DeserializeOwned;
|
|
use serde::Serialize;
|
|
use std::fmt::Debug;
|
|
use std::sync::Arc;
|
|
use warp::filters::BoxedFilter;
|
|
use warp::Filter;
|
|
|
|
pub trait ModelEndpoint: ApiEndpoint {
|
|
type Req: GeoffreyParam + 'static;
|
|
type Resp: Serialize + DeserializeOwned + Send + Debug;
|
|
|
|
fn token_permission() -> Vec<Permissions>;
|
|
fn run_endpoint(ctx: Arc<Context>, req: &Self::Req) -> Result<Self::Resp>;
|
|
|
|
fn check_token_permission(token: &Token) -> bool {
|
|
for perm in Self::token_permission() {
|
|
if !token.check_permission(perm) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
true
|
|
}
|
|
}
|
|
|
|
pub fn handle_remote_model_request<T: ModelEndpoint>(
|
|
ctx: Arc<Context>,
|
|
req: ModelRequest<T::Req>,
|
|
) -> Result<T::Resp> {
|
|
let token = get_token_from_req(&ctx.db, &req)?;
|
|
|
|
if let Some(token) = token {
|
|
if T::check_token_permission(&token) {
|
|
return T::run_endpoint(ctx, &req.params);
|
|
}
|
|
}
|
|
|
|
Err(GeoffreyAPIError::TokenNotAuthorized)
|
|
}
|
|
|
|
#[allow(clippy::needless_return)]
|
|
pub fn create_remote_model_filter<T: ModelEndpoint>(
|
|
ctx: Arc<Context>,
|
|
) -> BoxedFilter<(impl warp::Reply,)> {
|
|
let filter = warp::path(T::endpoint_name())
|
|
.and(warp::any().map(move || ctx.clone()))
|
|
.and(warp::body::json())
|
|
.map(|ctx: Arc<Context>, req: ModelRequest<T::Req>| {
|
|
log::info!("Running model query {}", T::endpoint_name());
|
|
log::debug!("Request params: {:?}", req.params);
|
|
|
|
let reply = handle_remote_model_request::<T>(ctx, req);
|
|
|
|
if let Ok(reply) = reply {
|
|
log::debug!("Successfully processed model request");
|
|
warp::reply::json(&APIResponse::Response::<T::Resp>(reply))
|
|
} else {
|
|
let e = reply.err().unwrap();
|
|
let msg = e.to_string();
|
|
log::warn!("Got error when processing model request '{:?}': {}", e, msg);
|
|
warp::reply::json(&APIResponse::<T::Resp>::Error { error: e, msg })
|
|
}
|
|
});
|
|
|
|
if T::request_type() == RequestType::POST {
|
|
return filter.and(warp::post()).boxed();
|
|
} else {
|
|
return filter.and(warp::get()).boxed();
|
|
}
|
|
}
|
|
|
|
#[allow(clippy::needless_return)]
|
|
pub fn create_local_model_filter<T: ModelEndpoint>(
|
|
ctx: Arc<Context>,
|
|
) -> BoxedFilter<(impl warp::Reply,)> {
|
|
let filter = warp::path(T::endpoint_name())
|
|
.and(warp::any().map(move || ctx.clone()))
|
|
.and(warp::body::json())
|
|
.map(|ctx: Arc<Context>, req: T::Req| {
|
|
log::info!("Running local model query {}", T::endpoint_name());
|
|
log::debug!("Request params: {:?}", req);
|
|
let reply = T::run_endpoint(ctx, &req);
|
|
|
|
if let Ok(reply) = reply {
|
|
log::debug!("Successfully processed model request");
|
|
warp::reply::json(&APIResponse::Response::<T::Resp>(reply))
|
|
} else {
|
|
let e = reply.err().unwrap();
|
|
let msg = e.to_string();
|
|
log::warn!("Got error when processing model request '{:?}': {}", e, msg);
|
|
warp::reply::json(&APIResponse::<T::Resp>::Error { error: e, msg })
|
|
}
|
|
});
|
|
|
|
if T::request_type() == RequestType::POST {
|
|
return filter.and(warp::post()).boxed();
|
|
} else {
|
|
return filter.and(warp::get()).boxed();
|
|
}
|
|
}
|
|
|
|
pub fn remote_model_filter(
|
|
ctx: Arc<Context>,
|
|
) -> impl Filter<Extract = impl warp::Reply, Error = warp::Rejection> + Clone {
|
|
warp::path("model").and(create_remote_model_filter::<Settings>(ctx))
|
|
}
|
|
|
|
pub fn local_model_filter(
|
|
ctx: Arc<Context>,
|
|
) -> impl Filter<Extract = impl warp::Reply, Error = warp::Rejection> + Clone {
|
|
warp::path("model").and(create_local_model_filter::<Settings>(ctx))
|
|
}
|