@@ -30,24 +30,27 @@ def all_gather_embeddings_and_labels(emb, labels):
3030 if not is_distributed ():
3131 return None , None
3232 ref_emb = all_gather (emb )
33- ref_labels = all_gather (labels )
33+ ref_labels = all_gather (labels ) if labels is not None else None
3434 return ref_emb , ref_labels
3535
3636
3737def gather (emb , labels ):
3838 device = emb .device
39- labels = c_f .to_device (labels , device = device )
39+ if labels is not None :
40+ labels = c_f .to_device (labels , device = device )
4041 dist_emb , dist_labels = all_gather_embeddings_and_labels (emb , labels )
4142 all_emb = torch .cat ([emb , dist_emb ], dim = 0 )
42- all_labels = torch .cat ([labels , dist_labels ], dim = 0 )
43+ all_labels = (
44+ torch .cat ([labels , dist_labels ], dim = 0 ) if dist_labels is not None else None
45+ )
4346 return all_emb , all_labels , labels
4447
4548
4649def gather_emb_and_ref (emb , labels , ref_emb = None , ref_labels = None ):
4750 all_emb , all_labels , labels = gather (emb , labels )
4851 all_ref_emb , all_ref_labels = None , None
4952
50- if ref_emb is not None and ref_labels is not None :
53+ if ref_emb is not None :
5154 all_ref_emb , all_ref_labels , _ = gather (ref_emb , ref_labels )
5255
5356 return all_emb , all_labels , all_ref_emb , all_ref_labels , labels
@@ -81,7 +84,9 @@ def __init__(self, loss, efficient=False):
8184 self .loss = loss
8285 self .efficient = efficient
8386
84- def forward (self , emb , labels , indices_tuple = None , ref_emb = None , ref_labels = None ):
87+ def forward (
88+ self , emb , labels = None , indices_tuple = None , ref_emb = None , ref_labels = None
89+ ):
8590 world_size = torch .distributed .get_world_size ()
8691 common_args = [emb , labels , indices_tuple , ref_emb , ref_labels , world_size ]
8792 if isinstance (self .loss , CrossBatchMemory ):
@@ -99,7 +104,8 @@ def forward_regular_loss(
99104 )
100105
101106 if self .efficient :
102- all_labels = select_ref_or_regular (all_labels , all_ref_labels )
107+ if all_labels is not None :
108+ all_labels = select_ref_or_regular (all_labels , all_ref_labels )
103109 all_emb = select_ref_or_regular (all_emb , all_ref_emb )
104110 if indices_tuple is None :
105111 indices_tuple = get_indices_tuple (labels , all_labels )
0 commit comments