Skip to content

Commit

Permalink
Fixes for event handlers and agents
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerberndt committed Dec 7, 2024
1 parent a35d284 commit 131fbbb
Show file tree
Hide file tree
Showing 10 changed files with 182 additions and 100 deletions.
3 changes: 1 addition & 2 deletions src/actions/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from src.util.logging import Logger
from src.actions.result import ActionResult
from src.actions.builtin import get_builtin_actions
from src.actions.semantic_search import SemanticSearchAction


class ActionRegistry:
Expand Down Expand Up @@ -48,7 +47,7 @@ async def handler(*args, **kwargs) -> str:
action = action_class()

# For semantic search, prepend "search" to the query
if action_class == SemanticSearchAction and args:
if action_class.__name__ == "SemanticSearchAction" and args:
query = f"search {' '.join(args)}"
result = await action.execute(query)
else:
Expand Down
15 changes: 8 additions & 7 deletions src/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from src.config.config import Config
from src.util.logging import Logger
from dataclasses import dataclass
from src.actions.registry import ActionRegistry


@dataclass
Expand Down Expand Up @@ -46,6 +45,8 @@ def __init__(self, custom_prompt: Optional[str] = None, command_names: Optional[
self.client = AsyncOpenAI(api_key=self.config.openai_api_key)

# Initialize action registry and wait for it to be ready
from src.actions.registry import ActionRegistry

self.action_registry = ActionRegistry()
self.action_registry.initialize() # Explicitly call initialize

Expand Down Expand Up @@ -73,7 +74,7 @@ def _get_available_commands(self, command_names: Optional[List[str]] = None) ->
"""Get the commands available to this agent
Args:
command_names: Optional list of command names to include. If None, all commands are included.
command_names: List of command names to include. If None or empty, no commands are included.
Returns:
Dict mapping command names to AgentCommand objects
Expand All @@ -84,12 +85,12 @@ def _get_available_commands(self, command_names: Optional[List[str]] = None) ->
actions = self.action_registry.get_actions()
self.logger.info("Found registered actions:", extra_data={"available_actions": list(actions.keys())})

# If command_names is None or empty, include all commands
# If command_names is None or empty, include no commands
if not command_names:
command_names = list(actions.keys())
self.logger.info("Using all available commands")
else:
self.logger.info("Filtering actions by command names:", extra_data={"requested_commands": command_names})
self.logger.info("No commands specified, agent will have no commands")
return commands

self.logger.info("Filtering actions by command names:", extra_data={"requested_commands": command_names})

# Convert actions to commands
for name, (_, spec) in actions.items():
Expand Down
8 changes: 8 additions & 0 deletions src/agents/conversation_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ class ConversationAgent(BaseAgent):
def __init__(self, command_names: Optional[List[str]] = None):
self.logger = Logger("ConversationAgent")

# Get all available commands from action registry
from src.actions.registry import ActionRegistry

action_registry = ActionRegistry()
action_registry.initialize()
command_names = list(action_registry.get_actions().keys())
self.logger.info("Using all available commands:", extra_data={"commands": command_names})

# Add specialized prompt for conversation handling
custom_prompt = """You are specialized in having helpful conversations with security researchers.
Expand Down
13 changes: 1 addition & 12 deletions src/agents/github_event_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,7 @@ def __init__(self):
- Potential economic attack vectors"""

# Specify commands this agent can use
command_names = [
"semantic_search", # For searching code semantically
"grep_search", # For pattern matching
"db_query", # For database queries
]
command_names = []

super().__init__(custom_prompt=custom_prompt, command_names=command_names)
DBSessionMixin.__init__(self)
Expand Down Expand Up @@ -108,11 +104,6 @@ async def analyze_pr(self, repo_url: str, pr_data: Dict[str, Any]) -> str:
]
)

# Search for similar patterns
similar_patterns = await self.execute_command(
"semantic_search", query=f"Find security issues similar to: {pr_data.get('title')} {pr_data.get('body')}"
)

# Get AI analysis
analysis_prompt = "\n".join(
[
Expand All @@ -124,8 +115,6 @@ async def analyze_pr(self, repo_url: str, pr_data: Dict[str, Any]) -> str:
"4. How does this relate to known vulnerability types?",
"\nContext:",
"\n".join(context),
"\nSimilar patterns found:",
similar_patterns,
]
)

Expand Down
6 changes: 4 additions & 2 deletions src/handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class Handler(ABC):

def __init__(self):
self.context = {}
self.trigger = None

@classmethod
@abstractmethod
Expand All @@ -47,6 +48,7 @@ def get_triggers(cls) -> List[HandlerTrigger]:
def handle(self) -> None:
"""Handle an event"""

def set_context(self, context: Dict[str, Any]) -> None:
"""Set context for this handler"""
def set_context(self, context: Dict[str, Any], trigger: HandlerTrigger = None) -> None:
"""Set context and trigger for this handler"""
self.context = context
self.trigger = trigger
9 changes: 6 additions & 3 deletions src/handlers/event_bus.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,27 @@ def register_handler(self, handler_class: Type[Handler]) -> None:
if trigger not in self._handlers:
self._handlers[trigger] = []
self._handlers[trigger].append(handler_class)
self.logger.debug(f"Registered {handler_class.__name__} for trigger {trigger.name}")

async def trigger_event(self, trigger: HandlerTrigger, context: Dict) -> None:
"""Trigger handlers for a specific event"""
if trigger not in self._handlers:
self.logger.warning(f"No handlers registered for trigger {trigger.name}")
return

self.logger.info(f"Triggering {trigger.value} handlers", extra_data={"context": context})
self.logger.info(f"Triggering {len(self._handlers[trigger])} handlers for {trigger.name}")

handler_tasks = []
for handler_class in self._handlers[trigger]:
try:
self.logger.debug(f"Creating handler instance for {handler_class.__name__}")
handler = handler_class()
handler.set_context(context)
handler.set_context(context, trigger)
handler_tasks.append(asyncio.create_task(handler.handle()))
except Exception as e:
self.logger.error(
f"Handler {handler_class.__name__} failed: {str(e)}",
extra_data={"trigger": trigger.value, "context": context},
extra_data={"trigger": trigger.name, "context": context},
)

# Wait for all handlers to complete
Expand Down
86 changes: 55 additions & 31 deletions src/handlers/github_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime, timezone
from sqlalchemy import or_
from src.agents.github_event_agent import GitHubSecurityAgent


class GitHubEventJob(Job, DBSessionMixin):
Expand Down Expand Up @@ -57,6 +58,12 @@ async def process_pr(self) -> str:
pr = self.payload.get("pull_request", {})
repo_url = self.payload.get("repo_url", "")

self.logger.debug(f"PROCESS PR: {repo_url} {pr}")

# Get security analysis
security_agent = GitHubSecurityAgent()
security_analysis = await security_agent.analyze_pr(repo_url, pr)

# Basic PR info
summary_lines = [
f"🔍 New PR in {repo_url}\n",
Expand All @@ -67,6 +74,8 @@ async def process_pr(self) -> str:
f"Changed files: {pr.get('changed_files', 0)}",
f"Additions: {pr.get('additions', 0)}",
f"Deletions: {pr.get('deletions', 0)}",
"\n🔒 Security Analysis:",
security_analysis,
]

# Look for related assets
Expand All @@ -93,12 +102,19 @@ async def process_push(self) -> str:
repo_url = self.payload.get("repo_url", "")
commit = self.payload.get("commit", {})

self.logger.debug(f"PROCESS PUSH: {repo_url} {commit}")
# Get security analysis
security_agent = GitHubSecurityAgent()
security_analysis = await security_agent.analyze_commit(repo_url, commit)

# Basic push info
summary_lines = [
f"📦 New commit in {repo_url}\n",
f"Message: {commit.get('commit', {}).get('message', 'No message')}",
f"Author: {commit.get('commit', {}).get('author', {}).get('name', 'Unknown')}",
f"URL: {commit.get('html_url', '')}",
"\n🔒 Security Analysis:",
security_analysis,
]

# Look for related assets
Expand All @@ -123,9 +139,6 @@ async def process_push(self) -> str:
async def start(self) -> None:
"""Process the GitHub event"""
try:
self.status = JobStatus.RUNNING
self.started_at = datetime.now(timezone.utc)

# Process based on event type
if self.event_type == "pull_request":
summary = await self.process_pr()
Expand Down Expand Up @@ -161,38 +174,49 @@ class GitHubEventHandler(Handler):
def __init__(self):
super().__init__()
self.logger = Logger("GitHubEventHandler")
self.logger.debug("GitHubEventHandler initialized")

@classmethod
def get_triggers(cls) -> List[HandlerTrigger]:
"""Get list of triggers this handler listens for"""
return [HandlerTrigger.GITHUB_PR, HandlerTrigger.GITHUB_PUSH]
triggers = [HandlerTrigger.GITHUB_PR, HandlerTrigger.GITHUB_PUSH]
Logger("GitHubEventHandler").debug(f"Registering triggers: {[t.name for t in triggers]}")
return triggers

async def handle(self) -> None:
"""Handle a GitHub event"""
if not self.context:
self.logger.error("No context provided")
return

if not self.context.get("payload"):
self.logger.error("No payload in context")
return

# Get the event data from the context
event_data = self.context
if not event_data:
self.logger.error("No event data in context")
return

# Determine event type from trigger
if self.trigger == HandlerTrigger.GITHUB_PR:
event_type = "pull_request"
elif self.trigger == HandlerTrigger.GITHUB_PUSH:
event_type = "push"
else:
self.logger.error(f"Unsupported trigger: {self.trigger}")
return

# Create and submit job
job = GitHubEventJob(event_type, event_data)
job_manager = JobManager()
await job_manager.submit_job(job)
try:
self.logger.debug("Starting handler")
self.logger.debug(f"Trigger type: {type(self.trigger)}")
self.logger.debug(f"Trigger value: {self.trigger}")
self.logger.debug(f"Trigger dict: {self.trigger.__dict__}")

if not self.context:
self.logger.error("No context provided")
return

payload = self.context.get("payload")

if not payload:
self.logger.error("No payload in context")
return

# Determine event type from trigger
if self.trigger == HandlerTrigger.GITHUB_PR:
event_type = "pull_request"
elif self.trigger == HandlerTrigger.GITHUB_PUSH:
event_type = "push"
else:
self.logger.error(f"Unsupported trigger: {self.trigger}")
return

# Create and submit job
job = GitHubEventJob(event_type, payload)

job_manager = JobManager()

await job_manager.submit_job(job)

except Exception as e:
self.logger.error(f"Error handling GitHub event: {str(e)}", exc_info=True)
raise
2 changes: 1 addition & 1 deletion src/util/etherscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self):
import logging

self.config = Config()
self.logger = logging.getLogger("EVMExplorer")
self.logger = logging.getLogger("EVMExplorer") # pylint: disable=no-member

def is_supported_explorer(self, url: str) -> Tuple[bool, Optional[ExplorerType]]:
"""Check if a URL is from a supported explorer
Expand Down
12 changes: 8 additions & 4 deletions src/watchers/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ async def _check_repo_updates(self, repo: Dict[str, Any]) -> None:

try:
commits, prs = await asyncio.gather(commits_task, prs_task)
self.logger.debug(f"Got commits: {commits}")
self.logger.debug(f"Got PRs: {prs}")
# self.logger.debug(f"Got commits: {commits}")
# self.logger.debug(f"Got PRs: {prs}")
except Exception as e:
self.logger.error(f"Failed to fetch updates for {repo_url}: {str(e)}")
return
Expand All @@ -163,7 +163,9 @@ async def _check_repo_updates(self, repo: Dict[str, Any]) -> None:
# Trigger event if this is a new commit
if not last_commit_sha or commit["sha"] != last_commit_sha:
self.logger.info(f"Found new commit: {commit['sha']}")
await self.handler_registry.trigger_event(HandlerTrigger.GITHUB_PUSH, {"repo_url": repo_url, "commit": commit})
await self.handler_registry.trigger_event(
HandlerTrigger.GITHUB_PUSH, {"payload": {"repo_url": repo_url, "commit": commit}}
)

# Process PR updates
last_pr_number = repo.get("last_pr_number") or 0 # Default to 0 if None
Expand All @@ -178,7 +180,9 @@ async def _check_repo_updates(self, repo: Dict[str, Any]) -> None:
# Trigger event for new PRs
if pr_number > last_pr_number:
self.logger.info(f"Found new PR: {pr_number}")
await self.handler_registry.trigger_event(HandlerTrigger.GITHUB_PR, {"repo_url": repo_url, "pull_request": pr})
await self.handler_registry.trigger_event(
HandlerTrigger.GITHUB_PR, {"payload": {"repo_url": repo_url, "pull_request": pr}}
)

# Keep track of the highest PR number
last_pr_number = max(last_pr_number, pr_number)
Expand Down
Loading

0 comments on commit 131fbbb

Please sign in to comment.