Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-Identification Model #141

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions luxonis_train/attached_modules/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .ohem_bce_with_logits import OHEMBCEWithLogitsLoss
from .ohem_cross_entropy import OHEMCrossEntropyLoss
from .ohem_loss import OHEMLoss
from .pml_loss import MetricLearningLoss
klemen1999 marked this conversation as resolved.
Show resolved Hide resolved
from .reconstruction_segmentation_loss import ReconstructionSegmentationLoss
from .sigmoid_focal_loss import SigmoidFocalLoss
from .smooth_bce_with_logits import SmoothBCEWithLogitsLoss
Expand All @@ -26,4 +27,5 @@
"OHEMCrossEntropyLoss",
"OHEMBCEWithLogitsLoss",
"FOMOLocalizationLoss",
"MetricLearningLoss",
]
122 changes: 122 additions & 0 deletions luxonis_train/attached_modules/losses/pml_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import warnings

from pytorch_metric_learning.losses import (
AngularLoss,
ArcFaceLoss,
CircleLoss,
ContrastiveLoss,
CosFaceLoss,
CrossBatchMemory,
DynamicSoftMarginLoss,
FastAPLoss,
GeneralizedLiftedStructureLoss,
HistogramLoss,
InstanceLoss,
IntraPairVarianceLoss,
LargeMarginSoftmaxLoss,
LiftedStructureLoss,
ManifoldLoss,
MarginLoss,
MultiSimilarityLoss,
NCALoss,
NormalizedSoftmaxLoss,
NPairsLoss,
NTXentLoss,
P2SGradLoss,
PNPLoss,
ProxyAnchorLoss,
ProxyNCALoss,
RankedListLoss,
SignalToNoiseRatioContrastiveLoss,
SoftTripleLoss,
SphereFaceLoss,
SubCenterArcFaceLoss,
SupConLoss,
TripletMarginLoss,
TupletMarginLoss,
)
from torch import Tensor

from .base_loss import BaseLoss

# Dictionary mapping string keys to loss classes
loss_dict = {
"AngularLoss": AngularLoss,
"ArcFaceLoss": ArcFaceLoss,
"CircleLoss": CircleLoss,
"ContrastiveLoss": ContrastiveLoss,
"CosFaceLoss": CosFaceLoss,
"DynamicSoftMarginLoss": DynamicSoftMarginLoss,
"FastAPLoss": FastAPLoss,
"GeneralizedLiftedStructureLoss": GeneralizedLiftedStructureLoss,
"InstanceLoss": InstanceLoss,
"HistogramLoss": HistogramLoss,
"IntraPairVarianceLoss": IntraPairVarianceLoss,
"LargeMarginSoftmaxLoss": LargeMarginSoftmaxLoss,
"LiftedStructureLoss": LiftedStructureLoss,
"ManifoldLoss": ManifoldLoss,
"MarginLoss": MarginLoss,
"MultiSimilarityLoss": MultiSimilarityLoss,
"NCALoss": NCALoss,
"NormalizedSoftmaxLoss": NormalizedSoftmaxLoss,
"NPairsLoss": NPairsLoss,
"NTXentLoss": NTXentLoss,
"P2SGradLoss": P2SGradLoss,
"PNPLoss": PNPLoss,
"ProxyAnchorLoss": ProxyAnchorLoss,
"ProxyNCALoss": ProxyNCALoss,
"RankedListLoss": RankedListLoss,
"SignalToNoiseRatioContrastiveLoss": SignalToNoiseRatioContrastiveLoss,
"SoftTripleLoss": SoftTripleLoss,
"SphereFaceLoss": SphereFaceLoss,
"SubCenterArcFaceLoss": SubCenterArcFaceLoss,
"SupConLoss": SupConLoss,
"TripletMarginLoss": TripletMarginLoss,
"TupletMarginLoss": TupletMarginLoss,
}


class MetricLearningLoss(BaseLoss):
CaptainTrojan marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
loss_name: str,
embedding_size: int = 512,
cross_batch_memory_size=0,
loss_kwargs: dict | None = None,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
if loss_kwargs is None:
loss_kwargs = {}
self.loss_func = loss_dict[loss_name](
CaptainTrojan marked this conversation as resolved.
Show resolved Hide resolved
CaptainTrojan marked this conversation as resolved.
Show resolved Hide resolved
**loss_kwargs
) # Instantiate the loss object
if cross_batch_memory_size > 0:
if loss_name in CrossBatchMemory.supported_losses():
self.loss_func = CrossBatchMemory(
self.loss_func, embedding_size=embedding_size
)
else:
# Warn that cross_batch_memory_size is ignored
CaptainTrojan marked this conversation as resolved.
Show resolved Hide resolved
warnings.warn(

Check warning on line 102 in luxonis_train/attached_modules/losses/pml_loss.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/attached_modules/losses/pml_loss.py#L102

Added line #L102 was not covered by tests
CaptainTrojan marked this conversation as resolved.
Show resolved Hide resolved
f"Cross batch memory is not supported for {loss_name}. Ignoring cross_batch_memory_size"
)

# self.miner_func = miner_func
CaptainTrojan marked this conversation as resolved.
Show resolved Hide resolved

def prepare(self, inputs, labels):
CaptainTrojan marked this conversation as resolved.
Show resolved Hide resolved
embeddings = inputs["features"][0]
CaptainTrojan marked this conversation as resolved.
Show resolved Hide resolved

assert (
CaptainTrojan marked this conversation as resolved.
Show resolved Hide resolved
labels is not None and "id" in labels
), "ID labels are required for metric learning losses"
IDs = labels["id"][0][:, 0]
CaptainTrojan marked this conversation as resolved.
Show resolved Hide resolved
return embeddings, IDs

def forward(self, inputs: Tensor, target: Tensor):
CaptainTrojan marked this conversation as resolved.
Show resolved Hide resolved
# miner_output = self.miner_func(inputs, target)

loss = self.loss_func(inputs, target)

return loss
3 changes: 3 additions & 0 deletions luxonis_train/attached_modules/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .mean_average_precision import MeanAveragePrecision
from .mean_average_precision_keypoints import MeanAveragePrecisionKeypoints
from .object_keypoint_similarity import ObjectKeypointSimilarity
from .pml_metrics import ClosestIsPositiveAccuracy, MedianDistances
from .torchmetrics import Accuracy, F1Score, JaccardIndex, Precision, Recall

__all__ = [
Expand All @@ -14,4 +15,6 @@
"ObjectKeypointSimilarity",
"Precision",
"Recall",
"ClosestIsPositiveAccuracy",
"MedianDistances",
]
Loading
Loading