Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RecallObservations for retrieval of prompt extensions #6909

Draft
wants to merge 35 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
d80c376
track used tokens
enyst Feb 22, 2025
bd9fc55
add response_id
enyst Feb 22, 2025
c59abb5
test accumulation
enyst Feb 22, 2025
dba25f5
clean up
enyst Feb 22, 2025
38b5198
fix not initialized
enyst Feb 22, 2025
b1a18d5
retrieve tokens usage for an event
enyst Feb 22, 2025
5b063cc
add tests
enyst Feb 22, 2025
801b134
fix tests
enyst Feb 22, 2025
c21ddaf
add recall action and observation
enyst Feb 22, 2025
16da353
refactor prompt extensions
enyst Feb 22, 2025
c26185d
dont want to fight o1 right now, will revisit
enyst Feb 22, 2025
143293d
fix logic
enyst Feb 22, 2025
66781fc
fix subscriber
enyst Feb 22, 2025
956b3b4
create memory
enyst Feb 22, 2025
f109a2a
rename memory to long term memory
enyst Feb 23, 2025
bb5817c
rename main module to memory
enyst Feb 23, 2025
0e54bab
rename to memory
enyst Feb 23, 2025
b95d540
refactor prompt manager to manage the view, memory manages info retri…
enyst Feb 23, 2025
21c2253
fix disabled microagents
enyst Feb 23, 2025
d596fd2
refactor info in the first user message to a recalled observation
enyst Feb 23, 2025
2c5018f
Merge branch 'main' of github.com:All-Hands-AI/OpenHands into enyst/r…
enyst Feb 23, 2025
bec0594
Merge branch 'main' of github.com:All-Hands-AI/OpenHands into enyst/r…
enyst Feb 24, 2025
c25701f
add memory.py
enyst Feb 24, 2025
88ae5d2
tweak name
enyst Feb 25, 2025
83fc613
Merge branch 'main' of github.com:All-Hands-AI/OpenHands into enyst/r…
enyst Feb 25, 2025
6cd9ece
add selected_repo command line arg
enyst Feb 25, 2025
5e495d0
fix init order
enyst Feb 25, 2025
6ef3e1e
add selected_repo command line arg
enyst Feb 25, 2025
338ae2c
fix loading
enyst Feb 26, 2025
69fc6f1
add selected_repo to sandbox_config
enyst Feb 26, 2025
d968b7f
update arg parser
enyst Feb 26, 2025
9ea235d
tweak
enyst Feb 26, 2025
df212ce
fix attr name
enyst Feb 26, 2025
a93e7ca
Merge branch 'enyst/selected-repo' into enyst/retrieve-prompt
enyst Feb 26, 2025
0e9f94e
use sandbox var
enyst Feb 26, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 1 addition & 17 deletions openhands/agenthub/codeact_agent/codeact_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os
from collections import deque

import openhands
import openhands.agenthub.codeact_agent.function_calling as codeact_function_calling
from openhands.controller.agent import Agent
from openhands.controller.state.state import State
Expand Down Expand Up @@ -35,7 +34,7 @@ class CodeActAgent(Agent):

### Overview

