Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions intent_and_sql_tools/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# intent_and_sql_tools

Minimal Brain + Hands + Registry Data Agent SDK scaffold for VeADK/Google ADK.

## Quick Start

```bash
python train.py all
python app.py
```

## Structure

- sdk/core_engine.py: IntentVanna and SQLVanna
- sdk/registry.py: intent to tool routing map
- sdk/compiler.py: compile IntentEnvelope into rich prompt
- sdk/tools/: gateway and worker tools
- train.py: offline training pipeline
- app.py: runtime entry
1 change: 1 addition & 0 deletions intent_and_sql_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

38 changes: 38 additions & 0 deletions intent_and_sql_tools/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from veadk import Agent

from sdk.tools import execute_api, execute_sql, identify_intent, visualize_data


def create_agent() -> Agent:
return Agent(
name="HeQu_Data_Agent_v6_13",
description="HeQu Data Agent",
instruction=(
"ALWAYS call identify_intent first. "
"Read `next_tool` from the JSON output. "
"Pass the ENTIRE JSON envelope to the next tool without modification."
),
tools=[identify_intent, execute_sql, execute_api, visualize_data],
)

TOOLS = {
"execute_sql": execute_sql,
"execute_api": execute_api,
"visualize_data": visualize_data,
}


def run_query(query: str):
envelope = identify_intent(query)
tool = envelope.get("next_tool") or "unknown_tool"
print(f"[Chain] identify_intent -> {tool}")
func = TOOLS.get(tool)
if func is None:
return envelope
return func(envelope)


if __name__ == "__main__":
print(run_query("查一下土豪流失"))
print(run_query("选出MA多头的票"))
print(run_query("画一张最近流水趋势图"))
5 changes: 5 additions & 0 deletions intent_and_sql_tools/sdk/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .compiler import ContextCompiler
from .core_engine import IntentVanna, SQLVanna
from .registry import ToolRegistry

__all__ = ["ContextCompiler", "IntentVanna", "SQLVanna", "ToolRegistry"]
91 changes: 91 additions & 0 deletions intent_and_sql_tools/sdk/compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from typing import Any


class ContextCompiler:
def compile(self, envelope: dict) -> str:
intent = self._get_str(envelope, "intent") or "unknown"
payload = envelope.get("payload") or {}
if not isinstance(payload, dict):
payload = {"value": payload}
if intent == "query_metric":
return self._compile_query_metric(payload)
if intent == "screening":
return self._compile_screening(payload)
if intent == "plot_chart":
return self._compile_plot_chart(payload)
return self._compile_unknown(intent, payload)

def _compile_query_metric(self, payload: dict) -> str:
metrics = self._listify(payload.get("metrics") or payload.get("metric"))
time_range = payload.get("time_range") or payload.get("timeRange")
filters = self._listify(payload.get("filters"))
dimensions = self._listify(payload.get("dimensions") or payload.get("dims"))
parts = [
"Task: Query metrics with strict schema adherence.",
f"Metrics: {self._fmt_list(metrics)}",
f"TimeRange: {self._fmt_value(time_range)}",
f"Filters: {self._fmt_list(filters)}",
f"Dimensions: {self._fmt_list(dimensions)}",
"Constraints: Use documentation definitions; do not hallucinate columns.",
]
return "\n".join(parts)

def _compile_screening(self, payload: dict) -> str:
universe = payload.get("universe")
factors = self._listify(payload.get("factors") or payload.get("rules"))
sort_by = payload.get("sort_by") or payload.get("sortBy")
limit = payload.get("limit")
parts = [
"Task: Screen entities using structured factors.",
f"Universe: {self._fmt_value(universe)}",
f"Factors: {self._fmt_list(factors)}",
f"SortBy: {self._fmt_value(sort_by)}",
f"Limit: {self._fmt_value(limit)}",
"Constraints: Use documentation definitions; do not hallucinate columns.",
]
return "\n".join(parts)

