Skip to content

Commit

Permalink
switched to the new yaw.randoms.BoxGenerator
Browse files Browse the repository at this point in the history
  • Loading branch information
jlvdb committed Dec 3, 2024
1 parent 6bd146c commit f7f545e
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 183 deletions.
11 changes: 7 additions & 4 deletions examples/full_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@
"outputs": [],
"source": [
"# step 1\n",
"from rail.pipelines.estimation.randoms_yaw_v2 import UniformRandoms\n",
"import pandas as pd\n",
"from yaw.randoms import BoxRandoms\n",
"from rail.yaw_rail.utils import get_dc2_test_data\n",
"\n",
"from rail.estimation.algos.cc_yaw import (\n",
Expand Down Expand Up @@ -131,14 +132,16 @@
"metadata": {},
"outputs": [],
"source": [
"angular_rng = UniformRandoms(\n",
"generator = BoxRandoms(\n",
" test_data[\"ra\"].min(),\n",
" test_data[\"ra\"].max(),\n",
" test_data[\"dec\"].min(),\n",
" test_data[\"dec\"].max(),\n",
" redshifts=redshifts,\n",
" seed=12345,\n",
")\n",
"test_rand = angular_rng.generate(n_data * 10, draw_from=dict(z=redshifts))\n",
"test_rand = generator.generate_dataframe(n_data * 10)\n",
"test_rand.rename(columns=dict(redshifts=\"z\"), inplace=True)\n",
"\n",
"handle_test_rand = DS.add_data(\"input_rand\", test_rand, TableHandle)"
]
Expand Down Expand Up @@ -589,7 +592,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "yaw_rail",
"display_name": "rail_yaw",
"language": "python",
"name": "python3"
},
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ dependencies = [
"numpy>=2.0",
"h5py",
"pz-rail-base>=1.0.3",
"yet_another_wizz>=3.0.5",
"yet_another_wizz>=3.0.7",
]

[project.urls]
Expand Down
10 changes: 7 additions & 3 deletions src/rail/pipelines/estimation/build_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
import os
from shutil import rmtree

import pandas as pd
from yaw.randoms import BoxRandoms

import rail.stages
from rail.core.stage import RailPipeline, RailStage
from rail.pipelines.estimation.randoms_yaw_v2 import UniformRandoms

rail.stages.import_and_attach_all()
from rail.stages import *
Expand Down Expand Up @@ -52,14 +54,16 @@ def create_datasets(root): # pragma: no cover
data_path = os.path.join(root, data_name)
test_data.to_parquet(data_path)

angular_rng = UniformRandoms(
generator = BoxRandoms(
test_data["ra"].min(),
test_data["ra"].max(),
test_data["dec"].min(),
test_data["dec"].max(),
redshifts=redshifts,
seed=12345,
)
test_rand = angular_rng.generate(n_data * 10, draw_from=dict(z=redshifts))
test_rand = generator.generate_dataframe(n_data * 10)
test_rand.rename(columns=dict(redshifts="z"), inplace=True)

rand_name = "input_rand.parquet"
rand_path = os.path.join(root, rand_name)
Expand Down
171 changes: 0 additions & 171 deletions src/rail/pipelines/estimation/randoms_yaw_v2.py

This file was deleted.

9 changes: 6 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import TYPE_CHECKING

from pytest import fixture
from rail.pipelines.estimation.randoms_yaw_v2 import UniformRandoms
from yaw.randoms import BoxRandoms

from rail.core.stage import RailStage
from rail.yaw_rail.utils import get_dc2_test_data
Expand Down Expand Up @@ -50,11 +50,14 @@ def fixture_zlim(mock_data):
def fixture_mock_rand(mock_data, seed) -> DataFrame:
n_data = len(mock_data)
redshifts = mock_data["z"].to_numpy()
angular_rng = UniformRandoms(

generator = BoxRandoms(
mock_data["ra"].min(),
mock_data["ra"].max(),
mock_data["dec"].min(),
mock_data["dec"].max(),
redshifts=redshifts,
seed=seed,
)
return angular_rng.generate(2 * n_data, draw_from=dict(z=redshifts))
test_rand = generator.generate_dataframe(n_data * 10)
return test_rand.rename(columns=dict(redshifts="z"))
2 changes: 1 addition & 1 deletion tests/estimation/algos/test_cc_yaw.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def test_ceci_pipeline(tmp_path) -> None:
expect_path = write_expect_ncc(tmp_path)
expect_data = np.loadtxt(expect_path).T
output_data = np.loadtxt(f"{output_prefix}.dat").T
for i, (col_a, col_b) in enumerate(zip(expect_data, output_data)):
for i, (col_a, col_b) in enumerate(zip(output_data, expect_data)):
if i == 3: # error column differs every time since using patch_num
break
npt.assert_array_equal(col_a, col_b)

0 comments on commit f7f545e

Please sign in to comment.