Skip to content

Commit 3f11d44

Browse files
authored
Pedro/fix vertex cache leak (#4135)
1 parent 2eeca1c commit 3f11d44

File tree

4 files changed

+282
-88
lines changed

4 files changed

+282
-88
lines changed

skyvern/forge/agent.py

Lines changed: 104 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import base64
3+
import hashlib
34
import json
45
import os
56
import random
@@ -2467,8 +2468,35 @@ async def build_and_record_step_prompt(
24672468

24682469
return scraped_page, extract_action_prompt, use_caching
24692470

2471+
@staticmethod
2472+
def _build_extract_action_cache_variant(
2473+
verification_code_check: bool,
2474+
has_magic_link_page: bool,
2475+
complete_criterion: str | None,
2476+
) -> str:
2477+
"""
2478+
Build a short-but-unique cache variant identifier so extract-action prompts that
2479+
differ meaningfully (OTP, magic link flows, complete criteria) do not reuse the
2480+
same Vertex cache object.
2481+
"""
2482+
variant_parts: list[str] = []
2483+
if verification_code_check:
2484+
variant_parts.append("vc")
2485+
if has_magic_link_page:
2486+
variant_parts.append("ml")
2487+
if complete_criterion:
2488+
normalized = " ".join(complete_criterion.split())
2489+
digest = hashlib.sha1(normalized.encode("utf-8")).hexdigest()[:6]
2490+
variant_parts.append(f"cc{digest}")
2491+
return "-".join(variant_parts) if variant_parts else "std"
2492+
24702493
async def _create_vertex_cache_for_task(
2471-
self, task: Task, static_prompt: str, context: SkyvernContext, llm_key_override: str | None
2494+
self,
2495+
task: Task,
2496+
static_prompt: str,
2497+
context: SkyvernContext,
2498+
llm_key_override: str | None,
2499+
prompt_variant: str | None = None,
24722500
) -> None:
24732501
"""
24742502
Create a Vertex AI cache for the task's static prompt.
@@ -2479,9 +2507,9 @@ async def _create_vertex_cache_for_task(
24792507
task: The task to create cache for
24802508
static_prompt: The static prompt content to cache
24812509
context: The Skyvern context to store the cache name in
2510+
llm_key_override: Optional override when we explicitly pick an LLM key
2511+
prompt_variant: Cache variant identifier (std/vc/ml/etc.)
24822512
"""
2483-
# Early return if task doesn't have an llm_key
2484-
# This should not happen given the guard at the call site, but being defensive
24852513
resolved_llm_key = llm_key_override or task.llm_key
24862514

24872515
if not resolved_llm_key:
@@ -2491,17 +2519,20 @@ async def _create_vertex_cache_for_task(
24912519
)
24922520
return
24932521

2522+
cache_variant = prompt_variant or "std"
2523+
24942524
try:
24952525
LOG.info(
24962526
"Attempting Vertex AI cache creation",
24972527
task_id=task.task_id,
24982528
llm_key=resolved_llm_key,
2529+
cache_variant=cache_variant,
24992530
)
25002531
cache_manager = get_cache_manager()
25012532

2502-
# Use llm_key as cache_key so all tasks with the same model share the same cache
2503-
# This maximizes cache reuse and reduces cache storage costs
2504-
cache_key = f"{EXTRACT_ACTION_CACHE_KEY_PREFIX}-{resolved_llm_key}"
2533+
variant_suffix = f"-{cache_variant}" if cache_variant else ""
2534+
2535+
cache_key = f"{EXTRACT_ACTION_CACHE_KEY_PREFIX}{variant_suffix}-{resolved_llm_key}"
25052536

25062537
# Get the actual model name from LLM config to ensure correct format
25072538
# (e.g., "gemini-2.5-flash" with decimal, not "gemini-2-5-flash")
@@ -2565,15 +2596,18 @@ async def _create_vertex_cache_for_task(
25652596
ttl_seconds=3600, # 1 hour
25662597
)
25672598

2568-
# Store cache resource name in context
2599+
# Store cache metadata in context
25692600
context.vertex_cache_name = cache_data["name"]
2601+
context.vertex_cache_key = cache_key
2602+
context.vertex_cache_variant = cache_variant
25702603

25712604
LOG.info(
25722605
"Created Vertex AI cache for task",
25732606
task_id=task.task_id,
25742607
cache_key=cache_key,
25752608
cache_name=cache_data["name"],
25762609
model_name=model_name,
2610+
cache_variant=cache_variant,
25772611
)
25782612
except Exception as e:
25792613
LOG.warning(
@@ -2653,7 +2687,7 @@ async def _build_extract_action_prompt(
26532687

26542688
# Check if prompt caching is enabled for extract-action
26552689
use_caching = False
2656-
prompt_caching_settings = LLMAPIHandlerFactory._prompt_caching_settings or {}
2690+
prompt_caching_settings = await self._get_prompt_caching_settings(context)
26572691
effective_llm_key = task.llm_key
26582692
if not effective_llm_key:
26592693
handler_for_key = LLMAPIHandlerFactory.get_override_llm_api_handler(
@@ -2701,6 +2735,11 @@ async def _build_extract_action_prompt(
27012735
"parse_select_feature_enabled": context.enable_parse_select_in_extract,
27022736
"has_magic_link_page": context.has_magic_link_page(task.task_id),
27032737
}
2738+
cache_variant = self._build_extract_action_cache_variant(
2739+
verification_code_check=verification_code_check,
2740+
has_magic_link_page=context.has_magic_link_page(task.task_id),
2741+
complete_criterion=task.complete_criterion.strip() if task.complete_criterion else None,
2742+
)
27042743
static_prompt = prompt_engine.load_prompt(f"{template}-static", **prompt_kwargs)
27052744
dynamic_prompt = prompt_engine.load_prompt(
27062745
f"{template}-dynamic",
@@ -2718,14 +2757,21 @@ async def _build_extract_action_prompt(
27182757

27192758
# Create Vertex AI cache for Gemini models
27202759
if effective_llm_key and "GEMINI" in effective_llm_key:
2721-
await self._create_vertex_cache_for_task(task, static_prompt, context, effective_llm_key)
2760+
await self._create_vertex_cache_for_task(
2761+
task,
2762+
static_prompt,
2763+
context,
2764+
effective_llm_key,
2765+
prompt_variant=cache_variant,
2766+
)
27222767

27232768
combined_prompt = f"{static_prompt.rstrip()}\n\n{dynamic_prompt.lstrip()}"
27242769

27252770
LOG.info(
27262771
"Using cached prompt",
27272772
task_id=task.task_id,
27282773
prompt_name=EXTRACT_ACTION_PROMPT_NAME,
2774+
cache_variant=cache_variant,
27292775
)
27302776
return combined_prompt, use_caching
27312777

@@ -2755,6 +2801,55 @@ async def _build_extract_action_prompt(
27552801

27562802
return full_prompt, use_caching
27572803

2804+
async def _get_prompt_caching_settings(self, context: SkyvernContext) -> dict[str, bool]:
2805+
"""
2806+
Resolve prompt caching settings for the current run.
2807+
2808+
We prefer explicit overrides via LLMAPIHandlerFactory.set_prompt_caching_settings(), which
2809+
are mostly used by scripts/tests. When no override exists, evaluate the PostHog experiment
2810+
once per context and cache the result on the context to avoid repeated lookups.
2811+
"""
2812+
if LLMAPIHandlerFactory._prompt_caching_settings is not None:
2813+
return LLMAPIHandlerFactory._prompt_caching_settings
2814+
2815+
if context.prompt_caching_settings is not None:
2816+
return context.prompt_caching_settings
2817+
2818+
distinct_id = context.run_id or context.workflow_run_id or context.task_id
2819+
organization_id = context.organization_id
2820+
context.prompt_caching_settings = {}
2821+
2822+
if not distinct_id or not organization_id:
2823+
return context.prompt_caching_settings
2824+
2825+
try:
2826+
enabled = await app.EXPERIMENTATION_PROVIDER.is_feature_enabled_cached(
2827+
"PROMPT_CACHING_OPTIMIZATION",
2828+
distinct_id,
2829+
properties={"organization_id": organization_id},
2830+
)
2831+
except Exception as exc:
2832+
LOG.warning(
2833+
"Failed to evaluate prompt caching experiment; defaulting to disabled",
2834+
distinct_id=distinct_id,
2835+
organization_id=organization_id,
2836+
error=str(exc),
2837+
)
2838+
return context.prompt_caching_settings
2839+
2840+
if enabled:
2841+
context.prompt_caching_settings = {
2842+
EXTRACT_ACTION_PROMPT_NAME: True,
2843+
EXTRACT_ACTION_TEMPLATE: True,
2844+
}
2845+
LOG.info(
2846+
"Prompt caching optimization enabled",
2847+
distinct_id=distinct_id,
2848+
organization_id=organization_id,
2849+
)
2850+
2851+
return context.prompt_caching_settings
2852+
27582853
def _should_process_totp(self, scraped_page: ScrapedPage | None) -> bool:
27592854
"""Detect TOTP pages by checking for multiple input fields or verification keywords."""
27602855
if not scraped_page:

0 commit comments

Comments
 (0)