Skip to content

Adapt CEBRA-Lens so that compatible with unified CEBRA#50

Open
CeliaBenquet wants to merge 29 commits intomainfrom
celia/test-unified-cebra
Open

Adapt CEBRA-Lens so that compatible with unified CEBRA#50
CeliaBenquet wants to merge 29 commits intomainfrom
celia/test-unified-cebra

Conversation

@CeliaBenquet
Copy link
Copy Markdown
Member

@CeliaBenquet CeliaBenquet commented Jun 24, 2025

Not all work yet but I added some if statements for unified CEBRA.

  • Add unified CEBRA compatibility
  • Add cebra.plot_embedding
  • remove label_ind
  • Add lab style plotting function
  • Remove some unused attributes: bool oracle and output_only are now just parameters of the compute or plotting function (was already the case but used to set the attribute to a different value).

What doesn't work

  • some ploting functions
  • the decoder, because of an issue with padding

--> I discuss with @anandawolz and she might have solved this in #44?

@anandawolz have a look and you can improve on top of this by creating a branch from it ideally.

@CeliaBenquet CeliaBenquet requested a review from anandawolz June 24, 2025 08:34
@CeliaBenquet CeliaBenquet changed the base branch from main to celia/add-package-files June 24, 2025 08:35
Base automatically changed from celia/add-package-files to MMathisLab-actions June 24, 2025 12:54
@MMathisLab MMathisLab deleted the branch main June 24, 2025 15:19
@MMathisLab MMathisLab closed this Jun 24, 2025
@MMathisLab MMathisLab reopened this Jun 24, 2025
@MMathisLab MMathisLab changed the base branch from MMathisLab-actions to main June 24, 2025 15:44
@MMathisLab
Copy link
Copy Markdown
Member

@CeliaBenquet you'll have to rebase to main/fix conflicts as this was branched from git actions

Comment thread cebra_lens/quantification/rdm_metric.py Outdated
rdm = pdist(layer_activation[self.idxs.flatten(), :], metric=self.metric)
rdm = pdist(layer_activation[self.idxs.flatten(), :],
metric=self.metric)
if self.bool_oracle:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@CeliaBenquet
Just a quick note: the parameter bool_oracle is currently commented out in the init method — it would be good to also comment it out (or remove it) consistently in the _compute_per_layer and plot functions.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My idea was that we don't want to remove it entirely, just not have it in the init, and have it as a parameter in the plot_metric function.

I didn't test the PR thoroughly but the aim is that you build on it now as there are the key points to adapt to unified CEBRA, so if you see that this is not the behavior now, feel free to change it in your branch :)

@CeliaBenquet
Copy link
Copy Markdown
Member Author

CeliaBenquet commented Jul 3, 2025

@anandawolz, note that conflicts are resolved and the tests should pass now, but I didn't add new tests for unified CEBRA. Please do in your branch :)

Note 2: the tests don't pass because the CEBRA version on pypi is not the latest one with unified CEBRA, you will have the same for future implementations you do. If that's too much of an issue we can update to a given commit on the requirement, but ideally you run the tests locally before committing now.

Note 3: I fixed the issues with decoder for output_only, and updated the notebook, still issues with output_only=False and the rdm metric.

@CeliaBenquet CeliaBenquet changed the title [wip] Adapt CEBRA-Lens so that compatible with unified CEBRA Adapt CEBRA-Lens so that compatible with unified CEBRA Jul 24, 2025
@CeliaBenquet CeliaBenquet self-assigned this Jul 24, 2025
@CeliaBenquet CeliaBenquet added enhancement New feature or request merge last labels Jul 24, 2025
…-fixes

Partial fix activation hook batching, padding trim, and decoding dimensionality.
@anandawolz
Copy link
Copy Markdown

Reviewed – Unified CEBRA compatibility works well. Thanks!
What did not work: the decoder, because of an issue with padding was solved in PR #58 built on top of this code

Comment thread cebra_lens/quantification/decoder.py Outdated
self.dataset_label,
)
})
if i == 0 and not isinstance(model,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BUG: if model is UnifiedSolver, i == 0 falls into the else and uses keys[i-1] == keys[-1],
so the last layer gets decoded twice (layer 0 unintentionally duplicates the final activation)
I change it in PR #59

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anandawolz did you solve it in your PRs?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request merge last

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Check the plot_embeddings function Add the lab style for plotting

3 participants