Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Login in the browser when possible #50

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions tabpfn_client/browser_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from threading import Event
import http.server
import socketserver
import webbrowser
import urllib.parse
from typing import Optional, Tuple


class BrowserAuthHandler:
def __init__(self, gui_url: str):
self.gui_url = gui_url

def try_browser_login(self) -> Tuple[bool, Optional[str]]:
"""
Attempts to perform browser-based login
Returns (success: bool, token: Optional[str])
"""
auth_event = Event()
received_token = None

class CallbackHandler(http.server.SimpleHTTPRequestHandler):
def do_GET(self):
nonlocal received_token

parsed = urllib.parse.urlparse(self.path)
query = urllib.parse.parse_qs(parsed.query)

if "token" in query:
received_token = query["token"][0]

self.send_response(200)
self.send_header("Content-type", "text/html")
self.end_headers()
success_html = """
<html>
<body style="text-align: center; font-family: Arial, sans-serif; padding: 50px;">
<h2>Login successful!</h2>
<p>You can close this window and return to your application.</p>
</body>
</html>
"""
self.wfile.write(success_html.encode())
auth_event.set()

def log_message(self, format, *args):
pass

try:
with socketserver.TCPServer(("", 0), CallbackHandler) as httpd:
port = httpd.server_address[1]
callback_url = f"http://localhost:{port}"

login_url = f"{self.gui_url}/login?callback={callback_url}"

print(
"\nOpening browser for login. Please complete the login/registration process in your browser and return here.\n"
)

if not webbrowser.open(login_url):
print(
"\nCould not open browser automatically. Falling back to command-line login...\n"
)
return False, None

while not auth_event.is_set():
httpd.handle_request()

return received_token is not None, received_token

except Exception:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you provide more info on how this could fail? Is there e.g. times when the user doesn't have gui? in which situations might it be more comfortable in the cli?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main thing I had in mind is when the user is on a remote cluster. (in this case, there could also be a different workflow where we give the user the url and ask the user to paste the access token once he's logged in)

print("\n Browser auth failed. Falling back to command-line login...\n")
return False, None
15 changes: 15 additions & 0 deletions tabpfn_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from tabpfn_client.tabpfn_common_utils import utils as common_utils
from tabpfn_client.constants import CACHE_DIR
from tabpfn_client.browser_auth import BrowserAuthHandler


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -704,3 +705,17 @@ def delete_user_account(self, confirm_pass: str) -> None:
)

self._validate_response(response, "delete_user_account")

def try_browser_login(self) -> tuple[bool, str]:
"""
Attempts browser-based login flow
Returns (success: bool, message: str)
"""
browser_auth = BrowserAuthHandler(self.server_config.gui_url)
success, token = browser_auth.try_browser_login()

if success and token:
# Don't authorize directly, let UserAuthenticationClient handle it
return True, token

return False, "Browser login failed or was cancelled"
5 changes: 0 additions & 5 deletions tabpfn_client/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,6 @@ def init(use_server=True):
PromptAgent.reverify_email(is_valid_token_set[1], user_auth_handler)
else:
PromptAgent.prompt_welcome()
if not PromptAgent.prompt_terms_and_cond():
raise RuntimeError(
"You must agree to the terms and conditions to use TabPFN"
)

# prompt for login / register
PromptAgent.prompt_and_set_token(user_auth_handler)

Expand Down
15 changes: 14 additions & 1 deletion tabpfn_client/prompt_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,20 @@ def prompt_welcome(cls):

@classmethod
def prompt_and_set_token(cls, user_auth_handler: "UserAuthenticationClient"):
# Choose between registration and login
# Try browser login first
success, message = user_auth_handler.try_browser_login()
if success:
print(cls.indent("Login via browser successful!"))
return

# Fall back to CLI login if browser login failed
# Show terms and conditions for CLI login
if not cls.prompt_terms_and_cond():
raise RuntimeError(
"You must agree to the terms and conditions to use TabPFN"
)

# Rest of the existing CLI login code
prompt = "\n".join(
[
"Please choose one of the following options:",
Expand Down
1 change: 1 addition & 0 deletions tabpfn_client/server_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ protocol: "https"
host: "tabpfn-server-wjedmz7r5a-ez.a.run.app"
# host: tabpfn-server-preprod-wjedmz7r5a-ez.a.run.app # preprod
port: "443"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before merging the PR change the logo and title of the UI

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gui_url: "https://ux.priorlabs.ai"
endpoints:
root:
path: "/"
Expand Down
7 changes: 7 additions & 0 deletions tabpfn_client/service_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ def send_verification_email(self, access_token: str) -> tuple[bool, str]:
sent, message = self.service_client.send_verification_email(access_token)
return sent, message

def try_browser_login(self) -> tuple[bool, str]:
"""Try to authenticate using browser-based login"""
success, token_or_message = self.service_client.try_browser_login()
if success:
self.set_token(token_or_message)
return success, token_or_message


class UserDataClient(ServiceClientWrapper):
"""
Expand Down
Loading
Loading