Skip to content

Commit

Permalink
switch dashmap for RwLock<HashMap<_, _>>
Browse files Browse the repository at this point in the history
  • Loading branch information
m1guelpf committed Nov 17, 2023
1 parent fdd3a34 commit f534919
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 37 deletions.
14 changes: 0 additions & 14 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ log = "0.4"
rand = "0.8"
url = "2.4.1"
indoc = "2.0.4"
dashmap = "5.5"
axum = "0.6.20"
dotenvy = "0.15.7"
serde_with = "3.3"
Expand Down
16 changes: 11 additions & 5 deletions src/bot/join_check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,21 +57,26 @@ pub async fn join_handler(
.await?
.id;

join_requests.insert((msg.chat.id, user.id), JoinRequest::new(msg_id));
let mut join_requests_wr = join_requests.write().await;
join_requests_wr.insert((msg.chat.id, user.id), JoinRequest::new(msg_id));
drop(join_requests_wr);

tokio::spawn({
let bot = bot.clone();
let ban_after = chat_cfg.ban_after;
async move {
tokio::time::sleep(ban_after).await;

if let Some((_, data)) = join_requests.remove(&(msg.chat.id, user.id)) {
if !data.is_verified {
let mut join_requests_wr = join_requests.write().await;
if let Some(join_req) = join_requests_wr.remove(&(msg.chat.id, user.id)) {
drop(join_requests_wr);

if !join_req.is_verified {
bot.ban_chat_member(msg.chat.id, user.id)
.await
.expect("Failed to ban the member after timeout");

if let Some(msg_id) = data.msg_id {
if let Some(msg_id) = join_req.msg_id {
bot.delete_message(msg.chat.id, msg_id)
.await
.expect("Failed to delete the message after timeout");
Expand All @@ -91,7 +96,8 @@ pub async fn on_verified(
user_id: UserId,
join_requests: JoinRequests,
) -> HandlerResult {
let mut join_req = join_requests
let mut join_requests_w = join_requests.write().await;
let join_req = join_requests_w
.get_mut(&(chat_id, user_id))
.ok_or("Can't find the message id in group dialogue")?;

Expand Down
6 changes: 3 additions & 3 deletions src/bot/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use dashmap::DashMap;
use std::sync::Arc;
use std::{collections::HashMap, sync::Arc};
use teloxide::{
dispatching::{MessageFilterExt, UpdateFilterExt},
prelude::{dptree, Dispatcher},
Expand All @@ -8,6 +7,7 @@ use teloxide::{
utils::command::BotCommands,
Bot,
};
use tokio::sync::RwLock;

use crate::{bot::commands::Command, config::AppConfig};
pub use join_check::on_verified;
Expand All @@ -16,7 +16,7 @@ mod commands;
mod join_check;

type HandlerResult = Result<(), HandlerError>;
pub type JoinRequests = Arc<DashMap<(ChatId, UserId), JoinRequest>>;
pub type JoinRequests = Arc<RwLock<HashMap<(ChatId, UserId), JoinRequest>>>;
type HandlerError = Box<dyn std::error::Error + Send + Sync>;

#[derive(Clone)]
Expand Down
17 changes: 5 additions & 12 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
use dashmap::DashMap;
use dotenvy::dotenv;
use std::sync::Arc;
use teloxide::{
requests::Requester,
types::{ChatId, UserId},
Bot,
};
use std::{collections::HashMap, sync::Arc};
use teloxide::{requests::Requester, Bot};
use tokio::sync::RwLock;

use crate::{
bot::{JoinRequest, JoinRequests},
config::AppConfig,
};
use crate::{bot::JoinRequests, config::AppConfig};

mod bot;
mod config;
Expand All @@ -22,7 +15,7 @@ async fn main() {
pretty_env_logger::init();

let config = AppConfig::try_read().expect("Failed to read config");
let join_requests: JoinRequests = Arc::new(DashMap::<(ChatId, UserId), JoinRequest>::new());
let join_requests: JoinRequests = Arc::new(RwLock::new(HashMap::new()));

let bot = Bot::new(&config.bot_token);
let bot_data = bot.get_me().await.expect("Failed to get bot account");
Expand Down
8 changes: 6 additions & 2 deletions src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ async fn verify_page(
Path((chat_id, user_id)): Path<(ChatId, UserId)>,
Extension(join_reqs): Extension<JoinRequests>,
) -> Result<Html<String>, StatusCode> {
let join_reqs = join_reqs.read().await;

let join_req = join_reqs
.get(&(chat_id, user_id))
.ok_or(StatusCode::NOT_FOUND)?;
Expand Down Expand Up @@ -107,7 +109,9 @@ async fn verify_api(
Extension(join_reqs): Extension<JoinRequests>,
Json(req): Json<VerifyRequest>,
) -> Result<&'static str, StatusCode> {
let join_req = join_reqs
let join_reqs_r = join_reqs.read().await;

let join_req = join_reqs_r
.get(&(chat_id, user_id))
.ok_or(StatusCode::NOT_FOUND)?;
let msg_id = join_req.msg_id.ok_or(StatusCode::CONFLICT)?;
Expand All @@ -131,7 +135,7 @@ async fn verify_api(
.error_for_status()
.map_err(|_| StatusCode::BAD_REQUEST)?;

drop(join_req);
drop(join_reqs_r);

on_verified(bot, chat_id, user_id, join_reqs)
.await
Expand Down

0 comments on commit f534919

Please sign in to comment.