Skip to content

Commit 1f35eff

Browse files
committed
Move adaptive background subtraction to new task here.
1 parent be4be72 commit 1f35eff

File tree

1 file changed

+264
-1
lines changed

1 file changed

+264
-1
lines changed

python/lsst/meas/algorithms/adaptive_thresholds.py

Lines changed: 264 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,22 @@
2222
__all__ = [
2323
"AdaptiveThresholdDetectionConfig",
2424
"AdaptiveThresholdDetectionTask",
25+
"AdaptiveThresholdBackgroundConfig",
26+
"AdaptiveThresholdBackgroundTask",
2527
]
2628

29+
from contextlib import contextmanager
30+
2731
import numpy as np
2832

29-
from lsst.pex.config import Field, Config, ConfigField, DictField, FieldValidationError
33+
from lsst.pex.config import Field, Config, ConfigField, DictField, FieldValidationError, ListField
3034
from lsst.pipe.base import Task
3135

36+
from lsst.afw.geom import SpanSet
37+
from lsst.afw.image import Mask
38+
from lsst.afw.math import BackgroundList
3239
from .detection import SourceDetectionConfig, SourceDetectionTask
40+
from .subtractBackground import SubtractBackgroundConfig, SubtractBackgroundTask
3341

3442

3543
class AdaptiveThresholdDetectionConfig(Config):
@@ -314,3 +322,258 @@ def run(self, table, exposure, **kwargs):
314322
detections.includeThresholdMultiplier = adaptiveDetectionConfig.includeThresholdMultiplier
315323
return detections
316324

