Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor init #29

Merged
merged 10 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions quick_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from sklearn.datasets import load_breast_cancer, load_diabetes
from sklearn.model_selection import train_test_split

from tabpfn_client import UserDataClient, init
from tabpfn_client import UserDataClient
from tabpfn_client.estimator import TabPFNClassifier, TabPFNRegressor

logging.basicConfig(level=logging.DEBUG)
Expand All @@ -21,7 +21,6 @@
X, y, test_size=0.33, random_state=42
)

init()
tabpfn = TabPFNClassifier(model="latest_tabpfn_hosted", n_estimators=3)
# print("checking estimator", check_estimator(tabpfn))
tabpfn.fit(X_train[:99], y_train[:99])
Expand Down
29 changes: 24 additions & 5 deletions tabpfn_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def try_connection(self) -> bool:

return found_valid_connection

def try_authenticate(self, access_token) -> bool:
def is_auth_token_outdated(self, access_token) -> bool | None:
"""
Check if the provided access token is valid and return True if successful.
"""
Expand All @@ -230,11 +230,13 @@ def try_authenticate(self, access_token) -> bool:
headers={"Authorization": f"Bearer {access_token}"},
)

self._validate_response(response, "try_authenticate", only_version_check=True)

self._validate_response(
response, "is_auth_token_outdated", only_version_check=True
)
if response.status_code == 200:
is_authenticated = True

elif response.status_code == 403:
is_authenticated = None
return is_authenticated

