Skip to content

Commit

Permalink
feat: update service account auth not to require rbac enabled org (#5360
Browse files Browse the repository at this point in the history
)

Related to grafana/oncall-private#2826

RBAC enabled or not (OSS or cloud), it is possible to get service
account permissions, enabling perm check (for service account tokens) in
public API.

Also allow empty value for users in sync (instead of returning a 400
response).
  • Loading branch information
matiasb authored Dec 12, 2024
1 parent b8dc7af commit 132bdf2
Show file tree
Hide file tree
Showing 17 changed files with 142 additions and 111 deletions.
2 changes: 1 addition & 1 deletion engine/apps/api/permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def user_is_authorized(user: "User", required_permissions: LegacyAccessControlCo
`required_permissions` - A list of permissions that a user must have to be considered authorized
"""
organization = user.organization
if organization.is_rbac_permissions_enabled:
if organization.is_rbac_permissions_enabled or user.is_service_account:
user_permissions = [u["action"] for u in user.permissions]
required_permission_values = get_required_permission_values(organization, required_permissions)
return all(permission in user_permissions for permission in required_permission_values)
Expand Down
4 changes: 0 additions & 4 deletions engine/apps/auth_token/models/service_account_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,6 @@ def organization(self):

@classmethod
def validate_token(cls, organization, token):
# require RBAC enabled to allow service account auth
if not organization.is_rbac_permissions_enabled:
raise InvalidToken

# Grafana API request: get permissions and confirm token is valid
permissions = get_service_account_token_permissions(organization, token)
if not permissions:
Expand Down
41 changes: 6 additions & 35 deletions engine/apps/auth_token/tests/test_grafana_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from apps.auth_token.auth import X_GRAFANA_INSTANCE_ID, GrafanaServiceAccountAuthentication
from apps.auth_token.models import ServiceAccountToken
from apps.auth_token.tests.helpers import setup_service_account_api_mocks
from apps.user_management.models import Organization, ServiceAccountUser
from apps.user_management.models import Organization
from common.constants.plugin_ids import PluginID
from settings.base import CLOUD_LICENSE_NAME, OPEN_SOURCE_LICENSE_NAME, SELF_HOSTED_SETTINGS

Expand Down Expand Up @@ -115,31 +115,10 @@ def test_grafana_authentication_invalid_grafana_url():
assert exc.value.detail == "Organization not found."


@pytest.mark.django_db
@httpretty.activate(verbose=True, allow_net_connect=False)
def test_grafana_authentication_rbac_disabled_fails(make_organization):
organization = make_organization(grafana_url="http://grafana.test")
if organization.is_rbac_permissions_enabled:
return

token = f"{ServiceAccountToken.GRAFANA_SA_PREFIX}xyz"
headers = {
"HTTP_AUTHORIZATION": token,
"HTTP_X_GRAFANA_URL": organization.grafana_url,
}
request = APIRequestFactory().get("/", **headers)

with pytest.raises(exceptions.AuthenticationFailed) as exc:
GrafanaServiceAccountAuthentication().authenticate(request)
assert exc.value.detail == "Invalid token."


@pytest.mark.django_db
@httpretty.activate(verbose=True, allow_net_connect=False)
def test_grafana_authentication_permissions_call_fails(make_organization):
organization = make_organization(grafana_url="http://grafana.test")
if not organization.is_rbac_permissions_enabled:
return

token = f"{ServiceAccountToken.GRAFANA_SA_PREFIX}xyz"
headers = {
Expand Down Expand Up @@ -170,8 +149,6 @@ def test_grafana_authentication_existing_token(
make_organization, make_service_account_for_organization, make_token_for_service_account
):
organization = make_organization(grafana_url="http://grafana.test")
if not organization.is_rbac_permissions_enabled:
return
service_account = make_service_account_for_organization(organization)
token_string = "glsa_the-token"
token = make_token_for_service_account(service_account, token_string)
Expand All @@ -187,7 +164,7 @@ def test_grafana_authentication_existing_token(

user, auth_token = GrafanaServiceAccountAuthentication().authenticate(request)

assert isinstance(user, ServiceAccountUser)
assert user.is_service_account
assert user.service_account == service_account
assert user.public_primary_key == service_account.public_primary_key
assert user.username == service_account.username
Expand All @@ -206,8 +183,6 @@ def test_grafana_authentication_existing_token(
@httpretty.activate(verbose=True, allow_net_connect=False)
def test_grafana_authentication_token_created(make_organization):
organization = make_organization(grafana_url="http://grafana.test")
if not organization.is_rbac_permissions_enabled:
return
token_string = "glsa_the-token"

headers = {
Expand All @@ -223,7 +198,7 @@ def test_grafana_authentication_token_created(make_organization):

user, auth_token = GrafanaServiceAccountAuthentication().authenticate(request)

assert isinstance(user, ServiceAccountUser)
assert user.is_service_account
service_account = user.service_account
assert service_account.organization == organization
assert user.public_primary_key == service_account.public_primary_key
Expand All @@ -248,8 +223,6 @@ def test_grafana_authentication_token_created(make_organization):
@httpretty.activate(verbose=True, allow_net_connect=False)
def test_grafana_authentication_token_created_older_grafana(make_organization):
organization = make_organization(grafana_url="http://grafana.test")
if not organization.is_rbac_permissions_enabled:
return
token_string = "glsa_the-token"

headers = {
Expand All @@ -265,7 +238,7 @@ def test_grafana_authentication_token_created_older_grafana(make_organization):

user, auth_token = GrafanaServiceAccountAuthentication().authenticate(request)

assert isinstance(user, ServiceAccountUser)
assert user.is_service_account
service_account = user.service_account
assert service_account.organization == organization
# use fallback data
Expand All @@ -278,8 +251,6 @@ def test_grafana_authentication_token_created_older_grafana(make_organization):
@httpretty.activate(verbose=True, allow_net_connect=False)
def test_grafana_authentication_token_reuse_service_account(make_organization, make_service_account_for_organization):
organization = make_organization(grafana_url="http://grafana.test")
if not organization.is_rbac_permissions_enabled:
return
service_account = make_service_account_for_organization(organization)
token_string = "glsa_the-token"

Expand All @@ -299,7 +270,7 @@ def test_grafana_authentication_token_reuse_service_account(make_organization, m

user, auth_token = GrafanaServiceAccountAuthentication().authenticate(request)

assert isinstance(user, ServiceAccountUser)
assert user.is_service_account
assert user.service_account == service_account
assert auth_token.service_account == service_account

Expand Down Expand Up @@ -335,7 +306,7 @@ def sync_org():

mock_setup_org.assert_called_once()

assert isinstance(user, ServiceAccountUser)
assert user.is_service_account
service_account = user.service_account
# organization is created
organization = Organization.objects.filter(grafana_url=grafana_url).get()
Expand Down
6 changes: 5 additions & 1 deletion engine/apps/grafana_plugin/serializers/sync_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ class SyncOnCallSettingsSerializer(serializers.Serializer):
labels_enabled = serializers.BooleanField()
irm_enabled = serializers.BooleanField(default=False)

def validate_grafana_url(self, value):
# remove trailing slash for URL consistency
return value.rstrip("/")

def create(self, validated_data):
return SyncSettings(**validated_data)

Expand All @@ -81,7 +85,7 @@ def to_representation(self, instance):


class SyncDataSerializer(serializers.Serializer):
users = serializers.ListField(child=SyncUserSerializer())
users = serializers.ListField(child=SyncUserSerializer(), allow_null=True, allow_empty=True)
teams = serializers.ListField(child=SyncTeamSerializer(), allow_null=True, allow_empty=True)
team_members = TeamMemberMappingField()
settings = SyncOnCallSettingsSerializer()
Expand Down
65 changes: 64 additions & 1 deletion engine/apps/grafana_plugin/tests/test_sync_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from rest_framework.test import APIClient

from apps.api.permissions import LegacyAccessControlRole
from apps.grafana_plugin.serializers.sync_data import SyncTeamSerializer
from apps.grafana_plugin.serializers.sync_data import SyncOnCallSettingsSerializer, SyncTeamSerializer
from apps.grafana_plugin.sync_data import SyncData, SyncSettings, SyncUser
from apps.grafana_plugin.tasks.sync_v2 import start_sync_organizations_v2, sync_organizations_v2
from common.constants.plugin_ids import PluginID
Expand Down Expand Up @@ -197,6 +197,47 @@ def test_sync_v2_irm_enabled(
assert organization.is_grafana_irm_enabled == expected


@patch("apps.grafana_plugin.helpers.client.GrafanaAPIClient.check_token", return_value=(None, {"connected": True}))
@pytest.mark.django_db
def test_sync_v2_none_values(
# mock this out so that we're not making a real network call, the sync v2 endpoint ends up calling
# user_management.sync._sync_organization which calls GrafanaApiClient.check_token
_mock_grafana_api_client_check_token,
make_organization_and_user_with_plugin_token,
make_user_auth_headers,
settings,
):
settings.LICENSE = settings.CLOUD_LICENSE_NAME
organization, _, token = make_organization_and_user_with_plugin_token()

client = APIClient()
headers = make_user_auth_headers(None, token, organization=organization)
url = reverse("grafana-plugin:sync-v2")

data = SyncData(
users=None,
teams=None,
team_members={},
settings=SyncSettings(
stack_id=organization.stack_id,
org_id=organization.org_id,
license=settings.CLOUD_LICENSE_NAME,
oncall_api_url="http://localhost",
oncall_token="",
grafana_url="http://localhost",
grafana_token="fake_token",
rbac_enabled=False,
incident_enabled=False,
incident_backend_url="",
labels_enabled=False,
irm_enabled=False,
),
)

response = client.post(url, format="json", data=asdict(data), **headers)
assert response.status_code == status.HTTP_200_OK


@pytest.mark.parametrize(
"test_team, validation_pass",
[
Expand All @@ -218,6 +259,28 @@ def test_sync_team_serialization(test_team, validation_pass):
assert (validation_error is None) == validation_pass


@pytest.mark.django_db
def test_sync_grafana_url_serialization():
data = {
"stack_id": 123,
"org_id": 321,
"license": "OSS",
"oncall_api_url": "http://localhost",
"oncall_token": "",
"grafana_url": "http://localhost/",
"grafana_token": "fake_token",
"rbac_enabled": False,
"incident_enabled": False,
"incident_backend_url": "",
"labels_enabled": False,
"irm_enabled": False,
}
serializer = SyncOnCallSettingsSerializer(data=data)
serializer.is_valid(raise_exception=True)
cleaned_data = serializer.save()
assert cleaned_data.grafana_url == "http://localhost"


@pytest.mark.django_db
def test_sync_batch_tasks(make_organization, settings):
settings.SYNC_V2_MAX_TASKS = 2
Expand Down
5 changes: 2 additions & 3 deletions engine/apps/public_api/serializers/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from apps.alerts.models import AlertReceiveChannel
from apps.base.messaging import get_messaging_backends
from apps.integrations.legacy_prefix import has_legacy_prefix, remove_legacy_prefix
from apps.user_management.models import ServiceAccountUser
from common.api_helpers.custom_fields import TeamPrimaryKeyRelatedField
from common.api_helpers.exceptions import BadRequest
from common.api_helpers.mixins import PHONE_CALL, SLACK, SMS, TELEGRAM, WEB, EagerLoadingMixin
Expand Down Expand Up @@ -129,8 +128,8 @@ def create(self, validated_data):
try:
instance = AlertReceiveChannel.create(
**validated_data,
author=user if not isinstance(user, ServiceAccountUser) else None,
service_account=user.service_account if isinstance(user, ServiceAccountUser) else None,
author=user if not user.is_service_account else None,
service_account=user.service_account if user.is_service_account else None,
organization=organization,
)
except AlertReceiveChannel.DuplicateDirectPagingError:
Expand Down
3 changes: 1 addition & 2 deletions engine/apps/public_api/serializers/resolution_notes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from rest_framework import serializers

from apps.alerts.models import AlertGroup, ResolutionNote
from apps.user_management.models import ServiceAccountUser
from common.api_helpers.custom_fields import OrganizationFilteredPrimaryKeyRelatedField, UserIdField
from common.api_helpers.exceptions import BadRequest
from common.api_helpers.mixins import EagerLoadingMixin
Expand Down Expand Up @@ -36,7 +35,7 @@ class Meta:

def create(self, validated_data):
user = self.context["request"].user
if not isinstance(user, ServiceAccountUser) and user.pk:
if not user.is_service_account and user.pk:
validated_data["author"] = user
validated_data["source"] = ResolutionNote.Source.WEB
return super().create(validated_data)
Expand Down
1 change: 1 addition & 0 deletions engine/apps/public_api/serializers/webhooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def validate_preset(self, preset):
raise serializers.ValidationError(PRESET_VALIDATION_MESSAGE)

def validate_user(self, user):
# user may also be a string when handling requests from the deprecated custom action API
if isinstance(user, ServiceAccountUser):
return None
return user
Expand Down
9 changes: 3 additions & 6 deletions engine/apps/public_api/tests/test_escalation_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,9 @@ def test_create_escalation_chain_via_service_account(
HTTP_AUTHORIZATION=f"{token_string}",
HTTP_X_GRAFANA_URL=organization.grafana_url,
)
if not organization.is_rbac_permissions_enabled:
assert response.status_code == status.HTTP_403_FORBIDDEN
else:
assert response.status_code == status.HTTP_201_CREATED
escalation_chain = organization.escalation_chains.get(name="test")
assert escalation_chain.team == team
assert response.status_code == status.HTTP_201_CREATED
escalation_chain = organization.escalation_chains.get(name="test")
assert escalation_chain.team == team


@pytest.mark.django_db
Expand Down
9 changes: 3 additions & 6 deletions engine/apps/public_api/tests/test_integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,9 @@ def test_create_integration_via_service_account(
HTTP_AUTHORIZATION=f"{token_string}",
HTTP_X_GRAFANA_URL=organization.grafana_url,
)
if not organization.is_rbac_permissions_enabled:
assert response.status_code == status.HTTP_403_FORBIDDEN
else:
assert response.status_code == status.HTTP_201_CREATED
integration = AlertReceiveChannel.objects.get(public_primary_key=response.data["id"])
assert integration.service_account == service_account
assert response.status_code == status.HTTP_201_CREATED
integration = AlertReceiveChannel.objects.get(public_primary_key=response.data["id"])
assert integration.service_account == service_account


@pytest.mark.django_db
Expand Down
34 changes: 15 additions & 19 deletions engine/apps/public_api/tests/test_rbac_permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,14 @@ def test_rbac_permissions(


@pytest.mark.parametrize(
"rbac_enabled,role,give_perm",
"rbac_enabled,give_perm",
[
# rbac disabled: auth is disabled
(False, LegacyAccessControlRole.ADMIN, None),
# rbac enabled: having role None, check the perm is required
(True, LegacyAccessControlRole.NONE, False),
(True, LegacyAccessControlRole.NONE, True),
# rbac enabled: check the perm is required
(True, False),
(True, True),
# rbac disabled: we still check for perms
(False, False),
(False, True),
],
)
@pytest.mark.django_db
Expand All @@ -124,7 +125,6 @@ def test_service_account_auth(
make_service_account_for_organization,
make_token_for_service_account,
rbac_enabled,
role,
give_perm,
):
# APIView default actions
Expand Down Expand Up @@ -155,18 +155,14 @@ def test_service_account_auth(
continue
for viewset_method_name, required_perms in viewset.rbac_permissions.items():
# setup Grafana API permissions response
if rbac_enabled:
permissions = {"perm": "value"}
expected = status.HTTP_403_FORBIDDEN
if give_perm:
permissions = {perm.value: "value" for perm in required_perms}
expected = status.HTTP_200_OK
mock_response = httpretty.Response(status=200, body=json.dumps(permissions))
perms_url = f"{organization.grafana_url}/api/access-control/user/permissions"
httpretty.register_uri(httpretty.GET, perms_url, responses=[mock_response])
else:
# service account auth is disabled
expected = status.HTTP_403_FORBIDDEN
permissions = {"perm": "value"}
expected = status.HTTP_403_FORBIDDEN
if give_perm:
permissions = {perm.value: "value" for perm in required_perms}
expected = status.HTTP_200_OK
mock_response = httpretty.Response(status=200, body=json.dumps(permissions))
perms_url = f"{organization.grafana_url}/api/access-control/user/permissions"
httpretty.register_uri(httpretty.GET, perms_url, responses=[mock_response])

# iterate over all viewset actions, making an API request for each,
# using the user's token and confirming the response status code
Expand Down
Loading

0 comments on commit 132bdf2

Please sign in to comment.