Skip to content

Commit f3efb3a

Browse files
authored
support pylayer for moe (#70375)
* support pylayer for moe * poolish * update_doc
1 parent b92ad06 commit f3efb3a

File tree

4 files changed

+278
-146
lines changed

4 files changed

+278
-146
lines changed

python/paddle/distributed/auto_parallel/api.py

Lines changed: 187 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -369,58 +369,97 @@ def forward(
369369
local_tensor_list,
370370
local_mesh_list,
371371
local_placements,
372-
idx,
373-
global_dims,
374372
mesh,
375373
placements,
374+
global_dims,
375+
idx=None,
376376
):
377-
local_tensor = local_tensor_list[idx]
378-
if local_tensor.is_dist():
379-
local_mesh = local_tensor.process_mesh
380-
local_val = local_tensor._local_value()
381-
else:
382-
local_val = local_tensor
383-
local_mesh = None
384-
385-
ctx.global_mesh = copy.deepcopy(mesh)
386-
ctx.placements = copy.deepcopy(placements)
387-
ctx.local_dims = local_tensor.shape
388-
ctx.local_mesh_list = copy.deepcopy(local_mesh_list)
389-
ctx.local_placements = copy.deepcopy(local_placements)
377+
# NOTE: _local_value/Paddle.Tensor is only supported in dynamic mode
378+
if paddle.in_dynamic_mode():
379+
local_tensor = local_tensor_list[idx]
380+
if local_tensor.is_dist():
381+
local_mesh = local_tensor.process_mesh
382+
local_val = local_tensor._local_value()
383+
else:
384+
local_val = local_tensor
385+
local_mesh = None
386+
387+
ctx.save_for_backward(
388+
copy.deepcopy(mesh), # global_mesh
389+
local_tensor.shape, # local_dims
390+
copy.deepcopy(local_mesh_list), # local_mesh_list
391+
copy.deepcopy(local_placements), # local_placements
392+
)
390393

391-
place = paddle.framework._current_expected_place()
392-
place = paddle.framework._get_paddle_place(place)
394+
place = paddle.framework._current_expected_place()
395+
place = paddle.framework._get_paddle_place(place)
393396

394-
global_tensor = paddle.Tensor(
395-
local_val,
396-
dims=global_dims,
397-
process_mesh=mesh,
398-
placements=placements,
399-
place=place,
400-
)
401-
global_tensor.stop_gradient = local_tensor.stop_gradient
402-
return global_tensor
397+
global_tensor = paddle.Tensor(
398+
local_val,
399+
dims=global_dims,
400+
process_mesh=mesh,
401+
placements=placements,
402+
place=place,
403+
)
404+
global_tensor.stop_gradient = local_tensor.stop_gradient
405+
return global_tensor
406+
else:
407+
ctx.save_for_backward(
408+
copy.deepcopy(mesh), # global_mesh
409+
copy.deepcopy(placements), # global_placements
410+
copy.deepcopy(local_mesh_list), # local_mesh_list
411+
copy.deepcopy(local_placements), # local_placements
412+
)
413+
dist_tensor = paddle._C_ops.moe_global_mesh_tensor(
414+
local_tensor_list,
415+
local_mesh_list,
416+
local_placements,
417+
mesh,
418+
placements,
419+
global_dims,
420+
)
421+
dist_tensor.stop_gradient = local_tensor_list[0].stop_gradient
422+
dist_tensor.persistable = local_tensor_list[0].persistable
423+
return dist_tensor
403424

404425
@staticmethod
405426
def backward(ctx, grad_tensor):
406-
if ctx.local_mesh_list is None:
407-
return grad_tensor._local_value()
408-
else:
409-
place = paddle.framework._current_expected_place()
410-
place = paddle.framework._get_paddle_place(place)
411-
out = []
412-
for i, local_mesh in enumerate(ctx.local_mesh_list):
413-
out.append(
414-
paddle.Tensor(
415-
grad_tensor._local_value(),
416-
dims=ctx.local_dims,
417-
process_mesh=local_mesh,
418-
placements=ctx.local_placements,
419-
place=place,
427+
if paddle.in_dynamic_mode():
428+
global_mesh, local_dims, local_mesh_list, local_placements = (
429+
ctx.saved_tensor()
430+
)
431+
if local_mesh_list is None:
432+
return grad_tensor._local_value()
433+
else:
434+
place = paddle.framework._current_expected_place()
435+
place = paddle.framework._get_paddle_place(place)
436+
out = []
437+
for i, local_mesh in enumerate(local_mesh_list):
438+
out.append(
439+
paddle.Tensor(
440+
grad_tensor._local_value(),
441+
dims=local_dims,
442+
process_mesh=local_mesh,
443+
placements=local_placements,
444+
place=place,
445+
)
420446
)
421-
)
422-
out[-1].get_tensor()._unsafe_set_skip_check_mesh(True)
423-
return out
447+
out[-1].get_tensor()._unsafe_set_skip_check_mesh(True)
448+
return out
449+
else:
450+
(
451+
global_mesh,
452+
global_placements,
453+
local_mesh_list,
454+
local_placements,
455+
) = ctx.saved_tensor()
456+
return paddle._C_ops.moe_sub_mesh_tensors(
457+
grad_tensor,
458+
local_mesh_list,
459+
local_placements,
460+
global_mesh,
461+
global_placements,
462+
)
424463

425464

426465
def _get_sub_meshes_and_local_placements(
@@ -469,6 +508,7 @@ def moe_global_mesh_tensor(
469508
local_tensor = local_tensor_list[local_tensor_idx]
470509

471510
if paddle.in_dynamic_mode():
511+
# NOTE: _local_value and Paddle.Tensor() is only supported in dynamic mode
472512
if local_coord[0].size == 0:
473513
local_tensor_shape = _cal_local_shape(
474514
local_tensor_list[0].shape, local_mesh_list[0], local_placements
@@ -498,27 +538,25 @@ def moe_global_mesh_tensor(
498538
resharded_local_tensor_list,
499539
local_mesh_list,
500540
local_placements,
501-
local_tensor_idx,
502-
global_dims,
503541
mesh,
504542
placements,
543+
global_dims,
544+
local_tensor_idx,
505545
)
506546
elif paddle.framework.in_pir_mode():
507547
global_dims = _cal_global_shape(
508548
local_tensor._local_shape, mesh, placements
509549
)
510-
dist_tensor = paddle._C_ops.moe_global_mesh_tensor(
550+
return paddle.jit.dy2static.py_layer.StaticPyLayer(
551+
_moe_global_mesh_tensor
552+
).apply(
511553
local_tensor_list,
512554
local_mesh_list,
513555
local_placements,
514556
mesh,
515557
placements,
516558
global_dims,
517559
)
518-
dist_tensor.stop_gradient = local_tensor_list[0].stop_gradient
519-
dist_tensor.persistable = local_tensor_list[0].persistable
520-
521-
return dist_tensor
522560
else:
523561
raise NotImplementedError(
524562
"dtensor_from_local_list() are only supported in dynamic and pir mode."
@@ -536,75 +574,114 @@ def forward(
536574
global_mesh=None,
537575
global_placements=None,
538576
):
539-
ctx.local_mesh_list = copy.deepcopy(local_mesh_list)
540-
ctx.local_placements = local_placements
541-
ctx.local_mesh_dim = local_mesh_dim
542-
ctx.global_mesh = copy.deepcopy(global_mesh)
543-
ctx.global_placements = global_placements
544-
ctx.global_shape = dist_tensor.shape
545-
546-
if global_mesh is None and global_placements is None:
547-
return dist_tensor._local_value()
548-
else:
549-
if global_mesh is None or global_placements is None:
550-
raise ValueError(
551-
"the args global_mesh and global_placements should be set together"
552-
)
553-
ori_mesh = dist_tensor.process_mesh
554-
if global_mesh != dist_tensor.process_mesh:
555-
raise ValueError(
556-
"the global_mesh should be the same as dist_tensor's process_mesh."
557-
)
558-
assert check_placements_equal(
559-
global_placements, dist_tensor.placements
560-
), f"the global_placements ({global_placements}) is not equal to dist_tensor's placements ({dist_tensor.placements})."
561-
local_shape = _cal_local_shape(
562-
dist_tensor.shape, global_mesh, global_placements
563-
)
564-
for idx, placement in enumerate(local_placements):
565-
if placement.is_shard():
566-
shard_dim = placement.get_dim()
567-
local_dim_size = local_shape[shard_dim]
568-
local_shape[shard_dim] = (
569-
local_dim_size * local_mesh_list[0].shape[idx]
577+
ctx.save_for_backward(
578+
copy.deepcopy(local_mesh_list), # local_mesh_list,
579+
local_placements, # local_placements,
580+
local_mesh_dim, # local_mesh_dim,
581+
copy.deepcopy(global_mesh), # global_mesh,
582+
global_placements, # global_placements,
583+
dist_tensor.shape, # global_shape,
584+
)
585+
if paddle.in_dynamic_mode():
586+
if global_mesh is None and global_placements is None:
587+
return dist_tensor._local_value()
588+
else:
589+
if global_mesh is None or global_placements is None:
590+
raise ValueError(
591+
"the args global_mesh and global_placements should be set together"
570592
)
571-
572-
place = paddle.framework._current_expected_place()
573-
place = paddle.framework._get_paddle_place(place)
574-
local_tensor_list = []
575-
for i, local_mesh in enumerate(local_mesh_list):
576-
local_tensor = paddle.Tensor(
577-
dist_tensor._local_value(),
578-
dims=local_shape,
579-
process_mesh=local_mesh,
580-
placements=local_placements,
581-
place=place,
593+
ori_mesh = dist_tensor.process_mesh
594+
if global_mesh != dist_tensor.process_mesh:
595+
raise ValueError(
596+
"the global_mesh should be the same as dist_tensor's process_mesh."
597+
)
598+
assert check_placements_equal(
599+
global_placements, dist_tensor.placements
600+
), f"the global_placements ({global_placements}) is not equal to dist_tensor's placements ({dist_tensor.placements})."
601+
local_shape = _cal_local_shape(
602+
dist_tensor.shape, global_mesh, global_placements
582603
)
583-
local_tensor.get_tensor()._unsafe_set_skip_check_mesh(True)
604+
for idx, placement in enumerate(local_placements):
605+
if placement.is_shard():
606+
shard_dim = placement.get_dim()
607+
local_dim_size = local_shape[shard_dim]
608+
local_shape[shard_dim] = (
609+
local_dim_size * local_mesh_list[0].shape[idx]
610+
)
611+
612+
place = paddle.framework._current_expected_place()
613+
place = paddle.framework._get_paddle_place(place)
614+
local_tensor_list = []
615+
for i, local_mesh in enumerate(local_mesh_list):
616+
local_tensor = paddle.Tensor(
617+
dist_tensor._local_value(),
618+
dims=local_shape,
619+
process_mesh=local_mesh,
620+
placements=local_placements,
621+
place=place,
622+
)
623+
local_tensor.get_tensor()._unsafe_set_skip_check_mesh(True)
624+
local_tensor.stop_gradient = dist_tensor.stop_gradient
625+
local_tensor_list.append(local_tensor)
626+
return local_tensor_list
627+
elif paddle.framework.in_pir_mode():
628+
local_tensors = paddle._C_ops.moe_sub_mesh_tensors(
629+
dist_tensor,
630+
local_mesh_list,
631+
local_placements,
632+
global_mesh,
633+
global_placements,
634+
)
635+
for local_tensor in local_tensors:
584636
local_tensor.stop_gradient = dist_tensor.stop_gradient
585-
local_tensor_list.append(local_tensor)
586-
return local_tensor_list
637+
local_tensor.persistable = dist_tensor.persistable
638+
return local_tensors
587639

588640
@staticmethod
589641
def backward(ctx, *grad_tensor):
642+
(
643+
local_mesh_list,
644+
local_placements,
645+
local_mesh_dim,
646+
global_mesh,
647+
global_placements,
648+
global_shape,
649+
) = ctx.saved_tensor()
590650
place = paddle.framework._current_expected_place()
591651
place = paddle.framework._get_paddle_place(place)
592-
mesh = ctx.global_mesh
652+
mesh = global_mesh
593653
process_ids = np.array(mesh.process_ids).reshape(mesh.shape)
594654
local_coord = np.where(process_ids == dist.get_rank())
595655
if local_coord[0].size == 0:
596656
local_tensor_idx = 0
597657
else:
598-
local_tensor_idx = local_coord[ctx.local_mesh_dim][0]
658+
local_tensor_idx = local_coord[local_mesh_dim][0]
599659
local_grad = grad_tensor[local_tensor_idx]
600-
global_tensor = paddle.Tensor(
601-
local_grad._local_value(),
602-
dims=ctx.global_shape,
603-
process_mesh=mesh,
604-
placements=ctx.global_placements,
605-
place=place,
606-
)
607-
return global_tensor
660+
661+
if paddle.in_dynamic_mode():
662+
place = paddle.framework._current_expected_place()
663+
place = paddle.framework._get_paddle_place(place)
664+
global_tensor = paddle.Tensor(
665+
local_grad._local_value(),
666+
dims=global_shape,
667+
process_mesh=mesh,
668+
placements=global_placements,
669+
place=place,
670+
)
671+
return global_tensor
672+
elif paddle.framework.in_pir_mode():
673+
global_dims = _cal_global_shape(
674+
local_grad._local_shape, mesh, global_placements
675+
)
676+
677+
return paddle._C_ops.moe_global_mesh_tensor(
678+
grad_tensor,
679+
local_mesh_list,
680+
local_placements,
681+
global_mesh,
682+
global_placements,
683+
global_dims,
684+
)
608685

609686

610687
def moe_sub_mesh_tensors(
@@ -627,17 +704,17 @@ def moe_sub_mesh_tensors(
627704
global_placements,
628705
)
629706
elif paddle.framework.in_pir_mode():
630-
local_tensors = paddle._C_ops.moe_sub_mesh_tensors(
707+
708+
return paddle.jit.dy2static.py_layer.StaticPyLayer(
709+
_moe_sub_mesh_tensors
710+
).apply(
631711
dist_tensor,
632712
local_mesh_list,
633713
local_placements,
714+
local_mesh_dim,
634715
global_mesh,
635716
global_placements,
636717
)
637-
for local_tensor in local_tensors:
638-
local_tensor.stop_gradient = dist_tensor.stop_gradient
639-
local_tensor.persistable = dist_tensor.persistable
640-
return local_tensors
641718
else:
642719
raise NotImplementedError(
643720
"moe_sub_mesh_tensors is only supported in dynamic mode."

0 commit comments

Comments
 (0)