def validate_email(self, email: str) -> tuple[bool, str]:
Expand Down Expand Up @@ -312,7 +314,8 @@ def register(
is_created = False
message = response.json()["detail"]

return is_created, message
access_token = response.json()["token"] if is_created else None
return is_created, message, access_token

def login(self, email: str, password: str) -> tuple[str, str]:
"""
Expand Down Expand Up @@ -381,6 +384,22 @@ def send_reset_password_email(self, email: str) -> tuple[bool, str]:
message = response.json()["detail"]
return sent, message

def send_verification_email(self, access_token: str) -> tuple[bool, str]:
"""
Let the server send an email for verifying the email.
"""
response = self.httpx_client.post(
self.server_endpoints.send_verification_email.path,
headers={"Authorization": f"Bearer {access_token}"},
)
if response.status_code == 200:
sent = True
message = response.json()["message"]
else:
sent = False
message = response.json()["detail"]
return sent, message

def retrieve_greeting_messages(self) -> list[str]:
"""
Retrieve greeting messages that are new for the user.
Expand Down
7 changes: 6 additions & 1 deletion tabpfn_client/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,13 @@ def init(use_server=True):

is_valid_token_set = user_auth_handler.try_reuse_existing_token()

if is_valid_token_set:
if isinstance(is_valid_token_set, bool) and is_valid_token_set:
PromptAgent.prompt_reusing_existing_token()
elif (
isinstance(is_valid_token_set, tuple) and is_valid_token_set[1] is not None
):
print("Your email is not verified. Please verify your email to continue...")
PromptAgent.reverify_email(is_valid_token_set[1], user_auth_handler)
else:
if not PromptAgent.prompt_terms_and_cond():
raise RuntimeError(
Expand Down
11 changes: 3 additions & 8 deletions tabpfn_client/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dataclasses import dataclass, asdict

import numpy as np
from tabpfn_client import init
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
from sklearn.utils.validation import check_is_fitted

Expand Down Expand Up @@ -183,10 +184,7 @@ def __init__(

def fit(self, X, y):
# assert init() is called
if not config.g_tabpfn_config.is_initialized:
raise RuntimeError(
"tabpfn_client.init() must be called before using TabPFNClassifier"
)
init()

if config.g_tabpfn_config.use_server:
try:
Expand Down Expand Up @@ -313,10 +311,7 @@ def __init__(

def fit(self, X, y):
# assert init() is called
if not config.g_tabpfn_config.is_initialized:
raise RuntimeError(
"tabpfn_client.init() must be called before using TabPFNRegressor"
)
init()

if config.g_tabpfn_config.use_server:
try:
Expand Down
32 changes: 32 additions & 0 deletions tabpfn_client/prompt_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def prompt_and_set_token(cls, user_auth_handler: "UserAuthenticationClient"):
]
)
choice = cls._choice_with_retries(prompt, ["1", "2"])
email = ""

# Registration
if choice == "1":
Expand Down Expand Up @@ -207,6 +208,37 @@ def prompt_reusing_existing_token(cls):

print(cls.indent(prompt))

@classmethod
def reverify_email(
cls, access_token, user_auth_handler: "UserAuthenticationClient"
):
prompt = "\n".join(
[
"Please check your inbox for the verification email.",
"Note: The email might be in your spam folder or could have expired.",
]
)
print(cls.indent(prompt))
retry_verification = "\n".join(
[
"Do you want to resend email verification link? (y/n): ",
]
)
choice = cls._choice_with_retries(retry_verification, ["y", "n"])
if choice == "y":
# get user email from user_auth_handler and resend verification email
sent, message = user_auth_handler.send_verification_email(access_token)
if not sent:
print(cls.indent("Failed to send verification email: " + message))
else:
print(
cls.indent(
"A verification email has been sent, provided the details are correct!"
)
+ "\n"
)
return

@classmethod
def prompt_retrieved_greeting_messages(cls, greeting_messages: list[str]):
for message in greeting_messages:
Expand Down
10 changes: 5 additions & 5 deletions tabpfn_client/server_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ endpoints:
methods: [ "POST" ]
description: "User login"

send_verification_email:
path: "/auth/send_verification_email/"
methods: [ "POST" ]
description: "Send verifiaction email or for reverification"

send_reset_password_email:
path: "/auth/send_reset_password_email/"
methods: [ "POST" ]
Expand All @@ -44,11 +49,6 @@ endpoints:
methods: [ "GET" ]
description: "Retrieve new greeting messages"

add_user_information:
path: "/add_user_information/"
methods: [ "POST" ]
description: "Add additional user information to database"

protected_root:
path: "/protected/"
methods: [ "GET" ]
Expand Down
19 changes: 14 additions & 5 deletions tabpfn_client/service_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,11 @@ def set_token_by_registration(
validation_link: str,
additional_info: dict,
) -> tuple[bool, str]:
is_created, message = self.service_client.register(
is_created, message, access_token = self.service_client.register(
email, password, password_confirm, validation_link, additional_info
)
if access_token is not None:
self.set_token(access_token)
return is_created, message

def set_token_by_login(self, email: str, password: str) -> tuple[bool, str]:
Expand All @@ -58,7 +60,7 @@ def set_token_by_login(self, email: str, password: str) -> tuple[bool, str]:
self.set_token(access_token)
return True, message

def try_reuse_existing_token(self) -> bool:
def try_reuse_existing_token(self) -> bool | tuple[bool, str]:
if self.service_client.access_token is None:
if not self.CACHED_TOKEN_FILE.exists():
return False
Expand All @@ -68,10 +70,12 @@ def try_reuse_existing_token(self) -> bool:
else:
access_token = self.service_client.access_token

is_valid = self.service_client.try_authenticate(access_token)
if not is_valid:
is_valid = self.service_client.is_auth_token_outdated(access_token)
if is_valid is False:
self._reset_token()
return False
elif is_valid is None:
return False, access_token

logger.debug(f"Reusing existing access token? {is_valid}")
self.set_token(access_token)
Expand All @@ -95,6 +99,10 @@ def send_reset_password_email(self, email: str) -> tuple[bool, str]:
sent, message = self.service_client.send_reset_password_email(email)
return sent, message

def send_verification_email(self, access_token: str) -> tuple[bool, str]:
sent, message = self.service_client.send_verification_email(access_token)
return sent, message


class UserDataClient(ServiceClientWrapper):
"""
Expand Down Expand Up @@ -175,7 +183,8 @@ def __init__(self, service_client=ServiceClient()):
def fit(self, X, y) -> None:
if not self.service_client.is_initialized:
raise RuntimeError(
"Either email is not verified or Service client is not initialized. Please Verify your email and try again!"
"Dear TabPFN User, please initialize the client first by verifying your E-mail address sent to your registered E-mail account."
"Please Note: The email verification token expires in 30 minutes."
)

self.last_train_set_uid = self.service_client.upload_train_set(X, y)
Expand Down
22 changes: 16 additions & 6 deletions tabpfn_client/tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_validate_email_invalid(self, mock_server):
@with_mock_server()
def test_register_user(self, mock_server):
mock_server.router.post(mock_server.endpoints.register.path).respond(
200, json={"message": "dummy_message"}
200, json={"message": "dummy_message", "token": "DUMMY_TOKEN"}
)
self.assertTrue(
self.client.register(
Expand All @@ -78,7 +78,7 @@ def test_register_user(self, mock_server):
@with_mock_server()
def test_register_user_with_invalid_email(self, mock_server):
mock_server.router.post(mock_server.endpoints.register.path).respond(
401, json={"detail": "dummy_message"}
401, json={"detail": "dummy_message", "token": None}
)
self.assertFalse(
self.client.register(
Expand All @@ -98,7 +98,7 @@ def test_register_user_with_invalid_email(self, mock_server):
@with_mock_server()
def test_register_user_with_invalid_validation_link(self, mock_server):
mock_server.router.post(mock_server.endpoints.register.path).respond(
401, json={"detail": "dummy_message"}
401, json={"detail": "dummy_message", "token": None}
)
self.assertFalse(
self.client.register(
Expand All @@ -118,7 +118,7 @@ def test_register_user_with_invalid_validation_link(self, mock_server):
@with_mock_server()
def test_register_user_with_limit_reached(self, mock_server):
mock_server.router.post(mock_server.endpoints.register.path).respond(
401, json={"detail": "dummy_message"}
401, json={"detail": "dummy_message", "token": "DUMMY_TOKEN"}
)
self.assertFalse(
self.client.register(
Expand All @@ -138,12 +138,12 @@ def test_register_user_with_limit_reached(self, mock_server):
@with_mock_server()
def test_invalid_auth_token(self, mock_server):
mock_server.router.get(mock_server.endpoints.protected_root.path).respond(401)
self.assertFalse(self.client.try_authenticate("fake_token"))
self.assertFalse(self.client.is_auth_token_outdated("fake_token"))

@with_mock_server()
def test_valid_auth_token(self, mock_server):
mock_server.router.get(mock_server.endpoints.protected_root.path).respond(200)
self.assertTrue(self.client.try_authenticate("true_token"))
self.assertTrue(self.client.is_auth_token_outdated("true_token"))

@with_mock_server()
def test_send_reset_password_email(self, mock_server):
Expand All @@ -155,6 +155,16 @@ def test_send_reset_password_email(self, mock_server):
(True, "Password reset email sent!"),
)

@with_mock_server()
def test_send_verification_email(self, mock_server):
mock_server.router.post(
mock_server.endpoints.send_verification_email.path
).respond(200, json={"message": "Verification Email sent!"})
self.assertEqual(
self.client.send_verification_email("test"),
(True, "Verification Email sent!"),
)

@with_mock_server()
def test_retrieve_greeting_messages(self, mock_server):
mock_server.router.get(
Expand Down
4 changes: 4 additions & 0 deletions tabpfn_client/tests/unit/test_tabpfn_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ def test_init_remote_classifier(
mock_server.endpoints.retrieve_greeting_messages.path
).respond(200, json={"messages": []})

mock_server.router.get(mock_server.endpoints.protected_root.path).respond(
200, json={"message": "Welcome to the protected zone, user!"}
)

mock_predict_response = [[1, 0.0], [0.9, 0.1], [0.01, 0.99]]
predict_route = mock_server.router.post(mock_server.endpoints.predict.path)
predict_route.respond(200, json={"classification": mock_predict_response})
Expand Down
4 changes: 4 additions & 0 deletions tabpfn_client/tests/unit/test_tabpfn_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ def test_init_remote_regressor(
mock_server.endpoints.retrieve_greeting_messages.path
).respond(200, json={"messages": []})

mock_server.router.get(mock_server.endpoints.protected_root.path).respond(
200, json={"message": "Welcome to the protected zone, user!"}
)

mock_predict_response = {
"mean": [100, 200, 300],
"median": [110, 210, 310],
Expand Down
Loading