Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ajet/backbone/main_trinity.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def patched_trainer_get_actor(cls, config: Config):
Trainer.get_actor = classmethod(patched_trainer_get_actor)

if ajet_config.ajet.enable_interchange_server:
from ajet.tuner_lib.experimental.as_oai_model_server import start_interchange_server
from ajet.tuner_lib.experimental.oai_model_server import start_interchange_server
start_interchange_server(ajet_config)


Expand Down
2 changes: 1 addition & 1 deletion ajet/backbone/main_verl.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def run(self, config):
from ajet.backbone.trainer_verl import AjetRayPPOTrainer

if config.ajet.enable_interchange_server:
from ajet.tuner_lib.experimental.as_oai_model_server import start_interchange_server
from ajet.tuner_lib.experimental.oai_model_server import start_interchange_server
start_interchange_server(config)

# Initialize the PPO trainer.
Expand Down
2 changes: 1 addition & 1 deletion ajet/backbone/main_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def main(config):
# atexit.register(lambda: print("Process exiting, performing cleanup..."))

if config.ajet.enable_interchange_server:
from ajet.tuner_lib.experimental.as_oai_model_server import start_interchange_server
from ajet.tuner_lib.experimental.oai_model_server import start_interchange_server
start_interchange_server(config)
if config.ajet.enable_swarm_mode:
from ajet.tuner_lib.experimental.interchange_utils import http_change_engine_status
Expand Down
2 changes: 1 addition & 1 deletion ajet/backbone/trainer_verl.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,7 +859,7 @@ def fit(self): # noqa: C901

# # when enabled oai request interchange, we need to clear the cache from time to time
# if self.config.ajet.enable_interchange_server:
# from ajet.tuner_lib.experimental.as_oai_model_server import ensure_dat_interchange_server_cache_clear
# from ajet.tuner_lib.experimental.oai_model_server import ensure_dat_interchange_server_cache_clear
# ensure_dat_interchange_server_cache_clear()

