diff --git a/changelog.d/566.feature b/changelog.d/566.feature new file mode 100644 index 00000000..8bd5cdde --- /dev/null +++ b/changelog.d/566.feature @@ -0,0 +1 @@ +Add a config option `homeserver_allow_list` to specify which homeservers can access Sydent. diff --git a/sydent/config/__init__.py b/sydent/config/__init__.py index b8354685..f3c78680 100644 --- a/sydent/config/__init__.py +++ b/sydent/config/__init__.py @@ -51,7 +51,8 @@ # 'prometheus_addr': '', # The address to bind to. Empty string means bind to all. # The following can be added to your local config file to enable sentry support. # 'sentry_dsn': 'https://...' # The DSN has configured in the sentry instance project. - # Whether clients and homeservers can register an association using v1 endpoints. + # Whether clients and homeservers can register an association using v1 endpoints. This + # option is now deprecated and will be superceded by the option `enable_v1_access` "enable_v1_associations": "true", "delete_tokens_on_bind": "true", # Prevent outgoing requests from being sent to the following blacklisted @@ -72,6 +73,12 @@ # This whitelist overrides `ip.blacklist` and defaults to an empty # list. "ip.whitelist": "", + # A list of homeservers that are allowed to register with this identity server. Defaults to + # allowing all homeservers. If a list is specified, the config option `enable_v1_access` must be + # set to 'false'. + "homeserver_allow_list": "", + # If set to 'false', entirely disable access via the V1 api. + "enable_v1_access": "true", }, "db": { "db.file": os.environ.get("SYDENT_DB_PATH", "sydent.db"), diff --git a/sydent/config/general.py b/sydent/config/general.py index 6ed305c7..db81a8d8 100644 --- a/sydent/config/general.py +++ b/sydent/config/general.py @@ -82,10 +82,6 @@ def parse_config(self, cfg: "ConfigParser") -> bool: self.sentry_enabled = cfg.has_option("general", "sentry_dsn") self.sentry_dsn = cfg.get("general", "sentry_dsn", fallback=None) - self.enable_v1_associations = parse_cfg_bool( - cfg.get("general", "enable_v1_associations") - ) - self.delete_tokens_on_bind = parse_cfg_bool( cfg.get("general", "delete_tokens_on_bind") ) @@ -99,6 +95,26 @@ def parse_config(self, cfg: "ConfigParser") -> bool: self.ip_blacklist = generate_ip_set(ip_blacklist) self.ip_whitelist = generate_ip_set(ip_whitelist) + self.enable_v1_access = parse_cfg_bool(cfg.get("general", "enable_v1_access")) + + homeserver_allow_list = list_from_comma_sep_string( + cfg.get("general", "homeserver_allow_list") + ) + if homeserver_allow_list and self.enable_v1_access: + raise RuntimeError( + """The V1 api must be disabled for the `homeserver_allow_list` to function, if you have + specified a `homeserver_allow_list` in the config file please ensure that the config + option `enable_v1_access` is set to 'false'.""" + ) + self.homeserver_allow_list = homeserver_allow_list + + if not self.enable_v1_access: + self.enable_v1_associations = False + else: + self.enable_v1_associations = parse_cfg_bool( + cfg.get("general", "enable_v1_associations") + ) + return False diff --git a/sydent/http/httpserver.py b/sydent/http/httpserver.py index 65e59170..9e1ea86c 100644 --- a/sydent/http/httpserver.py +++ b/sydent/http/httpserver.py @@ -99,40 +99,53 @@ def __init__(self, sydent: "Sydent", lookup_pepper: str) -> None: identity.putChild(b"api", api) identity.putChild(b"v2", v2) identity.putChild(b"versions", VersionsServlet()) - api.putChild(b"v1", v1) - validate.putChild(b"email", email) - validate.putChild(b"msisdn", msisdn) - - validate_v2.putChild(b"email", email_v2) - validate_v2.putChild(b"msisdn", msisdn_v2) - - v1.putChild(b"validate", validate) - - v1.putChild(b"lookup", LookupServlet(sydent)) - v1.putChild(b"bulk_lookup", BulkLookupServlet(sydent)) - - v1.putChild(b"pubkey", pubkey) pubkey.putChild(b"isvalid", PubkeyIsValidServlet(sydent)) pubkey.putChild(b"ed25519:0", Ed25519Servlet(sydent)) pubkey.putChild(b"ephemeral", ephemeralPubkey) ephemeralPubkey.putChild(b"isvalid", EphemeralPubkeyIsValidServlet(sydent)) - threepid_v2.putChild( - b"getValidated3pid", GetValidated3pidServlet(sydent, require_auth=True) - ) - threepid_v2.putChild(b"bind", ThreePidBindServlet(sydent, require_auth=True)) - threepid_v2.putChild(b"unbind", unbind) + # v1 + if self.sydent.config.general.enable_v1_access: + api.putChild(b"v1", v1) + validate.putChild(b"email", email) + validate.putChild(b"msisdn", msisdn) + v1.putChild(b"validate", validate) + + v1.putChild(b"lookup", LookupServlet(sydent)) + v1.putChild(b"bulk_lookup", BulkLookupServlet(sydent)) + + v1.putChild(b"pubkey", pubkey) + + threepid_v1.putChild(b"getValidated3pid", GetValidated3pidServlet(sydent)) + threepid_v1.putChild(b"unbind", unbind) + v1.putChild(b"3pid", threepid_v1) + + email.putChild(b"requestToken", EmailRequestCodeServlet(sydent)) + email.putChild(b"submitToken", EmailValidateCodeServlet(sydent)) + + msisdn.putChild(b"requestToken", MsisdnRequestCodeServlet(sydent)) + msisdn.putChild(b"submitToken", MsisdnValidateCodeServlet(sydent)) + + v1.putChild(b"store-invite", StoreInviteServlet(sydent)) + + v1.putChild(b"sign-ed25519", BlindlySignStuffServlet(sydent)) - threepid_v1.putChild(b"getValidated3pid", GetValidated3pidServlet(sydent)) - threepid_v1.putChild(b"unbind", unbind) if self.sydent.config.general.enable_v1_associations: threepid_v1.putChild(b"bind", ThreePidBindServlet(sydent)) - v1.putChild(b"3pid", threepid_v1) + # v2 + # note v2 loses the /api so goes on 'identity' not 'api' + identity.putChild(b"v2", v2) - email.putChild(b"requestToken", EmailRequestCodeServlet(sydent)) - email.putChild(b"submitToken", EmailValidateCodeServlet(sydent)) + validate_v2.putChild(b"email", email_v2) + validate_v2.putChild(b"msisdn", msisdn_v2) + + threepid_v2.putChild( + b"getValidated3pid", GetValidated3pidServlet(sydent, require_auth=True) + ) + threepid_v2.putChild(b"bind", ThreePidBindServlet(sydent, require_auth=True)) + threepid_v2.putChild(b"unbind", unbind) email_v2.putChild( b"requestToken", EmailRequestCodeServlet(sydent, require_auth=True) @@ -141,9 +154,6 @@ def __init__(self, sydent: "Sydent", lookup_pepper: str) -> None: b"submitToken", EmailValidateCodeServlet(sydent, require_auth=True) ) - msisdn.putChild(b"requestToken", MsisdnRequestCodeServlet(sydent)) - msisdn.putChild(b"submitToken", MsisdnValidateCodeServlet(sydent)) - msisdn_v2.putChild( b"requestToken", MsisdnRequestCodeServlet(sydent, require_auth=True) ) @@ -151,14 +161,6 @@ def __init__(self, sydent: "Sydent", lookup_pepper: str) -> None: b"submitToken", MsisdnValidateCodeServlet(sydent, require_auth=True) ) - v1.putChild(b"store-invite", StoreInviteServlet(sydent)) - - v1.putChild(b"sign-ed25519", BlindlySignStuffServlet(sydent)) - - # v2 - # note v2 loses the /api so goes on 'identity' not 'api' - identity.putChild(b"v2", v2) - # v2 exclusive APIs v2.putChild(b"terms", TermsServlet(sydent)) account = AccountServlet(sydent) diff --git a/sydent/http/servlets/registerservlet.py b/sydent/http/servlets/registerservlet.py index 7e68d547..17752828 100644 --- a/sydent/http/servlets/registerservlet.py +++ b/sydent/http/servlets/registerservlet.py @@ -53,6 +53,14 @@ async def render_POST(self, request: Request) -> JsonDict: matrix_server = args["matrix_server_name"].lower() + if self.sydent.config.general.homeserver_allow_list: + if matrix_server not in self.sydent.config.general.homeserver_allow_list: + request.setResponseCode(403) + return { + "errcode": "M_UNAUTHORIZED", + "error": "This homeserver is not authorized to access this server.", + } + if not is_valid_matrix_server_name(matrix_server): request.setResponseCode(400) return { diff --git a/tests/test_register.py b/tests/test_register.py index 88feb2c6..de66125f 100644 --- a/tests/test_register.py +++ b/tests/test_register.py @@ -98,3 +98,30 @@ def test_federation_does_not_return_json(self) -> None: # Check that we haven't just returned the generic error message in asyncjsonwrap self.assertNotEqual(channel.json_body["error"], "Internal Server Error") self.assertIn("JSON", channel.json_body["error"]) + + +class RegisterAllowListTestCase(unittest.TestCase): + """ + Test registering works with the `homeserver_allow_list` config option specified + """ + + def test_registering_not_allowed_if_homeserver_not_in_allow_list(self) -> None: + config = { + "general": { + "homeserver_allow_list": "friendly.com, example.com", + "enable_v1_access": "false", + } + } + # Create a new sydent with a homeserver_allow_list specified + self.sydent = make_sydent(test_config=config) + self.sydent.run() + + request, channel = make_request( + self.sydent.reactor, + self.sydent.clientApiHttpServer.factory, + "POST", + "/_matrix/identity/v2/account/register", + content={"matrix_server_name": "not.example.com", "access_token": "foo"}, + ) + self.assertEqual(channel.code, 403) + self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") diff --git a/tests/test_start.py b/tests/test_start.py index 3d49b27c..705c254b 100644 --- a/tests/test_start.py +++ b/tests/test_start.py @@ -9,3 +9,18 @@ class StartupTestCase(unittest.TestCase): def test_start(self): sydent = make_sydent() sydent.run() + + def test_homeserver_allow_list_refuses_to_start_if_v1_not_disabled(self): + """ + Test that Sydent throws a runtime error if `homeserver_allow_list` is specified + but the v1 API has not been disabled + """ + config = { + "general": { + "homeserver_allow_list": "friendly.com, example.com", + "enable_v1_access": "true", + } + } + + with self.assertRaises(RuntimeError): + make_sydent(test_config=config)