def _compile_plot_chart(self, payload: dict) -> str:
metric = payload.get("metric") or payload.get("metrics")
time_range = payload.get("time_range") or payload.get("timeRange")
dimension = payload.get("dimension") or payload.get("dimensions")
chart_type = payload.get("chart_type") or payload.get("chartType")
parts = [
"Task: Generate visualization request.",
f"Metric: {self._fmt_value(metric)}",
f"TimeRange: {self._fmt_value(time_range)}",
f"Dimension: {self._fmt_value(dimension)}",
f"ChartType: {self._fmt_value(chart_type)}",
"Constraints: Use documentation definitions; do not hallucinate columns.",
]
return "\n".join(parts)

def _compile_unknown(self, intent: str, payload: dict) -> str:
parts = [
"Task: Unknown intent, provide best-effort structured summary.",
f"Intent: {intent}",
f"RawPayload: {self._fmt_value(payload)}",
"Constraints: Use documentation definitions; do not hallucinate columns.",
]
return "\n".join(parts)

def _listify(self, value: Any) -> list:
if value is None:
return []
if isinstance(value, list):
return [v for v in value if v not in (None, "")]
return [value]

def _fmt_list(self, value: list) -> str:
return ", ".join([str(v) for v in value]) if value else "None"

def _fmt_value(self, value: Any) -> str:
if value in (None, "", []):
return "None"
return str(value)

def _get_str(self, payload: dict, key: str) -> str | None:
value = payload.get(key)
if value in (None, ""):
return None
return str(value)
145 changes: 145 additions & 0 deletions intent_and_sql_tools/sdk/core_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import json
from typing import Any, Callable

from .registry import ToolRegistry


class VannaBase:
def __init__(
self,
config: dict,
llm_stub: Callable[[list[dict]], str] | None = None,
impl: Any | None = None,
):
self._config = config or {}
self._impl = impl
self._llm_stub = llm_stub

def _ensure_impl(self):
if self._impl is None:
from vanna.chromadb import ChromaDB_VectorStore
from vanna.vertexai import VertexAI_Chat

class _Impl(ChromaDB_VectorStore, VertexAI_Chat):
def __init__(self, config: dict):
ChromaDB_VectorStore.__init__(self, config=config)
VertexAI_Chat.__init__(self, config=config)

self._impl = _Impl(self._config)
return self._impl

def _submit_prompt(self, messages: list[dict]) -> str:
if self._llm_stub is not None:
return self._llm_stub(messages)
impl = self._ensure_impl()
return impl.submit_prompt(messages)

def _get_related_documentation(self, question: str):
impl = self._ensure_impl()
return impl.get_related_documentation(question)

def _get_similar_question_sql(self, question: str):
impl = self._ensure_impl()
return impl.get_similar_question_sql(question)


class IntentVanna(VannaBase):
def generate_envelope(self, question: str) -> dict[str, Any]:
try:
docs = self._get_related_documentation(question)
examples = self._get_similar_question_sql(question)
system_prompt = (
"Role: Semantic Parser.\n"
"Task: Map query to JSON based on Knowledge.\n"
f"Knowledge: {docs}\n"
f"Examples: {examples}\n"
"Output: JSON (IntentEnvelope)\n"
)
raw_resp = self._submit_prompt(
[
{"role": "system", "content": system_prompt},
{"role": "user", "content": question},
]
)
envelope = json.loads(raw_resp)
if not isinstance(envelope, dict):
raise ValueError("IntentEnvelope must be a dict")
intent = envelope.get("intent")
if not intent:
raise ValueError("Missing intent")
payload = envelope.get("payload") or {}
debug_context = {
"docs": _summarize_debug(docs),
"examples": _summarize_debug(examples),
}
tool_name = ToolRegistry.get_tool_name(intent)
if tool_name == "unknown_tool":
m = getattr(ToolRegistry, "_intent_map", {})
if isinstance(m, dict) and intent in m:
tool_name = m[intent]
return {
"intent": intent,
"payload": payload,
"next_tool": tool_name,
"error": None,
"confidence": envelope.get("confidence"),
"debug_context": debug_context,
}
except Exception as exc:
return {
"intent": "unknown",
"payload": {},
"next_tool": "unknown_tool",
"error": str(exc),
"confidence": None,
"debug_context": None,
}