if is_last_step:
Expand Down
4 changes: 2 additions & 2 deletions ajet/context_tracker/single_agent_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,9 @@ def compute_step_level_reward(
def to_role_content(self, ext_msg_array: List[ExtendedMessage]) -> List:
result = []
for ext_msg in ext_msg_array:
d = {
d: dict = {
"role": ext_msg.role,
"content": ext_msg.content_for_future,
"content": ext_msg.content_for_compare,
}
if ext_msg.tool_calls:
d.update({"tool_calls": ext_msg.tool_calls})
Expand Down
12 changes: 6 additions & 6 deletions ajet/context_tracker/timeline_merging/timeline_merging.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def is_timeline_mergeable(
for i in range(len(target_timeline)):
if timeline_compare_level == "text":
same = (
source_timeline[i].content_for_future
== target_timeline[i].content_for_future
source_timeline[i].content_for_compare
== target_timeline[i].content_for_compare
)
elif timeline_compare_level == "token":
same = source_timeline[i].token_arr == target_timeline[i].token_arr
Expand Down Expand Up @@ -52,12 +52,12 @@ def is_timeline_mergeable(
# all_msg_match = False
# for i in range(len(target_timeline)):
# d = {}
# d["source"] = source_timeline[i].content_for_future
# d["target"] = target_timeline[i].content_for_future
# d["source"] = source_timeline[i].content_for_compare
# d["target"] = target_timeline[i].content_for_compare
# if timeline_compare_level == "text":
# same = (
# source_timeline[i].content_for_future
# == target_timeline[i].content_for_future
# source_timeline[i].content_for_compare
# == target_timeline[i].content_for_compare
# )
# elif timeline_compare_level == "token":
# same = source_timeline[i].token_arr == target_timeline[i].token_arr
Expand Down
2 changes: 1 addition & 1 deletion ajet/copilot/train-complex-blackbox/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ from ajet.task_reader import RouterTaskReader
from ajet.utils.thread_executors import PeriodicDrainThreadPoolExecutor
from ajet.tuner_lib.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey
from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo
from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient
from ajet.tuner_lib.experimental.swarm_client import SwarmClient

# python -m tutorial.example_math_swarm.math

Expand Down
4 changes: 2 additions & 2 deletions ajet/copilot/write-swarm-client/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ Below are some reference materials.
Now, create a python script and start coding:

```python
from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient
from ajet.tuner_lib.experimental.swarm_client import SwarmClient
REMOTE_SWARM_URL = "http://localhost:10086" # Change to your swarm remote url
swarm_worker = SwarmClient(REMOTE_SWARM_URL)
```
Expand Down Expand Up @@ -364,7 +364,7 @@ Below are some reference materials.

```python
from ajet.copilot.job import AgentJetJob
from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient, run_episodes_until_all_complete
from ajet.tuner_lib.experimental.swarm_client import SwarmClient, run_episodes_until_all_complete
from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo
from ajet.task_reader import RouterTaskReader
from tutorial.example_academic_trans_swarm.trans import execute_agent
Expand Down
2 changes: 1 addition & 1 deletion ajet/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def start_swarm_server(env, config):
assert config.ajet.enable_interchange_server, (
"Please enable_interchange_server in config to start swarm server."
)
from ajet.tuner_lib.experimental.as_oai_model_server import (
from ajet.tuner_lib.experimental.oai_model_server import (
start_interchange_server,
)

Expand Down
51 changes: 16 additions & 35 deletions ajet/schema/extended_msg.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,6 @@ def __init__(
token_arr=[],
token_begin_index=-1,
token_end_index=-1,
clip=False,
clip_token_limit=8192,
tokenizer: PreTrainedTokenizer = None, # type: ignore
token_generator="manual",
build_from_uuid="",
Expand All @@ -85,9 +83,8 @@ def __init__(
self.token_begin_index = token_begin_index
self.token_end_index = token_end_index
self.invalid_log_prob_value = INVALID_LOG_PROB_VALUE
self._content_for_future = ""
self._content_for_compare = ""
self._info = ""
self.clip = clip
self.tools = tools
self.tool_calls = tool_calls
self.tool_call_id = tool_call_id
Expand All @@ -101,14 +98,8 @@ def __init__(
self.manual_loss_mask_override = []
self.lack_normal_eos = False

if not clip:
self.generate_content_for_future(tokenizer=None, clip=False)
else:
self.generate_content_for_future(
tokenizer=tokenizer,
clip=True,
clip_token_limit=clip_token_limit,
)
self.generate_content_for_compare(tokenizer=None)

self.eos_token_id = tokenizer.eos_token_id

if token_generator == "auto":
Expand All @@ -127,9 +118,9 @@ def auto_tokenize(self, tokenizer, tools):
if not self.first_message:
self.token_arr = self.auto_tokenize_non_first_message(tokenizer=tokenizer, tools=tools)
else:
auto_tokenize_target = {
auto_tokenize_target:dict = {
"role": self.role,
"content": self.content_for_future,
"content": self.content_for_compare,
}
if self.tool_calls:
auto_tokenize_target.update({"tool_calls": self.tool_calls})
Expand All @@ -144,9 +135,9 @@ def auto_tokenize(self, tokenizer, tools):
def auto_tokenize_non_first_message(self, tokenizer, tools):
try:
# completion_token_arr will contain generation_prompt header
auto_tokenize_target = {
auto_tokenize_target:dict = {
"role": self.role,
"content": self.content_for_future,
"content": self.content_for_compare,
}
if self.tool_calls:
auto_tokenize_target.update({"tool_calls": self.tool_calls})
Expand All @@ -160,7 +151,7 @@ def auto_tokenize_non_first_message(self, tokenizer, tools):
)
except Exception as e:
raise ValueError(
f"Cannot tokenize {self.role} --- {self.content_for_future}, \n\n Error: {e}"
f"Cannot tokenize {self.role} --- {self.content_for_compare}, \n\n Error: {e}"
)
self.token_arr, _ = self.get_inc_simple(
text_frag_from=ajet_apply_chat_template(
Expand All @@ -175,12 +166,12 @@ def auto_tokenize_non_first_message(self, tokenizer, tools):
return self.token_arr

@property
def content_for_future(self):
if self._content_for_future == "":
def content_for_compare(self):
if self._content_for_compare == "":
if not self.tool_calls:
logger.exception("content_for_future is not set, or previous llm output is empty!")
self._content_for_future
return self._content_for_future
logger.exception("content_for_compare is not set, or previous llm output is empty!")
self._content_for_compare
return self._content_for_compare

@property
def need_training(self):
Expand All @@ -191,19 +182,9 @@ def need_training(self):
), f"author {self.author} is not identified"
return self.author in NEED_TRAIN_AUTHORS

def generate_content_for_future(self, tokenizer, clip, clip_token_limit=-1):
def generate_content_for_compare(self, tokenizer):
_content: str = self.content
if clip:
assert clip_token_limit > 0, "clip_token_limit must be set when clip is True"
n_token = len(tokenizer(_content, return_tensors="pt", padding=False)["input_ids"][0])
if n_token > clip_token_limit:
# 8000 > 4000
n_char = len(_content) # 10,000
eps = 100 # token
preserve_percent = (clip_token_limit - eps) / n_token # 3900 / 8000
n_char_to_preserve = int(n_char * preserve_percent)
_content = _content[:n_char_to_preserve] + "... truncate ..."
self._content_for_future = _content
self._content_for_compare = _content

def get_loss_mask(self, blackout_token_combo):
if self.need_training:
Expand Down Expand Up @@ -315,7 +296,7 @@ def merge_tool_group(group, tokenizer):
)
# re-compute token_arr
auto_tokenize_targets = [
{"role": msg.role, "content": msg.content_for_future} for msg in group
{"role": msg.role, "content": msg.content_for_compare} for msg in group
]
merged.token_arr, _ = merged.get_inc_simple(
text_frag_from=ajet_apply_chat_template(
Expand Down
20 changes: 19 additions & 1 deletion ajet/swarm_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def start_swarm_server(env, config, port):
# Set the port in the config
config.ajet.interchange_server.interchange_server_port = port

from ajet.tuner_lib.experimental.as_oai_model_server import (
from ajet.tuner_lib.experimental.oai_model_server import (
start_interchange_server,
)

Expand Down Expand Up @@ -139,6 +139,24 @@ def main():
)
parser_overwatch.set_defaults(func=cmd_overwatch)

# Subcommand: top (alias for overwatch)
parser_top = subparsers.add_parser("top", help="Monitor the swarm server (alias for overwatch)")
parser_top.add_argument(
"--swarm-url",
type=str,
default="http://localhost:10086",
required=False,
help="Swarm server URL (default: http://localhost:10086)",
)
parser_top.add_argument(
"--refresh-interval",
type=float,
default=2.0,
required=False,
help="Refresh interval in seconds (default: 2.0)",
)
parser_top.set_defaults(func=cmd_overwatch)

args = parser.parse_args()

if not hasattr(args, 'func'):
Expand Down
4 changes: 2 additions & 2 deletions ajet/task_runner/swarm_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,11 @@ def register_episode_and_wait_output(

while True:
# <wait for 1/2>:
# <from_sourcefile>: ajet/tuner_lib/experimental/as_swarm_server.py
# <from_sourcefile>: ajet/tuner_lib/experimental/swarm_server.py
# <from_code>: socket.send_string(workflow_output.model_dump_json())
# <expect>: workflow_output: WorkflowOutput
# <wait for 2/2>:
# <from_sourcefile>: ajet/tuner_lib/experimental/as_swarm_server.py
# <from_sourcefile>: ajet/tuner_lib/experimental/swarm_server.py
# <from_code>: socket.send_string("RUNNER.SPECIAL.RESET_CONTEXT_TRACKER")
# <expect>: "RUNNER.SPECIAL.RESET_CONTEXT_TRACKER"
try:
Expand Down
2 changes: 1 addition & 1 deletion ajet/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def get_context_tracker(self) -> MultiAgentContextTracker:
def _enable_interchange_server(self, llm_inference_fn):
# experimental reverse proxy start
if self.enable_interchange_server:
from ajet.tuner_lib.experimental.as_oai_model_client import InterchangeClient
from ajet.tuner_lib.experimental.oai_model_client import InterchangeClient
self.interchange_client = InterchangeClient(
episode_uuid=self.context_tracker.episode_uuid,
context_tracker=self.context_tracker,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from loguru import logger
from typing import TYPE_CHECKING
from ajet.tuner_lib.experimental.as_oai_model_server import InterchangeCompletionRequest
from ajet.tuner_lib.experimental.oai_model_server import InterchangeCompletionRequest
from ajet.utils.thread_executors import SharedInferenceTrackerThreadExecutor, SharedInterchangeThreadExecutor
from ajet.tuner_lib.experimental.interchange_utils import get_zmq_socket
from ajet.tuner_lib.experimental.interchange_utils import DEBUG
Expand Down Expand Up @@ -107,7 +107,7 @@ def _begin_service_threading(self):
try:

# <wait for>:
# <from_sourcefile>: ajet/tuner_lib/experimental/as_oai_model_server.py
# <from_sourcefile>: ajet/tuner_lib/experimental/oai_model_server.py
# <from_code>: socket.send_string(int_req.model_dump_json())
# <expect>: InterchangeCompletionRequest object in JSON string format
message = self.socket.recv_string()
Expand Down Expand Up @@ -165,7 +165,7 @@ def _begin_service_threading(self):
if DEBUG: logger.info(f"[client] {self.episode_uuid} | before send_string (send llm call result)")

# <send to>
# <to_sourcefile>: ajet/tuner_lib/experimental/as_oai_model_server.py
# <to_sourcefile>: ajet/tuner_lib/experimental/oai_model_server.py
# <to_code>: result_str = socket.recv_string()
self.socket.send_string(result)

Expand Down
Loading
Loading