Skip to content

Commit

Permalink
gpu kernel for 1d sparse recat gen (pytorch#179)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch/torchrec#179

* add the `expand_into_jagged_permute` GPU kernel callsite for generating 1D sparse data permute

Reviewed By: youyou6093

Differential Revision: D34778094

fbshipit-source-id: d14174cea809f3e33b1d860d297c7d318a930e34
  • Loading branch information
YazhiGao authored and facebook-github-bot committed Mar 30, 2022
1 parent 0410c7d commit 860d574
Showing 1 changed file with 68 additions and 55 deletions.
123 changes: 68 additions & 55 deletions torchrec/distributed/dist_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ def _get_recat(
local_split: int,
num_splits: int,
stagger: int = 1,
device: Optional[torch.device] = None,
batch_size_per_rank: Optional[List[int]] = None,
) -> List[int]:
) -> torch.Tensor:
"""
Calculates relevant recat indices required to reorder AlltoAll collective.
Expand All @@ -63,42 +64,58 @@ def _get_recat(
_recat(2, 4, 2)
# [0, 4, 2, 6, 1, 5, 3, 7]
"""
with record_function("## all2all_data:recat_permute_gen ##"):
recat: List[int] = []

recat: List[int] = []
if local_split == 0:
return torch.tensor(recat, device=device, dtype=torch.int32)

feature_order: List[int] = [
x + num_splits // stagger * y
for x in range(num_splits // stagger)
for y in range(stagger)
]
feature_order: List[int] = [
x + num_splits // stagger * y
for x in range(num_splits // stagger)
for y in range(stagger)
]

for i in range(local_split):
for j in feature_order: # range(num_splits):
recat.append(i + j * local_split)
for i in range(local_split):
for j in feature_order: # range(num_splits):
recat.append(i + j * local_split)

# variable batch size
if batch_size_per_rank is not None:
batch_size_per_feature = list(
itertools.chain.from_iterable(
itertools.repeat(x, local_split) for x in batch_size_per_rank
)
)
batch_size_per_feature_cumsum = [0] + list(
itertools.accumulate(batch_size_per_feature)
)
recat_per_feature = recat
recat = []
for r in recat_per_feature:
recat.extend(
list(
range(
batch_size_per_feature_cumsum[r],
batch_size_per_feature_cumsum[r + 1],
)
# variable batch size
if batch_size_per_rank is not None:
batch_size_per_feature = list(
itertools.chain.from_iterable(
itertools.repeat(x, local_split) for x in batch_size_per_rank
)
)

return recat
permuted_batch_size_per_feature = [batch_size_per_feature[r] for r in recat]
input_offset = [0] + list(itertools.accumulate(batch_size_per_feature))
output_offset = [0] + list(
itertools.accumulate(permuted_batch_size_per_feature)
)
recat_tensor = torch.tensor(
recat,
device=device,
dtype=torch.int32,
)
input_offset_tensor = torch.tensor(
input_offset,
device=device,
dtype=torch.int32,
)
output_offset_tensor = torch.tensor(
output_offset,
device=device,
dtype=torch.int32,
)
recat = torch.ops.fbgemm.expand_into_jagged_permute(
recat_tensor,
input_offset_tensor,
output_offset_tensor,
output_offset[-1],
)
return recat
else:
return torch.tensor(recat, device=device, dtype=torch.int32)


def _split_lengths(
Expand Down Expand Up @@ -321,15 +338,12 @@ def __init__(
)
self._batch_size_per_rank_tensor = batch_size_per_rank_tensor
self._batch_size_per_rank = batch_size_per_rank_tensor.cpu().tolist()
self._recat = torch.tensor(
_get_recat(
local_split=dim_0,
num_splits=len(splits),
stagger=stagger,
batch_size_per_rank=self._batch_size_per_rank,
),
self._recat = _get_recat(
local_split=dim_0,
num_splits=len(splits),
stagger=stagger,
device=self._device,
dtype=torch.int32,
batch_size_per_rank=self._batch_size_per_rank,
)
else:
assert self._recat is not None
Expand All @@ -341,9 +355,10 @@ def __init__(
dtype=in_lengths.dtype,
)
self._lengths = out_lengths
self._in_lengths_per_worker = _split_lengths(
splits, input.keys(), input.offset_per_key()
)
with record_function("## all2all_data:split length ##"):
self._in_lengths_per_worker = _split_lengths(
splits, input.keys(), input.offset_per_key()
)

self._output_split_sizes: List[int] = [
dim_0 * B_rank for B_rank in self._batch_size_per_rank
Expand All @@ -370,12 +385,13 @@ def _wait_impl(self) -> KJTAllToAllIndicesAwaitable:
if self._workers > 1:
self._lengths_awaitable.wait()
if self._variable_batch_size:
lengths_per_rank: List[torch.Tensor] = list(
self._lengths.split(self._output_split_sizes)
)
out_lengths_per_worker = [
int(length.sum().item()) for length in lengths_per_rank
]
with record_function("## all2all_data:split length for a2a ##"):
lengths_per_rank: List[torch.Tensor] = list(
self._lengths.split(self._output_split_sizes)
)
out_lengths_per_worker = [
int(length.sum().item()) for length in lengths_per_rank
]
else:
out_lengths_per_worker = (
self._lengths.view(self._workers, -1).sum(dim=1).cpu().tolist()
Expand Down Expand Up @@ -467,14 +483,11 @@ def __init__(
self._variable_batch_size = variable_batch_size
self.register_buffer(
"_recat",
torch.tensor(
_get_recat(
local_split=splits[pg.rank()],
num_splits=len(splits),
stagger=stagger,
),
_get_recat(
local_split=splits[pg.rank()],
num_splits=len(splits),
stagger=stagger,
device=device,
dtype=torch.int,
),
)

Expand Down

0 comments on commit 860d574

Please sign in to comment.