Skip to content

Commit

Permalink
Fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasBeiske committed Apr 16, 2024
1 parent 83f3664 commit 34b6f8d
Showing 1 changed file with 34 additions and 4 deletions.
38 changes: 34 additions & 4 deletions src/ctapipe/tools/tests/test_train.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json

import numpy as np
import pytest

Expand Down Expand Up @@ -225,22 +227,50 @@ def test_no_cross_validation(tmp_path):
assert ret == 0
return out_file


def test_direction_uncertainty_regressor(tmp_path):
from ctapipe.tools.train_direction_uncertainty_regressor import TrainAngularErrorRegressor
from ctapipe.tools.aggregate_features import AggregateFeatures
from ctapipe.tools.train_direction_uncertainty_regressor import (
TrainDirectionUncertaintyRegressor,
)

agg_path = tmp_path / "aggregated.dl1.h5"
config_path = tmp_path / "config.json"
config = {
"FeatureAggregator": {
"image_parameters": [
("hillas", "length"),
("hillas", "width"),
("hillas", "skewness"),
("hillas", "kurtosis"),
],
}
}
with config_path.open("w") as f:
json.dump(config, f)

run_tool(
AggregateFeatures(),
argv=[
"--input=dataset://gamma_diffuse_dl2_train_small.dl2.h5",
f"--output={agg_path}",
f"--config={config_path}",
],
)

out_file = tmp_path / "direction_uncertainty.pkl"

tool = TrainAngularErrorRegressor()
tool = TrainDirectionUncertaintyRegressor()
config = resource_file("train_direction_uncertainty_regressor.yaml")
ret = run_tool(
tool,
argv=[
"--input=dataset://gamma_diffuse_dl2_train_small.dl2.h5",
f"--input={agg_path}",
f"--output={out_file}",
f"--config={config}",
"--log-level=INFO",
"--overwrite",
],
)
assert ret == 0
return out_file
return out_file

0 comments on commit 34b6f8d

Please sign in to comment.