Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-Identification Model #141

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ repos:
hooks:
- id: mdformat
additional_dependencies:
- mdformat-gfm==0.3.6
- mdformat-gfm==0.3.6
2 changes: 2 additions & 0 deletions luxonis_train/attached_modules/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .ohem_bce_with_logits import OHEMBCEWithLogitsLoss
from .ohem_cross_entropy import OHEMCrossEntropyLoss
from .ohem_loss import OHEMLoss
from .pml_loss import EmbeddingLossWrapper
from .reconstruction_segmentation_loss import ReconstructionSegmentationLoss
from .sigmoid_focal_loss import SigmoidFocalLoss
from .smooth_bce_with_logits import SmoothBCEWithLogitsLoss
Expand All @@ -26,4 +27,5 @@
"OHEMCrossEntropyLoss",
"OHEMBCEWithLogitsLoss",
"FOMOLocalizationLoss",
"EmbeddingLossWrapper",
]
119 changes: 119 additions & 0 deletions luxonis_train/attached_modules/losses/pml_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import logging

import pytorch_metric_learning.losses as pml_losses
from pytorch_metric_learning.losses import CrossBatchMemory
from torch import Tensor

from .base_loss import BaseLoss

logger = logging.getLogger(__name__)

ALL_EMBEDDING_LOSSES = [
"AngularLoss",
"ArcFaceLoss",
"CircleLoss",
"ContrastiveLoss",
"CosFaceLoss",
"DynamicSoftMarginLoss",
"FastAPLoss",
"HistogramLoss",
"InstanceLoss",
"IntraPairVarianceLoss",
"LargeMarginSoftmaxLoss",
"GeneralizedLiftedStructureLoss",
"LiftedStructureLoss",
"MarginLoss",
"MultiSimilarityLoss",
"NPairsLoss",
"NCALoss",
"NormalizedSoftmaxLoss",
"NTXentLoss",
"PNPLoss",
"ProxyAnchorLoss",
"ProxyNCALoss",
"RankedListLoss",
"SignalToNoiseRatioContrastiveLoss",
"SoftTripleLoss",
"SphereFaceLoss",
"SubCenterArcFaceLoss",
"SupConLoss",
"ThresholdConsistentMarginLoss",
"TripletMarginLoss",
"TupletMarginLoss",
]

CLASS_EMBEDDING_LOSSES = [
"ArcFaceLoss",
"CosFaceLoss",
"LargeMarginSoftmaxLoss",
"NormalizedSoftmaxLoss",
"ProxyAnchorLoss",
"ProxyNCALoss",
"SoftTripleLoss",
"SphereFaceLoss",
"SubCenterArcFaceLoss",
]


class EmbeddingLossWrapper(BaseLoss):
def __init__(
self,
loss_name: str,
embedding_size: int = 512,
cross_batch_memory_size=0,
num_classes: int = 0,
loss_kwargs: dict | None = None,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
if loss_kwargs is None:
loss_kwargs = {}

try:
loss_cls = getattr(pml_losses, loss_name)
except AttributeError as e:
raise ValueError(
f"Loss {loss_name} not found in pytorch_metric_learning"
) from e

if loss_name in CLASS_EMBEDDING_LOSSES:
if num_classes < 0:
raise ValueError(
f"Loss {loss_name} requires num_classes to be set to a positive value"
)
loss_kwargs["num_classes"] = num_classes
loss_kwargs["embedding_size"] = embedding_size

# If we wanted to support these losses, we would need to add a separate optimizer for them.
# They may be useful in some scenarios, so leaving this here for future reference.
raise ValueError(
f"Loss {loss_name} requires its own optimizer, and that is not currently supported."
)

self.loss_func = loss_cls(**loss_kwargs)

if cross_batch_memory_size > 0:
if loss_name in CrossBatchMemory.supported_losses():
self.loss_func = CrossBatchMemory(
self.loss_func, embedding_size=embedding_size
)
else:
logger.warning(
f"Cross batch memory is not supported for {loss_name}. Ignoring cross_batch_memory_size."
)

def prepare(
self, inputs: dict[str, list[Tensor]], labels: dict[str, list[Tensor]]
) -> tuple[Tensor, Tensor]:
embeddings = self.get_input_tensors(inputs, "features")[0]

if labels is None or "id" not in labels:
raise ValueError("Labels must contain 'id' key")

ids = labels["id"][0][:, 0]
return embeddings, ids

def forward(self, inputs: Tensor, target: Tensor) -> Tensor:
loss = self.loss_func(inputs, target)
return loss
3 changes: 3 additions & 0 deletions luxonis_train/attached_modules/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .mean_average_precision import MeanAveragePrecision
from .mean_average_precision_keypoints import MeanAveragePrecisionKeypoints
from .object_keypoint_similarity import ObjectKeypointSimilarity
from .pml_metrics import ClosestIsPositiveAccuracy, MedianDistances
from .torchmetrics import Accuracy, F1Score, JaccardIndex, Precision, Recall

__all__ = [
Expand All @@ -15,5 +16,7 @@
"ObjectKeypointSimilarity",
"Precision",
"Recall",
"ClosestIsPositiveAccuracy",
"ConfusionMatrix",
"MedianDistances",
]
Loading
Loading