Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Would it be possible to use a pretrained model (such as Topview Mouse from DLC animal zoo)? #158

Open
hummuscience opened this issue May 24, 2024 · 39 comments
Assignees
Labels
app Question about the app rather than lightning-pose enhancement New feature or request

Comments

@hummuscience
Copy link
Contributor

I have been playing around with lightning pose the past few days and quite impressed with the training speed and performance!

Coming from DeepLabCut, I am testing LP on videos of mice captured from the top view. As you probably know, the DLC animal zoo had a pretrained model for this scenario .

Would it be possible to use that as a backbone for LP instead of the typical resnets?

I am still new to your codebase, so I might have not understood it in depth yet and missing something here...

@themattinthehatt
Copy link
Collaborator

@hummuscience thanks for the kind words!

We currently do not offer any pretrained models, though we hope to start providing some within the year. Once DLC releases the dataset used to train their TopViewMouse network we can use that to train an LP version. I'll leave this issue open for now and update it once we start working on this.

We are also considering providing a pretrained model from the Facemap dataset (mouse face from different angles) and the CRIM13 dataset (top-down view of two mice, one black and one white).

If anybody has additional pretrained models they would like to see (and importantly, a pointer to a labeled dataset), please let us know!

@themattinthehatt themattinthehatt self-assigned this May 24, 2024
@themattinthehatt themattinthehatt added enhancement New feature or request app Question about the app rather than lightning-pose labels May 24, 2024
@hummuscience
Copy link
Contributor Author

I just checked the AnimalZoo preprint and the mentioned references. It doesn't seem that any of the actual labeled datasets are available, but in some cases, the videos are (see here for example: https://zenodo.org/records/3608658).

I wonder if it would make sense to run the TopViewMouse model on these videos, extract some output frames (refine them in case of errors), and then convert the project to a DLC project. It might take some time, but could be worth the effort (for me, at least). How many frames would you aim for in this case?

The facemap model would be quite useful, since many people working with head-fixed mice could benefit from it.

The CRIM13 model is also interesting, but wouldn't have that many applications since it's specific for assays with two animals (one white the other black). Unless, this is a more common assay, and I am not aware of that.

On the subject of multianimal tracking. Are there currently any possibilities with LP?

@themattinthehatt
Copy link
Collaborator

themattinthehatt commented May 28, 2024

We've been in touch with the DLC folks and they plan to release the labeled datasets once the paper is out. However, it is not clear how long that will take so your suggestion is probably the quickest way to getting something workable. This isn't something we have the bandwidth to work on right now, but if you're up for trying it that would be great! I'd suggest labeling all the video frames with the TopViewMouse model, then doing the following:

  1. compute motion energy (absolute difference between keypoints on consecutive frames, averaged over all keypoints)
  2. remove frames with low motion energy (when the mouse is sitting still)
  3. for the remaining frames, run a clustering algorithm (kmeans is fine) on the poses and selecting 1000 clusters. Then you can take 1 frame from each cluster to get 1000 labeled frames with a decent amount of pose variability (see a related implementation here, though note this runs kmeans on a pca embedding of the frames rather than the predicted keypoints).
  4. refine errors in selected frames if necessary

I would also note that we find, even with the LP bells and whistles, that more labeled frames is always better. so if you're not too daunted by the refinement step, selecting 2k or 3k frames will almost certainly result in a more robust model than 1k frames. but 1k will certainly do a good job.

If you go down this route please let me know! Happy to keep discussing it with you.

Re: multi-animal tracking - this is something that we are working towards, but it will be a while before we have these features built into LP (unless the animals are visibly distinct, like in the CRIM13 dataset)

@hummuscience
Copy link
Contributor Author

So, I started to implement this for some of my own videos and currently refining the predictions on 300 images to test if it improves things.

