Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support pylayer for moe #70375

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
297 changes: 187 additions & 110 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,58 +369,97 @@ def forward(
local_tensor_list,
local_mesh_list,
local_placements,
idx,
global_dims,
mesh,
placements,
global_dims,
idx=-1,
):
local_tensor = local_tensor_list[idx]
if local_tensor.is_dist():
local_mesh = local_tensor.process_mesh
local_val = local_tensor._local_value()
else:
local_val = local_tensor
local_mesh = None

ctx.global_mesh = copy.deepcopy(mesh)
ctx.placements = copy.deepcopy(placements)
ctx.local_dims = local_tensor.shape
ctx.local_mesh_list = copy.deepcopy(local_mesh_list)
ctx.local_placements = copy.deepcopy(local_placements)
# NOTE: _local_value/Paddle.Tensor is only supported in dynamic mode
if paddle.in_dynamic_mode():
local_tensor = local_tensor_list[idx]
if local_tensor.is_dist():
local_mesh = local_tensor.process_mesh
local_val = local_tensor._local_value()
else:
local_val = local_tensor
local_mesh = None

ctx.save_for_backward(
copy.deepcopy(mesh), # global_mesh
local_tensor.shape, # local_dims
copy.deepcopy(local_mesh_list), # local_mesh_list
copy.deepcopy(local_placements), # local_placements
)

place = paddle.framework._current_expected_place()
place = paddle.framework._get_paddle_place(place)
place = paddle.framework._current_expected_place()
place = paddle.framework._get_paddle_place(place)

global_tensor = paddle.Tensor(
local_val,
dims=global_dims,
process_mesh=mesh,
placements=placements,
place=place,
)
global_tensor.stop_gradient = local_tensor.stop_gradient
return global_tensor
global_tensor = paddle.Tensor(
local_val,
dims=global_dims,
process_mesh=mesh,
placements=placements,
place=place,
)
global_tensor.stop_gradient = local_tensor.stop_gradient
return global_tensor
else:
ctx.save_for_backward(
copy.deepcopy(mesh), # global_mesh
copy.deepcopy(placements), # global_placements
copy.deepcopy(local_mesh_list), # local_mesh_list
copy.deepcopy(local_placements), # local_placements
)
dist_tensor = paddle._C_ops.moe_global_mesh_tensor(
local_tensor_list,
local_mesh_list,
local_placements,
mesh,
placements,
global_dims,
)
dist_tensor.stop_gradient = local_tensor_list[0].stop_gradient
dist_tensor.persistable = local_tensor_list[0].persistable
return dist_tensor

