Skip to content

Commit 7f241a6

Browse files
committed
Cleanups for the new adaptive background task.
- Replace "are we looping" variables with 'break' and 'for x in range'. - Use snake_case consistently for internal variables, but keep camelCase for configs, for consistency with SubtractBackgroundTask configs. - Move some magic numbers (not all) to config fields.
1 parent 1f35eff commit 7f241a6

File tree

1 file changed

+70
-61
lines changed

1 file changed

+70
-61
lines changed

python/lsst/meas/algorithms/adaptive_thresholds.py

Lines changed: 70 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
import numpy as np
3232

33-
from lsst.pex.config import Field, Config, ConfigField, DictField, FieldValidationError, ListField
33+
from lsst.pex.config import Field, Config, ConfigField, DictField, FieldValidationError, ListField, RangeField
3434
from lsst.pipe.base import Task
3535

3636
from lsst.afw.geom import SpanSet
@@ -327,17 +327,42 @@ def run(self, table, exposure, **kwargs):
327327
class AdaptiveThresholdBackgroundConfig(SubtractBackgroundConfig):
328328
detectedFractionBadMaskPlanes = ListField(
329329
"Mask planes to ignore when computing the detected fraction.", dtype=str,
330-
default=["BAD", "EDGE", "NO_DATA"]
330+
default=["BAD", "EDGE", "NO_DATA"],
331331
)
332332
minDetFracForFinalBg = Field(
333333
"Minimum detected fraction for the final background.",
334-
dtype=float, default=0.02
334+
dtype=float, default=0.02,
335335
)
336336
maxDetFracForFinalBg = Field(
337337
"Maximum detected fraction for the final background.",
338-
dtype=float, default=0.93
338+
dtype=float, default=0.93,
339+
)
340+
initialDilateRadius = RangeField(
341+
"Number of pixels to dilate the original mask by.",
342+
dtype=int, default=10, min=1,
343+
)
344+
doCheckPerAmpDetectionFraction = Field(
345+
"Whether to check per-amplifier detection fractions.",
346+
dtype=bool, default=True,
347+
)
348+
maxDetectionIterationCount = Field(
349+
"Maximum number of adaptive detection iterations.",
350+
dtype=int, default=39, # Original code had 40, but iterated from [0, n] inclusive.
351+
)
352+
detection = ConfigField(
353+
"Baseline configuration for SourceDetectionTask prior to iteration.",
354+
SourceDetectionConfig,
339355
)
340356

357+
def setDefaults(self):
358+
super().setDefaults()
359+
self.detection.doTempLocalBackground = False
360+
self.detection.nSigmaToGrow = 70.0
361+
self.detection.reEstimateBackground = False
362+
self.detection.includeThresholdMultiplier = 1.0
363+
self.detection.thresholdValue = 2.0 # Actually used only as a floor.
364+
self.detection.thresholdType = "pixel_stdev"
365+
341366

