Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incompatibility with using the pretrained StyleSAN model (CIFAR-10) #8

Open
rielmvp opened this issue Feb 19, 2025 · 5 comments
Open

Comments

@rielmvp
Copy link

rielmvp commented Feb 19, 2025

I encountered some issues with trying to use the discriminator of the StyleSAN model pretrained on CIFAR-10. Specifically, I got this error:

---------------------------------------------------------------------------
SystemError                               Traceback (most recent call last)
Cell In[7], [line 21](vscode-notebook-cell:?execution_count=7&line=21)
     [18](vscode-notebook-cell:?execution_count=7&line=18)     return np.array(scores)
     [20](vscode-notebook-cell:?execution_count=7&line=20) # 🔹 Run evaluation
---> [21](vscode-notebook-cell:?execution_count=7&line=21) discriminator_scores = evaluate_discriminator(cifar10syn_loader, discriminator, device)
     [23](vscode-notebook-cell:?execution_count=7&line=23) # 🔹 Print results
     [24](vscode-notebook-cell:?execution_count=7&line=24) print("Discriminator Scores Shape:", discriminator_scores.shape)

Cell In[7], [line 14](vscode-notebook-cell:?execution_count=7&line=14)
     [11](vscode-notebook-cell:?execution_count=7&line=11)         labels_one_hot = F.one_hot(labels, num_classes=num_classes).float().to(device)
     [13](vscode-notebook-cell:?execution_count=7&line=13)         # 🔹 Pass both image & label to discriminator
---> [14](vscode-notebook-cell:?execution_count=7&line=14)         logits = model(images, labels_one_hot)
     [16](vscode-notebook-cell:?execution_count=7&line=16)         scores.extend(logits.cpu().numpy())  # Collect scores
     [18](vscode-notebook-cell:?execution_count=7&line=18) return np.array(scores)

File ~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   [1734](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1734)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1735](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1735) else:
-> [1736](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1736)     return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   [1742](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1742) # If we don't have any hooks, we want to skip the rest of the logic in
   [1743](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1743) # this function, and just call forward.
   [1744](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1744) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1745](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1745)         or _global_backward_pre_hooks or _global_backward_hooks
   [1746](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1746)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1747](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1747)     return forward_call(*args, **kwargs)
   [1749](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1749) result = None
   [1750](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1750) called_always_called_hooks = set()

File ~/SSL_study/san/stylesan-xl/pg_modules/discriminator.py:231, in ProjectedDiscriminator.forward(self, x, c, flg_train)
    [228](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/SSL_study/san/stylesan-xl/pg_modules/discriminator.py:228)     x_n = F.interpolate(x_n, 224, mode='bilinear', align_corners=False)
    [230](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/SSL_study/san/stylesan-xl/pg_modules/discriminator.py:230) # forward pass
--> [231](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/SSL_study/san/stylesan-xl/pg_modules/discriminator.py:231) features = feat(x_n)
    [232](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/SSL_study/san/stylesan-xl/pg_modules/discriminator.py:232) if flg_train:
    [233](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/SSL_study/san/stylesan-xl/pg_modules/discriminator.py:233)     logit_fun, logit_dir = self.discriminators[bb_name](features, c, flg_train=True)

File ~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   [1734](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1734)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1735](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1735) else:
-> [1736](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1736)     return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   [1742](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1742) # If we don't have any hooks, we want to skip the rest of the logic in
   [1743](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1743) # this function, and just call forward.
   [1744](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1744) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1745](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1745)         or _global_backward_pre_hooks or _global_backward_hooks
   [1746](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1746)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1747](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1747)     return forward_call(*args, **kwargs)
   [1749](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1749) result = None
   [1750](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1750) called_always_called_hooks = set()

File ~/SSL_study/san/stylesan-xl/pg_modules/projector.py:114, in F_RandomProj.forward(self, x)
    [111](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/SSL_study/san/stylesan-xl/pg_modules/projector.py:111) def forward(self, x):
    [112](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/SSL_study/san/stylesan-xl/pg_modules/projector.py:112)     # predict feature maps
    [113](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/SSL_study/san/stylesan-xl/pg_modules/projector.py:113)     if self.backbone in VITS:
