Skip to content

Commit ace595b

Browse files
committed
Fix: avoid unnecessary allocation & 2nd time sorting of steps
1 parent bedb249 commit ace595b

File tree

4 files changed

+168
-13
lines changed

4 files changed

+168
-13
lines changed

jitter-test.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import timeit
2+
from statistics import mean
3+
from typing import Callable
4+
5+
import numpy as np
6+
7+
try:
8+
from numba import njit
9+
except ImportError as exc:
10+
raise SystemExit(
11+
"Numba is required for this benchmark. Install it with `pip install numba` before running the script."
12+
) from exc
13+
14+
15+
NUM_LISTS = 4
16+
LIST_LENGTH = 100000
17+
REPEAT = 5
18+
NUMBER = 1000
19+
20+
SENTINEL = np.int64(np.iinfo(np.int64).max)
21+
22+
23+
def make_sorted_arrays(num_lists: int, list_length: int) -> list[np.ndarray]:
24+
rng = np.random.default_rng(seed=0)
25+
return [
26+
np.sort(rng.integers(0, 10_000, size=list_length, dtype=np.int64))
27+
for _ in range(num_lists)
28+
]
29+
30+
31+
@njit(cache=True)
32+
def _merge_numba_impl(data: np.ndarray) -> np.ndarray:
33+
num_lists, list_len = data.shape
34+
total = num_lists * list_len
35+
indices = np.zeros(num_lists, dtype=np.int64)
36+
merged = np.empty(total, dtype=np.int64)
37+
38+
out_idx = 0
39+
last_val = SENTINEL
40+
41+
while True:
42+
best_val = SENTINEL
43+
best_list = -1
44+
for list_idx in range(num_lists):
45+
pos = indices[list_idx]
46+
if pos < list_len:
47+
value = data[list_idx, pos]
48+
if value < best_val:
49+
best_val = value
50+
best_list = list_idx
51+
52+
if best_list == -1:
53+
break
54+
55+
if best_val != last_val:
56+
merged[out_idx] = best_val
57+
out_idx += 1
58+
last_val = best_val
59+
60+
for list_idx in range(num_lists):
61+
pos = indices[list_idx]
62+
if pos < list_len:
63+
value = data[list_idx, pos]
64+
if value == best_val:
65+
pos += 1
66+
while pos < list_len and data[list_idx, pos] == best_val:
67+
pos += 1
68+
indices[list_idx] = pos
69+
70+
return merged[:out_idx]
71+
72+
73+
def merge_numba(data: list[np.ndarray], precomputed: np.ndarray | None = None) -> np.ndarray:
74+
stacked = precomputed if precomputed is not None else np.vstack(data)
75+
return _merge_numba_impl(stacked)
76+
77+
78+
def merge_numpy(data: list[np.ndarray]) -> np.ndarray:
79+
return np.unique(np.concatenate(data))
80+
81+
82+
def time_function(action: Callable[[], np.ndarray]) -> float:
83+
timer = timeit.Timer(action)
84+
runs = timer.repeat(repeat=REPEAT, number=NUMBER)
85+
return mean(runs) / NUMBER
86+
87+
88+
if __name__ == "__main__":
89+
dataset = make_sorted_arrays(NUM_LISTS, LIST_LENGTH)
90+
91+
stacked_dataset = np.vstack(dataset)
92+
numpy_result = merge_numpy(dataset)
93+
numba_result = merge_numba(dataset, stacked_dataset)
94+
assert np.array_equal(numpy_result, numba_result)
95+
96+
# Ensure Numba compilation happens before timing.
97+
merge_numba(dataset, stacked_dataset)
98+
99+
benchmarks = {
100+
"NumPy sort": lambda: merge_numpy(dataset),
101+
"Numba merge": lambda: merge_numba(dataset, stacked_dataset),
102+
}
103+
104+
for label, action in benchmarks.items():
105+
per_call = time_function(action)
106+
print(f"{label}: {per_call * 1_000_000:.2f} microseconds per merge")

