diff --git a/synapse/api/auth/__init__.py b/synapse/api/auth/__init__.py index d253938329b..0dd520cdae7 100644 --- a/synapse/api/auth/__init__.py +++ b/synapse/api/auth/__init__.py @@ -188,6 +188,22 @@ def get_access_token_from_request(request: Request) -> str: request """ + @staticmethod + def get_ip_address_from_request(request: Request) -> str: + """ + Extract the IPv4 or IPv6 address from a client request. + + Args: + request: The request to process. + + Returns: + The IPv4 or IPv6 address of the client. + + Raises: + SynapseError: If an IP address could not be extracted from the + request. + """ + async def check_user_in_room_or_world_readable( self, room_id: str, requester: Requester, allow_departed_users: bool = False ) -> Tuple[str, Optional[str]]: diff --git a/synapse/api/auth/base.py b/synapse/api/auth/base.py index 76c8c71628a..e78ad7b21b6 100644 --- a/synapse/api/auth/base.py +++ b/synapse/api/auth/base.py @@ -19,10 +19,12 @@ # # import logging +from http import HTTPStatus from typing import TYPE_CHECKING, Optional, Tuple from netaddr import IPAddress +from twisted.internet.address import IPv4Address, IPv6Address from twisted.web.server import Request from synapse import event_auth @@ -31,6 +33,7 @@ AuthError, Codes, MissingClientTokenError, + SynapseError, UnstableSpecAuthError, ) from synapse.appservice import ApplicationService @@ -291,6 +294,37 @@ def get_access_token_from_request(request: Request) -> str: return query_params[0].decode("ascii") + @staticmethod + def get_ip_address_from_request(request: Request) -> str: + """ + Extract the IPv4 or IPv6 address from a client request. + + Args: + request: The request to process. + + Returns: + The IPv4 or IPv6 address of the client. + + Raises: + SynapseError: If an IP address could not be extracted from the + request. + """ + client_address = request.getClientAddress() + if not isinstance(client_address, IPv4Address) and not isinstance( + client_address, IPv6Address + ): + logger.error( + "Unable to view IP address of the requester. " \ + "Check that you are setting the X-Forwarded-For header correctly in your reverse proxy." + ) + raise SynapseError( + HTTPStatus.INTERNAL_SERVER_ERROR, + "Unable to read client IP address", + Codes.UNKNOWN, + ) + + return client_address.host + @cancellable async def get_appservice_user( self, request: Request, access_token: str @@ -326,7 +360,8 @@ async def get_appservice_user( return None if app_service.ip_range_whitelist: - ip_address = IPAddress(request.getClientAddress().host) + ip_address_str = self.get_ip_address_from_request(request) + ip_address = IPAddress(ip_address_str) if ip_address not in app_service.ip_range_whitelist: return None diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 2d1990cce5b..fba86bb7c20 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -567,7 +567,7 @@ async def check_ui_auth( await self.store.set_ui_auth_clientdict(sid, clientdict) user_agent = get_request_user_agent(request) - clientip = request.getClientAddress().host + clientip = self.auth.get_ip_address_from_request(request) await self.store.add_user_agent_ip_to_ui_auth_session( session.session_id, user_agent, clientip diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index be757201fc6..fa49f47db2c 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -57,6 +57,7 @@ class IdentityHandler: def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() self.store = hs.get_datastores().main # An HTTP client for contacting trusted URLs. self.http_client = SimpleHttpClient(hs) @@ -97,9 +98,8 @@ async def ratelimit_request_token_requests( address: The actual threepid ID, e.g. the phone number or email address """ - await self._3pid_validation_ratelimiter_ip.ratelimit( - None, (medium, request.getClientAddress().host) - ) + ip_address = self._auth.get_ip_address_from_request(request) + await self._3pid_validation_ratelimiter_ip.ratelimit(None, (medium, ip_address)) await self._3pid_validation_ratelimiter_address.ratelimit( None, (medium, address) ) diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 735cfa0a0f8..cc13a284ca7 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -205,6 +205,7 @@ def __init__(self, hs: "HomeServer"): self.server_name = hs.hostname self._is_mine_server_name = hs.is_mine_server_name self._registration_handler = hs.get_registration_handler() + self._auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() self._error_template = hs.config.sso.sso_error_template @@ -505,12 +506,13 @@ async def complete_sso_login_request( auth_provider_session_id, ) + ip_address = self._auth.get_ip_address_from_request(request) user_id = await self._register_mapped_user( attributes, auth_provider_id, remote_user_id, get_request_user_agent(request), - request.getClientAddress().host, + ip_address, ) new_user = True elif self._sso_update_profile_information: @@ -1080,6 +1082,8 @@ async def register_sso_user(self, request: Request, session_id: str) -> None: if session.use_avatar: attributes.picture = session.avatar_url + ip_address = self._auth.get_ip_address_from_request(request) + # the following will raise a 400 error if the username has been taken in the # meantime. user_id = await self._register_mapped_user( @@ -1087,7 +1091,7 @@ async def register_sso_user(self, request: Request, session_id: str) -> None: session.auth_provider_id, session.remote_user_id, get_request_user_agent(request), - request.getClientAddress().host, + ip_address, ) logger.info( diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py index 600bb51a7e7..c74af357463 100644 --- a/synapse/rest/client/auth.py +++ b/synapse/rest/client/auth.py @@ -134,6 +134,8 @@ async def on_POST(self, request: Request, stagetype: str) -> None: if not session: raise SynapseError(400, "No session supplied") + ip_address = self.auth.get_ip_address_from_request(request) + if stagetype == LoginType.RECAPTCHA: response = parse_string(request, "g-recaptcha-response") @@ -144,7 +146,9 @@ async def on_POST(self, request: Request, stagetype: str) -> None: try: await self.auth_handler.add_oob_auth( - LoginType.RECAPTCHA, authdict, request.getClientAddress().host + LoginType.RECAPTCHA, + authdict, + ip_address, ) except LoginError as e: # Authentication failed, let user try again @@ -164,7 +168,9 @@ async def on_POST(self, request: Request, stagetype: str) -> None: try: await self.auth_handler.add_oob_auth( - LoginType.TERMS, authdict, request.getClientAddress().host + LoginType.TERMS, + authdict, + ip_address, ) except LoginError as e: # Authentication failed, let user try again @@ -195,7 +201,7 @@ async def on_POST(self, request: Request, stagetype: str) -> None: await self.auth_handler.add_oob_auth( LoginType.REGISTRATION_TOKEN, authdict, - request.getClientAddress().host, + ip_address, ) except LoginError as e: html = self.registration_token_template.render( diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index 921232a3ea4..5ee4ba1176f 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -205,6 +205,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, LoginResponse]: ) request_info = request.request_info() + ip_address = self.auth.get_ip_address_from_request(request) try: if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE: @@ -224,9 +225,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, LoginResponse]: ) if appservice.is_rate_limited(): - await self._address_ratelimiter.ratelimit( - None, request.getClientAddress().host - ) + await self._address_ratelimiter.ratelimit(None, ip_address) result = await self._do_appservice_login( login_submission, @@ -238,27 +237,21 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, LoginResponse]: self.jwt_enabled and login_submission["type"] == LoginRestServlet.JWT_TYPE ): - await self._address_ratelimiter.ratelimit( - None, request.getClientAddress().host - ) + await self._address_ratelimiter.ratelimit(None, ip_address) result = await self._do_jwt_login( login_submission, should_issue_refresh_token=should_issue_refresh_token, request_info=request_info, ) elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: - await self._address_ratelimiter.ratelimit( - None, request.getClientAddress().host - ) + await self._address_ratelimiter.ratelimit(None, ip_address) result = await self._do_token_login( login_submission, should_issue_refresh_token=should_issue_refresh_token, request_info=request_info, ) else: - await self._address_ratelimiter.ratelimit( - None, request.getClientAddress().host - ) + await self._address_ratelimiter.ratelimit(None, ip_address) result = await self._do_other_login( login_submission, should_issue_refresh_token=should_issue_refresh_token, diff --git a/synapse/rest/client/media.py b/synapse/rest/client/media.py index 4c044ae900e..43732bae06d 100644 --- a/synapse/rest/client/media.py +++ b/synapse/rest/client/media.py @@ -192,7 +192,8 @@ async def on_GET( respond_404(request) return - ip_address = request.getClientAddress().host + ip_address = self.auth.get_ip_address_from_request(request) + remote_resp_function = ( self.thumbnailer.select_or_generate_remote_thumbnail if self.dynamic_thumbnails @@ -263,7 +264,8 @@ async def on_GET( request, media_id, file_name, max_timeout_ms ) else: - ip_address = request.getClientAddress().host + ip_address = self.auth.get_ip_address_from_request(request) + await self.media_repo.get_remote_media( request, server_name, diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index b42006e4cee..3798f46ae13 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -329,6 +329,7 @@ class UsernameAvailabilityRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs + self._auth = hs.get_auth() self.server_name = hs.hostname self.registration_handler = hs.get_registration_handler() self.ratelimiter = FederationRateLimiter( @@ -361,7 +362,7 @@ async def on_GET(self, request: Request) -> Tuple[int, JsonDict]: if self.inhibit_user_in_use_error: return 200, {"available": True} - ip = request.getClientAddress().host + ip = self._auth.get_ip_address_from_request(request) with self.ratelimiter.ratelimit(ip) as wait_deferred: await wait_deferred @@ -395,6 +396,7 @@ class RegistrationTokenValidityRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs + self._auth = hs.get_auth() self.store = hs.get_datastores().main self.ratelimiter = Ratelimiter( store=self.store, @@ -403,7 +405,8 @@ def __init__(self, hs: "HomeServer"): ) async def on_GET(self, request: Request) -> Tuple[int, JsonDict]: - await self.ratelimiter.ratelimit(None, (request.getClientAddress().host,)) + ip_address = self._auth.get_ip_address_from_request(request) + await self.ratelimiter.ratelimit(None, (ip_address,)) if not self.hs.config.registration.enable_registration: raise SynapseError( @@ -456,7 +459,7 @@ def __init__(self, hs: "HomeServer"): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: body = parse_json_object_from_request(request) - client_addr = request.getClientAddress().host + client_addr = self.auth.get_ip_address_from_request(request) await self.ratelimiter.ratelimit(None, client_addr, update=False) @@ -916,7 +919,7 @@ def __init__(self, hs: "HomeServer"): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: body = parse_json_object_from_request(request) - client_addr = request.getClientAddress().host + client_addr = self.auth.get_ip_address_from_request(request) await self.ratelimiter.ratelimit(None, client_addr, update=False) diff --git a/synapse/rest/media/download_resource.py b/synapse/rest/media/download_resource.py index 3c3f703667a..2a29cfd1c9d 100644 --- a/synapse/rest/media/download_resource.py +++ b/synapse/rest/media/download_resource.py @@ -49,6 +49,7 @@ class DownloadResource(RestServlet): def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"): super().__init__() + self._auth = hs.get_auth() self.media_repo = media_repo self._is_mine_server_name = hs.is_mine_server_name @@ -97,7 +98,7 @@ async def on_GET( respond_404(request) return - ip_address = request.getClientAddress().host + ip_address = self._auth.get_ip_address_from_request(request) await self.media_repo.get_remote_media( request, server_name, diff --git a/synapse/rest/media/thumbnail_resource.py b/synapse/rest/media/thumbnail_resource.py index 536fea4c32f..5701f0d99e0 100644 --- a/synapse/rest/media/thumbnail_resource.py +++ b/synapse/rest/media/thumbnail_resource.py @@ -58,6 +58,7 @@ def __init__( ): super().__init__() + self._auth = hs.get_auth() self.store = hs.get_datastores().main self.media_repo = media_repo self.media_storage = media_storage @@ -120,7 +121,7 @@ async def on_GET( respond_404(request) return - ip_address = request.getClientAddress().host + ip_address = self._auth.get_ip_address_from_request(request) remote_resp_function = ( self.thumbnail_provider.select_or_generate_remote_thumbnail if self.dynamic_thumbnails diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index e7fcd928d71..1907a33112e 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -23,6 +23,7 @@ import pymacaroons +from twisted.internet.address import IPv4Address from twisted.internet.testing import MemoryReactor from synapse.api.auth.internal import InternalAuth @@ -118,7 +119,7 @@ def test_get_user_by_req_appservice_valid_token(self) -> None: self.store.get_user_by_access_token = AsyncMock(return_value=None) request = Mock(args={}) - request.getClientAddress.return_value.host = "127.0.0.1" + request.getClientAddress.return_value = IPv4Address(type="TCP", host="127.0.0.1", port=12345) request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = self.get_success(self.auth.get_user_by_req(request)) @@ -137,7 +138,7 @@ def test_get_user_by_req_appservice_valid_token_good_ip(self) -> None: self.store.get_user_by_access_token = AsyncMock(return_value=None) request = Mock(args={}) - request.getClientAddress.return_value.host = "192.168.10.10" + request.getClientAddress.return_value = IPv4Address(type="TCP", host="192.168.10.10", port=12345) request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = self.get_success(self.auth.get_user_by_req(request)) @@ -156,7 +157,7 @@ def test_get_user_by_req_appservice_valid_token_bad_ip(self) -> None: self.store.get_user_by_access_token = AsyncMock(return_value=None) request = Mock(args={}) - request.getClientAddress.return_value.host = "131.111.8.42" + request.getClientAddress.return_value = IPv4Address(type="TCP", host="131.111.8.42", port=12345) request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() f = self.get_failure( @@ -209,7 +210,7 @@ class FakeUserInfo: self.store.get_user_by_access_token = AsyncMock(return_value=None) request = Mock(args={}) - request.getClientAddress.return_value.host = "127.0.0.1" + request.getClientAddress.return_value = IPv4Address(type="TCP", host="127.0.0.1", port=12345) request.args[b"access_token"] = [self.test_token] request.args[b"user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = mock_getRawHeaders() @@ -231,7 +232,7 @@ def test_get_user_by_req_appservice_valid_token_bad_user_id(self) -> None: self.store.get_user_by_access_token = AsyncMock(return_value=None) request = Mock(args={}) - request.getClientAddress.return_value.host = "127.0.0.1" + request.getClientAddress.return_value = IPv4Address(type="TCP", host="127.0.0.1", port=12345) request.args[b"access_token"] = [self.test_token] request.args[b"user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = mock_getRawHeaders() @@ -261,7 +262,7 @@ def test_get_user_by_req_appservice_valid_token_valid_device_id(self) -> None: self.store.get_device = AsyncMock(return_value={"hidden": False}) request = Mock(args={}) - request.getClientAddress.return_value.host = "127.0.0.1" + request.getClientAddress.return_value = IPv4Address(type="TCP", host="127.0.0.1", port=12345) request.args[b"access_token"] = [self.test_token] request.args[b"user_id"] = [masquerading_user_id] request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id] @@ -296,7 +297,7 @@ def test_get_user_by_req_appservice_valid_token_invalid_device_id(self) -> None: self.store.get_device = AsyncMock(return_value=None) request = Mock(args={}) - request.getClientAddress.return_value.host = "127.0.0.1" + request.getClientAddress.return_value = IPv4Address(type="TCP", host="127.0.0.1", port=12345) request.args[b"access_token"] = [self.test_token] request.args[b"user_id"] = [masquerading_user_id] request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id] @@ -320,7 +321,7 @@ def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self) -> Non self.store.mark_access_token_as_used = AsyncMock(return_value=None) self.store.get_user_locked_status = AsyncMock(return_value=False) request = Mock(args={}) - request.getClientAddress.return_value.host = "127.0.0.1" + request.getClientAddress.return_value = IPv4Address(type="TCP", host="127.0.0.1", port=12345) request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() self.get_success(self.auth.get_user_by_req(request)) @@ -341,7 +342,7 @@ def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self) -> None: self.store.insert_client_ip = AsyncMock(return_value=None) self.store.mark_access_token_as_used = AsyncMock(return_value=None) request = Mock(args={}) - request.getClientAddress.return_value.host = "127.0.0.1" + request.getClientAddress.return_value = IPv4Address(type="TCP", host="127.0.0.1", port=12345) request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() self.get_success(self.auth.get_user_by_req(request)) diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py index f677f3be2a6..1135b38f2fb 100644 --- a/tests/handlers/test_cas.py +++ b/tests/handlers/test_cas.py @@ -21,6 +21,7 @@ from typing import Any, Dict from unittest.mock import AsyncMock, Mock +from twisted.internet.address import IPv4Address from twisted.internet.testing import MemoryReactor from synapse.handlers.cas import CasResponse @@ -234,6 +235,7 @@ def _mock_request() -> Mock: "write", ] ) + mock.getClientAddress.return_value = IPv4Address(type="TCP", host="127.0.0.1", port=12345) # `_disconnected` musn't be another `Mock`, otherwise it will be truthy. mock._disconnected = False return mock diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 5207382f00f..37a006fa0f7 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -25,6 +25,7 @@ import pymacaroons +from twisted.internet.address import IPv4Address from twisted.internet.testing import MemoryReactor from synapse.handlers.sso import MappingException @@ -1684,5 +1685,5 @@ def _build_callback_request( request.args = {} request.args[b"code"] = [code.encode("utf-8")] request.args[b"state"] = [state.encode("utf-8")] - request.getClientAddress.return_value.host = ip_address + request.getClientAddress.return_value = IPv4Address(type="TCP", host=ip_address, port=12345) return request diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index f7cbf911139..77c07564d96 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -24,6 +24,7 @@ import attr +from twisted.internet.address import IPv4Address from twisted.internet.testing import MemoryReactor from synapse.api.errors import RedirectException @@ -424,4 +425,5 @@ def _mock_request() -> Mock: ) # `_disconnected` musn't be another `Mock`, otherwise it will be truthy. mock._disconnected = False + mock.getClientAddress.return_value = IPv4Address(type="TCP", host="127.0.0.1", port=12345) return mock