Skip to content

Commit bb4876a

Browse files
committed
a better approach to dev
1 parent 42697de commit bb4876a

File tree

3 files changed

+335
-0
lines changed

3 files changed

+335
-0
lines changed
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
import warnings
2+
3+
import astropy.units as u
4+
import numpy as np
5+
import sunpy.map
6+
from skimage.feature import match_template
7+
from sunpy.util.exceptions import SunpyUserWarning
8+
9+
10+
############################ Coalignment Interface begins #################################
11+
@u.quantity_input
12+
def _clip_edges(data, yclips: u.pix, xclips: u.pix):
13+
"""
14+
Clips off the "y" and "x" edges of a 2D array according to a list of pixel
15+
values. This function is useful for removing data at the edge of 2d images
16+
that may be affected by shifts from solar de- rotation and layer co-
17+
registration, leaving an image unaffected by edge effects.
18+
19+
Parameters
20+
----------
21+
data : `numpy.ndarray`
22+
A numpy array of shape ``(ny, nx)``.
23+
yclips : `astropy.units.Quantity`
24+
The amount to clip in the y-direction of the data. Has units of
25+
pixels, and values should be whole non-negative numbers.
26+
xclips : `astropy.units.Quantity`
27+
The amount to clip in the x-direction of the data. Has units of
28+
pixels, and values should be whole non-negative numbers.
29+
30+
Returns
31+
-------
32+
`numpy.ndarray`
33+
A 2D image with edges clipped off according to ``yclips`` and ``xclips``
34+
arrays.
35+
"""
36+
ny = data.shape[0]
37+
nx = data.shape[1]
38+
# The purpose of the int below is to ensure integer type since by default
39+
# astropy quantities are converted to floats.
40+
return data[int(yclips[0].value) : ny - int(yclips[1].value), int(xclips[0].value) : nx - int(xclips[1].value)]
41+
42+
43+
@u.quantity_input
44+
def _calculate_clipping(y: u.pix, x: u.pix):
45+
"""
46+
Return the upper and lower clipping values for the "y" and "x" directions.
47+
48+
Parameters
49+
----------
50+
y : `astropy.units.Quantity`
51+
An array of pixel shifts in the y-direction for an image.
52+
x : `astropy.units.Quantity`
53+
An array of pixel shifts in the x-direction for an image.
54+
55+
Returns
56+
-------
57+
`tuple`
58+
The tuple is of the form ``([y0, y1], [x0, x1])``.
59+
The number of (integer) pixels that need to be clipped off at each
60+
edge in an image. The first element in the tuple is a list that gives
61+
the number of pixels to clip in the y-direction. The first element in
62+
that list is the number of rows to clip at the lower edge of the image
63+
in y. The clipped image has "clipping[0][0]" rows removed from its
64+
lower edge when compared to the original image. The second element in
65+
that list is the number of rows to clip at the upper edge of the image
66+
in y. The clipped image has "clipping[0][1]" rows removed from its
67+
upper edge when compared to the original image. The second element in
68+
the "clipping" tuple applies similarly to the x-direction (image
69+
columns). The parameters ``y0, y1, x0, x1`` have the type
70+
`~astropy.units.Quantity`.
71+
"""
72+
return (
73+
[_lower_clip(y.value), _upper_clip(y.value)] * u.pix,
74+
[_lower_clip(x.value), _upper_clip(x.value)] * u.pix,
75+
)
76+
77+
78+
def _upper_clip(z):
79+
"""
80+
Find smallest integer bigger than all the positive entries in the input
81+
array.
82+
"""
83+
zupper = 0
84+
zcond = z >= 0
85+
if np.any(zcond):
86+
zupper = int(np.max(np.ceil(z[zcond])))
87+
return zupper
88+
89+
90+
def _lower_clip(z):
91+
"""
92+
Find smallest positive integer bigger than the absolute values of the
93+
negative entries in the input array.
94+
"""
95+
zlower = 0
96+
zcond = z <= 0
97+
if np.any(zcond):
98+
zlower = int(np.max(np.ceil(-z[zcond])))
99+
return zlower
100+
101+
102+
def convert_array_to_map(array_obj, map_obj):
103+
"""
104+
Convert a 2D numpy array to a sunpy Map object using the header of a given
105+
map object.
106+
107+
Parameters
108+
----------
109+
array_obj : `numpy.ndarray`
110+
The 2D numpy array to be converted.
111+
map_obj : `sunpy.map.Map`
112+
The map object whose header is to be used for the new map.
113+
114+
Returns
115+
-------
116+
`sunpy.map.Map`
117+
A new sunpy map object with the data from `array_obj` and the header from `map_obj`.
118+
"""
119+
header = map_obj.meta.copy()
120+
header["crpix1"] -= array_obj.shape[1] / 2.0 - map_obj.data.shape[1] / 2.0
121+
header["crpix2"] -= array_obj.shape[0] / 2.0 - map_obj.data.shape[0] / 2.0
122+
return sunpy.map.Map(array_obj, header)
123+
124+
125+
def coalignment_interface(method, input_map, template_map, handle_nan=None):
126+
"""
127+
Interface for performing image coalignment using a specified method.
128+
129+
Parameters
130+
----------
131+
method : str
132+
The name of the registered coalignment method to use.
133+
input_map : `sunpy.map.Map`
134+
The input map to be coaligned.
135+
template_map : `sunpy.map.Map`
136+
The template map to which the input map is to be coaligned.
137+
handle_nan : callable, optional
138+
Function to handle NaN values in the input and template arrays.
139+
140+
Returns
141+
-------
142+
`sunpy.map.Map`
143+
The coaligned input map.
144+
145+
Raises
146+
------
147+
ValueError
148+
If the specified method is not registered.
149+
"""
150+
if method not in registered_methods:
151+
msg = f"Method {method} is not a registered method. Please register before using."
152+
raise ValueError(msg)
153+
input_array = np.float64(input_map.data)
154+
template_array = np.float64(template_map.data)
155+
156+
# Warn user if any NANs, Infs, etc are present in the input or the template array
157+
if not np.all(np.isfinite(input_array)):
158+
if not handle_nan:
159+
warnings.warn(
160+
"The layer image has nonfinite entries. "
161+
"This could cause errors when calculating shift between two "
162+
"images. Please make sure there are no infinity or "
163+
"Not a Number values. For instance, replacing them with a "
164+
"local mean.",
165+
SunpyUserWarning,
166+
stacklevel=3,
167+
)
168+
else:
169+
input_array = handle_nan(input_array)
170+
171+
if not np.all(np.isfinite(template_array)):
172+
if not handle_nan:
173+
warnings.warn(
174+
"The template image has nonfinite entries. "
175+
"This could cause errors when calculating shift between two "
176+
"images. Please make sure there are no infinity or "
177+
"Not a Number values. For instance, replacing them with a "
178+
"local mean.",
179+
SunpyUserWarning,
180+
stacklevel=3,
181+
)
182+
else:
183+
template_array = handle_nan(template_array)
184+
185+
shifts = registered_methods[method](input_array, template_array)
186+
# Calculate the clipping required
187+
yclips, xclips = _calculate_clipping(shifts["x"] * u.pix, shifts["y"] * u.pix)
188+
# Clip 'em
189+
coaligned_input_array = _clip_edges(input_array, yclips, xclips)
190+
return convert_array_to_map(coaligned_input_array, input_map)
191+
192+
######################################## Coalignment interface ends ##################
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import warnings
2+
3+
import astropy.units as u
4+
import numpy as np
5+
import sunpy.map
6+
from skimage.feature import match_template
7+
from sunpy.util.exceptions import SunpyUserWarning
8+
9+
######################################## Defining a method ###########################
10+
def _parabolic_turning_point(y):
11+
"""
12+
Calculate the turning point of a parabola given three points.
13+
14+
Parameters
15+
----------
16+
y : `numpy.ndarray`
17+
An array of three points defining the parabola.
18+
19+
Returns
20+
-------
21+
float
22+
The x-coordinate of the turning point.
23+
"""
24+
numerator = -0.5 * y.dot([-1, 0, 1])
25+
denominator = y.dot([1, -2, 1])
26+
return numerator / denominator
27+
28+
29+
def _get_correlation_shifts(array):
30+
"""
31+
Calculate the shifts in x and y directions based on the correlation array.
32+
33+
Parameters
34+
----------
35+
array : `numpy.ndarray`
36+
A 2D array representing the correlation values.
37+
38+
Returns
39+
-------
40+
tuple
41+
The shifts in y and x directions.
42+
43+
Raises
44+
------
45+
ValueError
46+
If the input array dimensions are greater than 3 in any direction.
47+
"""
48+
ny, nx = array.shape
49+
if nx > 3 or ny > 3:
50+
msg = "Input array dimension should not be greater than 3 in any dimension."
51+
raise ValueError(msg)
52+
53+
ij = np.unravel_index(np.argmax(array), array.shape)
54+
x_max_location, y_max_location = ij[::-1]
55+
56+
y_location = _parabolic_turning_point(array[:, x_max_location]) if ny == 3 else 1.0 * y_max_location
57+
x_location = _parabolic_turning_point(array[y_max_location, :]) if nx == 3 else 1.0 * x_max_location
58+
59+
return y_location, x_location
60+
61+
62+
def _find_best_match_location(corr):
63+
"""
64+
Find the best match location in the correlation array.
65+
66+
Parameters
67+
----------
68+
corr : `numpy.ndarray`
69+
The correlation array.
70+
71+
Returns
72+
-------
73+
tuple
74+
The best match location in the y and x directions.
75+
"""
76+
ij = np.unravel_index(np.argmax(corr), corr.shape)
77+
cor_max_x, cor_max_y = ij[::-1]
78+
79+
array_maximum = corr[
80+
max(0, cor_max_y - 1) : min(cor_max_y + 2, corr.shape[0] - 1),
81+
max(0, cor_max_x - 1) : min(cor_max_x + 2, corr.shape[1] - 1),
82+
]
83+
84+
y_shift_maximum, x_shift_maximum = _get_correlation_shifts(array_maximum)
85+
86+
y_shift_correlation_array = y_shift_maximum + cor_max_y
87+
x_shift_correlation_array = x_shift_maximum + cor_max_x
88+
89+
return y_shift_correlation_array, x_shift_correlation_array
90+
91+
92+
def match_template_coalign(input_array, template_array):
93+
"""
94+
Perform coalignment by matching the template array to the input array.
95+
96+
Parameters
97+
----------
98+
input_array : `numpy.ndarray`
99+
The input 2D array to be coaligned.
100+
template_array : `numpy.ndarray`
101+
The template 2D array to align to.
102+
103+
Returns
104+
-------
105+
dict
106+
A dictionary containing the shifts in x and y directions.
107+
"""
108+
corr = match_template(input_array, template_array)
109+
110+
# Find the best match location
111+
y_shift, x_shift = _find_best_match_location(corr)
112+
113+
# Apply the shift to get the coaligned input array
114+
return {"x": x_shift, "y": y_shift}
115+
116+
117+
################################ Registering the defined method ########################
118+
register_coalignment_method("match_template", match_template_coalign)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import warnings
2+
3+
import astropy.units as u
4+
import numpy as np
5+
import sunpy.map
6+
from skimage.feature import match_template
7+
from sunpy.util.exceptions import SunpyUserWarning
8+
9+
10+
## This dictionary will be further replaced in a more appropriate location, once the decorator structure is in place.
11+
registered_methods = {}
12+
13+
14+
def register_coalignment_method(name, method):
15+
"""
16+
Registers a coalignment method to be used by the coalignment interface.
17+
18+
Parameters
19+
----------
20+
name : str
21+
The name of the coalignment method.
22+
method : callable
23+
The function implementing the coalignment method.
24+
"""
25+
registered_methods[name] = method

0 commit comments

Comments
 (0)