diff --git a/python/paddle/distributed/auto_parallel/api.py b/python/paddle/distributed/auto_parallel/api.py index 38b330de7b768..d7efc797d4ab5 100644 --- a/python/paddle/distributed/auto_parallel/api.py +++ b/python/paddle/distributed/auto_parallel/api.py @@ -369,58 +369,97 @@ def forward( local_tensor_list, local_mesh_list, local_placements, - idx, - global_dims, mesh, placements, + global_dims, + idx=-1, ): - local_tensor = local_tensor_list[idx] - if local_tensor.is_dist(): - local_mesh = local_tensor.process_mesh - local_val = local_tensor._local_value() - else: - local_val = local_tensor - local_mesh = None - - ctx.global_mesh = copy.deepcopy(mesh) - ctx.placements = copy.deepcopy(placements) - ctx.local_dims = local_tensor.shape - ctx.local_mesh_list = copy.deepcopy(local_mesh_list) - ctx.local_placements = copy.deepcopy(local_placements) + # NOTE: _local_value/Paddle.Tensor is only supported in dynamic mode + if paddle.in_dynamic_mode(): + local_tensor = local_tensor_list[idx] + if local_tensor.is_dist(): + local_mesh = local_tensor.process_mesh + local_val = local_tensor._local_value() + else: + local_val = local_tensor + local_mesh = None + + ctx.save_for_backward( + copy.deepcopy(mesh), # global_mesh + local_tensor.shape, # local_dims + copy.deepcopy(local_mesh_list), # local_mesh_list + copy.deepcopy(local_placements), # local_placements + ) - place = paddle.framework._current_expected_place() - place = paddle.framework._get_paddle_place(place) + place = paddle.framework._current_expected_place() + place = paddle.framework._get_paddle_place(place) - global_tensor = paddle.Tensor( - local_val, - dims=global_dims, - process_mesh=mesh, - placements=placements, - place=place, - ) - global_tensor.stop_gradient = local_tensor.stop_gradient - return global_tensor + global_tensor = paddle.Tensor( + local_val, + dims=global_dims, + process_mesh=mesh, + placements=placements, + place=place, + ) + global_tensor.stop_gradient = local_tensor.stop_gradient + return global_tensor + else: + ctx.save_for_backward( + copy.deepcopy(mesh), # global_mesh + copy.deepcopy(placements), # global_placements + copy.deepcopy(local_mesh_list), # local_mesh_list + copy.deepcopy(local_placements), # local_placements + ) + dist_tensor = paddle._C_ops.moe_global_mesh_tensor( + local_tensor_list, + local_mesh_list, + local_placements, + mesh, + placements, + global_dims, + ) + dist_tensor.stop_gradient = local_tensor_list[0].stop_gradient + dist_tensor.persistable = local_tensor_list[0].persistable + return dist_tensor @staticmethod def backward(ctx, grad_tensor): - if ctx.local_mesh_list is None: - return grad_tensor._local_value() - else: - place = paddle.framework._current_expected_place() - place = paddle.framework._get_paddle_place(place) - out = [] - for i, local_mesh in enumerate(ctx.local_mesh_list): - out.append( - paddle.Tensor( - grad_tensor._local_value(), - dims=ctx.local_dims, - process_mesh=local_mesh, - placements=ctx.local_placements, - place=place, + if paddle.in_dynamic_mode(): + global_mesh, local_dims, local_mesh_list, local_placements = ( + ctx.saved_tensor() + ) + if local_mesh_list is None: + return grad_tensor._local_value() + else: + place = paddle.framework._current_expected_place() + place = paddle.framework._get_paddle_place(place) + out = [] + for i, local_mesh in enumerate(local_mesh_list): + out.append( + paddle.Tensor( + grad_tensor._local_value(), + dims=local_dims, + process_mesh=local_mesh, + placements=local_placements, + place=place, + ) ) - ) - out[-1].get_tensor()._unsafe_set_skip_check_mesh(True) - return out + out[-1].get_tensor()._unsafe_set_skip_check_mesh(True) + return out + else: + ( + global_mesh, + global_placements, + local_mesh_list, + local_placements, + ) = ctx.saved_tensor() + return paddle._C_ops.moe_sub_mesh_tensors( + grad_tensor, + local_mesh_list, + local_placements, + global_mesh, + global_placements, + ) def _get_sub_meshes_and_local_placements( @@ -469,6 +508,7 @@ def moe_global_mesh_tensor( local_tensor = local_tensor_list[local_tensor_idx] if paddle.in_dynamic_mode(): + # NOTE: _local_value and Paddle.Tensor() is only supported in dynamic mode if local_coord[0].size == 0: local_tensor_shape = _cal_local_shape( local_tensor_list[0].shape, local_mesh_list[0], local_placements @@ -498,16 +538,18 @@ def moe_global_mesh_tensor( resharded_local_tensor_list, local_mesh_list, local_placements, - local_tensor_idx, - global_dims, mesh, placements, + global_dims, + local_tensor_idx, ) elif paddle.framework.in_pir_mode(): global_dims = _cal_global_shape( local_tensor._local_shape, mesh, placements ) - dist_tensor = paddle._C_ops.moe_global_mesh_tensor( + return paddle.jit.dy2static.py_layer.StaticPyLayer( + _moe_global_mesh_tensor + ).apply( local_tensor_list, local_mesh_list, local_placements, @@ -515,10 +557,6 @@ def moe_global_mesh_tensor( placements, global_dims, ) - dist_tensor.stop_gradient = local_tensor_list[0].stop_gradient - dist_tensor.persistable = local_tensor_list[0].persistable - - return dist_tensor else: raise NotImplementedError( "dtensor_from_local_list() are only supported in dynamic and pir mode." @@ -536,75 +574,114 @@ def forward( global_mesh=None, global_placements=None, ): - ctx.local_mesh_list = copy.deepcopy(local_mesh_list) - ctx.local_placements = local_placements - ctx.local_mesh_dim = local_mesh_dim - ctx.global_mesh = copy.deepcopy(global_mesh) - ctx.global_placements = global_placements - ctx.global_shape = dist_tensor.shape - - if global_mesh is None and global_placements is None: - return dist_tensor._local_value() - else: - if global_mesh is None or global_placements is None: - raise ValueError( - "the args global_mesh and global_placements should be set together" - ) - ori_mesh = dist_tensor.process_mesh - if global_mesh != dist_tensor.process_mesh: - raise ValueError( - "the global_mesh should be the same as dist_tensor's process_mesh." - ) - assert check_placements_equal( - global_placements, dist_tensor.placements - ), f"the global_placements ({global_placements}) is not equal to dist_tensor's placements ({dist_tensor.placements})." - local_shape = _cal_local_shape( - dist_tensor.shape, global_mesh, global_placements - ) - for idx, placement in enumerate(local_placements): - if placement.is_shard(): - shard_dim = placement.get_dim() - local_dim_size = local_shape[shard_dim] - local_shape[shard_dim] = ( - local_dim_size * local_mesh_list[0].shape[idx] + ctx.save_for_backward( + copy.deepcopy(local_mesh_list), # local_mesh_list, + local_placements, # local_placements, + local_mesh_dim, # local_mesh_dim, + copy.deepcopy(global_mesh), # global_mesh, + global_placements, # global_placements, + dist_tensor.shape, # global_shape, + ) + if paddle.in_dynamic_mode(): + if global_mesh is None and global_placements is None: + return dist_tensor._local_value() + else: + if global_mesh is None or global_placements is None: + raise ValueError( + "the args global_mesh and global_placements should be set together" ) - - place = paddle.framework._current_expected_place() - place = paddle.framework._get_paddle_place(place) - local_tensor_list = [] - for i, local_mesh in enumerate(local_mesh_list): - local_tensor = paddle.Tensor( - dist_tensor._local_value(), - dims=local_shape, - process_mesh=local_mesh, - placements=local_placements, - place=place, + ori_mesh = dist_tensor.process_mesh + if global_mesh != dist_tensor.process_mesh: + raise ValueError( + "the global_mesh should be the same as dist_tensor's process_mesh." + ) + assert check_placements_equal( + global_placements, dist_tensor.placements + ), f"the global_placements ({global_placements}) is not equal to dist_tensor's placements ({dist_tensor.placements})." + local_shape = _cal_local_shape( + dist_tensor.shape, global_mesh, global_placements ) - local_tensor.get_tensor()._unsafe_set_skip_check_mesh(True) + for idx, placement in enumerate(local_placements): + if placement.is_shard(): + shard_dim = placement.get_dim() + local_dim_size = local_shape[shard_dim] + local_shape[shard_dim] = ( + local_dim_size * local_mesh_list[0].shape[idx] + ) + + place = paddle.framework._current_expected_place() + place = paddle.framework._get_paddle_place(place) + local_tensor_list = [] + for i, local_mesh in enumerate(local_mesh_list): + local_tensor = paddle.Tensor( + dist_tensor._local_value(), + dims=local_shape, + process_mesh=local_mesh, + placements=local_placements, + place=place, + ) + local_tensor.get_tensor()._unsafe_set_skip_check_mesh(True) + local_tensor.stop_gradient = dist_tensor.stop_gradient + local_tensor_list.append(local_tensor) + return local_tensor_list + elif paddle.framework.in_pir_mode(): + local_tensors = paddle._C_ops.moe_sub_mesh_tensors( + dist_tensor, + local_mesh_list, + local_placements, + global_mesh, + global_placements, + ) + for local_tensor in local_tensors: local_tensor.stop_gradient = dist_tensor.stop_gradient - local_tensor_list.append(local_tensor) - return local_tensor_list + local_tensor.persistable = dist_tensor.persistable + return local_tensors @staticmethod def backward(ctx, *grad_tensor): + ( + local_mesh_list, + local_placements, + local_mesh_dim, + global_mesh, + global_placements, + global_shape, + ) = ctx.saved_tensor() place = paddle.framework._current_expected_place() place = paddle.framework._get_paddle_place(place) - mesh = ctx.global_mesh + mesh = global_mesh process_ids = np.array(mesh.process_ids).reshape(mesh.shape) local_coord = np.where(process_ids == dist.get_rank()) if local_coord[0].size == 0: local_tensor_idx = 0 else: - local_tensor_idx = local_coord[ctx.local_mesh_dim][0] + local_tensor_idx = local_coord[local_mesh_dim][0] local_grad = grad_tensor[local_tensor_idx] - global_tensor = paddle.Tensor( - local_grad._local_value(), - dims=ctx.global_shape, - process_mesh=mesh, - placements=ctx.global_placements, - place=place, - ) - return global_tensor + + if paddle.in_dynamic_mode(): + place = paddle.framework._current_expected_place() + place = paddle.framework._get_paddle_place(place) + global_tensor = paddle.Tensor( + local_grad._local_value(), + dims=global_shape, + process_mesh=mesh, + placements=global_placements, + place=place, + ) + return global_tensor + elif paddle.framework.in_pir_mode(): + global_dims = _cal_global_shape( + local_grad._local_shape, mesh, global_placements + ) + + return paddle._C_ops.moe_global_mesh_tensor( + grad_tensor, + local_mesh_list, + local_placements, + global_mesh, + global_placements, + global_dims, + ) def moe_sub_mesh_tensors( @@ -627,17 +704,17 @@ def moe_sub_mesh_tensors( global_placements, ) elif paddle.framework.in_pir_mode(): - local_tensors = paddle._C_ops.moe_sub_mesh_tensors( + + return paddle.jit.dy2static.py_layer.StaticPyLayer( + _moe_sub_mesh_tensors + ).apply( dist_tensor, local_mesh_list, local_placements, + local_mesh_dim, global_mesh, global_placements, ) - for local_tensor in local_tensors: - local_tensor.stop_gradient = dist_tensor.stop_gradient - local_tensor.persistable = dist_tensor.persistable - return local_tensors else: raise NotImplementedError( "moe_sub_mesh_tensors is only supported in dynamic mode." diff --git a/python/paddle/distributed/auto_parallel/static/pir_pass.py b/python/paddle/distributed/auto_parallel/static/pir_pass.py index 080a9015c4d79..c340e39b918b8 100644 --- a/python/paddle/distributed/auto_parallel/static/pir_pass.py +++ b/python/paddle/distributed/auto_parallel/static/pir_pass.py @@ -405,46 +405,79 @@ def replace_moe_sub_mesh_tensors(op): ) ) + # update pylayer op by removing the unused outputs + def update_pylayer_output(trival_value): + define_op = trival_value.get_defining_op() + if define_op.get_parent_block().parent_op.name() != "pd_op.pylayer": + return + paddle.pir.set_insertion_point(define_op) + fake_value = paddle.static.data( + name="_fake_pylayer_out", + shape=trival_value.shape, + dtype=trival_value.dtype, + ) + fake_value.set_type(trival_value.type()) + trival_value.replace_all_uses_with(fake_value) + + for val in op.results(): + if not val.use_empty(): + update_pylayer_output(val) + assert all(val.use_empty() for val in op.results()) op.erase() +def remove_sub_block_unused_inputs(op): + inputs_size = op.operand_source.num_operands() + inputs = [op.operand_source(i) for i in range(inputs_size)] + # remove unused inputs + + class RemovePasses: + @staticmethod def remove_other_rank_op_pass(dist_program): # pruning op and value not belong to cur rank - cur_rank = paddle.distributed.get_rank() + def prune_op(block): + cur_rank = paddle.distributed.get_rank() + for op in block.ops[::-1]: + if op.name() == "dist_op.moe_sub_mesh_tensors": + replace_moe_sub_mesh_tensors(op) + continue + elif op.name() == "dist_op.moe_global_mesh_tensor": + replace_moe_global_mesh_tensor(op) + continue + elif op.name() == "cf.tuple_push": + stack_create_op = op.operand_source(0).get_defining_op() + if stack_create_op.result(2).use_empty(): + op.erase() + continue + elif op.name() == "cf.yield": + continue + elif op.name() == "pd_op.pylayer": + for pylayer_block in list(op.blocks())[::-1]: + prune_op(pylayer_block) + # update pylayer op's inputs + op.as_pylayer_op().update_input() + continue + elif op.name() in partition_skip_op_list: + can_delete = True + for val in op.results(): + if not val.use_empty(): + can_delete = False + if can_delete: + op.erase() + continue - for op in dist_program.global_block().ops[::-1]: - if op.name() == "dist_op.moe_sub_mesh_tensors": - replace_moe_sub_mesh_tensors(op) - continue - elif op.name() == "dist_op.moe_global_mesh_tensor": - replace_moe_global_mesh_tensor(op) - continue - elif op.name() == "cf.tuple_push": - stack_create_op = op.operand_source(0).get_defining_op() - if stack_create_op.result(2).use_empty(): + if cur_rank not in op.dist_attr.process_mesh.process_ids: op.erase() - continue - elif op.name() == "cf.yield": - continue - elif op.name() in partition_skip_op_list: - can_delete = True - for val in op.results(): - if not val.use_empty(): - can_delete = False - if can_delete: + elif op.name() == "dist_op.reshard": + assert op.result( + 0 + ).use_empty(), f'There should not have useful dist.reshard op in remove_other_rank_op_pass. but find : {op}' op.erase() - continue - if cur_rank not in op.dist_attr.process_mesh.process_ids: - op.erase() - elif op.name() == "dist_op.reshard": - assert op.result( - 0 - ).use_empty(), f'There should not have useful dist.reshard op in remove_other_rank_op_pass. but find : {op}' - op.erase() + prune_op(dist_program.global_block()) # merge pd.data ops for lr_ops = [] diff --git a/test/auto_parallel/pir/test_moe_api.py b/test/auto_parallel/pir/test_moe_api.py index ceeb2d3c8104d..59b8cc19e7629 100644 --- a/test/auto_parallel/pir/test_moe_api.py +++ b/test/auto_parallel/pir/test_moe_api.py @@ -128,16 +128,13 @@ def check_results( local_dims_mapping, ): # local_tensors_from_dtensor op - self.check_dist_attr(ops[2], local_meshes, local_dims_mapping) - + self.check_dist_attr(ops[4], local_meshes, local_dims_mapping) # dtensor_from_local_list op - self.check_dist_attr(ops[3], [global_mesh], global_dims_mapping) - + self.check_dist_attr(ops[5], [global_mesh], global_dims_mapping) # grad op for dtensor_from_local_list - self.check_dist_attr(ops[8], local_meshes, local_dims_mapping) - + self.check_dist_attr(ops[10], local_meshes, local_dims_mapping) # grad op for local_tensors_from_dtensor op - self.check_dist_attr(ops[9], [global_mesh], global_dims_mapping) + self.check_dist_attr(ops[11], [global_mesh], global_dims_mapping) if __name__ == "__main__":