|
3 | 3 | It defines the workflow graph, state, tools, nodes and edges. |
4 | 4 | """ |
5 | 5 |
|
6 | | -import json |
7 | | -import urllib.parse |
8 | | -import urllib.request |
9 | | -from collections.abc import Mapping |
10 | 6 | from typing import Any, List, TypedDict |
11 | 7 |
|
12 | 8 | from langchain.agents import create_agent |
13 | | -from langchain.agents.structured_output import ProviderStrategy |
14 | | -from langchain.agents.middleware import wrap_model_call, before_model |
15 | | -from langchain.tools import tool |
16 | 9 | from copilotkit import CopilotKitMiddleware, CopilotKitState |
17 | | - |
18 | | -@wrap_model_call |
19 | | -async def apply_structured_output_schema(request, handler): |
20 | | - """ |
21 | | - If the frontend forwards a JSON schema in runtime context, use it as the |
22 | | - structured output response_format for the model call. |
23 | | - """ |
24 | | - schema = None |
25 | | - runtime = getattr(request, "runtime", None) |
26 | | - runtime_context = getattr(runtime, "context", None) |
27 | | - if isinstance(runtime_context, Mapping): |
28 | | - schema = runtime_context.get("output_schema") |
29 | | - |
30 | | - if schema is None: |
31 | | - copilot_context = None |
32 | | - if isinstance(getattr(request, "state", None), dict): |
33 | | - copilot_context = request.state.get("copilotkit", {}).get("context") |
34 | | - if isinstance(copilot_context, list): |
35 | | - for item in copilot_context: |
36 | | - if isinstance(item, dict) and item.get("description") == "output_schema": |
37 | | - schema = item.get("value") |
38 | | - break |
39 | | - |
40 | | - if isinstance(schema, str): |
41 | | - try: |
42 | | - schema = json.loads(schema) |
43 | | - except json.JSONDecodeError: |
44 | | - schema = None |
45 | | - |
46 | | - if isinstance(schema, dict): |
47 | | - if not schema.get("title"): |
48 | | - schema["title"] = "StructuredOutput" |
49 | | - if not schema.get("description"): |
50 | | - schema["description"] = "Structured response schema for the CopilotKit agent." |
51 | | - request = request.override( |
52 | | - response_format=ProviderStrategy(schema=schema, strict=True), |
53 | | - ) |
54 | | - return await handler(request) |
55 | | - |
56 | | -@tool |
57 | | -def get_weather(location: str): |
58 | | - """ |
59 | | - Get the current weather for a given location. |
60 | | -
|
61 | | - Preferred input format: "City, State, Country" (e.g., "Huntsville, Alabama, USA"). |
62 | | - US shorthand is allowed: "City, ST" (e.g., "Huntsville, AL"). The tool will |
63 | | - expand state abbreviations and bias geocoding to the US when a state is detected. |
64 | | - """ |
65 | | - if not location or not location.strip(): |
66 | | - return { |
67 | | - "status": "error", |
68 | | - "message": "Please provide a location in the format \"City, State, Country\".", |
69 | | - "suggestedQueries": [], |
70 | | - } |
71 | | - |
72 | | - us_state_map = { |
73 | | - "AL": "Alabama", |
74 | | - "AK": "Alaska", |
75 | | - "AZ": "Arizona", |
76 | | - "AR": "Arkansas", |
77 | | - "CA": "California", |
78 | | - "CO": "Colorado", |
79 | | - "CT": "Connecticut", |
80 | | - "DE": "Delaware", |
81 | | - "FL": "Florida", |
82 | | - "GA": "Georgia", |
83 | | - "HI": "Hawaii", |
84 | | - "ID": "Idaho", |
85 | | - "IL": "Illinois", |
86 | | - "IN": "Indiana", |
87 | | - "IA": "Iowa", |
88 | | - "KS": "Kansas", |
89 | | - "KY": "Kentucky", |
90 | | - "LA": "Louisiana", |
91 | | - "ME": "Maine", |
92 | | - "MD": "Maryland", |
93 | | - "MA": "Massachusetts", |
94 | | - "MI": "Michigan", |
95 | | - "MN": "Minnesota", |
96 | | - "MS": "Mississippi", |
97 | | - "MO": "Missouri", |
98 | | - "MT": "Montana", |
99 | | - "NE": "Nebraska", |
100 | | - "NV": "Nevada", |
101 | | - "NH": "New Hampshire", |
102 | | - "NJ": "New Jersey", |
103 | | - "NM": "New Mexico", |
104 | | - "NY": "New York", |
105 | | - "NC": "North Carolina", |
106 | | - "ND": "North Dakota", |
107 | | - "OH": "Ohio", |
108 | | - "OK": "Oklahoma", |
109 | | - "OR": "Oregon", |
110 | | - "PA": "Pennsylvania", |
111 | | - "RI": "Rhode Island", |
112 | | - "SC": "South Carolina", |
113 | | - "SD": "South Dakota", |
114 | | - "TN": "Tennessee", |
115 | | - "TX": "Texas", |
116 | | - "UT": "Utah", |
117 | | - "VT": "Vermont", |
118 | | - "VA": "Virginia", |
119 | | - "WA": "Washington", |
120 | | - "WV": "West Virginia", |
121 | | - "WI": "Wisconsin", |
122 | | - "WY": "Wyoming", |
123 | | - "DC": "District of Columbia", |
124 | | - } |
125 | | - |
126 | | - raw_location = location.strip() |
127 | | - suggested_queries: list[str] = [] |
128 | | - country_bias: str | None = None |
129 | | - |
130 | | - normalized_location = raw_location |
131 | | - city_only: str | None = None |
132 | | - state_full: str | None = None |
133 | | - parts = [part.strip() for part in raw_location.split(",") if part.strip()] |
134 | | - if len(parts) == 2: |
135 | | - state = parts[1].upper() |
136 | | - if state in us_state_map: |
137 | | - city_only = parts[0] |
138 | | - state_full = us_state_map[state] |
139 | | - normalized_location = f"{parts[0]}, {us_state_map[state]}" |
140 | | - suggested_queries.append(f"{parts[0]}, {us_state_map[state]}, USA") |
141 | | - country_bias = "US" |
142 | | - if len(parts) >= 3: |
143 | | - tail = parts[-1].lower() |
144 | | - if tail in {"usa", "us", "united states", "united states of america"}: |
145 | | - country_bias = "US" |
146 | | - normalized_location = ", ".join(parts[:-1]) |
147 | | - suggested_queries.append(f"{normalized_location}, USA") |
148 | | - if len(parts) >= 2: |
149 | | - state_full = parts[-2].strip() |
150 | | - city_only = ", ".join(parts[:-2]).strip() |
151 | | - |
152 | | - if raw_location not in suggested_queries: |
153 | | - suggested_queries.append(raw_location) |
154 | | - if normalized_location not in suggested_queries: |
155 | | - suggested_queries.append(normalized_location) |
156 | | - |
157 | | - def geocode(name: str, country_code: str | None): |
158 | | - query = urllib.parse.urlencode( |
159 | | - { |
160 | | - "name": name, |
161 | | - "count": 5, |
162 | | - "language": "en", |
163 | | - "format": "json", |
164 | | - **({"countryCode": country_code} if country_code else {}), |
165 | | - } |
166 | | - ) |
167 | | - geo_url = f"https://geocoding-api.open-meteo.com/v1/search?{query}" |
168 | | - with urllib.request.urlopen(geo_url, timeout=10) as response: |
169 | | - return json.loads(response.read().decode("utf-8")) |
170 | | - |
171 | | - match = None |
172 | | - last_error: Exception | None = None |
173 | | - candidates = [ |
174 | | - (normalized_location, country_bias), |
175 | | - (raw_location, country_bias), |
176 | | - (normalized_location, None), |
177 | | - (raw_location, None), |
178 | | - ] |
179 | | - if city_only: |
180 | | - candidates.insert(0, (city_only, country_bias)) |
181 | | - candidates.append((city_only, None)) |
182 | | - for candidate, bias in candidates: |
183 | | - try: |
184 | | - cleaned = candidate |
185 | | - if bias and "," in cleaned: |
186 | | - tail = cleaned.split(",")[-1].strip().lower() |
187 | | - if tail in {"usa", "us", "united states", "united states of america"}: |
188 | | - cleaned = ", ".join(part.strip() for part in cleaned.split(",")[:-1]) |
189 | | - geo_data = geocode(cleaned, bias) |
190 | | - results = geo_data.get("results") or [] |
191 | | - if results: |
192 | | - if state_full: |
193 | | - filtered = [ |
194 | | - result |
195 | | - for result in results |
196 | | - if (result.get("admin1") or "").lower() |
197 | | - == state_full.lower() |
198 | | - ] |
199 | | - if filtered: |
200 | | - results = filtered |
201 | | - match = results[0] |
202 | | - for result in results: |
203 | | - name = result.get("name") |
204 | | - admin1 = result.get("admin1") |
205 | | - country = result.get("country") |
206 | | - formatted = ", ".join( |
207 | | - part for part in [name, admin1, country] if part |
208 | | - ) |
209 | | - if formatted and formatted not in suggested_queries: |
210 | | - suggested_queries.append(formatted) |
211 | | - break |
212 | | - except Exception as exc: |
213 | | - last_error = exc |
214 | | - |
215 | | - if match is None: |
216 | | - if last_error: |
217 | | - return { |
218 | | - "status": "error", |
219 | | - "message": f"Sorry, I couldn't look up the location \"{raw_location}\" right now.", |
220 | | - "suggestedQueries": suggested_queries, |
221 | | - } |
222 | | - return { |
223 | | - "status": "not_found", |
224 | | - "message": f"Sorry, I couldn't find a location match for \"{raw_location}\".", |
225 | | - "suggestedQueries": suggested_queries, |
226 | | - } |
227 | | - |
228 | | - latitude = match.get("latitude") |
229 | | - longitude = match.get("longitude") |
230 | | - name = match.get("name") or raw_location |
231 | | - admin1 = match.get("admin1") |
232 | | - country = match.get("country") |
233 | | - place = ", ".join(part for part in [name, admin1, country] if part) |
234 | | - |
235 | | - forecast_params = urllib.parse.urlencode( |
236 | | - { |
237 | | - "latitude": latitude, |
238 | | - "longitude": longitude, |
239 | | - "current": "temperature_2m,apparent_temperature,relative_humidity_2m,wind_speed_10m,weather_code", |
240 | | - "temperature_unit": "fahrenheit", |
241 | | - "windspeed_unit": "mph", |
242 | | - } |
243 | | - ) |
244 | | - forecast_url = f"https://api.open-meteo.com/v1/forecast?{forecast_params}" |
245 | | - |
246 | | - try: |
247 | | - with urllib.request.urlopen(forecast_url, timeout=10) as response: |
248 | | - forecast = json.loads(response.read().decode("utf-8")) |
249 | | - except Exception: |
250 | | - return { |
251 | | - "status": "error", |
252 | | - "message": f"Sorry, I couldn't fetch the weather for {place} right now.", |
253 | | - "suggestedQueries": suggested_queries, |
254 | | - } |
255 | | - |
256 | | - current = forecast.get("current") or {} |
257 | | - temperature = current.get("temperature_2m") |
258 | | - feels_like = current.get("apparent_temperature") |
259 | | - humidity = current.get("relative_humidity_2m") |
260 | | - windspeed = current.get("wind_speed_10m") |
261 | | - weather_code = current.get("weather_code") |
262 | | - |
263 | | - if temperature is None: |
264 | | - return { |
265 | | - "status": "error", |
266 | | - "message": f"Sorry, I couldn't read the current weather for {place}.", |
267 | | - "suggestedQueries": suggested_queries, |
268 | | - } |
269 | | - |
270 | | - details = [] |
271 | | - if feels_like is not None: |
272 | | - details.append(f"feels like {feels_like}°F") |
273 | | - if humidity is not None: |
274 | | - details.append(f"humidity {humidity}%") |
275 | | - if windspeed is not None: |
276 | | - details.append(f"wind {windspeed} mph") |
277 | | - if weather_code is not None: |
278 | | - details.append(f"code {weather_code}") |
279 | | - extra = f" ({', '.join(details)})" if details else "" |
280 | | - |
281 | | - return { |
282 | | - "status": "ok", |
283 | | - "location": place, |
284 | | - "temperatureF": temperature, |
285 | | - "feelsLikeF": feels_like, |
286 | | - "humidityPercent": humidity, |
287 | | - "summary": f"The weather for {place} is {temperature}°F{extra}.", |
288 | | - "suggestedQueries": suggested_queries, |
289 | | - } |
| 10 | +from src.middleware import apply_structured_output_schema |
| 11 | +from src.weather import get_weather |
290 | 12 |
|
291 | 13 | class AgentState(CopilotKitState): |
292 | 14 | proverbs: List[str] |
293 | 15 |
|
294 | | - |
295 | 16 | class AgentContext(TypedDict, total=False): |
296 | 17 | output_schema: dict[str, Any] |
297 | 18 |
|
298 | | - |
299 | | - |
300 | 19 | agent = create_agent( |
301 | 20 | model="openai:gpt-5.2", |
302 | 21 | tools=[get_weather], |
|
0 commit comments