diff --git a/doc/content.rst b/doc/content.rst index 14503f43c..26678212c 100644 --- a/doc/content.rst +++ b/doc/content.rst @@ -104,6 +104,7 @@ This module calculates the wavefront error by solving the TIE. * **CentroidDefault**: Default centroid class. * **CentroidRandomWalk**: CentroidDefault child class to get the centroid of donut by the random walk model. * **CentroidOtsu**: CentroidDefault child class to get the centroid of donut by the Otsu's method. +* **CentroidConvolveTemplate**: CentroidDefault child class to get the centroids of one or more donuts in an image by convolution with a template donut. * **BaseCwfsTestCase**: Base class for CWFS tests. .. _lsst.ts.wep-modules_wep_deblend: diff --git a/doc/uml/cwfsClass.uml b/doc/uml/cwfsClass.uml index 20b4c2b8c..7d7ddb847 100644 --- a/doc/uml/cwfsClass.uml +++ b/doc/uml/cwfsClass.uml @@ -5,8 +5,11 @@ Algorithm -- CompensableImage CompensableImage ..> Instrument CentroidDefault <|-- CentroidRandomWalk CentroidDefault <|-- CentroidOtsu +CentroidDefault <|-- CentroidConvolveTemplate CentroidFindFactory ..> CentroidRandomWalk CentroidFindFactory ..> CentroidOtsu +CentroidFindFactory ..> CentroidConvolveTemplate +CentroidConvolveTemplate *-- CentroidRandomWalk Image ..> CentroidFindFactory Image *-- CentroidDefault BaseCwfsTestCase ..> CompensableImage diff --git a/doc/versionHistory.rst b/doc/versionHistory.rst index 99cbfa7dd..3c39b820c 100644 --- a/doc/versionHistory.rst +++ b/doc/versionHistory.rst @@ -6,6 +6,14 @@ Version History ################## +.. _lsst.ts.wep-1.5.0: + +------------- +1.5.0 +------------- + +* Add ``CentroidConvolveTemplate`` as a new centroid finding method. + .. _lsst.ts.wep-1.4.9: ------------- diff --git a/policy/default.yaml b/policy/default.yaml index e956a954c..02c2513d4 100644 --- a/policy/default.yaml +++ b/policy/default.yaml @@ -47,7 +47,7 @@ defocalDistInMm: 1.5 # Donut image size in pixel (default value at 1.5 mm) donutImgSizeInPixel: 160 -# Centroid find algorithm. It can be "randomWalk" or "otsu" +# Centroid find algorithm. It can be "randomWalk", "otsu", or "convolveTemplate" centroidFindAlgo: randomWalk # Camera mapper for the data butler to use diff --git a/python/lsst/ts/wep/Utility.py b/python/lsst/ts/wep/Utility.py index 9ce6b57cd..143d0e176 100644 --- a/python/lsst/ts/wep/Utility.py +++ b/python/lsst/ts/wep/Utility.py @@ -62,6 +62,7 @@ class ImageType(IntEnum): class CentroidFindType(IntEnum): RandomWalk = 1 Otsu = auto() + ConvolveTemplate = auto() class DeblendDonutType(IntEnum): @@ -359,7 +360,7 @@ def getCentroidFindType(centroidFindType): Parameters ---------- centroidFindType : str - Centroid find algorithm to use (randomWalk or otsu). + Centroid find algorithm to use (randomWalk, otsu, or convolveTemplate). Returns ------- @@ -376,6 +377,8 @@ def getCentroidFindType(centroidFindType): return CentroidFindType.RandomWalk elif centroidFindType == "otsu": return CentroidFindType.Otsu + elif centroidFindType == "convolveTemplate": + return CentroidFindType.ConvolveTemplate else: raise ValueError("The %s is not supported." % centroidFindType) diff --git a/python/lsst/ts/wep/cwfs/CentroidConvolveTemplate.py b/python/lsst/ts/wep/cwfs/CentroidConvolveTemplate.py new file mode 100644 index 000000000..5df4e1eaf --- /dev/null +++ b/python/lsst/ts/wep/cwfs/CentroidConvolveTemplate.py @@ -0,0 +1,206 @@ +# This file is part of ts_wep. +# +# Developed for the LSST Telescope and Site Systems. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import numpy as np +from copy import copy +from lsst.ts.wep.cwfs.CentroidDefault import CentroidDefault +from lsst.ts.wep.cwfs.CentroidRandomWalk import CentroidRandomWalk +from scipy.signal import correlate +from sklearn.cluster import KMeans + + +class CentroidConvolveTemplate(CentroidDefault): + def __init__(self): + """CentroidDefault child class to get the centroid of donut by + convolution with a template donut image.""" + + super(CentroidConvolveTemplate, self).__init__() + self._centRandomWalk = CentroidRandomWalk() + + def getImgBinary(self, imgDonut): + """Get the binary image. + + Parameters + ---------- + imgDonut : numpy.ndarray + Donut image to do the analysis. + + Returns + ------- + numpy.ndarray [int] + Binary image of donut. + """ + + return self._centRandomWalk.getImgBinary(imgDonut) + + def getCenterAndR(self, imgDonut, templateDonut=None, peakThreshold=0.95): + """Get the centroid data and effective weighting radius. + + Parameters + ---------- + imgDonut : numpy.ndarray + Donut image. + templateDonut : None or numpy.ndarray, optional + Template image for a single donut. If set to None + then the image will be convolved with itself. (The Default is None) + peakThreshold : float, optional + This value is a specifies a number between 0 and 1 that is + the fraction of the highest pixel value in the convolved image. + The code then sets all pixels with a value below this to 0 before + running the K-means algorithm to find peaks that represent possible + donut locations. (The default is 0.95) + + Returns + ------- + float + Centroid x. + float + Centroid y. + float + Effective weighting radius. + """ + + imgBinary = self.getImgBinary(imgDonut) + + if templateDonut is None: + templateBinary = copy(imgBinary) + else: + templateBinary = self.getImgBinary(templateDonut) + + return self.getCenterAndRfromImgBinary( + imgBinary, templateBinary=templateBinary, peakThreshold=peakThreshold, + ) + + def getCenterAndRfromImgBinary( + self, imgBinary, templateBinary=None, peakThreshold=0.95 + ): + """Get the centroid data and effective weighting radius. + + Parameters + ---------- + imgBinary : numpy.ndarray + Binary image of donut. + templateBinary : None or numpy.ndarray, optional + Binary image of template for a single donut. If set to None + then the image will be convolved with itself. (The Default is None) + peakThreshold : float, optional + This value is a specifies a number between 0 and 1 that is + the fraction of the highest pixel value in the convolved image. + The code then sets all pixels with a value below this to 0 before + running the K-means algorithm to find peaks that represent possible + donut locations. (The default is 0.95) + + Returns + ------- + float + Centroid x. + float + Centroid y. + float + Effective weighting radius. + """ + + x, y, radius = self.getCenterAndRfromTemplateConv( + imgBinary, + templateImgBinary=templateBinary, + nDonuts=1, + peakThreshold=peakThreshold, + ) + + return x[0], y[0], radius + + def getCenterAndRfromTemplateConv( + self, imageBinary, templateImgBinary=None, nDonuts=1, peakThreshold=0.95 + ): + """ + Get the centers of the donuts by convolving a binary template image + with the binary image of the donut or donuts. + + Peaks will appear as bright spots in the convolved image. Since we + use binary images the brightness of the stars does not matter and + the peaks of any stars in the image should have about the same + brightness if the template is correct. + + Parameters + ---------- + imageBinary: numpy.ndarray + Binary image of postage stamp. + templateImgBinary: None or numpy.ndarray, optional + Binary image of template donut. If set to None then the image + is convolved with itself. (The default is None) + nDonuts: int, optional + Number of donuts there should be in the binary image. Needs to + be >= 1. (The default is 1) + peakThreshold: float, optional + This value is a specifies a number between 0 and 1 that is + the fraction of the highest pixel value in the convolved image. + The code then sets all pixels with a value below this to 0 before + running the K-means algorithm to find peaks that represent possible + donut locations. (The default is 0.95) + + Returns + ------- + list + X pixel coordinates for donut centroid. + list + Y pixel coordinates for donut centroid. + float + Effective weighting radius calculated using the template image. + """ + + if templateImgBinary is None: + templateImgBinary = copy(imageBinary) + + nDonutsAssertStr = "nDonuts must be an integer >= 1" + assert (nDonuts >= 1) & (type(nDonuts) is int), nDonutsAssertStr + + # We set the mode to be "same" because we need to return the same + # size image to the code. + tempConvolve = correlate(imageBinary, templateImgBinary, mode="same") + + # Then we rank the pixel values keeping only those above + # some fraction of the highest value. + rankedConvolve = np.argsort(tempConvolve.flatten())[::-1] + cutoff = len( + np.where(tempConvolve.flatten() > peakThreshold * np.max(tempConvolve))[0] + ) + rankedConvolveCutoff = rankedConvolve[:cutoff] + nx, ny = np.unravel_index(rankedConvolveCutoff, np.shape(imageBinary)) + + # Then to find peaks in the image we use K-Means with the + # specified number of donuts + kmeans = KMeans(n_clusters=nDonuts) + labels = kmeans.fit_predict(np.array([nx, ny]).T) + + # Then in each cluster we take the brightest pixel as the centroid + centX = [] + centY = [] + for labelNum in range(nDonuts): + nxLabel, nyLabel = np.unravel_index( + rankedConvolveCutoff[labels == labelNum][0], np.shape(imageBinary) + ) + centX.append(nxLabel) + centY.append(nyLabel) + + # Get the radius of the donut from the template image + radius = np.sqrt(np.sum(templateImgBinary) / np.pi) + + return centX, centY, radius diff --git a/python/lsst/ts/wep/cwfs/CentroidDefault.py b/python/lsst/ts/wep/cwfs/CentroidDefault.py index 1757476fc..9493a4eed 100644 --- a/python/lsst/ts/wep/cwfs/CentroidDefault.py +++ b/python/lsst/ts/wep/cwfs/CentroidDefault.py @@ -26,13 +26,15 @@ class CentroidDefault(object): """Default Centroid class.""" - def getCenterAndR(self, imgDonut): + def getCenterAndR(self, imgDonut, **kwargs): """Get the centroid data and effective weighting radius. Parameters ---------- imgDonut : numpy.ndarray Donut image. + **kwargs : dict[str, any] + Dictionary of input argument: new value for that input argument. Returns ------- @@ -48,7 +50,7 @@ def getCenterAndR(self, imgDonut): return self.getCenterAndRfromImgBinary(imgBinary) - def getCenterAndRfromImgBinary(self, imgBinary): + def getCenterAndRfromImgBinary(self, imgBinary, **kwargs): """Get the centroid data and effective weighting radius from the binary image. @@ -56,6 +58,8 @@ def getCenterAndRfromImgBinary(self, imgBinary): ---------- imgBinary : numpy.ndarray [int] Binary image of donut. + **kwargs : dict[str, any] + Dictionary of input argument: new value for that input argument. Returns ------- diff --git a/python/lsst/ts/wep/cwfs/CentroidFindFactory.py b/python/lsst/ts/wep/cwfs/CentroidFindFactory.py index af83fc236..520c5c983 100644 --- a/python/lsst/ts/wep/cwfs/CentroidFindFactory.py +++ b/python/lsst/ts/wep/cwfs/CentroidFindFactory.py @@ -22,6 +22,7 @@ from lsst.ts.wep.Utility import CentroidFindType from lsst.ts.wep.cwfs.CentroidRandomWalk import CentroidRandomWalk from lsst.ts.wep.cwfs.CentroidOtsu import CentroidOtsu +from lsst.ts.wep.cwfs.CentroidConvolveTemplate import CentroidConvolveTemplate class CentroidFindFactory(object): @@ -39,7 +40,7 @@ def createCentroidFind(centroidFindType): Returns ------- - CentroidRandomWalk, CentroidOtsu + Child class of centroidDefault Centroid find object. Raises @@ -52,5 +53,7 @@ def createCentroidFind(centroidFindType): return CentroidRandomWalk() elif centroidFindType == CentroidFindType.Otsu: return CentroidOtsu() + elif centroidFindType == CentroidFindType.ConvolveTemplate: + return CentroidConvolveTemplate() else: raise ValueError("The %s is not supported." % centroidFindType) diff --git a/tests/cwfs/test_centroidConvolveTemplate.py b/tests/cwfs/test_centroidConvolveTemplate.py new file mode 100644 index 000000000..469962686 --- /dev/null +++ b/tests/cwfs/test_centroidConvolveTemplate.py @@ -0,0 +1,192 @@ +# This file is part of ts_wep. +# +# Developed for the LSST Telescope and Site Systems. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import unittest +import numpy as np + +from lsst.ts.wep.cwfs.CentroidConvolveTemplate import CentroidConvolveTemplate + + +class TestCentroidConvolveTemplate(unittest.TestCase): + """Test the CentroidConvolveTemplate class.""" + + def setUp(self): + + self.centroidConv = CentroidConvolveTemplate() + + def _createData(self, radiusInner, radiusOuter, imageSize, addNoise=False): + + # Create two images. One with a single donut and one with two donuts. + singleDonut = np.zeros((imageSize, imageSize)) + doubleDonut = np.zeros((imageSize, imageSize)) + + for x in range(imageSize): + for y in range(imageSize): + # For single donut put the donut at the center of the image + if ( + np.sqrt((imageSize / 2 - x) ** 2 + (imageSize / 2 - y) ** 2) + <= radiusOuter + ): + singleDonut[x, y] += 1 + if ( + np.sqrt((imageSize / 2 - x) ** 2 + (imageSize / 2 - y) ** 2) + <= radiusInner + ): + singleDonut[x, y] -= 1 + # For double donut put the two donuts along same line + # halfway down the image and provide 10 pixels between + # image edge and outer edge of donut on either side of image + if ( + np.sqrt(((radiusOuter + 10) - x) ** 2 + (imageSize / 2 - y) ** 2) + <= radiusOuter + ): + doubleDonut[x, y] += 1 + if ( + np.sqrt(((radiusOuter + 10) - x) ** 2 + (imageSize / 2 - y) ** 2) + <= radiusInner + ): + doubleDonut[x, y] -= 1 + if ( + np.sqrt( + (imageSize - (radiusOuter + 10) - x) ** 2 + + (imageSize / 2 - y) ** 2 + ) + <= radiusOuter + ): + doubleDonut[x, y] += 1 + if ( + np.sqrt( + (imageSize - (radiusOuter + 10) - x) ** 2 + + (imageSize / 2 - y) ** 2 + ) + <= radiusInner + ): + doubleDonut[x, y] -= 1 + # Make binary image + doubleDonut[doubleDonut > 0.5] = 1 + + if addNoise is True: + # Add noise so the images are not binary + randState = np.random.RandomState(42) + singleDonut += randState.normal(scale=0.01, size=np.shape(singleDonut)) + doubleDonut += randState.normal(scale=0.01, size=np.shape(doubleDonut)) + + eff_radius = np.sqrt(radiusOuter ** 2 - radiusInner ** 2) + + return singleDonut, doubleDonut, eff_radius + + def testGetImgBinary(self): + + singleDonut, doubleDonut, eff_radius = self._createData( + 20, 40, 160, addNoise=False + ) + + noisySingle, noisyDouble, eff_radius = self._createData( + 20, 40, 160, addNoise=True + ) + + binarySingle = self.centroidConv.getImgBinary(noisySingle) + + np.testing.assert_array_equal(singleDonut, binarySingle) + + def testGetCenterAndRWithoutTemplate(self): + + singleDonut, doubleDonut, eff_radius = self._createData( + 20, 40, 160, addNoise=True + ) + + # Test recovery with defaults + centX, centY, rad = self.centroidConv.getCenterAndR(singleDonut) + + self.assertEqual(centX, 80.0) + self.assertEqual(centY, 80.0) + self.assertAlmostEqual(rad, eff_radius, delta=0.1) + + def testGetCenterAndRWithTemplate(self): + + singleDonut, doubleDonut, eff_radius = self._createData( + 20, 40, 160, addNoise=True + ) + + # Test recovery with defaults + centX, centY, rad = self.centroidConv.getCenterAndR( + singleDonut, templateDonut=singleDonut + ) + + self.assertEqual(centX, 80.0) + self.assertEqual(centY, 80.0) + self.assertAlmostEqual(rad, eff_radius, delta=0.1) + + def testGetCenterAndRFromImgBinary(self): + + singleDonut, doubleDonut, eff_radius = self._createData(20, 40, 160) + + # Test recovery with defaults + centX, centY, rad = self.centroidConv.getCenterAndRfromImgBinary(singleDonut) + + self.assertEqual(centX, 80.0) + self.assertEqual(centY, 80.0) + self.assertAlmostEqual(rad, eff_radius, delta=0.1) + + def testNDonutsAssertion(self): + + singleDonut, doubleDonut, eff_radius = self._createData(20, 40, 160) + + nDonutsAssertMsg = "nDonuts must be an integer >= 1" + with self.assertRaises(AssertionError, msg=nDonutsAssertMsg): + cX, cY, rad = self.centroidConv.getCenterAndRfromTemplateConv( + singleDonut, nDonuts=0 + ) + + with self.assertRaises(AssertionError, msg=nDonutsAssertMsg): + cX, cY, rad = self.centroidConv.getCenterAndRfromTemplateConv( + singleDonut, nDonuts=-1 + ) + + with self.assertRaises(AssertionError, msg=nDonutsAssertMsg): + cX, cY, rad = self.centroidConv.getCenterAndRfromTemplateConv( + singleDonut, nDonuts=1.5 + ) + + def testGetCenterAndRFromTemplateConv(self): + + singleDonut, doubleDonut, eff_radius = self._createData(20, 40, 160) + + # Test recovery of single donut + singleCX, singleCY, rad = self.centroidConv.getCenterAndRfromTemplateConv( + singleDonut + ) + self.assertEqual(singleCX, [80.0]) + self.assertEqual(singleCY, [80.0]) + self.assertAlmostEqual(rad, eff_radius, delta=0.1) + + # Test recovery of two donuts at once + doubleCX, doubleCY, rad = self.centroidConv.getCenterAndRfromTemplateConv( + doubleDonut, templateImgBinary=singleDonut, nDonuts=2 + ) + self.assertCountEqual(doubleCX, [50.0, 110.0]) + self.assertEqual(doubleCY, [80.0, 80.0]) + self.assertAlmostEqual(rad, eff_radius, delta=0.1) + + +if __name__ == "__main__": + + unittest.main() diff --git a/tests/cwfs/test_centroidFindFactory.py b/tests/cwfs/test_centroidFindFactory.py index a9ccfb6b8..199156f49 100644 --- a/tests/cwfs/test_centroidFindFactory.py +++ b/tests/cwfs/test_centroidFindFactory.py @@ -25,6 +25,7 @@ from lsst.ts.wep.cwfs.CentroidFindFactory import CentroidFindFactory from lsst.ts.wep.cwfs.CentroidRandomWalk import CentroidRandomWalk from lsst.ts.wep.cwfs.CentroidOtsu import CentroidOtsu +from lsst.ts.wep.cwfs.CentroidConvolveTemplate import CentroidConvolveTemplate class TestCentroidFindFactory(unittest.TestCase): @@ -42,6 +43,13 @@ def testCreateCentroidFindOtsu(self): centroidFind = CentroidFindFactory.createCentroidFind(CentroidFindType.Otsu) self.assertTrue(isinstance(centroidFind, CentroidOtsu)) + def testCreateCentroidFindConvolveTemplate(self): + + centroidFind = CentroidFindFactory.createCentroidFind( + CentroidFindType.ConvolveTemplate + ) + self.assertTrue(isinstance(centroidFind, CentroidConvolveTemplate)) + def testCreateCentroidFindWrongType(self): self.assertRaises(