From d0fc1e904a3060b0f459be9aa7df9b9f1501e294 Mon Sep 17 00:00:00 2001
From: Erik Johnston <erikj@element.io>
Date: Thu, 7 Nov 2024 15:26:14 +0000
Subject: [PATCH] Fix cancellation tests with new Twisted.  (#17906)

The latest Twisted release changed how they implemented `__await__` on
deferreds, which broke the machinery we used to test cancellation.

This PR changes things a bit to instead patch the `__await__` method,
which is a stable API. This mostly doesn't change the core logic, except
for fixing two bugs:
  - We previously did not intercept all await points
- After cancellation we now need to not only unblock currently blocked
await points, but also make sure we don't block any future await points.

c.f. https://github.com/twisted/twisted/pull/12226

---------

Co-authored-by: Devon Hudson <devon.dmytro@gmail.com>
---
 changelog.d/17906.bugfix   |   1 +
 tests/http/server/_base.py | 107 ++++++++++++++++++++++++++++---------
 2 files changed, 84 insertions(+), 24 deletions(-)
 create mode 100644 changelog.d/17906.bugfix

diff --git a/changelog.d/17906.bugfix b/changelog.d/17906.bugfix
new file mode 100644
index 0000000000..f38ce6a590
--- /dev/null
+++ b/changelog.d/17906.bugfix
@@ -0,0 +1 @@
+Fix tests to run with latest Twisted.
diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py
index 731b0c4e59..dff5a5d262 100644
--- a/tests/http/server/_base.py
+++ b/tests/http/server/_base.py
@@ -27,6 +27,7 @@ from typing import (
     Callable,
     ContextManager,
     Dict,
+    Generator,
     List,
     Optional,
     Set,
@@ -49,7 +50,10 @@ from synapse.http.server import (
     respond_with_json,
 )
 from synapse.http.site import SynapseRequest
-from synapse.logging.context import LoggingContext, make_deferred_yieldable
+from synapse.logging.context import (
+    LoggingContext,
+    make_deferred_yieldable,
+)
 from synapse.types import JsonDict
 
 from tests.server import FakeChannel, make_request
@@ -199,7 +203,7 @@ def make_request_with_cancellation_test(
     #
     # We would like to trigger a cancellation at the first `await`, re-run the
     # request and cancel at the second `await`, and so on. By patching
-    # `Deferred.__next__`, we can intercept `await`s, track which ones we have or
+    # `Deferred.__await__`, we can intercept `await`s, track which ones we have or
     # have not seen, and force them to block when they wouldn't have.
 
     # The set of previously seen `await`s.
@@ -211,7 +215,7 @@ def make_request_with_cancellation_test(
     )
 
     for request_number in itertools.count(1):
-        deferred_patch = Deferred__next__Patch(seen_awaits, request_number)
+        deferred_patch = Deferred__await__Patch(seen_awaits, request_number)
 
         try:
             with mock.patch(
@@ -250,6 +254,8 @@ def make_request_with_cancellation_test(
                             )
 
                 if respond_mock.called:
+                    _log_for_request(request_number, "--- response finished ---")
+
                     # The request ran to completion and we are done with testing it.
 
                     # `respond_with_json` writes the response asynchronously, so we
@@ -311,8 +317,8 @@ def make_request_with_cancellation_test(
     assert False, "unreachable"  # noqa: B011
 
 
-class Deferred__next__Patch:
-    """A `Deferred.__next__` patch that will intercept `await`s and force them
+class Deferred__await__Patch:
+    """A `Deferred.__await__` patch that will intercept `await`s and force them
     to block once it sees a new `await`.
 
     When done with the patch, `unblock_awaits()` must be called to clean up after any
@@ -322,7 +328,7 @@ class Deferred__next__Patch:
 
     Usage:
         seen_awaits = set()
-        deferred_patch = Deferred__next__Patch(seen_awaits, 1)
+        deferred_patch = Deferred__await__Patch(seen_awaits, 1)
         try:
             with deferred_patch.patch():
                 # do things
@@ -335,14 +341,14 @@ class Deferred__next__Patch:
         """
         Args:
             seen_awaits: The set of stack traces of `await`s that have been previously
-                seen. When the `Deferred.__next__` patch sees a new `await`, it will add
+                seen. When the `Deferred.__await__` patch sees a new `await`, it will add
                 it to the set.
             request_number: The request number to log against.
         """
         self._request_number = request_number
         self._seen_awaits = seen_awaits
 
-        self._original_Deferred___next__ = Deferred.__next__  # type: ignore[misc,unused-ignore]
+        self._original_Deferred__await__ = Deferred.__await__  # type: ignore[misc,unused-ignore]
 
         # The number of `await`s on `Deferred`s we have seen so far.
         self.awaits_seen = 0
@@ -350,8 +356,13 @@ class Deferred__next__Patch:
         # Whether we have seen a new `await` not in `seen_awaits`.
         self.new_await_seen = False
 
+        # Whether to block new await points we see. This gets set to False once
+        # we have cancelled the request to allow things to run after
+        # cancellation.
+        self._block_new_awaits = True
+
         # To force `await`s on resolved `Deferred`s to block, we make up a new
-        # unresolved `Deferred` and return it out of `Deferred.__next__` /
+        # unresolved `Deferred` and return it out of `Deferred.__await__` /
         # `coroutine.send()`. We have to resolve it later, in case the `await`ing
         # coroutine is part of some shared processing, such as `@cached`.
         self._to_unblock: Dict[Deferred, Union[object, Failure]] = {}
@@ -360,15 +371,15 @@ class Deferred__next__Patch:
         self._previous_stack: List[inspect.FrameInfo] = []
 
     def patch(self) -> ContextManager[Mock]:
-        """Returns a context manager which patches `Deferred.__next__`."""
+        """Returns a context manager which patches `Deferred.__await__`."""
 
-        def Deferred___next__(
-            deferred: "Deferred[T]", value: object = None
-        ) -> "Deferred[T]":
-            """Intercepts `await`s on `Deferred`s and rigs them to block once we have
-            seen enough of them.
+        def Deferred___await__(
+            deferred: "Deferred[T]",
+        ) -> Generator["Deferred[T]", None, T]:
+            """Intercepts calls to `__await__`, which returns a generator
+            yielding deferreds that we await on.
 
-            `Deferred.__next__` will normally:
+            The generator for `__await__` will normally:
                 * return `self` if the `Deferred` is unresolved, in which case
                    `coroutine.send()` will return the `Deferred`, and
                    `_defer.inlineCallbacks` will stop running the coroutine until the
@@ -376,9 +387,43 @@ class Deferred__next__Patch:
                 * raise a `StopIteration(result)`, containing the result of the `await`.
                 * raise another exception, which will come out of the `await`.
             """
+
+            # Get the original generator.
+            gen = self._original_Deferred__await__(deferred)
+
+            # Run the generator, handling each iteration to see if we need to
+            # block.
+            try:
+                while True:
+                    # We've hit a new await point (or the deferred has
+                    # completed), handle it.
+                    handle_next_iteration(deferred)
+
+                    # Continue on.
+                    yield gen.send(None)
+            except StopIteration as e:
+                # We need to convert `StopIteration` into a normal return.
+                return e.value
+
+        def handle_next_iteration(
+            deferred: "Deferred[T]",
+        ) -> None:
+            """Intercepts `await`s on `Deferred`s and rigs them to block once we have
+            seen enough of them.
+
+            Args:
+                deferred: The deferred that we've captured and are intercepting
+                    `await` calls within.
+            """
+            if not self._block_new_awaits:
+                # We're no longer blocking awaits points
+                return
+
             self.awaits_seen += 1
 
-            stack = _get_stack(skip_frames=1)
+            stack = _get_stack(
+                skip_frames=2  # Ignore this function and `Deferred___await__` in stack trace
+            )
             stack_hash = _hash_stack(stack)
 
             if stack_hash not in self._seen_awaits:
@@ -389,20 +434,29 @@ class Deferred__next__Patch:
             if not self.new_await_seen:
                 # This `await` isn't interesting. Let it proceed normally.
 
+                _log_await_stack(
+                    stack,
+                    self._previous_stack,
+                    self._request_number,
+                    "already seen",
+                )
+
                 # Don't log the stack. It's been seen before in a previous run.
                 self._previous_stack = stack
 
-                return self._original_Deferred___next__(deferred, value)
+                return
 
             # We want to block at the current `await`.
             if deferred.called and not deferred.paused:
-                # This `Deferred` already has a result.
-                # We return a new, unresolved, `Deferred` for `_inlineCallbacks` to wait
-                # on. This blocks the coroutine that did this `await`.
+                # This `Deferred` already has a result. We chain a new,
+                # unresolved, `Deferred` to the end of this Deferred that it
+                # will wait on. This blocks the coroutine that did this `await`.
                 # We queue it up for unblocking later.
                 new_deferred: "Deferred[T]" = Deferred()
                 self._to_unblock[new_deferred] = deferred.result
 
+                deferred.addBoth(lambda _: make_deferred_yieldable(new_deferred))
+
                 _log_await_stack(
                     stack,
                     self._previous_stack,
@@ -411,7 +465,9 @@ class Deferred__next__Patch:
                 )
                 self._previous_stack = stack
 
-                return make_deferred_yieldable(new_deferred)
+                # Continue iterating on the deferred now that we've blocked it
+                # again.
+                return
 
             # This `Deferred` does not have a result yet.
             # The `await` will block normally, so we don't have to do anything.
@@ -423,9 +479,9 @@ class Deferred__next__Patch:
             )
             self._previous_stack = stack
 
-            return self._original_Deferred___next__(deferred, value)
+            return
 
-        return mock.patch.object(Deferred, "__next__", new=Deferred___next__)
+        return mock.patch.object(Deferred, "__await__", new=Deferred___await__)
 
     def unblock_awaits(self) -> None:
         """Unblocks any shared processing that we forced to block.
@@ -433,6 +489,9 @@ class Deferred__next__Patch:
         Must be called when done, otherwise processing shared between multiple requests,
         such as database queries started by `@cached`, will become permanently stuck.
         """
+        # Also disable blocking at future await points
+        self._block_new_awaits = False
+
         to_unblock = self._to_unblock
         self._to_unblock = {}
         for deferred, result in to_unblock.items():
-- 
GitLab