@staticmethod
def backward(ctx, grad_tensor):
if ctx.local_mesh_list is None:
return grad_tensor._local_value()
else:
place = paddle.framework._current_expected_place()
place = paddle.framework._get_paddle_place(place)
out = []
for i, local_mesh in enumerate(ctx.local_mesh_list):
out.append(
paddle.Tensor(
grad_tensor._local_value(),
dims=ctx.local_dims,
process_mesh=local_mesh,
placements=ctx.local_placements,
place=place,
if paddle.in_dynamic_mode():
global_mesh, local_dims, local_mesh_list, local_placements = (
ctx.saved_tensor()
)
if local_mesh_list is None:
return grad_tensor._local_value()
else:
place = paddle.framework._current_expected_place()
place = paddle.framework._get_paddle_place(place)
out = []
for i, local_mesh in enumerate(local_mesh_list):
out.append(
paddle.Tensor(
grad_tensor._local_value(),
dims=local_dims,
process_mesh=local_mesh,
placements=local_placements,
place=place,
)
)
)
out[-1].get_tensor()._unsafe_set_skip_check_mesh(True)
return out
out[-1].get_tensor()._unsafe_set_skip_check_mesh(True)
return out
else:
(
global_mesh,
global_placements,
local_mesh_list,
local_placements,
) = ctx.saved_tensor()
return paddle._C_ops.moe_sub_mesh_tensors(
grad_tensor,
local_mesh_list,
local_placements,
global_mesh,
global_placements,
)


def _get_sub_meshes_and_local_placements(
Expand Down Expand Up @@ -469,6 +508,7 @@ def moe_global_mesh_tensor(
local_tensor = local_tensor_list[local_tensor_idx]

if paddle.in_dynamic_mode():
# NOTE: _local_value and Paddle.Tensor() is only supported in dynamic mode
if local_coord[0].size == 0:
local_tensor_shape = _cal_local_shape(
local_tensor_list[0].shape, local_mesh_list[0], local_placements
Expand Down Expand Up @@ -498,27 +538,25 @@ def moe_global_mesh_tensor(
resharded_local_tensor_list,
local_mesh_list,
local_placements,
local_tensor_idx,
global_dims,
mesh,
placements,
global_dims,
local_tensor_idx,
)
elif paddle.framework.in_pir_mode():
global_dims = _cal_global_shape(
local_tensor._local_shape, mesh, placements
)
dist_tensor = paddle._C_ops.moe_global_mesh_tensor(
return paddle.jit.dy2static.py_layer.StaticPyLayer(
_moe_global_mesh_tensor
).apply(
local_tensor_list,
local_mesh_list,
local_placements,
mesh,
placements,
global_dims,
)
dist_tensor.stop_gradient = local_tensor_list[0].stop_gradient
dist_tensor.persistable = local_tensor_list[0].persistable

return dist_tensor
else:
raise NotImplementedError(
"dtensor_from_local_list() are only supported in dynamic and pir mode."
Expand All @@ -536,75 +574,114 @@ def forward(
global_mesh=None,
global_placements=None,
):
ctx.local_mesh_list = copy.deepcopy(local_mesh_list)
ctx.local_placements = local_placements
ctx.local_mesh_dim = local_mesh_dim
ctx.global_mesh = copy.deepcopy(global_mesh)
ctx.global_placements = global_placements
ctx.global_shape = dist_tensor.shape

if global_mesh is None and global_placements is None:
return dist_tensor._local_value()
else:
if global_mesh is None or global_placements is None:
raise ValueError(
"the args global_mesh and global_placements should be set together"
)
ori_mesh = dist_tensor.process_mesh
if global_mesh != dist_tensor.process_mesh:
raise ValueError(
"the global_mesh should be the same as dist_tensor's process_mesh."
)
assert check_placements_equal(
global_placements, dist_tensor.placements
), f"the global_placements ({global_placements}) is not equal to dist_tensor's placements ({dist_tensor.placements})."
local_shape = _cal_local_shape(
dist_tensor.shape, global_mesh, global_placements
)
for idx, placement in enumerate(local_placements):
if placement.is_shard():
shard_dim = placement.get_dim()
local_dim_size = local_shape[shard_dim]
local_shape[shard_dim] = (
local_dim_size * local_mesh_list[0].shape[idx]
ctx.save_for_backward(
copy.deepcopy(local_mesh_list), # local_mesh_list,
local_placements, # local_placements,
local_mesh_dim, # local_mesh_dim,
copy.deepcopy(global_mesh), # global_mesh,
global_placements, # global_placements,
dist_tensor.shape, # global_shape,
)
if paddle.in_dynamic_mode():
if global_mesh is None and global_placements is None:
return dist_tensor._local_value()
else:
if global_mesh is None or global_placements is None:
raise ValueError(
"the args global_mesh and global_placements should be set together"
)

place = paddle.framework._current_expected_place()
place = paddle.framework._get_paddle_place(place)
local_tensor_list = []
for i, local_mesh in enumerate(local_mesh_list):
local_tensor = paddle.Tensor(
dist_tensor._local_value(),
dims=local_shape,
process_mesh=local_mesh,
placements=local_placements,
place=place,
ori_mesh = dist_tensor.process_mesh
if global_mesh != dist_tensor.process_mesh:
raise ValueError(
"the global_mesh should be the same as dist_tensor's process_mesh."
)
assert check_placements_equal(
global_placements, dist_tensor.placements
), f"the global_placements ({global_placements}) is not equal to dist_tensor's placements ({dist_tensor.placements})."
local_shape = _cal_local_shape(
dist_tensor.shape, global_mesh, global_placements
)
local_tensor.get_tensor()._unsafe_set_skip_check_mesh(True)
for idx, placement in enumerate(local_placements):
if placement.is_shard():
shard_dim = placement.get_dim()
local_dim_size = local_shape[shard_dim]
local_shape[shard_dim] = (
local_dim_size * local_mesh_list[0].shape[idx]
)

place = paddle.framework._current_expected_place()
place = paddle.framework._get_paddle_place(place)
local_tensor_list = []
for i, local_mesh in enumerate(local_mesh_list):
local_tensor = paddle.Tensor(
dist_tensor._local_value(),
dims=local_shape,
process_mesh=local_mesh,
placements=local_placements,
place=place,
)
local_tensor.get_tensor()._unsafe_set_skip_check_mesh(True)
local_tensor.stop_gradient = dist_tensor.stop_gradient
local_tensor_list.append(local_tensor)
return local_tensor_list
elif paddle.framework.in_pir_mode():
local_tensors = paddle._C_ops.moe_sub_mesh_tensors(
dist_tensor,
local_mesh_list,
local_placements,
global_mesh,
global_placements,
)
for local_tensor in local_tensors:
local_tensor.stop_gradient = dist_tensor.stop_gradient
local_tensor_list.append(local_tensor)
return local_tensor_list
local_tensor.persistable = dist_tensor.persistable
return local_tensors

@staticmethod
def backward(ctx, *grad_tensor):
(
local_mesh_list,
local_placements,
local_mesh_dim,
global_mesh,
global_placements,
global_shape,
) = ctx.saved_tensor()
place = paddle.framework._current_expected_place()
place = paddle.framework._get_paddle_place(place)
mesh = ctx.global_mesh
mesh = global_mesh
process_ids = np.array(mesh.process_ids).reshape(mesh.shape)
local_coord = np.where(process_ids == dist.get_rank())
if local_coord[0].size == 0:
local_tensor_idx = 0
else:
local_tensor_idx = local_coord[ctx.local_mesh_dim][0]
local_tensor_idx = local_coord[local_mesh_dim][0]
local_grad = grad_tensor[local_tensor_idx]
global_tensor = paddle.Tensor(
local_grad._local_value(),
dims=ctx.global_shape,
process_mesh=mesh,
placements=ctx.global_placements,
place=place,
)
return global_tensor

if paddle.in_dynamic_mode():
place = paddle.framework._current_expected_place()
place = paddle.framework._get_paddle_place(place)
global_tensor = paddle.Tensor(
local_grad._local_value(),
dims=global_shape,
process_mesh=mesh,
placements=global_placements,
place=place,
)
return global_tensor
elif paddle.framework.in_pir_mode():
global_dims = _cal_global_shape(
local_grad._local_shape, mesh, global_placements
)

return paddle._C_ops.moe_global_mesh_tensor(
grad_tensor,
local_mesh_list,
local_placements,
global_mesh,
global_placements,
global_dims,
)


def moe_sub_mesh_tensors(
Expand All @@ -627,17 +704,17 @@ def moe_sub_mesh_tensors(
global_placements,
)
elif paddle.framework.in_pir_mode():
local_tensors = paddle._C_ops.moe_sub_mesh_tensors(

return paddle.jit.dy2static.py_layer.StaticPyLayer(
_moe_sub_mesh_tensors
).apply(
dist_tensor,
local_mesh_list,
local_placements,
local_mesh_dim,
global_mesh,
global_placements,
)
for local_tensor in local_tensors:
local_tensor.stop_gradient = dist_tensor.stop_gradient
local_tensor.persistable = dist_tensor.persistable
return local_tensors
else:
raise NotImplementedError(
"moe_sub_mesh_tensors is only supported in dynamic mode."
Expand Down
Loading