Skip to content

Commit

Permalink
put labels on correct device
Browse files Browse the repository at this point in the history
  • Loading branch information
etienneguevel committed Aug 27, 2024
1 parent 2cfb06e commit d8cfab9
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions dinov2/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,8 @@ def do_train(cfg, model, resume=False):

if do_daino:
labelled_dataset = make_labelled_dataset(
cfg.train.dataset_path,
cfg.daino.labelled_dataset_path,
cfg.train.dataset_path,
)
print(f"{len(labelled_dataset)} elements were found for the labelled dataset")

Expand Down Expand Up @@ -242,8 +242,8 @@ def do_train(cfg, model, resume=False):

# A bit of verbose for information sake

print("There are {} images in the unlabelled dataset used".format(dataset.__len__()))
print("There are {} images in the labelled dataset used".format(labelled_dataset.__len__()))
print("There are {} images in the unlabelled dataset used".format(len(dataset)))
print("There are {} images in the labelled dataset used".format(len(labelled_dataset)))

# training loop

Expand Down Expand Up @@ -279,11 +279,12 @@ def do_train(cfg, model, resume=False):

optimizer.zero_grad(set_to_none=True)
if do_daino:
labelled_data = next(labelled_iterator)
images, labels = next(labelled_iterator)
labels = torch.tensor(labels, device="gpu")
loss_dict = model.forward_backward(
data,
teacher_temp=teacher_temp,
labelled_data=labelled_data
labelled_data=(images, labels)
)

else:
Expand Down

0 comments on commit d8cfab9

Please sign in to comment.