Skip to content

Commit

Permalink
Refactor init (#29)
Browse files Browse the repository at this point in the history
* Refator Init for new functionality

* Refactor Login/Email Verification Functionality

* removed comment

* Linting and Tests

* Linting and Tests

* Refactor Init with minimal changes

* linting

* Requested changes and email verification

* ruff format and linting

* ruff
  • Loading branch information
anshulg954 committed Jul 4, 2024
1 parent c678739 commit b14cb11
Show file tree
Hide file tree
Showing 10 changed files with 109 additions and 32 deletions.
3 changes: 1 addition & 2 deletions quick_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from sklearn.datasets import load_breast_cancer, load_diabetes
from sklearn.model_selection import train_test_split

from tabpfn_client import UserDataClient, init
from tabpfn_client import UserDataClient
from tabpfn_client.estimator import TabPFNClassifier, TabPFNRegressor

logging.basicConfig(level=logging.DEBUG)
Expand All @@ -21,7 +21,6 @@
X, y, test_size=0.33, random_state=42
)

init()
tabpfn = TabPFNClassifier(model="latest_tabpfn_hosted", n_estimators=3)
# print("checking estimator", check_estimator(tabpfn))
tabpfn.fit(X_train[:99], y_train[:99])
Expand Down
29 changes: 24 additions & 5 deletions tabpfn_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def try_connection(self) -> bool:

return found_valid_connection

def try_authenticate(self, access_token) -> bool:
def is_auth_token_outdated(self, access_token) -> bool | None:
"""
Check if the provided access token is valid and return True if successful.
"""
Expand All @@ -230,11 +230,13 @@ def try_authenticate(self, access_token) -> bool:
headers={"Authorization": f"Bearer {access_token}"},
)

self._validate_response(response, "try_authenticate", only_version_check=True)

self._validate_response(
response, "is_auth_token_outdated", only_version_check=True
)
if response.status_code == 200:
is_authenticated = True

elif response.status_code == 403:
is_authenticated = None
return is_authenticated

