Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
19ad728
fix: remove wrong trl imports
michaelbenayoun Oct 15, 2025
34f4698
feat: align to latest trl release
michaelbenayoun Oct 15, 2025
07437bc
chore: update pyproject.toml
michaelbenayoun Oct 15, 2025
e5256bf
style
michaelbenayoun Oct 15, 2025
954cfdf
feat: sync with SFTTrainer
michaelbenayoun Oct 15, 2025
3f1f700
Merge branch 'main' into sync_trl
michaelbenayoun Oct 31, 2025
3c72216
fix: minor issues
michaelbenayoun Oct 31, 2025
d286f50
chore: sync with trl==0.24.0
michaelbenayoun Oct 31, 2025
0992738
chore: sync sft_trainer
michaelbenayoun Oct 31, 2025
5a847ec
chore: sync sft_trainer
michaelbenayoun Oct 31, 2025
cddbf5f
chore: sync sft_trainer
michaelbenayoun Nov 3, 2025
0200820
fix: sft trainer
michaelbenayoun Nov 4, 2025
5bd79e6
Merge branch 'main' into sync_trl
michaelbenayoun Nov 4, 2025
2c8c1d1
chore: update dependency version for trl
michaelbenayoun Nov 4, 2025
7eda163
chore: cleanup and fix no-packing test
michaelbenayoun Nov 4, 2025
b6ee2a3
chore: restore finetune_qwen3.sh
michaelbenayoun Nov 4, 2025
72b338a
feat: add model card creation when saving a checkpoint
michaelbenayoun Nov 4, 2025
98a6210
chore: remove model card support
michaelbenayoun Nov 4, 2025
ee6caeb
doc: align with trl==0.24.0
michaelbenayoun Nov 4, 2025
ac0c9f2
test: fix broken sft + peft test
michaelbenayoun Nov 4, 2025
8892d51
chore: add GRPO imports in optimum.neuron
michaelbenayoun Nov 4, 2025
f26497d
chore: add GRPO imports in optimum.neuron.trainers
michaelbenayoun Nov 4, 2025
f574d3e
chore: add skeleton for GRPO trainer
michaelbenayoun Nov 4, 2025
b105f91
feat: add mock class for vLLM
michaelbenayoun Nov 4, 2025
b2d45f0
Merge branch 'main' into grpo
michaelbenayoun Nov 5, 2025
781b27f
fix: add is_vllm_available imports
michaelbenayoun Nov 5, 2025
e932e28
chore: add data loading
michaelbenayoun Nov 5, 2025
cef6d30
chore: add _prepare_inputs
michaelbenayoun Nov 5, 2025
a567289
chore: keep replacing stub methods
michaelbenayoun Nov 6, 2025
c8a7ed8
chore: add mock specific comment
michaelbenayoun Nov 6, 2025
0ddc40f
chore: wip, full training cycle with mocks
michaelbenayoun Nov 13, 2025
8cf2842
wip: grpo trainer almost working with mocks (recompilation issues)
michaelbenayoun Nov 14, 2025
26bdd70
fix: gradient checkpointing
michaelbenayoun Nov 17, 2025
e9ca881
temp: added the example script, temporary
michaelbenayoun Nov 17, 2025
8eed696
chore: wip, added torch.sync()
michaelbenayoun Nov 18, 2025
f5ad213
Merge branch 'main' into grpo
michaelbenayoun Nov 18, 2025
02b6253
wip: fix computation device in
michaelbenayoun Nov 18, 2025
e844af9
wip: fix computation device in
michaelbenayoun Nov 18, 2025
745674d
precompilation
michaelbenayoun Nov 18, 2025
b82025d
make ops XLA friendly
michaelbenayoun Nov 19, 2025
a224bed
add torch_xla.sync() to break the graphs in the for loops
michaelbenayoun Nov 19, 2025
fc18366
add DistributedRepeatSampler
michaelbenayoun Nov 20, 2025
caf238a
merge for lora.ParallelLinear
michaelbenayoun Nov 20, 2025
026a237
merge for ParallelEmbedding
michaelbenayoun Nov 20, 2025
a15843a
merge for peft models
michaelbenayoun Nov 20, 2025
b052efc
merge for peft models
michaelbenayoun Nov 21, 2025
b2ff310
merge for peft models
michaelbenayoun Nov 21, 2025
21dc065
fix test
michaelbenayoun Nov 21, 2025
58c438c
trainer runs with mock but produces NaNs
michaelbenayoun Nov 24, 2025
b0d3056
add vllm file
michaelbenayoun Nov 25, 2025
4f42f21
add collectives for python objects
michaelbenayoun Nov 26, 2025
23b0090
add vllm_client for CPU
michaelbenayoun Nov 26, 2025
1e15a27
add VLLMClient and collectives on python objects
michaelbenayoun Nov 26, 2025
fd1239b
add MockVLLMClient
michaelbenayoun Nov 28, 2025
5e8af49
wip, recompilations
michaelbenayoun Dec 2, 2025
dd4041d
collectives work
michaelbenayoun Dec 2, 2025
56011c4
fix clamping bug
michaelbenayoun Dec 2, 2025
a7f58a3
make use_vllm the default
michaelbenayoun Dec 3, 2025
544cadf
update with torch_xla.sync and peft
michaelbenayoun Dec 4, 2025
b0aa60c
wip training
michaelbenayoun Dec 8, 2025
ba62cf4
wip training
michaelbenayoun Dec 9, 2025
688de3a
chore: fix pyproject.toml for uv
michaelbenayoun Dec 9, 2025
a2880ae
chore: update pyproject.toml for SDK 2.26.1
michaelbenayoun Dec 10, 2025
7d8914b
feat: improve nan functions for XLA
michaelbenayoun Dec 18, 2025
ee083cb
feat: compute rewards more XLA friendly
michaelbenayoun Dec 18, 2025
998fd64
feat: optimization for XLA
michaelbenayoun Dec 19, 2025
07cd31b
Merge branch 'main' into grpo
michaelbenayoun Dec 23, 2025
bf000df
Merge branch 'main' into grpo
michaelbenayoun Jan 5, 2026
4407687
debug: training produces NaNs
michaelbenayoun Jan 29, 2026
131df14
fix: no NaNs anymore
michaelbenayoun Jan 29, 2026
7a0167e
rewrite _get_per_token_logps_and_entropies for better breaks
michaelbenayoun Jan 30, 2026
ac36687
optimize _compute_loss
michaelbenayoun Jan 30, 2026
8c816f6
optimize _generate_and_score_completions
michaelbenayoun Jan 30, 2026
4fa42ba
fix: use separate model for ref model to avoid XLA NaN issues
michaelbenayoun Feb 4, 2026
8fc448a
fix: use separate model for ref model to avoid XLA NaN issues
michaelbenayoun Feb 4, 2026
39660dc
chore: vllm_client.py remove unused functions
michaelbenayoun Feb 4, 2026
6828a9b
chore: remove useless docstrings in vllm_client.py
michaelbenayoun Feb 4, 2026
69252e8
chore: add safeguard for the GRPO feature
michaelbenayoun Feb 4, 2026
c2e6582
chore: grpo_trainer.py cleanup
michaelbenayoun Feb 4, 2026
22e1c22
chore: untrack example
michaelbenayoun Feb 4, 2026
ba1ac45
chore: clean trl_utils.py
michaelbenayoun Feb 4, 2026
5a19bd6
chore: clean trl_utils.py
michaelbenayoun Feb 4, 2026
51b68be
chore: clean trl_utils.py
michaelbenayoun Feb 4, 2026
bce138b
fix: add training extra for doc building
michaelbenayoun Feb 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/actions/install_optimum_neuron/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ runs:
shell: bash
run: |
source aws_neuron_venv_pytorch/bin/activate
python -m pip install .[neuronx,tests]
python -m pip install .[neuronx,tests,training]
4 changes: 4 additions & 0 deletions optimum/neuron/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@
"trainers": [
"NeuronTrainer",
"NeuronSFTTrainer",
"NeuronGRPOTrainer",
"NeuronTrainingArguments",
"NeuronSFTConfig",
"NeuronGRPOConfig",
],
"modeling_traced": ["NeuronTracedModel"],
"modeling": [
Expand Down Expand Up @@ -156,6 +158,8 @@
from .models.inference.yolos import NeuronYolosForObjectDetection
from .pipelines import pipeline
from .trainers import (
NeuronGRPOConfig,
NeuronGRPOTrainer,
NeuronSFTConfig,
NeuronSFTTrainer,
NeuronTrainer,
Expand Down
16 changes: 9 additions & 7 deletions optimum/neuron/models/training/transformations_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def peft_type(self) -> str | None:
return self._peft_type

@peft_type.setter
def peft_type(self, value: str):
def peft_type(self, value: str | None):
self._peft_type = value

@abstractmethod
Expand Down Expand Up @@ -533,6 +533,9 @@ def _lora_adapt_state_dict(
f"{module_fully_qualified_name}.{name}.lora_A.{param_name}" for name in self.linear_names
]

if not all(name in state_dict for name in lora_A_weight_names):
continue

logger.warning("Taking the mean of the LoRA A weights since there is only one LoRA A weight after fusing.")
lora_A_weight = torch.mean(
torch.stack([state_dict.pop(name) for name in lora_A_weight_names], dim=0),
Expand Down Expand Up @@ -650,9 +653,7 @@ def _lora_to_original_weights(
break

if weight_name is None or to_duplicate_name is None:
raise ValueError(
f"Could not find LoRA weights for {module_fully_qualified_name} with param name {param_name}."
)
continue

# When saved, the name of the adapter is removed in the weight qualified name since weights for each
# adapter are saved separately.
Expand Down Expand Up @@ -700,9 +701,7 @@ def _lora_to_original_weights(
if to_concat_and_duplicate_name is not None and to_unfuse_name is not None:
break
if to_concat_and_duplicate_name is None or to_unfuse_name is None:
raise ValueError(
f"Could not find LoRA weights for {module_fully_qualified_name} with param name {param_name}."
)
continue

weight_name_without_adapter_name = remove_adapter_name(to_concat_and_duplicate_name)
linear_sharded_weights = sharded_state_dicts[weight_name_without_adapter_name]
Expand Down Expand Up @@ -1100,6 +1099,9 @@ def _lora_adapt_state_dict(

lora_A_weight_names = [lora_A_q_name, lora_A_k_name, lora_A_v_name]

if not all(name in state_dict for name in lora_A_weight_names):
continue

logger.warning("Taking the mean of the LoRA A weights since there is only one LoRA A weight after fusing.")
lora_A_weight = torch.mean(
torch.stack([state_dict.pop(name) for name in lora_A_weight_names], dim=0),
Expand Down
2 changes: 2 additions & 0 deletions optimum/neuron/trainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .grpo_config import NeuronGRPOConfig
from .grpo_trainer import NeuronGRPOTrainer
from .sft_config import NeuronSFTConfig
from .sft_trainer import NeuronSFTTrainer
from .training_args import NeuronTrainingArguments
Expand Down
19 changes: 19 additions & 0 deletions optimum/neuron/trainers/extras/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .vllm_client import MockVLLMClient, VLLMClient


__all__ = ["VLLMClient", "MockVLLMClient"]
213 changes: 213 additions & 0 deletions optimum/neuron/trainers/extras/vllm_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import atexit
import random
import time
from collections import namedtuple
from typing import Union

import requests
import torch
import torch_xla
from optimum.utils import logging
from trl.extras.vllm_client import VLLMClient as TRLVLLMClient
from trl.import_utils import is_vllm_available


if is_vllm_available():
from vllm.distributed.utils import StatelessProcessGroup
else:

class StatelessProcessGroup:
pass


logger = logging.get_logger()

# Set up the communication group for weight broadcasting using CPU communicator
Group = namedtuple("Group", "barrier")


class CPUCommunicator:
def __init__(self, store, rank):
self.rank = rank
self.store = store
self.group = Group(barrier=self.barrier)

def broadcast(self, tensor, src):
# Move tensor to CPU to ensure compatibility with vLLM server
if tensor.device.type == "xla":
tensor = tensor.cpu()
torch_xla.sync()
self.store.broadcast_obj(tensor, src=self.rank)

def barrier(self):
self.store.barrier()

def __del__(self):
del self.store


class VLLMClient(TRLVLLMClient):
"""VLLMClient with CPU-based communication for Neuron environments."""

def __init__(
self,
base_url: str | None = None,
host: str = "0.0.0.0",
server_port: int = 8000,
group_port: int = 51216,
connection_timeout: float = 0.0,
):
super().__init__(
base_url=base_url,
host=host,
server_port=server_port,
group_port=group_port,
connection_timeout=connection_timeout,
)

def init_communicator(self, device: Union[torch.device, str, int] = 0):
# Get the world size from the server
url = f"{self.base_url}/get_world_size/"
response = requests.get(url)
if response.status_code == 200:
vllm_world_size = response.json()["world_size"]
else:
raise Exception(f"Request failed: {response.status_code}, {response.text}")

world_size = vllm_world_size + 1 # add the client to the world
self.rank = vllm_world_size # the client's rank is the last process

# Initialize weight update group
url = f"{self.base_url}/init_communicator/"

# Use dummy UUID for CPU/Neuron environments
client_device_uuid = "42"

# In the server side, the host is set to 0.0.0.0
response = self.session.post(
url,
json={
"host": "0.0.0.0",
"port": self.group_port,
"world_size": world_size,
"client_device_uuid": client_device_uuid,
},
)
if response.status_code != 200:
raise Exception(f"Request failed: {response.status_code}, {response.text}")

# Brief delay to allow server initialization. While not strictly required (client socket will retry on
# connection failure), this prevents log warnings like:
# [W416 23:24:57.460001114 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3
time.sleep(0.1)

pg = StatelessProcessGroup.create(host=self.host, port=self.group_port, rank=self.rank, world_size=world_size)
self.communicator = CPUCommunicator(pg, self.rank)

# When the client object is deleted, close the weight update group
atexit.register(self.close_communicator)


class MockVLLMClient(VLLMClient):
"""
Mock VLLMClient that generates completions without a vLLM server.

Used for neuron_parallel_compile and testing. Generates completions by cycling
through prompt tokens (echo mode), producing deterministic, non-garbage output.
"""

def __init__(self, tokenizer, max_completion_length=256, min_completion_length=10, seed=None):
self.tokenizer = tokenizer
self.max_completion_length = max_completion_length
self.min_completion_length = min(min_completion_length, max_completion_length)
self.random = random.Random(seed)

logger.warning(
"Using MockVLLMClient for neuron_parallel_compile or testing. "
"This generates echo completions and should only be used for compilation/testing."
)

def generate(
self,
prompts: list[str],
images=None,
n: int = 1,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
min_p: float = 0.0,
max_tokens: int = 256,
repetition_penalty: float = 1.0,
truncate_prompt_tokens=None,
guided_decoding_regex=None,
generation_kwargs=None,
):
prompt_ids = []
completion_ids = []
logprobs = []

# Fallback tokens if prompt is empty
vocab_size = self.tokenizer.vocab_size
fallback_token_id = min(100, vocab_size - 1)

for prompt in prompts:
# Tokenize prompt
prompt_tokens = self.tokenizer.encode(prompt, add_special_tokens=False)

# Truncate if needed
if truncate_prompt_tokens is not None and len(prompt_tokens) > truncate_prompt_tokens:
prompt_tokens = prompt_tokens[-truncate_prompt_tokens:]

prompt_ids.append(prompt_tokens)

# Generate n completions per prompt
for _ in range(n):
# Random completion length within bounds
max_len = min(max_tokens, self.max_completion_length)
completion_length = self.random.randint(self.min_completion_length, max_len)

# Echo mode: cycle through prompt tokens
if len(prompt_tokens) > 0:
completion = [prompt_tokens[i % len(prompt_tokens)] for i in range(completion_length)]
else:
# Fallback if prompt is empty
completion = [fallback_token_id] * completion_length

completion_ids.append(completion)

# Logprobs: simulate higher confidence for echoed tokens
completion_logprobs = [-self.random.uniform(0.5, 2.0) for _ in range(completion_length)]
logprobs.append(completion_logprobs)

return {
"prompt_ids": prompt_ids,
"completion_ids": completion_ids,
"logprobs": logprobs,
}

def init_communicator(self, device):
pass

def update_named_param(self, name, weights):
pass

def reset_prefix_cache(self):
pass

def close_communicator(self):
pass
Loading