Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 59 additions & 46 deletions sotrplib/handlers/prefect.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
from pathlib import Path

from prefect import flow, task
from astropy.coordinates import SkyCoord
from prefect import flow, task, unmapped

from sotrplib.maps.core import ProcessableMap
from sotrplib.maps.postprocessor import MapPostprocessor
from sotrplib.maps.preprocessor import MapPreprocessor
from sotrplib.outputs.core import SourceOutput
from sotrplib.sifter.core import EmptySifter, SiftingProvider
from sotrplib.sims.sources.core import SourceSimulation
from sotrplib.source_catalog.database import EmptyMockSourceCatalog, MockDatabase
from sotrplib.sims.sim_source_generators import (
SimulatedSource,
SimulatedSourceGenerator,
)
from sotrplib.sims.source_injector import EmptySourceInjector, SourceInjector
from sotrplib.source_catalog.core import SourceCatalog
from sotrplib.sources.blind import EmptyBlindSearch
from sotrplib.sources.core import (
BlindSearchProvider,
Expand All @@ -20,9 +25,11 @@

class PrefectRunner:
maps: list[ProcessableMap]
source_simulators: list[SimulatedSourceGenerator] | None
source_injector: SourceInjector | None
source_catalogs: list[SourceCatalog] | None
preprocessors: list[MapPreprocessor] | None
postprocessors: list[MapPostprocessor] | None
source_simulators: list[SourceSimulation] | None
forced_photometry: ForcedPhotometryProvider | None
source_subtractor: SourceSubtractor | None
blind_search: BlindSearchProvider | None
Expand All @@ -32,25 +39,23 @@ class PrefectRunner:
def __init__(
self,
maps: list[ProcessableMap],
forced_photometry_catalog: MockDatabase | None,
source_catalogs: list[MockDatabase] | None,
source_simulators: list[SimulatedSourceGenerator] | None,
source_injector: SourceInjector | None,
source_catalogs: list[SourceCatalog] | None,
preprocessors: list[MapPreprocessor] | None,
postprocessors: list[MapPostprocessor] | None,
source_simulators: list[SourceSimulation] | None,
forced_photometry: ForcedPhotometryProvider | None,
source_subtractor: SourceSubtractor | None,
blind_search: BlindSearchProvider | None,
sifter: SiftingProvider | None,
outputs: list[SourceOutput] | None,
):
self.maps = maps
self.forced_photometry_catalog = (
forced_photometry_catalog or EmptyMockSourceCatalog()
)
self.source_simulators = source_simulators or []
self.source_injector = source_injector or EmptySourceInjector()
self.source_catalogs = source_catalogs or []
self.preprocessors = preprocessors or []
self.postprocessors = postprocessors or []
self.source_simulators = source_simulators or []
self.forced_photometry = forced_photometry or EmptyForcedPhotometry()
self.source_subtractor = source_subtractor or EmptySourceSubtractor()
self.blind_search = blind_search or EmptyBlindSearch()
Expand All @@ -60,37 +65,50 @@ def __init__(
return

@task
def analyze_map(
self, input_map: ProcessableMap
) -> tuple[list, object, ProcessableMap]:
def build_map(self, input_map: ProcessableMap):
task(input_map.build)()

for preprocessor in self.preprocessors:
input_map = task(preprocessor.preprocess)(input_map=input_map)
task(preprocessor.preprocess)(input_map=input_map)

task(input_map.finalize)()
@property
def bbox(self):
bbox = self.maps[0].bbox

for postprocessor in self.postprocessors:
input_map = task(postprocessor.postprocess)(input_map=input_map)
for input_map in self.maps[1:]:
map_bbox = input_map.bbox
left = min(bbox[0].ra, map_bbox[0].ra)
bottom = min(bbox[0].dec, map_bbox[0].dec)
right = max(bbox[1].ra, map_bbox[1].ra)
top = max(bbox[1].dec, map_bbox[1].dec)
bbox = [SkyCoord(ra=left, dec=bottom), SkyCoord(ra=right, dec=top)]
return bbox

crossmatch_catalog = []
for c in self.source_catalogs:
crossmatch_catalog.extend(
task(c.cat.get_sources_in_map)(input_map=input_map)
)
@task
def simulate_sources(self) -> list[SimulatedSource]:
"""Generate sources based upon maximal bounding box of all maps"""
all_simulated_sources = []
bbox = self.bbox
for simulator in self.source_simulators:
simulated_sources, catalog = task(simulator.generate)(box=bbox)

for_forced_photometry = task(
self.forced_photometry_catalog.cat.get_sources_in_map
)(input_map)
all_simulated_sources.extend(simulated_sources)
self.source_catalogs.append(catalog)
return all_simulated_sources

for simulator in self.source_simulators:
input_map, additional_sources = task(simulator.simulate)(
input_map=input_map
)
for_forced_photometry.extend(additional_sources)
@task
def analyze_map(self, input_map: ProcessableMap, simulated_sources: list[SimulatedSource]) -> tuple[list, object, ProcessableMap]:
task(input_map.finalize)()

input_map = task(self.source_injector.inject)(
input_map=input_map, simulated_sources=simulated_sources
)

for postprocessor in self.postprocessors:
task(postprocessor.postprocess)(input_map=input_map)

forced_photometry_candidates = task(self.forced_photometry.force)(
input_map=input_map, sources=for_forced_photometry
input_map=input_map, catalogs=self.source_catalogs
)

source_subtracted_map = task(self.source_subtractor.subtract)(
Expand All @@ -101,9 +119,10 @@ def analyze_map(
input_map=source_subtracted_map
)

self.sifter.catalog_sources = crossmatch_catalog
sifter_result = task(self.sifter.sift)(
sources=blind_sources, input_map=source_subtracted_map
sources=blind_sources,
catalogs=self.source_catalogs,
input_map=source_subtracted_map,
)

for output in self.outputs:
Expand All @@ -112,24 +131,18 @@ def analyze_map(
sifter_result=sifter_result,
input_map=input_map,
)
return forced_photometry_candidates, sifter_result, input_map

@flow
def run(self):
return self.analyze_map.map(self.maps).result()
return forced_photometry_candidates, sifter_result, input_map

@flow
def analyze_new_maps(
self, map_file: Path | str
) -> tuple[list, object, ProcessableMap]:
from sotrplib.config.config import Settings

input_data = Settings.from_file(map_file).to_basic()
return self.analyze_map.map(input_data.maps).result()
def run(self) -> tuple[list[list], list[object], list[ProcessableMap]]:
self.build_map.map(self.maps).wait()
all_simulated_sources = self.simulate_sources()
return self.analyze_map.map(self.maps, unmapped(all_simulated_sources)).result()


@flow
def analyze_from_configuration(config: Path | str):
def analyze_from_configuration(config: Path | str) -> tuple[list[list], list[object], list[ProcessableMap]]:
from sotrplib.config.config import Settings

pipeline = Settings.from_file(config).to_prefect()
Expand Down
51 changes: 27 additions & 24 deletions tests/test_prefect_pipeline/test_prefect_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,39 +20,41 @@
{
"map_type": "simulated",
"observation_start": "2025-01-01",
"observation_end": "2025-01-02",
},
{
"map_type": "simulated",
"observation_start": "2025-01-02",
"observation_end": "2025-01-03",
},
"observation_end": "2025-01-02"
}
],
"source_simulators": [
{
"simulation_type": "random",
"parameters": {
"n_sources": 32,
"min_flux": "3.0 Jy",
"max_flux": "10.0 Jy",
"fwhm_uncertainty_frac": 0.0,
"fraction_return": 0.5,
},
"simulation_type": "fixed",
"number": 8,
"min_flux": "3.0 Jy",
"max_flux": "10.0 Jy",
"catalog_fraction": 0.5
}
],
"preprocessors": [{"preprocessor_type": "kappa_rho"}],
"source_subtractor": {"subtractor_type": "photutils"},
"blind_search": {"search_type": "photutils"},
"forced_photometry": {"photometry_type": "scipy", "reproject_thumbnails": "True"},
"sifter": {"sifter_type": "default"},
"source_injector": {
"injector_type": "photutils"
},
"forced_photometry": {
"photometry_type": "scipy"
},
"source_subtractor": {
"subtractor_type": "photutils"
},
"blind_search": {
"search_type": "photutils"
},
"sifter": {
"sifter_type": "simple"
},
"outputs": [
{
"output_type": "pickle",
"directory": "."
}
],
]
}


def test_basic_pipeline_scipy(
tmp_path, map_with_sources: tuple[SimulatedMap, list[RegisteredSource]]
):
Expand All @@ -65,15 +67,15 @@ def test_basic_pipeline_scipy(

runner = PrefectRunner(
maps=maps,
forced_photometry_catalog=None,
source_catalogs=[],
source_injector=None,
preprocessors=None,
postprocessors=None,
source_simulators=None,
forced_photometry=Scipy2DGaussianFitter(sources=sources),
source_subtractor=None,
blind_search=SigmaClipBlindSearch(),
sifter=DefaultSifter(catalog_sources=sources),
sifter=DefaultSifter(),
outputs=[PickleSerializer(directory=tmp_path)],
)

Expand Down Expand Up @@ -102,6 +104,7 @@ def test_basic_pipeline_from_file(

def _validate_pipeline_result(result, nmaps):
assert len(result) == nmaps
print(result[0])
candidates, sifter_result, output_map = result[0]
assert all([isinstance(candidate, MeasuredSource) for candidate in candidates])
assert isinstance(sifter_result, SifterResult)
Expand Down
Loading