Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable extraction of gene embeddings from geneformer (averaging of gene embeddings across all cells) #452

Open
jstjohn opened this issue Nov 19, 2024 · 9 comments

Comments

@jstjohn
Copy link
Collaborator

jstjohn commented Nov 19, 2024

A potential design:

  1. add an argparse option for --num-layers-override in infer.py https://github.com/NVIDIA/bionemo-framework/blob/main/sub-packages/bionemo-geneformer/src/bionemo/geneformer/scripts/infer_geneformer.py#L235 with a default of None.
  2. Add logic to https://github.com/NVIDIA/bionemo-framework/blob/main/sub-packages/bionemo-geneformer/src/bionemo/geneformer/scripts/infer_geneformer.py#L34 where if the override is unset, nothing different happens
  3. If the override is set we need to do two things to make it impact the model:
    1. import this thing: https://github.com/NVIDIA/bionemo-framework/blob/main/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/model.py#L105
    2. add override_parent_fields=['num_layers'] + OVERRIDE_BIOBERT_CONFIG_DEFAULTS to the config_class (around here https://github.com/NVIDIA/bionemo-framework/blob/main/sub-packages/bionemo-geneformer/src/bionemo/geneformer/scripts/infer_geneformer.py#L116) but only if the user set num_layers_override != None. This communicates to the checkpoint loader to not pull this field out of the trained model config in the checkpoint, and instead use the user supplied option for this field.
    3. also add num_layers=num_layers_override to the config around that point, but again only if the user set this to not None.

What will happen then is the model will be initialized with the user requested num layers rather than the num_layers it was originally trained with. So if you want to remove the last layer and get the inference results from that second to last layer, and you know the model was trained with 6 layers, then you could set --num-layers-override 5 and you would get back a 5 layer model with that last layer left off.

Side note: These steps are generally how you would override any setting in the loaded model. This pattern can be used for fine-tuning as well as inference if you want to change things about the model when you load it. Note that in the fine-tuning case, not here, if you add a new layer you also need to communicate to the checkpoint loader to not look for that new layer in the checkpoint, otherwise you get a confusing looking error about that layer not being found at checkpoint load time.

@skothenhill-nv
Copy link
Collaborator

More context, this is about getting 'gene embeddings' from geneformer. Right now we can pull the hiddens from the last layer, but will need to be able to pull them from an arbitrary embedding layer:

Our inference code:

https://github.com/NVIDIA/bionemo-framework/blob/main/sub-packages/bionemo-geneformer/src/bionemo/geneformer/scripts/infer_geneformer.py#L38

Description of the problem:

For each single cell transcriptome presented to Geneformer, the model embeds each gene into a 256-dimensional space that encodes the gene’s characteristics specific to the context of that cell. Contextual Geneformer gene embeddings are extracted as the hidden state weights for the 256 embedding dimensions for each gene within the given single cell transcriptome evaluated by forward pass through the Geneformer model. Gene embeddings analyzed in this study were extracted from the second to last layer of the models as the final layer is known to encompass features more directly related to the learning objective prediction while the second to last layer is a more generalizable representation.

(The second to last layer is handled by @jstjohn 's description above.

Reference Geneformer huggingface code

https://geneformer.readthedocs.io/en/latest/_modules/geneformer/emb_extractor.html#EmbExtractor
We will need to be able to do this kind of aggregation in a way that is memory efficient, as well as ensure we have access to the cell labels from sc-memmap (if we want to aggregate by cell).

@isabel-wilkinson
Copy link
Collaborator

Context from Birkan Gökbağ

Geneformer’s embedding extractions rely on the input datasets and for every cell, we obtain the generated embeddings of each cell’s expressed genes.
o Gene Embedding Extraction: Gene embeddings are obtained by averaging the genes’ embeddings across all the cells. Since the architecture is NLP based, the ordering of the gene tokens will influence gene embeddings and therefore gene embedding in one cell may not be the same in another. The index location of the gene token will carry that embedding in the output. Make sure to index it with CLS token position in consideration (usually at index 0, so you may need to shift by 1).

§ i.e., Average gene embeddings across all cells

o Cell Embedding Extraction: As input is a cell (i.e., sorted series of tokens), the output is already a representation of the cell. These embeddings are averaged to represent the cell embedding (not including CLS token embedding).

§ I.e., Average embeddings of the input cell directly

o Optional aggregation by cell annotation: The previous analyses are applied per cell type annotation. Since the embedding process is limited to select annotation subsets, the embeddings will already be representative of the state only. Those embeddings are then aggregated using mean/median to represent the state. This is the scenario where you basically take the mean, or median, of the means.

@isabel-wilkinson
Copy link
Collaborator

Ideally a test would be added as well

@isabel-wilkinson isabel-wilkinson changed the title Add option to geneformer to get the N-1th layer embeddings out of a model. Enable extraction of gene embeddings from geneformer (averaging of gene embeddings across all cells) Nov 20, 2024
@jyin-bst
Copy link

I would suggest to add a "--include-geneembeddings" and change the current "--include-embeddings" into "--include-cellembeddings" to avoid confusion: https://github.com/NVIDIA/bionemo-framework/blob/main/sub-packages/bionemo-geneformer/src/bionemo/geneformer/scripts/infer_geneformer.py#L235

Here is an example of the original implementation by Geneformer for your reference.
https://geneformer.readthedocs.io/en/latest/geneformer.emb_extractor.html

@jstjohn
Copy link
Collaborator Author

jstjohn commented Nov 22, 2024

My preference for this would be as a post-processing step to avoid OOM issues. Basically my recommendation would be to dump the cell x gene embeddings, and any necessary metadata to disk, then do whatever averaging/grouping/etc you need downstream. I am guessing you would need:

  1. cell x gene embeddings in the order they appear in the anndata
  2. token vector
  3. the tokenizer for gene id lookup later on

Then from there you could load those entities and either place them in a cell x gene_token shaped tensor ordered by gene_token (this needs testing):

import numpy as np
from scipy.sparse import coo_matrix, csr_matrix

def construct_sparse_matrices(token_matrix, embedding_matrix, num_tokens):
    """
    Constructs sparse matrices for embeddings and observation tracking.

    Args:
        token_matrix (np.ndarray): A (samples, 2048) matrix of token indices (int).
        embedding_matrix (np.ndarray): A (samples, 2048, emb_dim) matrix of embeddings.
        num_tokens (int): The total number of unique tokens.

    Returns:
        tuple:
            sparse_embeddings (scipy.sparse.coo_matrix): Sparse matrix of embeddings
                with shape (samples, num_tokens, emb_dim).
            sparse_boolean (scipy.sparse.csr_matrix): Sparse boolean matrix of
                token observations with shape (samples, num_tokens).
    """
    samples, seq_len, emb_dim = embedding_matrix.shape
    
    # Flatten tokens and embeddings
    flat_tokens = token_matrix.flatten()
    flat_embeddings = embedding_matrix.reshape(-1, emb_dim)
    
    # Row indices for each sample and token
    row_indices = np.repeat(np.arange(samples), seq_len)
    col_indices = flat_tokens
    
    # Construct sparse boolean matrix
    data_boolean = np.ones_like(flat_tokens, dtype=bool)
    sparse_boolean = csr_matrix(
        (data_boolean, (row_indices, col_indices)),
        shape=(samples, num_tokens)
    )
    
    # Construct sparse embedding matrix (coo for multidimensional sparse representation)
    emb_row_indices = np.repeat(row_indices, emb_dim)
    emb_col_indices = np.tile(np.arange(emb_dim), row_indices.size)
    emb_data = flat_embeddings.ravel()
    sparse_embeddings = coo_matrix(
        (emb_data, (emb_row_indices, emb_col_indices)),
        shape=(samples * num_tokens, emb_dim)
    )
    
    return sparse_embeddings, sparse_boolean

Where the sparse_boolean tells you which genes were present at all in a particular cell (up to the model's context length). You could then select/average with something like the following (again this needs testing):

def compute_grouped_means(sparse_embeddings, sparse_boolean, group_indices):
    """
    Compute grouped means for selected samples.

    Args:
        sparse_embeddings (scipy.sparse.coo_matrix): Sparse embeddings matrix.
        sparse_boolean (scipy.sparse.csr_matrix): Sparse boolean matrix.
        group_indices (np.ndarray): Indices of samples to include in the group.

    Returns:
        np.ndarray: Mean embeddings per token for the group.
    """
    # Subset the matrices
    group_boolean = sparse_boolean[group_indices, :]
    group_embeddings = sparse_embeddings[group_indices * sparse_boolean.shape[1]]
    
    # Sum embeddings and counts
    token_sums = group_embeddings.sum(axis=0)
    token_counts = group_boolean.sum(axis=0).A1  # Convert to 1D array
    
    # Compute means
    token_counts = np.maximum(token_counts, 1)  # Avoid division by zero
    mean_embeddings = token_sums / token_counts[:, None]
    
    return mean_embeddings

The benefit of this approach as well is that you can use the anndata.obs to determine how you want to group etc, and your shapes will all match that.

@jyin-bst
Copy link

Thanks @jstjohn I think it is getting close to the results. The workflow seems good to me, first by extracting (samples, genes, emb_dim) then we can average different samples into mean embeddings (genes, emb_dim).

Could you please further explain what the inputs are for construct_sparse_matrices(token_matrix, embedding_matrix, num_tokens)? How can I apply this function to the infer_geneformer.py output?

@jstjohn
Copy link
Collaborator Author

jstjohn commented Nov 23, 2024 via email

@jstjohn
Copy link
Collaborator Author

jstjohn commented Nov 23, 2024 via email

@jyin-bst
Copy link

jyin-bst commented Dec 5, 2024

After playing with BioNeMo Geneformer as suggested by John, I got the following results by infering using a 462 samples by 25429 genes single cell sequencing data set using the modified infer_geneformer.py. It has an impressive interface for controlling the parallelling computing processes, which the original hugging face implementation does not offer.

https://github.com/jyin-bst/bionemo-framework/blob/main/sub-packages/bionemo-geneformer/src/bionemo/geneformer/scripts/infer_geneformer.py

In the saved model output, by using sequence length of 512 (--seq-len 512), I got:

token_logits, torch.Size([512, 462, 25472])
hidden_states, torch.Size([462, 512, 256])
input_ids, torch.Size([462, 512])
embeddings, torch.Size([462, 256])

"input_ids" e.g. by using --include-logits, doesn't contain any gene information. It outputs a data matrix of (num_cells, sequence_length). So are "hidden_states" and "embeddings". All of them are cell related information.

"token logits" are the closest results to gene embeddings, e.g. by using --include-logits. It saves a "token logits" data matrix of (sequence_length, num_cells, num_genes). It contains predictions for genes from the last layer of the model, which are still different from the gene embeddings used by Geneformer. Gene embeddings should be extracted from the second-to-last layer.

@jstjohn @skothenhill-nv @isabel-wilkinson Do you have any suggestions on how to move forward from this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants