From a1fbd6d72917c0435da28b7f21c2f23aecc008b5 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Daniel=20Garc=C3=ADa?=
 <dani-garcia@users.noreply.github.com>
Date: Sun, 17 Mar 2024 15:11:20 +0100
Subject: [PATCH] Improve JWT key initialization and avoid saving public key
 (#4085)

---
 src/auth.rs                 | 59 +++++++++++++++++++++++++++----------
 src/config.rs               |  5 +---
 src/db/models/attachment.rs |  2 +-
 src/main.rs                 | 27 +----------------
 src/util.rs                 | 42 ++------------------------
 5 files changed, 48 insertions(+), 87 deletions(-)

diff --git a/src/auth.rs b/src/auth.rs
index 85b6359e..7eabbc1e 100644
--- a/src/auth.rs
+++ b/src/auth.rs
@@ -2,9 +2,10 @@
 //
 use chrono::{Duration, Utc};
 use num_traits::FromPrimitive;
-use once_cell::sync::Lazy;
+use once_cell::sync::{Lazy, OnceCell};
 
 use jsonwebtoken::{self, errors::ErrorKind, Algorithm, DecodingKey, EncodingKey, Header};
+use openssl::rsa::Rsa;
 use serde::de::DeserializeOwned;
 use serde::ser::Serialize;
 
@@ -26,23 +27,45 @@ static JWT_SEND_ISSUER: Lazy<String> = Lazy::new(|| format!("{}|send", CONFIG.do
 static JWT_ORG_API_KEY_ISSUER: Lazy<String> = Lazy::new(|| format!("{}|api.organization", CONFIG.domain_origin()));
 static JWT_FILE_DOWNLOAD_ISSUER: Lazy<String> = Lazy::new(|| format!("{}|file_download", CONFIG.domain_origin()));
 
-static PRIVATE_RSA_KEY: Lazy<EncodingKey> = Lazy::new(|| {
-    let key =
-        std::fs::read(CONFIG.private_rsa_key()).unwrap_or_else(|e| panic!("Error loading private RSA Key. \n{e}"));
-    EncodingKey::from_rsa_pem(&key).unwrap_or_else(|e| panic!("Error decoding private RSA Key.\n{e}"))
-});
-static PUBLIC_RSA_KEY: Lazy<DecodingKey> = Lazy::new(|| {
-    let key = std::fs::read(CONFIG.public_rsa_key()).unwrap_or_else(|e| panic!("Error loading public RSA Key. \n{e}"));
-    DecodingKey::from_rsa_pem(&key).unwrap_or_else(|e| panic!("Error decoding public RSA Key.\n{e}"))
-});
+static PRIVATE_RSA_KEY: OnceCell<EncodingKey> = OnceCell::new();
+static PUBLIC_RSA_KEY: OnceCell<DecodingKey> = OnceCell::new();
 
-pub fn load_keys() {
-    Lazy::force(&PRIVATE_RSA_KEY);
-    Lazy::force(&PUBLIC_RSA_KEY);
+pub fn initialize_keys() -> Result<(), crate::error::Error> {
+    let mut priv_key_buffer = Vec::with_capacity(2048);
+
+    let priv_key = {
+        let mut priv_key_file = File::options().create(true).read(true).write(true).open(CONFIG.private_rsa_key())?;
+
+        #[allow(clippy::verbose_file_reads)]
+        let bytes_read = priv_key_file.read_to_end(&mut priv_key_buffer)?;
+
+        if bytes_read > 0 {
+            Rsa::private_key_from_pem(&priv_key_buffer[..bytes_read])?
+        } else {
+            // Only create the key if the file doesn't exist or is empty
+            let rsa_key = openssl::rsa::Rsa::generate(2048)?;
+            priv_key_buffer = rsa_key.private_key_to_pem()?;
+            priv_key_file.write_all(&priv_key_buffer)?;
+            info!("Private key created correctly.");
+            rsa_key
+        }
+    };
+
+    let pub_key_buffer = priv_key.public_key_to_pem()?;
+
+    let enc = EncodingKey::from_rsa_pem(&priv_key_buffer)?;
+    let dec: DecodingKey = DecodingKey::from_rsa_pem(&pub_key_buffer)?;
+    if PRIVATE_RSA_KEY.set(enc).is_err() {
+        err!("PRIVATE_RSA_KEY must only be initialized once")
+    }
+    if PUBLIC_RSA_KEY.set(dec).is_err() {
+        err!("PUBLIC_RSA_KEY must only be initialized once")
+    }
+    Ok(())
 }
 
 pub fn encode_jwt<T: Serialize>(claims: &T) -> String {
-    match jsonwebtoken::encode(&JWT_HEADER, claims, &PRIVATE_RSA_KEY) {
+    match jsonwebtoken::encode(&JWT_HEADER, claims, PRIVATE_RSA_KEY.wait()) {
         Ok(token) => token,
         Err(e) => panic!("Error encoding jwt {e}"),
     }
@@ -56,7 +79,7 @@ fn decode_jwt<T: DeserializeOwned>(token: &str, issuer: String) -> Result<T, Err
     validation.set_issuer(&[issuer]);
 
     let token = token.replace(char::is_whitespace, "");
-    match jsonwebtoken::decode(&token, &PUBLIC_RSA_KEY, &validation) {
+    match jsonwebtoken::decode(&token, PUBLIC_RSA_KEY.wait(), &validation) {
         Ok(d) => Ok(d.claims),
         Err(err) => match *err.kind() {
             ErrorKind::InvalidToken => err!("Token is invalid"),
@@ -799,7 +822,11 @@ impl<'r> FromRequest<'r> for OwnerHeaders {
 //
 // Client IP address detection
 //
-use std::net::IpAddr;
+use std::{
+    fs::File,
+    io::{Read, Write},
+    net::IpAddr,
+};
 
 pub struct ClientIp {
     pub ip: IpAddr,
diff --git a/src/config.rs b/src/config.rs
index 2f0e9264..e174c66b 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -1164,7 +1164,7 @@ impl Config {
     }
 
     pub fn delete_user_config(&self) -> Result<(), Error> {
-        crate::util::delete_file(&CONFIG_FILE)?;
+        std::fs::remove_file(&*CONFIG_FILE)?;
 
         // Empty user config
         let usr = ConfigBuilder::default();
@@ -1189,9 +1189,6 @@ impl Config {
     pub fn private_rsa_key(&self) -> String {
         format!("{}.pem", CONFIG.rsa_key_filename())
     }
-    pub fn public_rsa_key(&self) -> String {
-        format!("{}.pub.pem", CONFIG.rsa_key_filename())
-    }
     pub fn mail_enabled(&self) -> bool {
         let inner = &self.inner.read().unwrap().config;
         inner._enable_smtp && (inner.smtp_host.is_some() || inner.use_sendmail)
diff --git a/src/db/models/attachment.rs b/src/db/models/attachment.rs
index 8f05e6b4..f8eca72f 100644
--- a/src/db/models/attachment.rs
+++ b/src/db/models/attachment.rs
@@ -103,7 +103,7 @@ impl Attachment {
 
             let file_path = &self.get_file_path();
 
-            match crate::util::delete_file(file_path) {
+            match std::fs::remove_file(file_path) {
                 // Ignore "file not found" errors. This can happen when the
                 // upstream caller has already cleaned up the file as part of
                 // its own error handling.
diff --git a/src/main.rs b/src/main.rs
index 05f43c5a..53b72606 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -71,7 +71,7 @@ async fn main() -> Result<(), Error> {
     let extra_debug = matches!(level, LF::Trace | LF::Debug);
 
     check_data_folder().await;
-    check_rsa_keys().unwrap_or_else(|_| {
+    auth::initialize_keys().unwrap_or_else(|_| {
         error!("Error creating keys, exiting...");
         exit(1);
     });
@@ -444,31 +444,6 @@ async fn container_data_folder_is_persistent(data_folder: &str) -> bool {
     true
 }
 
-fn check_rsa_keys() -> Result<(), crate::error::Error> {
-    // If the RSA keys don't exist, try to create them
-    let priv_path = CONFIG.private_rsa_key();
-    let pub_path = CONFIG.public_rsa_key();
-
-    if !util::file_exists(&priv_path) {
-        let rsa_key = openssl::rsa::Rsa::generate(2048)?;
-
-        let priv_key = rsa_key.private_key_to_pem()?;
-        crate::util::write_file(&priv_path, &priv_key)?;
-        info!("Private key created correctly.");
-    }
-
-    if !util::file_exists(&pub_path) {
-        let rsa_key = openssl::rsa::Rsa::private_key_from_pem(&std::fs::read(&priv_path)?)?;
-
-        let pub_key = rsa_key.public_key_to_pem()?;
-        crate::util::write_file(&pub_path, &pub_key)?;
-        info!("Public key created correctly.");
-    }
-
-    auth::load_keys();
-    Ok(())
-}
-
 fn check_web_vault() {
     if !CONFIG.web_vault_enabled() {
         return;
diff --git a/src/util.rs b/src/util.rs
index 0bf37959..2f04fe34 100644
--- a/src/util.rs
+++ b/src/util.rs
@@ -1,11 +1,7 @@
 //
 // Web Headers and caching
 //
-use std::{
-    collections::HashMap,
-    io::{Cursor, ErrorKind},
-    ops::Deref,
-};
+use std::{collections::HashMap, io::Cursor, ops::Deref, path::Path};
 
 use num_traits::ToPrimitive;
 use rocket::{
@@ -334,40 +330,6 @@ impl Fairing for BetterLogging {
     }
 }
 
-//
-// File handling
-//
-use std::{
-    fs::{self, File},
-    io::Result as IOResult,
-    path::Path,
-};
-
-pub fn file_exists(path: &str) -> bool {
-    Path::new(path).exists()
-}
-
-pub fn write_file(path: &str, content: &[u8]) -> Result<(), crate::error::Error> {
-    use std::io::Write;
-    let mut f = match File::create(path) {
-        Ok(file) => file,
-        Err(e) => {
-            if e.kind() == ErrorKind::PermissionDenied {
-                error!("Can't create '{}': Permission denied", path);
-            }
-            return Err(From::from(e));
-        }
-    };
-
-    f.write_all(content)?;
-    f.flush()?;
-    Ok(())
-}
-
-pub fn delete_file(path: &str) -> IOResult<()> {
-    fs::remove_file(path)
-}
-
 pub fn get_display_size(size: i64) -> String {
     const UNITS: [&str; 6] = ["bytes", "KB", "MB", "GB", "TB", "PB"];
 
@@ -444,7 +406,7 @@ pub fn get_env_str_value(key: &str) -> Option<String> {
     match (value_from_env, value_file) {
         (Ok(_), Ok(_)) => panic!("You should not define both {key} and {key_file}!"),
         (Ok(v_env), Err(_)) => Some(v_env),
-        (Err(_), Ok(v_file)) => match fs::read_to_string(v_file) {
+        (Err(_), Ok(v_file)) => match std::fs::read_to_string(v_file) {
             Ok(content) => Some(content.trim().to_string()),
             Err(e) => panic!("Failed to load {key}: {e:?}"),
         },
-- 
GitLab