|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +from dataclasses import dataclass |
3 | 4 | from typing import ( |
4 | 5 | Mapping, |
5 | 6 | Sequence, |
6 | | - Tuple, |
7 | 7 | Union, |
8 | 8 | ) |
9 | 9 |
|
|
12 | 12 | from neptune_query.internal.identifiers import RunAttributeDefinition |
13 | 13 | from neptune_query.internal.retrieval.metrics import MetricValues |
14 | 14 |
|
15 | | -FloatPointValue = Tuple[float, float, float, bool, float] |
| 15 | + |
| 16 | +@dataclass(frozen=True) |
| 17 | +class FloatPointValue: |
| 18 | + timestamp_ms: float | None |
| 19 | + step: float | None |
| 20 | + value: float | None |
| 21 | + is_preview: bool | None = None |
| 22 | + completion_ratio: float | None = None |
| 23 | + |
| 24 | + @classmethod |
| 25 | + def create( |
| 26 | + cls, |
| 27 | + step: float | None, |
| 28 | + value: float | None, |
| 29 | + *, |
| 30 | + timestamp_ms: float | None = None, |
| 31 | + is_preview: bool | None = None, |
| 32 | + completion_ratio: float | None = None, |
| 33 | + ) -> "FloatPointValue": |
| 34 | + return cls( |
| 35 | + timestamp_ms=timestamp_ms, |
| 36 | + step=step, |
| 37 | + value=value, |
| 38 | + is_preview=is_preview, |
| 39 | + completion_ratio=completion_ratio, |
| 40 | + ) |
| 41 | + |
| 42 | + def as_tuple(self) -> tuple[object, ...]: |
| 43 | + return self.timestamp_ms, self.step, self.value, self.is_preview, self.completion_ratio |
| 44 | + |
| 45 | + def __iter__(self): |
| 46 | + return iter(self.as_tuple()) |
| 47 | + |
| 48 | + def __getitem__(self, index: int) -> object: |
| 49 | + return self.as_tuple()[index] |
| 50 | + |
| 51 | + def __len__(self) -> int: |
| 52 | + return len(self.as_tuple()) |
| 53 | + |
| 54 | + def has_timestamp(self) -> bool: |
| 55 | + return self.timestamp_ms is not None |
| 56 | + |
| 57 | + def has_preview_data(self) -> bool: |
| 58 | + return self.is_preview is not None and self.completion_ratio is not None |
16 | 59 |
|
17 | 60 |
|
18 | 61 | def to_metric_values(points: Sequence[FloatPointValue]) -> MetricValues: |
19 | 62 | size = len(points) |
20 | | - include_timestamp = any(len(point) > 0 and point[0] is not None for point in points) |
21 | | - include_preview = any(len(point) >= 4 for point in points) |
| 63 | + include_timestamp = any(point.has_timestamp() for point in points) |
| 64 | + include_preview = any(point.has_preview_data() for point in points) |
22 | 65 |
|
23 | 66 | metric_values = MetricValues.allocate( |
24 | 67 | size=size, include_timestamp=include_timestamp, include_preview=include_preview |
25 | 68 | ) |
26 | 69 |
|
27 | 70 | for idx, point in enumerate(points): |
28 | | - timestamp = point[0] if len(point) > 0 else None |
29 | | - step = point[1] if len(point) > 1 else np.nan |
30 | | - value = point[2] if len(point) > 2 else np.nan |
31 | | - preview = point[3] if len(point) > 3 else False |
32 | | - completion_ratio = point[4] if len(point) > 4 else 1.0 |
33 | | - |
34 | | - metric_values.steps[idx] = float(step) if step is not None else np.nan |
35 | | - metric_values.values[idx] = float(value) if value is not None else np.nan |
| 71 | + metric_values.steps[idx] = float(point.step) |
| 72 | + metric_values.values[idx] = float(point.value) |
36 | 73 |
|
37 | 74 | if metric_values.timestamps is not None: |
38 | | - metric_values.timestamps[idx] = float(timestamp) if timestamp is not None else np.nan |
| 75 | + metric_values.timestamps[idx] = float(point.timestamp_ms) if point.timestamp_ms is not None else np.nan |
39 | 76 |
|
40 | 77 | if metric_values.is_preview is not None: |
41 | | - metric_values.is_preview[idx] = bool(preview) |
| 78 | + metric_values.is_preview[idx] = bool(point.is_preview) if point.is_preview is not None else False |
42 | 79 |
|
43 | 80 | if metric_values.completion_ratio is not None: |
44 | | - metric_values.completion_ratio[idx] = float(completion_ratio) if completion_ratio is not None else np.nan |
| 81 | + metric_values.completion_ratio[idx] = ( |
| 82 | + float(point.completion_ratio) if point.completion_ratio is not None else 1.0 |
| 83 | + ) |
45 | 84 |
|
46 | 85 | return metric_values |
47 | 86 |
|
48 | 87 |
|
49 | 88 | def normalize_metrics_data( |
50 | | - metrics_data: Mapping[RunAttributeDefinition, Union[MetricValues, Sequence[FloatPointValue]]], |
| 89 | + metrics_data: Mapping[ |
| 90 | + RunAttributeDefinition, |
| 91 | + Union[MetricValues, Sequence[FloatPointValue]], |
| 92 | + ], |
51 | 93 | ) -> dict[RunAttributeDefinition, MetricValues]: |
52 | 94 | return { |
53 | 95 | definition: value if isinstance(value, MetricValues) else to_metric_values(value) |
54 | 96 | for definition, value in metrics_data.items() |
55 | 97 | } |
| 98 | + |
| 99 | + |
| 100 | +def assert_metric_mappings_equal( |
| 101 | + actual: Mapping[RunAttributeDefinition, MetricValues], |
| 102 | + expected: Mapping[RunAttributeDefinition, MetricValues], |
| 103 | +) -> None: |
| 104 | + actual_keys = set(actual.keys()) |
| 105 | + expected_keys = set(expected.keys()) |
| 106 | + |
| 107 | + if actual_keys != expected_keys: |
| 108 | + missing = expected_keys - actual_keys |
| 109 | + unexpected = actual_keys - expected_keys |
| 110 | + raise AssertionError(f"Metric definitions mismatch. Missing: {missing}, unexpected: {unexpected}") |
| 111 | + |
| 112 | + for definition in expected_keys: |
| 113 | + actual_values = actual[definition] |
| 114 | + expected_values = expected[definition] |
| 115 | + if actual_values != expected_values: |
| 116 | + raise AssertionError( |
| 117 | + "Metric values differ for " f"{definition}: actual={actual_values!r}, expected={expected_values!r}" |
| 118 | + ) |
0 commit comments