Source code for brighteyes_ism.simulation.PSF_sim

import numpy as np
import torch
from numbers import Number
import copy as cp

from psf_generator.propagators import VectorialCartesianPropagator
from psf_generator.utils.zernike import create_special_pupil, create_zernike_aberrations
from zernikepy.zernike_polynomials import lookup_table

from .detector import custom_detector
from .utils import partial_convolution

# %% functions

[docs] class GridParameters: """ It calculates a z-stack of PSFs for all the elements of the SPAD array detector. Attributes ---------- N : int Number of detector elements in the array in each dimension (typically 5) Nx : int Number of pixels in each dimension in the simulation array (e.g. 1024) pxpitch : float Pixel pitch of the detector [nm] (real space, typically 75000) pxdim : float Detector element size [nm] (real space, typically 50000) pxsizex : float Pixel size of the simulation space [nm] (typically 1) Nz : int number of axial planes (typically an odd integer) pinhole_shape : str Shape of the invidual pinhole. Valid choices are 'square', 'cirle', or 'hexagon'. geometry : str Detector geometry. Valid choices are 'rect' or 'hex'. name : str If 'airyscan', the simulated detector is the commercial 32-elements AiryScan from Zeiss. M : float Total magnification of the optical system (typically 500) rotation : float Detector rotation angle (rad) mirroring: int Flip of the horizonatal axis of the detector plane (+1 or - 1) """ # __slots__ = ['pxsizex', 'pxsizez', 'Nx', 'Nz', 'pxpitch', 'pxdim', 'N', 'M'] def __init__(self, pxsizex=40, pxsizez=50, Nx=100, Nz=1, pxpitch=75e3, pxdim=50e3, N=5, M=450): self.pxsizex = pxsizex # nm - lateral pixel size of the images self.pxsizez = pxsizez # nm - distance of the axial planes self.Nx = Nx # number of samples along the X and Y axis self.Nz = Nz # number of axial planes self.pxpitch = pxpitch # nm - spad array pixel pitch (real space) self.pxdim = pxdim # nm - spad pixel size (real space) 57.3e-3 for cooled spad self.pinhole_shape = 'square' # 'square', 'cirle', or 'hexagon' self.geometry = 'rect' # Pinholes arrangement: 'rect' or 'hex' self.N = N # number of pixels in the detector in each dimension (5x5 typically) self.M = M # overall magnification of the system self.rotation = 0 # rototion angle of the detector array (rad) self.mirroring = 1 # flip of the x_d axis (+/- 1) self.name = None # None or 'airyscan' @property def rangex(self): return self.Nx * self.pxsizex @property def rangez(self): return self.Nz * self.pxsizez @property def Nch(self): if np.ndim(self.N) == 0: if self.geometry == 'rect': Ntot = self.N ** 2 elif self.geometry == 'hex': Ntot = self.N ** 2 - self.N // 2 - (1 + (-1) ** ((self.N + 1) / 2)) * 0.5 elif np.ndim(self.N) == 1: Ntot = self.N[1] * self.N[0] else: raise ValueError('N has to a be a single or a couple of positive integers.') return Ntot
[docs] def spad_size(self, mode: str = 'magnified', simPar=None): size = (self.pxpitch * (self.N - 1) + self.pxdim) if simPar is not None: return size / self.M / simPar.airy_unit elif mode == 'magnified': return size / self.M elif mode == 'real': return size
[docs] def copy(self): return cp.copy(self)
[docs] def Print(self): dic = self.__dict__ names = list(dic) values = list(dic.values()) for n, name in enumerate(names): print(name, end='') print(' ' * int(14 - len(name)), end='') if values[n] is None: print("") elif isinstance(values[n], Number): print(f'{values[n]:.2f}') else: print(values[n])
[docs] class simSettings: """ Optical settings used to calculate the psf Read more at https://pyfocus.readthedocs.io/en/latest/ Attributes ---------- na : float numerical aperture n : float sample refractive index wl : float wavelength in vacuum [nm] h : float radius of aperture of the objective lens [mm] gamma : float parameter describing the light polarization (amplitude) beta : float parameter describing the light polarization (phase) w0 : float radius of the incident gaussian beam [mm] I0 : float Intensity of the entrance field [W/m**2] field : str spatial distribution of the entrance field 'PlaneWave' = flat field 'Gaussian' = gaussian beam of waist w0 mask : str phase mask None = no mask 'VP' = vortex phase plate mask_sampl : int entrance field and mask sampling (# points) sted_sat : float STED maximum saturation factor sted_pulse : float STED pulse duration [ns] sted_tau : float fluorescence lifetime [ns] abe_index: int or array aberration index abe_ampli : float or array aberration amplitude in rad Methods ------- f : float Returns the focal length. alpha : float Returns the semi angular aperture. aberration : str / list Returns the list of aberrations by name. """ # __slots__ = ['na', 'n', 'wl', 'h', 'gamma', 'beta', 'w0', 'I0', 'field', 'mask', 'mask_sampl', 'sted_sat', 'sted_pulse', 'sted_tau', 'abe_index', 'abe_ampli'] def __init__(self, na=1.4, n=1.5, wl=485.0, h=2.8, gamma=45.0, beta=90.0, w0=100.0, I0=1, field='PlaneWave', mask=None, mask_sampl=200, sted_sat=50, sted_pulse=1, sted_tau=3.5, abe_index=None, abe_ampli=None): self.na = na # numerical aperture self.n = n # sample refractive index self.wl = wl # wavelength [nm] self.h = h # radius of aperture of the objective lens [mm] self.w0 = w0 # radius of the incident gaussian beam [mm] self.gamma = gamma # parameter describing the light polarization (amplitude) self.beta = beta # parameter describing the light polarization (phase) self.I0 = I0 # Intensity of the entrance field self.field = field # entrance field at the pupil plane # 'PlaneWave' = Flat field # 'Gaussian' = Gaussian beam with waist w0 self.mask = mask # phase mask # None = no mask # 'VP' = vortex phase plate self.mask_sampl = mask_sampl # phase mask sampling self.sted_sat = sted_sat # STED maximum saturation factor self.sted_pulse = sted_pulse # STED pulse duration [ns] self.sted_tau = sted_tau # fluorescence lifetime [ns] self.abe_index = abe_index # aberration index (int or array) self.abe_ampli = abe_ampli # aberration amplitude in rad (float or array) @property def f(self): # focal length of the objective lens [mm] return self.h * self.n / self.na @property def alpha(self): # semiangular aperture of the objective [rad] return np.arcsin(self.na / self.n) @property def airy_unit(self): au = 1.22 * self.wl / self.na return au @property def depth_of_field(self): dof = 2 * self.n * self.wl / (self.na ** 2) return dof @property def aberration(self): if self.abe_index is None: return 'None' if np.isscalar(self.abe_index): zernike_index = [self.abe_index] else: zernike_index = self.abe_index names = [] for name, index in lookup_table.items(): if index in zernike_index: names.append(name) return names @property def wavefront(self): if self.abe_index is not None: zernike = torch.zeros(np.max(self.abe_index) + 1) zernike[self.abe_index] = torch.as_tensor(self.abe_ampli).float() else: zernike = torch.zeros(1) zer_wf = create_zernike_aberrations(zernike, self.mask_sampl, 'cartesian') if self.mask is not None: mask_wf = create_special_pupil(self.mask_sampl, self.mask) else: mask_wf = 1 wf = torch.angle(zer_wf * mask_wf).cpu().detach().numpy() x = np.linspace(-1, 1, num=self.mask_sampl) xx, yy = np.meshgrid(x, x) r = np.sqrt(xx**2 + yy**2) pupil = np.where(r < 1, 1, np.nan) return wf * pupil
[docs] def copy(self): return cp.copy(self)
[docs] def Print(self): dic = self.__dict__ names = list(dic) values = list(dic.values()) for n, name in enumerate(names): print(name, end='') print(' ' * int(14 - len(name)), end='') print(str(values[n]))
[docs] def singlePSF(par, pxsizex, Nx, rangez, nz, device: str = 'cpu'): """ Simulate PSFs with PyFocus Parameters ---------- par : simSettings object Object with PSF parameters pxsizex : float Pixel size of the simulation space in XY [nm] (typically 1) Nx : int Number of pixels in XY dimensions in the simulation array, e.g. 1024 Returns ------- exPSF : np.array(Nx x Nx) with the excitation PSF calculated from exPSF emPSF : np.array(Nx x Nx) with the emission PSF calculated from emPSF """ kwargs = { 'apod_factor': True, 'defocus_min': rangez[0], 'defocus_max': rangez[1], 'n_defocus': nz, 'n_pix_psf': Nx, 'fov': Nx * pxsizex, 'n_pix_pupil': par.mask_sampl, 'na': par.na, 'n_i0': par.n, 'wavelength': par.wl / par.n, 'e0x': np.cos(np.deg2rad(par.gamma)), 'e0y': np.sin(np.deg2rad(par.gamma)) * np.exp(1j * np.deg2rad(par.beta)) } # Amplitude envelope if par.field == 'Gaussian': kwargs.update({ 'envelope': par.wo }) # Phase Mask if par.mask is not None: kwargs.update({ 'special_phase_mask': par.mask }) if par.abe_index is not None: zernike = torch.zeros(np.max(par.abe_index) + 1) zernike[par.abe_index] = torch.as_tensor(par.abe_ampli).float() kwargs.update({ 'zernike_coefficients': zernike }) propagator = VectorialCartesianPropagator(**kwargs, device=device) fields = propagator.compute_focus_field() psf = torch.sum(torch.abs(fields) ** 2, 1) psf = psf * par.I0 / torch.sum(psf) return psf, fields
[docs] def SPAD_PSF_3D(gridPar, exPar, emPar, stedPar=None, spad=None, n_photon_excitation: int = 1, stack: str = 'symmetrical', normalize: bool = True, process: str = 'gpu', output: str = 'numpy'): """ It calculates a z-stack of PSFs for all the elements of the SPAD array detector. Parameters ---------- gridPar : GridParameters object object with simulation space parameters exPar : simSettings object object with excitation PSF parameters emPar : simSettings object object with emission PSF parameters n_photon_excitation : int Order of non-linear excitation. Default is 1. stedPar : simSettings object object with STED beam parameters spad : np.array( N**2 x Nx x Nx) Pinholes distribution . If none it is calculated using the input parameters stack : str String that defines the direction along z of the simulation. If "symmetrical", the stack is generated at planes around z = 0 both on the negative and positive directions. Other possible entries are "positive", and "negative". Default: "symmetrical". normalize : bool If True, the returned PSFs are divided by the total flux calculated on the focal plane (z=0). Default is True. Returns ------- PSF : np.array(Nz x Nx x Nx x N**2) array with the overall PSFs for each detector element detPSF : np.array(Nz x Nx x Nx x N**2) array with the detection PSFs for each detector element exPSF : np.array(Nz x Nx x Nx) array with the excitation PSF """ device = torch.device("cuda:0" if torch.cuda.is_available() and process == 'gpu' else "cpu") if stack == "symmetrical": zeta = (np.arange(gridPar.Nz) - gridPar.Nz // 2) * gridPar.pxsizez elif stack == "positive": zeta = np.arange(gridPar.Nz) * gridPar.pxsizez elif stack == "negative": zeta = -np.arange(gridPar.Nz) * gridPar.pxsizez else: zeta = stack # simulate detector array if spad is None: pinholes = custom_detector(gridPar, device) else: pinholes = spad # Simulate ism psfs exPSF, _ = singlePSF(exPar, gridPar.pxsizex, gridPar.Nx, [zeta[0], zeta[-1]], gridPar.Nz, device) emPSF, _ = singlePSF(emPar, gridPar.pxsizex, gridPar.Nx, [zeta[0], zeta[-1]], gridPar.Nz, device) detPSF = partial_convolution(emPSF, pinholes, 'zxy', 'xyc', 'xy') # Apply non-linearity to excitation if n_photon_excitation > 1: exPSF = exPSF ** n_photon_excitation # Simulate donut if type(stedPar) == simSettings: stedPar.mask = 'VP' donut = singlePSF(stedPar, gridPar.pxsizex, gridPar.Nx, [zeta[0], zeta[-1]], gridPar.Nz) donut *= stedPar.sted_sat / np.max(donut) stedPSF = np.exp(-donut * stedPar.sted_pulse / stedPar.sted_tau) exPSF *= stedPSF # Calculate total PSF PSF = torch.einsum('zxyc, zxy -> zxyc', detPSF, exPSF) if normalize is True: idx = np.argwhere(zeta == 0).item() focal_flux = PSF[idx].sum() for i, z in enumerate(zeta): PSF[i] = PSF[i] / focal_flux if output == 'numpy': return PSF.cpu().detach().numpy(), detPSF.cpu().detach().numpy(), exPSF.cpu().detach().numpy() elif output == 'tensor': return PSF, detPSF, exPSF else: raise Exception('Output unknown')
[docs] def SPAD_PSF_2D(gridPar, exPar, emPar, n_photon_excitation=1, stedPar=None, z_shift=0, spad=None, normalize=True, process: str = 'gpu', output: str = 'numpy'): """ Calculate PSFs for all pixels of the SPAD array by using FFTs Parameters ---------- gridPar : GridParameters object object with simulation space parameters exPar : simSettings object object with excitation PSF parameters emPar : simSettings object object with emission PSF parameters n_photon_excitation : int Order of non-linear excitation. Default is 1. stedPar : simSettings object object with STED beam parameters z_shift : float Distance from the focal plane at which generate the PSF [nm] spad : np.array( N**2 x Nx x Nx) Pinholes distribution . If none it is calculated using the input parameters normalize : bool If True, all the returned PSFs are divided by the total flux. Default is True. Returns ------- PSF : np.array(Nx x Nx x N**2) array with the overall PSFs for each detector element detPSF : np.array(Nx x Nx x N**2) array with the detection PSFs for each detector element exPSF : np.array(Nx x Nx) array with the excitation PSF """ # simulate detector array grid = gridPar.copy() grid.Nz = 1 stack = [z_shift, z_shift] PSF, detPSFrot, exPSF = SPAD_PSF_3D(grid, exPar, emPar, stedPar, spad, n_photon_excitation, stack, False, process, output=output) PSF = np.squeeze(PSF) detPSFrot = np.squeeze(detPSFrot) exPSF = np.squeeze(exPSF) if normalize is True: PSF = PSF / PSF.sum() detPSFrot = detPSFrot / detPSFrot.sum() exPSF = exPSF / exPSF.sum() return PSF, detPSFrot, exPSF