-
Notifications
You must be signed in to change notification settings - Fork 250
Add antialiasing option for downsampling #1314
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@romainVala, could you please take a look? |
267fdd0
to
d117019
Compare
Hi there That funny, again it makes me think to an discussion I had on nibable I try to summarize the main point (hopefully short and clear) when doing a resampling factor 8 with itk, I get strange result, (similar to your ex 1) but if I do three consecutive resampling they it looks better
I definitively missing some basis of signal processing, and I am not sure to understand the antialias filter ... but you mus be right, since this is what Lestropie say : nipy/nibabel#1264 (comment) |
ok my multiple resample is working only because it uses 3 successive interpolation. Indeed if you compare with the same 3 consecutive resampling, but with What I do not understand, is why adding a smoothing when a simple average pooling is performing the job exactly (ie without interpolaiton ) ... ? |
@fepegar : how do you construct (or get ) the test image ? (so that I look the difference with average pooling) |
Thank you, @romainVala. For reference, here are the images generated by your code. Also added the one with nearest neighbor interpolation. This is definitely a digital signal processing (DSP) problem, but I can maybe try to explain intuitively and from a DSP point of view (although my undergrad happened a long time ago). Let's define some downsampling functions for 1D signals. >>> def avg_pool(x, k=2):
... return torch.nn.functional.avg_pool1d(x.unsqueeze(0).unsqueeze(0), kernel_size=k, stride=k)[0, 0]
...
>>>
>>> def resample(x, f=2, mode='linear'):
... align_corners = False if mode == 'linear' else None
... return torch.nn.functional.interpolate(x.unsqueeze(0).unsqueeze(0), scale_factor=1/f, mode=mode, align_corners=align_corners)[0, 0] Here's our "image": >>> x = torch.tensor([1, 8, 3, 9]).float() To resample by a factor of 2, we try to figure out which values would go between our current values:
Because it's linear interpolation and we're right between two samples, we use 0.5 * left + 0.5 * right = 0.5 * (left + right) = (left + right) / 2. Here, (1 + 8) / 2 = 4.5 and (3 + 9) / 2 = 6. >>> resampled_2 = resample(x, f=2)
>>> resampled_2
tensor([4.5000, 6.0000]) If we do that again with our output, we have (4.5 + 6) / 2 = 5.25. >>> resampled_2_2 = resample(resampled_2, f=2)
>>> resampled_2_2
tensor([5.2500]) If we resample directly by a factor of 4, we're in this situation:
And we get (8 + 3) / 2 = 5.5. Different to the previous 5.25! >>> resampled_4 = resample(x, f=4)
>>> resampled_4
tensor([5.5000]) If we use nearest neighbor interpolation, we only look at the closest sample here.
As we're right in between, we need to arbitrarily choose which sample we look at. That'd be an implementation detail. Let's see what PyTorch does: >>> resampled_nn_2 = resample(x, f=2, mode='nearest')
>>> resampled_nn_2
tensor([1., 2.]) It seems to use the value from the smaller index. You can start seeing here why nearest neighbour is bad. The output is not a smooth version of the input at all. If we do this again, we'll get I was curious to see what would happen with a factor of 4. >>> resampled_nn_4 = resample(x, f=4, mode='nearest')
>>> resampled_nn_4 It just takes the first index! I checked my link in and noticed that there is a >>> resampled_nn_exact_2 = resample(x, f=2, mode='nearest-exact')
>>> resampled_nn_exact_2
tensor([8., 9.]) >>> resampled_nn_exact_2_2 = resample(resampled_nn_exact_2, f=2, mode='nearest-exact')
>>> resampled_nn_exact_2_2
tensor([9.]) It seems to take the higher index instead. >>> resampled_nn_exact_4 = resample(x, f=4, mode='nearest-exact')
>>> resampled_nn_exact_4
tensor([3.]) This is better than Anyway, let's move on. Because of the downsampling factor of 2, average pooling is like resampling with linear interpolation: >>> avg_pooled_2 = avg_pool(x, k=2)
>>> avg_pooled_2
tensor([4.5000, 6.0000]) and of course >>> avg_pooled_2_2 = avg_pool(avg_pooled_2, k=2)
>>> avg_pooled_2_2
tensor([5.2500]) Now, average pooling with a kernel of size 4 is different. You need the average of the four values: (1 + 8 + 3 + 9) / 4 = 5.25: >>> avg_pooled_4 = avg_pool(x, k=4)
>>> avg_pooled_4
tensor([5.2500]) From a DSP point of view, there are a few things to take into account. When you resample a signal with a specific sampling frequency, aliases (copies) of the signal spectrum centered on multiples of the sampling frequency are generated (can be shown with maths 🪄 ) If your sampling frequency is very high, you're good because your aliases are far from your signal. But for downsampling, your frequency is low and the aliases will overlap with your signal, creating artifacts. The solution is to apply a low-pass filter to make your spectrum narrower and then the aliases won't overlap with your signal. To apply a low-pass filter in Fourier space we can multiply the spectrum F(k) of our image f(x) by a step function S(k). That should remove high frequencies, i.e. blur the image (from an old coursework exercise): But a multiplication F(k) × S(k) in Fourier space is mathematically equivalent to a convolution f(x) * s(x) in image space. The inverse Fourier transform is a sinc function. So we could convolve our image with a sinc function for resampling, which is nice but expensive. That's Lanzcos interpolation. Instead of a fancy sinc function with a wide support, we could use a more boring approximation with a small kernel. How about approximating the sinc with something that looks like a rectangle and convolve that? But wait, that's exactly what we did above!
Interestingly, the Fourier transform of a step function is a sinc. If we multiply the spectrum by a sinc, the low-pass filter is not very nice but still does something reasonable. So both 1) resampling with linear interpolation and 2) average pooling are not terrible choices for downsampling! But using large downsampling factors means low sampling frequencies and therefore aliases that will overlap a lot with our image unless we perform aggressive low-pass filtering. Average pooling with a kernel of size 8 works better than resampling by a factor of 8 because the kernel of average pooling (8) is larger than the one used for linear interpolation (2). A larger (flat) kernel in image space means a narrower sinc in Fourier space, meaning a more aggressive low-pass filter and therefore aliasing is reduced. Now, maybe we can use a nicer filter than a sinc in Fourier space. How about a Gaussian? It turns out that the Fourier transform of a Gaussian is a Gaussian, that's why we use a Gaussian kernel for filtering (convolving). The question is how to choose the variance of the Gaussian, but thankfully smart people have written about that, and I cite them in this PR. A bit messy and not super accurate, but hopefully that helps a bit! |
Also, this is the test image: import math
import numpy as np
def zone_plate_3d(
shape: tuple[int, int, int],
base_frequency: float = 1,
) -> np.ndarray:
coords = [
np.linspace(-1.0, 1.0, n, endpoint=False)
for n in shape
]
z, y, x = np.meshgrid(*coords, indexing="ij", copy=False)
r2 = x * x + y * y + z * z
phase = math.pi * base_frequency * r2
volume = 0.5 * (1.0 + np.cos(phase))
return volume
shape = 3 * (256,)
volume = zone_plate_3d(shape, base_frequency=16)
zone_plate = tio.ScalarImage(tensor=volume[None])
zone_plate.plot() avg_pool_8 = tio.Lambda(
lambda x: torch.nn.functional.avg_pool3d(x, kernel_size=8, stride=8)
)
downsampled_zone_plate = avg_pool_8(zone_plate)
downsampled_zone_plate.plot() Compare to the image filtered with a Gaussian kernel: Note that the artifacts are less visible when Gaussian smoothing is applied before resampling. |
d117019
to
ffbcd13
Compare
Thanks a lot for the nice detail answer, it helps indeed. ! I need more time to get the second part, but I like the first part. I only consider the case where you downsample by a fact 2^n. then the operation you want is the simple average over the 2^n voxel (This is what you expect to measure with MRI if you compare a 1mm acquisition with a 2^n mm one (supposing the FOV is alligned any voxel valu will be the average of the 8 1mm voxel ...) ... what do I miss ? I think the aliasing story is more related to the content of the image, (related to the gibs effect ...) but I still not understand why it should be taken into account when dowsampling by a factor 2^n ... |
of course smothing remove noise, and high spatial frequency (which will change a lot the input image by removing the rings). it is funny to do the same plot with But my point is that I am not sure this is an aterfact, the avgpool version is the correct low resolution content ... |
I disagree that these methods are "right" or "wrong". I think you just didn't expect resampling with linear interpolation to behave the way it does, and average pooling is closer to what you expected.
That's just how linear interpolation works. You go to the point in the target grid whose value you want to estimate, look for the two (1D), four (2D) or eight (3D) closest points and compute a weighted sum of their values, where the weights are inversely proportional to the distance to each point.
If that's what you want, don't use linear interpolation. Or you could apply a low-pass filter first and then resample (for example using a box kernel, which is equivalent to average pooling. But using a Gaussian kernel will work better).
Yes, that's about how you sample a continuous signal, not how you resample a discrete signal. They're different things!
Yes, artifacts will be more visible in images with higher frequency contents.
I tried to explain above, I'm not sure how much more I can say... For a factor of 2, it doesn't matter much. But I think it's been demonstrated that smoothing before resampling matters a lot for large downsampling factors, and that Gaussian smoothing works better than average pooling. |
may be but if can you conclude gaussian smoothing is better then at least avg pooling is better than linear ressampling (rigth ?) What is your criteria to say better ? (if this is to show less artefact, then just smooth more you'll reduce them even more...) When is say avg pooling is right, it is regarding the following question: How to downsample, a discrete signal, so that the resulting low res discrete signal would be the same than the one we would get from sampling the same continuous signal ? if the input an the output grid are aligned, then the correct downsampling operation should be average pooling |
Yes, that's expected. Intuitively, as the image is already smooth, the smoothing filter will have little effect on it and downsampling won't make a large difference. From a theory point of view, the signal spectrum is narrow and therefore the alias spectrum generated at the new sampling frequency is still far away and therefore doesn't overlap with the signal spectrum and no smoothing filter is needed. If we do smooth, the smoothing filter is probably much wider than the signal's spectrum so again nothing changes much.
Generally, yes.
Intuitively, using a Gaussian kernel to compute weights for averaging makes more sense (is better) than using average pooling (a box kernel): values further away from the new point should contribute less than values that are closer. From a theory point of view, the Fourier transform of a Gaussian is better than the Fourier transform of a box: The box filter creates zeros and ripples in the Fourier space, which might manifest as artifacts. |
Related to #48.
Description
Note e.g. the ventricles and basal ganglia in the sagittal view.
One more example with a test image. Note aliasing is smaller when enabling the new setting.
Checklist
CONTRIBUTING
docs and have a developer setup ready