Unverified Commit 98aec1cc authored by Jonathan de Jong's avatar Jonathan de Jong Committed by GitHub
Browse files

Use inline type hints in `handlers/` and `rest/`. (#10382)

parent 36dc1541
Convert internal type variable syntax to reflect wider ecosystem use.
\ No newline at end of file
......@@ -38,10 +38,10 @@ class BaseHandler:
"""
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() # type: synapse.storage.DataStore
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.notifier = hs.get_notifier()
self.state_handler = hs.get_state_handler() # type: synapse.state.StateHandler
self.state_handler = hs.get_state_handler()
self.distributor = hs.get_distributor()
self.clock = hs.get_clock()
self.hs = hs
......@@ -55,12 +55,12 @@ class BaseHandler:
# Check whether ratelimiting room admin message redaction is enabled
# by the presence of rate limits in the config
if self.hs.config.rc_admin_redaction:
self.admin_redaction_ratelimiter = Ratelimiter(
self.admin_redaction_ratelimiter: Optional[Ratelimiter] = Ratelimiter(
store=self.store,
clock=self.clock,
rate_hz=self.hs.config.rc_admin_redaction.per_second,
burst_count=self.hs.config.rc_admin_redaction.burst_count,
) # type: Optional[Ratelimiter]
)
else:
self.admin_redaction_ratelimiter = None
......
......@@ -139,7 +139,7 @@ class AdminHandler(BaseHandler):
to_key = RoomStreamToken(None, stream_ordering)
# Events that we've processed in this room
written_events = set() # type: Set[str]
written_events: Set[str] = set()
# We need to track gaps in the events stream so that we can then
# write out the state at those events. We do this by keeping track
......@@ -152,7 +152,7 @@ class AdminHandler(BaseHandler):
# The reverse mapping to above, i.e. map from unseen event to events
# that have the unseen event in their prev_events, i.e. the unseen
# events "children".
unseen_to_child_events = {} # type: Dict[str, Set[str]]
unseen_to_child_events: Dict[str, Set[str]] = {}
# We fetch events in the room the user could see by fetching *all*
# events that we have and then filtering, this isn't the most
......
......@@ -96,7 +96,7 @@ class ApplicationServicesHandler:
self.current_max, limit
)
events_by_room = {} # type: Dict[str, List[EventBase]]
events_by_room: Dict[str, List[EventBase]] = {}
for event in events:
events_by_room.setdefault(event.room_id, []).append(event)
......@@ -275,7 +275,7 @@ class ApplicationServicesHandler:
async def _handle_presence(
self, service: ApplicationService, users: Collection[Union[str, UserID]]
) -> List[JsonDict]:
events = [] # type: List[JsonDict]
events: List[JsonDict] = []
presence_source = self.event_sources.sources["presence"]
from_key = await self.store.get_type_stream_id_for_appservice(
service, "presence"
......@@ -375,7 +375,7 @@ class ApplicationServicesHandler:
self, only_protocol: Optional[str] = None
) -> Dict[str, JsonDict]:
services = self.store.get_app_services()
protocols = {} # type: Dict[str, List[JsonDict]]
protocols: Dict[str, List[JsonDict]] = {}
# Collect up all the individual protocol responses out of the ASes
for s in services:
......
......@@ -191,7 +191,7 @@ class AuthHandler(BaseHandler):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker]
self.checkers: Dict[str, UserInteractiveAuthChecker] = {}
for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
inst = auth_checker_class(hs)
if inst.is_enabled():
......@@ -296,7 +296,7 @@ class AuthHandler(BaseHandler):
# A mapping of user ID to extra attributes to include in the login
# response.
self._extra_attributes = {} # type: Dict[str, SsoLoginExtraAttributes]
self._extra_attributes: Dict[str, SsoLoginExtraAttributes] = {}
async def validate_user_via_ui_auth(
self,
......@@ -500,7 +500,7 @@ class AuthHandler(BaseHandler):
all the stages in any of the permitted flows.
"""
sid = None # type: Optional[str]
sid: Optional[str] = None
authdict = clientdict.pop("auth", {})
if "session" in authdict:
sid = authdict["session"]
......@@ -588,9 +588,9 @@ class AuthHandler(BaseHandler):
)
# check auth type currently being presented
errordict = {} # type: Dict[str, Any]
errordict: Dict[str, Any] = {}
if "type" in authdict:
login_type = authdict["type"] # type: str
login_type: str = authdict["type"]
try:
result = await self._check_auth_dict(authdict, clientip)
if result:
......@@ -766,7 +766,7 @@ class AuthHandler(BaseHandler):
LoginType.TERMS: self._get_params_terms,
}
params = {} # type: Dict[str, Any]
params: Dict[str, Any] = {}
for f in public_flows:
for stage in f:
......@@ -1530,9 +1530,9 @@ class AuthHandler(BaseHandler):
except StoreError:
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
user_id_to_verify = await self.get_session_data(
user_id_to_verify: str = await self.get_session_data(
session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
) # type: str
)
idps = await self.hs.get_sso_handler().get_identity_providers_for_user(
user_id_to_verify
......
......@@ -171,7 +171,7 @@ class CasHandler:
# Iterate through the nodes and pull out the user and any extra attributes.
user = None
attributes = {} # type: Dict[str, List[Optional[str]]]
attributes: Dict[str, List[Optional[str]]] = {}
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
......
......@@ -452,7 +452,7 @@ class DeviceHandler(DeviceWorkerHandler):
user_id
)
hosts = set() # type: Set[str]
hosts: Set[str] = set()
if self.hs.is_mine_id(user_id):
hosts.update(get_domain_from_id(u) for u in users_who_share_room)
hosts.discard(self.server_name)
......@@ -613,20 +613,20 @@ class DeviceListUpdater:
self._remote_edu_linearizer = Linearizer(name="remote_device_list")
# user_id -> list of updates waiting to be handled.
self._pending_updates = (
{}
) # type: Dict[str, List[Tuple[str, str, Iterable[str], JsonDict]]]
self._pending_updates: Dict[
str, List[Tuple[str, str, Iterable[str], JsonDict]]
] = {}
# Recently seen stream ids. We don't bother keeping these in the DB,
# but they're useful to have them about to reduce the number of spurious
# resyncs.
self._seen_updates = ExpiringCache(
self._seen_updates: ExpiringCache[str, Set[str]] = ExpiringCache(
cache_name="device_update_edu",
clock=self.clock,
max_len=10000,
expiry_ms=30 * 60 * 1000,
iterable=True,
) # type: ExpiringCache[str, Set[str]]
)
# Attempt to resync out of sync device lists every 30s.
self._resync_retry_in_progress = False
......@@ -755,7 +755,7 @@ class DeviceListUpdater:
"""Given a list of updates for a user figure out if we need to do a full
resync, or whether we have enough data that we can just apply the delta.
"""
seen_updates = self._seen_updates.get(user_id, set()) # type: Set[str]
seen_updates: Set[str] = self._seen_updates.get(user_id, set())
extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id)
......
......@@ -203,7 +203,7 @@ class DeviceMessageHandler:
log_kv({"number_of_to_device_messages": len(messages)})
set_tag("sender", sender_user_id)
local_messages = {}
remote_messages = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for user_id, by_device in messages.items():
# Ratelimit local cross-user key requests by the sending device.
if (
......
......@@ -237,9 +237,9 @@ class DirectoryHandler(BaseHandler):
async def get_association(self, room_alias: RoomAlias) -> JsonDict:
room_id = None
if self.hs.is_mine(room_alias):
result = await self.get_association_from_room_alias(
room_alias
) # type: Optional[RoomAliasMapping]
result: Optional[
RoomAliasMapping
] = await self.get_association_from_room_alias(room_alias)
if result:
room_id = result.room_id
......
......@@ -115,9 +115,9 @@ class E2eKeysHandler:
the number of in-flight queries at a time.
"""
with await self._query_devices_linearizer.queue((from_user_id, from_device_id)):
device_keys_query = query_body.get(
device_keys_query: Dict[str, Iterable[str]] = query_body.get(
"device_keys", {}
) # type: Dict[str, Iterable[str]]
)
# separate users by domain.
# make a map from domain to user_id to device_ids
......@@ -136,7 +136,7 @@ class E2eKeysHandler:
# First get local devices.
# A map of destination -> failure response.
failures = {} # type: Dict[str, JsonDict]
failures: Dict[str, JsonDict] = {}
results = {}
if local_query:
local_result = await self.query_local_devices(local_query)
......@@ -151,11 +151,9 @@ class E2eKeysHandler:
# Now attempt to get any remote devices from our local cache.
# A map of destination -> user ID -> device IDs.
remote_queries_not_in_cache = (
{}
) # type: Dict[str, Dict[str, Iterable[str]]]
remote_queries_not_in_cache: Dict[str, Dict[str, Iterable[str]]] = {}
if remote_queries:
query_list = [] # type: List[Tuple[str, Optional[str]]]
query_list: List[Tuple[str, Optional[str]]] = []
for user_id, device_ids in remote_queries.items():
if device_ids:
query_list.extend(
......@@ -362,9 +360,9 @@ class E2eKeysHandler:
A map from user_id -> device_id -> device details
"""
set_tag("local_query", query)
local_query = [] # type: List[Tuple[str, Optional[str]]]
local_query: List[Tuple[str, Optional[str]]] = []
result_dict = {} # type: Dict[str, Dict[str, dict]]
result_dict: Dict[str, Dict[str, dict]] = {}
for user_id, device_ids in query.items():
# we use UserID.from_string to catch invalid user ids
if not self.is_mine(UserID.from_string(user_id)):
......@@ -402,9 +400,9 @@ class E2eKeysHandler:
self, query_body: Dict[str, Dict[str, Optional[List[str]]]]
) -> JsonDict:
"""Handle a device key query from a federated server"""
device_keys_query = query_body.get(
device_keys_query: Dict[str, Optional[List[str]]] = query_body.get(
"device_keys", {}
) # type: Dict[str, Optional[List[str]]]
)
res = await self.query_local_devices(device_keys_query)
ret = {"device_keys": res}
......@@ -421,8 +419,8 @@ class E2eKeysHandler:
async def claim_one_time_keys(
self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int
) -> JsonDict:
local_query = [] # type: List[Tuple[str, str, str]]
remote_queries = {} # type: Dict[str, Dict[str, Dict[str, str]]]
local_query: List[Tuple[str, str, str]] = []
remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
for user_id, one_time_keys in query.get("one_time_keys", {}).items():
# we use UserID.from_string to catch invalid user ids
......@@ -439,8 +437,8 @@ class E2eKeysHandler:
results = await self.store.claim_e2e_one_time_keys(local_query)
# A map of user ID -> device ID -> key ID -> key.
json_result = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
failures = {} # type: Dict[str, JsonDict]
json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
failures: Dict[str, JsonDict] = {}
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
for key_id, json_str in keys.items():
......@@ -768,8 +766,8 @@ class E2eKeysHandler:
Raises:
SynapseError: if the input is malformed
"""
signature_list = [] # type: List[SignatureListItem]
failures = {} # type: Dict[str, Dict[str, JsonDict]]
signature_list: List["SignatureListItem"] = []
failures: Dict[str, Dict[str, JsonDict]] = {}
if not signatures:
return signature_list, failures
......@@ -930,8 +928,8 @@ class E2eKeysHandler:
Raises:
SynapseError: if the input is malformed
"""
signature_list = [] # type: List[SignatureListItem]
failures = {} # type: Dict[str, Dict[str, JsonDict]]
signature_list: List["SignatureListItem"] = []
failures: Dict[str, Dict[str, JsonDict]] = {}
if not signatures:
return signature_list, failures
......@@ -1300,7 +1298,7 @@ class SigningKeyEduUpdater:
self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
# user_id -> list of updates waiting to be handled.
self._pending_updates = {} # type: Dict[str, List[Tuple[JsonDict, JsonDict]]]
self._pending_updates: Dict[str, List[Tuple[JsonDict, JsonDict]]] = {}
async def incoming_signing_key_update(
self, origin: str, edu_content: JsonDict
......@@ -1349,7 +1347,7 @@ class SigningKeyEduUpdater:
# This can happen since we batch updates
return
device_ids = [] # type: List[str]
device_ids: List[str] = []
logger.info("pending updates: %r", pending_updates)
......
......@@ -93,7 +93,7 @@ class EventStreamHandler(BaseHandler):
# When the user joins a new room, or another user joins a currently
# joined room, we need to send down presence for those users.
to_add = [] # type: List[JsonDict]
to_add: List[JsonDict] = []
for event in events:
if not isinstance(event, EventBase):
continue
......@@ -103,9 +103,9 @@ class EventStreamHandler(BaseHandler):
# Send down presence.
if event.state_key == auth_user_id:
# Send down presence for everyone in the room.
users = await self.store.get_users_in_room(
users: Iterable[str] = await self.store.get_users_in_room(
event.room_id
) # type: Iterable[str]
)
else:
users = [event.state_key]
......
......@@ -181,7 +181,7 @@ class FederationHandler(BaseHandler):
# When joining a room we need to queue any events for that room up.
# For each room, a list of (pdu, origin) tuples.
self.room_queues = {} # type: Dict[str, List[Tuple[EventBase, str]]]
self.room_queues: Dict[str, List[Tuple[EventBase, str]]] = {}
self._room_pdu_linearizer = Linearizer("fed_room_pdu")
self._room_backfill = Linearizer("room_backfill")
......@@ -368,7 +368,7 @@ class FederationHandler(BaseHandler):
ours = await self.state_store.get_state_groups_ids(room_id, seen)
# state_maps is a list of mappings from (type, state_key) to event_id
state_maps = list(ours.values()) # type: List[StateMap[str]]
state_maps: List[StateMap[str]] = list(ours.values())
# we don't need this any more, let's delete it.
del ours
......@@ -845,7 +845,7 @@ class FederationHandler(BaseHandler):
# exact key to expect. Otherwise check it matches any key we
# have for that device.
current_keys = [] # type: Container[str]
current_keys: Container[str] = []
if device:
keys = device.get("keys", {}).get("keys", {})
......@@ -1185,7 +1185,7 @@ class FederationHandler(BaseHandler):
if e_type == EventTypes.Member and event.membership == Membership.JOIN
]
joined_domains = {} # type: Dict[str, int]
joined_domains: Dict[str, int] = {}
for u, d in joined_users:
try:
dom = get_domain_from_id(u)
......@@ -1314,7 +1314,7 @@ class FederationHandler(BaseHandler):
room_version = await self.store.get_room_version(room_id)
event_map = {} # type: Dict[str, EventBase]
event_map: Dict[str, EventBase] = {}
async def get_event(event_id: str):
with nested_logging_context(event_id):
......@@ -1596,7 +1596,7 @@ class FederationHandler(BaseHandler):
# Ask the remote server to create a valid knock event for us. Once received,
# we sign the event
params = {"ver": supported_room_versions} # type: Dict[str, Iterable[str]]
params: Dict[str, Iterable[str]] = {"ver": supported_room_versions}
origin, event, event_format_version = await self._make_and_verify_event(
target_hosts, room_id, knockee, Membership.KNOCK, content, params=params
)
......@@ -2453,14 +2453,14 @@ class FederationHandler(BaseHandler):
state_sets_d = await self.state_store.get_state_groups(
event.room_id, extrem_ids
)
state_sets = list(state_sets_d.values()) # type: List[Iterable[EventBase]]
state_sets: List[Iterable[EventBase]] = list(state_sets_d.values())
state_sets.append(state)
current_states = await self.state_handler.resolve_events(
room_version, state_sets, event
)
current_state_ids = {
current_state_ids: StateMap[str] = {
k: e.event_id for k, e in current_states.items()
} # type: StateMap[str]
}
else:
current_state_ids = await self.state_handler.get_current_state_ids(
event.room_id, latest_event_ids=extrem_ids
......@@ -2817,7 +2817,7 @@ class FederationHandler(BaseHandler):
"""
# exclude the state key of the new event from the current_state in the context.
if event.is_state():
event_key = (event.type, event.state_key) # type: Optional[Tuple[str, str]]
event_key: Optional[Tuple[str, str]] = (event.type, event.state_key)
else:
event_key = None
state_updates = {
......@@ -3156,7 +3156,7 @@ class FederationHandler(BaseHandler):
logger.debug("Checking auth on event %r", event.content)
last_exception = None # type: Optional[Exception]
last_exception: Optional[Exception] = None
# for each public key in the 3pid invite event
for public_key_object in event_auth.get_public_keys(invite_event):
......
......@@ -214,7 +214,7 @@ class GroupsLocalWorkerHandler:
async def bulk_get_publicised_groups(
self, user_ids: Iterable[str], proxy: bool = True
) -> JsonDict:
destinations = {} # type: Dict[str, Set[str]]
destinations: Dict[str, Set[str]] = {}
local_users = set()
for user_id in user_ids:
......@@ -227,7 +227,7 @@ class GroupsLocalWorkerHandler:
raise SynapseError(400, "Some user_ids are not local")
results = {}
failed_results = [] # type: List[str]
failed_results: List[str] = []
for destination, dest_user_ids in destinations.items():
try:
r = await self.transport_client.bulk_get_publicised_groups(
......
......@@ -46,9 +46,17 @@ class InitialSyncHandler(BaseHandler):
self.state = hs.get_state_handler()
self.clock = hs.get_clock()
self.validator = EventValidator()
self.snapshot_cache = ResponseCache(
hs.get_clock(), "initial_sync_cache"
) # type: ResponseCache[Tuple[str, Optional[StreamToken], Optional[StreamToken], str, Optional[int], bool, bool]]
self.snapshot_cache: ResponseCache[
Tuple[
str,
Optional[StreamToken],
Optional[StreamToken],
str,
Optional[int],
bool,
bool,
]
] = ResponseCache(hs.get_clock(), "initial_sync_cache")
self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage()
self.state_store = self.storage.state
......
......@@ -81,7 +81,7 @@ class MessageHandler:
# The scheduled call to self._expire_event. None if no call is currently
# scheduled.
self._scheduled_expiry = None # type: Optional[IDelayedCall]
self._scheduled_expiry: Optional[IDelayedCall] = None
if not hs.config.worker_app:
run_as_background_process(
......@@ -196,9 +196,7 @@ class MessageHandler:
room_state_events = await self.state_store.get_state_for_events(
[event.event_id], state_filter=state_filter
)
room_state = room_state_events[
event.event_id
] # type: Mapping[Any, EventBase]
room_state: Mapping[Any, EventBase] = room_state_events[event.event_id]
else:
raise AuthError(
403,
......@@ -421,9 +419,9 @@ class EventCreationHandler:
self.action_generator = hs.get_action_generator()
self.spam_checker = hs.get_spam_checker()
self.third_party_event_rules = (
self.third_party_event_rules: "ThirdPartyEventRules" = (
self.hs.get_third_party_event_rules()
) # type: ThirdPartyEventRules
)
self._block_events_without_consent_error = (
self.config.block_events_without_consent_error
......@@ -440,7 +438,7 @@ class EventCreationHandler:
#
# map from room id to time-of-last-attempt.
#
self._rooms_to_exclude_from_dummy_event_insertion = {} # type: Dict[str, int]
self._rooms_to_exclude_from_dummy_event_insertion: Dict[str, int] = {}
# The number of forward extremeities before a dummy event is sent.
self._dummy_events_threshold = hs.config.dummy_events_threshold
......@@ -465,9 +463,7 @@ class EventCreationHandler:
# Stores the state groups we've recently added to the joined hosts
# external cache. Note that the timeout must be significantly less than
# the TTL on the external cache.
self._external_cache_joined_hosts_updates = (
None
) # type: Optional[ExpiringCache]
self._external_cache_joined_hosts_updates: Optional[ExpiringCache] = None
if self._external_cache.is_enabled():
self._external_cache_joined_hosts_updates = ExpiringCache(
"_external_cache_joined_hosts_updates",
......@@ -1299,7 +1295,7 @@ class EventCreationHandler:
# Validate a newly added alias or newly added alt_aliases.
original_alias = None
original_alt_aliases = [] # type: List[str]
original_alt_aliases: List[str] = []
original_event_id = event.unsigned.get("replaces_state")
if original_event_id:
......
......@@ -105,9 +105,9 @@ class OidcHandler:
assert provider_confs
self._token_generator = OidcSessionTokenGenerator(hs)
self._providers = {
self._providers: Dict[str, "OidcProvider"] = {
p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs
} # type: Dict[str, OidcProvider]
}
async def load_metadata(self) -> None:
"""Validate the config and load the metadata from the remote endpoint.
......@@ -178,7 +178,7 @@ class OidcHandler:
# are two.
for cookie_name, _ in _SESSION_COOKIES:
session = request.getCookie(cookie_name) # type: Optional[bytes]
session: Optional[bytes] = request.getCookie(cookie_name)
if session is not None:
break
else:
......@@ -277,7 +277,7 @@ class OidcProvider:
self._token_generator = token_generator
self._config = provider
self._callback_url = hs.config.oidc_callback_url # type: str
self._callback_url: str = hs.config.oidc_callback_url
# Calculate the prefix for OIDC callback paths based on the public_baseurl.
# We'll insert this into the Path= parameter of any session cookies we set.
......@@ -290,7 +290,7 @@ class OidcProvider:
self._scopes = provider.scopes
self._user_profile_method = provider.user_profile_method
client_secret = None # type: Union[None, str, JwtClientSecret]
client_secret: Optional[Union[str, JwtClientSecret]] = None
if provider.client_secret:
client_secret = provider.client_secret
elif provider.client_secret_jwt_key:
......@@ -305,7 +305,7 @@ class OidcProvider:
provider.client_id,
client_secret,
provider.client_auth_method,
) # type: ClientAuth
)
self._client_auth_method = provider.client_auth_method
# cache of metadata for the identity provider (endpoint uris, mostly). This is
......@@ -324,7 +324,7 @@ class OidcProvider:
self._allow_existing_users = provider.allow_existing_users
self._http_client = hs.get_proxied_http_client()
self._server_name = hs.config.server_name # type: str
self._server_name: str = hs.config.server_name
# identifier for the external_ids table
self.idp_id = provider.idp_id
......@@ -1381,7 +1381,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
if display_name == "":
display_name = None
emails = [] # type: List[str]
emails: List[str] = []
email = render_template_field(self._config.email_template)