From 04014c24d6fcedc347bd4c26fcf2410de81dad07 Mon Sep 17 00:00:00 2001 From: Joey Hines Date: Sat, 16 Jan 2021 13:58:46 -0600 Subject: [PATCH] Code cleanup + Reduced copies needed to access config or the draft event from global data + fmt + clippy --- src/database/models.rs | 17 +++++++++- src/discord/events.rs | 23 +++++++------ src/discord/mod.rs | 75 +++++++++++++++++++++++------------------- src/hypebot_config.rs | 3 +- src/main.rs | 44 ++++++++----------------- 5 files changed, 86 insertions(+), 76 deletions(-) diff --git a/src/database/models.rs b/src/database/models.rs index 6230913..ebcfe9f 100644 --- a/src/database/models.rs +++ b/src/database/models.rs @@ -1,5 +1,5 @@ use super::schema::events; -use chrono::NaiveDateTime; +use chrono::{NaiveDateTime, Utc}; #[derive(Queryable, Clone, Debug)] pub struct Event { @@ -58,3 +58,18 @@ pub struct NewEvent { /// Reminder sent tracker pub reminder_sent: i32, } + +impl Default for NewEvent { + fn default() -> Self { + Self { + message_id: String::default(), + event_time: Utc::now().naive_utc(), + event_name: String::default(), + organizer: String::default(), + event_desc: String::default(), + event_loc: String::default(), + thumbnail_link: String::default(), + reminder_sent: i32::default(), + } + } +} diff --git a/src/discord/events.rs b/src/discord/events.rs index 8e84b51..0b9b7d8 100644 --- a/src/discord/events.rs +++ b/src/discord/events.rs @@ -21,8 +21,8 @@ use url::Url; /// **Note** /// You can only post events you have created. Only one preview event can exist at a time. async fn confirm(ctx: &Context, msg: &Message, _args: Args) -> CommandResult { - let config = get_config(&ctx.data).await?; - let draft_event = get_draft_event(&ctx.data).await?; + let config = get_config(&ctx.data.read().await).await?; + let draft_event = get_draft_event(&ctx.data.read().await).await?; let mut new_event = draft_event.event.clone(); // Check to to see if message author is the owner of the pending event @@ -61,7 +61,7 @@ async fn confirm(ctx: &Context, msg: &Message, _args: Args) -> CommandResult { /// The user or group that is organizing the event, defaults to the user creating the event async fn create(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { // Get config - let config = get_config(&ctx.data).await?; + let config = get_config(&ctx.data.read().await).await?; // Parse args let event_name = match args.find::() { @@ -173,7 +173,7 @@ async fn create(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { /// /// `~cancel "event name"` async fn cancel(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { - let config = get_config(&ctx.data).await?; + let config = get_config(&ctx.data.read().await).await?; // Parse args let event_name = args.single::()?.replace("\"", ""); @@ -186,12 +186,15 @@ async fn cancel(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { .await?; let cancel_msg = format!("**{}** has been canceled!", event.event_name.clone()); - if let Ok(reaction_users) = message - .reaction_users(&ctx.http, INTERESTED_EMOJI, None, None) - .await - { - for user in reaction_users { - send_dm_message(ctx, user, &cancel_msg).await; + // Only send a cancel message if the even has not already happened + if event.event_time > Utc::now().naive_utc() { + if let Ok(reaction_users) = message + .reaction_users(&ctx.http, INTERESTED_EMOJI, None, None) + .await + { + for user in reaction_users { + send_dm_message(ctx, user, &cancel_msg).await; + } } } diff --git a/src/discord/mod.rs b/src/discord/mod.rs index 1adb3ac..acca503 100644 --- a/src/discord/mod.rs +++ b/src/discord/mod.rs @@ -14,23 +14,24 @@ use serenity::Result; use std::collections::HashMap; use std::sync::Arc; use strfmt::strfmt; +use tokio::sync::RwLockReadGuard; pub mod events; /// Struct for storing drafted events -#[derive(Clone)] +#[derive(Debug, Clone, Default)] pub struct DraftEvent { pub event: NewEvent, pub creator_id: u64, } impl TypeMapKey for DraftEvent { - type Value = DraftEvent; + type Value = Arc; } /// Send a message to a reaction user pub async fn send_message_to_reaction_users(ctx: &Context, reaction: &Reaction, msg_text: &str) { - if let Ok(config) = get_config(&ctx.data).await { + if let Ok(config) = get_config(&ctx.data.read().await).await { let db_link = config.db_url.clone(); let message_id = reaction.message_id.0.to_string(); @@ -149,46 +150,49 @@ pub async fn update_draft_event( creator_id: u64, ) -> CommandResult { let mut data = ctx.data.write().await; - let mut draft_event = data + let draft_event = data .get_mut::() .ok_or_else(|| CommandError::from("Unable get draft event!".to_string()))?; - draft_event.event.event_name = event_name; - draft_event.event.event_desc = event_desc; - draft_event.event.event_loc = location; - draft_event.event.organizer = organizer; - draft_event.event.thumbnail_link = thumbnail; - draft_event.event.message_id = String::new(); - draft_event.event.event_time = event_time; - draft_event.creator_id = creator_id; + let new_draft_event = DraftEvent { + event: NewEvent { + event_name, + event_desc, + event_loc: location, + organizer, + event_time, + message_id: "".to_string(), + thumbnail_link: thumbnail, + reminder_sent: 0, + }, + creator_id, + }; + + *draft_event = Arc::new(new_draft_event); + Ok(()) } /// Sends the draft event stored in the context data pub async fn send_draft_event(ctx: &Context, channel: ChannelId) -> CommandResult { let data = ctx.data.read().await; - let config = data - .get::() - .ok_or_else(|| CommandError::from("Config not found!".to_string()))?; - let draft_event = data - .get::() - .ok_or_else(|| CommandError::from("Draft event not found!".to_string()))?; + let config = get_config(&data).await?; + let draft_event = get_draft_event(&data).await?; channel .send_message(&ctx, |m| { m.content("Draft message, use the `confirm` command to post it.".to_string()) }) .await?; - send_event_msg(ctx, config, channel.0, &draft_event.event, false).await?; + send_event_msg(ctx, &config, channel.0, &draft_event.event, false).await?; Ok(()) } /// Gets the config from context data pub async fn get_config( - data: &Arc>, -) -> std::result::Result { - let data_read = data.read().await; - let config = data_read + data: &RwLockReadGuard<'_, TypeMap>, +) -> std::result::Result, CommandError> { + let config = data .get::() .ok_or_else(|| CommandError::from("Unable to get config".to_string()))?; @@ -197,12 +201,11 @@ pub async fn get_config( /// Gets the draft event from context data pub async fn get_draft_event( - data: &Arc>, -) -> std::result::Result { - let data_read = data.read().await; - let draft_event = data_read + data: &RwLockReadGuard<'_, TypeMap>, +) -> std::result::Result, CommandError> { + let draft_event = data .get::() - .ok_or_else(|| CommandError::from("Unable to queued event".to_string()))?; + .ok_or_else(|| CommandError::from("Unable to get queued event".to_string()))?; Ok(draft_event.clone()) } @@ -225,7 +228,7 @@ pub async fn log_error( #[hook] pub async fn permission_check(ctx: &Context, msg: &Message, _command_name: &str) -> bool { if let Some(guild_id) = msg.guild_id { - if let Ok(config) = get_config(&ctx.data).await { + if let Ok(config) = get_config(&ctx.data.read().await).await { if let Ok(roles) = ctx.http.get_guild_roles(guild_id.0).await { for role in roles { if config.event_roles.contains(&role.id.0) { @@ -244,9 +247,11 @@ pub async fn permission_check(ctx: &Context, msg: &Message, _command_name: &str) /// Schedule event reminders pub async fn schedule_event(ctx: &Context, event: &Event) { - let config = get_config(&ctx.data).await.expect("Unable to get config"); + let config = get_config(&ctx.data.read().await) + .await + .expect("Unable to get config"); - if let Some(reminders) = config.reminders { + if let Some(reminders) = &config.reminders { let event_time: DateTime = DateTime::::from_utc(event.event_time, Utc); for reminder in reminders { @@ -278,7 +283,9 @@ pub async fn send_reminders_task( duration.num_milliseconds() as u64, )) .await; - let config = get_config(&ctx.data).await.expect("Unable to get config"); + let config = get_config(&ctx.data.read().await) + .await + .expect("Unable to get config"); let event_channel_id = config.event_channel; if let Ok(message_id) = event.message_id.parse::() { @@ -303,7 +310,9 @@ pub async fn send_reminders_task( /// Delete event #[allow(dead_code)] pub async fn delete_event(http: &Arc, data: &Arc>, event: &Event) { - let config = get_config(&data).await.expect("Unable to get config"); + let config = get_config(&data.read().await) + .await + .expect("Unable to get config"); remove_event(config.db_url.clone(), event.id).ok(); if let Ok(message_id) = event.message_id.parse::() { diff --git a/src/hypebot_config.rs b/src/hypebot_config.rs index 9d067d3..ab4f76f 100644 --- a/src/hypebot_config.rs +++ b/src/hypebot_config.rs @@ -4,6 +4,7 @@ use serde::de::{self, Error, Visitor}; use serde::{Deserialize, Deserializer}; use serenity::prelude::TypeMapKey; use std::fmt; +use std::sync::Arc; #[derive(Debug, Deserialize, Clone)] pub struct EventReminder { @@ -65,5 +66,5 @@ impl HypeBotConfig { } impl TypeMapKey for HypeBotConfig { - type Value = HypeBotConfig; + type Value = Arc; } diff --git a/src/main.rs b/src/main.rs index 3d5fb9f..3fcd406 100644 --- a/src/main.rs +++ b/src/main.rs @@ -11,7 +11,6 @@ use std::collections::HashSet; use std::path::Path; use std::process::exit; -use chrono::Utc; use clap::{App, Arg}; use log::LevelFilter; use log4rs::append::console::ConsoleAppender; @@ -33,7 +32,6 @@ use serenity::model::id::UserId; use serenity::model::prelude::Ready; use serenity::prelude::{Context, EventHandler}; -use database::models::NewEvent; use database::*; use discord::events::{CANCEL_COMMAND, CONFIRM_COMMAND, CREATE_COMMAND}; use discord::{ @@ -41,6 +39,7 @@ use discord::{ DraftEvent, }; use hypebot_config::HypeBotConfig; +use std::sync::Arc; mod database; mod discord; @@ -65,13 +64,9 @@ struct Handler; impl EventHandler for Handler { /// On reaction add async fn reaction_add(&self, ctx: Context, reaction: Reaction) { - let config = match get_config(&ctx.data).await { - Ok(config) => config, - Err(e) => { - error!("Unable to get config: {}", e); - return; - } - }; + let config = get_config(&ctx.data.read().await) + .await + .expect("Unable to get config"); if reaction.channel_id.0 == config.event_channel && reaction.emoji.as_data().chars().next().unwrap() == INTERESTED_EMOJI { @@ -86,13 +81,10 @@ impl EventHandler for Handler { /// On reaction remove async fn reaction_remove(&self, ctx: Context, reaction: Reaction) { - let config = match get_config(&ctx.data).await { - Ok(config) => config, - Err(e) => { - error!("Unable to get config: {}", e); - return; - } - }; + let config = get_config(&ctx.data.read().await) + .await + .expect("Unable to get config"); + if reaction.channel_id.0 == config.event_channel && reaction.emoji.as_data().chars().next().unwrap() == INTERESTED_EMOJI { @@ -110,7 +102,9 @@ impl EventHandler for Handler { info!("Connected to Discord as {}", ready.user.name); // Schedule current events - let config = get_config(&ctx.data).await.expect("Unable to get config"); + let config = get_config(&ctx.data.read().await) + .await + .expect("Unable to get config"); for event in get_all_events(config.db_url.clone()).unwrap() { if event.reminder_sent == 0 { schedule_event(&ctx, &event).await; @@ -249,20 +243,8 @@ async fn main() -> HypeBotResult<()> { // Copy config data to client data and setup scheduler { let mut data = client.data.write().await; - data.insert::(cfg); - data.insert::(DraftEvent { - event: NewEvent { - message_id: String::new(), - event_time: Utc::now().naive_utc(), - event_name: String::new(), - organizer: String::new(), - event_desc: String::new(), - event_loc: String::new(), - thumbnail_link: String::new(), - reminder_sent: 0 as i32, - }, - creator_id: 0, - }); + data.insert::(Arc::new(cfg)); + data.insert::(Arc::new(DraftEvent::default())); } // Start bot