Skip to content

Commit ea52adf

Browse files
committed
adding all diff to new branch
1 parent 1275d36 commit ea52adf

File tree

8 files changed

+214
-59
lines changed

8 files changed

+214
-59
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,4 @@ dmypy.json
159159
cache/
160160
local_dataset_cache/
161161
scratch/
162+
vllm_olmo3/

Dockerfile

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,5 @@
11
FROM ghcr.io/allenai/cuda:12.8-dev-ubuntu22.04-torch2.7.0-v1.2.170
22

3-
# Add build arguments for git information
4-
ARG GIT_COMMIT=""
5-
ARG GIT_BRANCH=""
6-
7-
# Set them as environment variables
8-
ENV GIT_COMMIT=${GIT_COMMIT}
9-
ENV GIT_BRANCH=${GIT_BRANCH}
10-
113
COPY --from=ghcr.io/astral-sh/uv:0.8.6 /uv /uvx /bin/
124

135
# Set default cache directory but allow override from environment
@@ -31,6 +23,9 @@ COPY pyproject.toml uv.lock ./
3123
# Annoyingly, we need this before `uv run`, or it complains.
3224
COPY open_instruct open_instruct
3325

26+
# Install custom vllm for olmo3
27+
RUN git clone -b shanea/olmo2-retrofit https://github.com/2015aroras/vllm.git vllm_olmo2.5
28+
3429
# Install dependencies
3530
RUN --mount=type=cache,target=${UV_CACHE_DIR} \
3631
--mount=type=bind,source=uv.lock,target=uv.lock \
@@ -47,6 +42,7 @@ COPY configs configs
4742
COPY scripts scripts
4843
COPY oe-eval-internal oe-eval-internal
4944
COPY mason.py mason.py
45+
COPY .git/ ./.git/
5046

5147
# Set up the environment
52-
ENV PATH=/stage/.venv/bin:$PATH
48+
ENV PATH=/stage/.venv/bin:$PATH

Makefile

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
.PHONY: style quality
1+
.PHONY: style quality docker
22

33
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
44
export PYTHONPATH = open_instruct
@@ -16,3 +16,10 @@ style-check: ## *fail* if anything needs rewriting
1616

1717
quality-check: ## *fail* if any rewrite was needed
1818
uv run ruff check --exit-non-zero-on-fix $(check_dirs)
19+
20+
docker:
21+
DOCKER_BUILDKIT=1 docker build -f Dockerfile --build-arg UV_CACHE_DIR=$(UV_CACHE_DIR) -t open_instruct_dev_uv_olmo3 .
22+
# if you are internally at AI2, you can create an image like this:
23+
$(eval beaker_user := $(shell beaker account whoami --format json | jq -r '.[0].name'))
24+
beaker image delete $(beaker_user)/open_instruct_dev_olmo2.5
25+
beaker image create open_instruct_dev_uv_olmo3 -n open_instruct_dev_uv_olmo3 -w ai2/$(beaker_user)

open_instruct/grpo_fast.py

Lines changed: 43 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
from ray.util.placement_group import PlacementGroup, placement_group
7474
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
7575
from rich.pretty import pprint
76+
from torch.utils.tensorboard import SummaryWriter
7677
from tqdm import tqdm
7778
from transformers import AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer, get_scheduler
7879
from transformers.integrations import HfDeepSpeedConfig
@@ -122,7 +123,6 @@
122123
is_beaker_job,
123124
launch_ai2_evals_on_weka,
124125
maybe_get_beaker_config,
125-
maybe_update_beaker_description_with_wandb_url,
126126
maybe_use_ai2_hf_entity,
127127
maybe_use_ai2_wandb_entity,
128128
ray_get_with_progress,
@@ -382,6 +382,8 @@ class Args:
382382
"""The beaker evaluation tasks to launch"""
383383
oe_eval_max_length: int = 4096
384384
"""the max generation length for evaluation for oe-eval"""
385+
oe_eval_beaker_image: Optional[str] = None
386+
"""the docker image for evaluation for oe-eval"""
385387
eval_priority: Literal["low", "normal", "high", "urgent"] = "normal"
386388
"""the priority of auto-launched evaluation jobs"""
387389

@@ -1078,6 +1080,7 @@ def launch_ai2_evals_on_weka_wrapper(self, step_dir, leaderboard_name, wandb_url
10781080
args.stop_strings,
10791081
args.gs_bucket_path,
10801082
args.eval_priority,
1083+
args.oe_eval_beaker_image,
10811084
)
10821085

10831086

@@ -1648,15 +1651,21 @@ def setup_experiment_tracking(args: Args, tc: TokenizerConfig, model_config: Mod
16481651
wandb.init(
16491652
project=args.wandb_project_name,
16501653
entity=args.wandb_entity,
1654+
sync_tensorboard=True,
16511655
config=all_configs,
16521656
name=args.run_name,
16531657
save_code=True,
16541658
tags=[args.exp_name] + get_wandb_tags(),
16551659
)
16561660
wandb_url = wandb.run.get_url()
1657-
maybe_update_beaker_description_with_wandb_url(wandb_url)
16581661

1659-
return beaker_config, wandb_url
1662+
writer = SummaryWriter(f"runs/{args.run_name}")
1663+
writer.add_text(
1664+
"hyperparameters",
1665+
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
1666+
)
1667+
1668+
return beaker_config, writer, wandb_url
16601669

16611670

16621671
def setup_datasets(args: Args, tc: TokenizerConfig, tokenizer: PreTrainedTokenizer):
@@ -1936,11 +1945,13 @@ def one_training_step(
19361945
collated_data,
19371946
tokenizer,
19381947
data_thread_metrics,
1948+
average_metrics,
19391949
episode,
19401950
training_step,
19411951
num_total_tokens,
19421952
start_time,
19431953
train_dataset,
1954+
writer,
19441955
wandb_url,
19451956
chat_template_name,
19461957
):
@@ -1975,18 +1986,16 @@ def one_training_step(
19751986
**data_thread_metrics,
19761987
**average_metrics,
19771988
}
1978-
# Print only scalar metrics
1979-
scalar_metrics = {k: v for k, v in metrics.items() if isinstance(v, (float, int))}
1989+
scalar_metrics = {}
1990+
for key, value in metrics.items():
1991+
if isinstance(value, float) or isinstance(value, int):
1992+
writer.add_scalar(key, value, episode)
1993+
scalar_metrics[key] = value
1994+
if isinstance(value, np.ndarray) or isinstance(value, list):
1995+
if len(value) > 0:
1996+
writer.add_histogram(key, value, episode)
19801997
print_rich_single_line_metrics(scalar_metrics)
19811998

1982-
if args.with_tracking:
1983-
# Convert array/list metrics to wandb histograms for logging
1984-
for key, value in metrics.items():
1985-
if isinstance(value, np.ndarray) or isinstance(value, list):
1986-
if len(value) > 0:
1987-
metrics[key] = wandb.Histogram(value)
1988-
wandb.log(metrics, step=episode)
1989-
19901999
if args.save_freq > 0 and training_step % args.save_freq == 0 and (args.eval_on_step_0 or training_step > 1):
19912000
with Timer("[Main Thread] 🗡️ Saving model"):
19922001
checkpoint_dir = f"{args.output_dir}_checkpoints"
@@ -2036,6 +2045,7 @@ def maybe_evaluate(
20362045
eval_batch: Optional[Batch],
20372046
reward_fn,
20382047
episode,
2048+
writer,
20392049
eval_pending_queries_map: PendingQueriesMap,
20402050
eval_generation_config,
20412051
):
@@ -2083,18 +2093,19 @@ def maybe_evaluate(
20832093
**eval_reward_metrics,
20842094
}
20852095
print_rich_single_line_metrics(eval_metrics)
2086-
2096+
for key, value in eval_metrics.items():
2097+
writer.add_scalar(key, value, episode)
20872098
table = {}
20882099
table["prompt"] = tokenizer.batch_decode(eval_batch.queries if eval_batch else [])
20892100
table["response"] = eval_decoded_responses
20902101
table["response"] = [item.replace(tokenizer.pad_token, "") for item in table["response"]]
20912102
table["scores"] = eval_scores
20922103
table["ground_truth"] = eval_batch.ground_truths if eval_batch else []
20932104
df = pd.DataFrame(table)
2094-
20952105
if args.with_tracking:
2096-
eval_metrics["sample_completions"] = wandb.Table(dataframe=df)
2097-
wandb.log(eval_metrics, step=episode)
2106+
import wandb
2107+
2108+
wandb.log({"sample_completions": wandb.Table(dataframe=df)})
20982109
else:
20992110
print_rich_table(df.iloc[:1])
21002111
del table
@@ -2229,8 +2240,11 @@ async def reward_fn(
22292240

22302241
def cleanup_judge_clients():
22312242
"""Cleans up all LLM judge clients and shutdown Ray."""
2232-
asyncio.run(cleanup_all_llm_judge_clients())
2233-
logger.info("✅ LLM judge clients cleaned up")
2243+
try:
2244+
asyncio.run(cleanup_all_llm_judge_clients())
2245+
logger.info("✅ LLM judge clients cleaned up")
2246+
except Exception as cleanup_error:
2247+
logger.warning(f"Error during LLM judge cleanup: {cleanup_error}")
22342248
ray.shutdown()
22352249

22362250

@@ -2263,7 +2277,12 @@ def cleanup_training_resources(
22632277
queues[0].put(ShutdownSentinel(), timeout=1)
22642278

22652279
logger.info("Shutting down Ray queues...")
2266-
[queue.shutdown() for queue in queues]
2280+
for queue in queues:
2281+
try:
2282+
queue.shutdown()
2283+
except Exception as e:
2284+
logger.warning(f"Error shutting down Ray queue: {e}")
2285+
22672286
logger.info("Shutting down thread pool executor...")
22682287
executor.shutdown(wait=True)
22692288

@@ -2274,7 +2293,7 @@ def cleanup_training_resources(
22742293
def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_samples: int = 32):
22752294
tokenizer = make_tokenizer(tc, model_config)
22762295
args = setup_runtime_variables(args)
2277-
beaker_config, wandb_url = setup_experiment_tracking(args, tc, model_config)
2296+
beaker_config, writer, wandb_url = setup_experiment_tracking(args, tc, model_config)
22782297

22792298
train_dataset, eval_dataset = setup_datasets(args, tc, tokenizer)
22802299
if args.cache_dataset_only:
@@ -2412,11 +2431,13 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa
24122431
collated_data,
24132432
tokenizer,
24142433
data_thread_metrics,
2434+
{},
24152435
episode,
24162436
training_step,
24172437
num_total_tokens,
24182438
start_time,
24192439
train_dataset,
2440+
writer,
24202441
wandb_url,
24212442
tc.chat_template_name,
24222443
)
@@ -2429,6 +2450,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa
24292450
eval_batch,
24302451
reward_fn,
24312452
episode,
2453+
writer,
24322454
eval_pending_queries_map,
24332455
generation_configs["eval"],
24342456
)

open_instruct/utils.py

Lines changed: 58 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
from multiprocessing import resource_tracker as _rt
4949
from typing import Any, Iterable, List, NewType, Optional, Tuple, Union
5050

51-
import beaker
5251
import numpy as np
5352
import ray
5453
import requests
@@ -650,22 +649,66 @@ def parse(self) -> Union[DataClassType, Tuple[DataClassType]]:
650649

651650
# ----------------------------------------------------------------------------
652651
# Experiment tracking utilities
653-
def get_wandb_tags() -> List[str]:
654-
"""Get tags for Weights & Biases (e.g., `no-tag-404-g98dc659,pr-123,branch-main`)"""
655-
tags = [t for t in os.environ.get("WANDB_TAGS", "").split(",") if t != ""]
656-
if "GIT_COMMIT" in os.environ:
657-
git_commit = os.environ["GIT_COMMIT"]
658-
tags.append(f"commit: {git_commit}")
652+
def get_git_tag() -> str:
653+
"""Try to get the latest Git tag (e.g., `no-tag-404-g98dc659` or `v1.0.0-4-g98dc659`)"""
654+
git_tag = ""
655+
try:
656+
git_tag = (
657+
subprocess.check_output(["git", "describe", "--tags"], stderr=subprocess.DEVNULL).decode("ascii").strip()
658+
)
659+
except subprocess.CalledProcessError as e:
660+
logging.debug(f"Failed to get Git tag: {e}")
661+
662+
# If no Git tag found, create a custom tag based on commit count and hash
663+
if len(git_tag) == 0:
664+
try:
665+
count = int(
666+
subprocess.check_output(["git", "rev-list", "--count", "HEAD"], stderr=subprocess.DEVNULL)
667+
.decode("ascii")
668+
.strip()
669+
)
670+
hash = (
671+
subprocess.check_output(["git", "rev-parse", "--short", "HEAD"], stderr=subprocess.DEVNULL)
672+
.decode("ascii")
673+
.strip()
674+
)
675+
git_tag = f"no-tag-{count}-g{hash}"
676+
except subprocess.CalledProcessError as e:
677+
logging.debug(f"Failed to get commit count and hash: {e}")
678+
679+
return git_tag
680+
681+
682+
def get_pr_tag() -> str:
683+
"""Try to find associated pull request on GitHub (e.g., `pr-123`)"""
684+
pr_tag = ""
685+
try:
686+
git_commit = (
687+
subprocess.check_output(["git", "rev-parse", "--verify", "HEAD"], stderr=subprocess.DEVNULL)
688+
.decode("ascii")
689+
.strip()
690+
)
659691
# try finding the pull request number on github
660692
prs = requests.get(f"https://api.github.com/search/issues?q=repo:allenai/open-instruct+is:pr+{git_commit}")
661693
if prs.status_code == 200:
662694
prs = prs.json()
663-
if len(prs["items"]):
695+
if len(prs["items"]) > 0:
664696
pr = prs["items"][0]
665-
tags.append(f"pr: {pr['number']}")
666-
if "GIT_BRANCH" in os.environ:
667-
tags.append(f"branch: {os.environ['GIT_BRANCH']}")
668-
return tags
697+
pr_number = pr["number"]
698+
pr_tag = f"pr-{pr_number}"
699+
except Exception as e:
700+
logging.debug(f"Failed to get PR number: {e}")
701+
702+
return pr_tag
703+
704+
705+
def get_wandb_tags() -> List[str]:
706+
"""Get tags for Weights & Biases (e.g., `no-tag-404-g98dc659,pr-123`)"""
707+
existing_wandb_tags = os.environ.get("WANDB_TAGS", "")
708+
git_tag = get_git_tag()
709+
pr_tag = get_pr_tag()
710+
non_empty_tags = [tag for tag in [existing_wandb_tags, git_tag, pr_tag] if len(tag) > 0]
711+
return non_empty_tags
669712

670713

671714
# ----------------------------------------------------------------------------
@@ -923,17 +966,6 @@ def maybe_get_beaker_config():
923966
)
924967

