Skip to content

Commit

Permalink
Crucible service unit test coverage
Browse files Browse the repository at this point in the history
This is a draft to archive what I've done before I evaporate for the rest of
2024, because, you know, I could get run over by a reindeer or something...

This brings crucible service test coverage to 50%.
  • Loading branch information
dbutenhof committed Dec 13, 2024
1 parent 1446089 commit 840a9cf
Show file tree
Hide file tree
Showing 4 changed files with 769 additions and 81 deletions.
150 changes: 87 additions & 63 deletions backend/app/services/crucible_svc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
aggregate, or Plotly graph format for UI display.
"""

import time
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime, timezone
import logging
import time
from typing import Any, Iterator, Optional, Tuple, Union

from elasticsearch import AsyncElasticsearch, NotFoundError
from elasticsearch import AsyncElasticsearch
from fastapi import HTTPException, status
from pydantic import BaseModel

Expand Down Expand Up @@ -272,6 +273,15 @@ class CrucibleService:
"source",
)

# Set up a Logger at class level rather than at each instance creation
formatter = logging.Formatter(
"%(asctime)s %(process)d:%(thread)d %(levelname)s %(module)s:%(lineno)d %(message)s"
)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger = logging.getLogger("CrucibleService")
logger.addHandler(handler)

def __init__(self, configpath: str = "crucible"):
"""Initialize a Crucible CDM (OpenSearch) connection.
Expand All @@ -291,11 +301,23 @@ def __init__(self, configpath: str = "crucible"):
self.auth = (self.user, self.password) if self.user or self.password else None
self.url = self.cfg.get(configpath + ".url")
self.elastic = AsyncElasticsearch(self.url, basic_auth=self.auth)
self.logger.info("Initializing CDM V7 service to %s", self.url)

@staticmethod
def _get_index(root: str) -> str:
"""Expand the root index name to the full name"""
return "cdmv7dev-" + root

@staticmethod
def _get(source: dict[str, Any], fields: list[str], default: Optional[Any] = None):
"""Safely traverse nested dictionaries with a default value"""
r = source
last_missing = False
for f in fields:
last_missing = f not in r
r = r.get(f, {})
return default if last_missing else r