Meanwhile, it seems like the datasets from the SuperAnimal paper are public now (or at least, I found them on zenodo: MausHaus and the whole TopViewMouse dataset.

I tried the pretrained model on some images from the MausHaus and the BlackMice datasets, but the results were actually not as good as I expected. The expectation would be that the pretrained model would reproduce the labels from the training set. But it didn't (or rather, performed poorly).

Now there is also the entire TopViewMouse dataset. There, not all keypoints are labelled in all images (since they come from different datasets). I wondered if I can just go ahead and train LP with it like that.

The other issue is that the annotations are in a JSON file. I think I could manage to convert it to LP format though.

@themattinthehatt
Copy link
Collaborator

I looked at the TopViewMouse demo a year or so back and also found that it did not perform as well as expected on a top-view mouse video that looked very similar to one in the training dataset. Good to know that you could replicate this finding.

I still haven't had a chance to look into the TopVIewMouse dataset - I do remember that not all keypoints are labeled in all images, but are the keypoints at least named the same when they are in the same location across datasets? If so then you can definitely train and LP model on these, you would just leave the ground truth label empty where it doesn't exist and then LP ignores this keypoint during training.

It shouldn't be too hard to convert the annotations from JSON to LP format. If you end up doing this please let me know and we can discuss best ways to train the model!

@hummuscience
Copy link
Contributor Author

The results are a bit better when one uses spatio-temporal adaptation. But not as I would expect.

Yes, the positions are the same. Then I will go ahead and try out the try. I will report :)

@themattinthehatt
Copy link
Collaborator

Awesome excited to see the results! I take it the labeled dataset doesn't have the associated videos as well? Maybe I can ask the Mathis's for that data, then we could extract the context frames and test out a context model as well.

@hummuscience
Copy link
Contributor Author

Training is running 👍 will post once its done.

Yeah, the dataset doesn't contain any videos. I think some of the origin datasets could have videos (pranav 2018 maybe?). But yeah, it could be easier to ask Mathis's for the videos. Even though it is possible that they won't have them...

Would it be possible to inform the context model with a non-context model somehow? Maybe one could use the PCA?

Btw, I wrote a script that automatically extracts the context frames for each image. Could that be useful to add as a utility in the scripts folder?

@hummuscience
Copy link
Contributor Author

First look at the training. It seems like I should be stopping the training earlier (150k?) or at least saving more checkpoints.

Test videos look very good :) I am thinking of training a DLC model witht he same dataset (maybe some shuffles?) to compare output. There are 300 additional frames not contained in the original TopViewMouse5k that come from my own datasets.

Will check the evaluation

I am quite new to LP so I am not so sure about the choices in the config.yaml file. I added it below.

Screenshot 2024-07-15 at 10 49 46
Screenshot 2024-07-15 at 10 49 29
Screenshot 2024-07-15 at 10 49 00
Screenshot 2024-07-15 at 10 48 45

data:
  image_orig_dims:
    height: 480
    width: 640
  image_resize_dims:
    height: 512
    width: 512
  data_dir: /mnt/hpc_slurm/home/abdelhaym/freely-moving/lp-models/all_mighty_mouse/
  video_dir: /mnt/hpc_slurm/home/abdelhaym/freely-moving/lp-models/all_mighty_mouse/videos/
  csv_file: CollectedData_all.csv
  downsample_factor: 2
  num_keypoints: 27
  keypoint_names:
  - nose
  - left_ear
  - right_ear
  - left_ear_tip
  - right_ear_tip
  - left_eye
  - right_eye
  - neck
  - mid_back
  - mouse_center
  - mid_backend
  - mid_backend2
  - mid_backend3
  - tail_base
  - tail1
  - tail2
  - tail3
  - tail4
  - tail5
  - left_shoulder
  - left_midside
  - left_hip
  - right_shoulder
  - right_midside
  - right_hip
  - tail_end
  - head_midpoint
  mirrored_column_matches: null
  columns_for_singleview_pca:
  - 1
  - 2
  - 7
  - 9
  - 13
  - 19
  - 20
  - 21
  - 22
  - 23
  - 24
  - 26