--> [114](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/SSL_study/san/stylesan-xl/pg_modules/projector.py:114)         out0, out1, out2, out3 = forward_vit(self.pretrained, x)
    [115](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/SSL_study/san/stylesan-xl/pg_modules/projector.py:115)     else:
    [116](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/SSL_study/san/stylesan-xl/pg_modules/projector.py:116)         out0 = self.pretrained.layer0(x)

File ~/SSL_study/san/stylesan-xl/feature_networks/vit.py:59, in forward_vit(pretrained, x)
     [56](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/SSL_study/san/stylesan-xl/feature_networks/vit.py:56) def forward_vit(pretrained, x):
     [57](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/SSL_study/san/stylesan-xl/feature_networks/vit.py:57)     b, c, h, w = x.shape
---> [59](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/SSL_study/san/stylesan-xl/feature_networks/vit.py:59)     _ = pretrained.model.forward_flex(x)
     [61](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/SSL_study/san/stylesan-xl/feature_networks/vit.py:61)     layer_1 = pretrained.activations["1"]
     [62](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/SSL_study/san/stylesan-xl/feature_networks/vit.py:62)     layer_2 = pretrained.activations["2"]

File ~/SSL_study/san/stylesan-xl/feature_networks/vit.py:149, in forward_flex(self, x)
    [146](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/SSL_study/san/stylesan-xl/feature_networks/vit.py:146) x = self.pos_drop(x)
    [148](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/SSL_study/san/stylesan-xl/feature_networks/vit.py:148) for blk in self.blocks:
--> [149](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/SSL_study/san/stylesan-xl/feature_networks/vit.py:149)     x = blk(x)
    [151](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/SSL_study/san/stylesan-xl/feature_networks/vit.py:151) x = self.norm(x)
    [153](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/SSL_study/san/stylesan-xl/feature_networks/vit.py:153) return x

File ~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   [1734](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1734)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1735](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1735) else:
-> [1736](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1736)     return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1844, in Module._call_impl(self, *args, **kwargs)
   [1841](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1841)     return inner()
   [1843](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1843) try:
-> [1844](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1844)     return inner()
   [1845](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1845) except Exception:
   [1846](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1846)     # run always called hooks if they have not already been run
   [1847](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1847)     # For now only forward hooks have the always_call option but perhaps
   [1848](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1848)     # this functionality should be added to full backward hooks as well.
   [1849](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1849)     for hook_id, hook in _global_forward_hooks.items():

File ~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1803, in Module._call_impl.<locals>.inner()
   [1801](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1801)     hook_result = hook(self, args, kwargs, result)
   [1802](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1802) else:
-> [1803](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1803)     hook_result = hook(self, args, result)
   [1805](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1805) if hook_result is not None:
   [1806](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a223134332e3234382e35352e323130222c2275736572223a226761627269656c227d.vscode-resource.vscode-cdn.net/mnt/hdd_4A/gabriel/SSL_study/san/stylesan-xl/~/anaconda3/envs/research/lib/python3.12/site-packages/torch/nn/modules/module.py:1806)     result = hook_result

SystemError: /shared/torch211-cu121/stylesan-xl/feature_networks/vit.py:160: unknown opcode 232

After a brief search, it seems that the error is occuring because of a difference in Python/PyTorch version. In my current environment, my Python version is Python 3.12.8, Pytorch version 2.5.1+cu121 and I installed all packages as mentioned in requirements.txt (including dill==0.3.9). The code (loading the discriminator) seemed to work fine until this point. Additionally I realize that the requirements mentioned was to use Python 3.8, but when I created a new environment with Python 3.8 and installed all the requirements according to the requirements.txt file, I got errors as below, when I tried to run the code as follows:

# Load the pretrained StyleGAN-XL-SAN model and extract the discriminator
print(f"Loading model from {pkl_path}...")
with dnnlib.util.open_url(pkl_path) as f:
    model = legacy.load_network_pkl(f)  # Load the full model dictionary

#Print the model architecture
print(type(model))
#It is a dict, so print the keys
print(model.keys())

Image

Image

Any insight regarding this would be very helpful. I can also show you the exact code that I am using the discriminator for.

@TakashiShibuyaSony
Copy link
Collaborator

Thank you very much for your interest in our work. Much appreciated. And, apologies for not timely replying to another post from you, #6 (comment) ).

