diff --git a/ajet/backbone/main_trinity.py b/ajet/backbone/main_trinity.py index dc06c21..2e44974 100644 --- a/ajet/backbone/main_trinity.py +++ b/ajet/backbone/main_trinity.py @@ -53,7 +53,7 @@ def patched_trainer_get_actor(cls, config: Config): Trainer.get_actor = classmethod(patched_trainer_get_actor) if ajet_config.ajet.enable_interchange_server: - from ajet.tuner_lib.experimental.as_oai_model_server import start_interchange_server + from ajet.tuner_lib.experimental.oai_model_server import start_interchange_server start_interchange_server(ajet_config) diff --git a/ajet/backbone/main_verl.py b/ajet/backbone/main_verl.py index 8eebb95..0fe845c 100644 --- a/ajet/backbone/main_verl.py +++ b/ajet/backbone/main_verl.py @@ -251,7 +251,7 @@ def run(self, config): from ajet.backbone.trainer_verl import AjetRayPPOTrainer if config.ajet.enable_interchange_server: - from ajet.tuner_lib.experimental.as_oai_model_server import start_interchange_server + from ajet.tuner_lib.experimental.oai_model_server import start_interchange_server start_interchange_server(config) # Initialize the PPO trainer. diff --git a/ajet/backbone/main_vllm.py b/ajet/backbone/main_vllm.py index 4e8b717..bb79eb5 100644 --- a/ajet/backbone/main_vllm.py +++ b/ajet/backbone/main_vllm.py @@ -187,7 +187,7 @@ def main(config): # atexit.register(lambda: print("Process exiting, performing cleanup...")) if config.ajet.enable_interchange_server: - from ajet.tuner_lib.experimental.as_oai_model_server import start_interchange_server + from ajet.tuner_lib.experimental.oai_model_server import start_interchange_server start_interchange_server(config) if config.ajet.enable_swarm_mode: from ajet.tuner_lib.experimental.interchange_utils import http_change_engine_status diff --git a/ajet/backbone/trainer_verl.py b/ajet/backbone/trainer_verl.py index 28f09f9..00caaa6 100644 --- a/ajet/backbone/trainer_verl.py +++ b/ajet/backbone/trainer_verl.py @@ -859,7 +859,7 @@ def fit(self): # noqa: C901 # # when enabled oai request interchange, we need to clear the cache from time to time # if self.config.ajet.enable_interchange_server: - # from ajet.tuner_lib.experimental.as_oai_model_server import ensure_dat_interchange_server_cache_clear + # from ajet.tuner_lib.experimental.oai_model_server import ensure_dat_interchange_server_cache_clear # ensure_dat_interchange_server_cache_clear() if is_last_step: diff --git a/ajet/context_tracker/single_agent_tracking.py b/ajet/context_tracker/single_agent_tracking.py index c49828a..775abf3 100644 --- a/ajet/context_tracker/single_agent_tracking.py +++ b/ajet/context_tracker/single_agent_tracking.py @@ -185,9 +185,9 @@ def compute_step_level_reward( def to_role_content(self, ext_msg_array: List[ExtendedMessage]) -> List: result = [] for ext_msg in ext_msg_array: - d = { + d: dict = { "role": ext_msg.role, - "content": ext_msg.content_for_future, + "content": ext_msg.content_for_compare, } if ext_msg.tool_calls: d.update({"tool_calls": ext_msg.tool_calls}) diff --git a/ajet/context_tracker/timeline_merging/timeline_merging.py b/ajet/context_tracker/timeline_merging/timeline_merging.py index 86e7f8b..fcc3b05 100644 --- a/ajet/context_tracker/timeline_merging/timeline_merging.py +++ b/ajet/context_tracker/timeline_merging/timeline_merging.py @@ -21,8 +21,8 @@ def is_timeline_mergeable( for i in range(len(target_timeline)): if timeline_compare_level == "text": same = ( - source_timeline[i].content_for_future - == target_timeline[i].content_for_future + source_timeline[i].content_for_compare + == target_timeline[i].content_for_compare ) elif timeline_compare_level == "token": same = source_timeline[i].token_arr == target_timeline[i].token_arr @@ -52,12 +52,12 @@ def is_timeline_mergeable( # all_msg_match = False # for i in range(len(target_timeline)): # d = {} - # d["source"] = source_timeline[i].content_for_future - # d["target"] = target_timeline[i].content_for_future + # d["source"] = source_timeline[i].content_for_compare + # d["target"] = target_timeline[i].content_for_compare # if timeline_compare_level == "text": # same = ( - # source_timeline[i].content_for_future - # == target_timeline[i].content_for_future + # source_timeline[i].content_for_compare + # == target_timeline[i].content_for_compare # ) # elif timeline_compare_level == "token": # same = source_timeline[i].token_arr == target_timeline[i].token_arr diff --git a/ajet/copilot/train-complex-blackbox/SKILL.md b/ajet/copilot/train-complex-blackbox/SKILL.md index 032eb47..c33b527 100644 --- a/ajet/copilot/train-complex-blackbox/SKILL.md +++ b/ajet/copilot/train-complex-blackbox/SKILL.md @@ -55,7 +55,7 @@ from ajet.task_reader import RouterTaskReader from ajet.utils.thread_executors import PeriodicDrainThreadPoolExecutor from ajet.tuner_lib.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo -from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient +from ajet.tuner_lib.experimental.swarm_client import SwarmClient # python -m tutorial.example_math_swarm.math diff --git a/ajet/copilot/write-swarm-client/SKILL.md b/ajet/copilot/write-swarm-client/SKILL.md index 4fe28c5..3fd6fc4 100644 --- a/ajet/copilot/write-swarm-client/SKILL.md +++ b/ajet/copilot/write-swarm-client/SKILL.md @@ -136,7 +136,7 @@ Below are some reference materials. Now, create a python script and start coding: ```python - from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient + from ajet.tuner_lib.experimental.swarm_client import SwarmClient REMOTE_SWARM_URL = "http://localhost:10086" # Change to your swarm remote url swarm_worker = SwarmClient(REMOTE_SWARM_URL) ``` @@ -364,7 +364,7 @@ Below are some reference materials. ```python from ajet.copilot.job import AgentJetJob - from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient, run_episodes_until_all_complete + from ajet.tuner_lib.experimental.swarm_client import SwarmClient, run_episodes_until_all_complete from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo from ajet.task_reader import RouterTaskReader from tutorial.example_academic_trans_swarm.trans import execute_agent diff --git a/ajet/launcher.py b/ajet/launcher.py index d3a5b4c..71d8683 100644 --- a/ajet/launcher.py +++ b/ajet/launcher.py @@ -153,7 +153,7 @@ def start_swarm_server(env, config): assert config.ajet.enable_interchange_server, ( "Please enable_interchange_server in config to start swarm server." ) - from ajet.tuner_lib.experimental.as_oai_model_server import ( + from ajet.tuner_lib.experimental.oai_model_server import ( start_interchange_server, ) diff --git a/ajet/schema/extended_msg.py b/ajet/schema/extended_msg.py index 5e78907..3f31418 100644 --- a/ajet/schema/extended_msg.py +++ b/ajet/schema/extended_msg.py @@ -65,8 +65,6 @@ def __init__( token_arr=[], token_begin_index=-1, token_end_index=-1, - clip=False, - clip_token_limit=8192, tokenizer: PreTrainedTokenizer = None, # type: ignore token_generator="manual", build_from_uuid="", @@ -85,9 +83,8 @@ def __init__( self.token_begin_index = token_begin_index self.token_end_index = token_end_index self.invalid_log_prob_value = INVALID_LOG_PROB_VALUE - self._content_for_future = "" + self._content_for_compare = "" self._info = "" - self.clip = clip self.tools = tools self.tool_calls = tool_calls self.tool_call_id = tool_call_id @@ -101,14 +98,8 @@ def __init__( self.manual_loss_mask_override = [] self.lack_normal_eos = False - if not clip: - self.generate_content_for_future(tokenizer=None, clip=False) - else: - self.generate_content_for_future( - tokenizer=tokenizer, - clip=True, - clip_token_limit=clip_token_limit, - ) + self.generate_content_for_compare(tokenizer=None) + self.eos_token_id = tokenizer.eos_token_id if token_generator == "auto": @@ -127,9 +118,9 @@ def auto_tokenize(self, tokenizer, tools): if not self.first_message: self.token_arr = self.auto_tokenize_non_first_message(tokenizer=tokenizer, tools=tools) else: - auto_tokenize_target = { + auto_tokenize_target:dict = { "role": self.role, - "content": self.content_for_future, + "content": self.content_for_compare, } if self.tool_calls: auto_tokenize_target.update({"tool_calls": self.tool_calls}) @@ -144,9 +135,9 @@ def auto_tokenize(self, tokenizer, tools): def auto_tokenize_non_first_message(self, tokenizer, tools): try: # completion_token_arr will contain generation_prompt header - auto_tokenize_target = { + auto_tokenize_target:dict = { "role": self.role, - "content": self.content_for_future, + "content": self.content_for_compare, } if self.tool_calls: auto_tokenize_target.update({"tool_calls": self.tool_calls}) @@ -160,7 +151,7 @@ def auto_tokenize_non_first_message(self, tokenizer, tools): ) except Exception as e: raise ValueError( - f"Cannot tokenize {self.role} --- {self.content_for_future}, \n\n Error: {e}" + f"Cannot tokenize {self.role} --- {self.content_for_compare}, \n\n Error: {e}" ) self.token_arr, _ = self.get_inc_simple( text_frag_from=ajet_apply_chat_template( @@ -175,12 +166,12 @@ def auto_tokenize_non_first_message(self, tokenizer, tools): return self.token_arr @property - def content_for_future(self): - if self._content_for_future == "": + def content_for_compare(self): + if self._content_for_compare == "": if not self.tool_calls: - logger.exception("content_for_future is not set, or previous llm output is empty!") - self._content_for_future - return self._content_for_future + logger.exception("content_for_compare is not set, or previous llm output is empty!") + self._content_for_compare + return self._content_for_compare @property def need_training(self): @@ -191,19 +182,9 @@ def need_training(self): ), f"author {self.author} is not identified" return self.author in NEED_TRAIN_AUTHORS - def generate_content_for_future(self, tokenizer, clip, clip_token_limit=-1): + def generate_content_for_compare(self, tokenizer): _content: str = self.content - if clip: - assert clip_token_limit > 0, "clip_token_limit must be set when clip is True" - n_token = len(tokenizer(_content, return_tensors="pt", padding=False)["input_ids"][0]) - if n_token > clip_token_limit: - # 8000 > 4000 - n_char = len(_content) # 10,000 - eps = 100 # token - preserve_percent = (clip_token_limit - eps) / n_token # 3900 / 8000 - n_char_to_preserve = int(n_char * preserve_percent) - _content = _content[:n_char_to_preserve] + "... truncate ..." - self._content_for_future = _content + self._content_for_compare = _content def get_loss_mask(self, blackout_token_combo): if self.need_training: @@ -315,7 +296,7 @@ def merge_tool_group(group, tokenizer): ) # re-compute token_arr auto_tokenize_targets = [ - {"role": msg.role, "content": msg.content_for_future} for msg in group + {"role": msg.role, "content": msg.content_for_compare} for msg in group ] merged.token_arr, _ = merged.get_inc_simple( text_frag_from=ajet_apply_chat_template( diff --git a/ajet/swarm_cli.py b/ajet/swarm_cli.py index 723d9b5..7f65e9f 100644 --- a/ajet/swarm_cli.py +++ b/ajet/swarm_cli.py @@ -31,7 +31,7 @@ def start_swarm_server(env, config, port): # Set the port in the config config.ajet.interchange_server.interchange_server_port = port - from ajet.tuner_lib.experimental.as_oai_model_server import ( + from ajet.tuner_lib.experimental.oai_model_server import ( start_interchange_server, ) @@ -139,6 +139,24 @@ def main(): ) parser_overwatch.set_defaults(func=cmd_overwatch) + # Subcommand: top (alias for overwatch) + parser_top = subparsers.add_parser("top", help="Monitor the swarm server (alias for overwatch)") + parser_top.add_argument( + "--swarm-url", + type=str, + default="http://localhost:10086", + required=False, + help="Swarm server URL (default: http://localhost:10086)", + ) + parser_top.add_argument( + "--refresh-interval", + type=float, + default=2.0, + required=False, + help="Refresh interval in seconds (default: 2.0)", + ) + parser_top.set_defaults(func=cmd_overwatch) + args = parser.parse_args() if not hasattr(args, 'func'): diff --git a/ajet/task_runner/swarm_runner.py b/ajet/task_runner/swarm_runner.py index 9e4a5c9..5810d3a 100644 --- a/ajet/task_runner/swarm_runner.py +++ b/ajet/task_runner/swarm_runner.py @@ -66,11 +66,11 @@ def register_episode_and_wait_output( while True: # : - # : ajet/tuner_lib/experimental/as_swarm_server.py + # : ajet/tuner_lib/experimental/swarm_server.py # : socket.send_string(workflow_output.model_dump_json()) # : workflow_output: WorkflowOutput # : - # : ajet/tuner_lib/experimental/as_swarm_server.py + # : ajet/tuner_lib/experimental/swarm_server.py # : socket.send_string("RUNNER.SPECIAL.RESET_CONTEXT_TRACKER") # : "RUNNER.SPECIAL.RESET_CONTEXT_TRACKER" try: diff --git a/ajet/tuner.py b/ajet/tuner.py index 45a5442..54b2b35 100644 --- a/ajet/tuner.py +++ b/ajet/tuner.py @@ -171,7 +171,7 @@ def get_context_tracker(self) -> MultiAgentContextTracker: def _enable_interchange_server(self, llm_inference_fn): # experimental reverse proxy start if self.enable_interchange_server: - from ajet.tuner_lib.experimental.as_oai_model_client import InterchangeClient + from ajet.tuner_lib.experimental.oai_model_client import InterchangeClient self.interchange_client = InterchangeClient( episode_uuid=self.context_tracker.episode_uuid, context_tracker=self.context_tracker, diff --git a/ajet/tuner_lib/experimental/as_oai_model_client.py b/ajet/tuner_lib/experimental/oai_model_client.py similarity index 98% rename from ajet/tuner_lib/experimental/as_oai_model_client.py rename to ajet/tuner_lib/experimental/oai_model_client.py index aaecde5..89b3866 100644 --- a/ajet/tuner_lib/experimental/as_oai_model_client.py +++ b/ajet/tuner_lib/experimental/oai_model_client.py @@ -9,7 +9,7 @@ from loguru import logger from typing import TYPE_CHECKING -from ajet.tuner_lib.experimental.as_oai_model_server import InterchangeCompletionRequest +from ajet.tuner_lib.experimental.oai_model_server import InterchangeCompletionRequest from ajet.utils.thread_executors import SharedInferenceTrackerThreadExecutor, SharedInterchangeThreadExecutor from ajet.tuner_lib.experimental.interchange_utils import get_zmq_socket from ajet.tuner_lib.experimental.interchange_utils import DEBUG @@ -107,7 +107,7 @@ def _begin_service_threading(self): try: # : - # : ajet/tuner_lib/experimental/as_oai_model_server.py + # : ajet/tuner_lib/experimental/oai_model_server.py # : socket.send_string(int_req.model_dump_json()) # : InterchangeCompletionRequest object in JSON string format message = self.socket.recv_string() @@ -165,7 +165,7 @@ def _begin_service_threading(self): if DEBUG: logger.info(f"[client] {self.episode_uuid} | before send_string (send llm call result)") # - # : ajet/tuner_lib/experimental/as_oai_model_server.py + # : ajet/tuner_lib/experimental/oai_model_server.py # : result_str = socket.recv_string() self.socket.send_string(result) diff --git a/ajet/tuner_lib/experimental/as_oai_model_server.py b/ajet/tuner_lib/experimental/oai_model_server.py similarity index 87% rename from ajet/tuner_lib/experimental/as_oai_model_server.py rename to ajet/tuner_lib/experimental/oai_model_server.py index 367c808..cc41597 100644 --- a/ajet/tuner_lib/experimental/as_oai_model_server.py +++ b/ajet/tuner_lib/experimental/oai_model_server.py @@ -33,7 +33,7 @@ from typing import Coroutine, Optional, Tuple from vllm.entrypoints.openai.protocol import ChatCompletionRequest -from openai.types.chat.chat_completion import ChatCompletion +from openai.types.chat.chat_completion import ChatCompletion, CompletionUsage from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice from openai.types.chat.chat_completion_chunk import ChoiceDelta, ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction @@ -70,7 +70,6 @@ def ep_key(episode_uuid: str) -> str: def get_app(max_fastapi_threads: int = 512, enable_swarm_mode=False, shared_mem_dict=None, shared_mem_dict_lock=None) -> Tuple[FastAPI, Optional[Coroutine]]: - @asynccontextmanager async def lifespan(app: FastAPI): # Startup @@ -96,7 +95,7 @@ def _begin_handle_chat_completion(episode_address, int_req: InterchangeCompletio if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | connect done") # - # : ajet/tuner_lib/experimental/as_oai_model_client.py + # : ajet/tuner_lib/experimental/oai_model_client.py # : message = self.socket.recv_string() socket.send_string(int_req.model_dump_json()) @@ -116,7 +115,7 @@ def _begin_handle_chat_completion(episode_address, int_req: InterchangeCompletio if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | recv_string begin.") # : - # : ajet/tuner_lib/experimental/as_oai_model_client.py + # : ajet/tuner_lib/experimental/oai_model_client.py # : self.socket.send_string(result) # : ChatCompletion object in JSON string format result_str = socket.recv_string() @@ -152,6 +151,8 @@ async def mock_as_stream_response(result: ChatCompletion): """ content = result.choices[0].message.content if result.choices else "" role = result.choices[0].message.role if result.choices else "assistant" + result_id = result.id if result.id else uuid.uuid4().hex + result.id = "chatcmpl-" + result_id if not result_id.startswith("chatcmpl-") else result_id # try: # thinking = result.choices[0].message.reasoning_content # except: @@ -170,6 +171,18 @@ async def mock_as_stream_response(result: ChatCompletion): type=tc.type ) for index, tc in enumerate(tool_calls)] + def dump_chunk(chunk: ChatCompletionChunk) -> str: + dump = chunk.model_dump() + dump.pop("service_tier", None) + dump.pop("system_fingerprint", None) + if "usage" in dump and dump["usage"] is None: + dump.pop("usage", None) + # for each choice delta, if field (such as tool_calls) is empty, remove it from the delta to avoid confusion + for key in list(dump["choices"][0]["delta"].keys()): + if not dump["choices"][0]["delta"][key] and key != "content": # keep content even if it's empty + dump["choices"][0]["delta"].pop(key, None) + return f"data: {json.dumps(dump)}\n\n" + # First chunk with role first_chunk = ChatCompletionChunk( id=result.id, @@ -184,8 +197,7 @@ async def mock_as_stream_response(result: ChatCompletion): ) ] ) - dat = f"data: {first_chunk.model_dump_json()}\n\n" - yield dat + yield dump_chunk(first_chunk) # Content chunk content_chunk = ChatCompletionChunk( @@ -196,30 +208,28 @@ async def mock_as_stream_response(result: ChatCompletion): choices=[ ChunkChoice( index=0, - delta=ChoiceDelta(role=role, content=content, tool_calls=delta_tool_calls), + delta=ChoiceDelta(content=content, tool_calls=delta_tool_calls), finish_reason=None ) ] ) - dat = f"data: {content_chunk.model_dump_json()}\n\n" - yield dat - + yield dump_chunk(content_chunk) # Final chunk with finish_reason final_chunk = ChatCompletionChunk( id=result.id, model=result.model, created=result.created, object="chat.completion.chunk", + usage=CompletionUsage(completion_tokens=0, prompt_tokens=0, total_tokens=0), choices=[ ChunkChoice( index=0, - delta=ChoiceDelta(), - finish_reason=finish_reason + delta=ChoiceDelta(content=""), + finish_reason='stop' if tool_calls is None else 'tool_calls', ) ] ) - dat = f"data: {final_chunk.model_dump_json()}\n\n" - yield dat + yield dump_chunk(final_chunk) yield "data: [DONE]\n\n" @@ -269,7 +279,7 @@ async def chat_completions(request: Request, authorization: str = Header(None)): # enable_swarm_mode if enable_swarm_mode: - from ajet.tuner_lib.experimental.as_swarm_server import ep_key + from ajet.tuner_lib.experimental.swarm_server import ep_key assert shared_mem_dict is not None assert shared_mem_dict_lock is not None @@ -308,18 +318,37 @@ async def chat_completions(request: Request, authorization: str = Header(None)): loop = asyncio.get_running_loop() result = await loop.run_in_executor(request.app.state.executor, _begin_handle_chat_completion, episode_address, int_req, episode_uuid) + if enable_swarm_mode: + assert shared_mem_dict is not None + shared_mem_dict["latest_llm_call"] = { + "input": body, + "output": result, + } + if original_stream: + result.model = "unknown_model" if not new_req.model else new_req.model return StreamingResponse(mock_as_stream_response(result), media_type="text/event-stream") return result if enable_swarm_mode: - from ajet.tuner_lib.experimental.as_swarm_server import register_enable_swarm_mode_routes + from ajet.tuner_lib.experimental.swarm_server import register_enable_swarm_mode_routes + + @app.post("/replay_latest_llm_call") + async def replay_latest_llm_call(): + """Return the buffered latest LLM call result.""" + assert shared_mem_dict is not None + if ("latest_llm_call" not in shared_mem_dict) or shared_mem_dict["latest_llm_call"] is None: + raise HTTPException(status_code=404, detail="No LLM call has been made yet") + return shared_mem_dict["latest_llm_call"] + assert shared_mem_dict is not None, "shared_mem_dict must not be None when enable_swarm_mode is True." assert shared_mem_dict_lock is not None, "shared_mem_dict_lock must not be None when enable_swarm_mode is True." app, additional_coro = register_enable_swarm_mode_routes(app, zmq_context=context, shared_mem_dict=shared_mem_dict, shared_mem_dict_lock=shared_mem_dict_lock) + else: + additional_coro = None @@ -481,6 +510,6 @@ def start_interchange_server(config, blocking=False, env={}) -> int: if interchange_server: interchange_server.terminate() if enable_swarm_mode: - from ajet.tuner_lib.experimental.as_swarm_server import kill_process_tree + from ajet.tuner_lib.experimental.swarm_server import kill_process_tree kill_process_tree(None, None) return -1 diff --git a/ajet/tuner_lib/experimental/as_swarm_client.py b/ajet/tuner_lib/experimental/swarm_client.py similarity index 96% rename from ajet/tuner_lib/experimental/as_swarm_client.py rename to ajet/tuner_lib/experimental/swarm_client.py index d4c3ff4..9776302 100644 --- a/ajet/tuner_lib/experimental/as_swarm_client.py +++ b/ajet/tuner_lib/experimental/swarm_client.py @@ -67,12 +67,12 @@ def __init__(self, server_url: str): # better logging management self._last_second_print_buffer: dict[str, float] = {} self._begin_episode_lock = threading.Lock() + self._http_client_lock = threading.Lock() + self._http_client = self._refresh_http_client() # record last registered AgentJetJob self._agent_jet_job = None # throttle self._recent_seen_tasks = [] - # reuse httpx client to avoid creating SSL context repeatedly - self._http_client = httpx.Client(timeout=GENERAL_TIMEOUT) def logger_info(self, message): # logger with de-duplication within 1 second to prevent log flooding @@ -96,21 +96,26 @@ def logger_info(self, message): def _refresh_http_client(self): """Refresh the HTTP client by closing the old one and creating a new one.""" - try: - self._http_client.close() - except Exception: - pass # Ignore errors when closing - self._http_client = httpx.Client(timeout=GENERAL_TIMEOUT) - logger.info("HTTP client refreshed due to connection error") + with self._http_client_lock: + try: + self._http_client.close() + except Exception: + pass # Ignore errors when closing + try: + self._http_client = httpx.Client(timeout=GENERAL_TIMEOUT, http2=True) + except: + self._http_client = httpx.Client(timeout=GENERAL_TIMEOUT, http2=False) + logger.warning("HTTP client refreshed due to connection error") + return self._http_client def _should_refresh_client_on_error(self, error: Exception) -> bool: """Check if an error suggests the HTTP client should be refreshed.""" error_msg = str(error).lower() return any(keyword in error_msg for keyword in [ + "broken pipe", "disconnected", "connection reset", "connection closed", - "broken pipe", "connection aborted" ]) @@ -320,6 +325,8 @@ def _begin_episode_auto_retry(self, discard_episode_timeout=240, episode_type="t continue except Exception as e: + if self._should_refresh_client_on_error(e): + self._refresh_http_client() logger.error(f"Error claiming episode: {e}. Retrying ...") retry_delay = START_EPISODE_RETRY_DELAY continue @@ -398,6 +405,8 @@ def abort_episode(self, episode_uuid: str): logger.error(f"Failed to end episode {episode_uuid}") except Exception as e: + if self._should_refresh_client_on_error(e): + self._refresh_http_client() logger.error(f"Error ending episode: {e}") def sync_train_config(self, agent_jet_job: AgentJetJob): @@ -498,6 +507,8 @@ def _wait_until_status_change_to(self, desired_status="ENGINE.ROLLING", verbose= raise e except Exception as e: + if self._should_refresh_client_on_error(e): + self._refresh_http_client() logger.error(f"Error polling engine status: {e}") time.sleep(5) @@ -511,8 +522,8 @@ def get_engine_status(self) -> Tuple[str, dict]: raise_for_status_with_detail(resp) resp_json = resp.json() result = resp_json.get("engine_status", "unknown") - engine_status_detail = resp_json.get("engine_status_detail", None) - global_step = resp_json.get("global_step", None) + # engine_status_detail = resp_json.get("engine_status_detail", None) + # global_step = resp_json.get("global_step", None) if result == "unknown": logger.warning("get_engine_status: " + str(resp_json)) return result, resp_json @@ -525,7 +536,6 @@ def get_engine_status(self) -> Tuple[str, dict]: def can_continue_episode(self, episode_uuid: str) -> bool: if not episode_uuid: return False - try: req_obj = CanContinueEpisodeRequest( client_uuid=self.client_uuid, @@ -540,6 +550,8 @@ def can_continue_episode(self, episode_uuid: str) -> bool: data = CanContinueEpisodeResponse.model_validate(resp.json()) return data.can_continue except Exception as e: + if self._should_refresh_client_on_error(e): + self._refresh_http_client() logger.error(f"Error checking can_continue_episode: {e}") return False @@ -554,6 +566,8 @@ def get_episode_buffer(self) -> List[EpisodeStatus]: data = EpisodeBufferResponse.model_validate(resp.json()) return data.buffer except Exception as e: + if self._should_refresh_client_on_error(e): + self._refresh_http_client() logger.error(f"Error getting episode buffer: {e}") return [] @@ -632,6 +646,8 @@ def get_rollout_stat(self) -> CurrentBatchRolloutPoolInformation: data = CurrentBatchRolloutPoolInformation.model_validate(resp.json()) return data except Exception as e: + if self._should_refresh_client_on_error(e): + self._refresh_http_client() logger.error(f"Error getting rollout statistics: {e}") return CurrentBatchRolloutPoolInformation() diff --git a/ajet/tuner_lib/experimental/as_swarm_server.py b/ajet/tuner_lib/experimental/swarm_server.py similarity index 100% rename from ajet/tuner_lib/experimental/as_swarm_server.py rename to ajet/tuner_lib/experimental/swarm_server.py diff --git a/ajet/utils/swarm_overwatch.py b/ajet/utils/swarm_overwatch.py index 9f8b771..5c8370a 100644 --- a/ajet/utils/swarm_overwatch.py +++ b/ajet/utils/swarm_overwatch.py @@ -392,7 +392,7 @@ def create_logo_panel(self, info: CurrentBatchRolloutPoolInformation) -> Text: return content def create_dashboard( - self, info: Optional[CurrentBatchRolloutPoolInformation] + self, info: Optional[CurrentBatchRolloutPoolInformation], init=False ) -> Layout: """Create the main dashboard layout""" layout = Layout() @@ -400,7 +400,7 @@ def create_dashboard( # Create header header = self.create_header(info) - if info is None: + if (info is None) and (not init): # Show error state error_panel = Panel( "[bold red]Failed to fetch data from server, please check your connection or simply wait a moment...[/bold red]\n" @@ -409,8 +409,19 @@ def create_dashboard( padding=(1, 2), ) layout.split_column(Layout(header, size=8), Layout(error_panel)) + elif (info is None) and (init): + # Initial state before first successful data fetch + welcome_panel = Panel( + "[bold green]Welcome to AgentJet Swarm Overwatch![/bold green]\n\n" + "Attempting to connect to server and fetch data...\n" + f"[dim]Target server: {self.server_url}[/dim]\n", + border_style="green", + padding=(1, 2), + ) + layout.split_column(Layout(header, size=8), Layout(welcome_panel)) else: # Check engine status and show logo for OFFLINE or BOOTING states + assert info is not None # for type checker if info.engine_status in ["ENGINE.OFFLINE", "ENGINE.BOOTING"]: # Hide tables and show logo logo_display = self.create_logo_panel(info) @@ -439,46 +450,88 @@ def create_dashboard( return layout - def run(self): - """Start the monitoring interface""" - self.console.clear() - try: - with Live( - self.create_dashboard(None), - console=self.console, - refresh_per_second=1, - screen=True, - ) as live: + def display_latest_llm_call(self): + while True: + response = httpx.post(f"{self.server_url}/replay_latest_llm_call", timeout=30.0) + structured_response = response.json() + if "input" not in structured_response or "output" not in structured_response: + self.console.print(f"[bold red]{structured_response}[/bold red]") + time.sleep(5) + continue + else: + input = structured_response["input"] + output = structured_response["output"] + self.console.print(f"\n[bold green]Input:[/bold green]\n{input}") + self.console.print(f"\n[bold green]Output:[/bold green]\n{output}") + time.sleep(5) + + def choose_run(self) -> str: + mode = "overwatch" + # mode = "replay_latest_llm_call" + while True: + self.console.clear() + try: + if mode == "overwatch": + self.run() + elif mode == "replay_latest_llm_call": + self.display_latest_llm_call() + + except KeyboardInterrupt: + self.console.clear() + self.console.print("\n[bold yellow]Overwatch stopped by user[/bold yellow]") self.console.print( - "[bold green]Starting Swarm Overwatch...[/bold green]" + f"[dim]Total requests: {self.total_requests}, Errors: {self.error_count}[/dim]\n" ) - self.console.print(f"[dim]Press Ctrl+C to exit[/dim]\n") - time.sleep(1) - - while True: - try: - # Fetch latest data - info = self.fetch_pool_info() - # Update display - live.update(self.create_dashboard(info)) + self.console.print("\n[bold]Choose action:[/bold]") + self.console.print(" [bold cyan]o[/bold cyan] - Return to overwatch") + self.console.print(" [bold cyan]t[/bold cyan] - Show replay_latest_llm_call") + self.console.print(" [bold cyan]ctrl+c[/bold cyan] - Exit") + choice = input("\n> ").strip().lower() + + if choice == "o": + mode = "overwatch" + self.console.clear() + continue + elif choice == "t": + mode = "replay_latest_llm_call" + self.console.clear() + continue + else: + self.console.print("[yellow]Invalid choice. Please enter 'o' or 't'.[/yellow]") - # Wait for next refresh - time.sleep(self.refresh_interval) - - except KeyboardInterrupt: - raise - except Exception as e: - logger.error(f"Error in monitoring loop: {e}") - time.sleep(self.refresh_interval) + def run(self): + """Start the monitoring interface""" - except KeyboardInterrupt: - self.console.clear() - self.console.print("\n[bold yellow]Overwatch stopped by user[/bold yellow]") + with Live( + self.create_dashboard(None, init=True), + console=self.console, + refresh_per_second=1, + screen=True, + ) as live: self.console.print( - f"[dim]Total requests: {self.total_requests}, Errors: {self.error_count}[/dim]\n" + "[bold green]Starting Swarm Overwatch...[/bold green]" ) + self.console.print(f"[dim]Press Ctrl+C to exit[/dim]\n") + time.sleep(1) + + while True: + try: + # Fetch latest data + info = self.fetch_pool_info() + + # Update display + live.update(self.create_dashboard(info)) + + # Wait for next refresh + time.sleep(self.refresh_interval) + + except KeyboardInterrupt: + raise + except Exception as e: + logger.error(f"Error in monitoring loop: {e}") + time.sleep(self.refresh_interval) def start_overwatch(server_url: str, refresh_interval: float = 2.0): @@ -490,7 +543,7 @@ def start_overwatch(server_url: str, refresh_interval: float = 2.0): refresh_interval: Refresh interval in seconds (default: 2.0) """ overwatch = SwarmOverwatch(server_url, refresh_interval) - overwatch.run() + overwatch.choose_run() if __name__ == "__main__": diff --git a/docs/en/ajet-swarm-docker.md b/docs/en/ajet-swarm-docker.md index 15c054a..f2c51cb 100644 --- a/docs/en/ajet-swarm-docker.md +++ b/docs/en/ajet-swarm-docker.md @@ -123,7 +123,7 @@ Meanwhile, all VERL and training logs stream into `./swarmlog/swarm_server.log` From any machine (no GPU required) that can reach the server on port `10086`, run your Swarm Client: ```python -from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient +from ajet.tuner_lib.experimental.swarm_client import SwarmClient from ajet.copilot.job import AgentJetJob swarm_worker = SwarmClient("http://:10086") diff --git a/docs/en/example_train_multi_model.md b/docs/en/example_train_multi_model.md new file mode 100644 index 0000000..6e3617a --- /dev/null +++ b/docs/en/example_train_multi_model.md @@ -0,0 +1,204 @@ +# 非共享参数多智能体强化学习:学术翻译实战 + +在传统的多智能体强化学习(MARL)系统中,所有智能体通常共享同一套模型参数——这意味着无论有多少个智能体,它们都共用一个"大脑"。这种设计虽然简单,但在实际应用中存在明显的局限性:不同智能体可能需要不同规模的模型来执行不同复杂度的任务。AgentJet 的 Swarm 训练模式突破了这一限制,实现了真正的**非共享参数多智能体强化学习**。 + +## 背景:从"共享大脑"到"异构团队" + +在传统框架中训练多智能体系统时,研究者面临一个隐含假设:所有智能体必须共享同一个底层模型。这种设计源于大多数训练后端(如 VERL 和 TRL)的架构限制——它们通常只支持对单个 LLM 模型进行微调训练。 + +然而,这种"共享大脑"的设计在很多场景下并不经济: + +- **能力错配**:一个负责高层规划的 Agent 可能需要 32B 的大模型来保证推理质量,而负责具体执行的 Agent 用一个 7B 的小模型就足够了 +- **资源浪费**:用大模型处理简单任务是对计算资源的浪费 +- **训练信号单一**:所有智能体接受相同的奖励信号,难以针对各自的任务进行专门优化 + +AgentJet Swarm 模式通过部署多个独立的 Swarm Server,每个 Server 承载不同大小的模型,实现了真正的**异构多模型训练**。每个模型可以拥有独立的训练配置、奖励函数和优化目标。 + +## 示例场景:学术论文翻译工作流 + +让我们通过一个具体的例子来理解非共享参数多智能体强化学习的工作方式。本示例实现了一个三阶段的学术论文翻译工作流: + +```mermaid +graph LR + A[输入英文论文] --> B1(Agent 1: 粗翻译) + B1 --> B2(Agent 2: 检测专有名词) + B2 --> B3(Agent 3: 最终翻译) + B3 --> D[中文论文] + + B1 -.-> R1[7B 模型] + B2 -.-> R2[14B 模型] + B3 -.-> R3[7B 模型] + + R1 --> T1[翻译质量奖励] + R2 --> T2[检测质量奖励] + R3 --> T3[翻译质量奖励] +``` + +在这个工作流中: + +- **Agent 1(粗翻译)**:使用 7B 模型将英文论文初步翻译为中文 +- **Agent 2(检测专有名词)**:使用 14B 模型检测翻译中的专有名词错误(如术语翻译、缩写处理等) +- **Agent 3(最终翻译)**:使用 7B 模型根据检测结果生成最终的中文翻译 + +## 核心创新:独立奖励函数 + +传统方案中,所有智能体共享同一个奖励信号——无论哪个 Agent 产生输出,奖励都基于最终翻译质量来计算。这种设计存在一个根本问题:Agent 2(14B 模型)实际上是在为"最终翻译质量"而不是"检测质量"负责,这导致训练信号模糊,模型难以学到真正的检测能力。 + +本示例的创新之处在于为每个模型配置了**独立的奖励函数**: + +| 模型 | Agent 角色 | 奖励函数 | 评估重点 | +|------|-----------|---------|---------| +| 7B | Agent 1 & 3 | TranslationQualityGrader | 最终翻译质量(人称代词、缩写、语序、主语清晰度) | +| 14B | Agent 2 | ProperNounDetectionGrader | 专有名词检测质量(完整性、准确性、误报率) | + +这种设计的优势在于: + +1. **任务特异性训练**:每个模型学习其特定角色的最佳策略 +2. **信号清晰**:14B 模型直接学习"如何检测错误",而非"如何让最终翻译看起来更好" +3. **资源优化**:简单翻译任务使用小模型,复杂检测任务使用大模型 +4. **独立演化**:7B 和 14B 模型可以独立优化,互不干扰 + +## 系统架构 + +AgentJet 通过部署**两个独立的 Swarm Server** 来实现非共享参数训练: + +```mermaid +graph TB + subgraph "客户端 (Swarm Client)" + C[训练脚本] + end + + subgraph "Server 1: 7B 模型" + S1[Swarm Server
:10086] + M1[Qwen2.5-7B-Instruct] + S1 --> M1 + end + + subgraph "Server 2: 14B 模型" + S2[Swarm Server
:10087] + M2[Qwen2.5-14B-Instruct] + S2 --> M2 + end + + C -->|begin_episode| S1 + C -->|begin_episode| S2 + S1 -->|api_base_url + api_key| C + S2 -->|api_base_url + api_key| C + C -->|end_episode + reward_7b| S1 + C -->|end_episode + reward_14b| S2 +``` + +**架构说明**: + +- **Swarm Server 1 (端口 10086)**:承载 7B 模型,负责 Agent 1 和 Agent 3 的推理与训练 +- **Swarm Server 2 (端口 10087)**:承载 14B 模型,负责 Agent 2 的推理与训练 +- **Swarm Client**:运行在任何设备上,负责工作流编排和奖励计算 + +客户端代码只需要传入两个不同的 `api_baseurl_key`,分别对应两个模型: + +```python +def rollout(task): + # 从两个 Swarm Server 获取独立的 API 凭证 + episode_uuid_7b, api_baseurl_key_7b = swarm_worker_7b.begin_episode() + episode_uuid_14b, api_baseurl_key_14b = swarm_worker_14b.begin_episode() + + # 使用两个模型执行工作流 + workflow_output_7b, workflow_output_14b = execute_agent( + task, + api_baseurl_key_7b, + api_baseurl_key_14b + ) + + # 分别向两个 Server 报告各自对应的奖励 + swarm_worker_7b.end_episode(task, episode_uuid_7b, workflow_output_7b) + swarm_worker_14b.end_episode(task, episode_uuid_14b, workflow_output_14b) +``` + +## 奖励函数设计 + +### 7B 模型奖励:翻译质量评估 + +7B 模型(Agent 1 和 Agent 3)的奖励由 `TranslationQualityGrader` 计算,评估标准包括: + +- **第一人称代词使用**:禁止使用"我们",应使用"本研究"、"本文"等 +- **缩写翻译**:当有简洁中文表达时使用缩写(如 GWs 而非"引力波") +- **语序调整**:未按中文习惯调整句子结构 +- **主语清晰度**:主语缺失或不明确 +- **专有名词翻译**:领域术语翻译错误 + +评分范围 0-2 分,归一化到 [0, 1]。 + +### 14B 模型奖励:检测质量评估 + +14B 模型(Agent 2)的奖励由 `ProperNounDetectionGrader` 计算,评估标准包括: + +- **完整性**:是否检测到所有关键错误(第一人称代词、缩写问题、专有名词错误等) +- **准确性**:检测到的错误是否准确,纠正建议是否合理 +- **误报率**:是否将正确的翻译标记为错误 +- **JSON 格式**:输出是否为有效的 JSON 格式 + +同样采用 0-2 分的评分体系,归一化到 [0, 1]。 + +## 训练流程 + +整个训练流程如下: + +```mermaid +sequenceDiagram + participant Client as Swarm Client + participant Server7B as 7B Swarm Server + participant Server14B as 14B Swarm Server + + Client->>Server7B: begin_episode() + Client->>Server14B: begin_episode() + Server7B-->>Client: api_baseurl_key_7b + Server14B-->>Client: api_baseurl_key_14b + + Note over Client: Agent 1 (7B): 粗翻译 + Note over Client: Agent 2 (14B): 检测错误 + Note over Client: Agent 3 (7B): 最终翻译 + + Note over Client: 计算 7B 奖励: 翻译质量 + Note over Client: 计算 14B 奖励: 检测质量 + + Client->>Server7B: end_episode(reward_7b) + Client->>Server14B: end_episode(reward_14b) + + Server7B-->>Server7B: 策略梯度更新 (7B) + Server14B-->>Server14B: 策略梯度更新 (14B) +``` + +每个训练周期中: + +1. 客户端同时向两个 Swarm Server 请求 episode 资源 +2. 执行完整的工作流,获取两个模型的输出 +3. 分别计算两个奖励:7B 基于最终翻译质量,14B 基于检测质量 +4. 将各自的奖励汇报给对应的 Swarm Server +5. 两个 Server 独立执行策略梯度更新 + +## 优势总结 + +与传统的单模型共享参数训练相比,非共享参数多智能体强化学习具有显著优势: + +| 特性 | 共享参数 | 非共享参数(本示例) | +|------|---------|-------------------| +| 模型配置 | 单一模型 | 7B + 14B 异构组合 | +| 奖励信号 | 统一奖励 | 任务特异性奖励 | +| 资源利用 | 低效(大模型处理简单任务) | 高效(按需分配) | +| 训练目标 | 所有 Agent 优化同一目标 | 每个 Agent 优化各自目标 | +| 扩展性 | 受限于单一模型容量 | 可独立扩展各组件 | + +## 延伸阅读 + +### 交叉引用 + +- **[AgentJet Swarm 训练模式](../swarm.md)**:深入了解 AgentJet 蜂群架构的设计理念和核心优势 +- **[可训练工作流](../workflow.md)**:学习如何在 AgentJet 中定义多智能体工作流 +- **[任务评判器](../task_judger.md)**:了解奖励函数的设计原理和自定义方法 +- **[数学 Agent 示例](../example_math_agent.md)**:学习单智能体训练的基础示例 + +### 接下来推荐阅读 + +1. **[Werewolves 狼人杀游戏](../example_werewolves.md)**:了解如何在 AgentJet 中训练多智能体协作与竞争 +2. **[学术翻译蜂群训练](../example_academic_trans_swarm/README.md)**:了解更简单的单模型版本实现 +3. **[蜂群训练博客](swarm_intro_blog_zh.md)**:深入理解非共享参数训练的更多应用场景 diff --git a/docs/en/swarm.md b/docs/en/swarm.md index 11960d6..f2e8d33 100644 --- a/docs/en/swarm.md +++ b/docs/en/swarm.md @@ -61,7 +61,7 @@ The primary objective of swarm client is to make sure network connection is good Now, create a python script and start coding: ```python -from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient +from ajet.tuner_lib.experimental.swarm_client import SwarmClient REMOTE_SWARM_URL = "http://localhost:10086" # Change to your swarm remote url swarm_worker = SwarmClient(REMOTE_SWARM_URL) ``` diff --git a/docs/en/swarm_best_practice.md b/docs/en/swarm_best_practice.md index 2f8eb28..3557a11 100644 --- a/docs/en/swarm_best_practice.md +++ b/docs/en/swarm_best_practice.md @@ -132,7 +132,7 @@ Hint: you do not have to use `run_episodes_until_all_complete`, you are free to ```python from ajet.copilot.job import AgentJetJob -from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient, run_episodes_until_all_complete +from ajet.tuner_lib.experimental.swarm_client import SwarmClient, run_episodes_until_all_complete from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo from ajet.task_reader import RouterTaskReader from tutorial.example_academic_trans_swarm.trans import execute_agent diff --git a/docs/en/swarm_deepdive.md b/docs/en/swarm_deepdive.md index 6120887..f00ac38 100644 --- a/docs/en/swarm_deepdive.md +++ b/docs/en/swarm_deepdive.md @@ -91,7 +91,7 @@ In code, the most common pattern is: ```python from ajet.copilot.job import AgentJetJob -from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient +from ajet.tuner_lib.experimental.swarm_client import SwarmClient swarm_client = SwarmClient("http://your-swarm-server:10086") yaml_job = AgentJetJob( diff --git a/docs/en/tune_your_first_agent.md b/docs/en/tune_your_first_agent.md index 98701ae..1103a11 100644 --- a/docs/en/tune_your_first_agent.md +++ b/docs/en/tune_your_first_agent.md @@ -497,7 +497,7 @@ Create your client script. The client reads the dataset, runs the agent workflow from ajet.utils.thread_executors import PeriodicDrainThreadPoolExecutor from ajet.tuner_lib.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo - from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient + from ajet.tuner_lib.experimental.swarm_client import SwarmClient # Configuration GRPO_N = 4 # grpo group size @@ -650,7 +650,7 @@ The server handles gradient computation and model updates automatically. from ajet.utils.thread_executors import PeriodicDrainThreadPoolExecutor from ajet.tuner_lib.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo - from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient + from ajet.tuner_lib.experimental.swarm_client import SwarmClient GRPO_N = 4 # grpo group size NUM_EPOCH = 10000 diff --git a/pyproject.toml b/pyproject.toml index bb13ce1..7aebcbc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ requires-python = ">=3.10,<3.13" dependencies = [ "agentscope==1.0.8", "chromadb", - "httpx", + "httpx[http2]", "tenacity", "loguru", "debugpy", diff --git a/scripts/deploy_model.py b/scripts/deploy_model.py index e2b0786..4212a5f 100644 --- a/scripts/deploy_model.py +++ b/scripts/deploy_model.py @@ -16,13 +16,13 @@ parser.add_argument( "--target", # default="/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen3-235B-A22B-Instruct-2507/", - default="/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen3-Coder-480B-A35B-Instruct", + default="/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-14B-Instruct", type=str, help="Model path", ) parser.add_argument( "--alias", - default="Qwen/Qwen3-Coder-480B-A35B-Instruct", + default="Qwen/Qwen2.5-14B-Instruct", type=str, help="Model alias", ) diff --git a/tutorial/example_academic_trans_swarm/trans_roll.py b/tutorial/example_academic_trans_swarm/trans_roll.py index 538f609..c6a55cb 100644 --- a/tutorial/example_academic_trans_swarm/trans_roll.py +++ b/tutorial/example_academic_trans_swarm/trans_roll.py @@ -1,5 +1,5 @@ from ajet.copilot.job import AgentJetJob -from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient, run_episodes_until_all_complete +from ajet.tuner_lib.experimental.swarm_client import SwarmClient, run_episodes_until_all_complete from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo from ajet.task_reader import RouterTaskReader from tutorial.example_academic_trans_swarm.trans import execute_agent diff --git a/tutorial/example_frozenlake_swarm/frozen_lake_roll.py b/tutorial/example_frozenlake_swarm/frozen_lake_roll.py index 1b5569c..e3365f4 100644 --- a/tutorial/example_frozenlake_swarm/frozen_lake_roll.py +++ b/tutorial/example_frozenlake_swarm/frozen_lake_roll.py @@ -1,5 +1,5 @@ from ajet.copilot.job import AgentJetJob -from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient, run_episodes_until_all_complete +from ajet.tuner_lib.experimental.swarm_client import SwarmClient, run_episodes_until_all_complete from ajet.default_config.ajet_default import AjetTaskReader from ajet.task_reader import RouterTaskReader from .frozenlake import FrozenLake diff --git a/tutorial/example_frozenlake_swarm/frozen_lake_roll_2_models.py b/tutorial/example_frozenlake_swarm/frozen_lake_roll_2_models.py index c1331ce..151274e 100644 --- a/tutorial/example_frozenlake_swarm/frozen_lake_roll_2_models.py +++ b/tutorial/example_frozenlake_swarm/frozen_lake_roll_2_models.py @@ -1,5 +1,5 @@ from ajet.copilot.job import AgentJetJob -from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient, run_episodes_until_all_complete +from ajet.tuner_lib.experimental.swarm_client import SwarmClient, run_episodes_until_all_complete from ajet.default_config.ajet_default import AjetTaskReader from ajet.task_reader import RouterTaskReader from .frozenlake import FrozenLake diff --git a/tutorial/example_math_swarm/math.py b/tutorial/example_math_swarm/math.py index c1351a9..2174076 100644 --- a/tutorial/example_math_swarm/math.py +++ b/tutorial/example_math_swarm/math.py @@ -10,7 +10,7 @@ from ajet.utils.thread_executors import PeriodicDrainThreadPoolExecutor from ajet.tuner_lib.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo -from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient +from ajet.tuner_lib.experimental.swarm_client import SwarmClient # python -m tutorial.example_math_swarm.math @@ -21,6 +21,8 @@ REMOTE_BATCH_SIZE = 32 REMOTE_ALLOCATE_GPU_PER_NODE = 8 +assert AJET_SWARM_URL != "http://swarm-server-ip:10086", "Please set the environment variable AJET_SWARM_URL to your swarm server's URL, e.g., http://localhost:10086 or http://your-swarm-server-ip:10086" + def main(): # Handshake with swarm remote, then send training param to swarm remote (such as model to be trained, algorithm, etc) diff --git a/tutorial/example_train_multi_model/trans_roll.py b/tutorial/example_train_multi_model/trans_roll.py index 0454c92..7e8e28a 100644 --- a/tutorial/example_train_multi_model/trans_roll.py +++ b/tutorial/example_train_multi_model/trans_roll.py @@ -1,5 +1,5 @@ from ajet.copilot.job import AgentJetJob -from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient, run_episodes_until_all_complete +from ajet.tuner_lib.experimental.swarm_client import SwarmClient, run_episodes_until_all_complete from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo from ajet.task_reader import RouterTaskReader from tutorial.example_academic_trans_swarm.trans import execute_agent diff --git a/tutorial/example_werewolves_swarm/agent_roll.py b/tutorial/example_werewolves_swarm/agent_roll.py index ac85925..5107ab8 100644 --- a/tutorial/example_werewolves_swarm/agent_roll.py +++ b/tutorial/example_werewolves_swarm/agent_roll.py @@ -7,7 +7,7 @@ from ajet.utils.thread_executors import PeriodicDrainThreadPoolExecutor from ajet.tuner_lib.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey from ajet.default_config.ajet_default import AjetTaskReader -from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient +from ajet.tuner_lib.experimental.swarm_client import SwarmClient NUM_EPOCH = 10000 AJET_SWARM_URL = os.getenv("AJET_SWARM_URL", "http://localhost:10086") diff --git a/tutorial/opencode_build_appworld_react/agent_roll.py b/tutorial/opencode_build_appworld_react/agent_roll.py index 84fad56..443c956 100644 --- a/tutorial/opencode_build_appworld_react/agent_roll.py +++ b/tutorial/opencode_build_appworld_react/agent_roll.py @@ -10,7 +10,7 @@ import os import subprocess from ajet.copilot.job import AgentJetJob -from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient, run_episodes_until_all_complete +from ajet.tuner_lib.experimental.swarm_client import SwarmClient, run_episodes_until_all_complete from ajet.utils.env_service_client.env_client_ng import EnvClient from ajet.schema.task import Task from tutorial.opencode_build_appworld_react.agent_run import run_agent_and_compute_reward diff --git a/tutorial/opencode_build_countdown_agent/agent_roll.py b/tutorial/opencode_build_countdown_agent/agent_roll.py index 09fd868..d6b7e09 100644 --- a/tutorial/opencode_build_countdown_agent/agent_roll.py +++ b/tutorial/opencode_build_countdown_agent/agent_roll.py @@ -16,7 +16,7 @@ """ from ajet.copilot.job import AgentJetJob -from ajet.tuner_lib.experimental.as_swarm_client import ( +from ajet.tuner_lib.experimental.swarm_client import ( SwarmClient, run_episodes_until_all_complete, )