Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Investigate whether current KNN imputation default (FKNNI) works for all sparseness and AnnData objects #803

Open
2 tasks done
Zethson opened this issue Oct 1, 2024 · 12 comments
Assignees
Labels
bug Something isn't working enhancement New feature or request

Comments

@Zethson
Copy link
Member

Zethson commented Oct 1, 2024

Description of feature

We ran into cases in the past where users suddenly reported that it didn't impute everything and we need to understand this.

  • Generate a loooot of different AnnData objects with different numbers of missing data etc. A LLM can easily do that for us.
  • Run KNN imputation with backend=fknni (default) on all of them -> check whether any of them break
@Zethson Zethson added bug Something isn't working enhancement New feature or request labels Oct 1, 2024
@eroell
Copy link
Collaborator

eroell commented Oct 1, 2024

We ran into cases in the past where users suddenly reported that it didn't impute everything

Does this relate to the same reportings as #734?

@Zethson
Copy link
Member Author

Zethson commented Oct 2, 2024

Yes! Sorry this is a duplicate now :(

@nicolassidoux nicolassidoux marked this as a duplicate of #734 Jan 10, 2025
@nicolassidoux
Copy link
Collaborator

nicolassidoux commented Jan 10, 2025

I found interesting things after batch of tests from yesterday.

  1. It was surprisingly easy to generate an AnnData that causes the issue. This script is a loop that generates a dataset of random shape filled with random floats. Each value is then either kept as is or replaced by NaN based on a random probability. knn_impute is then called and the resulting AnnData is parsed to check for NaN. If found, the loop breaks and the datasets are saved for further inspection. If you try it, you'll find that it doesn't take long to get results.
import os
import numpy as np
import pandas as pd
import anndata as ad
import ehrapy as ep

root = "1"

if not os.path.exists(root):
    os.makedirs(root)

rng = np.random.default_rng()
failed = False
iteration = 0

while not failed:
    iteration += 1

    n_obs = rng.integers(10, 100000)
    n_vars = rng.integers(1, 100)
    adata = ad.AnnData(pd.DataFrame(
        np.random.uniform(0.0, 100.0, size=(n_obs, n_vars)),
        columns=[f'Var{i + 1}' for i in range(n_vars)]))

    proba_missing = rng.uniform()
    missing_count = 0
    for i in range(adata.shape[0]):
        for j in range(adata.shape[1]):
            if rng.uniform() < proba_missing:
                adata.X[i, j] = np.nan
                missing_count += 1

    print (f"Iteration {iteration}: n_obs={n_obs}, n_vars={n_vars}, size={n_obs * n_vars}, missing={missing_count}, "
           f"missing ratio={missing_count / (n_obs * n_vars):0.3f}")

    imputed_adata = ep.pp.knn_impute(adata, copy=True)
    failed = np.isnan(imputed_adata.X).any()

print (f"NaN found!!! Saving Anndatas!")
adata.write_h5ad(f"{root}/before imputation.h5ad")
pd.DataFrame(adata.X, index=adata.obs_names, columns=adata.var_names).to_csv(f"{root}/before imputation.csv")
imputed_adata.write_h5ad(f"{root}/after imputation.h5ad")
pd.DataFrame(imputed_adata.X, index=adata.obs_names, columns=adata.var_names).to_csv(f"{root}/after imputation.csv")
  1. Inspection of the content of adata and adata_imputed shows only some columns are fully imputed while the others are left untouched, despite the call for knn_impute without var_names.
  2. Calling again knn_impute with the original adata from previous test (by loading before imputation.h5ad instead of randomly creating a dataset) will lead to the same output.
  3. Calling again knn_impute with the imputed imputed_adata (by loading after imputation.h5ad instead of randomly creating a dataset) will work as expected this time.

I share these early observations before digging deeper because it may ring a bell for you guys (and save me from useless work😁).

@Zethson
Copy link
Member Author

Zethson commented Jan 10, 2025

Ohhh these is as interesting as disturbing haha. I don't have any further ideas yet but you're doing amazing work, thank you!

@nicolassidoux
Copy link
Collaborator

Actually further tests show the problem has nothing to do with columns. The imputation silently fails at some point and leave the rest untouched, and this "rest" may contain full columns.

I'll dig into that asap.

@eroell
Copy link
Collaborator

eroell commented Jan 11, 2025

Huh, nice! Great reproducible example... How the adata fails to be imputed in 3., but succeeds if this adata is written to disk and then loaded again as in 4. is very strange. Can reproduce this! You're on very important track here!

@nicolassidoux
Copy link
Collaborator

nicolassidoux commented Jan 12, 2025

I don't think writing and loading has anything to do with our problem: calling knn_impute twice in a row works as well.

@nicolassidoux
Copy link
Collaborator

nicolassidoux commented Jan 12, 2025

Digging further:

  1. In fknni\faiss\faiss.py, function transform, line 99:
            if self.strategy == "mean":
                imputed_values = np.nanmean(neighbor_values, axis=0)

After np.nanmean call, imputed_values contains some np.nan. There is a suppressed warning in this function warnings.warn("Mean of empty slice", RuntimeWarning, stacklevel=2) that will correctly be displayed if calling warnings.simplefilter("always", RuntimeWarning) in the calling code.
2. A few lines before, line 93:

           distances, neighbor_indices = self.index_.search(sample_data.reshape(1, -1), self.n_neighbors)

With a failing dataset, the neighbor_indices array only contains several -1, which is highly suspicous since it's supposed to be indexes.
3. Tracking down subsequent calls, in faiss\swigfaiss_avx2.py, function search, line 2350:

           return _swigfaiss_avx2.IndexFlat_search(self, n, x, k, distances, labels, params)

IndexFlat_search is a call to a native library, right? With all the pointers, and deletion functions laying around in faiss\_swigfaiss_avx2.py... Then -1 would be a typical placeholder for an error. In my test case:

  • n = 1
  • x is an array of n_vars = 87 of float32 from 0 to 100.
  • k = 5
  • distances and labels are the the output arrays
  • params is None

Same questions as before: Do you guys have an idea at this stage? Does the input parameters look suspicious? I will have to inspect the native code if not. Do you know where I can find it?

I already have a few suggestions:

  1. We need to revise the way to generate the fixtures we use to tests imputations.
  2. For all native code use, we need to check what errors can be returned and wrap them in exceptions
  3. Make sure all warnings are displayed

@Zethson
Copy link
Member Author

Zethson commented Jan 12, 2025

Great digging, thank you!

Do you guys have an idea at this stage?

facebookresearch/faiss#3830 could it be that we are searching more neighbors than there are samples in the population? Can we return more sane values here using FAISS? Edit: chatGPT seems to suggest that this is indeed the case.

Does the input parameters look suspicious?

Uhmm not to me...

I will have to inspect the native code if not. Do you know where I can find it?

We're using FAISS here https://github.com/facebookresearch/faiss which uhm should not have such big issues. It might be a configuration or number of neighbors issue. Maybe they say something in their wiki? https://github.com/facebookresearch/faiss/wiki

The fknni code is here: https://github.com/Zethson/fknni/blob/main/src/fknni/faiss/faiss.py

chatGPT suggests (I cleaned it a bit):

To ensure FAISS does not return `-1` for indices when performing KNN search (important for tasks like KNN imputation), you can address the issue by enforcing the return of valid neighbors for every query. Here are potential solutions:

2. Decrease k

If the number of neighbors (k) requested exceeds the dataset size, FAISS may return -1. Use a value of k that is less than or equal to the number of dataset points:

k = min(k, index.ntotal)

3. Configure Approximate Search Parameters

For indices like IVF or HNSW, adjust search parameters like nprobe (IVF) or efSearch (HNSW) to increase the likelihood of finding neighbors:

index.nprobe = 10  # IVF
index.hnsw.efSearch = 50  # HNSW

4. Fill Missing Neighbors with Defaults

If you cannot guarantee that FAISS will find enough neighbors, post-process the results to replace -1 indices. For example:

  • Replace -1 with the index of the closest valid neighbor.
  • Use a random imputation strategy for missing values.

Example:

import numpy as np

# Replace `-1` indices with the closest valid neighbor
def fix_invalid_indices(indices, distances):
    for i in range(len(indices)):
        valid_mask = indices[i] != -1
        if not valid_mask.any():
            indices[i] = [0] * len(indices[i])  # Default to the first index
        else:
            valid_indices = indices[i][valid_mask]
            valid_distances = distances[i][valid_mask]
            indices[i][~valid_mask] = valid_indices[np.argmin(valid_distances)]
    return indices

indices, distances = index.search(query_vectors, k)
indices = fix_invalid_indices(indices, distances)

5. Add Synthetic or Fallback Points

If your dataset cannot satisfy the search criteria, consider adding synthetic points to ensure valid neighbors are always found. For example, add a point with all zeros or a global mean:

synthetic_point = np.zeros((1, vector_dim), dtype='float32')
index.add(synthetic_point)

6. Switch to a Dense Search Index

If feasible, use a dense search index like Flat instead of an approximate one (IVF, HNSW). This ensures all points are considered, eliminating the chance of missing neighbors:

index = faiss.IndexFlatL2(d)  # L2 distance
index.add(data)

7. Use FAISS with Brute-Force Fallback

If -1 is returned, re-query for missing neighbors using brute force (e.g., IndexFlatL2).


Combining Techniques

For robust KNN imputation, combine approaches:

  • Use dense indices if feasible.
  • Post-process -1 indices to ensure valid neighbors.
  • Add fallback points or synthetic data for safety.

Not sure whether I like a fallback because for big data that can be expensive but we can consider the solutions here

@nicolassidoux
Copy link
Collaborator

nicolassidoux commented Jan 13, 2025

I found an entirely other case which can lead to nan's in the imputed dataset.

All take place in transform function in fknni\faiss\faiss.py.

This time, FAISS returned correctly the nearest neighbour indices, but check what rows got extracted from the original dataset based on these indices (2 values were missing in this specific row):

Image

You got it, one column is filled with nan's, bad luck! The subsequent call to np.nanmean returns:

Image

And we now have a nan in the imputed dataset!

@nicolassidoux
Copy link
Collaborator

I think I figured out what is going on for the first case I described.

In fknni\faiss\faiss.py, function fit, line 60:

        mask = ~np.isnan(X).any(axis=1)
        X_non_missing = X[mask]

X_non_missing will include rows from X that do not contain nan values. With a dataset containing numerous features, even a low probability of missingness per feature can make it highly improbable for any row to have no NaN values. As a result, X_non_missing is likely to be empty.

        index.train(X_non_missing)
        index.add(X_non_missing)

Here, we train index with no data. Consequently, _swigfaiss_avx2.IndexFlat_search returns an array of -1, as there is no data to extract results from, as explained in the documentation.

@Zethson
Copy link
Member Author

Zethson commented Jan 13, 2025

Ahhh, this makes a lot of sense. Hmm, how do you think about solving this problem?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants