-
Notifications
You must be signed in to change notification settings - Fork 97
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2337 from kif/2315_medfilt_ng_cython
Cython implementation of medfilt_ng
- Loading branch information
Showing
9 changed files
with
251 additions
and
36 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,7 +30,7 @@ | |
__contact__ = "[email protected]" | ||
__license__ = "MIT" | ||
__copyright__ = "European Synchrotron Radiation Facility, Grenoble, France" | ||
__date__ = "12/06/2024" | ||
__date__ = "19/11/2024" | ||
__status__ = "development" | ||
|
||
from collections import namedtuple | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,7 +26,7 @@ | |
__contact__ = "[email protected]" | ||
__license__ = "MIT" | ||
__copyright__ = "European Synchrotron Radiation Facility, Grenoble, France" | ||
__date__ = "14/11/2024" | ||
__date__ = "19/11/2024" | ||
__status__ = "development" | ||
|
||
from collections.abc import Iterable | ||
|
@@ -395,7 +395,9 @@ def medfilt(self, data, dark=None, dummy=None, delta_dummy=None, | |
variance=None, dark_variance=None, | ||
flat=None, solidangle=None, polarization=None, absorption=None, | ||
safe=True, error_model=None, | ||
normalization_factor=1.0, quantile=0.5 | ||
normalization_factor=1.0, | ||
quant_min=0.5, | ||
quant_max=0.5, | ||
): | ||
""" | ||
Perform a median-filter/quantile mean in azimuthal space. | ||
|
@@ -425,18 +427,11 @@ def medfilt(self, data, dark=None, dummy=None, delta_dummy=None, | |
:param safe: Unused in this implementation | ||
:param error_model: Enum or str, "azimuthal" or "poisson" | ||
:param normalization_factor: divide raw signal by this value | ||
:param quantile: which percentile/100 use for cutting out quantil. | ||
can be a 2-tuple to specify a region to average out. | ||
By default, takes the median | ||
:return: namedtuple with "position intensity error signal variance normalization count" | ||
:param quant_min: start percentile/100 to use. Use 0.5 for the median (default). 0<=quant_min<=1 | ||
:param quant_max: stop percentile/100 to use. Use 0.5 for the median (default). 0<=quant_max<=1 | ||
:return: namedtuple with "position intensity error signal variance normalization count" | ||
""" | ||
if isinstance(quantile, Iterable): | ||
q_start = min(quantile) | ||
q_stop = max(quantile) | ||
else: | ||
q_stop = q_start = quantile | ||
|
||
indptr = self._csr.indptr | ||
indices = self._csr.indices | ||
csr_data = self._csr.data | ||
|
@@ -484,7 +479,7 @@ def medfilt(self, data, dark=None, dummy=None, delta_dummy=None, | |
upper = numpy.cumsum(tmp["norm"]) | ||
last = upper[-1] | ||
lower = numpy.concatenate(([0],upper[:-1])) | ||
mask = numpy.logical_and(upper>=q_start*last, lower<=q_stop*last) | ||
mask = numpy.logical_and(upper>=quant_min*last, lower<=quant_max*last) | ||
tmp = tmp[mask] | ||
cnt[i] = tmp.size | ||
signal[i] = tmp["sig"].sum(dtype=numpy.float64) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,18 +29,47 @@ | |
|
||
__author__ = "Jérôme Kieffer" | ||
__contact__ = "[email protected]" | ||
__date__ = "06/09/2024" | ||
__date__ = "19/11/2024" | ||
__status__ = "stable" | ||
__license__ = "MIT" | ||
|
||
|
||
from libcpp cimport bool | ||
from libcpp.algorithm cimport sort | ||
from cython cimport floating | ||
|
||
import os | ||
import cython | ||
from cython.parallel import prange | ||
import numpy | ||
|
||
from .preproc import preproc | ||
from ..containers import Integrate1dtpl, Integrate2dtpl, ErrorModel | ||
|
||
# cdef Py_ssize_t MAX_THREADS = 8 | ||
# try: | ||
# MAX_THREADS = min(MAX_THREADS, len(os.sched_getaffinity(os.getpid()))) # Limit to the actual number of threads | ||
# except Exception: | ||
# MAX_THREADS = min(MAX_THREADS, os.cpu_count() or 1) | ||
|
||
|
||
cdef struct float4_t: | ||
float s0 | ||
float s1 | ||
float s2 | ||
float s3 | ||
float4_d = numpy.dtype([('s0','f4'),('s1','f4'),('s2','f4'),('s3','f4')]) | ||
|
||
cdef inline bool cmp(float4_t a, float4_t b) noexcept nogil: | ||
return True if a.s0<b.s0 else False | ||
|
||
cdef inline void sort_float4(float4_t[::1] ary) noexcept nogil: | ||
"Sort in place of an array of float4 along first element (s0)" | ||
cdef: | ||
int size | ||
size = ary.shape[0] | ||
sort(&ary[0], &ary[size-1]+1, cmp) | ||
|
||
|
||
cdef class CsrIntegrator(object): | ||
"""Abstract class which implements only the integrator... | ||
|
@@ -298,8 +327,8 @@ cdef class CsrIntegrator(object): | |
:type absorption: ndarray | ||
:param normalization_factor: divide the valid result by this value | ||
:param bool weighted_average: set to False to use an unweighted mean (similar to legacy) instead of the weighted average. WIP | ||
:return: positions, pattern, weighted_histogram and unweighted_histogram | ||
:rtype: Integrate1dtpl 4-named-tuple of ndarrays | ||
:return: namedtuple with "position intensity sigma signal variance normalization count std sem norm_sq" | ||
:rtype: Integrate1dtpl named-tuple of ndarrays | ||
""" | ||
cdef: | ||
index_t i, j, idx = 0 | ||
|
@@ -485,8 +514,9 @@ cdef class CsrIntegrator(object): | |
:param error_model: set to "poissonian" to use signal as variance (minimum 1), "azimuthal" to use the variance in a ring. | ||
:param normalization_factor: divide the valid result by this value | ||
:return: positions, pattern, weighted_histogram and unweighted_histogram | ||
:rtype: Integrate1dtpl 4-named-tuple of ndarrays | ||
:return: namedtuple with "position intensity sigma signal variance normalization count std sem norm_sq" | ||
:rtype: Integrate1dtpl named-tuple of ndarrays | ||
""" | ||
error_model = ErrorModel.parse(error_model) | ||
cdef: | ||
|
@@ -691,7 +721,167 @@ cdef class CsrIntegrator(object): | |
stda[i] = empty | ||
sema[i] = empty | ||
|
||
#"position intensity error signal variance normalization count" | ||
#"position intensity sigma signal variance normalization count std sem norm_sq" | ||
return Integrate1dtpl(self.bin_centers, | ||
numpy.asarray(merged),numpy.asarray(sema) , | ||
numpy.asarray(sum_sig),numpy.asarray(sum_var), | ||
numpy.asarray(sum_norm), numpy.asarray(sum_count), | ||
numpy.asarray(stda), numpy.asarray(sema), numpy.asarray(sum_norm_sq)) | ||
|
||
|
||
def medfilt( self, | ||
weights, | ||
dark=None, | ||
dummy=None, | ||
delta_dummy=None, | ||
variance=None, | ||
dark_variance=None, | ||
flat=None, | ||
solidangle=None, | ||
polarization=None, | ||
absorption=None, | ||
bint safe=True, | ||
error_model=ErrorModel.NO, | ||
data_t normalization_factor=1.0, | ||
double quant_min=0.5, | ||
double quant_max=0.5, | ||
): | ||
"""Perform a median filter/quantile averaging in azimuthal space | ||
Else, the error is propagated like Poisson or pre-defined variance, no azimuthal variance for now. | ||
Integration is performed using the CSR representation of the look-up table on all | ||
arrays: signal, variance, normalization and count | ||
All data are duplicated, sorted and the relevant values (i.e. within [quant_min..quant_max]) | ||
are averaged like in `integrate_ng` | ||
:param weights: input image | ||
:type weights: ndarray | ||
:param dark: array with the dark-current value to be subtracted (if any) | ||
:type dark: ndarray | ||
:param dummy: value for dead pixels (optional) | ||
:type dummy: float | ||
:param delta_dummy: precision for dead-pixel value in dynamic masking | ||
:type delta_dummy: float | ||
:param variance: the variance associate to the image | ||
:type variance: ndarray | ||
:param dark_variance: the variance associate to the dark | ||
:type dark_variance: ndarray | ||
:param flat: array with the dark-current value to be divided by (if any) | ||
:type flat: ndarray | ||
:param solidAngle: array with the solid angle of each pixel to be divided by (if any) | ||
:type solidAngle: ndarray | ||
:param polarization: array with the polarization correction values to be divided by (if any) | ||
:type polarization: ndarray | ||
:param absorption: Apparent efficiency of a pixel due to parallax effect | ||
:type absorption: ndarray | ||
:param safe: set to True to save some tests | ||
:param error_model: set to "poissonian" to use signal as variance (minimum 1), "azimuthal" to use the variance in a ring. | ||
:param normalization_factor: divide the valid result by this value | ||
:param quant_min: start percentile/100 to use. Use 0.5 for the median (default). 0<=quant_min<=1 | ||
:param quant_max: stop percentile/100 to use. Use 0.5 for the median (default). 0<=quant_max<=1 | ||
:return: namedtuple with "position intensity sigma signal variance normalization count std sem norm_sq" | ||
:rtype: Integrate1dtpl named-tuple of ndarrays | ||
""" | ||
error_model = ErrorModel.parse(error_model) | ||
cdef: | ||
index_t i, j, c, bad_pix, npix = self._indices.shape[0], idx = 0, start, stop, cnt=0 | ||
acc_t acc_sig = 0.0, acc_var = 0.0, acc_norm = 0.0, acc_count = 0.0, coef = 0.0, acc_norm_sq=0.0 | ||
acc_t cumsum = 0.0 | ||
data_t qmin, qmax | ||
data_t empty, sig, var, nrm, weight, nrm2 | ||
acc_t[::1] sum_sig = numpy.zeros(self.output_size, dtype=acc_d) | ||
acc_t[::1] sum_var = numpy.zeros(self.output_size, dtype=acc_d) | ||
acc_t[::1] sum_norm = numpy.zeros(self.output_size, dtype=acc_d) | ||
acc_t[::1] sum_norm_sq = numpy.zeros(self.output_size, dtype=acc_d) | ||
index_t[::1] sum_count = numpy.zeros(self.output_size, dtype=index_d) | ||
data_t[::1] merged = numpy.zeros(self.output_size, dtype=data_d) | ||
data_t[::1] stda = numpy.zeros(self.output_size, dtype=data_d) | ||
data_t[::1] sema = numpy.zeros(self.output_size, dtype=data_d) | ||
data_t[:, ::1] preproc4 | ||
bint do_azimuthal_variance = error_model == ErrorModel.AZIMUTHAL | ||
bint do_hybrid_variance = error_model == ErrorModel.HYBRID | ||
float4_t element, former_element | ||
float4_t[::1] work = numpy.zeros(npix, dtype=float4_d) | ||
|
||
assert weights.size == self.input_size, "weights size" | ||
empty = dummy if dummy is not None else self.empty | ||
#Call the preprocessor ... | ||
preproc4 = preproc(weights.ravel(), | ||
dark=dark, | ||
flat=flat, | ||
solidangle=solidangle, | ||
polarization=polarization, | ||
absorption=absorption, | ||
mask=self.cmask if self.check_mask else None, | ||
dummy=dummy, | ||
delta_dummy=delta_dummy, | ||
normalization_factor=normalization_factor, | ||
empty=self.empty, | ||
split_result=4, | ||
variance=variance, | ||
dtype=data_d, | ||
error_model=error_model, | ||
out=self.preprocessed) | ||
# print("start nogil", npix) | ||
with nogil: | ||
# Duplicate the input data and populate the large work-array | ||
for i in range(npix): # NOT faster in parallel ! | ||
weight = self._data[i] | ||
j = self._indices[i] | ||
sig = preproc4[j,0] | ||
var = preproc4[j,1] | ||
nrm = preproc4[j,2] | ||
element.s0 = sig/nrm # average signal | ||
element.s1 = sig * weight # weighted raw signal | ||
element.s2 = var * weight * weight # weighted raw variance | ||
element.s3 = nrm * weight # weighted raw normalization | ||
work[i] = element | ||
for idx in prange(self.output_size, schedule="guided"): | ||
start = self._indptr[idx] | ||
stop = self._indptr[idx+1] | ||
acc_sig = acc_var = acc_norm = acc_norm_sq = 0.0 | ||
cnt = 0 | ||
cumsum = 0.0 | ||
|
||
sort_float4(work[start:stop]) | ||
|
||
for i in range(start, stop): | ||
cumsum = cumsum + work[i].s3 | ||
work[i].s0 = cumsum | ||
|
||
qmin = quant_min * cumsum | ||
qmax = quant_max * cumsum | ||
|
||
element.s0 = 0.0 | ||
for i in range(start, stop): | ||
former_element = element | ||
element = work[i] | ||
if (qmin<=former_element.s0) and (element.s0 <= qmax): | ||
acc_sig = acc_sig + element.s1 | ||
acc_var = acc_var + element.s2 | ||
acc_norm = acc_norm + element.s3 | ||
acc_norm_sq = acc_norm_sq + element.s3*element.s3 | ||
cnt = cnt + 1 | ||
|
||
#collect things ... | ||
sum_sig[idx] = acc_sig | ||
sum_var[idx] = acc_var | ||
sum_norm[idx] = acc_norm | ||
sum_norm_sq[idx] = acc_norm_sq | ||
sum_count[idx] = cnt | ||
if (acc_norm_sq): | ||
merged[idx] = acc_sig/acc_norm | ||
stda[idx] = sqrt(acc_var / acc_norm_sq) | ||
sema[idx] = sqrt(acc_var) / acc_norm | ||
else: | ||
merged[idx] = empty | ||
stda[idx] = empty | ||
sema[idx] = empty | ||
|
||
#"position intensity sigma signal variance normalization count std sem norm_sq" | ||
return Integrate1dtpl(self.bin_centers, | ||
numpy.asarray(merged),numpy.asarray(sema) , | ||
numpy.asarray(sum_sig),numpy.asarray(sum_var), | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,7 +33,7 @@ | |
|
||
__author__ = "Jerome Kieffer" | ||
__contact__ = "[email protected]" | ||
__date__ = "21/08/2024" | ||
__date__ = "19/11/2024" | ||
__status__ = "stable" | ||
__license__ = "MIT" | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
#cython: embedsignature=True, language_level=3, binding=True | ||
#cython: boundscheck=False, wraparound=False, cdivision=True, initializedcheck=False, | ||
## This is for developping | ||
## cython: profile=True, warn.undeclared=True, warn.unused=True, warn.unused_result=False, warn.unused_arg=True | ||
##cython: profile=True, warn.undeclared=True, warn.unused=True, warn.unused_result=False, warn.unused_arg=True | ||
# | ||
# Project: Fast Azimuthal Integration | ||
# https://github.com/silx-kit/pyFAI | ||
|
@@ -35,7 +35,7 @@ Sparse matrix represented using the CompressedSparseRow. | |
|
||
__author__ = "Jérôme Kieffer" | ||
__contact__ = "[email protected]" | ||
__date__ = "04/10/2023" | ||
__date__ = "19/11/2024" | ||
__status__ = "stable" | ||
__license__ = "MIT" | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,7 +32,7 @@ | |
__contact__ = "[email protected]" | ||
__license__ = "MIT" | ||
__copyright__ = "European Synchrotron Radiation Facility, Grenoble, France" | ||
__date__ = "14/11/2024" | ||
__date__ = "19/11/2024" | ||
|
||
import sys | ||
import unittest | ||
|
Oops, something went wrong.