From 2ab880d0f4c664bb3370fb117616e086993652af Mon Sep 17 00:00:00 2001 From: Lynne Jones Date: Fri, 13 Sep 2024 18:57:32 -0700 Subject: [PATCH 1/2] Removed deprecated features and BasisFunctions; adjust unit tests --- .../basis_functions/basis_functions.py | 768 +----------------- .../basis_functions/feasibility_funcs.py | 41 +- .../basis_functions/mask_basis_funcs.py | 102 +-- .../basis_functions/rolling_funcs.py | 173 +--- .../scheduler/example/simple_examples.py | 4 +- .../scheduler/features/features.py | 327 +------- .../scheduler/surveys/dd_surveys.py | 4 +- .../scheduler/surveys/field_survey.py | 4 +- tests/scheduler/test_baseline.py | 265 +----- tests/scheduler/test_basisfuncs.py | 11 +- tests/scheduler/test_coresched.py | 20 +- tests/scheduler/test_features.py | 54 +- 12 files changed, 112 insertions(+), 1661 deletions(-) diff --git a/rubin_scheduler/scheduler/basis_functions/basis_functions.py b/rubin_scheduler/scheduler/basis_functions/basis_functions.py index bbf99da7..933e93f7 100644 --- a/rubin_scheduler/scheduler/basis_functions/basis_functions.py +++ b/rubin_scheduler/scheduler/basis_functions/basis_functions.py @@ -4,40 +4,26 @@ "ConstantBasisFunction", "SimpleArrayBasisFunction", "DelayStartBasisFunction", - "TargetMapBasisFunction", - "AvoidLongGapsBasisFunction", - "AvoidFastRevisits", "AvoidFastRevisitsBasisFunction", "VisitRepeatBasisFunction", "M5DiffBasisFunction", "M5DiffAtHpixBasisFunction", "StrictFilterBasisFunction", - "GoalStrictFilterBasisFunction", "FilterChangeBasisFunction", "SlewtimeBasisFunction", - "AggressiveSlewtimeBasisFunction", - "SkybrightnessLimitAtHpixBasisFunction", - "SkybrightnessLimitBasisFunction", - "CablewrapUnwrapBasisFunction", "CadenceEnhanceBasisFunction", "CadenceEnhanceTrapezoidBasisFunction", "AzimuthBasisFunction", "AzModuloBasisFunction", "DecModuloBasisFunction", "MapModuloBasisFunction", - "TemplateGenerateBasisFunction", - "FootprintNvisBasisFunction", - "ThirdObservationBasisFunction", "SeasonCoverageBasisFunction", "NObsPerYearBasisFunction", "CadenceInSeasonBasisFunction", - "NearSunTwilightBasisFunction", "NearSunHighAirmassBasisFunction", "NObsHighAmBasisFunction", "GoodSeeingBasisFunction", - "ObservedTwiceBasisFunction", "EclipticBasisFunction", - "LimitRepeatBasisFunction", "VisitGap", "NGoodSeeingBasisFunction", "AvoidDirectWind", @@ -467,148 +453,6 @@ def _calc_value(self, conditions, indx=None): return result -class AvoidLongGapsBasisFunction(BaseBasisFunction): - """Boost the reward on parts of the survey that haven't been - observed for a while. - - Parameters - ---------- - filtername : `str`, optional - The filter to consider when tracking visits. - nside : `int`, optional - The nside to use for the basis function. - Default None uses `set_default_nside()`. - footprint : `np.ndarray`, (N,) - The footprint to consider when tracking visits. - Default None uses `get_current_footprint()`. - min_gap : `float`, optional - The minimum gap of time before boosting (in days). - max_gap : `float`, optional - The maximum gap of time before stopping boosting (in days). - ha_limit : `float, optional - Only boost visits at parts of the sky with HA < ha_limit - (in hours). - """ - - def __init__( - self, - filtername=None, - nside=None, - footprint=None, - min_gap=4.0, - max_gap=40.0, - ha_limit=3.5, - ): - super(AvoidLongGapsBasisFunction, self).__init__(nside=nside, filtername=filtername) - self.min_gap = min_gap - self.max_gap = max_gap - self.filtername = filtername - if footprint is None: - footprints, labels = get_current_footprint(self.nside) - footprint = footprints[self.filtername] - self.footprint = footprint - self.ha_limit = 2.0 * np.pi * ha_limit / 24.0 # To radians - self.survey_features = {} - self.survey_features["last_observed"] = features.LastObserved(nside=nside, filtername=filtername) - self.result = np.zeros(hp.nside2npix(self.nside)) - send_unused_deprecation_warning(self.__class__.__name__) - - def _calc_value(self, conditions, indx=None): - result = self.result.copy() - - gap = conditions.mjd - self.survey_features["last_observed"].feature - in_range = np.where((gap > self.min_gap) & (gap < self.max_gap) & (self.footprint > 0)) - result[in_range] = 1 - - # mask out areas beyond the hour angle limit. - out_ha = np.where((conditions.HA > self.ha_limit) & (conditions.HA < (2.0 * np.pi - self.ha_limit)))[ - 0 - ] - result[out_ha] = 0 - - return result - - -class TargetMapBasisFunction(BaseBasisFunction): - """Basis function that tracks number of observations and tries - to match a specified spatial distribution. - - In general, this is deprecated in favor of - `FootprintBasisFunction`. - - Parameters - ---------- - filtername: `str` ('r') - The name of the filter for this target map. - nside: `int` (default_nside) - The healpix resolution. - target_map : `np.array` (None) - A healpix map showing the ratio of observations desired - for all points on the sky - norm_factor : `float` (0.00010519) - for converting target map to number of observations. - Should be the area of the camera divided by the area of a healpixel - divided by the sum of all your goal maps. - Default value assumes LSST foV has 1.75 degree radius and the - standard goal maps. - If using multiple filters, see - rubin_scheduler.utils.calc_norm_factor - for a utility that computes norm_factor. - out_of_bounds_val : `float` (-10.) - Reward value to give regions where there are no - observations requested (unitless). - """ - - def __init__( - self, - filtername="r", - nside=None, - target_map=None, - norm_factor=None, - out_of_bounds_val=-10.0, - ): - super(TargetMapBasisFunction, self).__init__(nside=nside, filtername=filtername) - - if norm_factor is None: - warnings.warn("No norm_factor set, use utils.calc_norm_factor if using multiple filters.") - self.norm_factor = 0.00010519 - else: - self.norm_factor = norm_factor - - self.survey_features = {} - # Map of the number of observations in filter - self.survey_features["n_obs"] = features.NObservations(filtername=filtername, nside=self.nside) - # Count of all the observations - self.survey_features["n_obs_count_all"] = features.NObsCount(filtername=None) - if target_map is None: - target_maps, labels = utils.get_current_footprint(self.nside) - self.target_map = target_maps[filtername] - else: - self.target_map = target_map - self.out_of_bounds_area = np.where(self.target_map == 0)[0] - self.out_of_bounds_val = out_of_bounds_val - self.result = np.zeros(hp.nside2npix(self.nside), dtype=float) - self.all_indx = np.arange(self.result.size) - # As of 4/2024, - # this is used in the ts_fbs_utils maintel "anytime" test survey. - # It is also used in unit tests. - # send_unused_deprecation_warning(self.__class__.__name__) - # In general, it is deprecated in favor of FootprintBasisFunction - - def _calc_value(self, conditions, indx=None): - result = self.result.copy() - if indx is None: - indx = self.all_indx - - # Find out how many observations we want now at those points - goal_n = self.target_map[indx] * self.survey_features["n_obs_count_all"].feature * self.norm_factor - - result[indx] = goal_n - self.survey_features["n_obs"].feature[indx] - result[self.out_of_bounds_area] = self.out_of_bounds_val - - return result - - def az_rel_point(azs, point_az): az_rel_moon = (azs - point_az) % (2.0 * np.pi) if isinstance(azs, np.ndarray): @@ -828,89 +672,6 @@ def _calc_value(self, conditions, indx=None): return result -class FootprintNvisBasisFunction(BaseBasisFunction): - """Basis function to drive observations of a given footprint. - Good to target of opportunity targets where one might want to observe - a region 3 times. - - Parameters - ---------- - footprint : `np.array` - A healpix array (1 for desired, 0 for not desired) of the - target footprint. - nvis : `int` (1) - The number of visits to try and gather - """ - - def __init__( - self, - filtername="r", - nside=None, - footprint=None, - nvis=1, - out_of_bounds_val=np.nan, - ): - super(FootprintNvisBasisFunction, self).__init__(nside=nside, filtername=filtername) - if footprint is None: - footprint = np.zeros(hp.nside2npix(self.nside)) - self.footprint = footprint - self.nvis = nvis - - # Have a feature that tracks how many observations we have - self.survey_features = {} - # Map of the number of observations in filter - self.survey_features["n_obs"] = features.NObservations(filtername=filtername, nside=self.nside) - self.result = np.zeros(hp.nside2npix(self.nside)) - self.result.fill(out_of_bounds_val) - self.out_of_bounds_val = out_of_bounds_val - send_unused_deprecation_warning(self.__class__.__name__) - - def _calc_value(self, conditions, indx=None): - result = self.result.copy() - diff = IntRounded(self.footprint * self.nvis - self.survey_features["n_obs"].feature) - - result[np.where(diff > IntRounded(0))] = 1 - - # Any spot where we have enough visits is out of bounds now. - result[np.where(diff <= IntRounded(0))] = self.out_of_bounds_val - return result - - -class ThirdObservationBasisFunction(BaseBasisFunction): - """If there have been observations in two filters long enough ago, - go for a third - - Parameters - ---------- - gap_min : `float` (40.) - The minimum time gap to consider a pixel good (minutes) - gap_max : `float` (120) - The maximum time to consider going for a pair (minutes) - """ - - def __init__(self, nside=32, filtername1="r", filtername2="z", gap_min=40.0, gap_max=120.0): - super(ThirdObservationBasisFunction, self).__init__(nside=nside) - self.filtername1 = filtername1 - self.filtername2 = filtername2 - self.gap_min = IntRounded(gap_min / 60.0 / 24.0) - self.gap_max = IntRounded(gap_max / 60.0 / 24.0) - - self.survey_features = {} - self.survey_features["last_obs_f1"] = features.LastObserved(filtername=filtername1, nside=nside) - self.survey_features["last_obs_f2"] = features.LastObserved(filtername=filtername2, nside=nside) - self.result = np.empty(hp.nside2npix(self.nside)) - self.result.fill(np.nan) - send_unused_deprecation_warning(self.__class__.__name__) - - def _calc_value(self, conditions, indx=None): - result = self.result.copy() - d1 = IntRounded(conditions.mjd - self.survey_features["last_obs_f1"].feature) - d2 = IntRounded(conditions.mjd - self.survey_features["last_obs_f2"].feature) - good = np.where((d1 > self.gap_min) & (d1 < self.gap_max) & (d2 > self.gap_min) & (d2 < self.gap_max)) - result[good] = 1 - return result - - class AvoidFastRevisitsBasisFunction(BaseBasisFunction): """Marks targets as unseen if they are in a specified time window in order to avoid fast revisits. @@ -951,13 +712,6 @@ def _calc_value(self, conditions, indx=None): return result -class AvoidFastRevisits(AvoidFastRevisitsBasisFunction): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - warnings.warn("Class has been renamed AvoidFastRevisitsBasisFunction", DeprecationWarning, 2) - - class NearSunHighAirmassBasisFunction(BaseBasisFunction): """Reward areas on the sky at high airmass, within 90 degrees azimuth of the Sun, such as suitable for the near-sun twilight microsurvey for @@ -995,13 +749,6 @@ def _calc_value(self, conditions, indx=None): return result -class NearSunTwilightBasisFunction(NearSunHighAirmassBasisFunction): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - warnings.warn("Class has been renamed NearSunHighAirmassBasisFunction", DeprecationWarning, 2) - - class VisitRepeatBasisFunction(BaseBasisFunction): """ Basis function to reward re-visiting an area on the sky. @@ -1160,192 +907,6 @@ def _calc_value(self, conditions, **kwargs): return result -class GoalStrictFilterBasisFunction(BaseBasisFunction): - """Remove the bonus for staying in the same filter - if certain conditions are met. - - If the moon rises/sets or twilight starts/ends, it makes a lot of sense - to consider a filter change. This basis function rewards if it matches - the current filter, the moon rises or sets, twilight starts or stops, - or there has been a large gap since the last observation. - - Parameters - ---------- - time_lag_min : `float` - Minimum time after a filter change for which a new filter change - will receive zero reward, or be denied at all (see unseen_before_lag). - time_lag_max : `float` - Time after a filter change where the reward for changing filters - achieve its maximum. - time_lag_boost : `float` - Time after a filter change to apply a boost on the reward. - boost_gain : `float` - A multiplier factor for the reward after time_lag_boost. - unseen_before_lag : `bool` - If True will make it impossible to switch filter before time_lag - has passed. - filtername : `str` - The filter for which this basis function will be used. - tag: `str` or None - When using filter proportion use only regions with this tag to - count for observations. - twi_change : `float` - Switch reward on when twilight changes. - proportion : `float` - The expected filter proportion distribution. - aways_available: `bool` - If this is true the basis function will aways be computed - regardless of the feasibility. - If False a more detailed feasibility check is performed. - When set to False, it may speed up the computation process by - avoiding the computation of the reward functions paired with this bf, - when observation is not feasible. - """ - - def __init__( - self, - time_lag_min=10.0, - time_lag_max=30.0, - time_lag_boost=60.0, - boost_gain=2.0, - unseen_before_lag=False, - filtername="r", - tag=None, - twi_change=-18.0, - proportion=1.0, - aways_available=False, - ): - super(GoalStrictFilterBasisFunction, self).__init__(filtername=filtername) - - self.time_lag_min = time_lag_min / 60.0 / 24.0 # Convert to days - self.time_lag_max = time_lag_max / 60.0 / 24.0 # Convert to days - self.time_lag_boost = time_lag_boost / 60.0 / 24.0 - self.boost_gain = boost_gain - self.unseen_before_lag = unseen_before_lag - - self.twi_change = np.radians(twi_change) - self.proportion = proportion - self.aways_available = aways_available - - self.survey_features = {} - self.survey_features["Last_observation"] = features.LastObservation() - self.survey_features["Last_filter_change"] = features.LastFilterChange() - self.survey_features["n_obs_all"] = features.NObsCount(filtername=None) - # Tag is not actually supported at observation level. - self.survey_features["n_obs"] = features.NObsCount(filtername=filtername, tag=tag) - send_unused_deprecation_warning(self.__class__.__name__) - - def filter_change_bonus(self, time): - lag_min = self.time_lag_min - lag_max = self.time_lag_max - - a = 1.0 / (lag_max - lag_min) - b = -a * lag_min - - bonus = a * time + b - # How far behind we are with respect to proportion? - nobs = self.survey_features["n_obs"].feature - nobs_all = self.survey_features["n_obs_all"].feature - goal = self.proportion - # need = 1. - nobs / nobs_all + goal if nobs_all > 0 else 1. + goal - need = goal / nobs * nobs_all if nobs > 0 else 1.0 - # need /= goal - if hasattr(time, "__iter__"): - before_lag = np.where(time <= lag_min) - bonus[before_lag] = -np.inf if self.unseen_before_lag else 0.0 - after_lag = np.where(time >= lag_max) - bonus[after_lag] = 1.0 if time < self.time_lag_boost else self.boost_gain - elif IntRounded(time) <= IntRounded(lag_min): - return -np.inf if self.unseen_before_lag else 0.0 - elif IntRounded(time) >= IntRounded(lag_max): - return 1.0 if IntRounded(time) < IntRounded(self.time_lag_boost) else self.boost_gain - - return bonus * need - - def check_feasibility(self, conditions): - """ - This method makes a pre-check of the feasibility of this - basis function. If a basis function returns False on the - feasibility check, it won't computed at all. - - Returns - ------- - feasibility : `bool` - """ - - # Make a quick check about the feasibility of this basis function. - # If current filter is none, telescope is parked and we could, - # in principle, switch to any filter. If this basis function - # computes reward for the current filter, then it is also feasible. - # At last we check for an "aways_available" flag. Meaning, we - # force this basis function to be aways be computed. - if ( - conditions.current_filter is None - or conditions.current_filter == self.filtername - or self.aways_available - ): - return True - - # If we arrive here, - # we make some extra checks to make sure this bf is - # feasible and should be computed. - - # Did the moon set or rise since last observation? - moon_changed = conditions.moon_alt * self.survey_features["Last_observation"].feature["moonAlt"] < 0 - - # Are we already in the filter (or at start of night)? - not_in_filter = conditions.current_filter != self.filtername - - # Has enough time past? - lag = conditions.mjd - self.survey_features["Last_filter_change"].feature["mjd"] - time_past = IntRounded(lag) > IntRounded(self.time_lag_min) - - # Did twilight start/end? - twi_changed = (conditions.sun_alt - self.twi_change) * ( - self.survey_features["Last_observation"].feature["sun_alt"] - self.twi_change - ) < 0 - - # Did we just finish a DD sequence - was_dd = self.survey_features["Last_observation"].feature["scheduler_note"] == "DD" - - # Is the filter mounted? - mounted = self.filtername in conditions.mounted_filters - - if (moon_changed | time_past | twi_changed | was_dd) & mounted & not_in_filter: - return True - else: - return False - - def _calc_value(self, conditions, **kwargs): - if conditions.current_filter is None: - return 0.0 # no bonus if no filter is mounted - - # Did the moon set or rise since last observation? - moon_changed = conditions.moon_alt * self.survey_features["Last_observation"].feature["moonAlt"] < 0 - - # Has enough time past? - lag = conditions.mjd - self.survey_features["Last_filter_change"].feature["mjd"] - time_past = lag > self.time_lag_min - - # Did twilight start/end? - twi_changed = (conditions.sun_alt - self.twi_change) * ( - self.survey_features["Last_observation"].feature["sun_alt"] - self.twi_change - ) < 0 - - # Did we just finish a DD sequence - was_dd = self.survey_features["Last_observation"].feature["scheduler_note"] == "DD" - - # Is the filter mounted? - mounted = self.filtername in conditions.mounted_filters - - if (moon_changed | time_past | twi_changed | was_dd) & mounted: - result = self.filter_change_bonus(lag) if time_past else 0.0 - else: - result = -100.0 if self.unseen_before_lag else 0.0 - - return result - - class FilterChangeBasisFunction(BaseBasisFunction): """Reward staying in the current filter.""" @@ -1416,226 +977,6 @@ def _calc_value(self, conditions, indx=None): return result -class AggressiveSlewtimeBasisFunction(BaseBasisFunction): - """Reward slews that take little time - - XXX--not sure how this is different from SlewtimeBasisFunction? - Looks like it's checking the slewtime to the field position - rather than the healpix maybe? - """ - - def __init__(self, max_time=135.0, order=1.0, hard_max=None, filtername="r", nside=None): - super(AggressiveSlewtimeBasisFunction, self).__init__(nside=nside, filtername=filtername) - - self.maxtime = max_time - self.hard_max = hard_max - self.order = order - self.result = np.zeros(hp.nside2npix(self.nside), dtype=float) - send_unused_deprecation_warning(self.__class__.__name__) - - def _calc_value(self, conditions, indx=None): - # If we are in a different filter, the - # FilterChangeBasisFunction will take it - if conditions.current_filter != self.filtername: - result = 0.0 - else: - # Need to make sure smaller slewtime is larger reward. - if np.size(self.condition_features["slewtime"].feature) > 1: - result = self.result.copy() - result.fill(np.nan) - - good = np.where(np.bitwise_and(conditions.slewtime > 0.0, conditions.slewtime < self.maxtime)) - result[good] = ((self.maxtime - conditions.slewtime[good]) / self.maxtime) ** self.order - if self.hard_max is not None: - not_so_good = np.where(conditions.slewtime > self.hard_max) - result[not_so_good] -= 10.0 - fields = np.unique(conditions.hp2fields[good]) - for field in fields: - hp_indx = np.where(conditions.hp2fields == field) - result[hp_indx] = np.min(result[hp_indx]) - else: - result = (self.maxtime - conditions.slewtime) / self.maxtime - return result - - -class SkybrightnessLimitBasisFunction(BaseBasisFunction): - """Mask regions that are outside a sky brightness limit. - - Parameters - ---------- - nside : `int`, optional - The nside for the basis function. Default None uses - `set_default_nside()`. - filtername : `str`, optional - The filter to consider for the skybrightness pre values. - sbmin : `float`, optional - The minimum (brightest) skybrightness to consider (mags). - Default of 20 will cut out some times of night or parts of the - sky. - sbmax : `float`, optional - The maximum (faintest) skybrightness to consider (mags). - Default of 30 will pass all skybrightness values. - """ - - def __init__(self, nside=None, filtername="r", sbmin=20.0, sbmax=30.0): - super(SkybrightnessLimitBasisFunction, self).__init__(nside=nside, filtername=filtername) - - self.min = IntRounded(sbmin) - self.max = IntRounded(sbmax) - self.result = np.empty(hp.nside2npix(self.nside), dtype=float) - self.result.fill(np.nan) - send_unused_deprecation_warning(self.__class__.__name__) - - def _calc_value(self, conditions, indx=None): - result = self.result.copy() - - # Replace non-finite values so IntRounded works, then set - # the result to nan. - sky_brightness = conditions.skybrightness[self.filtername].copy() - not_finite = np.where(~np.isfinite(sky_brightness)) - sky_brightness[not_finite] = 0 - rounded_sky_brightness = IntRounded(sky_brightness) - good = np.where( - np.bitwise_and( - rounded_sky_brightness > self.min, - rounded_sky_brightness < self.max, - ) - ) - result[not_finite] = np.nan - result[good] = 1.0 - - return result - - -class SkybrightnessLimitAtHpixBasisFunction( - HealpixLimitedBasisFunctionMixin, SkybrightnessLimitBasisFunction -): - pass - - -class CablewrapUnwrapBasisFunction(BaseBasisFunction): - """ - Parameters - ---------- - min_az : `float` (20.) - The minimum azimuth to activate bf (degrees) - max_az : `float` (82.) - The maximum azimuth to activate bf (degrees) - unwrap_until: `float` (90.) - The window in which the bf is activated (degrees) - """ - - def __init__( - self, - nside=None, - min_az=-270.0, - max_az=270.0, - min_alt=20.0, - max_alt=82.0, - activate_tol=20.0, - delta_unwrap=1.2, - unwrap_until=70.0, - max_duration=30.0, - ): - super(CablewrapUnwrapBasisFunction, self).__init__(nside=nside) - - self.min_az = np.radians(min_az) - self.max_az = np.radians(max_az) - - self.activate_tol = np.radians(activate_tol) - self.delta_unwrap = np.radians(delta_unwrap) - self.unwrap_until = np.radians(unwrap_until) - - self.min_alt = np.radians(min_alt) - self.max_alt = np.radians(max_alt) - # Convert to half-width for convienence - self.nside = nside - self.active = False - self.unwrap_direction = 0.0 # either -1., 0., 1. - self.max_duration = max_duration / 60.0 / 24.0 # Convert to days - self.activation_time = None - self.result = np.zeros(hp.nside2npix(self.nside), dtype=float) - send_unused_deprecation_warning(self.__class__.__name__) - - def _calc_value(self, conditions, indx=None): - result = self.result.copy() - - current_abs_rad = np.radians(conditions.az) - unseen = np.where(np.bitwise_or(conditions.alt < self.min_alt, conditions.alt > self.max_alt)) - result[unseen] = np.nan - - if ( - self.min_az + self.activate_tol < current_abs_rad < self.max_az - self.activate_tol - ) and not self.active: - return result - elif self.active and self.unwrap_direction == 1 and current_abs_rad > self.min_az + self.unwrap_until: - self.active = False - self.unwrap_direction = 0.0 - self.activation_time = None - return result - elif ( - self.active and self.unwrap_direction == -1 and current_abs_rad < self.max_az - self.unwrap_until - ): - self.active = False - self.unwrap_direction = 0.0 - self.activation_time = None - return result - elif self.activation_time is not None and conditions.mjd - self.activation_time > self.max_duration: - self.active = False - self.unwrap_direction = 0.0 - self.activation_time = None - return result - - if not self.active: - self.activation_time = conditions.mjd - if current_abs_rad < 0.0: - self.unwrap_direction = 1 # clock-wise unwrap - else: - self.unwrap_direction = -1 # counter-clock-wise unwrap - - self.active = True - - max_abs_rad = self.max_az - min_abs_rad = self.min_az - - TWOPI = 2.0 * np.pi - - # Compute distance and accumulated az. - norm_az_rad = np.divmod(conditions.az - min_abs_rad, TWOPI)[1] + min_abs_rad - distance_rad = divmod(norm_az_rad - current_abs_rad, TWOPI)[1] - get_shorter = np.where(distance_rad > np.pi) - distance_rad[get_shorter] -= TWOPI - accum_abs_rad = current_abs_rad + distance_rad - - # Compute wrap regions and fix distances - mask_max = np.where(accum_abs_rad > max_abs_rad) - distance_rad[mask_max] -= TWOPI - mask_min = np.where(accum_abs_rad < min_abs_rad) - distance_rad[mask_min] += TWOPI - - # Step-2: Repeat but now with compute reward to unwrap - # using specified delta_unwrap - unwrap_current_abs_rad = current_abs_rad - ( - np.abs(self.delta_unwrap) if self.unwrap_direction > 0 else -np.abs(self.delta_unwrap) - ) - unwrap_distance_rad = divmod(norm_az_rad - unwrap_current_abs_rad, TWOPI)[1] - unwrap_get_shorter = np.where(unwrap_distance_rad > np.pi) - unwrap_distance_rad[unwrap_get_shorter] -= TWOPI - unwrap_distance_rad = np.abs(unwrap_distance_rad) - - if self.unwrap_direction < 0: - mask = np.where(accum_abs_rad > unwrap_current_abs_rad) - else: - mask = np.where(accum_abs_rad < unwrap_current_abs_rad) - - # Finally build reward map - result = (1.0 - unwrap_distance_rad / np.max(unwrap_distance_rad)) ** 2.0 - result[mask] = 0.0 - result[unseen] = np.nan - - return result - - class CadenceEnhanceBasisFunction(BaseBasisFunction): """Drive a certain cadence @@ -1995,109 +1336,6 @@ def _calc_value(self, conditions, **kwargs): return result -class TemplateGenerateBasisFunction(BaseBasisFunction): - """Emphasize areas that have not been observed in a long time - - Parameters - ---------- - nside : `int`, optional - The nside for the basis function and feature. - Default None uses `set_default_nside()`. - day_gap : `float`, optional - How long to wait before boosting the reward (days). - Default of 250 pushes visits into parts of the sky which missed - a significant chunk of a season. - filtername : `str`, optional - The filter to consider when tracking observations. - footprint : `np.array`, (N,) - The indices of the healpixels to apply the boost to. - Default None will call `get_current_footprint()`. - """ - - def __init__(self, nside=None, day_gap=250.0, filtername="r", footprint=None): - super(TemplateGenerateBasisFunction, self).__init__(nside=nside) - self.day_gap = day_gap - self.filtername = filtername - self.survey_features = {} - self.survey_features["Last_observed"] = features.LastObserved(filtername=filtername) - self.result = np.zeros(hp.nside2npix(self.nside)) - if footprint is None: - footprints, labels = get_current_footprint(self.nside) - fp = footprints[self.filtername] - else: - fp = footprint - self.out_of_bounds = np.where(fp == 0) - send_unused_deprecation_warning(self.__class__.__name__) - - def _calc_value(self, conditions, **kwargs): - result = self.result.copy() - overdue = np.where( - (IntRounded(conditions.mjd - self.survey_features["Last_observed"].feature)) - > IntRounded(self.day_gap) - ) - result[overdue] = 1 - result[self.out_of_bounds] = 0 - - return result - - -class LimitRepeatBasisFunction(BaseBasisFunction): - """Mask out pixels that haven't been observed in the night. - - Parameters - ---------- - nside : `int` or None - Nside for the basis function values. Default None will use - default nside. - filtername : `str` or None - Filter to consider when tracking number of acquired observations. - - """ - - def __init__(self, nside=None, filtername="r", n_limit=2): - super(LimitRepeatBasisFunction, self).__init__(nside=nside) - self.filtername = filtername - self.n_limit = n_limit - self.survey_features = {} - self.survey_features["n_obs"] = features.NObsNight(nside=nside, filtername=filtername) - - self.result = np.zeros(hp.nside2npix(self.nside)) - send_unused_deprecation_warning(self.__class__.__name__) - - def _calc_value(self, conditions, **kwargs): - result = self.result.copy() - good_pix = np.where(self.survey_features["n_obs"].feature >= self.n_limit)[0] - result[good_pix] = 1 - - return result - - -class ObservedTwiceBasisFunction(BaseBasisFunction): - """Mask out pixels that haven't been observed in the night""" - - def __init__(self, nside=None, filtername="r", n_obs_needed=2, n_obs_in_filt_needed=1): - super(ObservedTwiceBasisFunction, self).__init__(nside=nside) - self.n_obs_needed = n_obs_needed - self.n_obs_in_filt_needed = n_obs_in_filt_needed - self.filtername = filtername - self.survey_features = {} - self.survey_features["n_obs_infilt"] = features.NObsNight(nside=nside, filtername=filtername) - self.survey_features["n_obs_all"] = features.NObsNight(nside=nside, filtername="") - - self.result = np.zeros(hp.nside2npix(self.nside)) - send_unused_deprecation_warning(self.__class__.__name__) - - def _calc_value(self, conditions, **kwargs): - result = self.result.copy() - good_pix = np.where( - (self.survey_features["n_obs_infilt"].feature >= self.n_obs_in_filt_needed) - & (self.survey_features["n_obs_all"].feature >= self.n_obs_needed) - )[0] - result[good_pix] = 1 - - return result - - class VisitGap(BaseBasisFunction): """Basis function to create a visit gap based on the survey note field. @@ -2219,8 +1457,8 @@ def __init__(self, nobs_reference, note_survey, note_interest, nside=None): self.nobs_reference = nobs_reference self.survey_features = {} - self.survey_features["n_obs_survey"] = features.NObsSurvey(note=note_survey) - self.survey_features["n_obs_survey_interest"] = features.NObsSurvey(note=note_interest) + self.survey_features["n_obs_survey"] = features.NObsCount(note=note_survey) + self.survey_features["n_obs_survey_interest"] = features.NObsCount(note=note_interest) def _calc_value(self, conditions, indx=None): return (1 + np.floor(self.survey_features["n_obs_survey_interest"].feature / self.nobs_reference)) / ( @@ -2255,7 +1493,7 @@ def __init__(self, n_obs_survey, note_survey, nside=None): self.n_obs_survey = n_obs_survey self.survey_features = {} - self.survey_features["n_obs_survey"] = features.NObsSurvey(note=note_survey) + self.survey_features["n_obs_survey"] = features.NObsCount(note=note_survey) def _calc_value(self, conditions, indx=None): return self.survey_features["n_obs_survey"].feature % self.n_obs_survey diff --git a/rubin_scheduler/scheduler/basis_functions/feasibility_funcs.py b/rubin_scheduler/scheduler/basis_functions/feasibility_funcs.py index 45ed0213..fac39abb 100644 --- a/rubin_scheduler/scheduler/basis_functions/feasibility_funcs.py +++ b/rubin_scheduler/scheduler/basis_functions/feasibility_funcs.py @@ -18,8 +18,6 @@ "NightModuloBasisFunction", "EndOfEveningBasisFunction", "TimeToScheduledBasisFunction", - "LimitObsPnightBasisFunction", - "SunHighLimitBasisFunction", "CloseToTwilightBasisFunction", "MoonDistPointRangeBasisFunction", "AirmassPointRangeBasisFunction", @@ -203,17 +201,6 @@ def check_feasibility(self, conditions): return result -class SunHighLimitBasisFunction(CloseToTwilightBasisFunction): - - def __init__(self, sun_alt_limit=-14.8, time_to_12deg=21.0, time_remaining=15.0): - super().__init__( - max_sun_alt_limit=sun_alt_limit, - min_time_remaining=time_remaining, - max_time_to_12deg=time_to_12deg, - ) - warnings.warn("Class has been renamed CloseToTwilightBasisFunction", DeprecationWarning, 2) - - class OnceInNightBasisFunction(BaseBasisFunction): """Stop observing if something has been executed already in the night @@ -254,22 +241,6 @@ def check_feasibility(self, conditions): return result -class LimitObsPnightBasisFunction(BaseBasisFunction): - """""" - - def __init__(self, survey_str="", nlimit=100.0): - super(LimitObsPnightBasisFunction, self).__init__() - self.nlimit = nlimit - self.survey_features["N_in_night"] = features.SurveyInNight(survey_str=survey_str) - send_unused_deprecation_warning(self.__class__.__name__) - - def check_feasibility(self, conditions): - if self.survey_features["N_in_night"].feature >= self.nlimit: - return False - else: - return True - - class NightModuloBasisFunction(BaseBasisFunction): """Only return true on certain nights""" @@ -464,8 +435,8 @@ def __init__( self.scheduler_note = scheduler_note self.survey_features["last_obs_self"] = features.LastObservation(scheduler_note=self.scheduler_note) self.fractions = fractions - self.survey_features["N_total"] = features.NObsSurvey(note=None) - self.survey_features["N_note"] = features.NObsSurvey(note=self.scheduler_note) + self.survey_features["N_total"] = features.NObsCount(note=None) + self.survey_features["N_note"] = features.NObsCount(note=self.scheduler_note) def check_feasibility(self, conditions): result = True @@ -544,8 +515,8 @@ def __init__(self, frac_total, scheduler_note=None, survey_name=None): else: self.scheduler_note = scheduler_note self.frac_total = frac_total - self.survey_features["N_total"] = features.NObsSurvey(note=None) - self.survey_features["N_note"] = features.NObsSurvey(note=self.scheduler_note) + self.survey_features["N_total"] = features.NObsCount(note=None) + self.survey_features["N_note"] = features.NObsCount(note=self.scheduler_note) def check_feasibility(self, conditions): # If nothing has been observed, fine to go @@ -615,8 +586,8 @@ def __init__( self.time_jump = time_jump / 60.0 / 24.0 # To days self.time_needed = time_needed / 60.0 / 24.0 # To days self.aggressive_fraction = aggressive_fraction - self.survey_features["N_total"] = features.NObsSurvey(note=None) - self.survey_features["N_note"] = features.NObsSurvey(note=self.scheduler_note) + self.survey_features["N_total"] = features.NObsCount(note=None) + self.survey_features["N_note"] = features.NObsCount(note=self.scheduler_note) def check_feasibility(self, conditions): result = True diff --git a/rubin_scheduler/scheduler/basis_functions/mask_basis_funcs.py b/rubin_scheduler/scheduler/basis_functions/mask_basis_funcs.py index 9bc7e253..247ed40a 100644 --- a/rubin_scheduler/scheduler/basis_functions/mask_basis_funcs.py +++ b/rubin_scheduler/scheduler/basis_functions/mask_basis_funcs.py @@ -1,26 +1,20 @@ __all__ = ( "SolarElongMaskBasisFunction", - "ZenithShadowMaskBasisFunction", "HaMaskBasisFunction", "MoonAvoidanceBasisFunction", "MapCloudBasisFunction", "PlanetMaskBasisFunction", - "MaskAzimuthBasisFunction", "SolarElongationMaskBasisFunction", "AreaCheckMaskBasisFunction", "AltAzShadowMaskBasisFunction", ) -import warnings - import healpy as hp import numpy as np from rubin_scheduler.scheduler.basis_functions import BaseBasisFunction from rubin_scheduler.scheduler.utils import HpInLsstFov, IntRounded -from rubin_scheduler.utils import Site, _angular_separation, _hpid2_ra_dec - -from .basis_functions import send_unused_deprecation_warning +from rubin_scheduler.utils import _angular_separation class SolarElongMaskBasisFunction(BaseBasisFunction): @@ -312,76 +306,6 @@ def _calc_value(self, conditions, indx=None): return result -class ZenithShadowMaskBasisFunction(BaseBasisFunction): - """Mask the zenith, and things that will soon pass near zenith. - Useful for making sure observations will not be too close to zenith - when they need to be observed again (e.g. for a pair). - - Parameters - ---------- - min_alt : float (20.) - The minimum alititude to alow. Everything lower is masked. (degrees) - max_alt : float (82.) - The maximum altitude to alow. Everything higher is masked. (degrees) - shadow_minutes : float (40.) - Mask anything that will pass through the max alt in the next - shadow_minutes time. (minutes) - """ - - def __init__( - self, - nside=None, - min_alt=20.0, - max_alt=82.0, - shadow_minutes=40.0, - penalty=np.nan, - site="LSST", - ): - warnings.warn( - "Deprecating ZenithShadowMaskBasisFunction in favor of AltAzShadowMaskBasisFunction.", - DeprecationWarning, - ) - - super(ZenithShadowMaskBasisFunction, self).__init__(nside=nside) - self.update_on_newobs = False - - self.penalty = penalty - - self.min_alt = np.radians(min_alt) - self.max_alt = np.radians(max_alt) - self.ra, self.dec = _hpid2_ra_dec(self.nside, np.arange(hp.nside2npix(self.nside))) - self.shadow_minutes = np.radians(shadow_minutes / 60.0 * 360.0 / 24.0) - # Compute the declination band where things could drift into zenith - self.decband = np.zeros(self.dec.size, dtype=float) - self.zenith_radius = np.radians(90.0 - max_alt) / 2.0 - site = Site(name=site) - self.lat_rad = site.latitude_rad - self.lon_rad = site.longitude_rad - self.decband[ - np.where( - (IntRounded(self.dec) < IntRounded(self.lat_rad + self.zenith_radius)) - & (IntRounded(self.dec) > IntRounded(self.lat_rad - self.zenith_radius)) - ) - ] = 1 - - self.result = np.empty(hp.nside2npix(self.nside), dtype=float) - self.result.fill(self.penalty) - - def _calc_value(self, conditions, indx=None): - result = self.result.copy() - alt_limit = np.where( - (IntRounded(conditions.alt) > IntRounded(self.min_alt)) - & (IntRounded(conditions.alt) < IntRounded(self.max_alt)) - )[0] - result[alt_limit] = 1 - to_mask = np.where( - (IntRounded(conditions.HA) > IntRounded(2.0 * np.pi - self.shadow_minutes - self.zenith_radius)) - & (self.decband == 1) - ) - result[to_mask] = np.nan - return result - - class MoonAvoidanceBasisFunction(BaseBasisFunction): """Avoid observing within `moon_distance` of the moon. @@ -515,27 +439,3 @@ def _calc_value(self, conditions, indx=None): result[clouded] = self.out_of_bounds_val return result - - -class MaskAzimuthBasisFunction(BaseBasisFunction): - """Mask pixels based on azimuth. - - Superseded by AltAzShadowMaskBasisFunction. - """ - - def __init__(self, nside=None, out_of_bounds_val=np.nan, az_min=0.0, az_max=180.0): - super(MaskAzimuthBasisFunction, self).__init__(nside=nside) - self.az_min = IntRounded(np.radians(az_min)) - self.az_max = IntRounded(np.radians(az_max)) - self.out_of_bounds_val = out_of_bounds_val - self.result = np.ones(hp.nside2npix(self.nside)) - send_unused_deprecation_warning(self.__class__.__name__) - - def _calc_value(self, conditions, indx=None): - to_mask = np.where( - (IntRounded(conditions.az) > self.az_min) & (IntRounded(conditions.az) < self.az_max) - )[0] - result = self.result.copy() - result[to_mask] = self.out_of_bounds_val - - return result diff --git a/rubin_scheduler/scheduler/basis_functions/rolling_funcs.py b/rubin_scheduler/scheduler/basis_functions/rolling_funcs.py index 02354521..470ef3dd 100644 --- a/rubin_scheduler/scheduler/basis_functions/rolling_funcs.py +++ b/rubin_scheduler/scheduler/basis_functions/rolling_funcs.py @@ -1,11 +1,7 @@ -__all__ = ( - "TargetMapModuloBasisFunction", - "FootprintBasisFunction", -) +__all__ = ("FootprintBasisFunction",) import warnings -import healpy as hp import numpy as np from rubin_scheduler.scheduler import features, utils @@ -37,7 +33,7 @@ class FootprintBasisFunction(BaseBasisFunction): The desired footprint. The default will set this to None, but in general this is really not desirable. In order to make default a kwarg, a current baseline footprint - is setup with a Constant footprint (not rolling, not even season + is set up with a Constant footprint (not rolling, not even season aware). out_of_bounds_val : `float`, optional The value to set the basis function for regions that are not in @@ -83,168 +79,3 @@ def _calc_value(self, conditions, indx=None): result = desired - self.survey_features["N_obs"].feature result[self.out_of_bounds_area] = self.out_of_bounds_val return result - - -class TargetMapModuloBasisFunction(BaseBasisFunction): - """Basis function that tracks number of observations and tries to match - a specified spatial distribution can enter multiple maps that will be - used at different times in the survey - - Parameters - ---------- - day_offset : np.array - Healpix map that has the offset to be applied to each pixel when - computing what season it is on. - filtername : (string 'r') - The name of the filter for this target map. - nside: int (default_nside) - The healpix resolution. - target_maps : list of numpy array (None) - healpix maps showing the ratio of observations desired for all - points on the sky. Last map will be used for season -1. Probably - shouldn't support going to season less than -1. - norm_factor : float (0.00010519) - for converting target map to number of observations. Should be the - area of the camera divided by the area of a healpixel divided by - the sum of all your goal maps. Default value assumes LSST foV has - 1.75 degree radius and the standard goal maps. If using mulitple - filters, see rubin_scheduler.scheduler.utils.calc_norm_factor for - a utility that computes norm_factor. - out_of_bounds_val : float (-10.) - Reward value to give regions where there are no observations - requested (unitless). - season_modulo : int (2) - The value to modulate the season by (years). - max_season : int (None) - For seasons higher than this value (pre-modulo), the final target - map is used. - - """ - - def __init__( - self, - day_offset=None, - filtername="r", - nside=None, - target_maps=None, - norm_factor=None, - out_of_bounds_val=-10.0, - season_modulo=2, - max_season=None, - season_length=365.25, - ): - super(TargetMapModuloBasisFunction, self).__init__(nside=nside, filtername=filtername) - - if norm_factor is None: - warnings.warn("No norm_factor set, use utils.calc_norm_factor if using multiple filters.") - self.norm_factor = 0.00010519 - else: - self.norm_factor = norm_factor - - self.survey_features = {} - # Map of the number of observations in filter - - for i, temp in enumerate(target_maps[0:-1]): - self.survey_features["N_obs_%i" % i] = features.N_observations_season( - i, - filtername=filtername, - nside=self.nside, - modulo=season_modulo, - offset=day_offset, - max_season=max_season, - season_length=season_length, - ) - # Count of all the observations taken in a season - self.survey_features["N_obs_count_all_%i" % i] = features.N_obs_count_season( - i, - filtername=None, - season_modulo=season_modulo, - offset=day_offset, - nside=self.nside, - max_season=max_season, - season_length=season_length, - ) - # Set the final one to be -1 - self.survey_features["N_obs_%i" % -1] = features.N_observations_season( - -1, - filtername=filtername, - nside=self.nside, - modulo=season_modulo, - offset=day_offset, - max_season=max_season, - season_length=season_length, - ) - self.survey_features["N_obs_count_all_%i" % -1] = features.N_obs_count_season( - -1, - filtername=None, - season_modulo=season_modulo, - offset=day_offset, - nside=self.nside, - max_season=max_season, - season_length=season_length, - ) - if target_maps is None: - target_maps, labels = utils.get_current_footprint(nside) - self.target_map = target_maps[filtername] - else: - self.target_maps = target_maps - # should probably actually loop over all the target maps? - self.out_of_bounds_area = np.where(self.target_maps[0] == 0)[0] - self.out_of_bounds_val = out_of_bounds_val - self.result = np.zeros(hp.nside2npix(self.nside), dtype=float) - self.all_indx = np.arange(self.result.size) - - # For computing what day each healpix is on - if day_offset is None: - self.day_offset = np.zeros(hp.nside2npix(self.nside), dtype=float) - else: - self.day_offset = day_offset - - self.season_modulo = season_modulo - self.max_season = max_season - self.season_length = season_length - send_unused_deprecation_warning(self.__class__.__name__) - - def _calc_value(self, conditions, indx=None): - """ - Parameters - ---------- - indx : list (None) - Index values to compute, if None, full map is computed - Returns - ------- - Healpix reward map - """ - - result = self.result.copy() - if indx is None: - indx = self.all_indx - - # Compute what season it is at each pixel - seasons = utils.season_calc( - conditions.night, - offset=self.day_offset, - modulo=self.season_modulo, - max_season=self.max_season, - season_length=self.season_length, - ) - - composite_target = self.result.copy()[indx] - composite_nobs = self.result.copy()[indx] - - composite_goal_n = self.result.copy()[indx] - - for season in np.unique(seasons): - season_indx = np.where(seasons == season)[0] - composite_target[season_indx] = self.target_maps[season][season_indx] - composite_nobs[season_indx] = self.survey_features["N_obs_%i" % season].feature[season_indx] - composite_goal_n[season_indx] = ( - composite_target[season_indx] - * self.survey_features["N_obs_count_all_%i" % season].feature - * self.norm_factor - ) - - result[indx] = composite_goal_n - composite_nobs[indx] - result[self.out_of_bounds_area] = self.out_of_bounds_val - - return result diff --git a/rubin_scheduler/scheduler/example/simple_examples.py b/rubin_scheduler/scheduler/example/simple_examples.py index 900b827d..1ac5fb8f 100644 --- a/rubin_scheduler/scheduler/example/simple_examples.py +++ b/rubin_scheduler/scheduler/example/simple_examples.py @@ -524,7 +524,7 @@ def simple_pairs_survey( # Tucking this here so we can look at how many observations # recorded for this survey and what was the last one. - pair_survey.extra_features["ObsRecorded"] = features.NObsSurvey() + pair_survey.extra_features["ObsRecorded"] = features.NObsCount() pair_survey.extra_features["LastObs"] = features.LastObservation() return pair_survey @@ -672,7 +672,7 @@ def simple_greedy_survey( # Tucking this here so we can look at how many observations # recorded for this survey and what was the last one. - greedy_survey.extra_features["ObsRecorded"] = features.NObsSurvey() + greedy_survey.extra_features["ObsRecorded"] = features.NObsCount() greedy_survey.extra_features["LastObs"] = features.LastObservation() return greedy_survey diff --git a/rubin_scheduler/scheduler/features/features.py b/rubin_scheduler/scheduler/features/features.py index b9668cc1..1085db64 100644 --- a/rubin_scheduler/scheduler/features/features.py +++ b/rubin_scheduler/scheduler/features/features.py @@ -2,21 +2,14 @@ "BaseFeature", "BaseSurveyFeature", "NObsCount", - "NObsSurvey", "LastObservation", - "LastsequenceObservation", - "LastFilterChange", "NObservations", - "CoaddedDepth", "LastObserved", "NObsNight", "PairInNight", "RotatorAngle", - "NObservationsSeason", - "NObsCountSeason", "NObservationsCurrentSeason", "LastNObsTimes", - "SurveyInNight", "NoteInNight", "NoteLastObserved", ) @@ -30,7 +23,7 @@ from rubin_scheduler.scheduler import utils from rubin_scheduler.scheduler.utils import IntRounded from rubin_scheduler.skybrightness_pre import dark_sky -from rubin_scheduler.utils import _hpid2_ra_dec, calc_season, m5_flat_sed, survey_start_mjd +from rubin_scheduler.utils import _hpid2_ra_dec, calc_season, survey_start_mjd def send_unused_deprecation_warning(name): @@ -111,34 +104,6 @@ def add_observation(self, observation, indx=None, **kwargs): raise NotImplementedError -class SurveyInNight(BaseSurveyFeature): - """Count appearances of `survey_str` within observation `note` in - the current night; `survey_str` must be contained in `note`. - - Useful to keep track of how many times a survey has executed in a night. - - Parameters - ---------- - survey_str : `str`, optional - String to search for in observation `scheduler_note`. - Default of "" means any observation will match. - """ - - def __init__(self, survey_str=""): - self.feature = 0 - self.survey_str = survey_str - self.night = -100 - send_unused_deprecation_warning(self.__class__.__name__) - - def add_observation(self, observation, indx=None): - if observation["night"] != self.night: - self.night = observation["night"] - self.feature = 0 - - if self.survey_str in observation["scheduler_note"][0]: - self.feature += 1 - - class NoteInNight(BaseSurveyFeature): """Count appearances of any of `scheduler_notes` in observation `scheduler_note` in the @@ -182,137 +147,6 @@ def add_observation(self, observation, indx=None): class NObsCount(BaseSurveyFeature): - """Count the number of observations. - Total number, not tracked over sky - - Parameters - ---------- - filtername : `str` or None - The filter to count. Default None, all filters counted. - - """ - - def __init__(self, filtername=None, tag=None): - self.feature = 0 - self.filtername = filtername - # 'tag' is used in GoalStrictFilterBasisFunction - self.tag = tag - if self.tag is not None: - warnings.warn( - "Tag is not a supported element" - "of the `observation` and this aspect of " - "the feature will be " - "deprecated in 2 minor releases.", - DeprecationWarning, - stack_level=2, - ) - - def add_observations_array(self, observations_array, observations_hpid): - if self.filtername is None: - self.feature += np.size(observations_array) - else: - in_filt = np.where(observations_array["filter"] == self.filtername)[0] - self.feature += np.size(in_filt) - - def add_observation(self, observation, indx=None): - if (self.filtername is None) and (self.tag is None): - # Track all observations - self.feature += 1 - elif ( - (self.filtername is not None) - and (self.tag is None) - and (observation["filter"][0] in self.filtername) - ): - # Track all observations on a specified filter - self.feature += 1 - elif (self.filtername is None) and (self.tag is not None) and (observation["tag"][0] in self.tag): - # Track all observations on a specified tag - self.feature += 1 - elif ( - (self.filtername is None) - and (self.tag is not None) - and - # Track all observations on a specified filter on a specified tag - (observation["filter"][0] in self.filtername) - and (observation["tag"][0] in self.tag) - ): - self.feature += 1 - - -class NObsCountSeason(BaseSurveyFeature): - """Count the number of observations in a season. - - Parameters - ---------- - filtername : `str` (None) - The filter to count (if None, all filters counted) - - Notes - ----- - Uses `season_calc` to calculate season value. - - Seems unused - added deprecation warning. - """ - - def __init__( - self, - season, - nside=None, - filtername=None, - tag=None, - season_modulo=2, - offset=None, - max_season=None, - season_length=365.25, - ): - self.feature = 0 - self.filtername = filtername - self.tag = tag - self.season = season - self.season_modulo = season_modulo - if offset is None: - self.offset = np.zeros(hp.nside2npix(nside), dtype=int) - else: - self.offset = offset - self.max_season = max_season - self.season_length = season_length - send_unused_deprecation_warning(self.__class__.__name__) - - def add_observation(self, observation, indx=None): - season = utils.season_calc( - observation["night"], - modulo=self.season_modulo, - offset=self.offset[indx], - max_season=self.max_season, - season_length=self.season_length, - ) - if self.season in season: - if (self.filtername is None) and (self.tag is None): - # Track all observations - self.feature += 1 - elif ( - (self.filtername is not None) - and (self.tag is None) - and (observation["filter"][0] in self.filtername) - ): - # Track all observations on a specified filter - self.feature += 1 - elif (self.filtername is None) and (self.tag is not None) and (observation["tag"][0] in self.tag): - # Track all observations on a specified tag - self.feature += 1 - elif ( - (self.filtername is None) - and (self.tag is not None) - and - # Track all observations on a specified filter on a - # specified tag - (observation["filter"][0] in self.filtername) - and (observation["tag"][0] in self.tag) - ): - self.feature += 1 - - -class NObsSurvey(BaseSurveyFeature): """Count the number of observations, whole sky (not per pixel). Because this feature will belong to a survey, it would count all @@ -323,28 +157,47 @@ class NObsSurvey(BaseSurveyFeature): note : `str` or None Count observations that match `str` in their scheduler_note field. Note can be a substring of scheduler_note, and will still match. + filtername : `str` or None + Optionally also (or independently) specify a filter to match. """ - def __init__(self, note=None): + def __init__(self, note=None, filtername=None): self.feature = 0 self.note = note if self.note == "": self.note = None + self.filtername = filtername def add_observations_array(self, observations_array, observations_hpid): - if self.note is None: + if self.note is None and self.filtername is None: self.feature += observations_array.size - else: + elif self.note is None and self.filtername is not None: + in_filt = np.where(observations_array["filter"] == self.filtername) + self.feature += np.size(in_filt) + elif self.note is not None and self.filtername is None: count = [self.note in note for note in observations_array["scheduler_note"]] self.feature += np.sum(count) + else: + # note and filtername are defined + in_filt = np.where(observations_array["filter"] == self.filtername) + count = [self.note in note for note in observations_array["scheduler_note"][in_filt]] + self.feature += np.sum(count) def add_observation(self, observation, indx=None): # Track all observations - if self.note is None: + if self.note is None and self.filtername is None: self.feature += 1 - else: + elif self.note is None and self.filtername is not None: + if observation["filter"][0] in self.filtername: + self.feature += 1 + elif self.note is not None and self.filtername is None: if self.note in observation["scheduler_note"][0]: self.feature += 1 + else: + if (observation["filter"][0] in self.filtername) and self.note in observation["scheduler_note"][ + 0 + ]: + self.feature += 1 class LastObservation(BaseSurveyFeature): @@ -387,39 +240,6 @@ def add_observation(self, observation, indx=None): self.feature = observation -class LastsequenceObservation(BaseSurveyFeature): - """When was the last observation""" - - def __init__(self, sequence_ids=""): - self.sequence_ids = sequence_ids # The ids of all sequence - # observations... - # Start out with an empty observation - self.feature = utils.ObservationArray() - send_unused_deprecation_warning(self.__class__.__name__) - - def add_observation(self, observation, indx=None): - if observation["survey_id"] in self.sequence_ids: - self.feature = observation - - -class LastFilterChange(BaseSurveyFeature): - """Record when the filter last changed.""" - - def __init__(self): - self.feature = {"mjd": 0.0, "previous_filter": None, "current_filter": None} - send_unused_deprecation_warning(self.__class__.__name__) - - def add_observation(self, observation, indx=None): - if self.feature["current_filter"] is None: - self.feature["mjd"] = observation["mjd"][0] - self.feature["previous_filter"] = None - self.feature["current_filter"] = observation["filter"][0] - elif observation["filter"][0] != self.feature["current_filter"]: - self.feature["mjd"] = observation["mjd"][0] - self.feature["previous_filter"] = self.feature["current_filter"] - self.feature["current_filter"] = observation["filter"][0] - - class NObservations(BaseSurveyFeature): """ Track the number of observations that have been made across the sky. @@ -480,70 +300,10 @@ def add_observation(self, observation, indx=None): self.feature[indx] += 1 -class NObservationsSeason(BaseSurveyFeature): - """ - Track the number of observations that have been made across sky - - Parameters - ---------- - season : `int` - Only count observations in this season (year). - filtername : `str` ('r') - String or list that has all the filters that can count. - nside : `int` (32) - The nside of the healpixel map to use - offset : `int` (0) - The offset to use when computing the season (days) - modulo : `int` (None) - How to mod the years when computing season - - Notes - ----- - Uses `season_calc` to calculate season value. - """ - - def __init__( - self, - season, - filtername=None, - nside=None, - offset=0, - modulo=None, - max_season=None, - season_length=365.25, - ): - if offset is None: - offset = np.zeros(hp.nside2npix(nside), dtype=int) - if nside is None: - nside = utils.set_default_nside() - - self.feature = np.zeros(hp.nside2npix(nside), dtype=float) - self.filtername = filtername - ## How does this work if the default is 0 -- in add_observation - # an index is referenced for offset, so the default should fail - self.offset = offset - self.modulo = modulo - self.season = season - self.max_season = max_season - self.season_length = season_length - send_unused_deprecation_warning(self.__class__.__name__) - - def add_observation(self, observation, indx=None): - # How does this work if indx is None -- self.offset[indx] should fail - observation_season = utils.season_calc( - observation["night"], - offset=self.offset[indx], - modulo=self.modulo, - max_season=self.max_season, - season_length=self.season_length, - ) - if self.season in observation_season: - if self.filtername is None or observation["filter"][0] in self.filtername: - self.feature[indx] += 1 - - class LargestN: def __init__(self, n): + # This is used within other features or basis functions, + # but is not a feature itself self.n = n def __call__(self, in_arr): @@ -807,39 +567,6 @@ def add_observation(self, observation, indx): self.feature[this_season_indx] += 1 -class CoaddedDepth(BaseSurveyFeature): - """Track the co-added depth that has been reached across the sky - - Parameters - ---------- - fwh_meff_limit : `float` (100) - The effective FWHM of the seeing (arcsecond). - Images will only be added to the coadded depth if the observation - FWHM is less than or equal to the limit. Default 100. - """ - - def __init__(self, filtername="r", nside=None, fwhm_eff_limit=100.0): - if nside is None: - nside = utils.set_default_nside() - self.filtername = filtername - self.fwhm_eff_limit = IntRounded(fwhm_eff_limit) - # Starting at limiting mag of zero should be fine. - self.feature = np.zeros(hp.nside2npix(nside), dtype=float) - - def add_observation(self, observation, indx=None): - if observation["filter"][0] == self.filtername: - if IntRounded(observation["FWHMeff"]) <= self.fwhm_eff_limit: - m5 = m5_flat_sed( - observation["filter"], - observation["skybrightness"], - observation["FWHMeff"], - observation["exptime"], - observation["airmass"], - ) - - self.feature[indx] = 1.25 * np.log10(10.0 ** (0.8 * self.feature[indx]) + 10.0 ** (0.8 * m5)) - - class LastObserved(BaseSurveyFeature): """ Track the MJD when a pixel was last observed. diff --git a/rubin_scheduler/scheduler/surveys/dd_surveys.py b/rubin_scheduler/scheduler/surveys/dd_surveys.py index 57381235..46a03c3d 100644 --- a/rubin_scheduler/scheduler/surveys/dd_surveys.py +++ b/rubin_scheduler/scheduler/surveys/dd_surveys.py @@ -114,8 +114,8 @@ def __init__( ) # to days if self.reward_value is None: - self.extra_features["Ntot"] = features.NObsSurvey() - self.extra_features["N_survey"] = features.NObsSurvey(note=self.survey_name) + self.extra_features["Ntot"] = features.NObsCount() + self.extra_features["N_survey"] = features.NObsCount(note=self.survey_name) @cached_property def roi_hpid(self): diff --git a/rubin_scheduler/scheduler/surveys/field_survey.py b/rubin_scheduler/scheduler/surveys/field_survey.py index 74dd9693..20c9cb89 100644 --- a/rubin_scheduler/scheduler/surveys/field_survey.py +++ b/rubin_scheduler/scheduler/surveys/field_survey.py @@ -8,7 +8,7 @@ from rubin_scheduler.utils import ra_dec2_hpid -from ..features import LastObservation, NObsSurvey +from ..features import LastObservation, NObsCount from ..utils import ObservationArray from . import BaseSurvey @@ -209,7 +209,7 @@ def __init__( # Tucking this here so we can look at how many observations # recorded for this field and what was the last one. - self.extra_features["ObsRecorded"] = NObsSurvey() + self.extra_features["ObsRecorded"] = NObsCount() self.extra_features["LastObs"] = LastObservation() def _generate_survey_name(self, target_name=None): diff --git a/tests/scheduler/test_baseline.py b/tests/scheduler/test_baseline.py index 57abbdb0..a76055bd 100644 --- a/tests/scheduler/test_baseline.py +++ b/tests/scheduler/test_baseline.py @@ -3,21 +3,12 @@ import numpy as np -import rubin_scheduler.scheduler.basis_functions as bf -import rubin_scheduler.scheduler.detailers as detailers import rubin_scheduler.utils as utils from rubin_scheduler.data import get_data_dir from rubin_scheduler.scheduler import sim_runner -from rubin_scheduler.scheduler.example import example_scheduler +from rubin_scheduler.scheduler.example import example_scheduler, simple_greedy_survey, simple_pairs_survey from rubin_scheduler.scheduler.model_observatory import ModelObservatory from rubin_scheduler.scheduler.schedulers import CoreScheduler -from rubin_scheduler.scheduler.surveys import ( - BlobSurvey, - GreedySurvey, - ScriptedSurvey, - generate_ddf_scheduled_obs, -) -from rubin_scheduler.scheduler.utils import SkyAreaGenerator, calc_norm_factor_array SAMPLE_BIG_DATA_FILE = os.path.join(get_data_dir(), "scheduler/dust_maps/dust_nside_32.npz") @@ -42,149 +33,6 @@ def return_conditions(self): return self.conditions -def ddf_surveys(detailers=None, season_unobs_frac=0.2, euclid_detailers=None, nside=None): - obs_array = generate_ddf_scheduled_obs(season_unobs_frac=season_unobs_frac) - - euclid_obs = np.where( - (obs_array["scheduler_note"] == "DD:EDFS_b") | (obs_array["scheduler_note"] == "DD:EDFS_a") - )[0] - all_other = np.where( - (obs_array["scheduler_note"] != "DD:EDFS_b") & (obs_array["scheduler_note"] != "DD:EDFS_a") - )[0] - - survey1 = ScriptedSurvey([bf.AvoidDirectWind(nside=nside)], detailers=detailers) - survey1.set_script(obs_array[all_other]) - - survey2 = ScriptedSurvey([bf.AvoidDirectWind(nside=nside)], detailers=euclid_detailers) - survey2.set_script(obs_array[euclid_obs]) - - return [survey1, survey2] - - -def gen_greedy_surveys(nside): - """ - Make a quick set of greedy surveys - """ - sky = SkyAreaGenerator(nside=nside) - target_map, labels = sky.return_maps() - filters = ["g", "r", "i", "z", "y"] - surveys = [] - - for filtername in filters: - bfs = [] - bfs.append(bf.M5DiffBasisFunction(filtername=filtername, nside=nside)) - bfs.append( - bf.TargetMapBasisFunction( - filtername=filtername, - target_map=target_map[filtername], - out_of_bounds_val=np.nan, - nside=nside, - ) - ) - bfs.append(bf.SlewtimeBasisFunction(filtername=filtername, nside=nside)) - bfs.append(bf.StrictFilterBasisFunction(filtername=filtername)) - # Masks, give these 0 weight - bfs.append(bf.AvoidDirectWind(nside=nside)) - bfs.append(bf.AltAzShadowMaskBasisFunction(nside=nside, shadow_minutes=60.0, max_alt=76.0)) - bfs.append(bf.MoonAvoidanceBasisFunction(nside=nside, moon_distance=30.0)) - bfs.append(bf.CloudedOutBasisFunction()) - - bfs.append(bf.FilterLoadedBasisFunction(filternames=filtername)) - bfs.append(bf.PlanetMaskBasisFunction(nside=nside)) - - weights = np.array([3.0, 0.3, 3.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) - surveys.append( - GreedySurvey( - bfs, - weights, - block_size=1, - filtername=filtername, - dither=True, - nside=nside, - survey_name="greedy", - ) - ) - return surveys - - -def gen_blob_surveys(nside): - """ - make a quick set of blob surveys - """ - sky = SkyAreaGenerator(nside=nside) - target_map, labels = sky.return_maps() - norm_factor = calc_norm_factor_array(target_map) - - filter1s = ["g"] # , 'r', 'i', 'z', 'y'] - filter2s = ["g"] # , 'r', 'i', None, None] - - pair_surveys = [] - for filtername, filtername2 in zip(filter1s, filter2s): - detailer_list = [] - bfs = [] - bfs.append(bf.M5DiffBasisFunction(filtername=filtername, nside=nside)) - if filtername2 is not None: - bfs.append(bf.M5DiffBasisFunction(filtername=filtername2, nside=nside)) - bfs.append( - bf.TargetMapBasisFunction( - filtername=filtername, - target_map=target_map[filtername], - out_of_bounds_val=np.nan, - nside=nside, - norm_factor=norm_factor, - ) - ) - if filtername2 is not None: - bfs.append( - bf.TargetMapBasisFunction( - filtername=filtername2, - target_map=target_map[filtername2], - out_of_bounds_val=np.nan, - nside=nside, - norm_factor=norm_factor, - ) - ) - bfs.append(bf.SlewtimeBasisFunction(filtername=filtername, nside=nside)) - bfs.append(bf.StrictFilterBasisFunction(filtername=filtername)) - # Masks, give these 0 weight - bfs.append(bf.AvoidDirectWind(nside=nside)) - bfs.append(bf.AltAzShadowMaskBasisFunction(nside=nside, shadow_minutes=60.0, max_alt=76.0)) - bfs.append(bf.MoonAvoidanceBasisFunction(nside=nside, moon_distance=30.0)) - bfs.append(bf.CloudedOutBasisFunction()) - # feasibility basis fucntions. Also give zero weight. - filternames = [fn for fn in [filtername, filtername2] if fn is not None] - bfs.append(bf.FilterLoadedBasisFunction(filternames=filternames)) - bfs.append(bf.TimeToTwilightBasisFunction(time_needed=22.0)) - bfs.append(bf.NotTwilightBasisFunction()) - bfs.append(bf.PlanetMaskBasisFunction(nside=nside)) - - weights = np.array([3.0, 3.0, 0.3, 0.3, 3.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) - if filtername2 is None: - # Need to scale weights up so filter balancing works properly. - weights = np.array([6.0, 0.6, 3.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) - if filtername2 is None: - survey_name = "blob, %s" % filtername - else: - survey_name = "blob, %s%s" % (filtername, filtername2) - if filtername2 is not None: - detailer_list.append(detailers.TakeAsPairsDetailer(filtername=filtername2)) - - detailer_list.append(detailers.FlushByDetailer()) - pair_surveys.append( - BlobSurvey( - bfs, - weights, - filtername1=filtername, - filtername2=filtername2, - survey_name=survey_name, - ignore_obs="DD", - detailers=detailer_list, - nside=nside, - ) - ) - return pair_surveys - - class TestExample(unittest.TestCase): @unittest.skipUnless(os.path.isfile(SAMPLE_BIG_DATA_FILE), "Test data not available.") def test_example(self): @@ -208,71 +56,6 @@ def test_example(self): # Make sure a DDF executed assert np.any(["DD" in note for note in observations["scheduler_note"]]) - -class TestFeatures(unittest.TestCase): - @unittest.skipUnless(os.path.isfile(SAMPLE_BIG_DATA_FILE), "Test data not available.") - def test_greedy(self): - """ - Set up a greedy survey and run for a few days. - A crude way to touch lots of code. - """ - mjd_start = utils.survey_start_mjd() - nside = 32 - survey_length = 3.0 # days - - surveys = gen_greedy_surveys(nside) - # Deprecating Pairs_survey_scripted - # surveys.append(Pairs_survey_scripted(None, ignore_obs='DD')) - - # Set up the DD - dd_surveys = ddf_surveys(nside=nside) - surveys.extend(dd_surveys) - - scheduler = CoreScheduler(surveys, nside=nside) - observatory = ModelObservatory(nside=nside, mjd_start=mjd_start) - observatory, scheduler, observations = sim_runner( - observatory, scheduler, sim_duration=survey_length, filename=None - ) - - # Check that greedy observed some - assert "greedy" in observations["scheduler_note"] - # Make sure lots of observations executed - assert observations.size > 1000 - # Make sure nothing tried to look through the earth - assert np.min(observations["alt"]) > 0 - - @unittest.skipUnless(os.path.isfile(SAMPLE_BIG_DATA_FILE), "Test data not available.") - def test_blobs(self): - """ - Set up a blob selection survey - """ - mjd_start = utils.survey_start_mjd() - nside = 32 - survey_length = 3.0 # days - - surveys = [] - # Set up the DD - dd_surveys = ddf_surveys(nside=nside) - surveys.append(dd_surveys) - - surveys.append(gen_blob_surveys(nside)) - surveys.append(gen_greedy_surveys(nside)) - - scheduler = CoreScheduler(surveys, nside=nside) - observatory = ModelObservatory(nside=nside, mjd_start=mjd_start) - observatory, scheduler, observations = sim_runner( - observatory, scheduler, sim_duration=survey_length, filename=None - ) - # Make sure some blobs executed - assert "blob, gg, b" in observations["scheduler_note"] - assert "blob, gg, a" in observations["scheduler_note"] - # Make sure some greedy executed - assert "greedy" in observations["scheduler_note"] - # Make sure lots of observations executed - assert observations.size > 1000 - # Make sure nothing tried to look through the earth - assert np.min(observations["alt"]) > 0 - @unittest.skipUnless(os.path.isfile(SAMPLE_BIG_DATA_FILE), "Test data not available.") def test_wind(self): """ @@ -281,29 +64,21 @@ def test_wind(self): """ mjd_start = utils.survey_start_mjd() nside = 32 - survey_length = 4.0 # days - - surveys = [] - # Set up the DD - dd_surveys = ddf_surveys(nside=nside) - surveys.append(dd_surveys) + survey_length = 2.0 # days - surveys.append(gen_blob_surveys(nside)) - surveys.append(gen_greedy_surveys(nside)) + surveys = [simple_greedy_survey(filtername=f) for f in "gri"] scheduler = CoreScheduler(surveys, nside=nside) - observatory = ModelObservatoryWindy(nside=nside, mjd_start=mjd_start) + observatory = ModelObservatoryWindy( + nside=nside, mjd_start=mjd_start, downtimes="ideal", cloud_data="ideal" + ) observatory, scheduler, observations = sim_runner( observatory, scheduler, sim_duration=survey_length, filename=None ) - # Make sure some blobs executed - assert "blob, gg, b" in observations["scheduler_note"] - assert "blob, gg, a" in observations["scheduler_note"] - # Make sure some greedy executed - assert "greedy" in observations["scheduler_note"] - # Make sure lots of observations executed - assert observations.size > 1000 + # Make sure lots of observations executed, but allow short night + # if survey_start changes + assert observations.size > 700 * survey_length # Make sure nothing tried to look through the earth assert np.min(observations["alt"]) > 0 @@ -320,25 +95,25 @@ def test_nside(self): nside = 64 survey_length = 3.0 # days - surveys = [] - # Set up the DD - dd_surveys = ddf_surveys(nside=nside) - surveys.append(dd_surveys) - - surveys.append(gen_blob_surveys(nside)) - surveys.append(gen_greedy_surveys(nside)) + pairs_surveys = [ + simple_pairs_survey(filtername="g", filtername2="r", nside=nside), + simple_pairs_survey(filtername="i", filtername2="z", nside=nside), + ] + greedy_surveys = [ + simple_greedy_survey(filtername="z", nside=nside), + ] - scheduler = CoreScheduler(surveys, nside=nside) + scheduler = CoreScheduler([pairs_surveys, greedy_surveys], nside=nside) observatory = ModelObservatory(nside=nside, mjd_start=mjd_start) observatory, scheduler, observations = sim_runner( observatory, scheduler, sim_duration=survey_length, filename=None ) # Make sure some blobs executed - assert "blob, gg, b" in observations["scheduler_note"] - assert "blob, gg, a" in observations["scheduler_note"] + assert "simple pair 30, iz, a" in observations["scheduler_note"] + assert "simple pair 30, iz, b" in observations["scheduler_note"] # Make sure some greedy executed - assert "greedy" in observations["scheduler_note"] + assert "greedy z" in observations["scheduler_note"] # Make sure lots of observations executed assert observations.size > 1000 # Make sure nothing tried to look through the earth diff --git a/tests/scheduler/test_basisfuncs.py b/tests/scheduler/test_basisfuncs.py index b6d5be28..51dd9d56 100644 --- a/tests/scheduler/test_basisfuncs.py +++ b/tests/scheduler/test_basisfuncs.py @@ -344,15 +344,8 @@ def test_AltAzShadowMask(self): self.assertTrue(len(overlap[0]) > 0) def test_deprecated(self): - deprecated_basis_functions = [ - basis_functions.NearSunTwilightBasisFunction, - basis_functions.AvoidFastRevisits, - basis_functions.AvoidLongGapsBasisFunction, - basis_functions.FootprintNvisBasisFunction, - basis_functions.GoalStrictFilterBasisFunction, - basis_functions.ZenithShadowMaskBasisFunction, - basis_functions.MaskAzimuthBasisFunction, - ] + # Add to-be-deprecated functions here as they appear + deprecated_basis_functions = [] for dep_bf in deprecated_basis_functions: print(dep_bf) with warnings.catch_warnings(record=True) as w: diff --git a/tests/scheduler/test_coresched.py b/tests/scheduler/test_coresched.py index c1f6faf7..36f01714 100644 --- a/tests/scheduler/test_coresched.py +++ b/tests/scheduler/test_coresched.py @@ -1,29 +1,21 @@ import unittest -import numpy as np import pandas as pd import rubin_scheduler.scheduler.basis_functions as basis_functions import rubin_scheduler.scheduler.surveys as surveys +from rubin_scheduler.scheduler.example import simple_greedy_survey from rubin_scheduler.scheduler.model_observatory import ModelObservatory from rubin_scheduler.scheduler.schedulers import CoreScheduler -from rubin_scheduler.scheduler.utils import ObservationArray, generate_all_sky +from rubin_scheduler.scheduler.utils import ObservationArray class TestCoreSched(unittest.TestCase): def testsched(self): - nside = 32 - # Just set up a very simple target map, dec limited, one filter - sky_dict = generate_all_sky(nside, mask=-1) - target_map = np.where( - ((sky_dict["map"] >= 0) & (sky_dict["dec"] < 2) & (sky_dict["dec"] > -65)), 1, 0 - ) - - bfs = [] - bfs.append(basis_functions.M5DiffBasisFunction(nside=nside)) - bfs.append(basis_functions.TargetMapBasisFunction(target_map=target_map, norm_factor=1)) - weights = np.array([1.0, 1]) - survey = surveys.GreedySurvey(bfs, weights) + + # Just set up a very simple survey, one filter + survey = simple_greedy_survey(filtername="r") + scheduler = CoreScheduler([survey]) observatory = ModelObservatory() diff --git a/tests/scheduler/test_features.py b/tests/scheduler/test_features.py index 80d943b1..603e6a05 100644 --- a/tests/scheduler/test_features.py +++ b/tests/scheduler/test_features.py @@ -358,41 +358,65 @@ def test_NObservationsCurrentSeason(self): # in these cases with added requirements .. but will leave it # to the "restore" test in test_utils.py. - def test_NObsSurvey(self): + def test_NObsCount(self): # Make some observations to count - observations_list = make_observations_list(2) - observations_list[0]["scheduler_note"] = "survey a" - observations_list[1]["scheduler_note"] = "survey b" + observations_list = make_observations_list(5) + for i in [0, 2]: + observations_list[i]["scheduler_note"] = "survey a" + for i in [1, 3]: + observations_list[i]["scheduler_note"] = "survey b" + observations_list[4]["scheduler_note"] = "survey" + for i in [0, 1]: + observations_list[i]["filter"] = "r" + for i in [2, 3, 4]: + observations_list[i]["filter"] = "g" observations_array, observations_hpid_array = make_observations_arrays(observations_list) # Count the observations matching any note - count_feature = features.NObsSurvey(note=None) - # ... it matters significantly that we pass obs[0] and not obs. + count_feature = features.NObsCount(note=None, filtername=None) for obs in observations_list: count_feature.add_observation(obs) - self.assertTrue(count_feature.feature == 2) + self.assertTrue(count_feature.feature == 5) # and count again using add_observations_array - count_feature = features.NObsSurvey(note=None) + count_feature = features.NObsCount(note=None) count_feature.add_observations_array(observations_array, observations_hpid=observations_hpid_array) - self.assertTrue(count_feature.feature == 2) + self.assertTrue(count_feature.feature == 5) # Count using a note to match # Count the observations matching specific note - count_feature = features.NObsSurvey(note="survey a") + count_feature = features.NObsCount(note="survey a") for obs in observations_list: count_feature.add_observation(obs) - self.assertTrue(count_feature.feature == 1) + self.assertTrue(count_feature.feature == 2) # and count again using add_observations_array - count_feature = features.NObsSurvey(note="survey a") + count_feature = features.NObsCount(note="survey a") count_feature.add_observations_array(observations_array, observations_hpid=observations_hpid_array) - self.assertTrue(count_feature.feature == 1) + self.assertTrue(count_feature.feature == 2) # Count the observations matching subset of note - count_feature = features.NObsSurvey(note="survey") + count_feature = features.NObsCount(note="survey") + for obs in observations_list: + count_feature.add_observation(obs) + self.assertTrue(count_feature.feature == 5) + # and count again using add_observations_array + count_feature = features.NObsCount(note="survey") + count_feature.add_observations_array(observations_array, observations_hpid=observations_hpid_array) + self.assertTrue(count_feature.feature == 5) + # Count the observations matching filter + count_feature = features.NObsCount(note=None, filtername="r") for obs in observations_list: count_feature.add_observation(obs) self.assertTrue(count_feature.feature == 2) # and count again using add_observations_array - count_feature = features.NObsSurvey(note="survey") + count_feature = features.NObsCount(note=None, filtername="r") count_feature.add_observations_array(observations_array, observations_hpid=observations_hpid_array) self.assertTrue(count_feature.feature == 2) + # Count the observations matching filter and surveyname + count_feature = features.NObsCount(note="survey b", filtername="r") + for obs in observations_list: + count_feature.add_observation(obs) + self.assertTrue(count_feature.feature == 1) + # and count again using add_observations_array + count_feature = features.NObsCount(note="survey b", filtername="r") + count_feature.add_observations_array(observations_array, observations_hpid=observations_hpid_array) + self.assertTrue(count_feature.feature == 1) def test_LastObservation(self): # Make some observations to count From 2d075a2e14515b8f276912548145f93d1d0145a7 Mon Sep 17 00:00:00 2001 From: Lynne Jones Date: Fri, 13 Sep 2024 19:19:11 -0700 Subject: [PATCH 2/2] Update sunAltBasisFunction + test_nside --- .../scheduler/basis_functions/feasibility_funcs.py | 2 +- tests/scheduler/test_baseline.py | 12 +++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/rubin_scheduler/scheduler/basis_functions/feasibility_funcs.py b/rubin_scheduler/scheduler/basis_functions/feasibility_funcs.py index fac39abb..d2ed5105 100644 --- a/rubin_scheduler/scheduler/basis_functions/feasibility_funcs.py +++ b/rubin_scheduler/scheduler/basis_functions/feasibility_funcs.py @@ -681,6 +681,6 @@ def __init__(self, alt_limit=-12.1): def check_feasibility(self, conditions): result = True - if conditions.sunAlt > self.alt_limit: + if conditions.sun_alt > self.alt_limit: result = False return result diff --git a/tests/scheduler/test_baseline.py b/tests/scheduler/test_baseline.py index a76055bd..6d33f90e 100644 --- a/tests/scheduler/test_baseline.py +++ b/tests/scheduler/test_baseline.py @@ -6,6 +6,7 @@ import rubin_scheduler.utils as utils from rubin_scheduler.data import get_data_dir from rubin_scheduler.scheduler import sim_runner +from rubin_scheduler.scheduler.basis_functions import SunAltLimitBasisFunction from rubin_scheduler.scheduler.example import example_scheduler, simple_greedy_survey, simple_pairs_survey from rubin_scheduler.scheduler.model_observatory import ModelObservatory from rubin_scheduler.scheduler.schedulers import CoreScheduler @@ -93,12 +94,17 @@ def test_nside(self): """ mjd_start = utils.survey_start_mjd() nside = 64 - survey_length = 3.0 # days + survey_length = 2.0 # days + # Add an avoidance of twilight+ for the pairs surveys - + # this ensures greedy survey will have some time to operate pairs_surveys = [ simple_pairs_survey(filtername="g", filtername2="r", nside=nside), simple_pairs_survey(filtername="i", filtername2="z", nside=nside), ] + for survey in pairs_surveys: + survey.basis_functions.append(SunAltLimitBasisFunction(alt_limit=-22)) + survey.basis_weights.append(0) greedy_surveys = [ simple_greedy_survey(filtername="z", nside=nside), ] @@ -113,9 +119,9 @@ def test_nside(self): assert "simple pair 30, iz, a" in observations["scheduler_note"] assert "simple pair 30, iz, b" in observations["scheduler_note"] # Make sure some greedy executed - assert "greedy z" in observations["scheduler_note"] + assert "simple greedy z" in observations["scheduler_note"] # Make sure lots of observations executed - assert observations.size > 1000 + assert observations.size > 800 # Make sure nothing tried to look through the earth assert np.min(observations["alt"]) > 0