This agent implements the CodeAct idea ([paper](https://arxiv.org/abs/2402.01030), [tweet](https://twitter.com/xingyaow_/status/1754556835703751087)) that consolidates LLM agents **act**ions into a unified **code** action space for both *simplicity* and *performance* (see paper for more details).
This agent implements the CodeAct idea ([paper](https://arxiv.org/abs/2402.01030), [tweet](https://twitter.com/xingyaow_/status/1754556835703751087)) that consolidates LLM agents' **act**ions into a unified **code** action space for both *simplicity* and *performance* (see paper for more details).

The conceptual idea is illustrated below. At each turn, the agent can:

Expand Down Expand Up @@ -80,14 +79,7 @@ def __init__(
f'TOOLS loaded for CodeActAgent: {json.dumps(self.tools, indent=2, ensure_ascii=False).replace("\\n", "\n")}'
)
self.prompt_manager = PromptManager(
microagent_dir=os.path.join(
os.path.dirname(os.path.dirname(openhands.__file__)),
'microagents',
)
if self.config.enable_prompt_extensions
else None,
prompt_dir=os.path.join(os.path.dirname(__file__), 'prompts'),
disabled_microagents=self.config.disabled_microagents,
)

self.condenser = Condenser.from_config(self.config.condenser)
Expand Down Expand Up @@ -223,14 +215,6 @@ def _enhance_messages(self, messages: list[Message]) -> list[Message]:
# compose the first user message with examples
self.prompt_manager.add_examples_to_initial_message(msg)

# and/or repo/runtime info
if self.config.enable_prompt_extensions:
self.prompt_manager.add_info_to_initial_message(msg)

# enhance the user message with additional context based on keywords matched
if msg.role == 'user':
self.prompt_manager.enhance_message(msg)

results.append(msg)

return results
7 changes: 4 additions & 3 deletions openhands/core/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,10 @@ async def main(loop: asyncio.AbstractEventLoop):
sid = str(uuid4())
display_message(f'Session ID: {sid}')

runtime = create_runtime(config, sid=sid, headless_mode=True)
await runtime.connect()
agent = create_agent(runtime, config)
agent = create_agent(config)

runtime = create_runtime(config, sid=sid, headless_mode=True, agent=agent)

controller, _ = create_controller(agent, runtime, config)

event_stream = runtime.event_stream
Expand Down
5 changes: 5 additions & 0 deletions openhands/core/config/app_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class AppConfig(BaseModel):
file_uploads_allowed_extensions: Allowed file extensions. `['.*']` allows all.
cli_multiline_input: Whether to enable multiline input in CLI. When disabled,
input is read line by line. When enabled, input continues until /exit command.
microagents_dir: Directory containing global microagents.
"""

llms: dict[str, LLMConfig] = Field(default_factory=dict)
Expand Down Expand Up @@ -82,6 +83,10 @@ class AppConfig(BaseModel):
daytona_target: str = Field(default='us')
cli_multiline_input: bool = Field(default=False)
conversation_max_age_seconds: int = Field(default=864000) # 10 days in seconds
microagents_dir: str = Field(
default='microagents',
description='Directory containing global microagents',
)

defaults_dict: ClassVar[dict] = {}

Expand Down
1 change: 1 addition & 0 deletions openhands/core/config/sandbox_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,5 +71,6 @@ class SandboxConfig(BaseModel):
remote_runtime_resource_factor: int = Field(default=1)
enable_gpu: bool = Field(default=False)
docker_runtime_kwargs: str | None = Field(default=None)
selected_repo: str | None = Field(default=None)

model_config = {'extra': 'forbid'}
17 changes: 14 additions & 3 deletions openhands/core/config/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,9 +513,9 @@ def get_parser() -> argparse.ArgumentParser:
parser.add_argument(
'-n',
'--name',
default='',
help='Session name',
type=str,
help='Name for the session',
default='',
)
parser.add_argument(
'--eval-ids',
Expand All @@ -525,8 +525,15 @@ def get_parser() -> argparse.ArgumentParser:
)
parser.add_argument(
'--no-auto-continue',
help='Disable auto-continue responses in headless mode (i.e. headless will read from stdin instead of auto-continuing)',
action='store_true',
help='Disable automatic "continue" responses in headless mode. Will read from stdin instead.',
default=False,
)
parser.add_argument(
'--selected-repo',
help='GitHub repository to clone (format: owner/repo)',
type=str,
default=None,
)
return parser

Expand Down Expand Up @@ -593,4 +600,8 @@ def setup_config_from_args(args: argparse.Namespace) -> AppConfig:
if args.max_budget_per_task is not None:
config.max_budget_per_task = args.max_budget_per_task

# Read selected repository in config for use by CLI and main.py
if args.selected_repo is not None:
config.sandbox.selected_repo = args.selected_repo

return config
31 changes: 27 additions & 4 deletions openhands/core/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from openhands.core.setup import (
create_agent,
create_controller,
create_memory,
create_runtime,
generate_sid,
)
Expand Down Expand Up @@ -88,14 +89,36 @@ async def run_controller(
"""
sid = sid or generate_sid(config)

if agent is None:
agent = create_agent(config)

# when the runtime is created, it will be connected and clone the selected repository
if runtime is None:
runtime = create_runtime(config, sid=sid, headless_mode=headless_mode)
await runtime.connect()
runtime = create_runtime(
config,
sid=sid,
headless_mode=headless_mode,
agent=agent,
selected_repository=config.sandbox.selected_repo,
)

event_stream = runtime.event_stream

if agent is None:
agent = create_agent(runtime, config)
# when memory is created, it will load the microagents from the selected repository
memory = create_memory(
microagents_dir=config.microagents_dir,
agent=agent,
runtime=runtime,
event_stream=event_stream,
selected_repository=config.sandbox.selected_repo,
)

# trick for testing
if agent.prompt_manager:
memory.set_prompt_manager(agent.prompt_manager)

microagents = runtime.get_microagents_from_selected_repo(None)
memory.load_user_workspace_microagents(microagents)

replay_events: list[Event] | None = None
if config.replay_trajectory_path:
Expand Down
3 changes: 3 additions & 0 deletions openhands/core/schema/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,5 +78,8 @@ class ActionTypeSchema(BaseModel):
SEND_PR: str = Field(default='send_pr')
"""Send a PR to github."""

RECALL: str = Field(default='recall')
"""Retrieves data from a file or other storage."""


ActionType = ActionTypeSchema()
3 changes: 3 additions & 0 deletions openhands/core/schema/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,8 @@ class ObservationTypeSchema(BaseModel):
CONDENSE: str = Field(default='condense')
"""Result of a condensation operation."""

RECALL: str = Field(default='recall')
"""Result of a recall operation."""


ObservationType = ObservationTypeSchema()
81 changes: 71 additions & 10 deletions openhands/core/setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import hashlib
import os
import uuid
from typing import Tuple, Type

from pydantic import SecretStr

import openhands.agenthub # noqa F401 (we import this to get the agents registered)
from openhands.controller import AgentController
from openhands.controller.agent import Agent
Expand All @@ -13,16 +16,22 @@
from openhands.events import EventStream
from openhands.events.event import Event
from openhands.llm.llm import LLM
from openhands.memory.memory import Memory
from openhands.microagent.microagent import BaseMicroAgent
from openhands.runtime import get_runtime_cls
from openhands.runtime.base import Runtime
from openhands.security import SecurityAnalyzer, options
from openhands.storage import get_file_store
from openhands.utils.async_utils import call_async_from_sync


def create_runtime(
config: AppConfig,
sid: str | None = None,
headless_mode: bool = True,
agent: Agent | None = None,
selected_repository: str | None = None,
github_token: SecretStr | None = None,
) -> Runtime:
"""Create a runtime for the agent to run on.

Expand All @@ -31,6 +40,8 @@ def create_runtime(
Set it to incompatible value will cause unexpected behavior on RemoteRuntime.
headless_mode: Whether the agent is run in headless mode. `create_runtime` is typically called within evaluation scripts,
where we don't want to have the VSCode UI open, so it defaults to True.
selected_repository: (optional) The GitHub repository to use.
github_token: (optional) The GitHub token to use.
"""
# if sid is provided on the command line, use it as the name of the event stream
# otherwise generate it on the basis of the configured jwt_secret
Expand All @@ -41,8 +52,17 @@ def create_runtime(
file_store = get_file_store(config.file_store, config.file_store_path)
event_stream = EventStream(session_id, file_store)

# set up the security analyzer
if config.security.security_analyzer:
options.SecurityAnalyzers.get(
config.security.security_analyzer, SecurityAnalyzer
)(event_stream)

# agent class
agent_cls = openhands.agenthub.Agent.get_cls(config.default_agent)
if agent:
agent_cls = type(agent)
else:
agent_cls = openhands.agenthub.Agent.get_cls(config.default_agent)

# runtime and tools
runtime_cls = get_runtime_cls(config.runtime)
Expand All @@ -55,25 +75,66 @@ def create_runtime(
headless_mode=headless_mode,
)

call_async_from_sync(runtime.connect)

# clone selected repository if provided
github_token = os.environ.get('GITHUB_TOKEN') if not github_token else github_token
if selected_repository and github_token:
logger.debug(f'Selected repository {selected_repository}.')
runtime.clone_repo(
github_token,
selected_repository,
None,
)

logger.debug(
f'Runtime initialized with plugins: {[plugin.name for plugin in runtime.plugins]}'
)

return runtime


def create_agent(runtime: Runtime, config: AppConfig) -> Agent:
def create_memory(
microagents_dir: str,
agent: Agent,
runtime: Runtime,
event_stream: EventStream,
selected_repository: str | None = None,
) -> Memory:
# If the agent config has disabled microagents, use them
disabled_microagents = agent.config.disabled_microagents

memory = Memory(
event_stream=event_stream,
microagents_dir=microagents_dir,
disabled_microagents=disabled_microagents,
)

if agent.prompt_manager and runtime:
# sets available hosts
memory.set_runtime_info(runtime.web_hosts)

# loads microagents from repo/.openhands/microagents
microagents: list[BaseMicroAgent] = runtime.get_microagents_from_selected_repo(
selected_repository
)
memory.load_user_workspace_microagents(microagents)

if selected_repository:
repo_directory = selected_repository.split('/')[1]
if repo_directory:
memory.set_repository_info(selected_repository, repo_directory)
return memory


def create_agent(config: AppConfig) -> Agent:
agent_cls: Type[Agent] = Agent.get_cls(config.default_agent)
agent_config = config.get_agent_config(config.default_agent)
llm_config = config.get_llm_config_from_agent(config.default_agent)
agent = agent_cls(
llm=LLM(config=llm_config),
config=agent_config,
)
if agent.prompt_manager:
microagents = runtime.get_microagents_from_selected_repo(None)
agent.prompt_manager.load_microagents(microagents)

if config.security.security_analyzer:
options.SecurityAnalyzers.get(
config.security.security_analyzer, SecurityAnalyzer
)(runtime.event_stream)

return agent

Expand Down
12 changes: 12 additions & 0 deletions openhands/events/action/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,15 @@ class AgentDelegateAction(Action):
@property
def message(self) -> str:
return f"I'm asking {self.agent} for help with this task."


@dataclass
class RecallAction(Action):
# This action is used for retrieving data, e.g., from memory or a knowledge base.
query: dict[str, Any] = field(default_factory=dict)
thought: str = ''
action: str = ActionType.RECALL

@property
def message(self) -> str:
return f'Retrieved data for: {self.query}'
11 changes: 11 additions & 0 deletions openhands/events/observation/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,14 @@ class AgentCondensationObservation(Observation):
@property
def message(self) -> str:
return self.content


@dataclass
class RecallObservation(Observation):
"""The output of a recall action."""

observation: str = ObservationType.RECALL

@property
def message(self) -> str:
return self.content
2 changes: 2 additions & 0 deletions openhands/events/serialization/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
AgentFinishAction,
AgentRejectAction,
ChangeAgentStateAction,
RecallAction,
)
from openhands.events.action.browse import BrowseInteractiveAction, BrowseURLAction
from openhands.events.action.commands import (
Expand Down Expand Up @@ -35,6 +36,7 @@
AgentDelegateAction,
ChangeAgentStateAction,
MessageAction,
RecallAction,
)

ACTION_TYPE_TO_CLASS = {action_class.action: action_class for action_class in actions} # type: ignore[attr-defined]
Expand Down
1 change: 1 addition & 0 deletions openhands/events/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class EventStreamSubscriber(str, Enum):
RESOLVER = 'openhands_resolver'
SERVER = 'server'
RUNTIME = 'runtime'
MEMORY = 'memory'
MAIN = 'main'
TEST = 'test'

Expand Down
Loading
Loading