Skip to content

Commit

Permalink
Fixes in distributed layers
Browse files Browse the repository at this point in the history
  • Loading branch information
angeloskath committed Aug 1, 2024
1 parent 62e5a5b commit becb7f8
Showing 1 changed file with 17 additions and 8 deletions.
25 changes: 17 additions & 8 deletions python/mlx/nn/layers/distributed.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright © 2024 Apple Inc.

import math
from functools import lru_cache
from typing import Optional

Expand Down Expand Up @@ -168,7 +169,7 @@ def __call__(self, x: mx.array) -> mx.array:
if self.group.size() > 1:
# Perform the local projection and aggregate the results
x = x @ self["weight"].T
x = mx.distributed.all_sum(x, group=group)
x = mx.distributed.all_sum(x, group=self.group)

# Add the bias if we have one
if "bias" in self:
Expand Down Expand Up @@ -316,9 +317,9 @@ def from_quantized_linear(
bits=quantized_linear_layer.bits,
group=group,
)
sl.weight = quantized_linear_layer.weight[r : step : (r + 1) * step] * 1
sl.scales = quantized_linear_layer.scales[r : step : (r + 1) * step] * 1
sl.biases = quantized_linear_layer.biases[r : step : (r + 1) * step] * 1
sl.weight = quantized_linear_layer.weight[r * step : (r + 1) * step] * 1
sl.scales = quantized_linear_layer.scales[r * step : (r + 1) * step] * 1
sl.biases = quantized_linear_layer.biases[r * step : (r + 1) * step] * 1
if "bias" in quantized_linear_layer:
sl.bias = quantized_linear_layer.bias[r * step : (r + 1) * step] * 1

Expand Down Expand Up @@ -413,7 +414,7 @@ def __call__(self, x: mx.array) -> mx.array:
bits=self.bits,
)
if self.group.size() > 1:
x = mx.distributed.sum_all(x, group=group)
x = mx.distributed.all_sum(x, group=self.group)
if "bias" in self:
x = x + self["bias"]
return x
Expand All @@ -428,6 +429,8 @@ def from_quantized_linear(
N = group.size()
r = group.rank()
output_dims, input_dims = quantized_linear_layer.weight.shape
step = input_dims // N
step_grouped = quantized_linear_layer.scales.shape[1] // N
input_dims *= (32 // quantized_linear_layer.bits) * N

sl = cls(
Expand All @@ -438,9 +441,15 @@ def from_quantized_linear(
bits=quantized_linear_layer.bits,
group=group,
)
sl.weight = quantized_linear_layer.weight[r : step : (r + 1) * step] * 1
sl.scales = quantized_linear_layer.scales[r : step : (r + 1) * step] * 1
sl.biases = quantized_linear_layer.biases[r : step : (r + 1) * step] * 1
sl.weight = quantized_linear_layer.weight[:, r * step : (r + 1) * step] * 1
sl.scales = (
quantized_linear_layer.scales[:, r * step_grouped : (r + 1) * step_grouped]
* 1
)
sl.biases = (
quantized_linear_layer.biases[:, r * step_grouped : (r + 1) * step_grouped]
* 1
)
if "bias" in quantized_linear_layer:
sl.bias = quantized_linear_layer.bias

Expand Down

0 comments on commit becb7f8

Please sign in to comment.