Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added functions to add command and retain order #1

Merged
merged 3 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 22 additions & 32 deletions mirascope_cli/generic/prompt_template.j2
Original file line number Diff line number Diff line change
@@ -1,36 +1,15 @@
{%- if comments -%}
"""{{ comments }}"""
{%- endif -%}
{%- if imports -%}
{%- for import, alias in imports -%}
{%- if alias %}
import {{ import }} as {{alias}}
{%- else %}
import {{ import }}
{%- endif -%}
{%- endfor -%}
{%- endif %}

{% if from_imports -%}
{%- set from_import_groups = {} -%}
{%- for module, name, alias in from_imports -%}
{% if module not in from_import_groups %}
{%- set _ = from_import_groups.update({module: [(name, alias)]}) -%}
{%- else -%}
{%- set _ = from_import_groups[module].append((name, alias)) -%}
{%- endif -%}
{%- endfor -%}
{% for module, names in from_import_groups.items() -%}
from {{ module }} import {% for name, alias in names %}{% if alias %}{{ name }} as {{ alias }}{% else %}{{ name }}{% endif %}{% if not loop.last %}, {% endif %}{% endfor %}
{% endfor %}
{% endif -%}

{% if variables %}
{%- for var_name, var_value in variables.items() -%}
{%- for item in order -%}
{%- if item.type == "comment" %}
"""{{ item.render }}"""
{%- elif item.type == "import" %}
import {{ item.render[0] }}{% if item.render[1] %} as {{ item.render[1] }}{% endif %}
{%- elif item.type == "from_import" %}
from {{ item.render[0] }} import {% if item.render[2] %}{{ item.render[1] }} as {{ item.render[2] }}{% else %}{{ item.render[1] }}{% endif %}
{%- elif item.type == "variable" %}
{%- set var_name, var_value = item.render %}
{{ var_name }} = {{ var_value }}
{% endfor %}
{% endif -%}
{% for class in classes -%}
{%- elif item.type == "class" %}
{%- set class = item.render -%}
{%- for decorator in class.decorators %}
@{{ decorator }}
{%- endfor %}
Expand All @@ -50,4 +29,15 @@ class {{ class.name }}({{ class.bases | join(', ') }}):
{%- endif %}
{{ line }}
{%- endfor %}
{%- elif item.type == "function" %}
{%- set function = item.render -%}
{%- for decorator in function.decorators %}
@{{ decorator }}
{%- endfor %}
{% if function.is_async %}async {% endif %}def {{ function.name }}({{ function.args | join(', ') }}){% if function.returns %} -> {{ function.returns }}{% endif %}:
{%- if function.docstring %}
"""{{ function.docstring }}"""
{% endif %}
{{ function.body | indent(8) }}
{%- endif -%}
{% endfor %}
17 changes: 16 additions & 1 deletion mirascope_cli/schemas.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Contains the schema for files created by the mirascope cli."""
from typing import Optional
from typing import Literal, Optional, Union

from pydantic import BaseModel, ConfigDict, Field

Expand Down Expand Up @@ -47,3 +47,18 @@ class FunctionInfo(BaseModel):
decorators: list[str]
docstring: Optional[str]
is_async: bool


class ASTOrder(BaseModel):
type: Literal["class", "function", "import", "from_import", "variable", "comment"]
order: int
render: Optional[
Union[
ClassInfo,
FunctionInfo,
tuple[str, Optional[str]],
tuple[str, str, Optional[str]],
tuple[str, str],
str,
]
] = None
119 changes: 89 additions & 30 deletions mirascope_cli/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Utility functions for the mirascope library."""

from __future__ import annotations

import ast
Expand All @@ -14,9 +13,11 @@

from jinja2 import Environment, FileSystemLoader

from .enums import MirascopeCommand
from mirascope_cli.enums import MirascopeCommand

from .constants import CURRENT_REVISION_KEY, LATEST_REVISION_KEY
from .schemas import (
ASTOrder,
ClassInfo,
FunctionInfo,
MirascopeCliVariables,
Expand Down Expand Up @@ -66,24 +67,37 @@ def __init__(self) -> None:
self.classes: list[ClassInfo] = []
self.functions: list[FunctionInfo] = []
self.comments: str = ""
self.order: list[ASTOrder] = []

def visit_Import(self, node) -> None:
"""Extracts imports from the given node."""
for alias in node.names:
self.imports.append((alias.name, alias.asname))
import_str = (alias.name, alias.asname)
self.imports.append(import_str)
self.order.append(ASTOrder(type="import", order=len(self.imports) - 1))

self.generic_visit(node)

def visit_ImportFrom(self, node) -> None:
"""Extracts from imports from the given node."""
for alias in node.names:
self.from_imports.append((node.module, alias.name, alias.asname))
from_import_str = (node.module, alias.name, alias.asname)
self.from_imports.append(from_import_str)
self.order.append(
ASTOrder(type="from_import", order=len(self.from_imports) - 1)
)
self.generic_visit(node)

def visit_Assign(self, node) -> None:
"""Extracts variables from the given node."""
target = node.targets[0]
if isinstance(target, ast.Name):
already_exists = target.id in self.variables
self.variables[target.id] = ast.unparse(node.value)
if not already_exists:
self.order.append(
ASTOrder(type="variable", order=len(self.variables.keys()) - 1)
)
self.generic_visit(node)

def visit_ClassDef(self, node) -> None:
Expand Down Expand Up @@ -127,6 +141,7 @@ def visit_ClassDef(self, node) -> None:
class_info.body = "\n".join(body)

self.classes.append(class_info)
self.order.append(ASTOrder(type="class", order=len(self.classes) - 1))

def visit_AsyncFunctionDef(self, node):
"""Extracts async functions from the given node."""
Expand Down Expand Up @@ -160,11 +175,13 @@ def _visit_Function(self, node, is_async):

# Assuming you have a list to store functions
self.functions.append(function_info)
self.order.append(ASTOrder(type="function", order=len(self.functions) - 1))

def visit_Module(self, node) -> None:
"""Extracts comments from the given node."""
comments = ast.get_docstring(node, False)
self.comments = "" if comments is None else comments
self.order.append(ASTOrder(type="comment", order=0))
self.generic_visit(node)

def check_function_changed(self, other: PromptAnalyzer) -> bool:
Expand Down Expand Up @@ -475,25 +492,29 @@ def _update_tag_decorator_with_version(
return import_name


def _update_mirascope_imports(imports: list[tuple[str, Optional[str]]]):
def _update_mirascope_imports(analyzer: PromptAnalyzer):
"""Updates the mirascope import.

Args:
imports: The imports from the PromptAnalyzer class
"""
imports = analyzer.imports
if not any(import_name == "mirascope" for import_name, _ in imports):
imports.append(("mirascope", None))
index = 0
if analyzer.comments:
index = 1
analyzer.order.insert(index, ASTOrder(type="import", order=len(imports) - 1))


def _update_mirascope_from_imports(
member: str, from_imports: list[tuple[str, str, Optional[str]]]
):
def _update_mirascope_from_imports(member: str, analyzer: PromptAnalyzer):
"""Updates the mirascope from imports.

Args:
member: The member to import.
from_imports: The from imports from the PromptAnalyzer class
"""
from_imports = analyzer.from_imports
if not any(
(
module_name == "mirascope"
Expand All @@ -504,6 +525,12 @@ def _update_mirascope_from_imports(
for module_name, import_name, _ in from_imports
):
from_imports.append(("mirascope", member, None))
index = 0
if analyzer.comments:
index = 1
analyzer.order.insert(
index, ASTOrder(type="from_import", order=len(from_imports) - 1)
)


def write_prompt_to_template(
Expand Down Expand Up @@ -537,20 +564,6 @@ def write_prompt_to_template(
if variables is None:
variables = MirascopeCliVariables()

if command == MirascopeCommand.ADD:
# double quote revision ids to match how `ast.unparse()` formats strings
new_variables = {
k: f"'{v}'" if isinstance(v, str) else None
for k, v in variables.__dict__.items()
} | analyzer.variables
else: # command == MirascopeCommand.USE
ignore_variable_keys = dict.fromkeys(ignore_variables, None)
new_variables = {
k: analyzer.variables[k]
for k in analyzer.variables
if k not in ignore_variable_keys
}

if auto_tag:
import_tag_name: Optional[str] = None
mirascope_alias = "mirascope"
Expand All @@ -569,18 +582,64 @@ def write_prompt_to_template(
import_tag_name = _update_tag_decorator_with_version(
decorators, variables, mirascope_alias
)

if import_tag_name == "tags":
_update_mirascope_from_imports(import_tag_name, analyzer.from_imports)
_update_mirascope_from_imports(import_tag_name, analyzer)
elif import_tag_name == f"{mirascope_alias}.tags":
_update_mirascope_imports(analyzer.imports)
_update_mirascope_imports(analyzer)

for item in analyzer.order:
if item.type == "import":
item.render = analyzer.imports[item.order]
elif item.type == "from_import":
item.render = analyzer.from_imports[item.order]
elif item.type == "variable":
variable_name = list(analyzer.variables.keys())[item.order]
variable_value = analyzer.variables[variable_name]
item.render = (variable_name, variable_value)
elif item.type == "class":
item.render = analyzer.classes[item.order]
elif item.type == "function":
item.render = analyzer.functions[item.order]
elif item.type == "comment":
item.render = analyzer.comments

if command == MirascopeCommand.ADD:
# double quote revision ids to match how `ast.unparse()` formats strings
new_variables = {
k: f"'{v}'" if isinstance(v, str) else None
for k, v in variables.__dict__.items()
}
first_class_func_var_index = next(
(
i
for i, item in enumerate(analyzer.order)
if item.type in ["class", "function", "variable"]
),
None,
)
new_variables_order = [
ASTOrder(type="variable", order=i, render=(k, v))
for i, (k, v) in enumerate(new_variables.items())
]
if first_class_func_var_index is not None:
analyzer.order[
first_class_func_var_index:first_class_func_var_index
] = new_variables_order
else:
analyzer.order += new_variables_order
else: # command == MirascopeCommand.USE
ignore_variable_keys = dict.fromkeys(ignore_variables, None)
analyzer.order = [
item
for item in analyzer.order
if not (
item.type == "variable"
and item.render is not None
and item.render[0] in ignore_variable_keys
)
]
data = {
"comments": analyzer.comments,
"variables": new_variables,
"imports": analyzer.imports,
"from_imports": analyzer.from_imports,
"classes": analyzer.classes,
"order": analyzer.order,
}
return template.render(**data)

Expand Down
1 change: 0 additions & 1 deletion tests/commands/golden/base_prompt/0001_base_prompt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""A prompt for recommending movies of a particular genre."""

from mirascope import BasePrompt

prev_revision_id = None
Expand Down
1 change: 0 additions & 1 deletion tests/commands/golden/base_prompt/0002_base_prompt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""A prompt for recommending movies of a particular genre."""

from mirascope import BasePrompt

prev_revision_id = "0001"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""A prompt for recommending movies of a particular genre."""

from mirascope import BasePrompt, tags

prev_revision_id = None
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""A prompt for recommending movies of a particular genre."""

from mirascope import BasePrompt, tags

prev_revision_id = "0001"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""A prompt for recommending movies of a particular genre."""

from mirascope import BasePrompt, tags

prev_revision_id = None
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""A prompt for recommending movies of a particular genre."""

from mirascope import BasePrompt, tags

prev_revision_id = "0001"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
"""A call for recommending movies of a particular genre."""

from mirascope import tags
from mirascope.openai import OpenAICall, OpenAICallParams

prev_revision_id = None
revision_id = "0001"
number = 1
chat = OpenAICall()
a_list = [1, 2, 3]


def foo(a: int, b: str) -> int:
"""ABC"""

return a + int(b)


@tags(["movie_project", "version:0001"])
Expand All @@ -23,9 +27,14 @@ class MovieRecommender(OpenAICall):
include succinct and clear descriptions of the movie. You also make sure to pique
their interest by mentioning any famous actors in the movie that might be of
interest.



USER:
Please recommend 3 movies in the {genre} cetegory.
"""

genre: str
call_params = OpenAICallParams(model="gpt-3.5-turbo")


a_list = [1, 2, 3]
Loading
Loading