325+
326+
327+
class AdaptiveThresholdBackgroundConfig(SubtractBackgroundConfig):
328+
detectedFractionBadMaskPlanes = ListField(
329+
"Mask planes to ignore when computing the detected fraction.", dtype=str,
330+
default=["BAD", "EDGE", "NO_DATA"]
331+
)
332+
minDetFracForFinalBg = Field(
333+
"Minimum detected fraction for the final background.",
334+
dtype=float, default=0.02
335+
)
336+
maxDetFracForFinalBg = Field(
337+
"Maximum detected fraction for the final background.",
338+
dtype=float, default=0.93
339+
)
340+
341+
342+
class AdaptiveThresholdBackgroundTask(SubtractBackgroundTask):
343+
"""A background subtraction task that does its own masking of detected
344+
sources, using an adaptive scheme that iterates until bounds on the mask
345+
fraction are satisfied.
346+
347+
Notes
348+
-----
349+
This task is only designed for use on detector images, as it is aware of
350+
amplifier geometry (to deal with the fact that some amps have much higher
351+
noise than others, and hence very different detected-mask fractions for the
352+
same detection threshold.
353+
"""
354+
355+
ConfigClass = AdaptiveThresholdBackgroundConfig
356+
_DETECTED_MASK_PLANES = ("DETECTED", "DETECTED_NEGATIVE")
357+
358+
def run(self, exposure, background=None, stats=True, statsKeys=None, backgroundToPhotometricRatio=None):
359+
# Restore the previously measured background and remeasure it
360+
# using an adaptive threshold detection iteration to ensure a
361+
# "Goldilocks Zone" for the fraction of detected pixels.
362+
if not background:
363+
background = BackgroundList()
364+
median_background = 0.0
365+
else:
366+
median_background = np.median(background.getImage().array)
367+
self.log.warning("Original median_background = %.2f", median_background)
368+
# TODO: apply backgroundToPhotometricRatio here!
369+
exposure.image.array += background.getImage().array
370+
371+
with self._restore_mask_when_done(exposure) as original_mask:
372+
self._dilate_original_mask(exposure, original_mask)
373+
self._set_adaptive_detection_mask(exposure, median_background)
374+
# Do not pass the original background in, since we want to wholly
375+
# replace it.
376+
return super().run(exposure=exposure, stats=stats, statsKeys=statsKeys,
377+
backgroundToPhotometricRatio=backgroundToPhotometricRatio)
378+
379+
def _dilate_original_mask(self, exposure, original_mask):
380+
nPixToDilate = 10
381+
detected_fraction_orig = self._compute_mask_fraction(exposure.mask)
382+
# Dilate the current detected mask planes and don't clear
383+
# them in the detection step.
384+
inDilating = True
385+
while inDilating:
386+
dilatedMask = original_mask.clone()
387+
for maskName in self._DETECTED_MASK_PLANES:
388+
# Compute the grown detection mask plane using SpanSet
389+
detectedMaskBit = dilatedMask.getPlaneBitMask(maskName)
390+
detectedMaskSpanSet = SpanSet.fromMask(dilatedMask, detectedMaskBit)
391+
detectedMaskSpanSet = detectedMaskSpanSet.dilated(nPixToDilate)
392+
detectedMaskSpanSet = detectedMaskSpanSet.clippedTo(dilatedMask.getBBox())
393+
# Clear the detected mask plane
394+
detectedMask = dilatedMask.getMaskPlane(maskName)
395+
dilatedMask.clearMaskPlane(detectedMask)
396+
# Set the mask plane to the dilated one
397+
detectedMaskSpanSet.setMask(dilatedMask, detectedMaskBit)
398+
399+
detected_fraction_dilated = self._compute_mask_fraction(dilatedMask)
400+
if detected_fraction_dilated < self.config.maxDetFracForFinalBg or nPixToDilate == 1:
401+
inDilating = False
402+
else:
403+
nPixToDilate -= 1
404+
exposure.mask = dilatedMask
405+
self.log.warning("detected_fraction_orig = %.3f detected_fraction_dilated = %.3f",
406+
detected_fraction_orig, detected_fraction_dilated)
407+
n_above_max_per_amp = -99
408+
highest_detected_fraction_per_amp = float("nan")
409+
doCheckPerAmpDetFraction = True
410+
if doCheckPerAmpDetFraction: # detected_fraction < maxDetFracForFinalBg:
411+
n_above_max_per_amp, highest_detected_fraction_per_amp, no_zero_det_amps = \
412+
self._compute_per_amp_fraction(exposure, detected_fraction_dilated)
413+
self.log.warning("Dilated mask: n_above_max_per_amp = %d, "
414+
"highest_detected_fraction_per_amp = %.3f",
415+
n_above_max_per_amp, highest_detected_fraction_per_amp)
416+
417+
def _set_adaptive_detection_mask(self, exposure, median_background):
418+
inBackgroundDet = True
419+
detected_fraction = 1.0
420+
maxIter = 40
421+
nIter = 0
422+
nFootprintTemp = 1e12
423+
starBackgroundDetectionConfig = SourceDetectionConfig()
424+
starBackgroundDetectionConfig.doTempLocalBackground = False
425+
starBackgroundDetectionConfig.nSigmaToGrow = 70.0
426+
starBackgroundDetectionConfig.reEstimateBackground = False
427+
starBackgroundDetectionConfig.includeThresholdMultiplier = 1.0
428+
starBackgroundDetectionConfig.thresholdValue = max(2.0, 0.2*median_background)
429+
starBackgroundDetectionConfig.thresholdType = "pixel_stdev" # "stdev"
430+
431+
n_above_max_per_amp = -99
432+
highest_detected_fraction_per_amp = float("nan")
433+
doCheckPerAmpDetFraction = True
434+
435+
while inBackgroundDet:
436+
currentThresh = starBackgroundDetectionConfig.thresholdValue
437+
if detected_fraction > self.config.maxDetFracForFinalBg:
438+
starBackgroundDetectionConfig.thresholdValue = 1.07*currentThresh
439+
if nFootprintTemp < 3 and detected_fraction > 0.9*self.config.maxDetFracForFinalBg:
440+
starBackgroundDetectionConfig.thresholdValue = 1.2*currentThresh
441+
if n_above_max_per_amp > 1:
442+
starBackgroundDetectionConfig.thresholdValue = 1.1*currentThresh
443+
if detected_fraction < self.config.minDetFracForFinalBg:
444+
starBackgroundDetectionConfig.thresholdValue = 0.8*currentThresh
445+
starBackgroundDetectionTask = SourceDetectionTask(
446+
config=starBackgroundDetectionConfig)
447+
tempDetections = starBackgroundDetectionTask.detectFootprints(
448+
exposure=exposure, clearMask=True)
449+
exposure.mask |= dilatedMask
450+
nFootprintTemp = (
451+
(len(tempDetections.positive.getFootprints()) if tempDetections is not None else 0)
452+
+ (len(tempDetections.negative.getFootprints()) if tempDetections.negative is not None else 0)
453+
)
454+
detected_fraction = self._compute_mask_fraction(exposure.mask)
455+
self.log.info("nIter = %d, thresh = %.2f: Fraction of pixels marked as DETECTED or "
456+
"DETECTED_NEGATIVE in star_background_detection = %.3f "
457+
"(max is %.3f; min is %.3f)",
458+
nIter, starBackgroundDetectionConfig.thresholdValue,
459+
detected_fraction, self.config.maxDetFracForFinalBg, self.config.minDetFracForFinalBg)
460+
461+
n_amp = len(exposure.detector.getAmplifiers())
462+
if doCheckPerAmpDetFraction: # detected_fraction < maxDetFracForFinalBg:
463+
n_above_max_per_amp, highest_detected_fraction_per_amp, no_zero_det_amps = \
464+
self._compute_per_amp_fraction(exposure, detected_fraction)
465+
466+
if not no_zero_det_amps:
467+
starBackgroundDetectionConfig.thresholdValue = 0.95*currentThresh
468+
nIter += 1
469+
if nIter > maxIter:
470+
inBackgroundDet = False
471+
472+
if (detected_fraction < self.config.maxDetFracForFinalBg and detected_fraction > self.config.minDetFracForFinalBg
473+
and n_above_max_per_amp < int(0.75*n_amp)
474+
and no_zero_det_amps):
475+
if (n_above_max_per_amp < max(1, int(0.15*n_amp))
476+
or detected_fraction < 0.85*self.config.maxDetFracForFinalBg):
477+
inBackgroundDet = False
478+
else:
479+
self.log.warning("Making small tweak....")
480+
starBackgroundDetectionConfig.thresholdValue = 1.05*currentThresh
481+
self.log.warning("n_above_max_per_amp = %d (abs max is %d)", n_above_max_per_amp, int(0.75*n_amp))
482+
483+
self.log.info("Fraction of pixels marked as DETECTED or DETECTED_NEGATIVE is now %.5f "
484+
"(highest per amp section = %.5f)",
485+
detected_fraction, highest_detected_fraction_per_amp)
486+
487+
if detected_fraction > self.config.maxDetFracForFinalBg:
488+
exposure.mask = dilatedMask
489+
self.log.warning("Final fraction of pixels marked as DETECTED or DETECTED_NEGATIVE "
490+
"was too large in star_background_detection = %.3f (max = %.3f). "
491+
"Reverting to dilated mask from PSF detection...",
492+
detected_fraction, self.config.maxDetFracForFinalBg)
493+
494+
def _compute_mask_fraction(self, mask):
495+
"""Evaluate the fraction of masked pixels in a (set of) mask plane(s).
496+
497+
Parameters
498+
----------
499+
mask : `lsst.afw.image.Mask`
500+
The mask on which to evaluate the fraction.
501+
502+
Returns
503+
-------
504+
detected_fraction : `float`
505+
The calculated fraction of masked pixels
506+
"""
507+
bad_pixel_mask = Mask.getPlaneBitMask(self.config.detectedFractionBadMaskPlanes)
508+
n_good_pix = np.sum(mask.array & bad_pixel_mask == 0)
509+
if n_good_pix == 0:
510+
detected_fraction = float("nan")
511+
return detected_fraction
512+
detected_pixel_mask = Mask.getPlaneBitMask(self._DETECTED_MASK_PLANES)
513+
n_detected_pix = np.sum((mask.array & detected_pixel_mask != 0)
514+
& (mask.array & bad_pixel_mask == 0))
515+
detected_fraction = n_detected_pix/n_good_pix
516+
return detected_fraction
517+
518+
def _compute_per_amp_fraction(self, exposure, detected_fraction):
519+
"""Evaluate the maximum per-amplifier fraction of masked pixels.
520+
521+
Parameters
522+
----------
523+
exposure : `lsst.afw.image.ExposureF`
524+
The exposure on which to compute the per-amp masked fraction.
525+
detected_fraction : `float`
526+
The current detected_fraction of the detected mask planes for the
527+
full detector.
528+
529+
Returns
530+
-------
531+
n_above_max_per_amp : `int`
532+
The number of amplifiers with masked fractions above a maximum
533+
value (set by the current full-detector ``detected_fraction``).
534+
highest_detected_fraction_per_amp : `float`
535+
The highest value of the per-amplifier fraction of masked pixels.
536+
no_zero_det_amps : `bool`
537+
A boolean representing whether any of the amplifiers has zero
538+
masked pixels.
539+
"""
540+
highest_detected_fraction_per_amp = -9.99
541+
n_above_max_per_amp = 0
542+
n_no_zero_det_amps = 0
543+
no_zero_det_amps = True
544+
amps = exposure.detector.getAmplifiers()
545+
if amps is not None:
546+
for ia, amp in enumerate(amps):
547+
amp_bbox = amp.getBBox()
548+
exp_bbox = exposure.getBBox()
549+
if not exp_bbox.contains(amp_bbox):
550+
self.log.info("Bounding box of amplifier (%s) does not fit in exposure's "
551+
"bounding box (%s). Skipping...", amp_bbox, exp_bbox)
552+
continue
553+
sub_image = exposure.subset(amp.getBBox())
554+
detected_fraction_amp = self._compute_mask_fraction(sub_image.mask)
555+
self.log.debug("Current detected fraction for amplifier %s = %.3f",
556+
amp.getName(), detected_fraction_amp)
557+
if detected_fraction_amp < 0.002:
558+
n_no_zero_det_amps += 1
559+
if n_no_zero_det_amps > 2:
560+
no_zero_det_amps = False
561+
break
562+
highest_detected_fraction_per_amp = max(detected_fraction_amp,
563+
highest_detected_fraction_per_amp)
564+
if highest_detected_fraction_per_amp > min(0.998, max(0.8, 3.0*detected_fraction)):
565+
n_above_max_per_amp += 1
566+
if n_above_max_per_amp > 2:
567+
break
568+
else:
569+
self.log.info("No amplifier object for detector %d, so skipping per-amp "
570+
"detection fraction checks.", exposure.detector.getId())
571+
return n_above_max_per_amp, highest_detected_fraction_per_amp, no_zero_det_amps
572+
573+
@contextmanager
574+
def _restore_mask_when_done(self, exposure):
575+
original_mask = exposure.mask.clone()
576+
try:
577+
yield original_mask
578+
finally:
579+
exposure.mask = original_mask

0 commit comments

Comments
 (0)