import numpy as np
import torch
from numbers import Number
import copy as cp
from psf_generator.propagators import VectorialCartesianPropagator
from psf_generator.utils.zernike import create_special_pupil, create_zernike_aberrations
from zernikepy.zernike_polynomials import lookup_table
from .detector import custom_detector
from .utils import partial_convolution
# %% functions
[docs]
class GridParameters:
"""
It calculates a z-stack of PSFs for all the elements of the SPAD array detector.
Attributes
----------
N : int
Number of detector elements in the array in each dimension (typically 5)
Nx : int
Number of pixels in each dimension in the simulation array (e.g. 1024)
pxpitch : float
Pixel pitch of the detector [nm] (real space, typically 75000)
pxdim : float
Detector element size [nm] (real space, typically 50000)
pxsizex : float
Pixel size of the simulation space [nm] (typically 1)
Nz : int
number of axial planes (typically an odd integer)
pinhole_shape : str
Shape of the invidual pinhole. Valid choices are 'square', 'cirle', or 'hexagon'.
geometry : str
Detector geometry. Valid choices are 'rect' or 'hex'.
name : str
If 'airyscan', the simulated detector is the commercial 32-elements AiryScan from Zeiss.
M : float
Total magnification of the optical system (typically 500)
rotation : float
Detector rotation angle (rad)
mirroring: int
Flip of the horizonatal axis of the detector plane (+1 or - 1)
"""
# __slots__ = ['pxsizex', 'pxsizez', 'Nx', 'Nz', 'pxpitch', 'pxdim', 'N', 'M']
def __init__(self, pxsizex=40, pxsizez=50, Nx=100, Nz=1, pxpitch=75e3, pxdim=50e3, N=5, M=450):
self.pxsizex = pxsizex # nm - lateral pixel size of the images
self.pxsizez = pxsizez # nm - distance of the axial planes
self.Nx = Nx # number of samples along the X and Y axis
self.Nz = Nz # number of axial planes
self.pxpitch = pxpitch # nm - spad array pixel pitch (real space)
self.pxdim = pxdim # nm - spad pixel size (real space) 57.3e-3 for cooled spad
self.pinhole_shape = 'square' # 'square', 'cirle', or 'hexagon'
self.geometry = 'rect' # Pinholes arrangement: 'rect' or 'hex'
self.N = N # number of pixels in the detector in each dimension (5x5 typically)
self.M = M # overall magnification of the system
self.rotation = 0 # rototion angle of the detector array (rad)
self.mirroring = 1 # flip of the x_d axis (+/- 1)
self.name = None # None or 'airyscan'
@property
def rangex(self):
return self.Nx * self.pxsizex
@property
def rangez(self):
return self.Nz * self.pxsizez
@property
def Nch(self):
if np.ndim(self.N) == 0:
if self.geometry == 'rect':
Ntot = self.N ** 2
elif self.geometry == 'hex':
Ntot = self.N ** 2 - self.N // 2 - (1 + (-1) ** ((self.N + 1) / 2)) * 0.5
elif np.ndim(self.N) == 1:
Ntot = self.N[1] * self.N[0]
else:
raise ValueError('N has to a be a single or a couple of positive integers.')
return Ntot
[docs]
def spad_size(self, mode: str = 'magnified', simPar=None):
size = (self.pxpitch * (self.N - 1) + self.pxdim)
if simPar is not None:
return size / self.M / simPar.airy_unit
elif mode == 'magnified':
return size / self.M
elif mode == 'real':
return size
[docs]
def copy(self):
return cp.copy(self)
[docs]
def Print(self):
dic = self.__dict__
names = list(dic)
values = list(dic.values())
for n, name in enumerate(names):
print(name, end='')
print(' ' * int(14 - len(name)), end='')
if values[n] is None:
print("")
elif isinstance(values[n], Number):
print(f'{values[n]:.2f}')
else:
print(values[n])
[docs]
class simSettings:
"""
Optical settings used to calculate the psf
Read more at https://pyfocus.readthedocs.io/en/latest/
Attributes
----------
na : float
numerical aperture
n : float
sample refractive index
wl : float
wavelength in vacuum [nm]
h : float
radius of aperture of the objective lens [mm]
gamma : float
parameter describing the light polarization (amplitude)
beta : float
parameter describing the light polarization (phase)
w0 : float
radius of the incident gaussian beam [mm]
I0 : float
Intensity of the entrance field [W/m**2]
field : str
spatial distribution of the entrance field
'PlaneWave' = flat field
'Gaussian' = gaussian beam of waist w0
mask : str
phase mask
None = no mask
'VP' = vortex phase plate
mask_sampl : int
entrance field and mask sampling (# points)
sted_sat : float
STED maximum saturation factor
sted_pulse : float
STED pulse duration [ns]
sted_tau : float
fluorescence lifetime [ns]
abe_index: int or array
aberration index
abe_ampli : float or array
aberration amplitude in rad
Methods
-------
f : float
Returns the focal length.
alpha : float
Returns the semi angular aperture.
aberration : str / list
Returns the list of aberrations by name.
"""
# __slots__ = ['na', 'n', 'wl', 'h', 'gamma', 'beta', 'w0', 'I0', 'field', 'mask', 'mask_sampl', 'sted_sat', 'sted_pulse', 'sted_tau', 'abe_index', 'abe_ampli']
def __init__(self, na=1.4, n=1.5, wl=485.0, h=2.8, gamma=45.0, beta=90.0,
w0=100.0, I0=1, field='PlaneWave', mask=None, mask_sampl=200,
sted_sat=50, sted_pulse=1, sted_tau=3.5,
abe_index=None, abe_ampli=None):
self.na = na # numerical aperture
self.n = n # sample refractive index
self.wl = wl # wavelength [nm]
self.h = h # radius of aperture of the objective lens [mm]
self.w0 = w0 # radius of the incident gaussian beam [mm]
self.gamma = gamma # parameter describing the light polarization (amplitude)
self.beta = beta # parameter describing the light polarization (phase)
self.I0 = I0 # Intensity of the entrance field
self.field = field # entrance field at the pupil plane
# 'PlaneWave' = Flat field
# 'Gaussian' = Gaussian beam with waist w0
self.mask = mask # phase mask
# None = no mask
# 'VP' = vortex phase plate
self.mask_sampl = mask_sampl # phase mask sampling
self.sted_sat = sted_sat # STED maximum saturation factor
self.sted_pulse = sted_pulse # STED pulse duration [ns]
self.sted_tau = sted_tau # fluorescence lifetime [ns]
self.abe_index = abe_index # aberration index (int or array)
self.abe_ampli = abe_ampli # aberration amplitude in rad (float or array)
@property
def f(self): # focal length of the objective lens [mm]
return self.h * self.n / self.na
@property
def alpha(self): # semiangular aperture of the objective [rad]
return np.arcsin(self.na / self.n)
@property
def airy_unit(self):
au = 1.22 * self.wl / self.na
return au
@property
def depth_of_field(self):
dof = 2 * self.n * self.wl / (self.na ** 2)
return dof
@property
def aberration(self):
if self.abe_index is None:
return 'None'
if np.isscalar(self.abe_index):
zernike_index = [self.abe_index]
else:
zernike_index = self.abe_index
names = []
for name, index in lookup_table.items():
if index in zernike_index:
names.append(name)
return names
@property
def wavefront(self):
if self.abe_index is not None:
zernike = torch.zeros(np.max(self.abe_index) + 1)
zernike[self.abe_index] = torch.as_tensor(self.abe_ampli).float()
else:
zernike = torch.zeros(1)
zer_wf = create_zernike_aberrations(zernike, self.mask_sampl, 'cartesian')
if self.mask is not None:
mask_wf = create_special_pupil(self.mask_sampl, self.mask)
else:
mask_wf = 1
wf = torch.angle(zer_wf * mask_wf).cpu().detach().numpy()
x = np.linspace(-1, 1, num=self.mask_sampl)
xx, yy = np.meshgrid(x, x)
r = np.sqrt(xx**2 + yy**2)
pupil = np.where(r < 1, 1, np.nan)
return wf * pupil
[docs]
def copy(self):
return cp.copy(self)
[docs]
def Print(self):
dic = self.__dict__
names = list(dic)
values = list(dic.values())
for n, name in enumerate(names):
print(name, end='')
print(' ' * int(14 - len(name)), end='')
print(str(values[n]))
[docs]
def singlePSF(par, pxsizex, Nx, rangez, nz, device: str = 'cpu'):
"""
Simulate PSFs with PyFocus
Parameters
----------
par : simSettings object
Object with PSF parameters
pxsizex : float
Pixel size of the simulation space in XY [nm] (typically 1)
Nx : int
Number of pixels in XY dimensions in the simulation array, e.g. 1024
Returns
-------
exPSF : np.array(Nx x Nx)
with the excitation PSF calculated from exPSF
emPSF : np.array(Nx x Nx)
with the emission PSF calculated from emPSF
"""
kwargs = {
'apod_factor': True,
'defocus_min': rangez[0],
'defocus_max': rangez[1],
'n_defocus': nz,
'n_pix_psf': Nx,
'fov': Nx * pxsizex,
'n_pix_pupil': par.mask_sampl,
'na': par.na,
'n_i0': par.n,
'wavelength': par.wl / par.n,
'e0x': np.cos(np.deg2rad(par.gamma)),
'e0y': np.sin(np.deg2rad(par.gamma)) * np.exp(1j * np.deg2rad(par.beta))
}
# Amplitude envelope
if par.field == 'Gaussian':
kwargs.update({
'envelope': par.wo
})
# Phase Mask
if par.mask is not None:
kwargs.update({
'special_phase_mask': par.mask
})
if par.abe_index is not None:
zernike = torch.zeros(np.max(par.abe_index) + 1)
zernike[par.abe_index] = torch.as_tensor(par.abe_ampli).float()
kwargs.update({
'zernike_coefficients': zernike
})
propagator = VectorialCartesianPropagator(**kwargs, device=device)
fields = propagator.compute_focus_field()
psf = torch.sum(torch.abs(fields) ** 2, 1)
psf = psf * par.I0 / torch.sum(psf)
return psf, fields
[docs]
def SPAD_PSF_3D(gridPar, exPar, emPar, stedPar=None, spad=None, n_photon_excitation: int = 1,
stack: str = 'symmetrical', normalize: bool = True, process: str = 'gpu', output: str = 'numpy'):
"""
It calculates a z-stack of PSFs for all the elements of the SPAD array detector.
Parameters
----------
gridPar : GridParameters object
object with simulation space parameters
exPar : simSettings object
object with excitation PSF parameters
emPar : simSettings object
object with emission PSF parameters
n_photon_excitation : int
Order of non-linear excitation. Default is 1.
stedPar : simSettings object
object with STED beam parameters
spad : np.array( N**2 x Nx x Nx)
Pinholes distribution . If none it is calculated using the input parameters
stack : str
String that defines the direction along z of the simulation.
If "symmetrical", the stack is generated at planes around z = 0 both on the negative and positive directions.
Other possible entries are "positive", and "negative".
Default: "symmetrical".
normalize : bool
If True, the returned PSFs are divided by the total flux calculated on the focal plane (z=0).
Default is True.
Returns
-------
PSF : np.array(Nz x Nx x Nx x N**2)
array with the overall PSFs for each detector element
detPSF : np.array(Nz x Nx x Nx x N**2)
array with the detection PSFs for each detector element
exPSF : np.array(Nz x Nx x Nx)
array with the excitation PSF
"""
device = torch.device("cuda:0" if torch.cuda.is_available() and process == 'gpu' else "cpu")
if stack == "symmetrical":
zeta = (np.arange(gridPar.Nz) - gridPar.Nz // 2) * gridPar.pxsizez
elif stack == "positive":
zeta = np.arange(gridPar.Nz) * gridPar.pxsizez
elif stack == "negative":
zeta = -np.arange(gridPar.Nz) * gridPar.pxsizez
else:
zeta = stack
# simulate detector array
if spad is None:
pinholes = custom_detector(gridPar, device)
else:
pinholes = spad
# Simulate ism psfs
exPSF, _ = singlePSF(exPar, gridPar.pxsizex, gridPar.Nx, [zeta[0], zeta[-1]], gridPar.Nz, device)
emPSF, _ = singlePSF(emPar, gridPar.pxsizex, gridPar.Nx, [zeta[0], zeta[-1]], gridPar.Nz, device)
detPSF = partial_convolution(emPSF, pinholes, 'zxy', 'xyc', 'xy')
# Apply non-linearity to excitation
if n_photon_excitation > 1:
exPSF = exPSF ** n_photon_excitation
# Simulate donut
if type(stedPar) == simSettings:
stedPar.mask = 'VP'
donut = singlePSF(stedPar, gridPar.pxsizex, gridPar.Nx, [zeta[0], zeta[-1]], gridPar.Nz)
donut *= stedPar.sted_sat / np.max(donut)
stedPSF = np.exp(-donut * stedPar.sted_pulse / stedPar.sted_tau)
exPSF *= stedPSF
# Calculate total PSF
PSF = torch.einsum('zxyc, zxy -> zxyc', detPSF, exPSF)
if normalize is True:
idx = np.argwhere(zeta == 0).item()
focal_flux = PSF[idx].sum()
for i, z in enumerate(zeta):
PSF[i] = PSF[i] / focal_flux
if output == 'numpy':
return PSF.cpu().detach().numpy(), detPSF.cpu().detach().numpy(), exPSF.cpu().detach().numpy()
elif output == 'tensor':
return PSF, detPSF, exPSF
else:
raise Exception('Output unknown')
[docs]
def SPAD_PSF_2D(gridPar, exPar, emPar, n_photon_excitation=1, stedPar=None, z_shift=0, spad=None, normalize=True,
process: str = 'gpu', output: str = 'numpy'):
"""
Calculate PSFs for all pixels of the SPAD array by using FFTs
Parameters
----------
gridPar : GridParameters object
object with simulation space parameters
exPar : simSettings object
object with excitation PSF parameters
emPar : simSettings object
object with emission PSF parameters
n_photon_excitation : int
Order of non-linear excitation. Default is 1.
stedPar : simSettings object
object with STED beam parameters
z_shift : float
Distance from the focal plane at which generate the PSF [nm]
spad : np.array( N**2 x Nx x Nx)
Pinholes distribution . If none it is calculated using the input parameters
normalize : bool
If True, all the returned PSFs are divided by the total flux.
Default is True.
Returns
-------
PSF : np.array(Nx x Nx x N**2)
array with the overall PSFs for each detector element
detPSF : np.array(Nx x Nx x N**2)
array with the detection PSFs for each detector element
exPSF : np.array(Nx x Nx)
array with the excitation PSF
"""
# simulate detector array
grid = gridPar.copy()
grid.Nz = 1
stack = [z_shift, z_shift]
PSF, detPSFrot, exPSF = SPAD_PSF_3D(grid, exPar, emPar, stedPar, spad, n_photon_excitation, stack, False, process, output=output)
PSF = np.squeeze(PSF)
detPSFrot = np.squeeze(detPSFrot)
exPSF = np.squeeze(exPSF)
if normalize is True:
PSF = PSF / PSF.sum()
detPSFrot = detPSFrot / detPSFrot.sum()
exPSF = exPSF / exPSF.sum()
return PSF, detPSFrot, exPSF