From f53491901f0947d026c6285a9765cb3af35836b6 Mon Sep 17 00:00:00 2001 From: Miguel Piedrafita Date: Thu, 16 Nov 2023 17:17:27 -0800 Subject: [PATCH] switch `dashmap` for `RwLock>` --- Cargo.lock | 14 -------------- Cargo.toml | 1 - src/bot/join_check/mod.rs | 16 +++++++++++----- src/bot/mod.rs | 6 +++--- src/main.rs | 17 +++++------------ src/server/mod.rs | 8 ++++++-- 6 files changed, 25 insertions(+), 37 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 762d13f..5b5fd17 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -349,19 +349,6 @@ dependencies = [ "syn 2.0.32", ] -[[package]] -name = "dashmap" -version = "5.5.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" -dependencies = [ - "cfg-if", - "hashbrown 0.14.0", - "lock_api", - "once_cell", - "parking_lot_core", -] - [[package]] name = "deranged" version = "0.3.8" @@ -2305,7 +2292,6 @@ version = "0.1.0" dependencies = [ "axum", "config", - "dashmap", "dotenvy", "humantime-serde", "indoc", diff --git a/Cargo.toml b/Cargo.toml index f46f8e1..2deda32 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,7 +22,6 @@ log = "0.4" rand = "0.8" url = "2.4.1" indoc = "2.0.4" -dashmap = "5.5" axum = "0.6.20" dotenvy = "0.15.7" serde_with = "3.3" diff --git a/src/bot/join_check/mod.rs b/src/bot/join_check/mod.rs index fe6f8a5..8568273 100644 --- a/src/bot/join_check/mod.rs +++ b/src/bot/join_check/mod.rs @@ -57,7 +57,9 @@ pub async fn join_handler( .await? .id; - join_requests.insert((msg.chat.id, user.id), JoinRequest::new(msg_id)); + let mut join_requests_wr = join_requests.write().await; + join_requests_wr.insert((msg.chat.id, user.id), JoinRequest::new(msg_id)); + drop(join_requests_wr); tokio::spawn({ let bot = bot.clone(); @@ -65,13 +67,16 @@ pub async fn join_handler( async move { tokio::time::sleep(ban_after).await; - if let Some((_, data)) = join_requests.remove(&(msg.chat.id, user.id)) { - if !data.is_verified { + let mut join_requests_wr = join_requests.write().await; + if let Some(join_req) = join_requests_wr.remove(&(msg.chat.id, user.id)) { + drop(join_requests_wr); + + if !join_req.is_verified { bot.ban_chat_member(msg.chat.id, user.id) .await .expect("Failed to ban the member after timeout"); - if let Some(msg_id) = data.msg_id { + if let Some(msg_id) = join_req.msg_id { bot.delete_message(msg.chat.id, msg_id) .await .expect("Failed to delete the message after timeout"); @@ -91,7 +96,8 @@ pub async fn on_verified( user_id: UserId, join_requests: JoinRequests, ) -> HandlerResult { - let mut join_req = join_requests + let mut join_requests_w = join_requests.write().await; + let join_req = join_requests_w .get_mut(&(chat_id, user_id)) .ok_or("Can't find the message id in group dialogue")?; diff --git a/src/bot/mod.rs b/src/bot/mod.rs index 0310f75..da8a158 100644 --- a/src/bot/mod.rs +++ b/src/bot/mod.rs @@ -1,5 +1,4 @@ -use dashmap::DashMap; -use std::sync::Arc; +use std::{collections::HashMap, sync::Arc}; use teloxide::{ dispatching::{MessageFilterExt, UpdateFilterExt}, prelude::{dptree, Dispatcher}, @@ -8,6 +7,7 @@ use teloxide::{ utils::command::BotCommands, Bot, }; +use tokio::sync::RwLock; use crate::{bot::commands::Command, config::AppConfig}; pub use join_check::on_verified; @@ -16,7 +16,7 @@ mod commands; mod join_check; type HandlerResult = Result<(), HandlerError>; -pub type JoinRequests = Arc>; +pub type JoinRequests = Arc>>; type HandlerError = Box; #[derive(Clone)] diff --git a/src/main.rs b/src/main.rs index b3b019d..fd5381c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,16 +1,9 @@ -use dashmap::DashMap; use dotenvy::dotenv; -use std::sync::Arc; -use teloxide::{ - requests::Requester, - types::{ChatId, UserId}, - Bot, -}; +use std::{collections::HashMap, sync::Arc}; +use teloxide::{requests::Requester, Bot}; +use tokio::sync::RwLock; -use crate::{ - bot::{JoinRequest, JoinRequests}, - config::AppConfig, -}; +use crate::{bot::JoinRequests, config::AppConfig}; mod bot; mod config; @@ -22,7 +15,7 @@ async fn main() { pretty_env_logger::init(); let config = AppConfig::try_read().expect("Failed to read config"); - let join_requests: JoinRequests = Arc::new(DashMap::<(ChatId, UserId), JoinRequest>::new()); + let join_requests: JoinRequests = Arc::new(RwLock::new(HashMap::new())); let bot = Bot::new(&config.bot_token); let bot_data = bot.get_me().await.expect("Failed to get bot account"); diff --git a/src/server/mod.rs b/src/server/mod.rs index 2a835b7..e6cdb5b 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -50,6 +50,8 @@ async fn verify_page( Path((chat_id, user_id)): Path<(ChatId, UserId)>, Extension(join_reqs): Extension, ) -> Result, StatusCode> { + let join_reqs = join_reqs.read().await; + let join_req = join_reqs .get(&(chat_id, user_id)) .ok_or(StatusCode::NOT_FOUND)?; @@ -107,7 +109,9 @@ async fn verify_api( Extension(join_reqs): Extension, Json(req): Json, ) -> Result<&'static str, StatusCode> { - let join_req = join_reqs + let join_reqs_r = join_reqs.read().await; + + let join_req = join_reqs_r .get(&(chat_id, user_id)) .ok_or(StatusCode::NOT_FOUND)?; let msg_id = join_req.msg_id.ok_or(StatusCode::CONFLICT)?; @@ -131,7 +135,7 @@ async fn verify_api( .error_for_status() .map_err(|_| StatusCode::BAD_REQUEST)?; - drop(join_req); + drop(join_reqs_r); on_verified(bot, chat_id, user_id, join_reqs) .await