@staticmethod
def _split_list(alist: Optional[list[str]] = None) -> list[str]:
"""Split a list of parameters
Expand Down Expand Up @@ -342,26 +364,24 @@ def _normalize_date(value: Optional[Union[int, str, datetime]]) -> int:
return value
elif isinstance(value, datetime):
return int(value.timestamp() * 1000.0)
elif isinstance(value, str):
else:
# If it's a stringified int, convert & return; otherwise try
# to decode as a date string.
try:
return int(value)
except ValueError:
pass
try:
d = datetime.fromisoformat(value)
return int(d.timestamp() * 1000.0)
except ValueError:
pass
d = datetime.fromisoformat(value)
return int(d.timestamp() * 1000.0)
except Exception as e:
print(f"normalizing {type(value).__name__} {value} failed with {str(e)}")
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
f"Date representation {value} is not a date string or timestamp",
f"Date representation {value} is not valid: {str(e)!r}",
)

@staticmethod
@classmethod
def _hits(
payload: dict[str, Any], fields: Optional[list[str]] = None
cls, payload: dict[str, Any], fields: Optional[list[str]] = None
) -> Iterator[dict[str, Any]]:
"""Helper to iterate through OpenSearch query matches
Expand All @@ -377,20 +397,19 @@ def _hits(
Returns:
Yields each object from the "greatest hits" list
"""
if "hits" not in payload:
if "hits" not in payload or not isinstance(payload["hits"], dict):
raise HTTPException(
status_code=500, detail=f"Attempt to iterate hits for {payload}"
)
hits = payload.get("hits", {}).get("hits", [])
hits = cls._get(payload, ["hits", "hits"], [])
for h in hits:
source = h["_source"]
if fields:
for f in fields:
source = source[f]
yield source
yield source if not fields else cls._get(source, fields)

@staticmethod
def _aggs(payload: dict[str, Any], aggregation: str) -> Iterator[dict[str, Any]]:
@classmethod
def _aggs(
cls, payload: dict[str, Any], aggregation: str
) -> Iterator[dict[str, Any]]:
"""Helper to access OpenSearch aggregations
Iteratively yields the name and value of each aggregation returned
Expand All @@ -403,18 +422,20 @@ def _aggs(payload: dict[str, Any], aggregation: str) -> Iterator[dict[str, Any]]
Returns:
Yields each aggregation from an aggregation bucket list
"""
if "aggregations" not in payload:
if "aggregations" not in payload or not isinstance(
payload["aggregations"], dict
):
raise HTTPException(
status_code=500,
detail=f"Attempt to iterate missing aggregations for {payload}",
)
aggs = payload["aggregations"]
if aggregation not in aggs:
if aggregation not in aggs or not isinstance(aggs[aggregation], dict):
raise HTTPException(
status_code=500,
detail=f"Attempt to iterate missing aggregation {aggregation} for {payload}",
detail=f"Attempt to iterate missing aggregation {aggregation!r} for {payload}",
)
for agg in aggs[aggregation]["buckets"]:
for agg in cls._get(aggs, [aggregation, "buckets"], []):
yield agg

@staticmethod
Expand All @@ -423,7 +444,9 @@ def _format_timestamp(timestamp: Union[str, int]) -> str:
try:
ts = int(timestamp)
except Exception as e:
print(f"ERROR: invalid {timestamp!r}: {str(e)!r}")
CrucibleService.logger.warning(
"invalid timestamp %r: %r", timestamp, str(e)
)
ts = 0
return str(datetime.fromtimestamp(ts / 1000.00, timezone.utc))

Expand Down Expand Up @@ -581,7 +604,7 @@ def _build_name_filters(
n, v = e.split("=", maxsplit=1)
except ValueError:
raise HTTPException(
status.HTTP_400_BAD_REQUEST, f"Filter item {e} must be '<k>=<v>'"
status.HTTP_400_BAD_REQUEST, f"Filter item {e!r} must be '<k>=<v>'"
)
filters.append({"term": {f"metric_desc.names.{n}": v}})
return filters
Expand Down Expand Up @@ -656,7 +679,7 @@ def _build_metric_filters(
)

@classmethod
def _build_sort_terms(cls, sorters: Optional[list[str]]) -> list[dict[str, str]]:
def _build_sort_terms(cls, sorters: Optional[list[str]]) -> list[dict[str, Any]]:
"""Build sort term list
Sorters may reference any native `run` index field and must specify
Expand All @@ -676,16 +699,16 @@ def _build_sort_terms(cls, sorters: Optional[list[str]]) -> list[dict[str, str]]
if dir not in cls.DIRECTIONS:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
f"Sort direction {dir!r} must be one of {','.join(DIRECTIONS)}",
f"Sort direction {dir!r} must be one of {','.join(cls.DIRECTIONS)}",
)
if key not in cls.FIELDS:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
f"Sort key {key!r} must be one of {','.join(FIELDS)}",
f"Sort key {key!r} must be one of {','.join(cls.FIELDS)}",
)
sort_terms.append({f"run.{key}": dir})
sort_terms.append({f"run.{key}": {"order": dir}})
else:
sort_terms = [{"run.begin": "asc"}]
sort_terms = [{"run.begin": {"order": "asc"}}]
return sort_terms