def validate_email(self, email: str) -> tuple[bool, str]:
Expand Down Expand Up @@ -312,7 +314,8 @@ def register(
is_created = False
message = response.json()["detail"]

return is_created, message
access_token = response.json()["token"] if is_created else None
return is_created, message, access_token

def login(self, email: str, password: str) -> tuple[str, str]:
"""
Expand Down Expand Up @@ -381,6 +384,22 @@ def send_reset_password_email(self, email: str) -> tuple[bool, str]:
message = response.json()["detail"]
return sent, message

def send_verification_email(self, access_token: str) -> tuple[bool, str]:
"""
Let the server send an email for verifying the email.
"""
response = self.httpx_client.post(
self.server_endpoints.send_verification_email.path,
headers={"Authorization": f"Bearer {access_token}"},
)
if response.status_code == 200:
sent = True
message = response.json()["message"]
else:
sent = False
message = response.json()["detail"]
return sent, message

def retrieve_greeting_messages(self) -> list[str]:
"""
Retrieve greeting messages that are new for the user.
Expand Down
7 changes: 6 additions & 1 deletion tabpfn_client/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,13 @@ def init(use_server=True):

is_valid_token_set = user_auth_handler.try_reuse_existing_token()

if is_valid_token_set:
if isinstance(is_valid_token_set, bool) and is_valid_token_set:
PromptAgent.prompt_reusing_existing_token()
elif (
isinstance(is_valid_token_set, tuple) and is_valid_token_set[1] is not None
):
print("Your email is not verified. Please verify your email to continue...")
PromptAgent.reverify_email(is_valid_token_set[1], user_auth_handler)
else:
if not PromptAgent.prompt_terms_and_cond():
raise RuntimeError(
Expand Down
11 changes: 3 additions & 8 deletions tabpfn_client/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dataclasses import dataclass, asdict

import numpy as np
from tabpfn_client import init
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
from sklearn.utils.validation import check_is_fitted

Expand Down Expand Up @@ -183,10 +184,7 @@ def __init__(

def fit(self, X, y):
# assert init() is called
if not config.g_tabpfn_config.is_initialized:
raise RuntimeError(
"tabpfn_client.init() must be called before using TabPFNClassifier"
)
init()

if config.g_tabpfn_config.use_server:
try:
Expand Down Expand Up @@ -313,10 +311,7 @@ def __init__(

def fit(self, X, y):
# assert init() is called
if not config.g_tabpfn_config.is_initialized:
raise RuntimeError(
"tabpfn_client.init() must be called before using TabPFNRegressor"
)
init()

if config.g_tabpfn_config.use_server:
try:
Expand Down
32 changes: 32 additions & 0 deletions tabpfn_client/prompt_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def prompt_and_set_token(cls, user_auth_handler: "UserAuthenticationClient"):
]
)
choice = cls._choice_with_retries(prompt, ["1", "2"])
email = ""

# Registration
if choice == "1":
Expand Down Expand Up @@ -207,6 +208,37 @@ def prompt_reusing_existing_token(cls):

print(cls.indent(prompt))

@classmethod
def reverify_email(
cls, access_token, user_auth_handler: "UserAuthenticationClient"
):
prompt = "\n".join(
[
"Please check your inbox for the verification email.",
"Note: The email might be in your spam folder or could have expired.",
]
)
print(cls.indent(prompt))
retry_verification = "\n".join(
[
"Do you want to resend email verification link? (y/n): ",
]
)
choice = cls._choice_with_retries(retry_verification, ["y", "n"])
if choice == "y":
# get user email from user_auth_handler and resend verification email
sent, message = user_auth_handler.send_verification_email(access_token)
if not sent:
print(cls.indent("Failed to send verification email: " + message))
else:
print(
cls.indent(
"A verification email has been sent, provided the details are correct!"
)
+ "\n"
)
return

@classmethod
def prompt_retrieved_greeting_messages(cls, greeting_messages: list[str]):
for message in greeting_messages:
Expand Down
10 changes: 5 additions & 5 deletions tabpfn_client/server_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ endpoints:
methods: [ "POST" ]
description: "User login"

send_verification_email:
path: "/auth/send_verification_email/"
methods: [ "POST" ]
description: "Send verifiaction email or for reverification"

send_reset_password_email:
path: "/auth/send_reset_password_email/"
methods: [ "POST" ]
Expand All @@ -44,11 +49,6 @@ endpoints:
methods: [ "GET" ]
description: "Retrieve new greeting messages"

add_user_information:
path: "/add_user_information/"
methods: [ "POST" ]
description: "Add additional user information to database"

protected_root:
path: "/protected/"
methods: [ "GET" ]
Expand Down
19 changes: 14 additions & 5 deletions tabpfn_client/service_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,11 @@ def set_token_by_registration(
validation_link: str,
additional_info: dict,
) -> tuple[bool, str]:
is_created, message = self.service_client.register(
is_created, message, access_token = self.service_client.register(
email, password, password_confirm, validation_link, additional_info
)
if access_token is not None:
self.set_token(access_token)
return is_created, message

def set_token_by_login(self, email: str, password: str) -> tuple[bool, str]:
Expand All @@ -58,7 +60,7 @@ def set_token_by_login(self, email: str, password: str) -> tuple[bool, str]:
self.set_token(access_token)
return True, message

def try_reuse_existing_token(self) -> bool:
def try_reuse_existing_token(self) -> bool | tuple[bool, str]:
if self.service_client.access_token is None:
if not self.CACHED_TOKEN_FILE.exists():
return False
Expand All @@ -68,10 +70,12 @@ def try_reuse_existing_token(self) -> bool:
else:
access_token = self.service_client.access_token

is_valid = self.service_client.try_authenticate(access_token)
if not is_valid:
is_valid = self.service_client.is_auth_token_outdated(access_token)
if is_valid is False:
self._reset_token()
return False
elif is_valid is None:
return False, access_token

logger.debug(f"Reusing existing access token? {is_valid}")
self.set_token(access_token)
Expand All @@ -95,6 +99,10 @@ def send_reset_password_email(self, email: str) -> tuple[bool, str]:
sent, message = self.service_client.send_reset_password_email(email)
return sent, message

def send_verification_email(self, access_token: str) -> tuple[bool, str]:
sent, message = self.service_client.send_verification_email(access_token)
return sent, message


class UserDataClient(ServiceClientWrapper):
"""
Expand Down Expand Up @@ -175,7 +183,8 @@ def __init__(self, service_client=ServiceClient()):
def fit(self, X, y) -> None:
if not self.service_client.is_initialized:
raise RuntimeError(
"Either email is not verified or Service client is not initialized. Please Verify your email and try again!"
"Dear TabPFN User, please initialize the client first by verifying your E-mail address sent to your registered E-mail account."
"Please Note: The email verification token expires in 30 minutes."
)

self.last_train_set_uid = self.service_client.upload_train_set(X, y)
Expand Down
22 changes: 16 additions & 6 deletions tabpfn_client/tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_validate_email_invalid(self, mock_server):
@with_mock_server()
def test_register_user(self, mock_server):
mock_server.router.post(mock_server.endpoints.register.path).respond(
200, json={"message": "dummy_message"}
200, json={"message": "dummy_message", "token": "DUMMY_TOKEN"}
)
self.assertTrue(
self.client.register(
Expand All @@ -78,7 +78,7 @@ def test_register_user(self, mock_server):
@with_mock_server()
def test_register_user_with_invalid_email(self, mock_server):
mock_server.router.post(mock_server.endpoints.register.path).respond(
401, json={"detail": "dummy_message"}
401, json={"detail": "dummy_message", "token": None}
)
self.assertFalse(
self.client.register(
Expand All @@ -98,7 +98,7 @@ def test_register_user_with_invalid_email(self, mock_server):
@with_mock_server()
def test_register_user_with_invalid_validation_link(self, mock_server):
mock_server.router.post(mock_server.endpoints.register.path).respond(
401, json={"detail": "dummy_message"}
401, json={"detail": "dummy_message", "token": None}
)
self.assertFalse(
self.client.register(
Expand All @@ -118,7 +118,7 @@ def test_register_user_with_invalid_validation_link(self, mock_server):
@with_mock_server()
def test_register_user_with_limit_reached(self, mock_server):
mock_server.router.post(mock_server.endpoints.register.path).respond(
401, json={"detail": "dummy_message"}
401, json={"detail": "dummy_message", "token": "DUMMY_TOKEN"}
)
self.assertFalse(
self.client.register(
Expand All @@ -138,12 +138,12 @@ def test_register_user_with_limit_reached(self, mock_server):
@with_mock_server()
def test_invalid_auth_token(self, mock_server):
mock_server.router.get(mock_server.endpoints.protected_root.path).respond(401)
self.assertFalse(self.client.try_authenticate("fake_token"))
self.assertFalse(self.client.is_auth_token_outdated("fake_token"))

@with_mock_server()
def test_valid_auth_token(self, mock_server):
mock_server.router.get(mock_server.endpoints.protected_root.path).respond(200)
self.assertTrue(self.client.try_authenticate("true_token"))
self.assertTrue(self.client.is_auth_token_outdated("true_token"))

@with_mock_server()
def test_send_reset_password_email(self, mock_server):
Expand All @@ -155,6 +155,16 @@ def test_send_reset_password_email(self, mock_server):
(True, "Password reset email sent!"),
)

@with_mock_server()
def test_send_verification_email(self, mock_server):
mock_server.router.post(
mock_server.endpoints.send_verification_email.path
).respond(200, json={"message": "Verification Email sent!"})
self.assertEqual(
self.client.send_verification_email("test"),
(True, "Verification Email sent!"),
)

@with_mock_server()
def test_retrieve_greeting_messages(self, mock_server):
mock_server.router.get(
Expand Down
4 changes: 4 additions & 0 deletions tabpfn_client/tests/unit/test_tabpfn_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ def test_init_remote_classifier(
mock_server.endpoints.retrieve_greeting_messages.path
).respond(200, json={"messages": []})

mock_server.router.get(mock_server.endpoints.protected_root.path).respond(
200, json={"message": "Welcome to the protected zone, user!"}
)

mock_predict_response = [[1, 0.0], [0.9, 0.1], [0.01, 0.99]]
predict_route = mock_server.router.post(mock_server.endpoints.predict.path)
predict_route.respond(200, json={"classification": mock_predict_response})
Expand Down
4 changes: 4 additions & 0 deletions tabpfn_client/tests/unit/test_tabpfn_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ def test_init_remote_regressor(
mock_server.endpoints.retrieve_greeting_messages.path
).respond(200, json={"messages": []})

mock_server.router.get(mock_server.endpoints.protected_root.path).respond(
200, json={"message": "Welcome to the protected zone, user!"}
)

mock_predict_response = {
"mean": [100, 200, 300],
"median": [110, 210, 310],
Expand Down

0 comments on commit b14cb11

Please sign in to comment.