From b14cb111032404b2762bafecb39f664c3eb2be9a Mon Sep 17 00:00:00 2001 From: Anshul Gupta Date: Thu, 4 Jul 2024 14:54:22 +0200 Subject: [PATCH] Refactor init (#29) * Refator Init for new functionality * Refactor Login/Email Verification Functionality * removed comment * Linting and Tests * Linting and Tests * Refactor Init with minimal changes * linting * Requested changes and email verification * ruff format and linting * ruff --- quick_test.py | 3 +- tabpfn_client/client.py | 29 ++++++++++++++--- tabpfn_client/config.py | 7 +++- tabpfn_client/estimator.py | 11 ++----- tabpfn_client/prompt_agent.py | 32 +++++++++++++++++++ tabpfn_client/server_config.yaml | 10 +++--- tabpfn_client/service_wrapper.py | 19 ++++++++--- tabpfn_client/tests/unit/test_client.py | 22 +++++++++---- .../tests/unit/test_tabpfn_classifier.py | 4 +++ .../tests/unit/test_tabpfn_regressor.py | 4 +++ 10 files changed, 109 insertions(+), 32 deletions(-) diff --git a/quick_test.py b/quick_test.py index 1064aa1..8138c4e 100644 --- a/quick_test.py +++ b/quick_test.py @@ -3,7 +3,7 @@ from sklearn.datasets import load_breast_cancer, load_diabetes from sklearn.model_selection import train_test_split -from tabpfn_client import UserDataClient, init +from tabpfn_client import UserDataClient from tabpfn_client.estimator import TabPFNClassifier, TabPFNRegressor logging.basicConfig(level=logging.DEBUG) @@ -21,7 +21,6 @@ X, y, test_size=0.33, random_state=42 ) - init() tabpfn = TabPFNClassifier(model="latest_tabpfn_hosted", n_estimators=3) # print("checking estimator", check_estimator(tabpfn)) tabpfn.fit(X_train[:99], y_train[:99]) diff --git a/tabpfn_client/client.py b/tabpfn_client/client.py index bb8a9eb..4fe1c4c 100644 --- a/tabpfn_client/client.py +++ b/tabpfn_client/client.py @@ -220,7 +220,7 @@ def try_connection(self) -> bool: return found_valid_connection - def try_authenticate(self, access_token) -> bool: + def is_auth_token_outdated(self, access_token) -> bool | None: """ Check if the provided access token is valid and return True if successful. """ @@ -230,11 +230,13 @@ def try_authenticate(self, access_token) -> bool: headers={"Authorization": f"Bearer {access_token}"}, ) - self._validate_response(response, "try_authenticate", only_version_check=True) - + self._validate_response( + response, "is_auth_token_outdated", only_version_check=True + ) if response.status_code == 200: is_authenticated = True - + elif response.status_code == 403: + is_authenticated = None return is_authenticated def validate_email(self, email: str) -> tuple[bool, str]: @@ -312,7 +314,8 @@ def register( is_created = False message = response.json()["detail"] - return is_created, message + access_token = response.json()["token"] if is_created else None + return is_created, message, access_token def login(self, email: str, password: str) -> tuple[str, str]: """ @@ -381,6 +384,22 @@ def send_reset_password_email(self, email: str) -> tuple[bool, str]: message = response.json()["detail"] return sent, message + def send_verification_email(self, access_token: str) -> tuple[bool, str]: + """ + Let the server send an email for verifying the email. + """ + response = self.httpx_client.post( + self.server_endpoints.send_verification_email.path, + headers={"Authorization": f"Bearer {access_token}"}, + ) + if response.status_code == 200: + sent = True + message = response.json()["message"] + else: + sent = False + message = response.json()["detail"] + return sent, message + def retrieve_greeting_messages(self) -> list[str]: """ Retrieve greeting messages that are new for the user. diff --git a/tabpfn_client/config.py b/tabpfn_client/config.py index 2f05f3a..75c2cb1 100644 --- a/tabpfn_client/config.py +++ b/tabpfn_client/config.py @@ -35,8 +35,13 @@ def init(use_server=True): is_valid_token_set = user_auth_handler.try_reuse_existing_token() - if is_valid_token_set: + if isinstance(is_valid_token_set, bool) and is_valid_token_set: PromptAgent.prompt_reusing_existing_token() + elif ( + isinstance(is_valid_token_set, tuple) and is_valid_token_set[1] is not None + ): + print("Your email is not verified. Please verify your email to continue...") + PromptAgent.reverify_email(is_valid_token_set[1], user_auth_handler) else: if not PromptAgent.prompt_terms_and_cond(): raise RuntimeError( diff --git a/tabpfn_client/estimator.py b/tabpfn_client/estimator.py index 2265552..7d043b7 100644 --- a/tabpfn_client/estimator.py +++ b/tabpfn_client/estimator.py @@ -3,6 +3,7 @@ from dataclasses import dataclass, asdict import numpy as np +from tabpfn_client import init from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin from sklearn.utils.validation import check_is_fitted @@ -183,10 +184,7 @@ def __init__( def fit(self, X, y): # assert init() is called - if not config.g_tabpfn_config.is_initialized: - raise RuntimeError( - "tabpfn_client.init() must be called before using TabPFNClassifier" - ) + init() if config.g_tabpfn_config.use_server: try: @@ -313,10 +311,7 @@ def __init__( def fit(self, X, y): # assert init() is called - if not config.g_tabpfn_config.is_initialized: - raise RuntimeError( - "tabpfn_client.init() must be called before using TabPFNRegressor" - ) + init() if config.g_tabpfn_config.use_server: try: diff --git a/tabpfn_client/prompt_agent.py b/tabpfn_client/prompt_agent.py index 5b9a299..f4e7b8f 100644 --- a/tabpfn_client/prompt_agent.py +++ b/tabpfn_client/prompt_agent.py @@ -55,6 +55,7 @@ def prompt_and_set_token(cls, user_auth_handler: "UserAuthenticationClient"): ] ) choice = cls._choice_with_retries(prompt, ["1", "2"]) + email = "" # Registration if choice == "1": @@ -207,6 +208,37 @@ def prompt_reusing_existing_token(cls): print(cls.indent(prompt)) + @classmethod + def reverify_email( + cls, access_token, user_auth_handler: "UserAuthenticationClient" + ): + prompt = "\n".join( + [ + "Please check your inbox for the verification email.", + "Note: The email might be in your spam folder or could have expired.", + ] + ) + print(cls.indent(prompt)) + retry_verification = "\n".join( + [ + "Do you want to resend email verification link? (y/n): ", + ] + ) + choice = cls._choice_with_retries(retry_verification, ["y", "n"]) + if choice == "y": + # get user email from user_auth_handler and resend verification email + sent, message = user_auth_handler.send_verification_email(access_token) + if not sent: + print(cls.indent("Failed to send verification email: " + message)) + else: + print( + cls.indent( + "A verification email has been sent, provided the details are correct!" + ) + + "\n" + ) + return + @classmethod def prompt_retrieved_greeting_messages(cls, greeting_messages: list[str]): for message in greeting_messages: diff --git a/tabpfn_client/server_config.yaml b/tabpfn_client/server_config.yaml index e5d2ef2..3440cbf 100644 --- a/tabpfn_client/server_config.yaml +++ b/tabpfn_client/server_config.yaml @@ -34,6 +34,11 @@ endpoints: methods: [ "POST" ] description: "User login" + send_verification_email: + path: "/auth/send_verification_email/" + methods: [ "POST" ] + description: "Send verifiaction email or for reverification" + send_reset_password_email: path: "/auth/send_reset_password_email/" methods: [ "POST" ] @@ -44,11 +49,6 @@ endpoints: methods: [ "GET" ] description: "Retrieve new greeting messages" - add_user_information: - path: "/add_user_information/" - methods: [ "POST" ] - description: "Add additional user information to database" - protected_root: path: "/protected/" methods: [ "GET" ] diff --git a/tabpfn_client/service_wrapper.py b/tabpfn_client/service_wrapper.py index f9f4ffc..cc3026e 100644 --- a/tabpfn_client/service_wrapper.py +++ b/tabpfn_client/service_wrapper.py @@ -44,9 +44,11 @@ def set_token_by_registration( validation_link: str, additional_info: dict, ) -> tuple[bool, str]: - is_created, message = self.service_client.register( + is_created, message, access_token = self.service_client.register( email, password, password_confirm, validation_link, additional_info ) + if access_token is not None: + self.set_token(access_token) return is_created, message def set_token_by_login(self, email: str, password: str) -> tuple[bool, str]: @@ -58,7 +60,7 @@ def set_token_by_login(self, email: str, password: str) -> tuple[bool, str]: self.set_token(access_token) return True, message - def try_reuse_existing_token(self) -> bool: + def try_reuse_existing_token(self) -> bool | tuple[bool, str]: if self.service_client.access_token is None: if not self.CACHED_TOKEN_FILE.exists(): return False @@ -68,10 +70,12 @@ def try_reuse_existing_token(self) -> bool: else: access_token = self.service_client.access_token - is_valid = self.service_client.try_authenticate(access_token) - if not is_valid: + is_valid = self.service_client.is_auth_token_outdated(access_token) + if is_valid is False: self._reset_token() return False + elif is_valid is None: + return False, access_token logger.debug(f"Reusing existing access token? {is_valid}") self.set_token(access_token) @@ -95,6 +99,10 @@ def send_reset_password_email(self, email: str) -> tuple[bool, str]: sent, message = self.service_client.send_reset_password_email(email) return sent, message + def send_verification_email(self, access_token: str) -> tuple[bool, str]: + sent, message = self.service_client.send_verification_email(access_token) + return sent, message + class UserDataClient(ServiceClientWrapper): """ @@ -175,7 +183,8 @@ def __init__(self, service_client=ServiceClient()): def fit(self, X, y) -> None: if not self.service_client.is_initialized: raise RuntimeError( - "Either email is not verified or Service client is not initialized. Please Verify your email and try again!" + "Dear TabPFN User, please initialize the client first by verifying your E-mail address sent to your registered E-mail account." + "Please Note: The email verification token expires in 30 minutes." ) self.last_train_set_uid = self.service_client.upload_train_set(X, y) diff --git a/tabpfn_client/tests/unit/test_client.py b/tabpfn_client/tests/unit/test_client.py index 294e51a..adb1740 100644 --- a/tabpfn_client/tests/unit/test_client.py +++ b/tabpfn_client/tests/unit/test_client.py @@ -58,7 +58,7 @@ def test_validate_email_invalid(self, mock_server): @with_mock_server() def test_register_user(self, mock_server): mock_server.router.post(mock_server.endpoints.register.path).respond( - 200, json={"message": "dummy_message"} + 200, json={"message": "dummy_message", "token": "DUMMY_TOKEN"} ) self.assertTrue( self.client.register( @@ -78,7 +78,7 @@ def test_register_user(self, mock_server): @with_mock_server() def test_register_user_with_invalid_email(self, mock_server): mock_server.router.post(mock_server.endpoints.register.path).respond( - 401, json={"detail": "dummy_message"} + 401, json={"detail": "dummy_message", "token": None} ) self.assertFalse( self.client.register( @@ -98,7 +98,7 @@ def test_register_user_with_invalid_email(self, mock_server): @with_mock_server() def test_register_user_with_invalid_validation_link(self, mock_server): mock_server.router.post(mock_server.endpoints.register.path).respond( - 401, json={"detail": "dummy_message"} + 401, json={"detail": "dummy_message", "token": None} ) self.assertFalse( self.client.register( @@ -118,7 +118,7 @@ def test_register_user_with_invalid_validation_link(self, mock_server): @with_mock_server() def test_register_user_with_limit_reached(self, mock_server): mock_server.router.post(mock_server.endpoints.register.path).respond( - 401, json={"detail": "dummy_message"} + 401, json={"detail": "dummy_message", "token": "DUMMY_TOKEN"} ) self.assertFalse( self.client.register( @@ -138,12 +138,12 @@ def test_register_user_with_limit_reached(self, mock_server): @with_mock_server() def test_invalid_auth_token(self, mock_server): mock_server.router.get(mock_server.endpoints.protected_root.path).respond(401) - self.assertFalse(self.client.try_authenticate("fake_token")) + self.assertFalse(self.client.is_auth_token_outdated("fake_token")) @with_mock_server() def test_valid_auth_token(self, mock_server): mock_server.router.get(mock_server.endpoints.protected_root.path).respond(200) - self.assertTrue(self.client.try_authenticate("true_token")) + self.assertTrue(self.client.is_auth_token_outdated("true_token")) @with_mock_server() def test_send_reset_password_email(self, mock_server): @@ -155,6 +155,16 @@ def test_send_reset_password_email(self, mock_server): (True, "Password reset email sent!"), ) + @with_mock_server() + def test_send_verification_email(self, mock_server): + mock_server.router.post( + mock_server.endpoints.send_verification_email.path + ).respond(200, json={"message": "Verification Email sent!"}) + self.assertEqual( + self.client.send_verification_email("test"), + (True, "Verification Email sent!"), + ) + @with_mock_server() def test_retrieve_greeting_messages(self, mock_server): mock_server.router.get( diff --git a/tabpfn_client/tests/unit/test_tabpfn_classifier.py b/tabpfn_client/tests/unit/test_tabpfn_classifier.py index 863d14b..f6cd64f 100644 --- a/tabpfn_client/tests/unit/test_tabpfn_classifier.py +++ b/tabpfn_client/tests/unit/test_tabpfn_classifier.py @@ -58,6 +58,10 @@ def test_init_remote_classifier( mock_server.endpoints.retrieve_greeting_messages.path ).respond(200, json={"messages": []}) + mock_server.router.get(mock_server.endpoints.protected_root.path).respond( + 200, json={"message": "Welcome to the protected zone, user!"} + ) + mock_predict_response = [[1, 0.0], [0.9, 0.1], [0.01, 0.99]] predict_route = mock_server.router.post(mock_server.endpoints.predict.path) predict_route.respond(200, json={"classification": mock_predict_response}) diff --git a/tabpfn_client/tests/unit/test_tabpfn_regressor.py b/tabpfn_client/tests/unit/test_tabpfn_regressor.py index dea087e..ff5c594 100644 --- a/tabpfn_client/tests/unit/test_tabpfn_regressor.py +++ b/tabpfn_client/tests/unit/test_tabpfn_regressor.py @@ -56,6 +56,10 @@ def test_init_remote_regressor( mock_server.endpoints.retrieve_greeting_messages.path ).respond(200, json={"messages": []}) + mock_server.router.get(mock_server.endpoints.protected_root.path).respond( + 200, json={"message": "Welcome to the protected zone, user!"} + ) + mock_predict_response = { "mean": [100, 200, 300], "median": [110, 210, 310],