Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is there something wrong with my zero-shot code? #14

Open
drizzle0171 opened this issue Jan 21, 2024 · 1 comment
Open

Is there something wrong with my zero-shot code? #14

drizzle0171 opened this issue Jan 21, 2024 · 1 comment

Comments

@drizzle0171
Copy link

drizzle0171 commented Jan 21, 2024

Thanks for the great work!

I tried to reproduce your work on my environment. However, the accuracy of the your pre-trained model (ResNet18) was 3% on my zero-shot code.
Therefore, please check what's different compared to your code.
Thank you so much!

This is my zero-shot classification code.

# model init from ckpt
audioencoder = AudioEncoder().cuda()
ckpt = torch.load('./audio-diffusion/resnet18_57.pth')
audioencoder.load_state_dict(ckpt)
textencoder, _ = clip.load("ViT-B/32")
textencoder = textencoder.cuda()

correct = 0

for i in tqdm(range(len(dataset))):
    audio = dataset[i][0][None].cuda() # (B, 1, 128, 512)

    # embeddings
    with torch.no_grad():
        text_token = torch.cat([clip.tokenize(txt) for txt in text])
        text_features = textencoder.encode_text(text_token.cuda())
        audio_features = audioencoder(audio)
        audio_features = audio_features / audio_features.norm(dim=-1, keepdim=True)
    audio_features = audio_features.float()
    text_features = text_features.float()

    # logit
    proj_per_audio = (audio_features @ text_features.T) * math.exp(0.07)

    # classification
    # confidence = proj_per_audio.softmax(dim=1)
    # conf_values, ids = confidence.topk(1)
    label_idx = torch.argmax(proj_per_audio, axis=1)

    pred = LABELS[label_idx]
    label = dataset[i][-1]

    if label == pred:
        correct += 1
    
    if i%100 == 0:
        print(correct)
        
print(f'Classification accuracy: {correct/len(dataset)*100} %')
@drizzle0171
Copy link
Author

drizzle0171 commented Jan 30, 2024

If you don't mind my asking, could you tell me what labels you used?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant