diff --git a/Cargo.toml b/Cargo.toml index 81067b6..3c11452 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,12 +7,6 @@ edition = "2021" name = "spoticord" path = "src/main.rs" -[profile.release] -lto = true -codegen-units = 1 -strip = true -opt-level = "z" - [dependencies] chrono = "0.4.22" dotenv = "0.15.0" diff --git a/src/bot/commands/mod.rs b/src/bot/commands/mod.rs index 4a64166..6c08f02 100644 --- a/src/bot/commands/mod.rs +++ b/src/bot/commands/mod.rs @@ -44,13 +44,32 @@ pub async fn respond_message( } } +pub async fn defer_message( + ctx: &Context, + command: &ApplicationCommandInteraction, + ephemeral: bool, +) { + if let Err(why) = command + .create_interaction_response(&ctx.http, |response| { + response + .kind(InteractionResponseType::DeferredChannelMessageWithSource) + .interaction_response_data(|message| message.ephemeral(ephemeral)) + }) + .await + { + error!("Error deferring message: {:?}", why); + } +} + pub type CommandOutput = Pin + Send>>; pub type CommandExecutor = fn(Context, ApplicationCommandInteraction) -> CommandOutput; +#[derive(Clone)] pub struct CommandManager { commands: HashMap, } +#[derive(Clone)] pub struct CommandInfo { pub name: String, pub executor: CommandExecutor, diff --git a/src/bot/commands/music/join.rs b/src/bot/commands/music/join.rs index 1347ac1..e65339f 100644 --- a/src/bot/commands/music/join.rs +++ b/src/bot/commands/music/join.rs @@ -5,7 +5,7 @@ use serenity::{ }; use crate::{ - bot::commands::{respond_message, CommandOutput}, + bot::commands::{defer_message, respond_message, CommandOutput}, session::manager::{SessionCreateError, SessionManager}, utils::embed::{EmbedBuilder, Status}, }; @@ -46,6 +46,7 @@ pub fn run(ctx: Context, command: ApplicationCommandInteraction) -> CommandOutpu // Check if another session is already active in this server let session_opt = session_manager.get_session(guild.id).await; + if let Some(session) = &session_opt { if let Some(owner) = session.get_owner().await { let msg = if owner == command.user.id { @@ -91,6 +92,8 @@ pub fn run(ctx: Context, command: ApplicationCommandInteraction) -> CommandOutpu return; } + defer_message(&ctx, &command, true).await; + if let Some(session) = &session_opt { if let Err(why) = session.update_owner(&ctx, command.user.id).await { // Need to link first diff --git a/src/bot/events.rs b/src/bot/events.rs index 16a49a6..40188e1 100644 --- a/src/bot/events.rs +++ b/src/bot/events.rs @@ -23,7 +23,10 @@ impl EventHandler for Handler { debug!("Ready received, logged in as {}", ready.user.name); - command_manager.register_commands(&ctx).await; + // Set this to true only when a command is removed/updated/created + if false { + command_manager.register_commands(&ctx).await; + } ctx.set_activity(Activity::listening(MOTD)).await; @@ -32,10 +35,15 @@ impl EventHandler for Handler { // INTERACTION_CREATE event, emitted when the bot receives an interaction (slash command, button, etc.) async fn interaction_create(&self, ctx: Context, interaction: Interaction) { + trace!("interaction_create START"); + if let Interaction::ApplicationCommand(command) = interaction { // Commands must only be executed inside of guilds - if command.guild_id.is_none() { - command + + let guild_id = match command.guild_id { + Some(guild_id) => guild_id, + None => { + if let Err(why) = command .create_interaction_response(&ctx.http, |response| { response .kind(serenity::model::prelude::interaction::InteractionResponseType::ChannelMessageWithSource) @@ -43,17 +51,20 @@ impl EventHandler for Handler { message.content("You can only execute commands inside of a server") }) }) - .await - .unwrap(); + .await { + error!("Failed to send run-in-guild-only error message: {}", why); + } - return; - } + trace!("interaction_create END2"); + return; + } + }; trace!( "Received command interaction: command={} user={} guild={}", command.data.name, command.user.id, - command.guild_id.unwrap() + guild_id ); let data = ctx.data.read().await; @@ -61,5 +72,7 @@ impl EventHandler for Handler { command_manager.execute_command(&ctx, command).await; } + + trace!("interaction_create END"); } } diff --git a/src/consts.rs b/src/consts.rs index 920f012..519980d 100644 --- a/src/consts.rs +++ b/src/consts.rs @@ -1,3 +1,3 @@ pub const VERSION: &str = env!("CARGO_PKG_VERSION"); -pub const MOTD: &str = "OPEN BETA (v2)"; +pub const MOTD: &str = "UNSTABLE BETA (v2)"; // pub const MOTD: &str = "some good 'ol music"; diff --git a/src/ipc/mod.rs b/src/ipc/mod.rs index 57e824c..6b231b9 100644 --- a/src/ipc/mod.rs +++ b/src/ipc/mod.rs @@ -1,6 +1,6 @@ use std::sync::{Arc, Mutex}; -use ipc_channel::ipc::{self, IpcError, IpcOneShotServer, IpcReceiver, IpcSender}; +use ipc_channel::ipc::{self, IpcError, IpcOneShotServer, IpcReceiver, IpcSender, TryRecvError}; use self::packet::IpcPacket; @@ -66,4 +66,8 @@ impl Client { pub fn recv(&self) -> Result { self.rx.lock().unwrap().recv() } + + pub fn try_recv(&self) -> Result { + self.rx.lock().unwrap().try_recv() + } } diff --git a/src/main.rs b/src/main.rs index 83855c7..aabe8a9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -23,7 +23,7 @@ mod session; mod stats; mod utils; -#[tokio::main(flavor = "multi_thread")] +#[tokio::main] async fn main() { if std::env::var("RUST_LOG").is_err() { #[cfg(debug_assertions)] @@ -39,6 +39,14 @@ async fn main() { env_logger::init(); + let orig_hook = std::panic::take_hook(); + std::panic::set_hook(Box::new(move |panic_info| { + error!("Panic: {}", panic_info); + + orig_hook(panic_info); + std::process::exit(1); + })); + let args: Vec = env::args().collect(); if args.len() > 2 { @@ -95,6 +103,7 @@ async fn main() { let shard_manager = client.shard_manager.clone(); let cache = client.cache_and_http.cache.clone(); + #[cfg(unix)] let mut sigterm = tokio::signal::unix::signal(SignalKind::terminate()).unwrap(); // Background tasks diff --git a/src/session/manager.rs b/src/session/manager.rs index 365145e..dd489a1 100644 --- a/src/session/manager.rs +++ b/src/session/manager.rs @@ -49,11 +49,12 @@ impl SessionManager { channel_id: ChannelId, owner_id: UserId, ) -> Result<(), SessionCreateError> { + // Create session first to make sure locks are kept for as little time as possible + let session = SpoticordSession::new(ctx, guild_id, channel_id, owner_id).await?; + let mut sessions = self.sessions.write().await; let mut owner_map = self.owner_map.write().await; - let session = SpoticordSession::new(ctx, guild_id, channel_id, owner_id).await?; - sessions.insert(guild_id, Arc::new(session)); owner_map.insert(owner_id, guild_id); diff --git a/src/session/mod.rs b/src/session/mod.rs index 8c8bc9b..8935f19 100644 --- a/src/session/mod.rs +++ b/src/session/mod.rs @@ -4,7 +4,7 @@ use crate::{ ipc::{self, packet::IpcPacket, Client}, utils::{self, spotify}, }; -use ipc_channel::ipc::IpcError; +use ipc_channel::ipc::{IpcError, TryRecvError}; use librespot::core::spotify_id::{SpotifyAudioType, SpotifyId}; use log::*; use serenity::{ @@ -21,6 +21,7 @@ use songbird::{ use std::{ process::{Command, Stdio}, sync::Arc, + time::Duration, }; use tokio::sync::Mutex; @@ -262,11 +263,18 @@ impl SpoticordSession { // Required for IpcPacket::TrackChange to work tokio::task::yield_now().await; - let msg = match ipc_client.recv() { + let msg = match ipc_client.try_recv() { Ok(msg) => msg, Err(why) => { - if let IpcError::Disconnected = why { - break; + if let TryRecvError::Empty = why { + // No message, wait a bit and try again + tokio::time::sleep(Duration::from_millis(25)).await; + + continue; + } else if let TryRecvError::IpcError(why) = &why { + if let IpcError::Disconnected = why { + break; + } } error!("Failed to receive IPC message: {:?}", why); @@ -407,8 +415,10 @@ impl SpoticordSession { } }; - let mut owner = self.owner.write().await; - *owner = Some(owner_id); + { + let mut owner = self.owner.write().await; + *owner = Some(owner_id); + } session_manager.set_owner(owner_id, self.guild_id).await; diff --git a/src/utils/spotify.rs b/src/utils/spotify.rs index 4857233..7414b8a 100644 --- a/src/utils/spotify.rs +++ b/src/utils/spotify.rs @@ -46,34 +46,53 @@ pub async fn get_username(token: impl Into) -> Result { let token = token.into(); let client = reqwest::Client::new(); - let response = match client - .get("https://api.spotify.com/v1/me") - .bearer_auth(token) - .send() - .await - { - Ok(response) => response, - Err(why) => { - error!("Failed to get username: {}", why); - return Err(format!("{}", why)); - } - }; + let mut retries = 3; - let body: Value = match response.json().await { - Ok(body) => body, - Err(why) => { - error!("Failed to parse body: {}", why); - return Err(format!("{}", why)); - } - }; + loop { + let response = match client + .get("https://api.spotify.com/v1/me") + .bearer_auth(&token) + .send() + .await + { + Ok(response) => response, + Err(why) => { + error!("Failed to get username: {}", why); + return Err(format!("{}", why)); + } + }; - if let Value::String(username) = &body["id"] { - trace!("Got username: {}", username); - return Ok(username.clone()); + if response.status().as_u16() >= 500 && retries > 0 { + retries -= 1; + continue; + } + + if response.status() != 200 { + return Err( + format!( + "Failed to get track info: Invalid status code: {}", + response.status() + ) + .into(), + ); + } + + let body: Value = match response.json().await { + Ok(body) => body, + Err(why) => { + error!("Failed to parse body: {}", why); + return Err(format!("{}", why)); + } + }; + + if let Value::String(username) = &body["id"] { + trace!("Got username: {}", username); + return Ok(username.clone()); + } + + error!("Missing 'id' field in body: {:#?}", body); + return Err("Failed to parse body: Invalid body received".to_string()); } - - error!("Missing 'id' field in body"); - Err("Failed to parse body: Invalid body received".to_string()) } pub async fn get_track_info( @@ -83,26 +102,35 @@ pub async fn get_track_info( let token = token.into(); let client = reqwest::Client::new(); - let response = client - .get(format!( - "https://api.spotify.com/v1/tracks/{}", - track.to_base62()? - )) - .bearer_auth(token) - .send() - .await?; + let mut retries = 3; - if response.status() != 200 { - return Err( - format!( - "Failed to get track info: Invalid status code: {}", - response.status() - ) - .into(), - ); + loop { + let response = client + .get(format!( + "https://api.spotify.com/v1/tracks/{}", + track.to_base62()? + )) + .bearer_auth(&token) + .send() + .await?; + + if response.status().as_u16() >= 500 && retries > 0 { + retries -= 1; + continue; + } + + if response.status() != 200 { + return Err( + format!( + "Failed to get track info: Invalid status code: {}", + response.status() + ) + .into(), + ); + } + + return Ok(response.json().await?); } - - Ok(response.json().await?) } pub async fn get_episode_info( @@ -112,24 +140,33 @@ pub async fn get_episode_info( let token = token.into(); let client = reqwest::Client::new(); - let response = client - .get(format!( - "https://api.spotify.com/v1/episodes/{}", - episode.to_base62()? - )) - .bearer_auth(token) - .send() - .await?; + let mut retries = 3; - if response.status() != 200 { - return Err( - format!( - "Failed to get episode info: Invalid status code: {}", - response.status() - ) - .into(), - ); + loop { + let response = client + .get(format!( + "https://api.spotify.com/v1/episodes/{}", + episode.to_base62()? + )) + .bearer_auth(&token) + .send() + .await?; + + if response.status().as_u16() >= 500 && retries > 0 { + retries -= 1; + continue; + } + + if response.status() != 200 { + return Err( + format!( + "Failed to get episode info: Invalid status code: {}", + response.status() + ) + .into(), + ); + } + + return Ok(response.json().await?); } - - Ok(response.json().await?) }