diff --git a/src/database/mod.rs b/src/database/mod.rs index ddc262c..ebdfdb1 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -1,19 +1,21 @@ -pub mod schema; +#[warn(dead_code)] pub mod models; +pub mod schema; use diesel::prelude::*; -use models::{Event, NewEvent}; use diesel::result::Error; use diesel::update; +use models::{Event, NewEvent}; use std::vec::Vec; +/// Establish a connection to the database pub fn establish_connection(database_url: String) -> MysqlConnection { MysqlConnection::establish(&database_url) .expect(&format!("Error connecting to {}", database_url)) } - +/// Insert an event into the database pub fn insert_event(databse_url: String, new_event: &NewEvent) -> Event { - use schema::events::dsl::{id, events}; + use schema::events::dsl::{events, id}; let connection = establish_connection(databse_url); @@ -25,22 +27,29 @@ pub fn insert_event(databse_url: String, new_event: &NewEvent) -> Event { events.order(id).first(&connection).unwrap() } +/// Get an event by name pub fn get_event_by_name(database_url: String, name: String) -> Result { use schema::events::dsl::{event_name, events}; let connection = establish_connection(database_url); - events.filter(event_name.eq(&name)).get_result::(&connection) + events + .filter(event_name.eq(&name)) + .get_result::(&connection) } +/// Get event by its message id pub fn get_event_by_msg_id(database_url: String, msg_id: String) -> Result { - use schema::events::dsl::{message_id, events}; + use schema::events::dsl::{events, message_id}; let connection = establish_connection(database_url); - events.filter(message_id.eq(&msg_id)).get_result::(&connection) + events + .filter(message_id.eq(&msg_id)) + .get_result::(&connection) } +/// Get all events pub fn get_all_events(database_url: String) -> Result, Error> { use schema::events::dsl::{event_time, events}; @@ -49,10 +58,13 @@ pub fn get_all_events(database_url: String) -> Result, Error> { events.order(event_time).load(&connection) } -pub fn set_reminder(database_url: String, event_id: i32) -> Result { +/// Set the reminder state of an event +pub fn set_reminder(database_url: String, event_id: i32, state: i32) -> Result { use schema::events::dsl::{events, id, reminder_sent}; let connection = establish_connection(database_url); let target = events.filter(id.eq(event_id)); - update(target).set(reminder_sent.eq(1)).execute(&connection) -} \ No newline at end of file + update(target) + .set(reminder_sent.eq(state)) + .execute(&connection) +} diff --git a/src/database/models.rs b/src/database/models.rs index 1596bf5..98e2604 100644 --- a/src/database/models.rs +++ b/src/database/models.rs @@ -1,24 +1,37 @@ use super::schema::events; -use chrono::{NaiveDateTime}; +use chrono::NaiveDateTime; #[derive(Queryable)] pub struct Event { + /// Event ID pub id: i32, + /// Event name pub event_name: String, + /// Event long description pub event_desc: String, + /// Event datetime pub event_time: NaiveDateTime, + /// Event discord message id pub message_id: String, + /// Event message thumbnail link pub thumbnail_link: String, + /// Reminder sent tracker pub reminder_sent: i32, } -#[derive(Insertable)] -#[table_name="events"] +#[derive(Insertable, Clone)] +#[table_name = "events"] pub struct NewEvent { + /// Event name pub event_name: String, + /// Event long description pub event_desc: String, + /// Event datetime pub event_time: NaiveDateTime, + /// Event discord message id pub message_id: String, + /// Event message thumbnail link pub thumbnail_link: String, + /// Reminder sent tracker pub reminder_sent: i32, } diff --git a/src/hypebot_config.rs b/src/hypebot_config.rs index 90504ba..0350261 100644 --- a/src/hypebot_config.rs +++ b/src/hypebot_config.rs @@ -1,5 +1,9 @@ -use config::{ConfigError, Config, File}; +use chrono_tz::Tz; +use config::{Config, ConfigError, File}; +use serde::de::{self, Error, Visitor}; +use serde::{Deserialize, Deserializer}; use serenity::prelude::TypeMapKey; +use std::fmt; #[derive(Debug, Deserialize)] pub struct HypeBotConfig { @@ -9,7 +13,37 @@ pub struct HypeBotConfig { pub prefix: String, pub event_channel: u64, pub event_roles: Vec, - pub event_timezone: String, + #[serde(deserialize_with = "from_tz_string")] + pub event_timezone: Tz, +} + +struct ConfigValueVisitor; +impl<'de> Visitor<'de> for ConfigValueVisitor { + type Value = String; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "Unable to parse timezone field.") + } + + fn visit_str(self, s: &str) -> Result + where + E: de::Error, + { + Ok(s.to_string()) + } +} + +fn from_tz_string<'de, D>(deserializer: D) -> Result +where + D: Deserializer<'de>, +{ + let string = deserializer.deserialize_struct("Value", &["into_str"], ConfigValueVisitor)?; + + let tz: Tz = string.parse().ok().ok_or(D::Error::custom( + "Unable to parse datetime, should be in format \"Country/City\"", + ))?; + + Ok(tz) } impl HypeBotConfig { @@ -24,3 +58,17 @@ impl HypeBotConfig { impl TypeMapKey for HypeBotConfig { type Value = HypeBotConfig; } + +impl Clone for HypeBotConfig { + fn clone(&self) -> Self { + HypeBotConfig { + db_url: self.db_url.clone(), + default_thumbnail_link: self.default_thumbnail_link.clone(), + discord_key: self.discord_key.clone(), + prefix: self.prefix.clone(), + event_channel: self.event_channel.clone(), + event_roles: self.event_roles.clone(), + event_timezone: self.event_timezone.clone(), + } + } +} diff --git a/src/main.rs b/src/main.rs index 9e1149c..ca86e3f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,40 +1,37 @@ #[macro_use] extern crate diesel; #[macro_use] -extern crate serde_derive; -#[macro_use] extern crate diesel_migrations; extern crate serde; -use serenity::client::Client; -use serenity::Result; -use serenity::model::channel::{Message, Reaction}; -use serenity::prelude::{EventHandler, Context, ShareMap, RwLock}; -use serenity::utils::{content_safe, ContentSafeOptions, Colour}; -use serenity::framework::standard::StandardFramework; -use serenity::framework::standard::CommandResult; -use serenity::framework::standard::macros::{command, group,}; -use serenity::framework::standard::Args; -use serenity::prelude::TypeMapKey; -use serenity::model::error::Error; -use clap::{Arg, App}; -use chrono::{DateTime, Utc, NaiveDateTime, Datelike, Timelike, TimeZone}; +use chrono::{DateTime, Datelike, NaiveDateTime, TimeZone, Timelike, Utc}; use chrono_tz::Tz; +use clap::{App, Arg}; +use serenity::client::Client; +use serenity::framework::standard::macros::{command, group}; +use serenity::framework::standard::Args; +use serenity::framework::standard::{CommandError, CommandResult, StandardFramework}; +use serenity::http::Http; +use serenity::model::channel::{Message, Reaction}; +use serenity::model::prelude::{ChannelId, Ready}; +use serenity::prelude::TypeMapKey; +use serenity::prelude::{Context, EventHandler, RwLock, ShareMap}; +use serenity::utils::{content_safe, Colour, ContentSafeOptions}; +use serenity::CacheAndHttp; +use serenity::Result; use std::process::exit; +use std::sync::Arc; +use std::thread; use std::thread::sleep; use std::time::Duration; -use std::sync::Arc; -use serenity::CacheAndHttp; -use std::thread; -mod hypebot_config; mod database; +mod hypebot_config; +use crate::database::models::NewEvent; use database::*; use hypebot_config::HypeBotConfig; -use crate::database::models::{NewEvent}; -use serenity::http::Http; -use serenity::model::prelude::Ready; +use serenity::model::user::User; const INTERESTED_EMOJI: &str = "\u{2705}"; const UNINTERESTED_EMOJI: &str = "\u{274C}"; @@ -47,46 +44,47 @@ struct EventCommands; /// Handler for Discord events struct Handler; - - impl EventHandler for Handler { - /// On bot ready - fn ready(&self, _: Context, ready: Ready) { - println!("Connected as {}", ready.user.name); - } - /// On reaction fn reaction_add(&self, ctx: Context, reaction: Reaction) { if reaction.emoji.as_data() == INTERESTED_EMOJI { - let data = ctx.data.read(); - let config = data.get::().unwrap(); - let db_link = config.db_url.clone(); - let message_id = reaction.message_id.0.to_string(); + if let Ok(config) = get_config(&ctx.data) { + let db_link = config.db_url.clone(); + let message_id = reaction.message_id.0.to_string(); - let event = match get_event_by_msg_id(db_link, message_id) { - Ok(event) => event, - Err(_) => { - return; + let event = match get_event_by_msg_id(db_link, message_id) { + Ok(event) => event, + Err(_) => { + return; + } + }; + + if let Ok(user) = ctx.http.get_user(reaction.user_id.0) { + if let Ok(dm_channel) = user.create_dm_channel(&ctx.http) { + dm_channel + .send_message(&ctx.http, |m| { + m.content(format!( + "You have signed up to receive reminders for **{}**!", + &event.event_name + )) + }) + .ok(); + } } - }; - - let user = ctx.http.get_user(reaction.user_id.0).unwrap(); - - let dm_channel = user.create_dm_channel(&ctx.http).unwrap(); - - dm_channel.send_message(&ctx.http, |m| { - m.content(format!("You have signed up to receive reminders for **{}**!", &event.event_name)) - }).ok(); + } } + } - + /// On bot ready + fn ready(&self, _: Context, ready: Ready) { + println!("Connected as {}", ready.user.name); } } /// Struct for storing drafted events struct DraftEvent { pub event: NewEvent, - pub creator_id: u64 + pub creator_id: u64, } impl TypeMapKey for DraftEvent { @@ -94,15 +92,18 @@ impl TypeMapKey for DraftEvent { } embed_migrations!("migrations/"); -fn main() { +fn main() -> clap::Result<()> { // Initialize arg parser let mut app = App::new("Hype Bot") - .about("Hype Bot: Hype Up Your Discord Events!").arg(Arg::with_name("config") - .index(1) - .short("c"). - long("config") - .value_name("CONFIG_PATH") - .help("Config file path")); + .about("Hype Bot: Hype Up Your Discord Events!") + .arg( + Arg::with_name("config") + .index(1) + .short("c") + .long("config") + .value_name("CONFIG_PATH") + .help("Config file path"), + ); // Get arg parser let matches = app.clone().get_matches(); @@ -123,13 +124,19 @@ fn main() { embedded_migrations::run(&connection).expect("Unable to run migrations"); // New client - let mut client = Client::new(cfg.discord_key.clone(), Handler) - .expect("Error creating client"); + let mut client = + Client::new(cfg.discord_key.clone(), Handler).expect("Error creating client"); // Configure client - client.with_framework(StandardFramework::new() - .configure(|c| c.prefix(cfg.prefix.as_str().clone())) - .group(&EVENTCOMMANDS_GROUP)); + client.with_framework( + StandardFramework::new() + .configure(|c| { + c.prefix(cfg.prefix.as_str().clone()) + .allow_dm(false) + .ignore_bots(true) + }) + .group(&EVENTCOMMANDS_GROUP), + ); // Copy config data to client data { @@ -144,115 +151,196 @@ fn main() { thumbnail_link: String::new(), reminder_sent: 0 as i32, }, - creator_id: 0 + creator_id: 0, }); } let data = client.data.clone(); let cache_and_http = client.cache_and_http.clone(); - thread::spawn(move || { - send_reminders(&cache_and_http, &data) - }); + thread::spawn(move || send_reminders(&cache_and_http, &data)); // Start bot println!("Starting Hypebot!"); if let Err(why) = client.start() { println!("An error occurred while running the client: {:?}", why); } - } - else { + } else { // Print help - app.print_help().unwrap(); + app.print_help()?; } + + Ok(()) } +/// Thread to send reminders to users fn send_reminders(cache_and_http: &Arc, data: &Arc>) { let sleep_duration = Duration::from_secs(60); + let config = get_config(data).unwrap(); loop { sleep(sleep_duration); - if let Some(config) = data.read().get::() { - let http = &cache_and_http.http; - let event_channel_id = config.event_channel; + let http = &cache_and_http.http; + let event_channel_id = config.event_channel; - if let Ok(events)= get_all_events(config.db_url.clone()) { - for event in events { - let utc_time = DateTime::::from_utc(event.event_time.clone(), Utc); - let time_to_event = utc_time - chrono::offset::Utc::now(); + // Get all current events + if let Ok(events) = get_all_events(config.db_url.clone()) { + for event in events { + // Get time to event + let utc_time = DateTime::::from_utc(event.event_time.clone(), Utc); + let time_to_event = (utc_time - chrono::offset::Utc::now()).num_minutes(); + // If the event starts in less than 10 minutes + if time_to_event <= 10 && time_to_event > 0 && event.reminder_sent == 1 { + // Get message isd + if let Ok(message_id) = event.message_id.parse::() { + if let Ok(message) = http.get_message(event_channel_id, message_id) { + let reaction_users = message + .reaction_users(&http, INTERESTED_EMOJI, None, None) + .unwrap_or(Vec::::new()); - if time_to_event.num_minutes() < 10 && event.reminder_sent != 1{ - if let Ok(message_id) = event.message_id.parse::() { - let message = http.get_message(event_channel_id, message_id).unwrap(); - for user in message.reaction_users(&http,INTERESTED_EMOJI, None, None).unwrap() { + // Send reminder to each reacted user + for user in reaction_users { if let Ok(dm_channel) = user.create_dm_channel(&http) { - dm_channel.send_message(&http, |m| { - m.content(format!("Hello! **{}** begins in **{} minutes**!", &event.event_name, time_to_event.num_minutes())) - }).ok(); + dm_channel + .send_message(&http, |m| { + m.content(format!( + "Hello! **{}** begins in **{} minutes**!", + &event.event_name, time_to_event + )) + }) + .ok(); } } - - set_reminder(config.db_url.clone(), event.id).ok(); } + + set_reminder(config.db_url.clone(), event.id, 1).ok(); } } } } - } } -fn send_event_msg(http: &Http, config: &HypeBotConfig, channel_id: u64, event: &NewEvent) -> Result { - let channel = http.get_channel(channel_id).unwrap(); +/// Sends the event message to the event channel +fn send_event_msg( + http: &Http, + config: &HypeBotConfig, + channel_id: u64, + event: &NewEvent, + react: bool, +) -> Result { + let channel = http.get_channel(channel_id)?; - let tz: Tz = config.event_timezone.parse().unwrap(); let utc_time = DateTime::::from_utc(event.event_time.clone(), Utc); - let native_time = utc_time.with_timezone(&tz); + + let native_time = utc_time.with_timezone(&config.event_timezone); // Send message let msg = channel.id().send_message(&http, |m| { m.embed(|e| { e.title(event.event_name.clone()) .color(Colour::PURPLE) - .description(format!("**{}**\n{}", native_time.format("%A, %B %d @ %I:%M %P %t %Z"), event.event_desc)) + .description(format!( + "**{}**\n{}", + native_time.format("%A, %B %d @ %I:%M %P %t %Z"), + event.event_desc + )) .thumbnail(event.thumbnail_link.clone()) - .footer(|f| { f.text("Local Event Time") }) + .footer(|f| f.text("Local Event Time")) .timestamp(utc_time.to_rfc3339()) }) })?; - // Add reacts - msg.react(http, INTERESTED_EMOJI).unwrap(); - msg.react(http, UNINTERESTED_EMOJI).unwrap(); + if react { + // Add reacts + msg.react(http, INTERESTED_EMOJI)?; + msg.react(http, UNINTERESTED_EMOJI)?; + } Ok(msg) } -#[command] -fn confirm_event(ctx: &mut Context, msg: &Message, _args: Args) -> CommandResult { +/// Updates the draft event stored in the context data +fn update_draft_event( + ctx: &Context, + event_name: String, + event_desc: String, + thumbnail: String, + event_time: NaiveDateTime, + creator_id: u64, +) -> CommandResult { + let mut data = ctx.data.write(); + let mut draft_event = data + .get_mut::() + .ok_or(CommandError("Unable get draft event!".to_string()))?; + + draft_event.event.event_name = event_name; + draft_event.event.event_desc = event_desc; + draft_event.event.thumbnail_link = thumbnail; + draft_event.event.message_id = String::new(); + draft_event.event.event_time = event_time; + draft_event.creator_id = creator_id; + Ok(()) +} + +/// Sends the draft event stored in the context data +fn send_draft_event(ctx: &Context, channel: ChannelId) -> CommandResult { let data = ctx.data.read(); - let config = data.get::().unwrap(); - let draft_event = match data.get::() { - Some(draft_event) => Ok(draft_event), - None => Err(Error::ItemMissing) - }?; - let new_event = &draft_event.event; + let config = data + .get::() + .ok_or(CommandError("Config not found!".to_string()))?; + let draft_event = data + .get::() + .ok_or(CommandError("Draft event not found!".to_string()))?; - if draft_event.creator_id == msg.author.id.0 { - let event_msg = send_event_msg(&ctx.http, config, config.event_channel, new_event)?; - msg.reply(&ctx, "Event posted!")?; + channel.send_message(&ctx, |m| { + m.content(format!( + "Draft message, use the `confirm_event` command to post it." + )) + })?; + send_event_msg(&ctx.http, config, channel.0, &draft_event.event, false)?; + Ok(()) +} - let new_event = NewEvent { - message_id: event_msg.id.0.to_string(), - event_time: new_event.event_time.clone(), - event_desc: new_event.event_desc.clone(), - event_name: new_event.event_name.clone(), - thumbnail_link: new_event.event_name.clone(), - reminder_sent: 0, - }; +/// Gets the config from context data +fn get_config(data: &Arc>) -> std::result::Result { + let data_read = data.read(); + let config = data_read + .get::() + .ok_or(CommandError("Unable to get config".to_string()))?; - insert_event(config.db_url.clone(), &new_event); + Ok(config.clone()) +} - } - else { - msg.reply(&ctx, format!("You do not have a pending event!"))?; +#[command] +/// Posts the pending event in the shared context +fn confirm_event(ctx: &mut Context, msg: &Message, _args: Args) -> CommandResult { + let config = get_config(&ctx.data)?; + let data = ctx.data.read(); + + // Get draft event + if let Some(draft_event) = data.get::() { + let new_event = &draft_event.event; + // Check to to see if message author is the owner of the pending event + if draft_event.creator_id == msg.author.id.0 { + // Send event message + let event_msg = + send_event_msg(&ctx.http, &config, config.event_channel, new_event, true)?; + + msg.reply(&ctx, "Event posted!")?; + + let new_event = NewEvent { + message_id: event_msg.id.0.to_string(), + event_time: new_event.event_time.clone(), + event_desc: new_event.event_desc.clone(), + event_name: new_event.event_name.clone(), + thumbnail_link: new_event.event_name.clone(), + reminder_sent: 0, + }; + + insert_event(config.db_url.clone(), &new_event); + } else { + msg.reply(&ctx, format!("You do not have a pending event!"))?; + } + } else { + msg.reply(&ctx, format!("There are no pending events!!"))?; } Ok(()) @@ -260,69 +348,60 @@ fn confirm_event(ctx: &mut Context, msg: &Message, _args: Args) -> CommandResult #[command] /// Creates an event and announce it -fn create_event (ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResult { - let mut event_name; - let mut description; - let thumbnail_link; - let date; +fn create_event(ctx: &mut Context, msg: &Message, mut args: Args) -> CommandResult { + // Get config + let config = get_config(&ctx.data)?; + let guild_id = msg + .guild_id + .ok_or(CommandError("Unable to get guild ID".to_string()))?; - { - // Open config - let data = ctx.data.read(); - let config = data.get::().unwrap(); + // Parse args + let event_name = args.single::()?.replace("\"", ""); + let date_string = args.single::()?.replace("\"", ""); + let description = args.single::()?.replace("\"", ""); + let thumbnail_link = match args.single::() { + Ok(link) => link.replace("<", "").replace(">", ""), + Err(_) => config.default_thumbnail_link.clone(), + }; - // Parse args - event_name = args.single::()?.replace("\"", ""); - let date_string = args.single::()?.replace("\"", ""); - description = args.single::()?.replace("\"", ""); - thumbnail_link = match args.single::() { - Ok(link) => link.replace("<", "").replace(">", ""), - Err(_) => config.default_thumbnail_link.clone(), - }; + // Parse date + let tz: Tz = config.event_timezone; + let input_date = NaiveDateTime::parse_from_str(date_string.as_str(), "%I:%M%p %Y-%m-%d")?; - // Parse date - let tz: Tz = config.event_timezone.parse()?; - let input_date = NaiveDateTime::parse_from_str(date_string.as_str(), "%I:%M%p %Y-%m-%d")?; - let input_date = tz.ymd(input_date.date().year(), input_date.date().month(), input_date.date().day()) - .and_hms(input_date.time().hour(), input_date.time().minute(), input_date.time().second()); - date = input_date.with_timezone(&Utc); + let input_date = tz + .ymd( + input_date.date().year(), + input_date.date().month(), + input_date.date().day(), + ) + .and_hms( + input_date.time().hour(), + input_date.time().minute(), + input_date.time().second(), + ); - // Clean channel, role, and everyone pings - let settings = ContentSafeOptions::default() - .clean_role(true) - .clean_here(true) - .clean_user(true) - .clean_everyone(true) - .display_as_member_from(msg.guild_id.unwrap()); + let event_time = input_date.with_timezone(&Utc).naive_utc(); - description = content_safe(&ctx.cache, description, &settings); - event_name = content_safe(&ctx.cache, event_name, &settings); - } + // Clean channel, role, and everyone pings + let settings = ContentSafeOptions::default() + .clean_role(true) + .clean_here(true) + .clean_user(true) + .clean_everyone(true) + .display_as_member_from(guild_id); - { - let mut data = ctx.data.write(); - let mut draft_event = match data.get_mut::() { - Some(event) => event, - None => { - println!("Error"); - panic!("Can't get write lock") - } - }; - draft_event.event.event_name = event_name; - draft_event.event.event_desc = description; - draft_event.event.thumbnail_link = thumbnail_link; - draft_event.event.message_id = String::new(); - draft_event.event.event_time = date.naive_utc(); + let description = content_safe(&ctx.cache, description, &settings); + let event_name = content_safe(&ctx.cache, event_name, &settings); - draft_event.creator_id = msg.author.id.0; - } - - { - let data = ctx.data.read(); - let config = data.get::().unwrap(); - msg.reply(&ctx, format!("Draft message, use the `confirm_event` command to post it."))?; - send_event_msg(&ctx.http, config, msg.channel_id.0, &data.get::().unwrap().event)?; - } + update_draft_event( + &ctx, + event_name, + description, + thumbnail_link, + event_time, + msg.author.id.0, + )?; + send_draft_event(&ctx, msg.channel_id)?; Ok(()) -} \ No newline at end of file +}