Skip to content

Commit

Permalink
fix: strongly type get_configure_view
Browse files Browse the repository at this point in the history
  • Loading branch information
reneluria committed Aug 20, 2024
1 parent d880ca4 commit 6c7a68c
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 17 deletions.
19 changes: 15 additions & 4 deletions oidc/provider.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
from __future__ import annotations

from collections.abc import Callable

from django.http import HttpRequest

import time

import requests
from sentry.auth.provider import MigratingIdentityId
from sentry.auth.providers.oauth2 import OAuth2Callback, OAuth2Login, OAuth2Provider
from sentry.auth.services.auth.model import RpcAuthProvider
from sentry.organizations.services.organization.model import RpcOrganization
from sentry.plugins.base.response import DeferredResponse

from .constants import (
AUTHORIZATION_ENDPOINT,
Expand All @@ -14,7 +23,7 @@
TOKEN_ENDPOINT,
USERINFO_ENDPOINT,
)
from .views import FetchUser, OIDCConfigureView
from .views import FetchUser, oidc_configure_view


class OIDCLogin(OAuth2Login):
Expand All @@ -37,7 +46,7 @@ def get_authorize_params(self, state, redirect_uri):


class OIDCProvider(OAuth2Provider):
name = ISSUER
name = ISSUER if ISSUER else "oidc"

def __init__(self, domain=None, domains=None, version=None, **config):
if domain:
Expand All @@ -63,8 +72,10 @@ def get_client_id(self):
def get_client_secret(self):
return CLIENT_SECRET

def get_configure_view(self):
return OIDCConfigureView.as_view()
def get_configure_view(
self,
) -> Callable[[HttpRequest, RpcOrganization, RpcAuthProvider], DeferredResponse]:
return oidc_configure_view

def get_auth_pipeline(self):
return [
Expand Down
35 changes: 22 additions & 13 deletions oidc/views.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
from __future__ import annotations

import logging

from sentry.auth.view import AuthView, ConfigureView
from django.http import HttpRequest
from rest_framework.response import Response

from sentry.auth.services.auth.model import RpcAuthProvider
from sentry.auth.view import AuthView
from sentry.utils import json
from sentry.organizations.services.organization.model import RpcOrganization
from sentry.plugins.base.response import DeferredResponse
from sentry.utils.signing import urlsafe_b64decode

from .constants import ERR_INVALID_RESPONSE, ISSUER
Expand All @@ -15,7 +23,7 @@ def __init__(self, domains, version, *args, **kwargs):
self.version = version
super().__init__(*args, **kwargs)

def dispatch(self, request, helper):
def dispatch(self, request: HttpRequest, helper) -> Response: # type: ignore
data = helper.fetch_state("data")

try:
Expand Down Expand Up @@ -52,17 +60,18 @@ def dispatch(self, request, helper):
return helper.next_step()


class OIDCConfigureView(ConfigureView):
def dispatch(self, request, organization, auth_provider):
config = auth_provider.config
if config.get("domain"):
domains = [config["domain"]]
else:
domains = config.get("domains")
return self.render(
"oidc/configure.html",
{"provider_name": ISSUER or "", "domains": domains or []},
)
def oidc_configure_view(
request: HttpRequest, organization: RpcOrganization, auth_provider: RpcAuthProvider
) -> DeferredResponse:
config = auth_provider.config
if config.get("domain"):
domains: list[str] | None
domains = [config["domain"]]
else:
domains = config.get("domains")
return DeferredResponse(
"oidc/configure.html", {"provider_name": ISSUER or "", "domains": domains or []}
)


def extract_domain(email):
Expand Down

0 comments on commit 6c7a68c

Please sign in to comment.