src/neptune_query/internal/output_format.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import pathlib
16-
import sys
17-
import time
1816
from collections import defaultdict
1917
from dataclasses import dataclass
2018
from typing import (
@@ -488,21 +486,18 @@ def from_observed_steps(
488486
total_rows_count = sum(len(steps) for steps in observed_steps.values())
489487
display_names: list[str] = [""] * total_rows_count
490488
step_values: np.ndarray = np.empty(shape=(total_rows_count,), dtype=np.float64)
491-
492489
row_num: int = 0
493-
for display_name in sorted(observed_steps.keys()):
490+
sorted_observed_steps = sorted(observed_steps.items(), key=lambda x: x[0])
491+
for display_name, steps in sorted_observed_steps:
494492
sys_id = display_name_to_sys_id[display_name]
495-
sorted_steps = np.sort(observed_steps[display_name], kind="stable")
496-
for i, step in enumerate(sorted_steps, start=row_num):
497-
display_names[i] = display_name
498-
step_values[i] = step
499-
493+
step_values[row_num:row_num + steps.size] = steps
494+
display_names[row_num:row_num + steps.size] = [display_name] * steps.size
500495
if sys_id_ranges is not None:
501-
sys_id_ranges[sys_id] = (row_num, row_num + sorted_steps.size)
496+
sys_id_ranges[sys_id] = (row_num, row_num + steps.size)
502497
if row_dict_lookup is not None:
503-
row_dict_lookup[sys_id] = {float(step): idx for idx, step in enumerate(sorted_steps, start=row_num)}
498+
row_dict_lookup[sys_id] = {float(step): idx for idx, step in enumerate(steps, start=row_num)}
499+
row_num += steps.size
504500

505-
row_num += sorted_steps.size
506501

507502
return cls(
508503
display_names=display_names,

src/neptune_query/internal/retrieval/metrics.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,10 +192,11 @@ def _process_metrics_page(
192192
) -> util.Page[tuple[identifiers.RunAttributeDefinition, MetricDatapoints]]:
193193
result = {}
194194
for series in data.series:
195+
pass
195196
metric_values = MetricDatapoints.allocate(
196197
size=len(series.series.values), include_timestamp=include_timestamp, include_preview=include_preview
197198
)
198-
199+
199200
for i, point in enumerate(series.series.values):
200201
idx = metric_values.length - 1 - i if reverse_order else i
201202

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import numpy as np
2+
import pytest
3+
4+
from neptune_query.internal.identifiers import (
5+
AttributeDefinition,
6+
ProjectIdentifier,
7+
RunAttributeDefinition,
8+
RunIdentifier,
9+
SysId,
10+
)
11+
from neptune_query.internal.output_format import create_metrics_dataframe
12+
from neptune_query.internal.retrieval.metrics import MetricDatapoints
13+
14+
15+
def test_create_metrics_dataframe_large_debug_workload():
16+
num_runs = 1000000
17+
num_metrics = 1
18+
num_datapoints = 1
19+
20+
project_identifier = ProjectIdentifier("debug/project")
21+
metrics_data: dict[RunAttributeDefinition, MetricDatapoints] = {}
22+
sys_id_label_mapping: dict[SysId, str] = {}
23+
base_steps = np.arange(num_datapoints, dtype=np.float64)
24+
25+
for run_idx in range(num_runs):
26+
sys_id = SysId(f"sys{run_idx:04d}")
27+
sys_id_label_mapping[sys_id] = f"run-{run_idx:04d}"
28+
run_identifier = RunIdentifier(project_identifier, sys_id)
29+
30+
for metric_idx in range(num_metrics):
31+
attribute_definition = AttributeDefinition(f"metric_{metric_idx:02d}", "float_series")
32+
run_attribute_definition = RunAttributeDefinition(run_identifier, attribute_definition)
33+
34+
datapoints = MetricDatapoints.allocate(
35+
size=num_datapoints, include_timestamp=False, include_preview=False
36+
)
37+
step_offset = metric_idx + 1
38+
base_value = run_idx * num_metrics * num_datapoints + metric_idx * num_datapoints
39+
40+
shifted_steps = base_steps + step_offset
41+
for idx, base_step in enumerate(base_steps):
42+
datapoints.append(step=float(shifted_steps[idx]), value=float(base_value + base_step))
43+
44+
metrics_data[run_attribute_definition] = datapoints.compile()
45+
46+
dataframe = create_metrics_dataframe(
47+
metrics_data=metrics_data,
48+
sys_id_label_mapping=sys_id_label_mapping,
49+
type_suffix_in_column_names=False,
50+
include_point_previews=False,
51+
index_column_name="run",
52+
timestamp_column_name=None,
53+
)

0 commit comments

Comments
 (0)