Source code for pytomography.transforms.shared.filters
import torch
import numpy as np
from torch.nn.functional import conv1d
import pytomography
from pytomography.metadata import ObjectMeta, ProjMeta
from pytomography.transforms import Transform
[docs]class GaussianFilter(Transform):
"""Applies a Gaussian smoothing filter to the reconstructed object with the specified full-width-half-max (FWHM)
Args:
FWHM (float): Specifies the width of the gaussian
"""
def __init__(self, FWHM: float):
self.sigma = FWHM / (2*np.sqrt(2*np.log(2)))
[docs] def _get_kernels(self):
"""Obtains required kernels for smoothing
"""
self.kernels = []
for i in range(3):
dx = self.object_meta.dr[i]
kernel_size = int(2*np.ceil(3*self.sigma/dx)+1)
x = torch.arange(-int(kernel_size//2), int(kernel_size//2)+1).to(pytomography.device)*dx
k = torch.exp(-x**2/(2*self.sigma**2)).reshape(1,1,-1)
self.kernels.append(k/k.sum())
[docs] def __call__(self, object):
"""Alternative way to call"""
return self.forward(object)
[docs] def forward(self, object):
"""Applies the Gaussian smoothing
Args:
object (torch.tensor): Object to smooth
Returns:
torch.tensor: Smoothed object
"""
for i in [1,2,3]:
object = object.swapaxes(i,3)
new_shape = object.shape
object = object.reshape(-1,1,object.shape[i])
object = conv1d(object, self.kernels[i-1], padding='same')
object = object.reshape(new_shape)
object= object.swapaxes(i,3)
return object
[docs] def backward(self, object, norm_constant=None):
"""Applies Gaussian smoothing in back projection. Because the operation is symmetric, it is the same as the forward projection.
Args:
object (torch.tensor): Object to smooth
norm_constant (torch.tensor, optional): Normalization constant used in iterative algorithms. Defaults to None.
Returns:
torch.tensor: Smoothed object
"""
object = self.forward(object)
if norm_constant is not None:
norm_constant = self.forward(norm_constant)
return object, norm_constant
else:
return object