training:
  imgaug: dlc-top-down
  train_batch_size: 8
  val_batch_size: 48
  test_batch_size: 48
  train_prob: 0.8
  val_prob: 0.1
  train_frames: 1
  num_gpus: 1
  num_workers: 4
  early_stop_patience: 3
  unfreezing_epoch: 20
  min_epochs: 300
  max_epochs: 750
  log_every_n_steps: 10
  check_val_every_n_epoch: 5
  gpu_id: 0
  rng_seed_data_pt: 0
  rng_seed_model_pt: 0
  lr_scheduler: multisteplr
  lr_scheduler_params:
    multisteplr:
      milestones:
      - 150
      - 200
      - 250
      gamma: 0.5
model:
  losses_to_use: []
  backbone: resnet50_animal_ap10k
  model_type: heatmap
  heatmap_loss_type: mse
  model_name: test
  checkpoint: null
  lightning_pose_version: 1.4.0
dali:
  general:
    seed: 123456
  base:
    train:
      sequence_length: 64
    predict:
      sequence_length: 128
  context:
    train:
      batch_size: 16
    predict:
      sequence_length: 96
losses:
  pca_multiview:
    log_weight: 5.0
    components_to_keep: 3
    epsilon: null
  pca_singleview:
    log_weight: 5.0
    components_to_keep: 0.99
    epsilon: null
  temporal:
    log_weight: 5.0
    epsilon: 20.0
    prob_threshold: 0.05
eval:
  hydra_paths:
  - ' '
  predict_vids_after_training: true
  save_vids_after_training: true
  fiftyone:
    dataset_name: freemovetest
    model_display_names:
    - test_freemoving
    launch_app_from_script: false
    remote: true
    address: 127.0.0.1
    port: 5151
  test_videos_directory: /mnt/hpc_slurm/home/abdelhaym/freely-moving/lp-models/all_mighty_mouse/test_videos/
  saved_vid_preds_dir: null
  confidence_thresh_for_vid: 0.9
  video_file_to_plot: null
  pred_csv_files_to_plot:
  - ' '
callbacks:
  anneal_weight:
    attr_name: total_unsupervised_importance
    init_val: 0.0
    increase_factor: 0.01
    final_val: 1.0
    freeze_until_epoch: 0

@hummuscience
Copy link
Contributor Author

I tried to run a semi-supervised model (PCA or temporal), but it fails due to the input images being of different sizes.

I could rescale the images and the key points, do you think that would make sense?

@themattinthehatt
Copy link
Collaborator

themattinthehatt commented Jul 15, 2024

ah very cool, glad you were able to get this working!

I tried to run a semi-supervised model (PCA or temporal), but it fails due to the input images being of different sizes.

yes pca will require the frames to be the same sizes; but actually this is an interesting use case because even if you resized the frames to be the same size the size of the animal would vary a lot from dataset to dataset, so maybe PCA wouldn't work so well anyways. I'll have to give this some more thought.

config options

  • image_resize_dims: 512x512 is quite large, you might also try 256x256 or 384x384 and see if they work as well/better. model training/inference will certainly be faster with smaller resize dims. training can actually be more accurate in some cases too because the number of trainable parameters increases as you increase the resize dims, so I've seen (through other users) that even in 1000x1000 pixel images with a freely moving mouse resizing to something smaller than 512x512 is more accurate.
  • train_prob/val_prob: I would up the train_prob so you use more of the data for training. maybe use something like 0.95/0.05? that doesn't leave any leftover for test data, but it depends on how you want to test the model (could be on held-data that you labeled that isn't the CollectedData.csv file). also, there are a ton of frames in this dataset (>5k right?) so you could even go further to 0.98/0.02 or something (this would still give you >100 validation frames)
  • train_batch_size: if you decrease the resize_dims you could double this to 16 without memory issues I'm guessing; then the model will train faster

everything else looks good! I see you set imgaug to dlc-top-down already, which is great 👌

Would it be possible to inform the context model with a non-context model somehow? Maybe one could use the PCA?

the lack of context frames and the frames being different sizes means it is difficult to play with the context/semi-supervised features of the model. there's not a real easy workaround to either of these with this dataset that I can think of off the top of my head. So I guess I'd say use the fully supervised model first and see how that works for you?

First look at the training. It seems like I should be stopping the training earlier (150k?) or at least saving more checkpoints.

