From 8c229920ad6bdccbb81110d88ccc139a09d65784 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Daniel=20Garc=C3=ADa?=
 <dani-garcia@users.noreply.github.com>
Date: Sat, 4 Jan 2020 23:52:38 +0100
Subject: [PATCH] Protect websocket server against panics

---
 src/api/notifications.rs | 60 ++++++++++++++++++++++++++++++++--------
 1 file changed, 48 insertions(+), 12 deletions(-)

diff --git a/src/api/notifications.rs b/src/api/notifications.rs
index 4781aa2d..17422a6b 100644
--- a/src/api/notifications.rs
+++ b/src/api/notifications.rs
@@ -54,10 +54,11 @@ fn negotiate(_headers: Headers, _conn: DbConn) -> JsonResult {
 //
 // Websockets server
 //
+use std::io;
 use std::sync::Arc;
 use std::thread;
 
-use ws::{self, util::Token, Factory, Handler, Handshake, Message, Sender, WebSocket};
+use ws::{self, util::Token, Factory, Handler, Handshake, Message, Sender};
 
 use chashmap::CHashMap;
 use chrono::NaiveDateTime;
@@ -135,20 +136,51 @@ struct InitialMessage {
 const PING_MS: u64 = 15_000;
 const PING: Token = Token(1);
 
+const ID_KEY: &str = "id=";
+const ACCESS_TOKEN_KEY: &str = "access_token=";
+
+impl WSHandler {
+    fn err(&self, msg: &'static str) -> ws::Result<()> {
+        self.out.close(ws::CloseCode::Invalid)?;
+
+        // We need to specifically return an IO error so ws closes the connection
+        let io_error = io::Error::from(io::ErrorKind::InvalidData);
+        Err(ws::Error::new(ws::ErrorKind::Io(io_error), msg))
+    }
+}
+
 impl Handler for WSHandler {
     fn on_open(&mut self, hs: Handshake) -> ws::Result<()> {
-        // TODO: Improve this split
+        // Path == "/notifications/hub?id=<id>==&access_token=<access_token>"
         let path = hs.request.resource();
-        let mut query_split: Vec<_> = path.split('?').nth(1).unwrap().split('&').collect();
-        query_split.sort();
-        let access_token = &query_split[0][13..];
-        let _id = &query_split[1][3..];
+
+        let (_id, access_token) = match path.split('?').nth(1) {
+            Some(params) => {
+                let mut params_iter = params.split('&').take(2);
+
+                let mut id = None;
+                let mut access_token = None;
+                while let Some(val) = params_iter.next() {
+                    if val.starts_with(ID_KEY) {
+                        id = Some(&val[ID_KEY.len()..]);
+                    } else if val.starts_with(ACCESS_TOKEN_KEY) {
+                        access_token = Some(&val[ACCESS_TOKEN_KEY.len()..]);
+                    }
+                }
+
+                match (id, access_token) {
+                    (Some(a), Some(b)) => (a, b),
+                    _ => return self.err("Missing id or access token"),
+                }
+            }
+            None => return self.err("Missing query path"),
+        };
 
         // Validate the user
         use crate::auth;
         let claims = match auth::decode_login(access_token) {
             Ok(claims) => claims,
-            Err(_) => return Err(ws::Error::new(ws::ErrorKind::Internal, "Invalid access token provided")),
+            Err(_) => return self.err("Invalid access token provided"),
         };
 
         // Assign the user to the handler
@@ -190,10 +222,7 @@ impl Handler for WSHandler {
             // reschedule the timeout
             self.out.timeout(PING_MS, PING)
         } else {
-            Err(ws::Error::new(
-                ws::ErrorKind::Internal,
-                "Invalid timeout token provided",
-            ))
+            Ok(())
         }
     }
 }
@@ -362,7 +391,14 @@ pub fn start_notification_server() -> WebSocketUsers {
 
     if CONFIG.websocket_enabled() {
         thread::spawn(move || {
-            WebSocket::new(factory)
+            let mut settings = ws::Settings::default();
+            settings.max_connections = 500;
+            settings.queue_size = 2;
+            settings.panic_on_internal = false;
+
+            ws::Builder::new()
+                .with_settings(settings)
+                .build(factory)
                 .unwrap()
                 .listen((CONFIG.websocket_address().as_str(), CONFIG.websocket_port()))
                 .unwrap();
-- 
GitLab