342367
class AdaptiveThresholdBackgroundTask(SubtractBackgroundTask):
343368
"""A background subtraction task that does its own masking of detected
@@ -369,123 +394,107 @@ def run(self, exposure, background=None, stats=True, statsKeys=None, backgroundT
369394
exposure.image.array += background.getImage().array
370395

371396
with self._restore_mask_when_done(exposure) as original_mask:
372-
self._dilate_original_mask(exposure, original_mask)
373-
self._set_adaptive_detection_mask(exposure, median_background)
397+
dilated_mask = self._dilate_original_mask(exposure, original_mask)
398+
self._set_adaptive_detection_mask(exposure, median_background, dilated_mask)
374399
# Do not pass the original background in, since we want to wholly
375400
# replace it.
376401
return super().run(exposure=exposure, stats=stats, statsKeys=statsKeys,
377402
backgroundToPhotometricRatio=backgroundToPhotometricRatio)
378403

379404
def _dilate_original_mask(self, exposure, original_mask):
380-
nPixToDilate = 10
381405
detected_fraction_orig = self._compute_mask_fraction(exposure.mask)
382406
# Dilate the current detected mask planes and don't clear
383407
# them in the detection step.
384-
inDilating = True
385-
while inDilating:
386-
dilatedMask = original_mask.clone()
387-
for maskName in self._DETECTED_MASK_PLANES:
408+
for n_pix_to_dilate in range(self.config.initialDilateRadius, 0, -1):
409+
dilated_mask = original_mask.clone()
410+
for mask_name in self._DETECTED_MASK_PLANES:
388411
# Compute the grown detection mask plane using SpanSet
389-
detectedMaskBit = dilatedMask.getPlaneBitMask(maskName)
390-
detectedMaskSpanSet = SpanSet.fromMask(dilatedMask, detectedMaskBit)
391-
detectedMaskSpanSet = detectedMaskSpanSet.dilated(nPixToDilate)
392-
detectedMaskSpanSet = detectedMaskSpanSet.clippedTo(dilatedMask.getBBox())
412+
detected_mask_bit = dilated_mask.getPlaneBitMask(mask_name)
413+
detected_mask_span_set = SpanSet.fromMask(dilated_mask, detected_mask_bit)
414+
detected_mask_span_set = detected_mask_span_set.dilated(n_pix_to_dilate)
415+
detected_mask_span_set = detected_mask_span_set.clippedTo(dilated_mask.getBBox())
393416
# Clear the detected mask plane
394-
detectedMask = dilatedMask.getMaskPlane(maskName)
395-
dilatedMask.clearMaskPlane(detectedMask)
417+
detectedMask = dilated_mask.getMaskPlane(mask_name)
418+
dilated_mask.clearMaskPlane(detectedMask)
396419
# Set the mask plane to the dilated one
397-
detectedMaskSpanSet.setMask(dilatedMask, detectedMaskBit)
420+
detected_mask_span_set.setMask(dilated_mask, detected_mask_bit)
398421

399-
detected_fraction_dilated = self._compute_mask_fraction(dilatedMask)
400-
if detected_fraction_dilated < self.config.maxDetFracForFinalBg or nPixToDilate == 1:
401-
inDilating = False
402-
else:
403-
nPixToDilate -= 1
404-
exposure.mask = dilatedMask
422+
detected_fraction_dilated = self._compute_mask_fraction(dilated_mask)
423+
if detected_fraction_dilated < self.config.maxDetFracForFinalBg or n_pix_to_dilate == 1:
424+
break
425+
exposure.mask = dilated_mask
405426
self.log.warning("detected_fraction_orig = %.3f detected_fraction_dilated = %.3f",
406427
detected_fraction_orig, detected_fraction_dilated)
407428
n_above_max_per_amp = -99
408429
highest_detected_fraction_per_amp = float("nan")
409-
doCheckPerAmpDetFraction = True
410-
if doCheckPerAmpDetFraction: # detected_fraction < maxDetFracForFinalBg:
430+
if self.config.doCheckPerAmpDetectionFraction:
411431
n_above_max_per_amp, highest_detected_fraction_per_amp, no_zero_det_amps = \
412432
self._compute_per_amp_fraction(exposure, detected_fraction_dilated)
413433
self.log.warning("Dilated mask: n_above_max_per_amp = %d, "
414434
"highest_detected_fraction_per_amp = %.3f",
415435
n_above_max_per_amp, highest_detected_fraction_per_amp)
436+
return dilated_mask
416437

417-
def _set_adaptive_detection_mask(self, exposure, median_background):
418-
inBackgroundDet = True
438+
def _set_adaptive_detection_mask(self, exposure, median_background, dilated_mask):
419439
detected_fraction = 1.0
420-
maxIter = 40
421-
nIter = 0
422-
nFootprintTemp = 1e12
423-
starBackgroundDetectionConfig = SourceDetectionConfig()
424-
starBackgroundDetectionConfig.doTempLocalBackground = False
425-
starBackgroundDetectionConfig.nSigmaToGrow = 70.0
426-
starBackgroundDetectionConfig.reEstimateBackground = False
427-
starBackgroundDetectionConfig.includeThresholdMultiplier = 1.0
428-
starBackgroundDetectionConfig.thresholdValue = max(2.0, 0.2*median_background)
429-
starBackgroundDetectionConfig.thresholdType = "pixel_stdev" # "stdev"
440+
n_footprint_tmp = 1e12
441+
detection_config = self.config.detection.copy()
442+
detection_config.thresholdValue = max(detection_config.thresholdValue, 0.2*median_background)
443+
detection_config.thresholdType = "pixel_stdev" # "stdev"
430444

445+
no_zero_det_amps = 0
431446
n_above_max_per_amp = -99
432447
highest_detected_fraction_per_amp = float("nan")
433-
doCheckPerAmpDetFraction = True
434448

435-
while inBackgroundDet:
436-
currentThresh = starBackgroundDetectionConfig.thresholdValue
449+
for n_iter in range(0, self.config.maxDetectionIterationCount):
450+
current_threshold = detection_config.thresholdValue
437451
if detected_fraction > self.config.maxDetFracForFinalBg:
438-
starBackgroundDetectionConfig.thresholdValue = 1.07*currentThresh
439-
if nFootprintTemp < 3 and detected_fraction > 0.9*self.config.maxDetFracForFinalBg:
440-
starBackgroundDetectionConfig.thresholdValue = 1.2*currentThresh
452+
detection_config.thresholdValue = 1.07*current_threshold
453+
if n_footprint_tmp < 3 and detected_fraction > 0.9*self.config.maxDetFracForFinalBg:
454+
detection_config.thresholdValue = 1.2*current_threshold
441455
if n_above_max_per_amp > 1:
442-
starBackgroundDetectionConfig.thresholdValue = 1.1*currentThresh
456+
detection_config.thresholdValue = 1.1*current_threshold
443457
if detected_fraction < self.config.minDetFracForFinalBg:
444-
starBackgroundDetectionConfig.thresholdValue = 0.8*currentThresh
445-
starBackgroundDetectionTask = SourceDetectionTask(
446-
config=starBackgroundDetectionConfig)
447-
tempDetections = starBackgroundDetectionTask.detectFootprints(
448-
exposure=exposure, clearMask=True)
449-
exposure.mask |= dilatedMask
450-
nFootprintTemp = (
458+
detection_config.thresholdValue = 0.8*current_threshold
459+
detection_task = SourceDetectionTask(config=detection_config)
460+
tempDetections = detection_task.detectFootprints(exposure=exposure, clearMask=True)
461+
exposure.mask |= dilated_mask
462+
n_footprint_tmp = (
451463
(len(tempDetections.positive.getFootprints()) if tempDetections is not None else 0)
452464
+ (len(tempDetections.negative.getFootprints()) if tempDetections.negative is not None else 0)
453465
)
454466
detected_fraction = self._compute_mask_fraction(exposure.mask)
455467
self.log.info("nIter = %d, thresh = %.2f: Fraction of pixels marked as DETECTED or "
456468
"DETECTED_NEGATIVE in star_background_detection = %.3f "
457469
"(max is %.3f; min is %.3f)",
458-
nIter, starBackgroundDetectionConfig.thresholdValue,
470+
n_iter, detection_config.thresholdValue,
459471
detected_fraction, self.config.maxDetFracForFinalBg, self.config.minDetFracForFinalBg)
460472

461473
n_amp = len(exposure.detector.getAmplifiers())
462-
if doCheckPerAmpDetFraction: # detected_fraction < maxDetFracForFinalBg:
474+
if self.config.doCheckPerAmpDetectionFraction:
463475
n_above_max_per_amp, highest_detected_fraction_per_amp, no_zero_det_amps = \
464476
self._compute_per_amp_fraction(exposure, detected_fraction)
465477

466478
if not no_zero_det_amps:
467-
starBackgroundDetectionConfig.thresholdValue = 0.95*currentThresh
468-
nIter += 1
469-
if nIter > maxIter:
470-
inBackgroundDet = False
479+
detection_config.thresholdValue = 0.95*current_threshold
471480

472481
if (detected_fraction < self.config.maxDetFracForFinalBg and detected_fraction > self.config.minDetFracForFinalBg
473482
and n_above_max_per_amp < int(0.75*n_amp)
474483
and no_zero_det_amps):
475484
if (n_above_max_per_amp < max(1, int(0.15*n_amp))
476485
or detected_fraction < 0.85*self.config.maxDetFracForFinalBg):
477-
inBackgroundDet = False
486+
break
478487
else:
479488
self.log.warning("Making small tweak....")
480-
starBackgroundDetectionConfig.thresholdValue = 1.05*currentThresh
489+
detection_config.thresholdValue = 1.05*current_threshold
481490
self.log.warning("n_above_max_per_amp = %d (abs max is %d)", n_above_max_per_amp, int(0.75*n_amp))
482491

483492
self.log.info("Fraction of pixels marked as DETECTED or DETECTED_NEGATIVE is now %.5f "
484493
"(highest per amp section = %.5f)",
485494
detected_fraction, highest_detected_fraction_per_amp)
486495

487496
if detected_fraction > self.config.maxDetFracForFinalBg:
488-
exposure.mask = dilatedMask
497+
exposure.mask = dilated_mask
489498
self.log.warning("Final fraction of pixels marked as DETECTED or DETECTED_NEGATIVE "
490499
"was too large in star_background_detection = %.3f (max = %.3f). "
491500
"Reverting to dilated mask from PSF detection...",

0 commit comments

Comments
 (0)