I'm kinda surprised the validation loss starts going up so early - but my intuition is that this is related to the big resize dim (512x512), I'm curious if this goes away with smaller resize dims.

@hummuscience
Copy link
Contributor Author

Re-training now with your suggestions.

I have a RTX A6000, so I increased the batch size to 32 without any memory error (after 40 epochs).
Went with 256 x 256 size
And 0.97 train, 0.02 val which ends up with 100+ validation images and 50+ test

Interestingly, the model is not much faster, an epoch takes 1 minute

I was checking the loss development with Tensorboard, and it seems like the loss is reducing quicker than the first model I trained.

It could be due to the larger training data size and larger batch size that the loss is reducing faster and its generalizing better.

Screenshot 2024-07-16 at 13 20 45
Screenshot 2024-07-16 at 13 20 14
Screenshot 2024-07-16 at 13 20 02

@hummuscience
Copy link
Contributor Author

Going through the evaluation results of the first model I trained, it seems like the model doesn't generalize across datasets where less keypoints are labelled.

For example, in this dataset, only 4 keypoints were labelled. When we predict all 26 keypoints, the model does quite a bad job at it, with a high confidence.

Am I doing something wrong?

In the SuperAnimal paper, they mention using gradient masking of the heatmaps to deal with this issue. I have no clear idea what that means, though.

Screenshot 2024-07-16 at 14 14 54

Screenshot 2024-07-16 at 14 15 05

@themattinthehatt
Copy link
Collaborator

Interestingly, the model is not much faster, an epoch takes 1 minute

I guess this makes sense, you have fewer batches but each batch takes longer to process. Is your GPU utilization at or near 100% while training? If not you could also increase training.num_workers to something like 6 or 8 (depending on the number of cores you have) and that would speed training up a bit.

I was checking the loss development with Tensorboard, and it seems like the loss is reducing quicker than the first model I trained.

You are correct that this is due to the increased batch size - you can see in the first model that there is a big dip around 10k iterations. This is actually due to the unfreezing of the backbone weights. When you increase the batch size you take fewer steps per epoch, and so the backbone is unfrozen earlier (in terms of number of gradient steps). Actually in your case because there are so many frames you could probably reduce the parameter training.unfreeze_epoch to something like 5 instead of 20, but that doesn't invalidate any of your current results.

For example, in this dataset, only 4 keypoints were labelled. When we predict all 26 keypoints, the model does quite a bad job at it, with a high confidence.

Huh this is a bit unexpected, I would have guessed the model could generalize better than this. I'm curious if the 256x256 model generalizes better. This is a great test though! Do you see this kind of issue with other datatsets that have few labeled keypoints? I forget how exactly the SuperAnimal paper does the gradient masking but I'll look into that and get back to you.

@hummuscience
Copy link
Contributor Author

hummuscience commented Jul 16, 2024

I found the code that DLC uses for gradient masking in their dlc_pytorch branch

        for b in range(batch_size):
            for heatmap_idx, group_keypoints in enumerate(coords[b]):
                for keypoint in group_keypoints:
                    # FIXME: Gradient masking weights should be parameters
                    if keypoint[-1] == 0:
                        # full gradient masking
                        weights[b, heatmap_idx] = 0.0
                    elif keypoint[-1] == -1:
                        # full gradient masking
                        weights[b, heatmap_idx] = 0.0

                    elif keypoint[-1] > 0:
                        # keypoint visible
                        self.update(
                            heatmap=heatmap[b, :, :, heatmap_idx],
                            grid=grid,
                            keypoint=keypoint[..., :2],
                            locref_map=self.get_locref(locref_map, b, heatmap_idx),
                            locref_mask=self.get_locref(locref_mask, b, heatmap_idx),
                        )

If I understand this correctly, the gradient masking is to the weights before the backwards pass? But I am not sure, since in the SuperAnimal paper they mention that it is applied before the loss calculation.

In LP the masking is done during [the loss calculation]

), as in, all NaNs are removed before loss calculation. Is that the same approach though?

Checking the SuperAnimal paper, there are some images from with vs. without masking

Screenshot 2024-07-16 at 16 24 23
Screenshot 2024-07-16 at 16 24 03

