-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #69 from gabrielmscampos/feat/predict-with-jetmet-…
…ml-models feat: add ml model inference pipeline coupled to file_ingesting_pipeline
- Loading branch information
Showing
51 changed files
with
2,672 additions
and
48 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -336,3 +336,4 @@ DQMIO/ | |
DQMIO_samples/ | ||
usercert.pem | ||
userkey.pem | ||
ML_models/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from django.apps import AppConfig | ||
|
||
|
||
class MLBadLumisectionConfig(AppConfig): | ||
default_auto_field = "django.db.models.BigAutoField" | ||
name = "ml_bad_lumisection" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from typing import ClassVar | ||
|
||
from django_filters import rest_framework as filters | ||
from utils import filters_mixins | ||
|
||
from .models import MLBadLumisection | ||
|
||
|
||
class MLBadLumisectionFilter(filters_mixins.DatasetFilterMethods, filters_mixins.MEsMethods, filters.FilterSet): | ||
class Meta: | ||
model = MLBadLumisection | ||
fields: ClassVar[dict[str, list[str]]] = { | ||
"model_id": ["exact", "in"], | ||
"dataset_id": ["exact"], | ||
"me_id": ["exact"], | ||
"run_number": ["exact", "in"], | ||
"ls_number": ["exact"], | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from typing import ClassVar | ||
|
||
from django.db import models | ||
|
||
|
||
class MLBadLumisection(models.Model): | ||
""" | ||
- Django doesn't support composite primary key | ||
- The unique constraint set in this class do not exist in the database, | ||
it is used here to select the composite primary key in the viewset and as a documentation | ||
""" | ||
|
||
model_id = models.BigIntegerField(primary_key=True) | ||
dataset_id = models.BigIntegerField() | ||
file_id = models.BigIntegerField() | ||
run_number = models.IntegerField() | ||
ls_number = models.IntegerField() | ||
me_id = models.IntegerField() | ||
mse = models.FloatField() | ||
|
||
class Meta: | ||
managed = False | ||
db_table = "fact_ml_bad_lumis" | ||
constraints: ClassVar[list[models.Index]] = [ | ||
models.UniqueConstraint( | ||
name="fact_ml_bad_lumis_primary_key", | ||
fields=["model_id", "dataset_id", "run_number", "ls_number", "me_id"], | ||
), | ||
] | ||
|
||
def __str__(self) -> str: | ||
return f"MLBadLumisection <{self.me_id}@{self.ls_number}@{self.run_number}@{self.dataset_id}@{self.model_id}>" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from rest_framework import routers | ||
|
||
from .viewsets import MLBadLumisectionViewSet | ||
|
||
|
||
router = routers.SimpleRouter() | ||
router.register(r"ml-bad-lumisection", MLBadLumisectionViewSet, basename="ml-bad-lumisection") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from rest_framework import serializers | ||
|
||
from .models import MLBadLumisection | ||
|
||
|
||
class MLBadLumisectionSerializer(serializers.ModelSerializer): | ||
class Meta: | ||
model = MLBadLumisection | ||
fields = "__all__" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
import logging | ||
from typing import ClassVar | ||
|
||
from django.conf import settings | ||
from django.shortcuts import get_object_or_404 | ||
from django.utils.decorators import method_decorator | ||
from django.views.decorators.cache import cache_page | ||
from django.views.decorators.vary import vary_on_headers | ||
from django_filters.rest_framework import DjangoFilterBackend | ||
from lumisection.models import Lumisection | ||
from ml_models_index.models import MLModelsIndex | ||
from rest_framework import mixins, viewsets | ||
from rest_framework.authentication import BaseAuthentication | ||
from rest_framework.decorators import action | ||
from rest_framework.exceptions import ValidationError | ||
from rest_framework.response import Response | ||
from utils.common import list_to_range | ||
from utils.db_router import GenericViewSetRouter | ||
from utils.rest_framework_cern_sso.authentication import ( | ||
CERNKeycloakClientSecretAuthentication, | ||
CERNKeycloakConfidentialAuthentication, | ||
) | ||
|
||
from .filters import MLBadLumisectionFilter | ||
from .models import MLBadLumisection | ||
from .serializers import MLBadLumisectionSerializer | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
composite_pks = next(filter(lambda x: "primary_key" in x.name, MLBadLumisection._meta.constraints), None) | ||
|
||
|
||
@method_decorator(cache_page(settings.CACHE_TTL), name="list") | ||
@method_decorator(cache_page(settings.CACHE_TTL), name="get_object") | ||
@method_decorator(vary_on_headers(settings.WORKSPACE_HEADER), name="list") | ||
@method_decorator(vary_on_headers(settings.WORKSPACE_HEADER), name="get_object") | ||
class MLBadLumisectionViewSet(GenericViewSetRouter, mixins.ListModelMixin, viewsets.GenericViewSet): | ||
queryset = MLBadLumisection.objects.all().order_by(*composite_pks.fields) | ||
serializer_class = MLBadLumisectionSerializer | ||
filterset_class = MLBadLumisectionFilter | ||
filter_backends: ClassVar[list[DjangoFilterBackend]] = [DjangoFilterBackend] | ||
authentication_classes: ClassVar[list[BaseAuthentication]] = [ | ||
CERNKeycloakClientSecretAuthentication, | ||
CERNKeycloakConfidentialAuthentication, | ||
] | ||
|
||
@action( | ||
detail=False, | ||
methods=["GET"], | ||
url_path=r"(?P<model_id>\d+)/(?P<dataset_id>\d+)/(?P<run_number>\d+)/(?P<ls_number>\d+)/(?P<me_id>\d+)", | ||
) | ||
def get_object(self, request, model_id=None, dataset_id=None, run_number=None, ls_number=None, me_id=None): | ||
# Since the MLBadLumisection table in the database has a composite primary key | ||
# that Django doesn't support, we are defining this method | ||
# as a custom retrieve method to query this table by the composite primary key | ||
try: | ||
model_id = int(model_id) | ||
dataset_id = int(dataset_id) | ||
run_number = int(run_number) | ||
ls_number = int(ls_number) | ||
me_id = int(me_id) | ||
except ValueError as err: | ||
raise ValidationError( | ||
"model_id, dataset_id, run_number, ls_number and me_id must be valid integers." | ||
) from err | ||
|
||
queryset = self.get_queryset() | ||
queryset = get_object_or_404( | ||
queryset, model_id=model_id, dataset_id=dataset_id, run_number=run_number, ls_number=ls_number, me_id=me_id | ||
) | ||
serializer = self.serializer_class(queryset) | ||
return Response(serializer.data) | ||
|
||
@action(detail=False, methods=["GET"], url_path=r"cert-json") | ||
def generate_certificate_json(self, request): | ||
try: | ||
dataset_id = int(request.query_params.get("dataset_id")) | ||
run_number = list(map(int, request.query_params.get("run_number__in").split(","))) | ||
model_id = list(map(int, request.query_params.get("model_id__in").split(","))) | ||
except ValueError as err: | ||
raise ValidationError( | ||
"dataset_id and run_number must be valid integers and model_ids a valid list of integers" | ||
) from err | ||
|
||
# Select user's workspace | ||
workspace = self.get_workspace() | ||
|
||
# Fetch models' metadata in the given workspace | ||
models = MLModelsIndex.objects.using(workspace).filter(model_id__in=model_id).all().values() | ||
models = {qs.get("model_id"): qs for qs in models} | ||
|
||
# Fetch predictions for a given dataset, multiple runs from multiple models | ||
queryset = self.get_queryset() | ||
result = ( | ||
queryset.filter(dataset_id=dataset_id, run_number__in=run_number, model_id__in=model_id) | ||
.all() | ||
.order_by("run_number", "ls_number") | ||
.values() | ||
) | ||
result = [qs for qs in result] | ||
|
||
# Format bad lumi certification json | ||
response = {} | ||
for run in run_number: | ||
response[run] = {} | ||
predictions_in_run = [res for res in result if res.get("run_number") == run] | ||
unique_ls = [res.get("ls_number") for res in predictions_in_run] | ||
for ls in unique_ls: | ||
response[run][ls] = [] | ||
predictions_in_ls = [res for res in predictions_in_run if res.get("ls_number") == ls] | ||
for preds in predictions_in_ls: | ||
mse = preds.get("mse") | ||
model_id = preds.get("model_id") | ||
me_id = preds.get("me_id") | ||
filename = models[model_id].get("filename") | ||
target_me = models[model_id].get("target_me") | ||
response[run][ls].append( | ||
{"model_id": model_id, "me_id": me_id, "filename": filename, "me": target_me, "mse": mse} | ||
) | ||
|
||
return Response(response) | ||
|
||
@action(detail=False, methods=["GET"], url_path=r"golden-json") | ||
def generate_golden_json(self, request): | ||
try: | ||
dataset_id = int(request.query_params.get("dataset_id")) | ||
run_number = list(map(int, request.query_params.get("run_number__in").split(","))) | ||
model_id = list(map(int, request.query_params.get("model_id__in").split(","))) | ||
except ValueError as err: | ||
raise ValidationError( | ||
"dataset_id and run_number must be valid integers and model_ids a valid list of integers" | ||
) from err | ||
|
||
# Select user's workspace | ||
workspace = self.get_workspace() | ||
|
||
# Fetch predictions for a given dataset, multiple runs from multiple models | ||
queryset = self.get_queryset() | ||
result = ( | ||
queryset.filter(dataset_id=dataset_id, run_number__in=run_number, model_id__in=model_id) | ||
.all() | ||
.order_by("run_number", "ls_number") | ||
.values() | ||
) | ||
result = [qs for qs in result] | ||
|
||
# Generate ML golden json | ||
response = {} | ||
for run in run_number: | ||
queryset = self.get_queryset() | ||
bad_lumis = ( | ||
queryset.filter(dataset_id=dataset_id, run_number=run, model_id__in=model_id) | ||
.all() | ||
.order_by("ls_number") | ||
.values_list("ls_number", flat=True) | ||
.distinct() | ||
) | ||
bad_lumis = [qs for qs in bad_lumis] | ||
all_lumis = ( | ||
Lumisection.objects.using(workspace) | ||
.filter(dataset_id=dataset_id, run_number=run) | ||
.all() | ||
.values_list("ls_number", flat=True) | ||
) | ||
good_lumis = [ls for ls in all_lumis if ls not in bad_lumis] | ||
response[run] = list_to_range(good_lumis) | ||
|
||
return Response(response) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from django.apps import AppConfig | ||
|
||
|
||
class MLModelsIndexConfig(AppConfig): | ||
default_auto_field = "django.db.models.BigAutoField" | ||
name = "ml_models_index" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from typing import ClassVar | ||
|
||
from django_filters import rest_framework as filters | ||
|
||
from .models import MLModelsIndex | ||
|
||
|
||
class MLModelsIndexFilter(filters.FilterSet): | ||
class Meta: | ||
model = MLModelsIndex | ||
fields: ClassVar[dict[str, list[str]]] = { | ||
"model_id": ["exact", "in"], | ||
"target_me": ["exact", "regex"], | ||
"active": ["exact"], | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from typing import ClassVar | ||
|
||
from django.db import models | ||
|
||
|
||
class MLModelsIndex(models.Model): | ||
model_id = models.IntegerField(primary_key=True) | ||
filename = models.CharField(max_length=255) | ||
target_me = models.CharField(max_length=255) | ||
active = models.BooleanField() | ||
|
||
class Meta: | ||
managed = False | ||
db_table = "dim_ml_models_index" | ||
indexes: ClassVar[list[models.Index]] = [ | ||
models.Index(name="idx_active", fields=["active"]), | ||
] | ||
|
||
def __str__(self) -> str: | ||
return f"Model <{self.model_id}>" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from rest_framework import routers | ||
|
||
from .viewsets import MLModelsIndexViewSet | ||
|
||
|
||
router = routers.SimpleRouter() | ||
router.register(r"ml-models-index", MLModelsIndexViewSet, basename="ml-models-index") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from rest_framework import serializers | ||
|
||
from .models import MLModelsIndex | ||
|
||
|
||
class MLModelsIndexSerializer(serializers.ModelSerializer): | ||
class Meta: | ||
model = MLModelsIndex | ||
fields = "__all__" |
Oops, something went wrong.