Skip to content

Commit 530283d

Browse files
kurisusnowdengver217
authored andcommitted
fix object_to_tensor usage when torch>=2.3.0 (#5820)
1 parent 2e28c79 commit 530283d

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

colossalai/pipeline/p2p.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,11 @@ def _broadcast_object_list(
9191
my_rank = dist.get_rank()
9292
# Serialize object_list elements to tensors on src rank.
9393
if my_rank == src:
94-
if Version(torch.__version__) >= Version("1.13.0"):
94+
if Version(torch.__version__) >= Version("2.3.0"):
95+
tensor_list, size_list = zip(
96+
*[c10d._object_to_tensor(obj, device=current_device, group=group) for obj in object_list]
97+
)
98+
elif Version(torch.__version__) >= Version("1.13.0"):
9599
tensor_list, size_list = zip(*[c10d._object_to_tensor(obj, device=current_device) for obj in object_list])
96100
else:
97101
tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list])
@@ -276,7 +280,11 @@ def _send_recv_serialization_object(
276280
send_object_tensor = None
277281
send_object_size_tensor = None
278282
if object is not None and send_dst is not None:
279-
if Version(torch.__version__) >= Version("1.13.0"):
283+
if Version(torch.__version__) >= Version("2.3.0"):
284+
send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(
285+
object, device=current_device, group=send_group
286+
)
287+
elif Version(torch.__version__) >= Version("1.13.0"):
280288
send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object, device=current_device)
281289
else:
282290
send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object)

0 commit comments

Comments
 (0)