The methods part mentions the following:

Training naively on these projected annotations would harm the training stability, as the loss function penalizes undefined keypoints, as if they were not visible (i.e., occluded).

For stable training of our panoptic pose estimation model, we mask components of the loss function across keypoints. The keypoint mask $n_k$ is set to 1 if the keypoint $k$ is present in the annotation of the image and set to 0 if the keypoint is absent. We denote the predicted probability for keypoint $k$ at pixel $(i, j)$ as $p_k(i, j) \in [0, 1]$ and the respective label as $t_k(i, j) \in {0, 1}$, and formulate the masked $L_k$ error loss function as

$$ L_k = \sum_{k=1}^{m} n_k \sum_{i,j} |p_k(i, j) - t_k(i, j)|_z, $$

with $z=2$ for mean square error and $z=1$ for L1 loss (e.g., used for locref maps in DLCRNet) and the masked cross-entropy loss function as

$$ L_{CE} = - \sum_{k=1}^{m} \sum_{i,j} n_k t_k(i, j) \log p_k(i, j). $$

Note that we make distinct the difference between not annotated and not defined in the original dataset and we only mask undefined keypoints. This is important as, in the case of sideview animals, "not annotated" could also mean occluded/invisible. Adding masking to not annotated keypoints will encourage the model to assign high likelihood to occluded keypoints.

@hummuscience
Copy link
Contributor Author

Does LP distinguish between "unlabelled" keypoints (because of occlusions) and keypoints that were not labelled at all?

@themattinthehatt
Copy link
Collaborator

No, currently LP does not distinguish between the two - if a ground truth label is missing it is dropped from the loss function with the remove_nans function you pointed out (though see "Why does the network produce high confidence values..." here: https://lightning-pose.readthedocs.io/en/latest/source/faqs.html).

So yes, on a first pass it appears that LP by default does the same gradient masking that the SuperAnimal paper implements. Looking at the DLC function you linked, weights does not refer to the weights of the model, but rather the mask applied to the loss ($n_k$ in their notation above).

@themattinthehatt
Copy link
Collaborator

Oh something else I just realized: the TopViewMouse dataset contains some datasets with multiple animals - how do you deal with this right now? LP cannot currently handle multi-animal pose estimation.

@hummuscience
Copy link
Contributor Author

I removed the two datasets that have multiple animals (TriMouse and Golden Lab).

I also just realized that one of the Datasets (Kiehn Lab Openfield) actually doesn't have labels. So I will rerun the training without it

@themattinthehatt
Copy link
Collaborator

Great. Let me know how the generalization looks with the 256x256 model when done, I'm still scratching my head a bit about the bad performance in the frames you showed above.

@hummuscience
Copy link
Contributor Author

This is the current state of the training (it still has the Kiehn Lab Openfield data though).

The train_heatmap_mse_loss is plateauing as well as the RMSE loss. Not sure what to think about the val loss.

Should I stop the training in this case? If I stop the training, will it save the checkpoint? (I set the config to save multiple checkpoints but realized that it is not implemented on the dynamic_crop branch)

Screenshot 2024-07-16 at 20 17 10
Screenshot 2024-07-16 at 20 16 33

@themattinthehatt
Copy link
Collaborator

the noisiness in the validation plot is weird, especially compared to the black line. is black 512x512 and magenta 256x256?

one thing you can do is hit the three dots in the upper right hand corner of these plots and change the y-axis to a log scale, that's typically more helpful the further you get in training - my guess is that the train_heatmap_mse_loss isn't plateauing yet, it's just decreasing on smaller scales.

the model should be saving out weights along the way, you'll find them in the tb_logs directory (you'll have to go down a couple more subdirectories).

@themattinthehatt
Copy link
Collaborator

btw if you're training on the dynamic_crop branch I would recommend switching over to main and pulling the latest updates, I recently made an update to the validation dataloader that might be relevant here (a5e3831). Previously the validation data was passing through the image augmentation pipeline, so the "validation" data was actually different on every single pass. In the most up-to-date main branch that has been fixed and now the validation data is not augmented. Obviously you were doing the same thing with the 512x512 network and didn't see the weird spikes in the validation loss but wouldn't hurt to remove that factor.

