From 23f99bbe24b48a01037c909e59fc487c78a78bf8 Mon Sep 17 00:00:00 2001 From: Joshua Ashton Date: Tue, 26 Dec 2023 10:53:36 +0000 Subject: [PATCH] Playlist support, new shit --- Cargo.toml | 22 +-- src/main.rs | 386 ++++++++++++++++++++++++++++------------------------ 2 files changed, 219 insertions(+), 189 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 54c454d..f05f253 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ name = "winter" version = "0.1.0" authors = ["Joshua Ashton "] -edition = "2018" +edition = "2021" [dependencies] tracing = "0.1" @@ -13,16 +13,21 @@ base64-stream = "1.2.7" rand = "0.8.5" [dependencies.songbird] -features = ["builtin-queue", "yt-dlp"] -version = "0.3.2" +features = ["builtin-queue"] +git = "https://github.com/serenity-rs/songbird" +branch = "current" + +[dependencies.symphonia] +version = "0.5.2" +features = ["aac", "mp3", "isomp4", "alac"] [dependencies.serenity] -version = "0.11" -features = ["client", "standard_framework", "voice", "rustls_backend"] +version = "0.12" +features = ["cache", "framework", "standard_framework", "voice", "http", "rustls_backend"] [dependencies.tokio] -version = "1.0" -features = ["macros", "rt-multi-thread", "signal"] +version = "1" +features = ["macros", "rt-multi-thread", "signal", "sync"] [dependencies.serde] version = "1.0" @@ -40,6 +45,3 @@ features = [ "fast-rng", # Use a faster (but still sufficiently random) RNG "macro-diagnostics", # Enable better diagnostics for compile-time UUIDs ] - -#[patch.crates-io] -#songbird = { git = "https://github.com/Erk-/songbird", branch="do-not-fail-if-new-opcode" } diff --git a/src/main.rs b/src/main.rs index 39b01bc..353a490 100644 --- a/src/main.rs +++ b/src/main.rs @@ -14,58 +14,74 @@ use std::{ sync::{ Arc, }, - collections::HashSet, - collections::HashMap, fs, }; +use reqwest::Client as HttpClient; + use serenity::{ async_trait, - client::{Client, Context, EventHandler, ClientBuilder}, + client::{Client, Context, EventHandler}, framework::{ standard::{ - macros::{command, group, help}, - help_commands, + macros::{command, group}, Args, CommandResult, - HelpOptions, - CommandGroup, + Configuration, }, StandardFramework, }, http::Http, - model::{channel::Message, gateway::Ready, prelude::ChannelId, prelude::UserId, prelude::GuildId}, + model::{channel::Message, gateway::Ready, prelude::ChannelId}, prelude::{GatewayIntents, Mentionable, TypeMapKey}, - Result as SerenityResult, + Result as SerenityResult, builder::{CreateEmbed, CreateMessage}, gateway::ShardManager, }; -use serenity::utils::Colour; -use serenity::client::bridge::gateway::ShardManager; - -use tracing::{error, info}; +use serenity::all::standard::HelpOptions; +use serenity::all::standard::CommandGroup; +use serenity::all::GuildId; +use serenity::all::UserId; +use serenity::all::ClientBuilder; +use serenity::all::Colour; +use serenity::all::standard::help_commands; +use serenity::all::standard::macros::help; use songbird::{ - input::{ - restartable::Restartable, Input, - }, + input::YoutubeDl, Event, EventContext, EventHandler as VoiceEventHandler, SerenityInit, - TrackEvent, + TrackEvent, tracks::Track, }; -use serde::{Deserialize, Serialize}; +struct HttpKey; +impl TypeMapKey for HttpKey { + type Value = HttpClient; +} + +struct ShardManagerContainer; + +impl TypeMapKey for ShardManagerContainer { + type Value = Arc; +} + +struct Handler; + +use serde::{Deserialize, Serialize}; +use tracing::error; + +use std::fs; use std::fs::File; use std::io::Write; +use std::collections::HashSet; +use std::collections::HashMap; use base64_stream::FromBase64Writer; -use tokio::sync::Mutex; +use tokio::{sync::Mutex, process::Command}; use uuid::Uuid; -struct Handler; - #[async_trait] impl EventHandler for Handler { async fn ready(&self, _: Context, ready: Ready) { @@ -73,16 +89,11 @@ impl EventHandler for Handler { } } -pub struct ShardManagerContainer; - -impl TypeMapKey for ShardManagerContainer { - type Value = Arc>; -} - #[group] #[commands( leave, mute, play, skip, stop, ping, unmute, volume, vox, chaos, restart, seek, tts, tts_list )] + struct General; #[help] @@ -208,6 +219,12 @@ pub async fn winter_get(ctx: &Context) -> Option>> { ///////// +async fn get_http_client(ctx: &Context) -> HttpClient { + let data = ctx.data.read().await; + data.get::() + .cloned() + .expect("Guaranteed to exist in the typemap.") +} #[tokio::main] async fn main() { @@ -216,10 +233,8 @@ async fn main() { // Configure the client with your Discord bot token in the environment. let token = env::var("DISCORD_TOKEN").expect("Expected a token in the environment"); - let framework = StandardFramework::new() - .configure(|c| c.prefix("~")) - .group(&GENERAL_GROUP) - .help(&MY_HELP); + let framework = StandardFramework::new().group(&GENERAL_GROUP); + framework.configure(Configuration::new().prefix("~")); let intents = GatewayIntents::non_privileged() | GatewayIntents::MESSAGE_CONTENT; @@ -229,6 +244,7 @@ async fn main() { .framework(framework) .register_songbird() .register_winter() + .type_map_insert::(HttpClient::new()) .await .expect("Err creating client"); @@ -241,7 +257,7 @@ async fn main() { tokio::spawn(async move { tokio::signal::ctrl_c().await.expect("Could not register ctrl+c handler"); - shard_manager.lock().await.shutdown_all().await; + shard_manager.shutdown_all().await; }); if let Err(why) = client.start().await { @@ -250,13 +266,15 @@ async fn main() { } async fn ensure_joined(ctx: &Context, msg: &Message) -> bool { - let guild = msg.guild(&ctx.cache).unwrap(); - let guild_id = guild.id; + let (guild_id, channel_id) = { + let guild = msg.guild(&ctx.cache).unwrap(); + let channel_id = guild + .voice_states + .get(&msg.author.id) + .and_then(|voice_state| voice_state.channel_id); - let channel_id = guild - .voice_states - .get(&msg.author.id) - .and_then(|voice_state| voice_state.channel_id); + (guild.id, channel_id) + }; let connect_to = match channel_id { Some(channel) => channel, @@ -289,9 +307,7 @@ async fn ensure_joined(ctx: &Context, msg: &Message) -> bool { } } - let (handler_lock, success) = manager.join(guild_id, connect_to).await; - - if let Ok(_channel) = success { + if let Ok(handler_lock) = manager.join(guild_id, connect_to).await { check_msg( msg.channel_id .say(&ctx.http, &format!("Joined {}", connect_to.mention())) @@ -382,8 +398,7 @@ impl VoiceEventHandler for TrackEndNotifier { #[command] #[only_in(guilds)] async fn leave(ctx: &Context, msg: &Message) -> CommandResult { - let guild = msg.guild(&ctx.cache).unwrap(); - let guild_id = guild.id; + let guild_id = msg.guild_id.unwrap(); let manager = songbird::get(ctx) .await @@ -411,8 +426,7 @@ async fn leave(ctx: &Context, msg: &Message) -> CommandResult { #[command] #[only_in(guilds)] async fn mute(ctx: &Context, msg: &Message) -> CommandResult { - let guild = msg.guild(&ctx.cache).unwrap(); - let guild_id = guild.id; + let guild_id = msg.guild_id.unwrap(); let manager = songbird::get(ctx) .await @@ -472,24 +486,42 @@ impl VoiceEventHandler for SongEndNotifier { } } +#[derive(Deserialize, Serialize, Debug)] +pub struct YtDlpOutput { + pub artist: Option, + pub album: Option, + pub channel: Option, + pub duration: Option, + pub filesize: Option, + pub http_headers: Option>, + pub release_date: Option, + pub thumbnail: Option, + pub title: Option, + pub track: Option, + pub upload_date: Option, + pub uploader: Option, + pub url: String, + pub webpage_url: Option, +} + #[command] #[only_in(guilds)] -async fn play(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { +async fn play(ctx: &Context, msg: &Message, args: Args) -> CommandResult { + let guild_id = msg.guild_id.unwrap(); + let url = args.rest().to_string(); if url.is_empty() { reply(&ctx, &msg, "Tell me what you want!", "You must provide a URL or search term for me to play video or audio!", None, false).await; return Ok(()); } - let guild = msg.guild(&ctx.cache).unwrap(); - let guild_id = guild.id; + let http_client = get_http_client(ctx).await; let manager = songbird::get(ctx) .await .expect("Songbird Voice client placed in at initialisation.") .clone(); - let winter_lock = winter_get(ctx) .await .expect("Winter placed in at initialisation.") @@ -501,71 +533,87 @@ async fn play(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { return Ok(()); } + let do_search = !url.starts_with("http"); + + let sources = if do_search { + let mut vec = Vec::new(); + vec.push(YoutubeDl::new_search(http_client, url)); + vec + } else { + let ytdl_args = [ + "-j", + url.as_str(), + "--flat-playlist", + ]; + + let mut output = Command::new("yt-dlp") + .args(ytdl_args) + .output() + .await + .unwrap(); + + if !output.status.success() { + println!("Fuck! Couldn't run yt-dlp"); + return Ok(()); + } + + let out = output + .stdout + .split_mut(|&b| b == b'\n') + .filter_map(|x| (!x.is_empty()).then(|| serde_json::from_slice(x))) + .collect::, _>>() + .unwrap(); + + let mut vec = Vec::new(); + for playlist_src in out { + vec.push(YoutubeDl::new(http_client.clone(), playlist_src.url)); + } + vec + }; + if let Some(handler_lock) = manager.get(guild_id) { let mut handler = handler_lock.lock().await; - // Here, we use lazy restartable sources to make sure that we don't pay - // for decoding, playback on tracks which aren't actually live yet. - let source : Restartable; + for source in sources { + let mut input = songbird::input::Input::from(source); - if url.starts_with("http:") || url.starts_with("https:") { - source = match Restartable::ytdl(url, false).await { - Ok(source) => source, - Err(why) => { - println!("Err starting source: {:?}", why); - reply(&ctx, &msg, "Oh no!", &format!("Error playing video. Reason: {:?}", why), None, true).await; - return Ok(()); - }, - }; - } else { - source = match Restartable::ytdl_search(url, false).await { - Ok(source) => source, - Err(why) => { - println!("Err starting source: {:?}", why); - reply(&ctx, &msg, "Oh no!", &format!("Error playing video. Reason: {:?}", why), None, true).await; - return Ok(()); - }, - }; - } + let aux_metadata = input.aux_metadata().await.unwrap(); - let input = Input::from(source); - let title = input.metadata.title.clone().unwrap_or("Unknown Title".to_string()); - let artist = input.metadata.artist.clone().unwrap_or("Unknown Artist".to_string()); - let thumbnail = input.metadata.thumbnail.clone(); - let mut duration = "Unknown duration".to_string(); - if input.metadata.duration.is_some() { - let meta_duration = input.metadata.duration.unwrap(); + let title = aux_metadata.title.clone().unwrap_or("Unknown Title".to_string()); + let artist = aux_metadata.artist.clone().unwrap_or("Unknown Artist".to_string()); + let thumbnail = aux_metadata.thumbnail.clone(); + let mut duration = "Unknown duration".to_string(); + if aux_metadata.duration.is_some() { + let meta_duration = aux_metadata.duration.unwrap(); - let seconds = meta_duration.as_secs() % 60; - let minutes = (meta_duration.as_secs() / 60) % 60; - let hours = (meta_duration.as_secs() / 60) / 60; - duration = format!("{:02}:{:02}:{:02}", hours, minutes, seconds).to_string(); - } + let seconds = meta_duration.as_secs() % 60; + let minutes = (meta_duration.as_secs() / 60) % 60; + let hours = (meta_duration.as_secs() / 60) / 60; + duration = format!("{:02}:{:02}:{:02}", hours, minutes, seconds).to_string(); + } - let (mut track, _handle) = songbird::tracks::create_player(input); - track.set_volume(winter.options.get_volume(guild_id)); - handler.enqueue(track); + let track = songbird::tracks::Track::new(input) + .volume(winter.options.get_volume(guild_id)); + handler.enqueue(track).await; + let mut embed = CreateEmbed::new() + .title(title.clone()) + .field("Arist", artist, true) + .field("Duration", duration, true) + .field("Queue Position", &format!("{:02}", handler.queue().len()), true) + .color(Colour::from_rgb(202,255,239)); - let msg = msg - .channel_id - .send_message(&ctx.http, |m| { - m.embed(|e| { - e.title(title.clone()); - if thumbnail.is_some() { - e.thumbnail(thumbnail.unwrap()); - } - e.field("Artist", artist, true); - e.field("Duration", duration, true); - e.field("Queue Position", &format!("{:02}", handler.queue().len()), true); - e.color(Colour::from_rgb(202,255,239)); - e - }) - }) - .await; + if thumbnail.is_some() { + embed = embed.thumbnail(thumbnail.unwrap()); + } - if let Err(why) = msg { - println!("Error sending message: {:?}", why); + let builder = CreateMessage::new().embed(embed); + + let msg = msg.channel_id.send_message(&ctx, builder).await; + + if let Err(why) = msg { + println!("Error sending message: {:?}", why); + } } } @@ -575,6 +623,8 @@ async fn play(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { #[command] #[only_in(guilds)] async fn chaos(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { + let guild_id = msg.guild_id.unwrap(); + let url = match args.single::() { Ok(url) => url, Err(_) => { @@ -594,8 +644,7 @@ async fn chaos(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { return Ok(()); } - let guild = msg.guild(&ctx.cache).unwrap(); - let guild_id = guild.id; + let http_client = get_http_client(ctx).await; let manager = songbird::get(ctx) .await @@ -614,22 +663,21 @@ async fn chaos(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { return Ok(()); } + let do_search = !url.starts_with("http"); + if let Some(handler_lock) = manager.get(guild_id) { let mut handler = handler_lock.lock().await; - let source = match songbird::ytdl(&url).await { - Ok(source) => source, - Err(why) => { - println!("Err starting source: {:?}", why); - reply(&ctx, &msg, "Oh no!", &format!("Error playing video. Reason: {:?}", why), None, true).await; - return Ok(()); - }, + let src = if do_search { + YoutubeDl::new_search(http_client, url) + } else { + YoutubeDl::new(http_client, url) }; check_msg(msg.reply(&ctx.http, "Added song to chaos mode.").await); - let (mut track, _handle) = songbird::tracks::create_player(source.into()); - track.set_volume(winter.options.get_volume(guild_id)); + let track = songbird::tracks::Track::new(src.into()) + .volume(winter.options.get_volume(guild_id)); handler.play(track); } @@ -639,6 +687,8 @@ async fn chaos(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { #[command] #[only_in(guilds)] async fn vox(ctx: &Context, msg: &Message, args: Args) -> CommandResult { + let guild_id = msg.guild_id.unwrap(); + let line = args.rest(); if line.is_empty() { check_msg( @@ -651,14 +701,19 @@ async fn vox(ctx: &Context, msg: &Message, args: Args) -> CommandResult { let words = line.split(" "); - let guild = msg.guild(&ctx.cache).unwrap(); - let guild_id = guild.id; - let manager = songbird::get(ctx) .await .expect("Songbird Voice client placed in at initialisation.") .clone(); + + let winter_lock = winter_get(ctx) + .await + .expect("Winter placed in at initialisation.") + .clone(); + + let winter = winter_lock.lock().await; + if !ensure_joined(ctx, msg).await { return Ok(()); } @@ -671,19 +726,12 @@ async fn vox(ctx: &Context, msg: &Message, args: Args) -> CommandResult { if word.chars().all(|x| x.is_alphanumeric()) { let vox_path = format!("./assets/vox/{}.wav", word); - let source = match songbird::ffmpeg(&vox_path).await { - Ok(source) => source, - Err(why) => { - println!("Err starting source: {:?}", why); - reply(&ctx, &msg, "Oh no!", &format!("Error playing video. Reason: {:?}", why), None, true).await; - return Ok(()); - }, - }; - - handler.enqueue_source(source.into()); + let source = songbird::input::File::new(vox_path); + let track = Track::new(source.into()) + .volume(winter.options.get_volume(guild_id)); + handler.enqueue(track).await; } } - check_msg( msg.channel_id @@ -770,7 +818,7 @@ use rand::seq::SliceRandom; #[command] #[only_in(guilds)] -async fn tts_list(ctx: &Context, msg: &Message, args: Args) -> CommandResult { +async fn tts_list(ctx: &Context, msg: &Message, _args: Args) -> CommandResult { let mut voice_list : String = "".to_string(); for voice in TIKTOK_VOICES { voice_list += voice; @@ -783,6 +831,8 @@ async fn tts_list(ctx: &Context, msg: &Message, args: Args) -> CommandResult { #[command] #[only_in(guilds)] async fn tts(ctx: &Context, msg: &Message, args: Args) -> CommandResult { + let guild_id = msg.guild_id.unwrap(); + let mut line = args.rest(); if line.is_empty() { check_msg( @@ -810,9 +860,6 @@ async fn tts(ctx: &Context, msg: &Message, args: Args) -> CommandResult { } } - let guild = msg.guild(&ctx.cache).unwrap(); - let guild_id = guild.id; - let manager = songbird::get(ctx) .await .expect("Songbird Voice client placed in at initialisation.") @@ -862,17 +909,10 @@ async fn tts(ctx: &Context, msg: &Message, args: Args) -> CommandResult { writer.write_all(response.data.as_bytes()).unwrap(); writer.flush().unwrap(); - let source = match songbird::ffmpeg(file_path).await { - Ok(source) => source, - Err(why) => { - println!("Err starting source: {:?}", why); - reply(&ctx, &msg, "Oh no!", &format!("Error playing tts. Reason: {:?}", why), None, true).await; - return Ok(()); - }, - }; + let source = songbird::input::File::new(file_path); + let track = Track::new(source.into()) + .volume(winter.options.get_volume(guild_id)); - let (mut track, _handle) = songbird::tracks::create_player(source.into()); - track.set_volume(winter.options.get_volume(guild_id)); handler.play(track); check_msg( @@ -891,8 +931,7 @@ async fn tts(ctx: &Context, msg: &Message, args: Args) -> CommandResult { #[command] #[only_in(guilds)] async fn skip(ctx: &Context, msg: &Message, _args: Args) -> CommandResult { - let guild = msg.guild(&ctx.cache).unwrap(); - let guild_id = guild.id; + let guild_id = msg.guild_id.unwrap(); let manager = songbird::get(ctx) .await @@ -922,8 +961,7 @@ async fn skip(ctx: &Context, msg: &Message, _args: Args) -> CommandResult { #[command] #[only_in(guilds)] async fn stop(ctx: &Context, msg: &Message, _args: Args) -> CommandResult { - let guild = msg.guild(&ctx.cache).unwrap(); - let guild_id = guild.id; + let guild_id = msg.guild_id.unwrap(); let manager = songbird::get(ctx) .await @@ -947,8 +985,7 @@ async fn stop(ctx: &Context, msg: &Message, _args: Args) -> CommandResult { #[command] #[only_in(guilds)] async fn unmute(ctx: &Context, msg: &Message) -> CommandResult { - let guild = msg.guild(&ctx.cache).unwrap(); - let guild_id = guild.id; + let guild_id = msg.guild_id.unwrap(); let manager = songbird::get(ctx) .await .expect("Songbird Voice client placed in at initialisation.") @@ -988,8 +1025,7 @@ async fn volume(ctx: &Context, msg: &Message, args: Args) -> CommandResult { let volume_human = volume_human_eval.unwrap() as f32; let volume = volume_human / 100.0; - let guild = msg.guild(&ctx.cache).unwrap(); - let guild_id = guild.id; + let guild_id = msg.guild_id.unwrap(); let manager = songbird::get(ctx) .await @@ -1037,8 +1073,7 @@ async fn seek(ctx: &Context, msg: &Message, args: Args) -> CommandResult { let seek = std::time::Duration::from_secs_f32(seek_human); - let guild = msg.guild(&ctx.cache).unwrap(); - let guild_id = guild.id; + let guild_id = msg.guild_id.unwrap(); let manager = songbird::get(ctx) .await @@ -1049,14 +1084,7 @@ async fn seek(ctx: &Context, msg: &Message, args: Args) -> CommandResult { let handler = handler_lock.lock().await; for (_, track) in handler.queue().current_queue().iter().enumerate() { - if !track.is_seekable() { - reply(&ctx, &msg, "I can't do that!", "This track is not seekable. (Livestream, etc)", None, true).await; - return Ok(()); - } - if !track.seek_time(seek).is_ok() { - reply(&ctx, &msg, "I can't do that!", "Failed to set current position", None, true).await; - return Ok(()); - } + let _success = track.seek(seek); } } @@ -1075,7 +1103,7 @@ async fn restart(ctx: &Context, msg: &Message) -> CommandResult { if let Some(manager) = data.get::() { msg.reply(ctx, "Shutting down!").await?; - manager.lock().await.shutdown_all().await; + manager.shutdown_all().await; } else { msg.reply(ctx, "There was a problem getting the shard manager").await?; @@ -1093,23 +1121,23 @@ fn check_msg(result: SerenityResult) { } async fn reply>(ctx: &Context, context_msg: &Message, title: S, desc: S, image: Option, error: bool) { - let msg = context_msg - .channel_id - .send_message(&ctx.http, |m| { - m.embed(|e| { - e.title(title.into()).description(desc.into()); - if image.is_some() { - e.image(image.unwrap().into()); - } - if error { - e.color(Colour::from_rgb(255,218,218)); - } else { - e.color(Colour::from_rgb(223,255,198)); - } - e - }) - }) - .await; + let mut embed = CreateEmbed::new() + .title(title.into()) + .description(desc.into()); + + if image.is_some() { + embed = embed.image(image.unwrap().into()); + } + + if error { + embed = embed.color(Colour::from_rgb(255,218,218)); + } else { + embed = embed.color(Colour::from_rgb(223,255,198)); + } + + let builder = CreateMessage::new().embed(embed); + + let msg = context_msg.channel_id.send_message(&ctx, builder).await; if let Err(why) = msg { println!("Error sending message: {:?}", why);