use crate::error::{JDbError, Result}; use crate::metadata::DBMetadata; use crate::model::JdbModel; use crate::query::QueryBuilder; use json::JsonValue; use sled::IVec; use std::convert::TryInto; use std::path::Path; pub const DB_METADATA_ID: u64 = 1; fn option_bytes_to_model(bytes: Option, id: u64) -> Result { if let Some(bytes) = bytes { Ok(T::try_from_bytes(&bytes)?) } else { log::debug!("{} of id {} was not found in the database", T::tree(), id); Err(JDbError::NotFound) } } pub struct Database { pub db: sled::Db, } impl Database { pub fn new(db_path: &Path) -> Result { let db = sled::open(db_path)?; let db = Self { db }; let version = match db.version() { Ok(version) => version, Err(_) => { let db_metadata = DBMetadata { id: Some(DB_METADATA_ID), version: 0, }; db.insert(db_metadata)?.version } }; log::info!("jDb Version V{}", version); Ok(db) } fn get_tree(&self) -> Result where T: JdbModel, { Ok(self.db.open_tree::(T::tree())?) } pub fn insert(&self, mut model: T) -> Result where T: JdbModel, { let id = match model.id() { Some(id) => id, None => { let id = self.db.generate_id()?; model.set_id(id); id } }; let match_count = self .filter(|o_id, o: &T| o_id != id && !o.check_unique(&model))? .count(); if match_count > 0 { log::debug!("{} is not unique: {:?}", T::tree(), model); return Err(JDbError::NotUnique); } let tree = self.get_tree::()?; let id_bytes = id.to_be_bytes(); tree.insert(id_bytes, model.to_bytes()?)?; Ok(model) } pub fn get(&self, id: u64) -> Result where T: JdbModel, { let tree = self.get_tree::()?; option_bytes_to_model(tree.get(id.to_be_bytes())?, id) } pub fn clear_tree(&self) -> Result<()> where T: JdbModel, { self.db.drop_tree(T::tree())?; Ok(()) } pub fn filter<'a, T>( &self, f: impl Fn(u64, &T) -> bool + 'a, ) -> Result + 'a> where T: JdbModel, { let tree = self.db.open_tree(T::tree())?; Ok(tree.iter().filter_map(move |e| { if let Ok((id, data)) = e { let id = u64::from_be_bytes(id.to_vec().try_into().unwrap()); let data = match T::try_from_bytes(&data) { Ok(data) => data, Err(err) => { log::debug!( "Invalid data: {}", String::from_utf8(data.to_vec()).unwrap_or_default() ); panic!("Unable to parse {} model from bytes: {}", T::tree(), err); } }; if f(id, &data) { Some(data) } else { None } } else { None } })) } pub fn run_query(&self, query_builder: QueryBuilder) -> Result> where T: JdbModel, { let result: Vec = self .filter(|id, loc: &T| { for query in &query_builder.queries { let res = query(id, loc); if !res { return false; } } true })? .collect(); if result.is_empty() { Err(JDbError::NotFound) } else { Ok(result) } } pub fn remove(&self, id: u64) -> Result where T: JdbModel, { let tree = self.db.open_tree(T::tree())?; option_bytes_to_model(tree.remove(id.to_be_bytes())?, id) } pub fn tree_iter(&self) -> Result where T: JdbModel, { Ok(self.db.open_tree(T::tree()).map(|tree| tree.iter())?) } pub fn version(&self) -> Result { Ok(self.get::(DB_METADATA_ID)?.version) } pub(crate) fn set_version(&mut self, version: u64) -> Result<()> { let mut md = self.get::(DB_METADATA_ID)?; md.version = version; self.insert(md)?; Ok(()) } pub fn dump_db(&self) -> Result { let mut json = JsonValue::new_object(); let mut global_array = JsonValue::new_array(); for model in self.db.iter() { let (_, model) = model?; let model_str = String::from_utf8(model.to_vec()).unwrap(); let model_json = json::parse(&model_str)?; global_array.push(json::from(model_json))?; } json.insert("global", global_array)?; for tree in self.db.tree_names() { let tree = self.db.open_tree(tree)?; let mut tree_array = JsonValue::new_array(); for model in tree.iter() { let (_, model) = model?; let model_str = String::from_utf8(model.to_vec()).unwrap(); let model_json = json::parse(&model_str)?; tree_array.push(json::from(model_json))?; } let tree_name = String::from_utf8(tree.name().to_vec()).unwrap(); json.insert(&tree_name, tree_array)?; } Ok(json) } pub fn import_db(&self, json: JsonValue) -> Result<()> { for model in json["global"].members() { let id_bytes = model["id"].as_u64().unwrap().to_be_bytes(); self.db.insert(id_bytes, model.to_string().as_bytes())?; } for (tree, models) in json.entries() { for model in models.members() { let id_bytes = model["id"].as_u64().unwrap().to_be_bytes(); if tree == "global" { self.db.insert(id_bytes, model.to_string().as_bytes())?; } else { let tree = self.db.open_tree(tree)?; tree.insert(id_bytes, model.to_string().as_bytes())?; } } } Ok(()) } } #[cfg(test)] mod tests { use crate::model::JdbModel; use crate::test::{cleanup, User, DB, LOCK}; use std::time::Instant; #[test] fn test_insert() { let _lock = LOCK.lock().unwrap(); cleanup(); let user = User::new("Test", 1); let user2 = DB.insert::(user.clone()).unwrap(); assert!(user2.id().is_some()); assert_eq!(user.name, user2.name); cleanup(); } #[test] fn test_unique_insert() { let _lock = LOCK.lock().unwrap(); cleanup(); let user1 = User::new("Test", 1); let user2 = User::new("Test", 1); DB.insert::(user1.clone()).unwrap(); assert_eq!(DB.insert::(user2.clone()).is_err(), true); cleanup(); } #[test] fn test_get() { let _lock = LOCK.lock().unwrap(); cleanup(); let user1 = User::new("Test", 1); let user1_db = DB.insert::(user1.clone()).unwrap(); let user1_get = DB.get::(user1_db.id().unwrap()).unwrap(); assert_eq!(user1_get.name, user1.name); cleanup(); } #[test] fn test_filter() { let _lock = LOCK.lock().unwrap(); cleanup(); let user = User::new("Test", 1); let user = DB.insert::(user.clone()).unwrap(); let count = DB .filter(|id: u64, u: &User| { assert_eq!(id, user.id().unwrap()); u.id().unwrap() == user.id().unwrap() }) .unwrap() .count(); assert_eq!(count, 1); cleanup(); } #[test] fn test_remove() { let _lock = LOCK.lock().unwrap(); cleanup(); let user = User::new("CoolZero123", 1); let user_insert = DB.insert::(user.clone()).unwrap(); let user_remove = DB.remove::(user_insert.id().unwrap()).unwrap(); assert!(DB.get::(user_insert.id().unwrap()).is_err()); assert_eq!(user_remove.id().unwrap(), user_insert.id().unwrap()); cleanup(); } #[test] fn test_speed() { let _lock = LOCK.lock().unwrap(); cleanup(); let insert_count = 1000; let timer = Instant::now(); for i in 0..insert_count { let user = User::new(&format!("User{}", i), 0); DB.insert::(user).unwrap(); } DB.db.flush().unwrap(); let sec_elapsed = timer.elapsed().as_secs_f32(); println!( "Completed in {}s. {} inserts per second", sec_elapsed, insert_count as f32 / sec_elapsed ); cleanup() } #[test] fn test_dump_and_load() { let _lock = LOCK.lock().unwrap(); cleanup(); let mut users = vec![]; for i in 0..10 { let user = User::new(&format!("User{}", i), 0); let u = DB.insert::(user).unwrap(); users.push(u); } let out = DB.dump_db().unwrap(); println!("{}", out); DB.import_db(out).unwrap(); for user in users { let import_user = DB.get::(user.id().unwrap()).unwrap(); assert_eq!(user.name, import_user.name); } cleanup(); } }