Unfortunately, we don't have a resource to address this issue. Is it possible to try using the Dockerfile we're providing (with the original setting dill>=0.3.4)? This is the same environment as the one we used.

@rielmvp
Copy link
Author

rielmvp commented Feb 19, 2025

Thank you for the prompt reply. I tried to set up Docker and I didn't get the previous error anymore, but now I got the error as shown below when I run the following code (the point of my code is to use the pretrained discriminator to give me discriminator confidence scores for synthetic images):

import os
os.environ["CUDA_VISIBLE_DEVICES"]="4,5" # first gpu
import timm
print(timm.__version__)

import torch
import numpy as np
import pickle
import dill
import legacy
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import os
import re
from typing import List, Optional, Tuple, Union

import click
import dnnlib
import PIL.Image
import legacy
from torch_utils import gen_utils
import torch_utils
import torch.nn.functional as F  # For one-hot encoding

#Load the pre-trained StyleGANXL-SAN Model and set the device (GPU 0)
pkl_path = 'stylesan-xl_cifar10.pkl'  # Path to your .pkl file
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}, GPU number: {torch.cuda.current_device()}')
#Print the number of available GPUs
print(f'Number of GPUs: {torch.cuda.device_count()}')
#Print dill version and pickle version
print(f'dill version: {dill.__version__}')
print(f'pickle version: {pickle.format_version}')
print(f"Pickle default protocol: {pickle.DEFAULT_PROTOCOL}")
print(f"Pickle highest protocol: {pickle.HIGHEST_PROTOCOL}")

# Load the pretrained StyleGAN-XL-SAN model and extract the discriminator
print(f"Loading model from {pkl_path}...")
with dnnlib.util.open_url(pkl_path) as f:
    model = legacy.load_network_pkl(f)  # Load the full model dictionary

#Print the model architecture
print(type(model))
#It is a dict, so print the keys
print(model.keys())

#The discriminator is from pretrinaed StyleGANXL-SAN model
discriminator = model['D'].to(device)
discriminator.eval()

# Custom Dataset for dynamic resizing
class SyntheticDataset(Dataset):
    def __init__(self, images, transform=None):
        self.images = images
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]
        if self.transform:
            img = self.transform(img)
        #print("Transformed image shape:", img.shape)  # Debugging print
        if img.shape[0] != 3:  # Debugging check
            print(f"Error: Non-RGB image at index {idx}, shape: {img.shape}")
        return img, -1 #returns dummy label since labels are not available for synthetic images
    
#Prepare the synthetic dataset
cifar10_synthetic = np.load('cifar10-stf-1m.npz')
#Load the images and labels
cifar10syn_images = cifar10_synthetic['image']
print(f"Loaded synthetic CIFAR-10 images: {cifar10syn_images.shape}")
#Load the labels from an .npy file
cifar10syn_labels = np.load("cifar10_synthetic_preds.npy")
# 🔹 Define a custom dataset that includes labels
class SyntheticDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels  # Store labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]
        label = self.labels[idx]  # Get corresponding label
        
        if self.transform:
            img = self.transform(img)

        # Debugging: Ensure valid image format
        if img.shape[0] != 3:
            print(f"Error: Non-RGB image at index {idx}, shape: {img.shape}")

        return img, label  # Return image + label

# 🔹 Data transformations (normalize, convert to tensor)
synthetic_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
])

# 🔹 Create DataLoader with labels
cifar10syn_loader = DataLoader(SyntheticDataset(images=cifar10syn_images, labels=cifar10syn_labels, transform=synthetic_transform), batch_size=256, shuffle=False)

# 🔹 Modify function to pass both image & one-hot label to discriminator
def evaluate_discriminator(loader, model, device, num_classes=10):
    scores = []
    model.eval()

    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating synthetic images"):
            images, labels = batch  # Extract images & labels

            # 🔹 Move images to GPU explicitly with dtype conversion
            images = images.to(device, dtype=torch.float32)

            # 🔹 Ensure labels are moved to GPU BEFORE one-hot encoding
            labels = labels.to(device)  # Move class labels before one-hot encoding
            labels_one_hot = F.one_hot(labels, num_classes=num_classes).float()  # One-hot encode labels on GPU

            # 🔹 Pass both images & one-hot labels to discriminator
            logits = model(images, labels_one_hot)

            # 🔹 Move logits back to CPU before saving to numpy
            scores.extend(logits.cpu().numpy())

    return np.array(scores)

