Skip to content

Commit

Permalink
block_to_split supports list str in CaptureSplitInfo pass (#1609)
Browse files Browse the repository at this point in the history
## Describe your changes

## Checklist before requesting a review
- [x] Add unit tests for this change.
- [x] Make sure all tests can pass.
- [x] Update documents if necessary.
- [x] Lint and apply fixes to your code by running `lintrunner -a`
- [x] Is this a user-facing change? If yes, give a description of this
change to be included in the release notes.
  block_to_split supports list str in CaptureSplitInfo pass 
- [x] Is this PR including examples changes? If yes, please remember to
update [example
documentation](https://github.com/microsoft/Olive/blob/main/docs/source/examples.md)
in a follow-up PR.

## (Optional) Issue link

Co-authored-by: hualxie <[email protected]>
  • Loading branch information
xieofxie and hualxie authored Feb 14, 2025
1 parent 00415b6 commit f0d9d77
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 11 deletions.
24 changes: 16 additions & 8 deletions olive/passes/pytorch/capture_split_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon
description="Number of splits to divide the model layers into.",
),
"block_to_split": PassConfigParam(
type_=str,
type_=Union[str, list[str]],
default_value=None,
description=(
"Name of the model block to split. Children of the block will be divided into the splits. For"
"Names of the model blocks to split. Children of the block will be divided into the splits. For"
" supported transformers models, the default value is the transformers layer block name."
),
),
Expand Down Expand Up @@ -114,18 +114,26 @@ def split_using_num_splits(
# check for None specifically since "" is a valid value
if block_to_split is None and isinstance(model, HfModelHandler):
model_wrapper = ModelWrapper.from_model(loaded_model)
block, block_to_split = model_wrapper.get_layers()
blocks = [model_wrapper.get_layers()]
elif block_to_split is None:
raise ValueError("block_to_split is not set and could not be inferred. Please set it manually.")
else:
block = get_attr(loaded_model, block_to_split, fail_on_not_found=True)

block_members = [child_name for child_name, _ in block.named_children()]
block_to_splits = block_to_split if isinstance(block_to_split, list) else [block_to_split]
blocks = [
(get_attr(loaded_model, block_to_split, fail_on_not_found=True), block_to_split)
for block_to_split in block_to_splits
]

block_members = [
f"{block_to_split}.{child_name}".lstrip(".")
for block, block_to_split in blocks
for child_name, _ in block.named_children()
]

split_assignments = {}
for split_idx, split_members in enumerate(np.array_split(block_members, config.num_splits)):
for child_name in split_members:
split_assignments[f"{block_to_split}.{child_name}".lstrip(".")] = split_idx
for member_name in split_members:
split_assignments[member_name] = split_idx

return split_assignments

Expand Down
21 changes: 18 additions & 3 deletions test/unit_test/passes/pytorch/test_capture_split_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@ def __init__(self):
super().__init__()
self.before_layer = torch.nn.Linear(2, 4)
self.layers = torch.nn.ModuleList([torch.nn.Linear(4, 4) for _ in range(4)])
self.after_layer = torch.nn.Linear(4, 2)
self.after_layers = torch.nn.ModuleList([torch.nn.Linear(4, 2) for _ in range(2)])

def forward(self, x):
x = self.before_layer(x)
for layer in self.layers:
x = layer(x)
return self.after_layer(x)
for layer in self.after_layers:
x = layer(x)
return x


@pytest.mark.parametrize(
Expand All @@ -34,11 +36,24 @@ def forward(self, x):
2,
{"layers.0": 0, "layers.1": 0, "layers.2": 1, "layers.3": 1},
),
# Test not equally divide the axis
(
PyTorchModelHandler(model_loader=lambda _: CustomModel()),
"layers",
3,
{"layers.0": 0, "layers.1": 0, "layers.2": 1, "layers.3": 2},
),
(
PyTorchModelHandler(model_loader=lambda _: CustomModel()),
"",
3,
{"before_layer": 0, "layers": 1, "after_layer": 2},
{"before_layer": 0, "layers": 1, "after_layers": 2},
),
(
PyTorchModelHandler(model_loader=lambda _: CustomModel()),
["layers", "after_layers"],
2,
{"layers.0": 0, "layers.1": 0, "layers.2": 0, "layers.3": 1, "after_layers.0": 1, "after_layers.1": 1},
),
(
HfModelHandler(model_path="hf-internal-testing/tiny-random-LlamaForCausalLM"),
Expand Down

0 comments on commit f0d9d77

Please sign in to comment.