Unverified Commit bdfde6dc authored by Jonathan de Jong's avatar Jonathan de Jong Committed by GitHub
Browse files

Use inline type hints in `http/federation/`, `storage/` and `util/` (#10381)

parent 3acf85c8
Convert internal type variable syntax to reflect wider ecosystem use.
\ No newline at end of file
......@@ -70,10 +70,8 @@ WELL_KNOWN_RETRY_ATTEMPTS = 3
logger = logging.getLogger(__name__)
_well_known_cache = TTLCache("well-known") # type: TTLCache[bytes, Optional[bytes]]
_had_valid_well_known_cache = TTLCache(
"had-valid-well-known"
) # type: TTLCache[bytes, bool]
_well_known_cache: TTLCache[bytes, Optional[bytes]] = TTLCache("well-known")
_had_valid_well_known_cache: TTLCache[bytes, bool] = TTLCache("had-valid-well-known")
@attr.s(slots=True, frozen=True)
......@@ -130,9 +128,10 @@ class WellKnownResolver:
# requests for the same server in parallel?
try:
with Measure(self._clock, "get_well_known"):
result, cache_period = await self._fetch_well_known(
server_name
) # type: Optional[bytes], float
result: Optional[bytes]
cache_period: float
result, cache_period = await self._fetch_well_known(server_name)
except _FetchWellKnownFailure as e:
if prev_result and e.temporary:
......
......@@ -92,14 +92,12 @@ class BackgroundUpdater:
self.db_pool = database
# if a background update is currently running, its name.
self._current_background_update = None # type: Optional[str]
self._background_update_performance = (
{}
) # type: Dict[str, BackgroundUpdatePerformance]
self._background_update_handlers = (
{}
) # type: Dict[str, Callable[[JsonDict, int], Awaitable[int]]]
self._current_background_update: Optional[str] = None
self._background_update_performance: Dict[str, BackgroundUpdatePerformance] = {}
self._background_update_handlers: Dict[
str, Callable[[JsonDict, int], Awaitable[int]]
] = {}
self._all_done = False
def start_doing_background_updates(self) -> None:
......@@ -411,7 +409,7 @@ class BackgroundUpdater:
c.execute(sql)
if isinstance(self.db_pool.engine, engines.PostgresEngine):
runner = create_index_psql # type: Optional[Callable[[Connection], None]]
runner: Optional[Callable[[Connection], None]] = create_index_psql
elif psql_only:
runner = None
else:
......
......@@ -670,8 +670,8 @@ class DatabasePool:
Returns:
The result of func
"""
after_callbacks = [] # type: List[_CallbackListEntry]
exception_callbacks = [] # type: List[_CallbackListEntry]
after_callbacks: List[_CallbackListEntry] = []
exception_callbacks: List[_CallbackListEntry] = []
if not current_context():
logger.warning("Starting db txn '%s' from sentinel context", desc)
......@@ -1090,7 +1090,7 @@ class DatabasePool:
return False
# We didn't find any existing rows, so insert a new one
allvalues = {} # type: Dict[str, Any]
allvalues: Dict[str, Any] = {}
allvalues.update(keyvalues)
allvalues.update(values)
allvalues.update(insertion_values)
......@@ -1121,7 +1121,7 @@ class DatabasePool:
values: The nonunique columns and their new values
insertion_values: additional key/values to use only when inserting
"""
allvalues = {} # type: Dict[str, Any]
allvalues: Dict[str, Any] = {}
allvalues.update(keyvalues)
allvalues.update(insertion_values or {})
......@@ -1257,7 +1257,7 @@ class DatabasePool:
value_values: A list of each row's value column values.
Ignored if value_names is empty.
"""
allnames = [] # type: List[str]
allnames: List[str] = []
allnames.extend(key_names)
allnames.extend(value_names)
......@@ -1566,7 +1566,7 @@ class DatabasePool:
"""
keyvalues = keyvalues or {}
results = [] # type: List[Dict[str, Any]]
results: List[Dict[str, Any]] = []
if not iterable:
return results
......@@ -1978,7 +1978,7 @@ class DatabasePool:
raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
where_clause = "WHERE " if filters or keyvalues or exclude_keyvalues else ""
arg_list = [] # type: List[Any]
arg_list: List[Any] = []
if filters:
where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters)
arg_list += list(filters.values())
......
......@@ -48,9 +48,7 @@ def _make_exclusive_regex(
]
if exclusive_user_regexes:
exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)
exclusive_user_pattern = re.compile(
exclusive_user_regex
) # type: Optional[Pattern]
exclusive_user_pattern: Optional[Pattern] = re.compile(exclusive_user_regex)
else:
# We handle this case specially otherwise the constructed regex
# will always match
......
......@@ -247,7 +247,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
txn.execute(sql, query_params)
result = {} # type: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]
result: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]] = {}
for (user_id, device_id, display_name, key_json) in txn:
if include_deleted_devices:
deleted_devices.remove((user_id, device_id))
......
......@@ -62,9 +62,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
)
# Cache of event ID to list of auth event IDs and their depths.
self._event_auth_cache = LruCache(
self._event_auth_cache: LruCache[str, List[Tuple[str, int]]] = LruCache(
500000, "_event_auth_cache", size_callback=len
) # type: LruCache[str, List[Tuple[str, int]]]
)
self._clock.looping_call(self._get_stats_for_federation_staging, 30 * 1000)
......@@ -137,10 +137,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
initial_events = set(event_ids)
# All the events that we've found that are reachable from the events.
seen_events = set() # type: Set[str]
seen_events: Set[str] = set()
# A map from chain ID to max sequence number of the given events.
event_chains = {} # type: Dict[int, int]
event_chains: Dict[int, int] = {}
sql = """
SELECT event_id, chain_id, sequence_number
......@@ -182,7 +182,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
"""
# A map from chain ID to max sequence number *reachable* from any event ID.
chains = {} # type: Dict[int, int]
chains: Dict[int, int] = {}
# Add all linked chains reachable from initial set of chains.
for batch in batch_iter(event_chains, 1000):
......@@ -353,14 +353,14 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
initial_events = set(state_sets[0]).union(*state_sets[1:])
# Map from event_id -> (chain ID, seq no)
chain_info = {} # type: Dict[str, Tuple[int, int]]
chain_info: Dict[str, Tuple[int, int]] = {}
# Map from chain ID -> seq no -> event Id
chain_to_event = {} # type: Dict[int, Dict[int, str]]
chain_to_event: Dict[int, Dict[int, str]] = {}
# All the chains that we've found that are reachable from the state
# sets.
seen_chains = set() # type: Set[int]
seen_chains: Set[int] = set()
sql = """
SELECT event_id, chain_id, sequence_number
......@@ -392,9 +392,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
# Corresponds to `state_sets`, except as a map from chain ID to max
# sequence number reachable from the state set.
set_to_chain = [] # type: List[Dict[int, int]]
set_to_chain: List[Dict[int, int]] = []
for state_set in state_sets:
chains = {} # type: Dict[int, int]
chains: Dict[int, int] = {}
set_to_chain.append(chains)
for event_id in state_set:
......@@ -446,7 +446,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
# Mapping from chain ID to the range of sequence numbers that should be
# pulled from the database.
chain_to_gap = {} # type: Dict[int, Tuple[int, int]]
chain_to_gap: Dict[int, Tuple[int, int]] = {}
for chain_id in seen_chains:
min_seq_no = min(chains.get(chain_id, 0) for chains in set_to_chain)
......@@ -555,7 +555,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
}
# The sorted list of events whose auth chains we should walk.
search = [] # type: List[Tuple[int, str]]
search: List[Tuple[int, str]] = []
# We need to get the depth of the initial events for sorting purposes.
sql = """
......@@ -578,7 +578,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
search.sort()
# Map from event to its auth events
event_to_auth_events = {} # type: Dict[str, Set[str]]
event_to_auth_events: Dict[str, Set[str]] = {}
base_sql = """
SELECT a.event_id, auth_id, depth
......
......@@ -759,7 +759,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# object because we might not have the same amount of rows in each of them. To do
# this, we use a dict indexed on the user ID and room ID to make it easier to
# populate.
summaries = {} # type: Dict[Tuple[str, str], _EventPushSummary]
summaries: Dict[Tuple[str, str], _EventPushSummary] = {}
for row in txn:
summaries[(row[0], row[1])] = _EventPushSummary(
unread_count=row[2],
......
......@@ -109,10 +109,8 @@ class PersistEventsStore:
# Ideally we'd move these ID gens here, unfortunately some other ID
# generators are chained off them so doing so is a bit of a PITA.
self._backfill_id_gen = (
self.store._backfill_id_gen
) # type: MultiWriterIdGenerator
self._stream_id_gen = self.store._stream_id_gen # type: MultiWriterIdGenerator
self._backfill_id_gen: MultiWriterIdGenerator = self.store._backfill_id_gen
self._stream_id_gen: MultiWriterIdGenerator = self.store._stream_id_gen
# This should only exist on instances that are configured to write
assert (
......@@ -221,7 +219,7 @@ class PersistEventsStore:
Returns:
Filtered event ids
"""
results = [] # type: List[str]
results: List[str] = []
def _get_events_which_are_prevs_txn(txn, batch):
sql = """
......@@ -508,7 +506,7 @@ class PersistEventsStore:
"""
# Map from event ID to chain ID/sequence number.
chain_map = {} # type: Dict[str, Tuple[int, int]]
chain_map: Dict[str, Tuple[int, int]] = {}
# Set of event IDs to calculate chain ID/seq numbers for.
events_to_calc_chain_id_for = set(event_to_room_id)
......@@ -817,8 +815,8 @@ class PersistEventsStore:
# new chain if the sequence number has already been allocated.
#
existing_chains = set() # type: Set[int]
tree = [] # type: List[Tuple[str, Optional[str]]]
existing_chains: Set[int] = set()
tree: List[Tuple[str, Optional[str]]] = []
# We need to do this in a topologically sorted order as we want to
# generate chain IDs/sequence numbers of an event's auth events before
......@@ -848,7 +846,7 @@ class PersistEventsStore:
)
txn.execute(sql % (clause,), args)
chain_to_max_seq_no = {row[0]: row[1] for row in txn} # type: Dict[Any, int]
chain_to_max_seq_no: Dict[Any, int] = {row[0]: row[1] for row in txn}
# Allocate the new events chain ID/sequence numbers.
#
......@@ -858,8 +856,8 @@ class PersistEventsStore:
# number of new chain IDs in one call, replacing all temporary
# objects with real allocated chain IDs.
unallocated_chain_ids = set() # type: Set[object]
new_chain_tuples = {} # type: Dict[str, Tuple[Any, int]]
unallocated_chain_ids: Set[object] = set()
new_chain_tuples: Dict[str, Tuple[Any, int]] = {}
for event_id, auth_event_id in tree:
# If we reference an auth_event_id we fetch the allocated chain ID,
# either from the existing `chain_map` or the newly generated
......@@ -870,7 +868,7 @@ class PersistEventsStore:
if not existing_chain_id:
existing_chain_id = chain_map[auth_event_id]
new_chain_tuple = None # type: Optional[Tuple[Any, int]]
new_chain_tuple: Optional[Tuple[Any, int]] = None
if existing_chain_id:
# We found a chain ID/sequence number candidate, check its
# not already taken.
......@@ -897,9 +895,9 @@ class PersistEventsStore:
)
# Map from potentially temporary chain ID to real chain ID
chain_id_to_allocated_map = dict(
chain_id_to_allocated_map: Dict[Any, int] = dict(
zip(unallocated_chain_ids, newly_allocated_chain_ids)
) # type: Dict[Any, int]
)
chain_id_to_allocated_map.update((c, c) for c in existing_chains)
return {
......@@ -1175,9 +1173,9 @@ class PersistEventsStore:
Returns:
list[(EventBase, EventContext)]: filtered list
"""
new_events_and_contexts = (
OrderedDict()
) # type: OrderedDict[str, Tuple[EventBase, EventContext]]
new_events_and_contexts: OrderedDict[
str, Tuple[EventBase, EventContext]
] = OrderedDict()
for event, context in events_and_contexts:
prev_event_context = new_events_and_contexts.get(event.event_id)
if prev_event_context:
......@@ -1205,7 +1203,7 @@ class PersistEventsStore:
we are persisting
backfilled (bool): True if the events were backfilled
"""
depth_updates = {} # type: Dict[str, int]
depth_updates: Dict[str, int] = {}
for event, context in events_and_contexts:
# Remove the any existing cache entries for the event_ids
txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
......@@ -1885,7 +1883,7 @@ class PersistEventsStore:
),
)
room_to_event_ids = {} # type: Dict[str, List[str]]
room_to_event_ids: Dict[str, List[str]] = {}
for e, _ in events_and_contexts:
room_to_event_ids.setdefault(e.room_id, []).append(e.event_id)
......@@ -2012,7 +2010,7 @@ class PersistEventsStore:
Forward extremities are handled when we first start persisting the events.
"""
events_by_room = {} # type: Dict[str, List[EventBase]]
events_by_room: Dict[str, List[EventBase]] = {}
for ev in events:
events_by_room.setdefault(ev.room_id, []).append(ev)
......
......@@ -960,9 +960,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
event_to_types = {row[0]: (row[1], row[2]) for row in rows}
# Calculate the new last position we've processed up to.
new_last_depth = rows[-1][3] if rows else last_depth # type: int
new_last_stream = rows[-1][4] if rows else last_stream # type: int
new_last_room_id = rows[-1][5] if rows else "" # type: str
new_last_depth: int = rows[-1][3] if rows else last_depth
new_last_stream: int = rows[-1][4] if rows else last_stream
new_last_room_id: str = rows[-1][5] if rows else ""
# Map from room_id to last depth/stream_ordering processed for the room,
# excluding the last room (which we're likely still processing). We also
......@@ -989,7 +989,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
retcols=("event_id", "auth_id"),
)
event_to_auth_chain = {} # type: Dict[str, List[str]]
event_to_auth_chain: Dict[str, List[str]] = {}
for row in auth_events:
event_to_auth_chain.setdefault(row["event_id"], []).append(row["auth_id"])
......
......@@ -1365,10 +1365,10 @@ class EventsWorkerStore(SQLBaseStore):
# we need to make sure that, for every stream id in the results, we get *all*
# the rows with that stream id.
rows = await self.db_pool.runInteraction(
rows: List[Tuple] = await self.db_pool.runInteraction(
"get_all_updated_current_state_deltas",
get_all_updated_current_state_deltas_txn,
) # type: List[Tuple]
)
# if we've got fewer rows than the limit, we're good
if len(rows) < target_row_count:
......@@ -1469,7 +1469,7 @@ class EventsWorkerStore(SQLBaseStore):
"""
mapping = {}
txn_id_to_event = {} # type: Dict[Tuple[str, int, str], str]
txn_id_to_event: Dict[Tuple[str, int, str], str] = {}
for event in events:
token_id = getattr(event.internal_metadata, "token_id", None)
......
......@@ -115,7 +115,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
logger.info("[purge] looking for events to delete")
should_delete_expr = "state_key IS NULL"
should_delete_params = () # type: Tuple[Any, ...]
should_delete_params: Tuple[Any, ...] = ()
if not delete_local_events:
should_delete_expr += " AND event_id NOT LIKE ?"
......
......@@ -79,9 +79,9 @@ class PushRulesWorkerStore(
super().__init__(database, db_conn, hs)
if hs.config.worker.worker_app is None:
self._push_rules_stream_id_gen = StreamIdGenerator(
db_conn, "push_rules_stream", "stream_id"
) # type: Union[StreamIdGenerator, SlavedIdTracker]
self._push_rules_stream_id_gen: Union[
StreamIdGenerator, SlavedIdTracker
] = StreamIdGenerator(db_conn, "push_rules_stream", "stream_id")
else:
self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id"
......
......@@ -1744,7 +1744,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
items = keyvalues.items()
where_clause = " AND ".join(k + " = ?" for k, _ in items)
values = [v for _, v in items] # type: List[Union[str, int]]
values: List[Union[str, int]] = [v for _, v in items]
# Conveniently, refresh_tokens and access_tokens both use the user_id and device_id fields. Only caveat
# is the `except_token_id` param that is tricky to get right, so for now we're just using the same where
# clause and values before we handle that. This seems to be only used in the "set password" handler.
......
......@@ -1085,9 +1085,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
# stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and
# then filtering the results.
if from_token.topological is not None:
from_bound = (
from_token.as_historical_tuple()
) # type: Tuple[Optional[int], int]
from_bound: Tuple[Optional[int], int] = from_token.as_historical_tuple()
elif direction == "b":
from_bound = (
None,
......@@ -1099,7 +1097,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
from_token.stream,
)
to_bound = None # type: Optional[Tuple[Optional[int], int]]
to_bound: Optional[Tuple[Optional[int], int]] = None
if to_token:
if to_token.topological is not None:
to_bound = to_token.as_historical_tuple()
......
......@@ -42,7 +42,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
"room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
)
tags_by_room = {} # type: Dict[str, Dict[str, JsonDict]]
tags_by_room: Dict[str, Dict[str, JsonDict]] = {}
for row in rows:
room_tags = tags_by_room.setdefault(row["room_id"], {})
room_tags[row["tag"]] = db_to_json(row["content"])
......
......@@ -224,12 +224,12 @@ class UIAuthWorkerStore(SQLBaseStore):
self, txn: LoggingTransaction, session_id: str, key: str, value: Any
):
# Get the current value.
result = self.db_pool.simple_select_one_txn(
result: Dict[str, Any] = self.db_pool.simple_select_one_txn( # type: ignore
txn,
table="ui_auth_sessions",
keyvalues={"session_id": session_id},
retcols=("serverdict",),
) # type: Dict[str, Any] # type: ignore
)
# Update it and add it back to the database.
serverdict = db_to_json(result["serverdict"])
......
......@@ -307,7 +307,7 @@ class EventsPersistenceStorage:
matched the transcation ID; the existing event is returned in such
a case.
"""
partitioned = {} # type: Dict[str, List[Tuple[EventBase, EventContext]]]
partitioned: Dict[str, List[Tuple[EventBase, EventContext]]] = {}
for event, ctx in events_and_contexts:
partitioned.setdefault(event.room_id, []).append((event, ctx))
......@@ -384,7 +384,7 @@ class EventsPersistenceStorage:
A dictionary of event ID to event ID we didn't persist as we already
had another event persisted with the same TXN ID.
"""
replaced_events = {} # type: Dict[str, str]
replaced_events: Dict[str, str] = {}
if not events_and_contexts:
return replaced_events
......@@ -440,16 +440,14 @@ class EventsPersistenceStorage:
# Set of remote users which were in rooms the server has left. We
# should check if we still share any rooms and if not we mark their
# device lists as stale.
potentially_left_users = set() # type: Set[str]
potentially_left_users: Set[str] = set()
if not backfilled:
with Measure(self._clock, "_calculate_state_and_extrem"):
# Work out the new "current state" for each room.
# We do this by working out what the new extremities are and then
# calculating the state from that.
events_by_room = (
{}
) # type: Dict[str, List[Tuple[EventBase, EventContext]]]
events_by_room: Dict[str, List[Tuple[EventBase, EventContext]]] = {}
for event, context in chunk:
events_by_room.setdefault(event.room_id, []).append(
(event, context)
......@@ -622,9 +620,9 @@ class EventsPersistenceStorage:
)
# Remove any events which are prev_events of any existing events.
existing_prevs = await self.persist_events_store._get_events_which_are_prevs(
result
) # type: Collection[str]
existing_prevs: Collection[
str
] = await self.persist_events_store._get_events_which_are_prevs(result)
result.difference_update(existing_prevs)
# Finally handle the case where the new events have soft-failed prev
......
......@@ -256,7 +256,7 @@ def _setup_new_database(
for database in databases
)
directory_entries = [] # type: List[_DirectoryListing]
directory_entries: List[_DirectoryListing] = []
for directory in directories:
directory_entries.extend(
_DirectoryListing(file_name, os.path.join(directory, file_name))
......@@ -424,10 +424,10 @@ def _upgrade_existing_database(
directories.append(os.path.join(schema_path, database, "delta", str(v)))
# Used to check if we have any duplicate file names
file_name_counter = Counter() # type: CounterType[str]
file_name_counter: CounterType[str] = Counter()
# Now find which directories have anything of interest.
directory_entries = [] # type: List[_DirectoryListing]
directory_entries: List[_DirectoryListing] = []
for directory in directories:
logger.debug("Looking for schema deltas in %s", directory)
try:
......
......@@ -91,7 +91,7 @@ class StateFilter:
Returns:
The new state filter.
"""
type_dict = {} # type: Dict[str, Optional[Set[str]]]
type_dict: Dict[str, Optional[Set[str]]] = {}
for typ, s in types:
if typ in type_dict:
if type_dict[typ] is None:
......@@ -194,7 +194,7 @@ class StateFilter:
"""
where_clause = ""
where_args = [] # type: List[str]
where_args: List[str] = []
if self.is_full():
return where_clause, where_args
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment