Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion strix/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ async def _wait_for_input(self) -> None:

return

await asyncio.sleep(0.5)
await self.state.wait_for_wake(timeout=0.5)

async def _enter_waiting_state(
self,
Expand Down
22 changes: 22 additions & 0 deletions strix/agents/state.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import uuid
from datetime import UTC, datetime
from typing import Any
Expand All @@ -10,6 +11,8 @@ def _generate_agent_id() -> str:


class AgentState(BaseModel):
model_config = {"arbitrary_types_allowed": True}

agent_id: str = Field(default_factory=_generate_agent_id)
agent_name: str = "Strix Agent"
parent_id: str | None = None
Expand Down Expand Up @@ -39,6 +42,9 @@ class AgentState(BaseModel):

errors: list[str] = Field(default_factory=list)

# Event for signaling state changes (excluded from serialization)
_wake_event: asyncio.Event = Field(default_factory=asyncio.Event, exclude=True)

def increment_iteration(self) -> None:
self.iteration += 1
self.last_updated = datetime.now(UTC).isoformat()
Expand All @@ -49,6 +55,8 @@ def add_message(self, role: str, content: Any, thinking_blocks: list[dict[str, A
message["thinking_blocks"] = thinking_blocks
self.messages.append(message)
self.last_updated = datetime.now(UTC).isoformat()
if self.waiting_for_input:
self._wake_event.set()

def add_action(self, action: dict[str, Any]) -> None:
self.actions_taken.append(
Expand Down Expand Up @@ -106,6 +114,20 @@ def resume_from_waiting(self, new_task: str | None = None) -> None:
if new_task:
self.task = new_task
self.last_updated = datetime.now(UTC).isoformat()
self._wake_event.set()

def signal_wake(self) -> None:
"""Signal the agent to wake up from waiting."""
self._wake_event.set()

async def wait_for_wake(self, timeout: float = 0.5) -> bool:
"""Wait for a wake signal with timeout. Returns True if signaled, False on timeout."""
try:
await asyncio.wait_for(self._wake_event.wait(), timeout=timeout)
self._wake_event.clear()
return True
except TimeoutError:
return False

def has_reached_max_iterations(self) -> bool:
return self.iteration >= self.max_iterations
Expand Down
61 changes: 60 additions & 1 deletion strix/llm/dedupe.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging
import re
from difflib import SequenceMatcher
from typing import Any

import litellm
Expand All @@ -10,6 +11,8 @@

logger = logging.getLogger(__name__)

MAX_COMPARISON_CANDIDATES = 10

DEDUPE_SYSTEM_PROMPT = """You are an expert vulnerability report deduplication judge.
Your task is to determine if a candidate vulnerability report describes the SAME vulnerability
as any existing report.
Expand Down Expand Up @@ -138,6 +141,59 @@ def _parse_dedupe_response(content: str) -> dict[str, Any]:
}


def _compute_similarity(report_a: dict[str, Any], report_b: dict[str, Any]) -> float:
"""Compute lightweight string similarity between two reports for pre-filtering."""
fields = ["title", "endpoint", "method", "target", "description"]

total_score = 0.0
weights_sum = 0.0

# Weighted fields: title and endpoint matter most for duplicate detection
field_weights = {
"title": 3.0,
"endpoint": 3.0,
"method": 1.5,
"target": 2.0,
"description": 1.0,
}

for field in fields:
val_a = str(report_a.get(field, "")).lower().strip()
val_b = str(report_b.get(field, "")).lower().strip()
weight = field_weights.get(field, 1.0)

if not val_a or not val_b:
continue

ratio = SequenceMatcher(None, val_a, val_b).ratio()
total_score += ratio * weight
weights_sum += weight

return total_score / weights_sum if weights_sum > 0 else 0.0


def _prefilter_candidates(
candidate: dict[str, Any],
existing_reports: list[dict[str, Any]],
max_candidates: int = MAX_COMPARISON_CANDIDATES,
) -> list[dict[str, Any]]:
"""Pre-filter existing reports using string similarity to reduce LLM calls.

Only the top-N most similar reports are sent to the LLM for detailed comparison,
avoiding sending hundreds of reports in a single prompt.
"""
if len(existing_reports) <= max_candidates:
return existing_reports

scored = []
for report in existing_reports:
score = _compute_similarity(candidate, report)
scored.append((score, report))

scored.sort(key=lambda x: x[0], reverse=True)
return [report for _, report in scored[:max_candidates]]


def check_duplicate(
candidate: dict[str, Any], existing_reports: list[dict[str, Any]]
) -> dict[str, Any]:
Expand All @@ -150,8 +206,11 @@ def check_duplicate(
}

try:
# Pre-filter to only compare against the most similar existing reports
filtered_reports = _prefilter_candidates(candidate, existing_reports)

candidate_cleaned = _prepare_report_for_comparison(candidate)
existing_cleaned = [_prepare_report_for_comparison(r) for r in existing_reports]
existing_cleaned = [_prepare_report_for_comparison(r) for r in filtered_reports]

comparison_data = {"candidate": candidate_cleaned, "existing_reports": existing_cleaned}

Expand Down
20 changes: 15 additions & 5 deletions strix/llm/memory_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
logger = logging.getLogger(__name__)


MAX_TOTAL_TOKENS = 100_000
MAX_TOTAL_TOKENS = 60_000
MIN_RECENT_MESSAGES = 15
COMPRESSION_CHUNK_SIZE = 20

SUMMARY_PROMPT_TEMPLATE = """You are an agent performing context
condensation for a security agent. Your job is to compress scan data while preserving
Expand Down Expand Up @@ -43,13 +44,22 @@
keeping the summary concise and to the point."""


_token_cache: dict[int, int] = {}


def _count_tokens(text: str, model: str) -> int:
cache_key = hash(text)
if cache_key in _token_cache:
return _token_cache[cache_key]

try:
count = litellm.token_counter(model=model, text=text)
return int(count)
count = int(litellm.token_counter(model=model, text=text))
except Exception:
logger.exception("Failed to count tokens")
return len(text) // 4 # Rough estimate
count = len(text) // 4 # Rough estimate

_token_cache[cache_key] = count
return count
Comment on lines +47 to +62
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_token_cache grows unbounded. For long-running agents with many unique messages, this will consume increasing memory. Consider adding LRU eviction or size limits.

Prompt To Fix With AI
This is a comment left during a code review.
Path: strix/llm/memory_compressor.py
Line: 47:62

Comment:
`_token_cache` grows unbounded. For long-running agents with many unique messages, this will consume increasing memory. Consider adding LRU eviction or size limits.

How can I resolve this? If you propose a fix, please make it concise.



def _get_message_tokens(msg: dict[str, Any], model: str) -> int:
Expand Down Expand Up @@ -215,7 +225,7 @@ def compress_history(
return messages

compressed = []
chunk_size = 10
chunk_size = COMPRESSION_CHUNK_SIZE
for i in range(0, len(old_msgs), chunk_size):
chunk = old_msgs[i : i + chunk_size]
summary = _summarize_messages(chunk, model_name, self.timeout)
Expand Down
112 changes: 88 additions & 24 deletions strix/tools/executor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import inspect
import os
from typing import Any
Expand Down Expand Up @@ -25,6 +26,31 @@
SANDBOX_EXECUTION_TIMEOUT = _SERVER_TIMEOUT + 30
SANDBOX_CONNECT_TIMEOUT = float(Config.get("strix_sandbox_connect_timeout") or "10")

# Connection pool: reuse HTTP clients per sandbox instead of creating one per call
_sandbox_clients: dict[str, httpx.AsyncClient] = {}


def _get_sandbox_client(sandbox_id: str) -> httpx.AsyncClient:
"""Get or create a persistent HTTP client for a sandbox, enabling connection reuse."""
if sandbox_id not in _sandbox_clients:
timeout = httpx.Timeout(
timeout=SANDBOX_EXECUTION_TIMEOUT,
connect=SANDBOX_CONNECT_TIMEOUT,
)
_sandbox_clients[sandbox_id] = httpx.AsyncClient(
trust_env=False,
timeout=timeout,
limits=httpx.Limits(max_connections=10, max_keepalive_connections=5),
)
return _sandbox_clients[sandbox_id]


async def close_sandbox_client(sandbox_id: str) -> None:
"""Close and remove the HTTP client for a sandbox when it's torn down."""
client = _sandbox_clients.pop(sandbox_id, None)
if client:
await client.aclose()
Comment on lines +48 to +52
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

close_sandbox_client is defined but never called in the codebase. Connection pool clients accumulate without cleanup when sandboxes are torn down, leading to resource leaks.

Prompt To Fix With AI
This is a comment left during a code review.
Path: strix/tools/executor.py
Line: 48:52

Comment:
`close_sandbox_client` is defined but never called in the codebase. Connection pool clients accumulate without cleanup when sandboxes are torn down, leading to resource leaks.

How can I resolve this? If you propose a fix, please make it concise.



async def execute_tool(tool_name: str, agent_state: Any | None = None, **kwargs: Any) -> Any:
execute_in_sandbox = should_execute_in_sandbox(tool_name)
Expand Down Expand Up @@ -71,31 +97,27 @@ async def _execute_tool_in_sandbox(tool_name: str, agent_state: Any, **kwargs: A
"Content-Type": "application/json",
}

timeout = httpx.Timeout(
timeout=SANDBOX_EXECUTION_TIMEOUT,
connect=SANDBOX_CONNECT_TIMEOUT,
)
client = _get_sandbox_client(agent_state.sandbox_id)

async with httpx.AsyncClient(trust_env=False) as client:
try:
response = await client.post(
request_url, json=request_data, headers=headers, timeout=timeout
)
response.raise_for_status()
response_data = response.json()
if response_data.get("error"):
posthog.error("tool_execution_error", f"{tool_name}: {response_data['error']}")
raise RuntimeError(f"Sandbox execution error: {response_data['error']}")
return response_data.get("result")
except httpx.HTTPStatusError as e:
posthog.error("tool_http_error", f"{tool_name}: HTTP {e.response.status_code}")
if e.response.status_code == 401:
raise RuntimeError("Authentication failed: Invalid or missing sandbox token") from e
raise RuntimeError(f"HTTP error calling tool server: {e.response.status_code}") from e
except httpx.RequestError as e:
error_type = type(e).__name__
posthog.error("tool_request_error", f"{tool_name}: {error_type}")
raise RuntimeError(f"Request error calling tool server: {error_type}") from e
try:
response = await client.post(
request_url, json=request_data, headers=headers
)
response.raise_for_status()
response_data = response.json()
if response_data.get("error"):
posthog.error("tool_execution_error", f"{tool_name}: {response_data['error']}")
raise RuntimeError(f"Sandbox execution error: {response_data['error']}")
return response_data.get("result")
except httpx.HTTPStatusError as e:
posthog.error("tool_http_error", f"{tool_name}: HTTP {e.response.status_code}")
if e.response.status_code == 401:
raise RuntimeError("Authentication failed: Invalid or missing sandbox token") from e
raise RuntimeError(f"HTTP error calling tool server: {e.response.status_code}") from e
except httpx.RequestError as e:
error_type = type(e).__name__
posthog.error("tool_request_error", f"{tool_name}: {error_type}")
raise RuntimeError(f"Request error calling tool server: {error_type}") from e


async def _execute_tool_locally(tool_name: str, agent_state: Any | None, **kwargs: Any) -> Any:
Expand Down Expand Up @@ -310,6 +332,13 @@ def _get_tracer_and_agent_id(agent_state: Any | None) -> tuple[Any | None, str]:
return tracer, agent_id


# Tools that modify shared state and must run sequentially
_SEQUENTIAL_TOOLS = frozenset({
"finish_scan", "agent_finish", "delegate_task", "send_message",
"wait_for_message", "create_agent",
})


async def process_tool_invocations(
tool_invocations: list[dict[str, Any]],
conversation_history: list[dict[str, Any]],
Expand All @@ -321,7 +350,42 @@ async def process_tool_invocations(

tracer, agent_id = _get_tracer_and_agent_id(agent_state)

# Partition into parallelizable and sequential tools
parallel_batch: list[dict[str, Any]] = []
sequential_queue: list[dict[str, Any]] = []

for tool_inv in tool_invocations:
tool_name = tool_inv.get("toolName", "unknown")
if tool_name in _SEQUENTIAL_TOOLS:
sequential_queue.append(tool_inv)
else:
parallel_batch.append(tool_inv)

# Execute parallelizable tools concurrently
if parallel_batch:
tasks = [
_execute_single_tool(tool_inv, agent_state, tracer, agent_id)
for tool_inv in parallel_batch
]
results = await asyncio.gather(*tasks, return_exceptions=True)

for i, result in enumerate(results):
if isinstance(result, Exception):
tool_name = parallel_batch[i].get("toolName", "unknown")
error_xml = (
f"<tool_result>\n<tool_name>{tool_name}</tool_name>\n"
f"<result>Error executing {tool_name}: {result!s}</result>\n</tool_result>"
)
observation_parts.append(error_xml)
else:
observation_xml, images, tool_should_finish = result
observation_parts.append(observation_xml)
all_images.extend(images)
if tool_should_finish:
should_agent_finish = True

# Execute sequential tools one at a time (order matters)
for tool_inv in sequential_queue:
observation_xml, images, tool_should_finish = await _execute_single_tool(
tool_inv, agent_state, tracer, agent_id
)
Expand Down