@hummuscience
Copy link
Contributor Author

Alright, re-training now with the on the unsupervised multiview branch with 8 workers, and unfreeze epoch set to 5.

I will also have a look at how well the magenta model performed.

@hummuscience
Copy link
Contributor Author

So far so good. Orange is the latest model.

Screenshot 2024-07-17 at 17 29 13
Screenshot 2024-07-17 at 17 28 58

@hummuscience
Copy link
Contributor Author

(I wonder if we should move this to discussions?)

So, the model is trained and it seems to perform better than the previous ones. It might be due to the removal of the weird datasets though (2-3 mice and the one without labels).

I would have expected the early stopping to kick in at some point after 5.5 hours. But maybe I am misunderstanding how it works. how do I know which checkpoint was used for the inference in the end? I understand that lightning uses the checkpoint with the "best score" but I can't find out how this is determined...

Screenshot 2024-07-18 at 11 28 10
Screenshot 2024-07-18 at 11 27 46

The predictions on some datasets look quite good:
13
12
11

While on others, less so:
9
8

Specially this effect of all the predictions being "squished" into the central line of the animal of to the front of the animal.

One obvious difference betewen the two different types of datasets is the size of the animal. I thought this should only affect semi-supervised PCA methods, but I have a blank losses_to_use list in the config. For some weird reason though, my test_video predictions have a *_pca_singleview_error.csv file. Is the PCA still being computed and used? (I am on the semisupervised multiview branch).

This issue reminds me of the situation that DLC had with the SuperAnimal models and therefore adapted spatial-pyramid search during testing and video adaptation during inference of previously unseen data.

One mistake that I did was that I didn't adjust the train/va/test proportions after removing one of the datasets. I ended um with a small val/test dataset (80 images each) which did not contains all datasets. I wonder if it would make sense to perform the train/test/val split manually to make sure all datasets are represented in each split.

@themattinthehatt
Copy link
Collaborator

awesome, thanks so much for looking into this. yeah maybe we can switch over to the discord? https://discord.gg/tDUPdRj4BM

to answer your most recent questions here though:

  • if you set training.early_stopping to true in the config file the model doesn't actually perform early stopping (apologies for this, old nomenclature that we should update). rather, the trainer monitors the validation loss and saves out a model checkpoint every time the validation loss is lower than any previously recorded validation loss. the model is trained for the full number of epochs still. so when you load the weights you'll get the ones corresponding to that lowest validation loss epoch.
  • those first three images you shared look pretty good. do the turquoise markers correspond to the set of keypoints that were labeled for that particular dataset, and the blue markers the unlabeled keypoints? (if so that's very helpful). I'm not sure why the generalization to missing keypoints would work in some instances but not others, especially that last one where...well...it's just a dark mouse on a bright background.
  • regarding the PCA error file, if you set losses_to_use to be an empty list the model will not use the pca loss during training, but it will still automatically compute that loss on videos after running inference if the columns_for_pca_singleview isn't null
  • i still haven't read deeply into their video adaptation method, but I would think that's not necessary on frames from datasets that were used in the training
  • if you want to manually create a test dataset you can select a set of frames, remove those frames from the CollectedData.csv file, and place them in a file called CollectedData_new.csv. Then you can use all frames in CollectedData.csv for train/val. Our training function will automatically look for a file named CollectedData_new.csv (or, more generally, the name of the csv file used for training with "_new" appended to the end) and run inference on that. This is what we did for all of our OOD experiments in the LP paper, to have strict control over which frames were train/val vs OOD test.

@hummuscience
Copy link
Contributor Author

I am currently thinking that the issues with the labels on the sides of the animal getting bad predictions has to do with the relative proportion of these labels vs. others (such as ears or tail base) in the whole dataset.

I will try to do some stratification with the train/test data or some oversampling. I haven't found hints of this being implemented in LP until now. Do you think it might make sense to add?

@themattinthehatt
Copy link
Collaborator

That's a good observation, it is indeed possible that oversampling would force the model to focus more on these less-frequent keypoints. One way to do this would be to implement a custom sampler for the pytorch data loader, which would allow upweighting of the labeled examples with more keypoints.

I don't have the bandwidth to work on this for the time being, but happy to discuss the details more if you're interested in trying it out.

@hummuscience
Copy link
Contributor Author

Nice! Good to know there is a way to implement this directly, I will give it a look :)