class SQLVanna(VannaBase):
def generate_sql(self, question: str) -> str:
impl = self._ensure_impl()
return impl.generate_sql(question=question)

def generate_sql_from_context(self, context: str) -> str:
return self.generate_sql(question=context)

def run_sql(self, sql: str):
impl = self._ensure_impl()
return impl.run_sql(sql)


class MockVannaImpl:
def __init__(
self,
docs: str = "DOCS",
examples: str = "EXAMPLES",
response: str = '{"intent":"query_metric","payload":{"metric":"revenue"}}',
):
self._docs = docs
self._examples = examples
self._response = response

def get_related_documentation(self, question: str):
return self._docs

def get_similar_question_sql(self, question: str):
return self._examples

def submit_prompt(self, messages: list[dict]):
return self._response

def generate_sql(self, question: str) -> str:
return "SELECT 1"

def run_sql(self, sql: str):
return [{"ok": True}]

def train(self, **kwargs):
return True

def _summarize_debug(value: Any, limit: int = 400):
text = str(value)
if len(text) <= limit:
return text
return text[:limit]
14 changes: 14 additions & 0 deletions intent_and_sql_tools/sdk/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
class ToolRegistry:
_intent_map: dict[str, str] = {}

@classmethod
def register(cls, intent: str, tool_name: str):
def decorator(func):
cls._intent_map[intent] = tool_name
return func

return decorator

@classmethod
def get_tool_name(cls, intent: str) -> str:
return cls._intent_map.get(intent, "unknown_tool")
6 changes: 6 additions & 0 deletions intent_and_sql_tools/sdk/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .api_tool import execute_api
from .gateway_tool import identify_intent
from .runsql_tool import execute_sql
from .visualize_tool import visualize_data

__all__ = ["identify_intent", "execute_sql", "execute_api", "visualize_data"]
20 changes: 20 additions & 0 deletions intent_and_sql_tools/sdk/tools/api_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from sdk.registry import ToolRegistry


@ToolRegistry.register(intent="screening", tool_name="execute_api")
def execute_api(envelope: dict) -> dict:
payload = envelope.get("payload") or {}
if not isinstance(payload, dict):
payload = {"value": payload}
normalized = {
"universe": payload.get("universe"),
"factors": payload.get("factors") or payload.get("rules") or [],
"sort_by": payload.get("sort_by") or payload.get("sortBy"),
"limit": payload.get("limit"),
"raw_payload": payload,
}
return {
"status": "mock",
"message": "Screening request normalized",
"request": normalized,
}
6 changes: 6 additions & 0 deletions intent_and_sql_tools/sdk/tools/gateway_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from sdk.tools.runtime import get_brain, init_engine


def identify_intent(query: str) -> dict:
brain = get_brain()
return brain.generate_envelope(query)
15 changes: 15 additions & 0 deletions intent_and_sql_tools/sdk/tools/runsql_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from sdk.compiler import ContextCompiler
from sdk.registry import ToolRegistry
from sdk.tools.runtime import get_hands


@ToolRegistry.register(intent="query_metric", tool_name="execute_sql")
def execute_sql(envelope: dict) -> str:
compiler = ContextCompiler()
rich_prompt = compiler.compile(envelope)
hands = get_hands()
sql = hands.generate_sql(question=rich_prompt)
df = hands.run_sql(sql)
if hasattr(df, "to_markdown"):
return df.to_markdown()
return str(df)
Loading