use thiserror::Error; use log::trace; use reqwest::{header::HeaderMap, Client, Error, Response, StatusCode}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde_json::{json, Value}; use serenity::prelude::TypeMapKey; use crate::utils; #[derive(Debug, Error)] pub enum DatabaseError { #[error("An error has occured during an I/O operation: {0}")] IOError(String), #[error("An error has occured during a parsing operation: {0}")] ParseError(String), #[error("An invalid status code was returned from a request: {0}")] InvalidStatusCode(StatusCode), #[error("An invalid input body was provided: {0}")] InvalidInputBody(String), } #[derive(Serialize, Deserialize)] struct GetAccessTokenResponse { id: String, access_token: String, } #[derive(Deserialize)] pub struct User { pub id: String, pub device_name: String, pub request: Option, pub accounts: Option>, } #[derive(Deserialize)] pub struct Account { pub user_id: String, pub r#type: String, pub access_token: String, pub refresh_token: String, pub expires: u64, } #[derive(Deserialize)] pub struct Request { pub token: String, pub user_id: String, pub expires: u64, } pub struct Database { base_url: String, default_headers: Option, } // Request options #[derive(Debug, Clone)] struct RequestOptions { pub method: Method, pub path: String, pub body: Option, pub headers: Option, } #[derive(Debug, Clone)] #[allow(dead_code)] enum Body { Json(Value), Text(String), } #[derive(Debug, Clone)] #[allow(dead_code)] enum Method { Get, Post, Put, Delete, Patch, } impl Database { pub fn new(base_url: impl Into, default_headers: Option) -> Self { Self { base_url: base_url.into(), default_headers, } } async fn request(&self, options: RequestOptions) -> Result { let builder = Client::builder(); let mut headers: HeaderMap = HeaderMap::new(); let mut url = self.base_url.clone(); url.push_str(&options.path); if let Some(default_headers) = &self.default_headers { headers.extend(default_headers.clone()); } if let Some(request_headers) = options.headers { headers.extend(request_headers); } trace!("Requesting {} with headers: {:?}", url, headers); let client = builder.default_headers(headers).build()?; let mut request = match options.method { Method::Get => client.get(url), Method::Post => client.post(url), Method::Put => client.put(url), Method::Delete => client.delete(url), Method::Patch => client.patch(url), }; request = if let Some(body) = options.body { match body { Body::Json(json) => request.json(&json), Body::Text(text) => request.body(text), } } else { request }; let response = request.send().await?; Ok(response) } async fn simple_get( &self, path: impl Into, ) -> Result { let response = match self .request(RequestOptions { method: Method::Get, path: path.into(), body: None, headers: None, }) .await { Ok(response) => response, Err(error) => return Err(DatabaseError::IOError(error.to_string())), }; match response.status() { StatusCode::OK | StatusCode::CREATED | StatusCode::ACCEPTED | StatusCode::NO_CONTENT => {} status => return Err(DatabaseError::InvalidStatusCode(status)), }; let body = match response.json::().await { Ok(body) => body, Err(error) => return Err(DatabaseError::ParseError(error.to_string())), }; Ok(body) } async fn json_post( &self, value: impl Serialize, path: impl Into, ) -> Result { let body = json!(value); let response = match self .request(RequestOptions { method: Method::Post, path: path.into(), body: Some(Body::Json(body)), headers: None, }) .await { Ok(response) => response, Err(error) => return Err(DatabaseError::IOError(error.to_string())), }; match response.status() { StatusCode::OK | StatusCode::CREATED | StatusCode::ACCEPTED | StatusCode::NO_CONTENT => {} status => return Err(DatabaseError::InvalidStatusCode(status)), }; let body = match response.json::().await { Ok(body) => body, Err(error) => return Err(DatabaseError::ParseError(error.to_string())), }; Ok(body) } } impl Database { // Get Spoticord user pub async fn get_user(&self, user_id: impl Into) -> Result { let path = format!("/user/{}", user_id.into()); self.simple_get(path).await } // Get the Spotify access token for a user pub async fn get_access_token( &self, user_id: impl Into + Send, ) -> Result { let body: GetAccessTokenResponse = self .simple_get(format!("/user/{}/spotify/access_token", user_id.into())) .await?; Ok(body.access_token) } // Get the Spotify account for a user pub async fn get_user_account( &self, user_id: impl Into + Send, ) -> Result { let body: Account = self .simple_get(format!("/account/{}/spotify", user_id.into())) .await?; Ok(body) } // Get the Request for a user pub async fn get_user_request( &self, user_id: impl Into + Send, ) -> Result { let body: Request = self .simple_get(format!("/request/by-user/{}", user_id.into())) .await?; Ok(body) } // Create a Spoticord user pub async fn create_user(&self, user_id: impl Into) -> Result { let body = json!({ "id": user_id.into(), }); let user: User = self.json_post(body, "/user/new").await?; Ok(user) } // Create the link Request for a user pub async fn create_user_request( &self, user_id: impl Into + Send, ) -> Result { let body = json!({ "user_id": user_id.into(), "expires": utils::get_time() + (1000 * 60 * 60) }); let response = match self .request(RequestOptions { method: Method::Post, path: "/request".into(), body: Some(Body::Json(body)), headers: None, }) .await { Ok(response) => response, Err(err) => return Err(DatabaseError::IOError(err.to_string())), }; match response.status() { StatusCode::OK | StatusCode::CREATED | StatusCode::ACCEPTED | StatusCode::NO_CONTENT => {} status => return Err(DatabaseError::InvalidStatusCode(status)), }; let body = match response.json::().await { Ok(body) => body, Err(error) => return Err(DatabaseError::ParseError(error.to_string())), }; Ok(body) } pub async fn delete_user_account( &self, user_id: impl Into + Send, ) -> Result<(), DatabaseError> { let response = match self .request(RequestOptions { method: Method::Delete, path: format!("/account/{}/spotify", user_id.into()), body: None, headers: None, }) .await { Ok(response) => response, Err(err) => return Err(DatabaseError::IOError(err.to_string())), }; match response.status() { StatusCode::OK | StatusCode::CREATED | StatusCode::ACCEPTED | StatusCode::NO_CONTENT => {} status => return Err(DatabaseError::InvalidStatusCode(status)), }; Ok(()) } pub async fn update_user_device_name( &self, user_id: impl Into, name: impl Into, ) -> Result<(), DatabaseError> { let device_name: String = name.into(); if device_name.len() > 16 || device_name.is_empty() { return Err(DatabaseError::InvalidInputBody( "Invalid device name length".into(), )); } let body = json!({ "device_name": device_name }); let response = match self .request(RequestOptions { method: Method::Patch, path: format!("/user/{}", user_id.into()), body: Some(Body::Json(body)), headers: None, }) .await { Ok(response) => response, Err(err) => return Err(DatabaseError::IOError(err.to_string())), }; match response.status() { StatusCode::OK | StatusCode::CREATED | StatusCode::ACCEPTED | StatusCode::NO_CONTENT => { Ok(()) } status => Err(DatabaseError::InvalidStatusCode(status)), } } } impl TypeMapKey for Database { type Value = Database; }