diff --git a/sotrplib/handlers/prefect.py b/sotrplib/handlers/prefect.py index 49c50a4..7139f22 100644 --- a/sotrplib/handlers/prefect.py +++ b/sotrplib/handlers/prefect.py @@ -1,14 +1,19 @@ from pathlib import Path -from prefect import flow, task +from astropy.coordinates import SkyCoord +from prefect import flow, task, unmapped from sotrplib.maps.core import ProcessableMap from sotrplib.maps.postprocessor import MapPostprocessor from sotrplib.maps.preprocessor import MapPreprocessor from sotrplib.outputs.core import SourceOutput from sotrplib.sifter.core import EmptySifter, SiftingProvider -from sotrplib.sims.sources.core import SourceSimulation -from sotrplib.source_catalog.database import EmptyMockSourceCatalog, MockDatabase +from sotrplib.sims.sim_source_generators import ( + SimulatedSource, + SimulatedSourceGenerator, +) +from sotrplib.sims.source_injector import EmptySourceInjector, SourceInjector +from sotrplib.source_catalog.core import SourceCatalog from sotrplib.sources.blind import EmptyBlindSearch from sotrplib.sources.core import ( BlindSearchProvider, @@ -20,9 +25,11 @@ class PrefectRunner: maps: list[ProcessableMap] + source_simulators: list[SimulatedSourceGenerator] | None + source_injector: SourceInjector | None + source_catalogs: list[SourceCatalog] | None preprocessors: list[MapPreprocessor] | None postprocessors: list[MapPostprocessor] | None - source_simulators: list[SourceSimulation] | None forced_photometry: ForcedPhotometryProvider | None source_subtractor: SourceSubtractor | None blind_search: BlindSearchProvider | None @@ -32,11 +39,11 @@ class PrefectRunner: def __init__( self, maps: list[ProcessableMap], - forced_photometry_catalog: MockDatabase | None, - source_catalogs: list[MockDatabase] | None, + source_simulators: list[SimulatedSourceGenerator] | None, + source_injector: SourceInjector | None, + source_catalogs: list[SourceCatalog] | None, preprocessors: list[MapPreprocessor] | None, postprocessors: list[MapPostprocessor] | None, - source_simulators: list[SourceSimulation] | None, forced_photometry: ForcedPhotometryProvider | None, source_subtractor: SourceSubtractor | None, blind_search: BlindSearchProvider | None, @@ -44,13 +51,11 @@ def __init__( outputs: list[SourceOutput] | None, ): self.maps = maps - self.forced_photometry_catalog = ( - forced_photometry_catalog or EmptyMockSourceCatalog() - ) + self.source_simulators = source_simulators or [] + self.source_injector = source_injector or EmptySourceInjector() self.source_catalogs = source_catalogs or [] self.preprocessors = preprocessors or [] self.postprocessors = postprocessors or [] - self.source_simulators = source_simulators or [] self.forced_photometry = forced_photometry or EmptyForcedPhotometry() self.source_subtractor = source_subtractor or EmptySourceSubtractor() self.blind_search = blind_search or EmptyBlindSearch() @@ -60,37 +65,50 @@ def __init__( return @task - def analyze_map( - self, input_map: ProcessableMap - ) -> tuple[list, object, ProcessableMap]: + def build_map(self, input_map: ProcessableMap): task(input_map.build)() for preprocessor in self.preprocessors: - input_map = task(preprocessor.preprocess)(input_map=input_map) + task(preprocessor.preprocess)(input_map=input_map) - task(input_map.finalize)() + @property + def bbox(self): + bbox = self.maps[0].bbox - for postprocessor in self.postprocessors: - input_map = task(postprocessor.postprocess)(input_map=input_map) + for input_map in self.maps[1:]: + map_bbox = input_map.bbox + left = min(bbox[0].ra, map_bbox[0].ra) + bottom = min(bbox[0].dec, map_bbox[0].dec) + right = max(bbox[1].ra, map_bbox[1].ra) + top = max(bbox[1].dec, map_bbox[1].dec) + bbox = [SkyCoord(ra=left, dec=bottom), SkyCoord(ra=right, dec=top)] + return bbox - crossmatch_catalog = [] - for c in self.source_catalogs: - crossmatch_catalog.extend( - task(c.cat.get_sources_in_map)(input_map=input_map) - ) + @task + def simulate_sources(self) -> list[SimulatedSource]: + """Generate sources based upon maximal bounding box of all maps""" + all_simulated_sources = [] + bbox = self.bbox + for simulator in self.source_simulators: + simulated_sources, catalog = task(simulator.generate)(box=bbox) - for_forced_photometry = task( - self.forced_photometry_catalog.cat.get_sources_in_map - )(input_map) + all_simulated_sources.extend(simulated_sources) + self.source_catalogs.append(catalog) + return all_simulated_sources - for simulator in self.source_simulators: - input_map, additional_sources = task(simulator.simulate)( - input_map=input_map - ) - for_forced_photometry.extend(additional_sources) + @task + def analyze_map(self, input_map: ProcessableMap, simulated_sources: list[SimulatedSource]) -> tuple[list, object, ProcessableMap]: + task(input_map.finalize)() + + input_map = task(self.source_injector.inject)( + input_map=input_map, simulated_sources=simulated_sources + ) + + for postprocessor in self.postprocessors: + task(postprocessor.postprocess)(input_map=input_map) forced_photometry_candidates = task(self.forced_photometry.force)( - input_map=input_map, sources=for_forced_photometry + input_map=input_map, catalogs=self.source_catalogs ) source_subtracted_map = task(self.source_subtractor.subtract)( @@ -101,9 +119,10 @@ def analyze_map( input_map=source_subtracted_map ) - self.sifter.catalog_sources = crossmatch_catalog sifter_result = task(self.sifter.sift)( - sources=blind_sources, input_map=source_subtracted_map + sources=blind_sources, + catalogs=self.source_catalogs, + input_map=source_subtracted_map, ) for output in self.outputs: @@ -112,24 +131,18 @@ def analyze_map( sifter_result=sifter_result, input_map=input_map, ) - return forced_photometry_candidates, sifter_result, input_map - @flow - def run(self): - return self.analyze_map.map(self.maps).result() + return forced_photometry_candidates, sifter_result, input_map @flow - def analyze_new_maps( - self, map_file: Path | str - ) -> tuple[list, object, ProcessableMap]: - from sotrplib.config.config import Settings - - input_data = Settings.from_file(map_file).to_basic() - return self.analyze_map.map(input_data.maps).result() + def run(self) -> tuple[list[list], list[object], list[ProcessableMap]]: + self.build_map.map(self.maps).wait() + all_simulated_sources = self.simulate_sources() + return self.analyze_map.map(self.maps, unmapped(all_simulated_sources)).result() @flow -def analyze_from_configuration(config: Path | str): +def analyze_from_configuration(config: Path | str) -> tuple[list[list], list[object], list[ProcessableMap]]: from sotrplib.config.config import Settings pipeline = Settings.from_file(config).to_prefect() diff --git a/tests/test_prefect_pipeline/test_prefect_pipeline.py b/tests/test_prefect_pipeline/test_prefect_pipeline.py index 84d2661..a5dcf0c 100644 --- a/tests/test_prefect_pipeline/test_prefect_pipeline.py +++ b/tests/test_prefect_pipeline/test_prefect_pipeline.py @@ -20,39 +20,41 @@ { "map_type": "simulated", "observation_start": "2025-01-01", - "observation_end": "2025-01-02", - }, - { - "map_type": "simulated", - "observation_start": "2025-01-02", - "observation_end": "2025-01-03", - }, + "observation_end": "2025-01-02" + } ], "source_simulators": [ { - "simulation_type": "random", - "parameters": { - "n_sources": 32, - "min_flux": "3.0 Jy", - "max_flux": "10.0 Jy", - "fwhm_uncertainty_frac": 0.0, - "fraction_return": 0.5, - }, + "simulation_type": "fixed", + "number": 8, + "min_flux": "3.0 Jy", + "max_flux": "10.0 Jy", + "catalog_fraction": 0.5 } ], - "preprocessors": [{"preprocessor_type": "kappa_rho"}], - "source_subtractor": {"subtractor_type": "photutils"}, - "blind_search": {"search_type": "photutils"}, - "forced_photometry": {"photometry_type": "scipy", "reproject_thumbnails": "True"}, - "sifter": {"sifter_type": "default"}, + "source_injector": { + "injector_type": "photutils" + }, + "forced_photometry": { + "photometry_type": "scipy" + }, + "source_subtractor": { + "subtractor_type": "photutils" + }, + "blind_search": { + "search_type": "photutils" + }, + "sifter": { + "sifter_type": "simple" + }, "outputs": [ { "output_type": "pickle", + "directory": "." } - ], + ] } - def test_basic_pipeline_scipy( tmp_path, map_with_sources: tuple[SimulatedMap, list[RegisteredSource]] ): @@ -65,15 +67,15 @@ def test_basic_pipeline_scipy( runner = PrefectRunner( maps=maps, - forced_photometry_catalog=None, source_catalogs=[], + source_injector=None, preprocessors=None, postprocessors=None, source_simulators=None, forced_photometry=Scipy2DGaussianFitter(sources=sources), source_subtractor=None, blind_search=SigmaClipBlindSearch(), - sifter=DefaultSifter(catalog_sources=sources), + sifter=DefaultSifter(), outputs=[PickleSerializer(directory=tmp_path)], ) @@ -102,6 +104,7 @@ def test_basic_pipeline_from_file( def _validate_pipeline_result(result, nmaps): assert len(result) == nmaps + print(result[0]) candidates, sifter_result, output_map = result[0] assert all([isinstance(candidate, MeasuredSource) for candidate in candidates]) assert isinstance(sifter_result, SifterResult)