@@ -369,58 +369,97 @@ def forward(
369
369
local_tensor_list ,
370
370
local_mesh_list ,
371
371
local_placements ,
372
- idx ,
373
- global_dims ,
374
372
mesh ,
375
373
placements ,
374
+ global_dims ,
375
+ idx = None ,
376
376
):
377
- local_tensor = local_tensor_list [idx ]
378
- if local_tensor .is_dist ():
379
- local_mesh = local_tensor .process_mesh
380
- local_val = local_tensor ._local_value ()
381
- else :
382
- local_val = local_tensor
383
- local_mesh = None
384
-
385
- ctx .global_mesh = copy .deepcopy (mesh )
386
- ctx .placements = copy .deepcopy (placements )
387
- ctx .local_dims = local_tensor .shape
388
- ctx .local_mesh_list = copy .deepcopy (local_mesh_list )
389
- ctx .local_placements = copy .deepcopy (local_placements )
377
+ # NOTE: _local_value/Paddle.Tensor is only supported in dynamic mode
378
+ if paddle .in_dynamic_mode ():
379
+ local_tensor = local_tensor_list [idx ]
380
+ if local_tensor .is_dist ():
381
+ local_mesh = local_tensor .process_mesh
382
+ local_val = local_tensor ._local_value ()
383
+ else :
384
+ local_val = local_tensor
385
+ local_mesh = None
386
+
387
+ ctx .save_for_backward (
388
+ copy .deepcopy (mesh ), # global_mesh
389
+ local_tensor .shape , # local_dims
390
+ copy .deepcopy (local_mesh_list ), # local_mesh_list
391
+ copy .deepcopy (local_placements ), # local_placements
392
+ )
390
393
391
- place = paddle .framework ._current_expected_place ()
392
- place = paddle .framework ._get_paddle_place (place )
394
+ place = paddle .framework ._current_expected_place ()
395
+ place = paddle .framework ._get_paddle_place (place )
393
396
394
- global_tensor = paddle .Tensor (
395
- local_val ,
396
- dims = global_dims ,
397
- process_mesh = mesh ,
398
- placements = placements ,
399
- place = place ,
400
- )
401
- global_tensor .stop_gradient = local_tensor .stop_gradient
402
- return global_tensor
397
+ global_tensor = paddle .Tensor (
398
+ local_val ,
399
+ dims = global_dims ,
400
+ process_mesh = mesh ,
401
+ placements = placements ,
402
+ place = place ,
403
+ )
404
+ global_tensor .stop_gradient = local_tensor .stop_gradient
405
+ return global_tensor
406
+ else :
407
+ ctx .save_for_backward (
408
+ copy .deepcopy (mesh ), # global_mesh
409
+ copy .deepcopy (placements ), # global_placements
410
+ copy .deepcopy (local_mesh_list ), # local_mesh_list
411
+ copy .deepcopy (local_placements ), # local_placements
412
+ )
413
+ dist_tensor = paddle ._C_ops .moe_global_mesh_tensor (
414
+ local_tensor_list ,
415
+ local_mesh_list ,
416
+ local_placements ,
417
+ mesh ,
418
+ placements ,
419
+ global_dims ,
420
+ )
421
+ dist_tensor .stop_gradient = local_tensor_list [0 ].stop_gradient
422
+ dist_tensor .persistable = local_tensor_list [0 ].persistable
423
+ return dist_tensor
403
424
404
425
@staticmethod
405
426
def backward (ctx , grad_tensor ):
406
- if ctx .local_mesh_list is None :
407
- return grad_tensor ._local_value ()
408
- else :
409
- place = paddle .framework ._current_expected_place ()
410
- place = paddle .framework ._get_paddle_place (place )
411
- out = []
412
- for i , local_mesh in enumerate (ctx .local_mesh_list ):
413
- out .append (
414
- paddle .Tensor (
415
- grad_tensor ._local_value (),
416
- dims = ctx .local_dims ,
417
- process_mesh = local_mesh ,
418
- placements = ctx .local_placements ,
419
- place = place ,
427
+ if paddle .in_dynamic_mode ():
428
+ global_mesh , local_dims , local_mesh_list , local_placements = (
429
+ ctx .saved_tensor ()
430
+ )
431
+ if local_mesh_list is None :
432
+ return grad_tensor ._local_value ()
433
+ else :
434
+ place = paddle .framework ._current_expected_place ()
435
+ place = paddle .framework ._get_paddle_place (place )
436
+ out = []
437
+ for i , local_mesh in enumerate (local_mesh_list ):
438
+ out .append (
439
+ paddle .Tensor (
440
+ grad_tensor ._local_value (),
441
+ dims = local_dims ,
442
+ process_mesh = local_mesh ,
443
+ placements = local_placements ,
444
+ place = place ,
445
+ )
420
446
)
421
- )
422
- out [- 1 ].get_tensor ()._unsafe_set_skip_check_mesh (True )
423
- return out
447
+ out [- 1 ].get_tensor ()._unsafe_set_skip_check_mesh (True )
448
+ return out
449
+ else :
450
+ (
451
+ global_mesh ,
452
+ global_placements ,
453
+ local_mesh_list ,
454
+ local_placements ,
455
+ ) = ctx .saved_tensor ()
456
+ return paddle ._C_ops .moe_sub_mesh_tensors (
457
+ grad_tensor ,
458
+ local_mesh_list ,
459
+ local_placements ,
460
+ global_mesh ,
461
+ global_placements ,
462
+ )
424
463
425
464
426
465
def _get_sub_meshes_and_local_placements (
@@ -469,6 +508,7 @@ def moe_global_mesh_tensor(
469
508
local_tensor = local_tensor_list [local_tensor_idx ]
470
509
471
510
if paddle .in_dynamic_mode ():
511
+ # NOTE: _local_value and Paddle.Tensor() is only supported in dynamic mode
472
512
if local_coord [0 ].size == 0 :
473
513
local_tensor_shape = _cal_local_shape (
474
514
local_tensor_list [0 ].shape , local_mesh_list [0 ], local_placements
@@ -498,27 +538,25 @@ def moe_global_mesh_tensor(
498
538
resharded_local_tensor_list ,
499
539
local_mesh_list ,
500
540
local_placements ,
501
- local_tensor_idx ,
502
- global_dims ,
503
541
mesh ,
504
542
placements ,
543
+ global_dims ,
544
+ local_tensor_idx ,
505
545
)
506
546
elif paddle .framework .in_pir_mode ():
507
547
global_dims = _cal_global_shape (
508
548
local_tensor ._local_shape , mesh , placements
509
549
)
510
- dist_tensor = paddle ._C_ops .moe_global_mesh_tensor (
550
+ return paddle .jit .dy2static .py_layer .StaticPyLayer (
551
+ _moe_global_mesh_tensor
552
+ ).apply (
511
553
local_tensor_list ,
512
554
local_mesh_list ,
513
555
local_placements ,
514
556
mesh ,
515
557
placements ,
516
558
global_dims ,
517
559
)
518
- dist_tensor .stop_gradient = local_tensor_list [0 ].stop_gradient
519
- dist_tensor .persistable = local_tensor_list [0 ].persistable
520
-
521
- return dist_tensor
522
560
else :
523
561
raise NotImplementedError (
524
562
"dtensor_from_local_list() are only supported in dynamic and pir mode."
@@ -536,75 +574,114 @@ def forward(
536
574
global_mesh = None ,
537
575
global_placements = None ,
538
576
):
539
- ctx .local_mesh_list = copy .deepcopy (local_mesh_list )
540
- ctx .local_placements = local_placements
541
- ctx .local_mesh_dim = local_mesh_dim
542
- ctx .global_mesh = copy .deepcopy (global_mesh )
543
- ctx .global_placements = global_placements
544
- ctx .global_shape = dist_tensor .shape
545
-
546
- if global_mesh is None and global_placements is None :
547
- return dist_tensor ._local_value ()
548
- else :
549
- if global_mesh is None or global_placements is None :
550
- raise ValueError (
551
- "the args global_mesh and global_placements should be set together"
552
- )
553
- ori_mesh = dist_tensor .process_mesh
554
- if global_mesh != dist_tensor .process_mesh :
555
- raise ValueError (
556
- "the global_mesh should be the same as dist_tensor's process_mesh."
557
- )
558
- assert check_placements_equal (
559
- global_placements , dist_tensor .placements
560
- ), f"the global_placements ({ global_placements } ) is not equal to dist_tensor's placements ({ dist_tensor .placements } )."
561
- local_shape = _cal_local_shape (
562
- dist_tensor .shape , global_mesh , global_placements
563
- )
564
- for idx , placement in enumerate (local_placements ):
565
- if placement .is_shard ():
566
- shard_dim = placement .get_dim ()
567
- local_dim_size = local_shape [shard_dim ]
568
- local_shape [shard_dim ] = (
569
- local_dim_size * local_mesh_list [0 ].shape [idx ]
577
+ ctx .save_for_backward (
578
+ copy .deepcopy (local_mesh_list ), # local_mesh_list,
579
+ local_placements , # local_placements,
580
+ local_mesh_dim , # local_mesh_dim,
581
+ copy .deepcopy (global_mesh ), # global_mesh,
582
+ global_placements , # global_placements,
583
+ dist_tensor .shape , # global_shape,
584
+ )
585
+ if paddle .in_dynamic_mode ():
586
+ if global_mesh is None and global_placements is None :
587
+ return dist_tensor ._local_value ()
588
+ else :
589
+ if global_mesh is None or global_placements is None :
590
+ raise ValueError (
591
+ "the args global_mesh and global_placements should be set together"
570
592
)
571
-
572
- place = paddle .framework ._current_expected_place ()
573
- place = paddle .framework ._get_paddle_place (place )
574
- local_tensor_list = []
575
- for i , local_mesh in enumerate (local_mesh_list ):
576
- local_tensor = paddle .Tensor (
577
- dist_tensor ._local_value (),
578
- dims = local_shape ,
579
- process_mesh = local_mesh ,
580
- placements = local_placements ,
581
- place = place ,
593
+ ori_mesh = dist_tensor .process_mesh
594
+ if global_mesh != dist_tensor .process_mesh :
595
+ raise ValueError (
596
+ "the global_mesh should be the same as dist_tensor's process_mesh."
597
+ )
598
+ assert check_placements_equal (
599
+ global_placements , dist_tensor .placements
600
+ ), f"the global_placements ({ global_placements } ) is not equal to dist_tensor's placements ({ dist_tensor .placements } )."
601
+ local_shape = _cal_local_shape (
602
+ dist_tensor .shape , global_mesh , global_placements
582
603
)
583
- local_tensor .get_tensor ()._unsafe_set_skip_check_mesh (True )
604
+ for idx , placement in enumerate (local_placements ):
605
+ if placement .is_shard ():
606
+ shard_dim = placement .get_dim ()
607
+ local_dim_size = local_shape [shard_dim ]
608
+ local_shape [shard_dim ] = (
609
+ local_dim_size * local_mesh_list [0 ].shape [idx ]
610
+ )
611
+
612
+ place = paddle .framework ._current_expected_place ()
613
+ place = paddle .framework ._get_paddle_place (place )
614
+ local_tensor_list = []
615
+ for i , local_mesh in enumerate (local_mesh_list ):
616
+ local_tensor = paddle .Tensor (
617
+ dist_tensor ._local_value (),
618
+ dims = local_shape ,
619
+ process_mesh = local_mesh ,
620
+ placements = local_placements ,
621
+ place = place ,
622
+ )
623
+ local_tensor .get_tensor ()._unsafe_set_skip_check_mesh (True )
624
+ local_tensor .stop_gradient = dist_tensor .stop_gradient
625
+ local_tensor_list .append (local_tensor )
626
+ return local_tensor_list
627
+ elif paddle .framework .in_pir_mode ():
628
+ local_tensors = paddle ._C_ops .moe_sub_mesh_tensors (
629
+ dist_tensor ,
630
+ local_mesh_list ,
631
+ local_placements ,
632
+ global_mesh ,
633
+ global_placements ,
634
+ )
635
+ for local_tensor in local_tensors :
584
636
local_tensor .stop_gradient = dist_tensor .stop_gradient
585
- local_tensor_list . append ( local_tensor )
586
- return local_tensor_list
637
+ local_tensor . persistable = dist_tensor . persistable
638
+ return local_tensors
587
639
588
640
@staticmethod
589
641
def backward (ctx , * grad_tensor ):
642
+ (
643
+ local_mesh_list ,
644
+ local_placements ,
645
+ local_mesh_dim ,
646
+ global_mesh ,
647
+ global_placements ,
648
+ global_shape ,
649
+ ) = ctx .saved_tensor ()
590
650
place = paddle .framework ._current_expected_place ()
591
651
place = paddle .framework ._get_paddle_place (place )
592
- mesh = ctx . global_mesh
652
+ mesh = global_mesh
593
653
process_ids = np .array (mesh .process_ids ).reshape (mesh .shape )
594
654
local_coord = np .where (process_ids == dist .get_rank ())
595
655
if local_coord [0 ].size == 0 :
596
656
local_tensor_idx = 0
597
657
else :
598
- local_tensor_idx = local_coord [ctx . local_mesh_dim ][0 ]
658
+ local_tensor_idx = local_coord [local_mesh_dim ][0 ]
599
659
local_grad = grad_tensor [local_tensor_idx ]
600
- global_tensor = paddle .Tensor (
601
- local_grad ._local_value (),
602
- dims = ctx .global_shape ,
603
- process_mesh = mesh ,
604
- placements = ctx .global_placements ,
605
- place = place ,
606
- )
607
- return global_tensor
660
+
661
+ if paddle .in_dynamic_mode ():
662
+ place = paddle .framework ._current_expected_place ()
663
+ place = paddle .framework ._get_paddle_place (place )
664
+ global_tensor = paddle .Tensor (
665
+ local_grad ._local_value (),
666
+ dims = global_shape ,
667
+ process_mesh = mesh ,
668
+ placements = global_placements ,
669
+ place = place ,
670
+ )
671
+ return global_tensor
672
+ elif paddle .framework .in_pir_mode ():
673
+ global_dims = _cal_global_shape (
674
+ local_grad ._local_shape , mesh , global_placements
675
+ )
676
+
677
+ return paddle ._C_ops .moe_global_mesh_tensor (
678
+ grad_tensor ,
679
+ local_mesh_list ,
680
+ local_placements ,
681
+ global_mesh ,
682
+ global_placements ,
683
+ global_dims ,
684
+ )
608
685
609
686
610
687
def moe_sub_mesh_tensors (
@@ -627,17 +704,17 @@ def moe_sub_mesh_tensors(
627
704
global_placements ,
628
705
)
629
706
elif paddle .framework .in_pir_mode ():
630
- local_tensors = paddle ._C_ops .moe_sub_mesh_tensors (
707
+
708
+ return paddle .jit .dy2static .py_layer .StaticPyLayer (
709
+ _moe_sub_mesh_tensors
710
+ ).apply (
631
711
dist_tensor ,
632
712
local_mesh_list ,
633
713
local_placements ,
714
+ local_mesh_dim ,
634
715
global_mesh ,
635
716
global_placements ,
636
717
)
637
- for local_tensor in local_tensors :
638
- local_tensor .stop_gradient = dist_tensor .stop_gradient
639
- local_tensor .persistable = dist_tensor .persistable
640
- return local_tensors
641
718
else :
642
719
raise NotImplementedError (
643
720
"moe_sub_mesh_tensors is only supported in dynamic mode."
0 commit comments