925968

926-
def maybe_update_beaker_description_with_wandb_url(wandb_url: str) -> None:
927-
"""Update Beaker experiment description with wandb URL if running on Beaker."""
928-
if not is_beaker_job() or wandb_url is None:
929-
return
930-
931-
client = beaker.Beaker.from_env()
932-
spec = client.experiment.get(os.environ["BEAKER_WORKLOAD_ID"])
933-
current_description = spec.description or ""
934-
client.experiment.set_description(os.environ["BEAKER_WORKLOAD_ID"], f"{current_description}\n{wandb_url}")
935-
936-
937969
def live_subprocess_output(cmd: List[str]) -> str:
938970
output_lines = []
939971
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1)
@@ -1045,6 +1077,7 @@ def launch_ai2_evals_on_weka(
10451077
stop_strings: Optional[List[str]] = None,
10461078
gs_bucket_path: Optional[str] = None,
10471079
eval_priority: Optional[str] = "normal",
1080+
beaker_image: Optional[str] = None,
10481081
) -> None:
10491082
weka_cluster = "ai2/saturn-cirrascale ai2/neptune-cirrascale"
10501083
gcp_cluster = "ai2/augusta-google-1"
@@ -1096,6 +1129,8 @@ def launch_ai2_evals_on_weka(
10961129
command += f" --oe_eval_tasks {','.join(oe_eval_tasks)}"
10971130
if stop_strings is not None:
10981131
command += f" --oe_eval_stop_sequences '{','.join(stop_strings)}'"
1132+
if beaker_image is not None:
1133+
command += f" --beaker_image {beaker_image}"
10991134
print(f"Launching eval jobs with command: {command}")
11001135
process = subprocess.Popen(["bash", "-c", command], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
11011136
stdout, stderr = process.communicate()

open_instruct/vllm_utils3.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,10 @@ def create_vllm_engines(
306306
results_queue=None,
307307
eval_results_queue=None,
308308
) -> list[LLMRayActor]:
309-
assert vllm.__version__ >= "0.8.1", "OpenRLHF only supports vllm >= 0.8.1"
309+
310+
# if we installed from source, don't worry about it
311+
if "dev" not in vllm.__version__:
312+
assert vllm.__version__ >= "0.8.1", "OpenRLHF only supports vllm >= 0.8.1"
310313

311314
# Convert max_tool_calls to a dict mapping tool end strings to their limits
312315
assert len(max_tool_calls) == 1 or len(max_tool_calls) == len(tools), (

0 commit comments

Comments
 (0)