async def _search(
Expand All @@ -704,9 +727,11 @@ async def _search(
idx = self._get_index(index)
start = time.time()
value = await self.elastic.search(index=idx, body=query, **kwargs)
print(
f"QUERY on {idx} took {time.time() - start} seconds, "
f"hits: {value.get('hits', {}).get('total')}"
self.logger.info(
"QUERY on %s took %.3f seconds, hits: %d",
idx,
time.time() - start,
value.get("hits", {}).get("total"),
)
return value

Expand Down Expand Up @@ -777,6 +802,11 @@ async def _get_metric_ids(
422 HTTP error (UNPROCESSABLE CONTENT) with a response body showing
the unsatisfied breakouts (name and available values).
TODO: Instead of either single metric or aggregation across multiple
metrics, we should support "breakouts", which would individually
process (graph, summarize, or list) data for each "loose" breakout
name. E.g., Busy-CPU might list per-core, or per-processor mode.
Args:
run: run ID
metric: combined metric name (e.g., sar-net::packets-sec)
Expand Down Expand Up @@ -823,7 +853,7 @@ async def _get_metric_ids(
# We want to help filter a consistent summary, so only show those
# breakout names with more than one value.
response["names"] = {n: sorted(v) for n, v in names.items() if v and len(v) > 1}
response["periods"] = list(periods)
response["periods"] = sorted(periods)
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=response
)
Expand Down Expand Up @@ -851,30 +881,26 @@ async def _build_timestamp_range_filters(
)
start = None
end = None
name = "<unknown>"
for h in self._hits(matches):
p = h["period"]
st = p["begin"]
et = p["end"]

# If any period is missing a timestamp, use the run's timestamp
# instead to avoid blowing up on a CDM error.
if st is None:
st = h["run"]["begin"]
if et is None:
et = h["run"]["end"]
st = p.get("begin")
et = p.get("end")
if not st or not et:
name = (
f"run {self._get(h, ['run', 'benchmark'])}:"
f"{self._get(h, ['run', 'begin'])},"
f"iteration {self._get(h, ['iteration', 'num'])},"
f"sample {self._get(h, ['sample', 'num'])}"
)
if st and (not start or st < start):
start = st
if et and (not end or et > end):
end = et
if start is None or end is None:
name = (
f"{h['run']['benchmark']}:{h['run']['begin']}-"
f"{h['iteration']['num']}-{h['sample']['num']}-"
f"{p['name']}"
)
raise HTTPException(
status.HTTP_422_UNPROCESSABLE_ENTITY,
f"Unable to compute {name!r} time range {start!r} -> {end!r}",
f"Unable to compute {name!r} time range: the run is missing period timestamps",
)
return [
{"range": {"metric_data.begin": {"gte": str(start)}}},
Expand All @@ -901,7 +927,7 @@ async def _get_run_ids(
filtered = await self.search(
index, source="run.id", filters=filters, ignore_unavailable=True
)
print(f"HITS: {filtered['hits']['hits']}")
self.logger.debug("HITS: %s", filtered["hits"]["hits"])
return set([x for x in self._hits(filtered, ["run", "id"])])

async def get_run_filters(self) -> dict[str, dict[str, list[str]]]:
Expand Down Expand Up @@ -1185,7 +1211,7 @@ async def get_runs(
run["begin_date"] = self._format_timestamp(run["begin"])
run["end_date"] = self._format_timestamp(run["end"])
except KeyError as e:
print(f"Missing 'run' key {str(e)} in {run}")
self.logger.warning("Missing 'run' key %r in %s", str(e), run)
run["begin_date"] = self._format_timestamp("0")
run["end_date"] = self._format_timestamp("0")

Expand Down Expand Up @@ -1254,8 +1280,11 @@ async def get_params(
iter = param["iteration"]["id"]
arg = param["param"]["arg"]
val = param["param"]["val"]
if response.get(iter) and response.get(iter).get(arg):
print(f"Duplicate param {arg} for iteration {iter}")
old = self._get(response, [iter, arg])
if old:
self.logger.warning(
"Duplicate param %s for iteration %s (%r, %r)", arg, iter, old, val
)
response[iter][arg] = val

# Filter out all parameter values that don't exist in all or which have
Expand Down Expand Up @@ -1313,7 +1342,6 @@ async def get_samples(
)
samples = []
for s in self._hits(hits):
print(f"SAMPLE's ITERATION {s['iteration']}")
sample = s["sample"]
sample["iteration"] = s["iteration"]["num"]
sample["primary_metric"] = s["iteration"]["primary-metric"]
Expand Down Expand Up @@ -1545,8 +1573,7 @@ async def get_metric_breakouts(
if len(pl) > 1:
response["periods"] = pl
response["breakouts"] = {n: v for n, v in breakouts.items() if len(v) > 1}
duration = time.time() - start
print(f"Processing took {duration} seconds")
self.logger.info("Processing took %.3f seconds", time.time() - start)
return response

async def get_metrics_data(
Expand Down Expand Up @@ -1645,8 +1672,7 @@ async def get_metrics_data(
for h in self._hits(data, ["metric_data"]):
response.append(self._format_data(h))
response.sort(key=lambda a: a["end"])
duration = time.time() - start
print(f"Processing took {duration} seconds")
self.logger.info("Processing took %.3f seconds", time.time() - start)
return response

async def get_metrics_summary(
Expand Down Expand Up @@ -1687,8 +1713,7 @@ async def get_metrics_summary(
filters=filters,
aggregations={"score": {"stats": {"field": "metric_data.value"}}},
)
duration = time.time() - start
print(f"Processing took {duration} seconds")
self.logger.info("Processing took %.3f seconds", time.time() - start)
return data["aggregations"]["score"]

async def _graph_title(
Expand Down Expand Up @@ -1983,6 +2008,5 @@ async def get_metrics_graph(self, graphdata: GraphList) -> dict[str, Any]:
axes[metric] = yref
graphitem["yaxis"] = yref
graphlist.append(graphitem)
duration = time.time() - start
print(f"Processing took {duration} seconds")
self.logger.info("Processing took %.3f seconds", time.time() - start)
return {"data": graphlist, "layout": layout}
2 changes: 1 addition & 1 deletion backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ allowlist_externals = ["bash", "echo", "coverage"]
commands = [
["echo", "{env:COVERAGE}"],
["pip", "list"],
["pytest", "-s", "--cov-branch", "--cov=app", "tests"],
["pytest", "-s", "--cov-branch", "--cov=app", "{posargs}", "tests"],
["coverage", "html", "--directory={env:COVERAGE}/html"],
["bash", "-c", "coverage report --format=markdown >{env:COVERAGE}/coverage.txt"],
]
Expand Down
Loading

0 comments on commit 840a9cf

Please sign in to comment.