In the meantime, I ran a few experiments. I took the whole datasets I am working with, created a OOD test-set (40 images from each modality/dataset), and the rest I used for training using a 90% train and 10% eval split.

I tried 4 different approaches.

  1. Trained untouched ("Original")
  2. "Oversampled": the images with the under-represented labels (created a "weight" per image based on how much under-represented labels it has) were oversampled while keeping the total number of images for training fixed. The final training set had a total of around 4400 images (1900 if duplicates are removed).
  3. "Oversampled-Inflated": since I am "losing" precious training data if I oversample while keeping the number of images the same, I inflated the total number of images by 1.5, kept the original dataset, and then added more images, but there I oversampled the under-represented images
  4. "Oversampled-2.5Inflated": The same as 3, but with 2.5 inflation, to makes sure even more under-represented images are present (in retrospect, I should have thought of a better way the bias the data more towards the under-represented labels).

Doing this lead to the following distributions of labels in the training data:

replative-bodypart-proportions

Here is the validation loss during training: (blue original, dark blue oversampled, pink inflated, purple 2.5 inflated)

Screenshot 2024-08-21 at 15 28 22

I also saved a snapshot every 50 epochs from each model to run the eval on the OOD dataset and see the development:

Here is the median log pixel error of some keypoints (mainly the problematic ones and mouse_center/tail_base where are present in each dataset. The vertical lines show the "best" model as chosen by LP
checkpoint-comparison-pixel-error

I wanted to see if the models perform similarly on different origin datasets. Here is the log pixel error for "left_hip"

checkpoint-comparison-pixelerror-bydataset

And tail_base:

checkpoint-comparison-pixelerror-bydataset-tail_base

It seems like the oversampling and inflation improves the model performance for these rare labels.

The varying performance between datasets is still a bit puzzling... Will have a closer look at the weights for each image, it might be that OFT/3CSI are over-represented for some reason.

@hummuscience
Copy link
Contributor Author

The more intersting question (at least for me) is what happens with this "collapsing" of these left_ and right_ keypoints to the center. Here is an example of the same frame from 4 different models. The "no_topview" model is one that was trained only on in-house data (without the TopViewMouse dataset). You can see that the lateral keypoints are "collapsing" into the spine of the animal in the other models.

model-comparison-collapse-example

If I plot the distances between the left and right keypoints, moments where this "collapse" is happening are situations where the distance is close to 0. compared to the "no_topview" model (which this collapsing doesnt happen), all trained models show this. And they are not even consistent as to when it happens (even though it seems like frames where the animal stands up, narrowing its body).

model-comparison-collapse

Here are the distributions of the distance for all trained models:

model-comparison-collapse-density

In this case, it seems like inflation of the training dataset leads to a worse result.

This is however not the case for all datasets. The OFT dataset which benefitted from the oversampling/inflation above has this problem solved :)

model-comparison-collapse-density-OFT

So I might be getting closer to the solution. I think it has to do with how I am oversampling and weighing the data :)

@themattinthehatt
Copy link
Collaborator

@hummuscience wow this is awesome, nice work! some misc comments/questions:

  1. the oversampled + inflated approach seems to be a simple (and potentially effective?) way forward. But I am also confused as to why the collapsing is still happening. In your final 3 plots above, are these results on labeled OOD frames or on a video?
  2. does the collapsing occur on all datasets, or just the datasets that didn't include the additional left_ and right_ keypoints?
  3. since you have your own in-house dataset, I'm curious if you see improved results by first fitting a model on the TopViewMouse dataset, and then fine-tuning on your own data
  4. is there an easy way to oversample + inflate the data such that the distribution over keypoints in your first plot is uniform (or closer to uniform)? the oversample_inflated and oversample_2.5inflated distributions are much more similar to each other than the original and oversampled distributions

