From a14657d124da22cf90a4c1b1e24f19cc13624c07 Mon Sep 17 00:00:00 2001
From: dullbananas <dull.bananas0@gmail.com>
Date: Thu, 19 Oct 2023 06:31:51 -0700
Subject: [PATCH] Refactor rate limiter and improve rate limit bucket cleanup
 (#3937)

* Update rate_limiter.rs

* Update mod.rs

* Update rate_limiter.rs

* Update rate_limiter.rs

* Update mod.rs

* Update scheduled_tasks.rs

* Shrink `RateLimitBucket`

* Update rate_limiter.rs

* Update mod.rs

* Update rate_limiter.rs

* Update rate_limiter.rs

* Update rate_limiter.rs

* Update rate_limiter.rs

* Update mod.rs

* Update rate_limiter.rs

* fmt

* Update rate_limiter.rs

* Update rate_limiter.rs

* Update rate_limiter.rs

* Update rate_limiter.rs

* Update rate_limiter.rs

* Update rate_limiter.rs

* Update rate_limiter.rs

* Update rate_limiter.rs

* Update rate_limiter.rs

* rerun ci

* Update rate_limiter.rs

* Undo changes to  fields

* Manually undo changes to RateLimitBucket fields

* fmt

* Bucket cleanup loop in rate_limit/mod.rs

* Remove rate limit bucket cleanup from scheduled_tasks.rs

* Remove ;

* Remove UNINITIALIZED_TOKEN_AMOUNT

* Update rate_limiter.rs

* fmt

* Update rate_limiter.rs

* fmt

* Update rate_limiter.rs

* fmt

* Update rate_limiter.rs

* stuff

* MapLevel trait

* fix merge

* Prevent negative numbers in buckets

* Clean up MapLevel::check

* MapLevel::remove_full_buckets

* stuff

* Use remove_full_buckets to avoid allocations

* stuff

* remove tx

* Remove RateLimitConfig

* Rename settings_updated_channel to rate_limit_cell

* Remove global rate limit cell

* impl Default for RateLimitCell

* bucket_configs doc comment to explain EnumMap

* improve test_rate_limiter

* rename default to with_test_config

---------

Co-authored-by: Dessalines <dessalines@users.noreply.github.com>
Co-authored-by: Nutomic <me@nutomic.com>
---
 Cargo.lock                                  |   1 +
 Cargo.toml                                  |   1 +
 crates/api_common/Cargo.toml                |   1 +
 crates/api_common/src/claims.rs             |   6 +-
 crates/api_common/src/context.rs            |   2 +-
 crates/api_common/src/utils.rs              |  37 +-
 crates/api_crud/src/site/create.rs          |   5 +-
 crates/api_crud/src/site/update.rs          |   5 +-
 crates/apub/src/objects/mod.rs              |   8 +-
 crates/utils/Cargo.toml                     |   2 +-
 crates/utils/src/rate_limit/mod.rs          | 250 +++++-------
 crates/utils/src/rate_limit/rate_limiter.rs | 420 ++++++++++++--------
 src/lib.rs                                  |   4 +-
 src/scheduled_tasks.rs                      |  11 -
 src/session_middleware.rs                   |   6 +-
 15 files changed, 376 insertions(+), 383 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index debd8bcc5..073e0a95c 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -2659,6 +2659,7 @@ dependencies = [
  "anyhow",
  "chrono",
  "encoding",
+ "enum-map",
  "futures",
  "getrandom",
  "jsonwebtoken",
diff --git a/Cargo.toml b/Cargo.toml
index 356abb035..9bf1000b6 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -129,6 +129,7 @@ rustls = { version = "0.21.3", features = ["dangerous_configuration"] }
 futures-util = "0.3.28"
 tokio-postgres = "0.7.8"
 tokio-postgres-rustls = "0.10.0"
+enum-map = "2.6"
 
 [dependencies]
 lemmy_api = { workspace = true }
diff --git a/crates/api_common/Cargo.toml b/crates/api_common/Cargo.toml
index 5325350c8..a01e6008c 100644
--- a/crates/api_common/Cargo.toml
+++ b/crates/api_common/Cargo.toml
@@ -68,6 +68,7 @@ actix-web = { workspace = true, optional = true }
 jsonwebtoken = { version = "8.3.0", optional = true }
 # necessary for wasmt compilation
 getrandom = { version = "0.2.10", features = ["js"] }
+enum-map = { workspace = true }
 
 [dev-dependencies]
 serial_test = { workspace = true }
diff --git a/crates/api_common/src/claims.rs b/crates/api_common/src/claims.rs
index 6676840dc..09191ad71 100644
--- a/crates/api_common/src/claims.rs
+++ b/crates/api_common/src/claims.rs
@@ -88,7 +88,7 @@ mod tests {
     traits::Crud,
     utils::build_db_pool_for_tests,
   };
-  use lemmy_utils::rate_limit::{RateLimitCell, RateLimitConfig};
+  use lemmy_utils::rate_limit::RateLimitCell;
   use reqwest::Client;
   use reqwest_middleware::ClientBuilder;
   use serial_test::serial;
@@ -103,9 +103,7 @@ mod tests {
       pool_.clone(),
       ClientBuilder::new(Client::default()).build(),
       secret,
-      RateLimitCell::new(RateLimitConfig::builder().build())
-        .await
-        .clone(),
+      RateLimitCell::with_test_config(),
     );
 
     let inserted_instance = Instance::read_or_create(pool, "my_domain.tld".to_string())
diff --git a/crates/api_common/src/context.rs b/crates/api_common/src/context.rs
index 0d448ef97..888a98741 100644
--- a/crates/api_common/src/context.rs
+++ b/crates/api_common/src/context.rs
@@ -46,7 +46,7 @@ impl LemmyContext {
   pub fn secret(&self) -> &Secret {
     &self.secret
   }
-  pub fn settings_updated_channel(&self) -> &RateLimitCell {
+  pub fn rate_limit_cell(&self) -> &RateLimitCell {
     &self.rate_limit_cell
   }
 }
diff --git a/crates/api_common/src/utils.rs b/crates/api_common/src/utils.rs
index b3dcd7558..5ba9a34c3 100644
--- a/crates/api_common/src/utils.rs
+++ b/crates/api_common/src/utils.rs
@@ -7,6 +7,7 @@ use crate::{
 use actix_web::cookie::{Cookie, SameSite};
 use anyhow::Context;
 use chrono::{DateTime, Days, Local, TimeZone, Utc};
+use enum_map::{enum_map, EnumMap};
 use lemmy_db_schema::{
   newtypes::{CommunityId, DbUrl, PersonId, PostId},
   source::{
@@ -34,7 +35,7 @@ use lemmy_utils::{
   email::{send_email, translations::Lang},
   error::{LemmyError, LemmyErrorExt, LemmyErrorType, LemmyResult},
   location_info,
-  rate_limit::RateLimitConfig,
+  rate_limit::{ActionType, BucketConfig},
   settings::structs::Settings,
   utils::slurs::build_slur_regex,
 };
@@ -390,25 +391,21 @@ fn lang_str_to_lang(lang: &str) -> Lang {
 }
 
 pub fn local_site_rate_limit_to_rate_limit_config(
-  local_site_rate_limit: &LocalSiteRateLimit,
-) -> RateLimitConfig {
-  let l = local_site_rate_limit;
-  RateLimitConfig {
-    message: l.message,
-    message_per_second: l.message_per_second,
-    post: l.post,
-    post_per_second: l.post_per_second,
-    register: l.register,
-    register_per_second: l.register_per_second,
-    image: l.image,
-    image_per_second: l.image_per_second,
-    comment: l.comment,
-    comment_per_second: l.comment_per_second,
-    search: l.search,
-    search_per_second: l.search_per_second,
-    import_user_settings: l.import_user_settings,
-    import_user_settings_per_second: l.import_user_settings_per_second,
-  }
+  l: &LocalSiteRateLimit,
+) -> EnumMap<ActionType, BucketConfig> {
+  enum_map! {
+    ActionType::Message => (l.message, l.message_per_second),
+    ActionType::Post => (l.post, l.post_per_second),
+    ActionType::Register => (l.register, l.register_per_second),
+    ActionType::Image => (l.image, l.image_per_second),
+    ActionType::Comment => (l.comment, l.comment_per_second),
+    ActionType::Search => (l.search, l.search_per_second),
+    ActionType::ImportUserSettings => (l.import_user_settings, l.import_user_settings_per_second),
+  }
+  .map(|_key, (capacity, secs_to_refill)| BucketConfig {
+    capacity: u32::try_from(capacity).unwrap_or(0),
+    secs_to_refill: u32::try_from(secs_to_refill).unwrap_or(0),
+  })
 }
 
 pub fn local_site_to_slur_regex(local_site: &LocalSite) -> Option<Regex> {
diff --git a/crates/api_crud/src/site/create.rs b/crates/api_crud/src/site/create.rs
index 61dfd7c77..1449f4844 100644
--- a/crates/api_crud/src/site/create.rs
+++ b/crates/api_crud/src/site/create.rs
@@ -119,10 +119,7 @@ pub async fn create_site(
 
   let rate_limit_config =
     local_site_rate_limit_to_rate_limit_config(&site_view.local_site_rate_limit);
-  context
-    .settings_updated_channel()
-    .send(rate_limit_config)
-    .await?;
+  context.rate_limit_cell().set_config(rate_limit_config);
 
   Ok(Json(SiteResponse {
     site_view,
diff --git a/crates/api_crud/src/site/update.rs b/crates/api_crud/src/site/update.rs
index 3afc79559..b9d8f6a7f 100644
--- a/crates/api_crud/src/site/update.rs
+++ b/crates/api_crud/src/site/update.rs
@@ -157,10 +157,7 @@ pub async fn update_site(
 
   let rate_limit_config =
     local_site_rate_limit_to_rate_limit_config(&site_view.local_site_rate_limit);
-  context
-    .settings_updated_channel()
-    .send(rate_limit_config)
-    .await?;
+  context.rate_limit_cell().set_config(rate_limit_config);
 
   Ok(Json(SiteResponse {
     site_view,
diff --git a/crates/apub/src/objects/mod.rs b/crates/apub/src/objects/mod.rs
index b3653172a..6e27c0d09 100644
--- a/crates/apub/src/objects/mod.rs
+++ b/crates/apub/src/objects/mod.rs
@@ -61,10 +61,7 @@ pub(crate) mod tests {
   use anyhow::anyhow;
   use lemmy_api_common::{context::LemmyContext, request::build_user_agent};
   use lemmy_db_schema::{source::secret::Secret, utils::build_db_pool_for_tests};
-  use lemmy_utils::{
-    rate_limit::{RateLimitCell, RateLimitConfig},
-    settings::SETTINGS,
-  };
+  use lemmy_utils::{rate_limit::RateLimitCell, settings::SETTINGS};
   use reqwest::{Client, Request, Response};
   use reqwest_middleware::{ClientBuilder, Middleware, Next};
   use task_local_extensions::Extensions;
@@ -101,8 +98,7 @@ pub(crate) mod tests {
       jwt_secret: String::new(),
     };
 
-    let rate_limit_config = RateLimitConfig::builder().build();
-    let rate_limit_cell = RateLimitCell::new(rate_limit_config).await;
+    let rate_limit_cell = RateLimitCell::with_test_config();
 
     let context = LemmyContext::create(pool, client, secret, rate_limit_cell.clone());
     let config = FederationConfig::builder()
diff --git a/crates/utils/Cargo.toml b/crates/utils/Cargo.toml
index 20611702e..dc9714b0d 100644
--- a/crates/utils/Cargo.toml
+++ b/crates/utils/Cargo.toml
@@ -47,7 +47,7 @@ smart-default = "0.7.1"
 lettre = { version = "0.10.4", features = ["tokio1", "tokio1-native-tls"] }
 markdown-it = "0.5.1"
 ts-rs = { workspace = true, optional = true }
-enum-map = "2.6"
+enum-map = { workspace = true }
 
 [dev-dependencies]
 reqwest = { workspace = true }
diff --git a/crates/utils/src/rate_limit/mod.rs b/crates/utils/src/rate_limit/mod.rs
index 114daf452..63090749b 100644
--- a/crates/utils/src/rate_limit/mod.rs
+++ b/crates/utils/src/rate_limit/mod.rs
@@ -1,9 +1,9 @@
 use crate::error::{LemmyError, LemmyErrorType};
 use actix_web::dev::{ConnectionInfo, Service, ServiceRequest, ServiceResponse, Transform};
-use enum_map::enum_map;
+use enum_map::{enum_map, EnumMap};
 use futures::future::{ok, Ready};
-use rate_limiter::{InstantSecs, RateLimitStorage, RateLimitType};
-use serde::{Deserialize, Serialize};
+pub use rate_limiter::{ActionType, BucketConfig};
+use rate_limiter::{InstantSecs, RateLimitState};
 use std::{
   future::Future,
   net::{IpAddr, Ipv4Addr, SocketAddr},
@@ -14,208 +14,140 @@ use std::{
   task::{Context, Poll},
   time::Duration,
 };
-use tokio::sync::{mpsc, mpsc::Sender, OnceCell};
-use typed_builder::TypedBuilder;
 
 pub mod rate_limiter;
 
-#[derive(Debug, Deserialize, Serialize, Clone, TypedBuilder)]
-pub struct RateLimitConfig {
-  #[builder(default = 180)]
-  /// Maximum number of messages created in interval
-  pub message: i32,
-  #[builder(default = 60)]
-  /// Interval length for message limit, in seconds
-  pub message_per_second: i32,
-  #[builder(default = 6)]
-  /// Maximum number of posts created in interval
-  pub post: i32,
-  #[builder(default = 300)]
-  /// Interval length for post limit, in seconds
-  pub post_per_second: i32,
-  #[builder(default = 3)]
-  /// Maximum number of registrations in interval
-  pub register: i32,
-  #[builder(default = 3600)]
-  /// Interval length for registration limit, in seconds
-  pub register_per_second: i32,
-  #[builder(default = 6)]
-  /// Maximum number of image uploads in interval
-  pub image: i32,
-  #[builder(default = 3600)]
-  /// Interval length for image uploads, in seconds
-  pub image_per_second: i32,
-  #[builder(default = 6)]
-  /// Maximum number of comments created in interval
-  pub comment: i32,
-  #[builder(default = 600)]
-  /// Interval length for comment limit, in seconds
-  pub comment_per_second: i32,
-  #[builder(default = 60)]
-  /// Maximum number of searches created in interval
-  pub search: i32,
-  #[builder(default = 600)]
-  /// Interval length for search limit, in seconds
-  pub search_per_second: i32,
-  #[builder(default = 1)]
-  /// Maximum number of user settings imports in interval
-  pub import_user_settings: i32,
-  #[builder(default = 24 * 60 * 60)]
-  /// Interval length for importing user settings, in seconds (defaults to 24 hours)
-  pub import_user_settings_per_second: i32,
-}
-
-#[derive(Debug, Clone)]
-struct RateLimit {
-  pub rate_limiter: RateLimitStorage,
-  pub rate_limit_config: RateLimitConfig,
-}
-
 #[derive(Debug, Clone)]
-pub struct RateLimitedGuard {
-  rate_limit: Arc<Mutex<RateLimit>>,
-  type_: RateLimitType,
+pub struct RateLimitChecker {
+  state: Arc<Mutex<RateLimitState>>,
+  action_type: ActionType,
 }
 
 /// Single instance of rate limit config and buckets, which is shared across all threads.
 #[derive(Clone)]
 pub struct RateLimitCell {
-  tx: Sender<RateLimitConfig>,
-  rate_limit: Arc<Mutex<RateLimit>>,
+  state: Arc<Mutex<RateLimitState>>,
 }
 
 impl RateLimitCell {
-  /// Initialize cell if it wasnt initialized yet. Otherwise returns the existing cell.
-  pub async fn new(rate_limit_config: RateLimitConfig) -> &'static Self {
-    static LOCAL_INSTANCE: OnceCell<RateLimitCell> = OnceCell::const_new();
-    LOCAL_INSTANCE
-      .get_or_init(|| async {
-        let (tx, mut rx) = mpsc::channel::<RateLimitConfig>(4);
-        let rate_limit = Arc::new(Mutex::new(RateLimit {
-          rate_limiter: Default::default(),
-          rate_limit_config,
-        }));
-        let rate_limit2 = rate_limit.clone();
-        tokio::spawn(async move {
-          while let Some(r) = rx.recv().await {
-            rate_limit2
-              .lock()
-              .expect("Failed to lock rate limit mutex for updating")
-              .rate_limit_config = r;
-          }
-        });
-        RateLimitCell { tx, rate_limit }
-      })
-      .await
-  }
+  pub fn new(rate_limit_config: EnumMap<ActionType, BucketConfig>) -> Self {
+    let state = Arc::new(Mutex::new(RateLimitState::new(rate_limit_config)));
 
-  /// Call this when the config was updated, to update all in-memory cells.
-  pub async fn send(&self, config: RateLimitConfig) -> Result<(), LemmyError> {
-    self.tx.send(config).await?;
-    Ok(())
-  }
+    let state_weak_ref = Arc::downgrade(&state);
 
-  /// Remove buckets older than the given duration
-  pub fn remove_older_than(&self, mut duration: Duration) {
-    let mut guard = self
-      .rate_limit
-      .lock()
-      .expect("Failed to lock rate limit mutex for reading");
-    let rate_limit = &guard.rate_limit_config;
-
-    // If any rate limit interval is greater than `duration`, then the largest interval is used instead. This preserves buckets that would not pass the rate limit check.
-    let max_interval_secs = enum_map! {
-      RateLimitType::Message => rate_limit.message_per_second,
-      RateLimitType::Post => rate_limit.post_per_second,
-      RateLimitType::Register => rate_limit.register_per_second,
-      RateLimitType::Image => rate_limit.image_per_second,
-      RateLimitType::Comment => rate_limit.comment_per_second,
-      RateLimitType::Search => rate_limit.search_per_second,
-      RateLimitType::ImportUserSettings => rate_limit.import_user_settings_per_second
-    }
-    .into_values()
-    .max()
-    .and_then(|max| u64::try_from(max).ok())
-    .unwrap_or(0);
+    tokio::spawn(async move {
+      let hour = Duration::from_secs(3600);
 
-    duration = std::cmp::max(duration, Duration::from_secs(max_interval_secs));
+      // This loop stops when all other references to `state` are dropped
+      while let Some(state) = state_weak_ref.upgrade() {
+        tokio::time::sleep(hour).await;
+        state
+          .lock()
+          .expect("Failed to lock rate limit mutex for reading")
+          .remove_full_buckets(InstantSecs::now());
+      }
+    });
 
-    guard
-      .rate_limiter
-      .remove_older_than(duration, InstantSecs::now())
+    RateLimitCell { state }
+  }
+
+  pub fn set_config(&self, config: EnumMap<ActionType, BucketConfig>) {
+    self
+      .state
+      .lock()
+      .expect("Failed to lock rate limit mutex for updating")
+      .set_config(config);
   }
 
-  pub fn message(&self) -> RateLimitedGuard {
-    self.kind(RateLimitType::Message)
+  pub fn message(&self) -> RateLimitChecker {
+    self.new_checker(ActionType::Message)
   }
 
-  pub fn post(&self) -> RateLimitedGuard {
-    self.kind(RateLimitType::Post)
+  pub fn post(&self) -> RateLimitChecker {
+    self.new_checker(ActionType::Post)
   }
 
-  pub fn register(&self) -> RateLimitedGuard {
-    self.kind(RateLimitType::Register)
+  pub fn register(&self) -> RateLimitChecker {
+    self.new_checker(ActionType::Register)
   }
 
-  pub fn image(&self) -> RateLimitedGuard {
-    self.kind(RateLimitType::Image)
+  pub fn image(&self) -> RateLimitChecker {
+    self.new_checker(ActionType::Image)
   }
 
-  pub fn comment(&self) -> RateLimitedGuard {
-    self.kind(RateLimitType::Comment)
+  pub fn comment(&self) -> RateLimitChecker {
+    self.new_checker(ActionType::Comment)
   }
 
-  pub fn search(&self) -> RateLimitedGuard {
-    self.kind(RateLimitType::Search)
+  pub fn search(&self) -> RateLimitChecker {
+    self.new_checker(ActionType::Search)
   }
 
-  pub fn import_user_settings(&self) -> RateLimitedGuard {
-    self.kind(RateLimitType::ImportUserSettings)
+  pub fn import_user_settings(&self) -> RateLimitChecker {
+    self.new_checker(ActionType::ImportUserSettings)
   }
 
-  fn kind(&self, type_: RateLimitType) -> RateLimitedGuard {
-    RateLimitedGuard {
-      rate_limit: self.rate_limit.clone(),
-      type_,
+  fn new_checker(&self, action_type: ActionType) -> RateLimitChecker {
+    RateLimitChecker {
+      state: self.state.clone(),
+      action_type,
     }
   }
+
+  pub fn with_test_config() -> Self {
+    Self::new(enum_map! {
+      ActionType::Message => BucketConfig {
+        capacity: 180,
+        secs_to_refill: 60,
+      },
+      ActionType::Post => BucketConfig {
+        capacity: 6,
+        secs_to_refill: 300,
+      },
+      ActionType::Register => BucketConfig {
+        capacity: 3,
+        secs_to_refill: 3600,
+      },
+      ActionType::Image => BucketConfig {
+        capacity: 6,
+        secs_to_refill: 3600,
+      },
+      ActionType::Comment => BucketConfig {
+        capacity: 6,
+        secs_to_refill: 600,
+      },
+      ActionType::Search => BucketConfig {
+        capacity: 60,
+        secs_to_refill: 600,
+      },
+      ActionType::ImportUserSettings => BucketConfig {
+        capacity: 1,
+        secs_to_refill: 24 * 60 * 60,
+      },
+    })
+  }
 }
 
 pub struct RateLimitedMiddleware<S> {
-  rate_limited: RateLimitedGuard,
+  checker: RateLimitChecker,
   service: Rc<S>,
 }
 
-impl RateLimitedGuard {
+impl RateLimitChecker {
   /// Returns true if the request passed the rate limit, false if it failed and should be rejected.
   pub fn check(self, ip_addr: IpAddr) -> bool {
     // Does not need to be blocking because the RwLock in settings never held across await points,
     // and the operation here locks only long enough to clone
-    let mut guard = self
-      .rate_limit
+    let mut state = self
+      .state
       .lock()
       .expect("Failed to lock rate limit mutex for reading");
-    let rate_limit = &guard.rate_limit_config;
-
-    let (kind, interval) = match self.type_ {
-      RateLimitType::Message => (rate_limit.message, rate_limit.message_per_second),
-      RateLimitType::Post => (rate_limit.post, rate_limit.post_per_second),
-      RateLimitType::Register => (rate_limit.register, rate_limit.register_per_second),
-      RateLimitType::Image => (rate_limit.image, rate_limit.image_per_second),
-      RateLimitType::Comment => (rate_limit.comment, rate_limit.comment_per_second),
-      RateLimitType::Search => (rate_limit.search, rate_limit.search_per_second),
-      RateLimitType::ImportUserSettings => (
-        rate_limit.import_user_settings,
-        rate_limit.import_user_settings_per_second,
-      ),
-    };
-    let limiter = &mut guard.rate_limiter;
-
-    limiter.check_rate_limit_full(self.type_, ip_addr, kind, interval, InstantSecs::now())
+
+    state.check(self.action_type, ip_addr, InstantSecs::now())
   }
 }
 
-impl<S> Transform<S, ServiceRequest> for RateLimitedGuard
+impl<S> Transform<S, ServiceRequest> for RateLimitChecker
 where
   S: Service<ServiceRequest, Response = ServiceResponse, Error = actix_web::Error> + 'static,
   S::Future: 'static,
@@ -228,7 +160,7 @@ where
 
   fn new_transform(&self, service: S) -> Self::Future {
     ok(RateLimitedMiddleware {
-      rate_limited: self.clone(),
+      checker: self.clone(),
       service: Rc::new(service),
     })
   }
@@ -252,11 +184,11 @@ where
   fn call(&self, req: ServiceRequest) -> Self::Future {
     let ip_addr = get_ip(&req.connection_info());
 
-    let rate_limited = self.rate_limited.clone();
+    let checker = self.checker.clone();
     let service = self.service.clone();
 
     Box::pin(async move {
-      if rate_limited.check(ip_addr) {
+      if checker.check(ip_addr) {
         service.call(req).await
       } else {
         let (http_req, _) = req.into_parts();
diff --git a/crates/utils/src/rate_limit/rate_limiter.rs b/crates/utils/src/rate_limit/rate_limiter.rs
index 7ba1345c5..d0dad5df2 100644
--- a/crates/utils/src/rate_limit/rate_limiter.rs
+++ b/crates/utils/src/rate_limit/rate_limiter.rs
@@ -1,15 +1,13 @@
-use enum_map::{enum_map, EnumMap};
+use enum_map::EnumMap;
 use once_cell::sync::Lazy;
 use std::{
   collections::HashMap,
   hash::Hash,
   net::{IpAddr, Ipv4Addr, Ipv6Addr},
-  time::{Duration, Instant},
+  time::Instant,
 };
 use tracing::debug;
 
-const UNINITIALIZED_TOKEN_AMOUNT: f32 = -2.0;
-
 static START_TIME: Lazy<Instant> = Lazy::new(Instant::now);
 
 /// Smaller than `std::time::Instant` because it uses a smaller integer for seconds and doesn't
@@ -26,27 +24,48 @@ impl InstantSecs {
         .expect("server has been running for over 136 years"),
     }
   }
-
-  fn secs_since(self, earlier: Self) -> u32 {
-    self.secs.saturating_sub(earlier.secs)
-  }
-
-  fn to_instant(self) -> Instant {
-    *START_TIME + Duration::from_secs(self.secs.into())
-  }
 }
 
-#[derive(PartialEq, Debug, Clone)]
-struct RateLimitBucket {
+#[derive(PartialEq, Debug, Clone, Copy)]
+struct Bucket {
   last_checked: InstantSecs,
   /// This field stores the amount of tokens that were present at `last_checked`.
   /// The amount of tokens steadily increases until it reaches the bucket's capacity.
   /// Performing the rate-limited action consumes 1 token.
-  tokens: f32,
+  tokens: u32,
+}
+
+#[derive(PartialEq, Debug, Copy, Clone)]
+pub struct BucketConfig {
+  pub capacity: u32,
+  pub secs_to_refill: u32,
+}
+
+impl Bucket {
+  fn update(self, now: InstantSecs, config: BucketConfig) -> Self {
+    let secs_since_last_checked = now.secs.saturating_sub(self.last_checked.secs);
+
+    // For `secs_since_last_checked` seconds, the amount of tokens increases by `capacity` every `secs_to_refill` seconds.
+    // The amount of tokens added per second is `capacity / secs_to_refill`.
+    // The expression below is like `secs_since_last_checked * (capacity / secs_to_refill)` but with precision and non-overflowing multiplication.
+    let added_tokens = u64::from(secs_since_last_checked) * u64::from(config.capacity)
+      / u64::from(config.secs_to_refill);
+
+    // The amount of tokens there would be if the bucket had infinite capacity
+    let unbounded_tokens = self.tokens + (added_tokens as u32);
+
+    // Bucket stops filling when capacity is reached
+    let tokens = std::cmp::min(unbounded_tokens, config.capacity);
+
+    Bucket {
+      last_checked: now,
+      tokens,
+    }
+  }
 }
 
 #[derive(Debug, enum_map::Enum, Copy, Clone, AsRefStr)]
-pub(crate) enum RateLimitType {
+pub enum ActionType {
   Message,
   Register,
   Post,
@@ -56,179 +75,228 @@ pub(crate) enum RateLimitType {
   ImportUserSettings,
 }
 
-type Map<K, C> = HashMap<K, RateLimitedGroup<C>>;
-
 #[derive(PartialEq, Debug, Clone)]
 struct RateLimitedGroup<C> {
-  total: EnumMap<RateLimitType, RateLimitBucket>,
+  total: EnumMap<ActionType, Bucket>,
   children: C,
 }
 
-impl<C: Default> RateLimitedGroup<C> {
-  fn new(now: InstantSecs) -> Self {
-    RateLimitedGroup {
-      total: enum_map! {
-        _ => RateLimitBucket {
-          last_checked: now,
-          tokens: UNINITIALIZED_TOKEN_AMOUNT,
-        },
-      },
-      children: Default::default(),
+type Map<K, C> = HashMap<K, RateLimitedGroup<C>>;
+
+/// Implemented for `()`, `Map<T, ()>`, `Map<T, Map<U, ()>>`, etc.
+trait MapLevel: Default {
+  type CapacityFactors;
+  type AddrParts;
+
+  fn check(
+    &mut self,
+    action_type: ActionType,
+    now: InstantSecs,
+    configs: EnumMap<ActionType, BucketConfig>,
+    capacity_factors: Self::CapacityFactors,
+    addr_parts: Self::AddrParts,
+  ) -> bool;
+
+  /// Remove full buckets and return `true` if there's any buckets remaining
+  fn remove_full_buckets(
+    &mut self,
+    now: InstantSecs,
+    configs: EnumMap<ActionType, BucketConfig>,
+  ) -> bool;
+}
+
+impl<K: Eq + Hash, C: MapLevel> MapLevel for Map<K, C> {
+  type CapacityFactors = (u32, C::CapacityFactors);
+  type AddrParts = (K, C::AddrParts);
+
+  fn check(
+    &mut self,
+    action_type: ActionType,
+    now: InstantSecs,
+    configs: EnumMap<ActionType, BucketConfig>,
+    (capacity_factor, child_capacity_factors): Self::CapacityFactors,
+    (addr_part, child_addr_parts): Self::AddrParts,
+  ) -> bool {
+    // Multiplies capacities by `capacity_factor` for groups in `self`
+    let adjusted_configs = configs.map(|_, config| BucketConfig {
+      capacity: config.capacity.saturating_mul(capacity_factor),
+      ..config
+    });
+
+    // Remove groups that are no longer needed if the hash map's existing allocation has no space for new groups.
+    // This is done before calling `HashMap::entry` because that immediately allocates just like `HashMap::insert`.
+    if (self.capacity() == self.len()) && !self.contains_key(&addr_part) {
+      self.remove_full_buckets(now, configs);
     }
+
+    let group = self
+      .entry(addr_part)
+      .or_insert(RateLimitedGroup::new(now, adjusted_configs));
+
+    #[allow(clippy::indexing_slicing)]
+    let total_passes = group.check_total(action_type, now, adjusted_configs[action_type]);
+
+    let children_pass = group.children.check(
+      action_type,
+      now,
+      configs,
+      child_capacity_factors,
+      child_addr_parts,
+    );
+
+    total_passes && children_pass
   }
 
-  fn check_total(
+  fn remove_full_buckets(
     &mut self,
-    type_: RateLimitType,
     now: InstantSecs,
-    capacity: i32,
-    secs_to_refill: i32,
+    configs: EnumMap<ActionType, BucketConfig>,
   ) -> bool {
-    let capacity = capacity as f32;
-    let secs_to_refill = secs_to_refill as f32;
+    self.retain(|_key, group| {
+      let some_children_remaining = group.children.remove_full_buckets(now, configs);
+
+      // Evaluated if `some_children_remaining` is false
+      let total_has_refill_in_future = || {
+        group.total.into_iter().all(|(action_type, bucket)| {
+          #[allow(clippy::indexing_slicing)]
+          let config = configs[action_type];
+          bucket.update(now, config).tokens != config.capacity
+        })
+      };
 
-    #[allow(clippy::indexing_slicing)] // `EnumMap` has no `get` funciton
-    let bucket = &mut self.total[type_];
+      some_children_remaining || total_has_refill_in_future()
+    });
 
-    if bucket.tokens == UNINITIALIZED_TOKEN_AMOUNT {
-      bucket.tokens = capacity;
-    }
+    self.shrink_to_fit();
 
-    let secs_since_last_checked = now.secs_since(bucket.last_checked) as f32;
-    bucket.last_checked = now;
+    !self.is_empty()
+  }
+}
 
-    // For `secs_since_last_checked` seconds, increase `bucket.tokens`
-    // by `capacity` every `secs_to_refill` seconds
-    bucket.tokens += {
-      let tokens_per_sec = capacity / secs_to_refill;
-      secs_since_last_checked * tokens_per_sec
-    };
+impl MapLevel for () {
+  type CapacityFactors = ();
+  type AddrParts = ();
 
-    // Prevent `bucket.tokens` from exceeding `capacity`
-    if bucket.tokens > capacity {
-      bucket.tokens = capacity;
+  fn check(
+    &mut self,
+    _: ActionType,
+    _: InstantSecs,
+    _: EnumMap<ActionType, BucketConfig>,
+    _: Self::CapacityFactors,
+    _: Self::AddrParts,
+  ) -> bool {
+    true
+  }
+
+  fn remove_full_buckets(&mut self, _: InstantSecs, _: EnumMap<ActionType, BucketConfig>) -> bool {
+    false
+  }
+}
+
+impl<C: Default> RateLimitedGroup<C> {
+  fn new(now: InstantSecs, configs: EnumMap<ActionType, BucketConfig>) -> Self {
+    RateLimitedGroup {
+      total: configs.map(|_, config| Bucket {
+        last_checked: now,
+        tokens: config.capacity,
+      }),
+      // `HashMap::new()` or `()`
+      children: Default::default(),
     }
+  }
 
-    if bucket.tokens < 1.0 {
+  fn check_total(
+    &mut self,
+    action_type: ActionType,
+    now: InstantSecs,
+    config: BucketConfig,
+  ) -> bool {
+    #[allow(clippy::indexing_slicing)] // `EnumMap` has no `get` funciton
+    let bucket = &mut self.total[action_type];
+
+    let new_bucket = bucket.update(now, config);
+
+    if new_bucket.tokens == 0 {
       // Not enough tokens yet
-      debug!(
-        "Rate limited type: {}, time_passed: {}, allowance: {}",
-        type_.as_ref(),
-        secs_since_last_checked,
-        bucket.tokens
-      );
+      // Setting `bucket` to `new_bucket` here is useless and would cause the bucket to start over at 0 tokens because of rounding
       false
     } else {
       // Consume 1 token
-      bucket.tokens -= 1.0;
+      *bucket = new_bucket;
+      bucket.tokens -= 1;
       true
     }
   }
 }
 
 /// Rate limiting based on rate type and IP addr
-#[derive(PartialEq, Debug, Clone, Default)]
-pub struct RateLimitStorage {
-  /// One bucket per individual IPv4 address
+#[derive(PartialEq, Debug, Clone)]
+pub struct RateLimitState {
+  /// Each individual IPv4 address gets one `RateLimitedGroup`.
   ipv4_buckets: Map<Ipv4Addr, ()>,
-  /// Seperate buckets for 48, 56, and 64 bit prefixes of IPv6 addresses
+  /// All IPv6 addresses that share the same first 64 bits share the same `RateLimitedGroup`.
+  ///
+  /// The same thing happens for the first 48 and 56 bits, but with increased capacity.
+  ///
+  /// This is done because all users can easily switch to any other IPv6 address that has the same first 64 bits.
+  /// It could be as low as 48 bits for some networks, which is the reason for 48 and 56 bit address groups.
   ipv6_buckets: Map<[u8; 6], Map<u8, Map<u8, ()>>>,
+  /// This stores a `BucketConfig` for each `ActionType`. `EnumMap` makes it impossible to have a missing `BucketConfig`.
+  bucket_configs: EnumMap<ActionType, BucketConfig>,
 }
 
-impl RateLimitStorage {
+impl RateLimitState {
+  pub fn new(bucket_configs: EnumMap<ActionType, BucketConfig>) -> Self {
+    RateLimitState {
+      ipv4_buckets: HashMap::new(),
+      ipv6_buckets: HashMap::new(),
+      bucket_configs,
+    }
+  }
+
   /// Rate limiting Algorithm described here: https://stackoverflow.com/a/668327/1655478
   ///
   /// Returns true if the request passed the rate limit, false if it failed and should be rejected.
-  pub(super) fn check_rate_limit_full(
-    &mut self,
-    type_: RateLimitType,
-    ip: IpAddr,
-    capacity: i32,
-    secs_to_refill: i32,
-    now: InstantSecs,
-  ) -> bool {
-    let mut result = true;
-
-    match ip {
+  pub fn check(&mut self, action_type: ActionType, ip: IpAddr, now: InstantSecs) -> bool {
+    let result = match ip {
       IpAddr::V4(ipv4) => {
-        // Only used by one address.
-        let group = self
+        self
           .ipv4_buckets
-          .entry(ipv4)
-          .or_insert(RateLimitedGroup::new(now));
-
-        result &= group.check_total(type_, now, capacity, secs_to_refill);
+          .check(action_type, now, self.bucket_configs, (1, ()), (ipv4, ()))
       }
 
       IpAddr::V6(ipv6) => {
         let (key_48, key_56, key_64) = split_ipv6(ipv6);
-
-        // Contains all addresses with the same first 48 bits. These addresses might be part of the same network.
-        let group_48 = self
-          .ipv6_buckets
-          .entry(key_48)
-          .or_insert(RateLimitedGroup::new(now));
-        result &= group_48.check_total(type_, now, capacity.saturating_mul(16), secs_to_refill);
-
-        // Contains all addresses with the same first 56 bits. These addresses might be part of the same network.
-        let group_56 = group_48
-          .children
-          .entry(key_56)
-          .or_insert(RateLimitedGroup::new(now));
-        result &= group_56.check_total(type_, now, capacity.saturating_mul(4), secs_to_refill);
-
-        // A group with no children. It is shared by all addresses with the same first 64 bits. These addresses are always part of the same network.
-        let group_64 = group_56
-          .children
-          .entry(key_64)
-          .or_insert(RateLimitedGroup::new(now));
-
-        result &= group_64.check_total(type_, now, capacity, secs_to_refill);
+        self.ipv6_buckets.check(
+          action_type,
+          now,
+          self.bucket_configs,
+          (16, (4, (1, ()))),
+          (key_48, (key_56, (key_64, ()))),
+        )
       }
     };
 
     if !result {
-      debug!("Rate limited IP: {ip}");
+      debug!("Rate limited IP: {ip}, type: {action_type:?}");
     }
 
     result
   }
 
-  /// Remove buckets older than the given duration
-  pub(super) fn remove_older_than(&mut self, duration: Duration, now: InstantSecs) {
-    // Only retain buckets that were last used after `instant`
-    let Some(instant) = now.to_instant().checked_sub(duration) else {
-      return;
-    };
-
-    let is_recently_used = |group: &RateLimitedGroup<_>| {
-      group
-        .total
-        .values()
-        .all(|bucket| bucket.last_checked.to_instant() > instant)
-    };
-
-    retain_and_shrink(&mut self.ipv4_buckets, |_, group| is_recently_used(group));
-
-    retain_and_shrink(&mut self.ipv6_buckets, |_, group_48| {
-      retain_and_shrink(&mut group_48.children, |_, group_56| {
-        retain_and_shrink(&mut group_56.children, |_, group_64| {
-          is_recently_used(group_64)
-        });
-        !group_56.children.is_empty()
-      });
-      !group_48.children.is_empty()
-    })
+  /// Remove buckets that are now full
+  pub fn remove_full_buckets(&mut self, now: InstantSecs) {
+    self
+      .ipv4_buckets
+      .remove_full_buckets(now, self.bucket_configs);
+    self
+      .ipv6_buckets
+      .remove_full_buckets(now, self.bucket_configs);
   }
-}
 
-fn retain_and_shrink<K, V, F>(map: &mut HashMap<K, V>, f: F)
-where
-  K: Eq + Hash,
-  F: FnMut(&K, &mut V) -> bool,
-{
-  map.retain(f);
-  map.shrink_to_fit();
+  pub fn set_config(&mut self, new_configs: EnumMap<ActionType, BucketConfig>) {
+    self.bucket_configs = new_configs;
+  }
 }
 
 fn split_ipv6(ip: Ipv6Addr) -> ([u8; 6], u8, u8) {
@@ -241,6 +309,8 @@ mod tests {
   #![allow(clippy::unwrap_used)]
   #![allow(clippy::indexing_slicing)]
 
+  use super::{ActionType, BucketConfig, InstantSecs, RateLimitState, RateLimitedGroup};
+
   #[test]
   fn test_split_ipv6() {
     let ip = std::net::Ipv6Addr::new(
@@ -254,9 +324,20 @@ mod tests {
 
   #[test]
   fn test_rate_limiter() {
-    let mut rate_limiter = super::RateLimitStorage::default();
-    let mut now = super::InstantSecs::now();
+    let bucket_configs = enum_map::enum_map! {
+      ActionType::Message => BucketConfig {
+        capacity: 2,
+        secs_to_refill: 1,
+      },
+      _ => BucketConfig {
+        capacity: 2,
+        secs_to_refill: 1,
+      },
+    };
+    let mut rate_limiter = RateLimitState::new(bucket_configs);
+    let mut now = InstantSecs::now();
 
+    // Do 1 `Message` and 1 `Post` action for each IP address, and expect the limit to not be reached
     let ips = [
       "123.123.123.123",
       "1:2:3::",
@@ -266,66 +347,71 @@ mod tests {
     ];
     for ip in ips {
       let ip = ip.parse().unwrap();
-      let message_passed =
-        rate_limiter.check_rate_limit_full(super::RateLimitType::Message, ip, 2, 1, now);
-      let post_passed =
-        rate_limiter.check_rate_limit_full(super::RateLimitType::Post, ip, 3, 1, now);
+      let message_passed = rate_limiter.check(ActionType::Message, ip, now);
+      let post_passed = rate_limiter.check(ActionType::Post, ip, now);
       assert!(message_passed);
       assert!(post_passed);
     }
 
     #[allow(clippy::indexing_slicing)]
-    let expected_buckets = |factor: f32, tokens_consumed: f32| {
-      let mut buckets = super::RateLimitedGroup::<()>::new(now).total;
-      buckets[super::RateLimitType::Message] = super::RateLimitBucket {
-        last_checked: now,
-        tokens: (2.0 * factor) - tokens_consumed,
-      };
-      buckets[super::RateLimitType::Post] = super::RateLimitBucket {
-        last_checked: now,
-        tokens: (3.0 * factor) - tokens_consumed,
-      };
+    let expected_buckets = |factor: u32, tokens_consumed: u32| {
+      let adjusted_configs = bucket_configs.map(|_, config| BucketConfig {
+        capacity: config.capacity.saturating_mul(factor),
+        ..config
+      });
+      let mut buckets = RateLimitedGroup::<()>::new(now, adjusted_configs).total;
+      buckets[ActionType::Message].tokens -= tokens_consumed;
+      buckets[ActionType::Post].tokens -= tokens_consumed;
       buckets
     };
 
-    let bottom_group = |tokens_consumed| super::RateLimitedGroup {
-      total: expected_buckets(1.0, tokens_consumed),
+    let bottom_group = |tokens_consumed| RateLimitedGroup {
+      total: expected_buckets(1, tokens_consumed),
       children: (),
     };
 
     assert_eq!(
       rate_limiter,
-      super::RateLimitStorage {
-        ipv4_buckets: [([123, 123, 123, 123].into(), bottom_group(1.0)),].into(),
+      RateLimitState {
+        bucket_configs,
+        ipv4_buckets: [([123, 123, 123, 123].into(), bottom_group(1))].into(),
         ipv6_buckets: [(
           [0, 1, 0, 2, 0, 3],
-          super::RateLimitedGroup {
-            total: expected_buckets(16.0, 4.0),
+          RateLimitedGroup {
+            total: expected_buckets(16, 4),
             children: [
               (
                 0,
-                super::RateLimitedGroup {
-                  total: expected_buckets(4.0, 1.0),
-                  children: [(0, bottom_group(1.0)),].into(),
+                RateLimitedGroup {
+                  total: expected_buckets(4, 1),
+                  children: [(0, bottom_group(1))].into(),
                 }
               ),
               (
                 4,
-                super::RateLimitedGroup {
-                  total: expected_buckets(4.0, 3.0),
-                  children: [(0, bottom_group(1.0)), (5, bottom_group(2.0)),].into(),
+                RateLimitedGroup {
+                  total: expected_buckets(4, 3),
+                  children: [(0, bottom_group(1)), (5, bottom_group(2))].into(),
                 }
               ),
             ]
             .into(),
           }
-        ),]
+        )]
         .into(),
       }
     );
 
+    // Do 2 `Message` actions for 1 IP address and expect only the 2nd one to fail
+    for expected_to_pass in [true, false] {
+      let ip = "1:2:3:0400::".parse().unwrap();
+      let passed = rate_limiter.check(ActionType::Message, ip, now);
+      assert_eq!(passed, expected_to_pass);
+    }
+
+    // Expect `remove_full_buckets` to remove everything when called 2 seconds later
     now.secs += 2;
-    rate_limiter.remove_older_than(std::time::Duration::from_secs(1), now);
+    rate_limiter.remove_full_buckets(now);
     assert!(rate_limiter.ipv4_buckets.is_empty());
     assert!(rate_limiter.ipv6_buckets.is_empty());
   }
diff --git a/src/lib.rs b/src/lib.rs
index c093faaca..2df231dd5 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -156,7 +156,7 @@ pub async fn start_lemmy_server(args: CmdArgs) -> Result<(), LemmyError> {
   // Set up the rate limiter
   let rate_limit_config =
     local_site_rate_limit_to_rate_limit_config(&site_view.local_site_rate_limit);
-  let rate_limit_cell = RateLimitCell::new(rate_limit_config).await;
+  let rate_limit_cell = RateLimitCell::new(rate_limit_config);
 
   println!(
     "Starting http server at {}:{}",
@@ -298,7 +298,7 @@ fn create_http_server(
     .expect("Should always be buildable");
 
   let context: LemmyContext = federation_config.deref().clone();
-  let rate_limit_cell = federation_config.settings_updated_channel().clone();
+  let rate_limit_cell = federation_config.rate_limit_cell().clone();
   let self_origin = settings.get_protocol_and_hostname();
   // Create Http server with websocket support
   let server = HttpServer::new(move || {
diff --git a/src/scheduled_tasks.rs b/src/scheduled_tasks.rs
index 99dd16829..8db74ef9d 100644
--- a/src/scheduled_tasks.rs
+++ b/src/scheduled_tasks.rs
@@ -78,17 +78,6 @@ pub async fn setup(context: LemmyContext) -> Result<(), LemmyError> {
     }
   });
 
-  let context_1 = context.clone();
-  // Remove old rate limit buckets after 1 to 2 hours of inactivity
-  scheduler.every(CTimeUnits::hour(1)).run(move || {
-    let context = context_1.clone();
-
-    async move {
-      let hour = Duration::from_secs(3600);
-      context.settings_updated_channel().remove_older_than(hour);
-    }
-  });
-
   let context_1 = context.clone();
   // Overwrite deleted & removed posts and comments every day
   scheduler.every(CTimeUnits::days(1)).run(move || {
diff --git a/src/session_middleware.rs b/src/session_middleware.rs
index ae82cd44d..f50e0eccd 100644
--- a/src/session_middleware.rs
+++ b/src/session_middleware.rs
@@ -112,7 +112,7 @@ mod tests {
     traits::Crud,
     utils::build_db_pool_for_tests,
   };
-  use lemmy_utils::rate_limit::{RateLimitCell, RateLimitConfig};
+  use lemmy_utils::rate_limit::RateLimitCell;
   use reqwest::Client;
   use reqwest_middleware::ClientBuilder;
   use serial_test::serial;
@@ -131,9 +131,7 @@ mod tests {
       pool_.clone(),
       ClientBuilder::new(Client::default()).build(),
       secret,
-      RateLimitCell::new(RateLimitConfig::builder().build())
-        .await
-        .clone(),
+      RateLimitCell::with_test_config(),
     );
 
     let inserted_instance = Instance::read_or_create(pool, "my_domain.tld".to_string())
-- 
GitLab