# 🔹 Run evaluation
discriminator_scores = evaluate_discriminator(cifar10syn_loader, discriminator, device)

# 🔹 Print results
print("Discriminator Scores Shape:", discriminator_scores.shape)

Here is the output I got:

0.4.12
Using device: cuda, GPU number: 0
Number of GPUs: 2
dill version: 0.3.9
pickle version: 4.0
Pickle default protocol: 4
Pickle highest protocol: 5
Loading model from stylesan-xl_cifar10.pkl...
<class 'dict'>
dict_keys(['G', 'D', 'G_ema', 'augment_pipe', 'training_set_kwargs', 'progress'])
Loaded synthetic CIFAR-10 images: (1000000, 32, 32, 3)
Evaluating synthetic images:   0%|                                                                                                                                                                                                                                                         | 0/3907 [00:00<?, ?it/s]/usr/local/lib/python3.10/dist-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3526.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
Evaluating synthetic images:   0%|                                                                                                                                                                                                                                                         | 0/3907 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/workspace/discriminatorfiltering.py", line 134, in <module>
    discriminator_scores = evaluate_discriminator(cifar10syn_loader, discriminator, device)
  File "/workspace/discriminatorfiltering.py", line 126, in evaluate_discriminator
    logits = model(images, labels_one_hot)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/pg_modules/discriminator.py", line 231, in forward
    features = feat(x_n)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/pg_modules/projector.py", line 114, in forward
    out0, out1, out2, out3 = forward_vit(self.pretrained, x)
  File "/workspace/feature_networks/vit.py", line 92, in forward_vit
    layer_1 = pretrained.layer1[3 : len(pretrained.layer1)](layer_1)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/container.py", line 215, in forward
    input = module(input)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/conv.py", line 460, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor

Any insight on why the discriminator would not work would be helpful.

@TakashiShibuyaSony
Copy link
Collaborator

TakashiShibuyaSony commented Feb 19, 2025

Have you tried printing images.device and images.dtype just after images = images.to(device, dtype=torch.float32) to confirm that the tensor images are actually placed on a GPU?

@rielmvp
Copy link
Author

rielmvp commented Feb 20, 2025

Yes. when I printed the device for images, labels, and model, I got the following output:

Images device: cuda:0, images dtype: torch.float32, labels device: cuda:0, model device: cuda:0, labels_one_hot device: cuda:0

@TakashiShibuyaSony
Copy link
Collaborator

TakashiShibuyaSony commented Feb 21, 2025

The original codebase on which our codebase is based is tricky and then I'm not sure, but it might be good to trace what device the variable layer1 is placed by print(layer1.device) every time it's replaced.

def forward_vit(pretrained, x):
b, c, h, w = x.shape
_ = pretrained.model.forward_flex(x)
layer_1 = pretrained.activations["1"]
layer_2 = pretrained.activations["2"]
layer_3 = pretrained.activations["3"]
layer_4 = pretrained.activations["4"]
layer_1 = pretrained.layer1[0:2](layer_1)
layer_2 = pretrained.layer2[0:2](layer_2)
layer_3 = pretrained.layer3[0:2](layer_3)
layer_4 = pretrained.layer4[0:2](layer_4)
unflatten = nn.Sequential(
nn.Unflatten(
2,
torch.Size(
[
h // pretrained.model.patch_size[1],
w // pretrained.model.patch_size[0],
]
),
)
)
if layer_1.ndim == 3:
layer_1 = unflatten(layer_1)
if layer_2.ndim == 3:
layer_2 = unflatten(layer_2)
if layer_3.ndim == 3:
layer_3 = unflatten(layer_3)
if layer_4.ndim == 3:
layer_4 = unflatten(layer_4)
layer_1 = pretrained.layer1[3 : len(pretrained.layer1)](layer_1)
layer_2 = pretrained.layer2[3 : len(pretrained.layer2)](layer_2)
layer_3 = pretrained.layer3[3 : len(pretrained.layer3)](layer_3)
layer_4 = pretrained.layer4[3 : len(pretrained.layer4)](layer_4)
return layer_1, layer_2, layer_3, layer_4

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants