From 2a321bac35b872d47d8ae8da4cba31d757e96a26 Mon Sep 17 00:00:00 2001
From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com>
Date: Wed, 6 Nov 2024 22:21:06 +0000
Subject: [PATCH] Issue one time keys in upload order (#17903)

Currently, one-time-keys are issued in a somewhat random order. (In
practice, they are issued according to the lexicographical order of
their key IDs.) That can lead to a situation where a client gives up
hope of a given OTK ever being used, whilst it is still on the server.

Related: https://github.com/element-hq/element-meta/issues/2356
---
 changelog.d/17903.bugfix                      |  1 +
 synapse/handlers/e2e_keys.py                  |  2 +-
 .../storage/databases/main/end_to_end_keys.py | 25 +++++-
 .../delta/88/03_add_otk_ts_added_index.sql    | 18 +++++
 tests/handlers/test_e2e_keys.py               | 78 +++++++++++++++++--
 5 files changed, 116 insertions(+), 8 deletions(-)
 create mode 100644 changelog.d/17903.bugfix
 create mode 100644 synapse/storage/schema/main/delta/88/03_add_otk_ts_added_index.sql

diff --git a/changelog.d/17903.bugfix b/changelog.d/17903.bugfix
new file mode 100644
index 0000000000..a4d02fc983
--- /dev/null
+++ b/changelog.d/17903.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug in Synapse which could cause one-time keys to be issued in the incorrect order, causing message decryption failures.
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index f78e66ad0a..315461fefb 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -615,7 +615,7 @@ class E2eKeysHandler:
         3. Attempt to fetch fallback keys from the database.
 
         Args:
-            local_query: An iterable of tuples of (user ID, device ID, algorithm).
+            local_query: An iterable of tuples of (user ID, device ID, algorithm, number of keys).
             always_include_fallback_keys: True to always include fallback keys.
 
         Returns:
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 575aaf498b..1fbc49e7c5 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -99,6 +99,13 @@ class EndToEndKeyBackgroundStore(SQLBaseStore):
             unique=True,
         )
 
+        self.db_pool.updates.register_background_index_update(
+            update_name="add_otk_ts_added_index",
+            index_name="e2e_one_time_keys_json_user_id_device_id_algorithm_ts_added_idx",
+            table="e2e_one_time_keys_json",
+            columns=("user_id", "device_id", "algorithm", "ts_added_ms"),
+        )
+
 
 class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorkerStore):
     def __init__(
@@ -1122,7 +1129,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         """Take a list of one time keys out of the database.
 
         Args:
-            query_list: An iterable of tuples of (user ID, device ID, algorithm).
+            query_list: An iterable of tuples of (user ID, device ID, algorithm, number of keys).
 
         Returns:
             A tuple (results, missing) of:
@@ -1310,9 +1317,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
             OTK was found.
         """
 
+        # Return the oldest keys from this device (based on `ts_added_ms`).
+        # Doing so means that keys are issued in the same order they were uploaded,
+        # which reduces the chances of a client expiring its copy of a (private)
+        # key while the public key is still on the server, waiting to be issued.
         sql = """
             SELECT key_id, key_json FROM e2e_one_time_keys_json
             WHERE user_id = ? AND device_id = ? AND algorithm = ?
+            ORDER BY ts_added_ms
             LIMIT ?
         """
 
@@ -1354,13 +1366,22 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
             A list of tuples (user_id, device_id, algorithm, key_id, key_json)
             for each OTK claimed.
         """
+        # Find, delete, and return the oldest keys from each device (based on
+        # `ts_added_ms`).
+        #
+        # Doing so means that keys are issued in the same order they were uploaded,
+        # which reduces the chances of a client expiring its copy of a (private)
+        # key while the public key is still on the server, waiting to be issued.
         sql = """
             WITH claims(user_id, device_id, algorithm, claim_count) AS (
                 VALUES ?
             ), ranked_keys AS (
                 SELECT
                     user_id, device_id, algorithm, key_id, claim_count,
-                    ROW_NUMBER() OVER (PARTITION BY (user_id, device_id, algorithm)) AS r
+                    ROW_NUMBER() OVER (
+                        PARTITION BY (user_id, device_id, algorithm)
+                        ORDER BY ts_added_ms
+                    ) AS r
                 FROM e2e_one_time_keys_json
                     JOIN claims USING (user_id, device_id, algorithm)
             )
diff --git a/synapse/storage/schema/main/delta/88/03_add_otk_ts_added_index.sql b/synapse/storage/schema/main/delta/88/03_add_otk_ts_added_index.sql
new file mode 100644
index 0000000000..7712ea68ad
--- /dev/null
+++ b/synapse/storage/schema/main/delta/88/03_add_otk_ts_added_index.sql
@@ -0,0 +1,18 @@
+--
+-- This file is licensed under the Affero General Public License (AGPL) version 3.
+--
+-- Copyright (C) 2024 New Vector, Ltd
+--
+-- This program is free software: you can redistribute it and/or modify
+-- it under the terms of the GNU Affero General Public License as
+-- published by the Free Software Foundation, either version 3 of the
+-- License, or (at your option) any later version.
+--
+-- See the GNU Affero General Public License for more details:
+-- <https://www.gnu.org/licenses/agpl-3.0.html>.
+
+
+-- Add an index on (user_id, device_id, algorithm, ts_added_ms) on e2e_one_time_keys_json, so that OTKs can
+-- efficiently be issued in the same order they were uploaded.
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+    (8803, 'add_otk_ts_added_index', '{}');
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 8a3dfdcf75..bca314db83 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -151,18 +151,30 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
     def test_claim_one_time_key(self) -> None:
         local_user = "@boris:" + self.hs.hostname
         device_id = "xyz"
-        keys = {"alg1:k1": "key1"}
-
         res = self.get_success(
             self.handler.upload_keys_for_user(
-                local_user, device_id, {"one_time_keys": keys}
+                local_user, device_id, {"one_time_keys": {"alg1:k1": "key1"}}
             )
         )
         self.assertDictEqual(
             res, {"one_time_key_counts": {"alg1": 1, "signed_curve25519": 0}}
         )
 
-        res2 = self.get_success(
+        # Keys should be returned in the order they were uploaded. To test, advance time
+        # a little, then upload a second key with an earlier key ID; it should get
+        # returned second.
+        self.reactor.advance(1)
+        res = self.get_success(
+            self.handler.upload_keys_for_user(
+                local_user, device_id, {"one_time_keys": {"alg1:k0": "key0"}}
+            )
+        )
+        self.assertDictEqual(
+            res, {"one_time_key_counts": {"alg1": 2, "signed_curve25519": 0}}
+        )
+
+        # now claim both keys back. They should be in the same order
+        res = self.get_success(
             self.handler.claim_one_time_keys(
                 {local_user: {device_id: {"alg1": 1}}},
                 self.requester,
@@ -171,12 +183,27 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             )
         )
         self.assertEqual(
-            res2,
+            res,
             {
                 "failures": {},
                 "one_time_keys": {local_user: {device_id: {"alg1:k1": "key1"}}},
             },
         )
+        res = self.get_success(
+            self.handler.claim_one_time_keys(
+                {local_user: {device_id: {"alg1": 1}}},
+                self.requester,
+                timeout=None,
+                always_include_fallback_keys=False,
+            )
+        )
+        self.assertEqual(
+            res,
+            {
+                "failures": {},
+                "one_time_keys": {local_user: {device_id: {"alg1:k0": "key0"}}},
+            },
+        )
 
     def test_claim_one_time_key_bulk(self) -> None:
         """Like test_claim_one_time_key but claims multiple keys in one handler call."""
@@ -336,6 +363,47 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
                     counts_by_alg, expected_counts_by_alg, f"{user_id}:{device_id}"
                 )
 
+    def test_claim_one_time_key_bulk_ordering(self) -> None:
+        """Keys returned by the bulk claim call should be returned in the correct order"""
+
+        # Alice has lots of keys, uploaded in a specific order
+        alice = f"@alice:{self.hs.hostname}"
+        alice_dev = "alice_dev_1"
+
+        self.get_success(
+            self.handler.upload_keys_for_user(
+                alice,
+                alice_dev,
+                {"one_time_keys": {"alg1:k20": 20, "alg1:k21": 21, "alg1:k22": 22}},
+            )
+        )
+        # Advance time by 1s, to ensure that there is a difference in upload time.
+        self.reactor.advance(1)
+        self.get_success(
+            self.handler.upload_keys_for_user(
+                alice,
+                alice_dev,
+                {"one_time_keys": {"alg1:k10": 10, "alg1:k11": 11, "alg1:k12": 12}},
+            )
+        )
+
+        # Now claim some, and check we get the right ones.
+        claim_res = self.get_success(
+            self.handler.claim_one_time_keys(
+                {alice: {alice_dev: {"alg1": 2}}},
+                self.requester,
+                timeout=None,
+                always_include_fallback_keys=False,
+            )
+        )
+        # We should get the first-uploaded keys, even though they have later key ids.
+        # We should get a random set of two of k20, k21, k22.
+        self.assertEqual(claim_res["failures"], {})
+        claimed_keys = claim_res["one_time_keys"]["@alice:test"]["alice_dev_1"]
+        self.assertEqual(len(claimed_keys), 2)
+        for key_id in claimed_keys.keys():
+            self.assertIn(key_id, ["alg1:k20", "alg1:k21", "alg1:k22"])
+
     def test_fallback_key(self) -> None:
         local_user = "@boris:" + self.hs.hostname
         device_id = "xyz"
-- 
GitLab