Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 56 additions & 2 deletions open_instruct/grpo_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@
from ray.util.placement_group import PlacementGroup, placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from rich.pretty import pprint
from tensordict import tensorclass
from torchrl.data import RayReplayBuffer, ReplayBuffer
from tqdm import tqdm
from transformers import AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer, get_scheduler
from transformers.integrations import HfDeepSpeedConfig
Expand Down Expand Up @@ -487,6 +489,32 @@ def __post_init__(self):
assert len(self.tools) == len(set(self.tools)), "Duplicate tools are not allowed"


@tensorclass
class PackedLogProbSequence:
query_response: torch.Tensor
"""packed query and response (batch_size, pack_length)"""
attention_mask: torch.Tensor
"""3D attention mask for packed sequences (batch_size, pack_length, pack_length);
it basically uses a intra-document mask for each query response pair;
see https://huggingface.co/blog/sirluk/llm-sequence-packing for more details
"""
response_mask: torch.Tensor
"""response mask for packed sequences (batch_size, pack_length)"""
ref_logprob: Optional[torch.Tensor]
"""packed rewards (batch_size, pack_length)"""
old_logprob: Optional[torch.Tensor]
"""packed rewards (batch_size, pack_length)"""

tool_mask: Optional[torch.Tensor] = None
"""tool mask for packed sequences (batch_size, pack_length)"""
position_id: Optional[torch.Tensor] = None
"""packed position ids (batch_size, pack_length)"""
advantage: Optional[torch.Tensor] = None
"""packed advantages (batch_size, pack_length) (to be filled in by the main process)"""
reward: Optional[torch.Tensor] = None
"""packed rewards (batch_size, pack_length)"""


def next_batch(dataset_indices: List[int], dataset: datasets.Dataset) -> Batch:
"""Extract next batch of data based on indices."""
data_next = dataset[dataset_indices]
Expand Down Expand Up @@ -593,6 +621,7 @@ def from_pretrained(
beaker_config: BeakerRuntimeConfig,
wandb_url: str,
tokenizer: PreTrainedTokenizer,
replay_buffer: ReplayBuffer,
):
# ------------------------------------------------------------
# Monkey patch to load checkpoints with `weights_only=False`
Expand All @@ -615,6 +644,7 @@ def load(self, path: str, map_location=None):
self.model_config = model_config
self.beaker_config = beaker_config
self.wandb_url = wandb_url
self.replay_buffer = replay_buffer
torch.cuda.set_device(self.local_rank)
self.device = torch.device(self.local_rank)

Expand Down Expand Up @@ -946,6 +976,7 @@ def train(
# if we have multiple minibatches, we need to calculate the old logprobs for each minibatch
# following gtrl scripts in just doing this on the current active policy, rather than use the logprobs
# from the generator (note that async mode means these are a bit diff!)
replay_buffer_list = []
old_logprobs = [None for _ in range(len(collated_query_responses))]
if num_mini_batches > 1:
with Timer("Old logprobs Calculation", noop=self.rank != 0):
Expand Down Expand Up @@ -973,6 +1004,20 @@ def train(
old_logprobs[i] = old_logprob
torch.cuda.empty_cache()

replay_buffer_list.append(
PackedLogProbSequence(
query_response=query_response,
attention_mask=attention_mask,
response_mask=response_mask,
tool_mask=tool_mask,
advantage=collated_advantages[i],
position_id=position_id,
old_logprob=old_logprob,
ref_logprob=ref_logprob,
)
)
self.replay_buffer.extend(replay_buffer_list)

local_step = 0
# Do multiple epochs of training on on-policy data (PPO-style), with a fresh random shuffle in each epoch
with Timer("[Training Processes] Loss calculation", noop=self.rank != 0):
Expand Down Expand Up @@ -1965,6 +2010,7 @@ def create_model_and_optimizer(
inference_results_Q: ray_queue.Queue,
param_prompt_Q: ray_queue.Queue,
evaluation_inference_results_Q: ray_queue.Queue,
replay_buffer: ReplayBuffer,
) -> tuple[ModelGroup, list[vllm_utils3.LLMRayActor], dict, int, int]:
"""Create the model, optimizer, and vLLM engines."""
# Create placement group
Expand All @@ -1975,7 +2021,7 @@ def create_model_and_optimizer(
policy_group = ModelGroup(pg, PolicyTrainerRayProcess, args.num_learners_per_node, args.single_gpu_mode)
wandb_url = wandb.run.get_url() if args.with_tracking else None
inits.extend(
model.from_pretrained.remote(args, model_config, beaker_config, wandb_url, tokenizer)
model.from_pretrained.remote(args, model_config, beaker_config, wandb_url, tokenizer, replay_buffer)
for model in policy_group.models
)

Expand Down Expand Up @@ -2109,7 +2155,11 @@ def split_and_insert_batch(


def load_data_from_packing_thread(
packed_sequences_Q: Queue, num_total_tokens: int, stop_event: threading.Event, health_check_fn: Callable[[], None]
packed_sequences_Q: Queue,
num_total_tokens: int,
stop_event: threading.Event,
health_check_fn: Callable[[], None],
replay_buffer: ReplayBuffer,
):
"""Get the packed sequences with advantages from the packing thread."""
with Timer("[Main Thread] 📦 Getting packed sequences from thread") as timer:
Expand Down Expand Up @@ -2595,6 +2645,7 @@ def run_training(
generate_metrics_Q,
weight_sync_metrics_Q,
actor_manager: ActorManager,
replay_buffer: ReplayBuffer,
checkpoint_state=None,
):
if resume_training_step > 1:
Expand Down Expand Up @@ -2808,6 +2859,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig):
param_prompt_Q = ray_queue.Queue(maxsize=queue_size)
# We don't care if we ever hit the max, so we let the queue be unbounded.
evaluation_inference_results_Q = ray_queue.Queue()
replay_buffer = RayReplayBuffer(batch_size=128)

policy_group, vllm_engines, tool_objects, resume_training_step, episode, actor_manager = (
create_model_and_optimizer(
Expand All @@ -2820,6 +2872,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig):
inference_results_Q,
param_prompt_Q,
evaluation_inference_results_Q,
replay_buffer,
)
)

Expand Down Expand Up @@ -2886,6 +2939,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig):
generate_metrics_Q,
weight_sync_metrics_Q,
actor_manager,
replay_buffer,
checkpoint_state,
)
finally:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies = [
"immutabledict==1.2.0",
"flash-attn>=2.8.0.post2; platform_system != 'Darwin'",
"liger-kernel>=0.5.4; platform_system != 'Darwin'",
"torchrl>=0.9.2",
]

[build-system]
Expand Down
Loading