Skip to content

Commit

Permalink
[BUGFIX] fix vars for role names (#177)
Browse files Browse the repository at this point in the history
* fix vars for role names

---------

Co-authored-by: TJ Murphy <[email protected]>
  • Loading branch information
teej and teej authored Dec 18, 2024
1 parent cf5a6c3 commit dc61f7f
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 21 deletions.
8 changes: 8 additions & 0 deletions tests/test_blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,14 @@ def test_blueprint_vars_spec(session_ctx):
blueprint.generate_manifest(session_ctx)


def test_blueprint_vars_in_owner(session_ctx):
blueprint = Blueprint(
resources=[res.Schema(name="schema", owner="role_{{ var.role_name }}", database="STATIC_DATABASE")],
vars={"role_name": "role123"},
)
assert blueprint.generate_manifest(session_ctx)


def test_blueprint_allowlist(session_ctx, remote_state):
blueprint = Blueprint(
resources=[res.Role(name="role1")],
Expand Down
11 changes: 11 additions & 0 deletions tests/test_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,14 @@ def test_blueprint_vars_comparison_with_system_names():

schema = res.Schema(name=var.schema_name)
assert isinstance(schema.name, VarString)


def test_vars_in_owner():
schema = res.Schema(name="schema", owner="role_{{ var.role_name }}")
assert isinstance(schema._data.owner, VarString)


def test_vars_database_role():
role = res.DatabaseRole(name="role_{{ var.role_name }}", database="db_{{ var.db_name }}")
assert isinstance(role._data.name, VarString)
assert isinstance(role._data.database, VarString)
7 changes: 7 additions & 0 deletions titan/blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,12 @@ def _resolve_vars(self):
for resource in self._staged:
resource._resolve_vars(self._config.vars)

def _resolve_role_refs(self):
for resource in _walk(self._root):
if isinstance(resource, ResourcePointer):
continue
resource._resolve_role_refs()

def _build_resource_graph(self, session_ctx: SessionContext) -> None:
"""
Convert the staged resources into a directed graph of resources
Expand Down Expand Up @@ -921,6 +927,7 @@ def _finalize(self, session_ctx: SessionContext) -> None:
self._finalized = True
self._resolve_vars()
self._build_resource_graph(session_ctx)
self._resolve_role_refs()
self._create_tag_references()
self._create_ownership_refs(session_ctx)
self._create_grandparent_refs()
Expand Down
5 changes: 1 addition & 4 deletions titan/resources/grant.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ def __init__(
owner=owner,
)

self.requires(self._data.to)
granted_on = None
if on_type:
granted_on = ResourcePointer(name=on, resource_type=on_type)
Expand Down Expand Up @@ -336,7 +335,7 @@ class FutureGrant(Resource):
def __init__(
self,
priv: str,
to: Role,
to: Union[Role, DatabaseRole],
grant_option: bool = False,
**kwargs,
):
Expand Down Expand Up @@ -391,7 +390,6 @@ def __init__(
to=to,
grant_option=grant_option,
)
self.requires(self._data.to)
if granted_in_ref:
self.requires(granted_in_ref)

Expand Down Expand Up @@ -592,7 +590,6 @@ def __init__(
to=to,
grant_option=grant_option,
)
self.requires(self._data.to)

@classmethod
def from_sql(cls, sql):
Expand Down
45 changes: 31 additions & 14 deletions titan/resources/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,12 @@ def _coerce_resource_field(field_value, field_type):
return {k: _coerce_resource_field(v, field_type=dict_types[1]) for k, v in field_value.items()}

elif field_type is RoleRef:
return convert_role_ref(field_value)
if isinstance(field_value, str) and string_contains_var(field_value):
return VarString(field_value)
elif isinstance(field_value, (Resource, VarString, str)):
return convert_role_ref(field_value)
else:
raise TypeError

# Check for field_value's type in a Union
elif get_origin(field_type) == Union:
Expand Down Expand Up @@ -139,14 +144,7 @@ def _coerce_resource_field(field_value, field_type):
elif field_type is ResourceTags:
return ResourceTags(field_value)
elif field_type is str:
if isinstance(field_value, str) and string_contains_var(field_value):
return VarString(field_value)
elif isinstance(field_value, VarString):
return field_value
elif not isinstance(field_value, str):
raise TypeError
else:
return field_value
return convert_to_varstring(field_value)
elif field_type is float:
if isinstance(field_value, float):
return field_value
Expand Down Expand Up @@ -477,6 +475,15 @@ def _render_vars(field_value):
if isinstance(self, NamedResource) and isinstance(self._name, VarString):
self._name = ResourceName(self._name.to_string(vars))

def _resolve_role_refs(self):
for f in fields(self._data):
if f.type == RoleRef:
field_value = getattr(self._data, f.name)
new_value = convert_role_ref(field_value)
setattr(self._data, f.name, new_value)
if new_value.name != "":
self.requires(new_value)

def to_pointer(self):
return ResourcePointer(
name=str(self.fqn),
Expand Down Expand Up @@ -639,7 +646,7 @@ def container(self):
@property
def database(self):
if isinstance(self.scope, DatabaseScope):
return self.container.name
return self.container.name # type: ignore
else:
raise ValueError("ResourcePointer does not have a database")

Expand Down Expand Up @@ -703,16 +710,15 @@ def convert_to_resource(

def convert_role_ref(role_ref: RoleRef) -> Resource:
if role_ref.__class__.__name__ == "Role":
return role_ref
return role_ref # type: ignore
elif role_ref.__class__.__name__ == "DatabaseRole":
return role_ref
return role_ref # type: ignore
elif isinstance(role_ref, ResourcePointer) and role_ref.resource_type in (
ResourceType.DATABASE_ROLE,
ResourceType.ROLE,
):
return role_ref

elif isinstance(role_ref, str) or isinstance(role_ref, ResourceName):
elif isinstance(role_ref, (str, ResourceName)):
return ResourcePointer(name=role_ref, resource_type=infer_role_type_from_name(role_ref))
else:
raise TypeError
Expand All @@ -728,3 +734,14 @@ def infer_role_type_from_name(name: Union[str, ResourceName]) -> ResourceType:
return ResourceType.DATABASE_ROLE
else:
return ResourceType.ROLE


def convert_to_varstring(value: Union[str, ResourceName]) -> Union[VarString, str]:
if isinstance(value, str) and string_contains_var(value):
return VarString(value)
elif isinstance(value, VarString):
return value
elif not isinstance(value, str):
raise TypeError
else:
return value
8 changes: 5 additions & 3 deletions titan/role_ref.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Union, TYPE_CHECKING
from typing import TYPE_CHECKING, Union

from .var import VarString

if TYPE_CHECKING:
from titan.resources.role import Role, DatabaseRole
from titan.resources.role import DatabaseRole, Role

RoleRef = Union["Role", "DatabaseRole", str]
RoleRef = Union["Role", "DatabaseRole", VarString, str]

0 comments on commit dc61f7f

Please sign in to comment.