@hummuscience
Copy link
Contributor Author

  1. Results on two videos. One is from the ZN_free_moving dataset, the other from the OFT dataset. The collapsing only affects the ZN_free_moving (and the GK_sickness, but that is due to too little labelled frames), but not the OFT dataset.

  2. Haven't tested on all datasets yet. I only have videos for OFT, EPM, ZN_free_moving, GK_sickness (last two are own projects). I could give them a try. All of these have labels for left_ and right_. But they lack some intermediate labels on the spine... I wonder if that is what is impacting the performance.

  3. That was a thought I had. It is basically what I did with the DLC model in the beginning. There is no possibility of using the TopViewMouse model as backbone to train a semi-supervised/context model, right?

  4. Yeah, I was thinking of how to do it. The issue is that some keypoints are present in every dataset (tail_base for example). So oversampling an image with left_ear_tip by giving it a higher weight would automatically lead to more tail_base keypoints in the dataset. Currently, I can't think of a better way to deal with this (unless I am missing some method). Do you have an idea?

@themattinthehatt
Copy link
Collaborator

  1. yes! you won't be able to train a semi-supervised/context model directly with the TopViewMouse data, but you can definitely use that supervised/no-context backbone to to train a semi-supervised/context model on your own data (in the same way we use pretrained ImageNet/AP10k backbones for those models). All you need to do is train on the TopViewMouse and then point to those weights in the config file. The context model weights will be randomly initialized, but the ResNet-50 weights from TopViewMouse will be loaded in.
  2. ah yes of course. Are you planning on labeling all possible points (including left_ and right_ keypoints) in your own data, or just sticking with the most common ones in the TopViewMouse dataset?

@hummuscience
Copy link
Contributor Author

  1. Sweeet! That's what I will try out then! First re-running the model with less val/test proportions and adding another dataset I was missing (with 1000ish images).

  2. In my own datasets, I have all keypoints labelled. But not many labelled images (in total I have around 300).

@hummuscience
Copy link
Contributor Author

  1. The collapsing happens in all datasets, except for the OFT dataset... Funnily, the OFT dataset has quite a small number of labels to begin with. Very puzzling. I will do some digging to see if I did a mistake somewhere.

@hummuscience
Copy link
Contributor Author

How do I actually do this?

"All you need to do is train on the TopViewMouse and then point to those weights in the config file. The context model weights will be randomly initialized, but the ResNet-50 weights from TopViewMouse will be loaded in."

I tried to pass the path to the ckpt file as backbone in the config, but that didn't work.

Now I went with resnet50_animal_ap10k as backbone and the checkpoint passed to the "checkpoint" argument. Is that the correct way?

@themattinthehatt
Copy link
Collaborator

Yes your second approach is the way to go! A little bit inefficient because it builds the backbone, then loads the ap10k weights, and then loads the weights from the checkpoint, but it's the most straightforward way to lay it out in the config file.

@hummuscience
Copy link
Contributor Author

hummuscience commented Aug 30, 2024

Here is a small update. I am not completely done with the tests but it seems like there is a solution to the “collapsing” issue.

The upper part is the pixel error on test images using different data as input (combined is GK + ZN, two different experiments).

The middle part is violin plots of the distance between left and right bodyparts in the same test video. The closer to 0, the worse the model.

I sorted them a bit by the average pixel error, to make it a bit easier to see patterns in the combinations.

I had a suspicion that something about the size of the input images was defining whether I see "collapsing" or not. So I ran some models with 512x512 instead of 256x256 as image_size. And this does the trick. I am not well enough versed in dealing with these models but it could be that reducing the size to 256x256 gets the keypoins "too close" to each other and the model has issues in learning to differentiate them?

All these models were run after pre-training on the TopViewMouse5k data with inflated+upsampled frames and 256x256 as image size. I also ran tests there after seeing these results, so I will post an update another time. But 512x512 seems to be a better choice for this data as it improves the result from datasets where the mice are small due to zoomed out images.

model-comparison-collapse-pixel-error

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
app Question about the app rather than lightning-pose enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants