Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ dependencies = [
"tree_sitter_language_pack>=0.7.0",
"tree_sitter_languages>=1.9.1",
"vtk>=9.3.1",
"pyyaml>=6.0.0",
]
requires-python = ">=3.10"
readme = "README.md"
Expand Down Expand Up @@ -62,7 +63,13 @@ build-backend = "setuptools.build_meta"
fallback_version = "0.1.0"

[tool.setuptools.package-data]
vtk_prompt = ["prompts/*.txt"]
vtk_prompt = ["prompts/*.yml"]

[tool.setuptools.packages.find]
where = ["src"]

[tool.setuptools.package-dir]
"" = "src"

[tool.black]
include = 'src/.*.py$'
2 changes: 1 addition & 1 deletion rag-components
Submodule rag-components updated from 317b2f to 32c044
183 changes: 95 additions & 88 deletions src/vtk_prompt/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,16 @@

import ast
import os
import re
import sys
import json
import openai
import click
from dataclasses import dataclass
from pathlib import Path

from .prompts import (
get_no_rag_context,
get_rag_context,
get_python_role,
)
# Using YAML system exclusively
from .yaml_prompt_loader import GitHubModelYAMLLoader


@dataclass
Expand Down Expand Up @@ -85,30 +83,32 @@ def run_code(self, code_string):
print(code_string)
return None

def query(
def query_yaml(
self,
message="",
api_key=None,
model="gpt-4o",
message,
api_key,
prompt_source="vtk_python_code_generation",
base_url=None,
max_tokens=1000,
temperature=0.1,
top_k=5,
rag=False,
top_k=5,
retry_attempts=1,
override_model=None,
override_temperature=None,
override_max_tokens=None,
):
"""Generate VTK code with optional RAG enhancement and retry logic.
"""Generate VTK code using YAML prompt templates.

Args:
message: The user query
api_key: API key for the service
model: Model name to use
prompt_source: Name of the YAML prompt file to use or binary blob of the prompt file content
base_url: API base URL
max_tokens: Maximum tokens to generate
temperature: Temperature for generation
top_k: Number of RAG examples to retrieve
rag: Whether to use RAG enhancement
retry_attempts: Number of times to retry if AST validation fails
top_k: Number of RAG examples to retrieve
retry_attempts: Number of retry attempts for failed generations

Returns:
Generated code string or None if failed
"""
if not api_key:
api_key = os.environ.get("OPENAI_API_KEY")
Expand All @@ -121,13 +121,18 @@ def query(
# Create client with current parameters
client = openai.OpenAI(api_key=api_key, base_url=base_url)

# Load existing conversation if present
if self.conversation_file and not self.conversation:
self.conversation = self.load_conversation()
# Load YAML prompt configuration
from pathlib import Path

prompts_dir = Path(__file__).parent / "prompts"
yaml_loader = GitHubModelYAMLLoader(prompts_dir)
model_params = yaml_loader.get_model_parameters(prompt_source)
model = override_model or yaml_loader.get_model_name(prompt_source)

if not message and not self.conversation:
raise ValueError("No prompt or conversation file provided")
# Prepare variables for template substitution
variables = {"request": message}

# Handle RAG if requested
if rag:
from .rag_chat_wrapper import (
check_rag_components_available,
Expand All @@ -148,27 +153,33 @@ def query(
raise ValueError("Failed to load RAG snippets")

context_snippets = "\n\n".join(rag_snippets["code_snippets"])
context = get_rag_context(message, context_snippets)
variables["context_snippets"] = context_snippets

if self.verbose:
print("CONTEXT: " + context)
references = rag_snippets.get("references")
if references:
print("Using examples from:")
for ref in references:
print(f"- {ref}")
else:
context = get_no_rag_context(message)
if self.verbose:
print("CONTEXT: " + context)

# If no conversation exists, start with system role
if not self.conversation:
self.conversation = [{"role": "system", "content": get_python_role()}]
# Load existing conversation or start fresh
conversation_messages = self.load_conversation()

# Add current user message
if message:
self.conversation.append({"role": "user", "content": context})
# Build base messages from YAML template
base_messages = yaml_loader.build_messages(prompt_source, variables)

# If conversation exists, extend it with new user message
if conversation_messages:
# Add the current request as a new user message
conversation_messages.append({"role": "user", "content": message})
self.conversation = conversation_messages
else:
# Use YAML template as starting point
self.conversation = base_messages

# Extract parameters with overrides
temperature = override_temperature or model_params.get("temperature", 0.3)
max_tokens = override_max_tokens or model_params.get("max_tokens", 2000)

# Retry loop for AST validation
for attempt in range(retry_attempts):
Expand Down Expand Up @@ -197,58 +208,50 @@ def query(
f"Output was truncated due to max_tokens limit ({max_tokens}). Please increase max_tokens."
)

generated_code = None
if "import vtk" not in content:
generated_code = "import vtk\n" + content
else:
pos = content.find("import vtk")
if pos != -1:
generated_code = content[pos:]
else:
generated_code = content
generated_explanation = re.findall(
"<explanation>(.*?)</explanation>", content, re.DOTALL
)[0]
generated_code = re.findall("<code>(.*?)</code>", content, re.DOTALL)[0]
if "import vtk" not in generated_code:
generated_code = f"import vtk\n{generated_code}"

is_valid, error_msg = self.validate_code_syntax(generated_code)
if is_valid:
if message:
self.conversation.append(
{"role": "assistant", "content": content}
)
self.save_conversation()
return generated_code, response.usage
# Save conversation with assistant response
self.conversation.append({"role": "assistant", "content": content})
self.save_conversation()

elif attempt < retry_attempts - 1: # Don't print on last attempt
if self.verbose:
print(f"AST validation failed: {error_msg}. Retrying...")
# Add error feedback to context for retry
self.conversation.append({"role": "assistant", "content": content})
self.conversation.append(
{
"role": "user",
"content": (
f"The generated code has a syntax error: {error_msg}. "
"Please fix the syntax and generate valid Python code."
),
}
)
print("Code validation successful!")
return generated_code, generated_explanation
else:
# Last attempt failed
if self.verbose:
print(f"Final attempt failed AST validation: {error_msg}")

if message:
self.conversation.append(
{"role": "assistant", "content": content}
print(
f"Code validation failed on attempt {attempt + 1}: {error_msg}"
)
print("Generated code:")
print(generated_code)

if attempt < retry_attempts - 1:
# Add error feedback to messages for retry
error_feedback = (
f"The previous code had a syntax error: {error_msg}. "
"Please fix the syntax and try again."
)
self.conversation.append({"role": "user", "content": error_feedback})
else:
# Save conversation even if final attempt failed
self.conversation.append({"role": "assistant", "content": content})
self.save_conversation()
return (
generated_code,
response.usage,
) # Return anyway, let caller handle
print(
f"All {retry_attempts} attempts failed. Final error: {error_msg}"
)
return generated_code, generated_explanation # Return anyway, let caller handle
else:
if attempt == retry_attempts - 1:
return "No response generated", response.usage
print("No response content received")
return None

return "No response generated"
return None


@click.command()
Expand All @@ -259,14 +262,14 @@ def query(
default="openai",
help="LLM provider to use",
)
@click.option("-m", "--model", default="gpt-4o", help="Model name to use")
@click.option("-m", "--model", default="gpt-4o-mini", help="Model name to use")
@click.option(
"-k", "--max-tokens", type=int, default=1000, help="Max # of tokens to generate"
)
@click.option(
"--temperature",
type=float,
default=0.7,
default=0.1,
help="Temperature for generation (0.0-2.0)",
)
@click.option(
Expand Down Expand Up @@ -310,7 +313,7 @@ def main(
retry_attempts,
conversation,
):
"""Generate and execute VTK code using LLMs.
"""Generate and execute VTK code using LLMs with YAML prompts.

INPUT_STRING: The code description to generate VTK code for
"""
Expand Down Expand Up @@ -340,22 +343,26 @@ def main(
verbose=verbose,
conversation_file=conversation,
)
generated_code, usage = client.query(

# Use YAML system directly
prompt_source = "rag_context" if rag else "no_rag_context"
generated_code = client.query_yaml(
input_string,
api_key=token,
model=model,
prompt_source=prompt_source,
base_url=base_url,
max_tokens=max_tokens,
temperature=temperature,
top_k=top_k,
rag=rag,
top_k=top_k,
retry_attempts=retry_attempts,
# Override parameters if specified in CLI
override_model=model if model != "gpt-4o-mini" else None,
override_temperature=temperature if temperature != 0.1 else None,
override_max_tokens=max_tokens if max_tokens != 1000 else None,
)

if verbose and usage is not None:
print(
f"Used tokens: input={usage.prompt_tokens} output={usage.completion_tokens}"
)
# Usage tracking not yet implemented for YAML system
if verbose:
print("Token usage tracking not available in YAML mode")

client.run_code(generated_code)

Expand Down
Loading
Loading