|
18 | 18 |
|
19 | 19 | import re
|
20 | 20 | from typing import AsyncGenerator
|
21 |
| -from typing import Generator |
22 | 21 | from typing import TYPE_CHECKING
|
23 | 22 |
|
24 | 23 | from typing_extensions import override
|
@@ -77,37 +76,90 @@ def _populate_values(
|
77 | 76 | instruction_template: str,
|
78 | 77 | context: InvocationContext,
|
79 | 78 | ) -> str:
|
80 |
| - """Populates values in the instruction template, e.g. state, artifact, etc.""" |
| 79 | + """Populates values in the instruction template, e.g. state, artifact, etc. |
| 80 | +
|
| 81 | + Supports nested dot-separated references like: |
| 82 | + - state.user.name |
| 83 | + - artifact.config.settings |
| 84 | + - user.profile.email |
| 85 | + - user?.profile?.name? (optional markers at any level) |
| 86 | + """ |
| 87 | + |
| 88 | + def _get_nested_value( |
| 89 | + obj, paths: list[str], is_optional: bool = False |
| 90 | + ) -> str: |
| 91 | + """Gets a nested value from an object using a list of path segments. |
| 92 | +
|
| 93 | + Args: |
| 94 | + obj: The object to get the value from |
| 95 | + paths: List of path segments to traverse |
| 96 | + is_optional: Whether the entire path is optional |
| 97 | +
|
| 98 | + Returns: |
| 99 | + The value as a string |
| 100 | +
|
| 101 | + Raises: |
| 102 | + KeyError: If the path doesn't exist and the reference is not optional |
| 103 | + """ |
| 104 | + if not paths: |
| 105 | + return str(obj) |
| 106 | + |
| 107 | + # Get current part and remaining paths |
| 108 | + current_part = paths[0] |
| 109 | + |
| 110 | + # Handle optional markers |
| 111 | + is_current_optional = current_part.endswith('?') or is_optional |
| 112 | + clean_part = current_part.removesuffix('?') |
| 113 | + |
| 114 | + # Get value for current part |
| 115 | + if isinstance(obj, dict) and clean_part in obj: |
| 116 | + return _get_nested_value(obj[clean_part], paths[1:], is_current_optional) |
| 117 | + elif hasattr(obj, clean_part): |
| 118 | + return _get_nested_value( |
| 119 | + getattr(obj, clean_part), paths[1:], is_current_optional |
| 120 | + ) |
| 121 | + elif is_current_optional: |
| 122 | + return '' |
| 123 | + else: |
| 124 | + raise KeyError(f'Key not found: {clean_part}') |
81 | 125 |
|
82 | 126 | def _replace_match(match) -> str:
|
83 | 127 | var_name = match.group().lstrip('{').rstrip('}').strip()
|
84 | 128 | optional = False
|
85 | 129 | if var_name.endswith('?'):
|
86 | 130 | optional = True
|
87 | 131 | var_name = var_name.removesuffix('?')
|
88 |
| - if var_name.startswith('artifact.'): |
89 |
| - var_name = var_name.removeprefix('artifact.') |
90 |
| - if context.artifact_service is None: |
91 |
| - raise ValueError('Artifact service is not initialized.') |
92 |
| - artifact = context.artifact_service.load_artifact( |
93 |
| - app_name=context.session.app_name, |
94 |
| - user_id=context.session.user_id, |
95 |
| - session_id=context.session.id, |
96 |
| - filename=var_name, |
97 |
| - ) |
98 |
| - if not var_name: |
99 |
| - raise KeyError(f'Artifact {var_name} not found.') |
100 |
| - return str(artifact) |
101 |
| - else: |
102 |
| - if not _is_valid_state_name(var_name): |
103 |
| - return match.group() |
104 |
| - if var_name in context.session.state: |
105 |
| - return str(context.session.state[var_name]) |
| 132 | + |
| 133 | + try: |
| 134 | + if var_name.startswith('artifact.'): |
| 135 | + var_name = var_name.removeprefix('artifact.') |
| 136 | + if context.artifact_service is None: |
| 137 | + raise ValueError('Artifact service is not initialized.') |
| 138 | + artifact = context.artifact_service.load_artifact( |
| 139 | + app_name=context.session.app_name, |
| 140 | + user_id=context.session.user_id, |
| 141 | + session_id=context.session.id, |
| 142 | + filename=var_name, |
| 143 | + ) |
| 144 | + if not var_name: |
| 145 | + raise KeyError(f'Artifact {var_name} not found.') |
| 146 | + return str(artifact) |
106 | 147 | else:
|
107 |
| - if optional: |
108 |
| - return '' |
109 |
| - else: |
| 148 | + if not _is_valid_state_name(var_name.split('.')[0].removesuffix('?')): |
| 149 | + return match.group() |
| 150 | + # Try to resolve nested path |
| 151 | + try: |
| 152 | + return _get_nested_value( |
| 153 | + context.session.state, var_name.split('.'), optional |
| 154 | + ) |
| 155 | + except KeyError: |
| 156 | + if not _is_valid_state_name(var_name): |
| 157 | + return match.group() |
110 | 158 | raise KeyError(f'Context variable not found: `{var_name}`.')
|
| 159 | + except Exception as e: |
| 160 | + if optional: |
| 161 | + return '' |
| 162 | + raise e |
111 | 163 |
|
112 | 164 | return re.sub(r'{+[^{}]*}+', _replace_match, instruction_template)
|
113 | 165 |
|
|
0 commit comments