Which miner to use for ProxyAnchor (and ProxyNCA) loss? #561
-
Greetings, Kevin (and the PML community), First, thank you for contributing to the field and making this framework available. Your work has made my first steps into metric learning much more accessible and made me move from Keras to Torch earlier than I thought possible. Many thanks! Moving to the question, which may be relatively simple, which miner would you use together with the ProxyAnchor loss? I took the following approach, which does work programmatically, but I'm curious if I am following the original work or making up a new approach 😇. distance = distances.CosineSimilarity()
reducer = reducers.DivisorReducer()
loss_fn = losses.ProxyAnchorLoss(
num_classes=N_CLASSES,
embedding_size=N_EMBEDDINGS,
margin=MARGIN,
alpha=ALPHA,
distance=distance,
reducer=reducer)
# original paper scale learning rate by 100 for proxies
loss_fn_optim = torch.optim.AdamW(loss_fn.parameters(), lr=LR*100)
miner = miners.TripletMarginMiner(
margin=MARGIN,
distance=distance,
type_of_triplets="semihard") I then pass that to my trainer function below: def train(model, miner, dataloader, optimizer, loss_fn, device, loss_fn_optim=None):
# make sure the model is in train mode
model.train()
running_loss = 0.0
for i, batch in enumerate(dataloader):
inputs, labels = batch[0].to(device), batch[1].to(device)
optimizer.zero_grad()
outputs = model(inputs)
indices_tuple = miner(outputs, labels)
loss = loss_fn(outputs, labels, indices_tuple)
running_loss += loss.item()
loss.backward()
optimizer.step()
# if the loss func req an optim, such as ProxyAnchorLoss.
if loss_fn_optim : loss_fn_optim.step()
return running_loss |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
I may be wrong but... what's the point of using a miner on proxy-based loss? You are not comparing data points against each other, you're comparing data points to class proxies |
Beta Was this translation helpful? Give feedback.
I may be wrong but... what's the point of using a miner on proxy-based loss? You are not comparing data points against each other, you're comparing data points to class proxies