Source code for tapqir.distributions.util

# Copyright Contributors to the Tapqir project.
# SPDX-License-Identifier: Apache-2.0

"""
Distribution utility functions
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
"""

import math
from functools import lru_cache

import torch


[docs]def gaussian_spots( height: torch.Tensor, # (N, F, C, K) or (N, F, Q, 1, K) width: torch.Tensor, # (N, F, C, K) or (N, F, Q, 1, K) x: torch.Tensor, # (N, F, C, K) or (N, F, Q, 1, K) y: torch.Tensor, # (N, F, C, K) or (N, F, Q, 1, K) target_locs: torch.Tensor, # (N, F, C, 1, 2) or (N, F, 1, C, 1, 2) P: int, m: torch.Tensor = None, ) -> torch.Tensor: r""" Calculates ideal shape of the 2D-Gaussian spots given spot parameters and target positions. .. math:: \mu^S_{\mathsf{pixelX}(i), \mathsf{pixelY}(j)} = \dfrac{m \cdot h}{2 \pi w^2} \exp{\left( -\dfrac{(i-x-x^\mathsf{target})^2 + (j-y-y^\mathsf{target})^2}{2 w^2} \right)} :param height: Integrated spot intensity. Should be broadcastable to ``batch_shape``. :param width: Spot width. Should be broadcastable to ``batch_shape``. :param x: Spot center on x-axis. Should be broadcastable to ``batch_shape``. :param y: Spot center on y-axis. Should be broadcastable to ``batch_shape``. :param target_locs: Target location. Should have the rightmost size ``2`` correspondnig to locations on x- and y-axes, and be broadcastable to ``batch_shape + (2,)``. :param P: Number of pixels along the axis. :param m: Spot presence indicator. Should be broadcastable to ``batch_shape``. :return: A tensor of a shape ``batch_shape + (P, P)`` representing 2D-Gaussian spots. """ # create meshgrid of PxP pixel positions device = height.device P_range = torch.arange(P, device=device) i_pixel, j_pixel = torch.meshgrid(P_range, P_range, indexing="xy") ij_pixel = torch.stack((i_pixel, j_pixel), dim=-1) # Ideal 2D gaussian spots spot_locs = target_locs + torch.stack((x, y), -1) scale = width[..., None, None, None] loc = spot_locs[..., None, None, :] var = scale**2 normalized_gaussian = torch.exp( ( -((ij_pixel - loc) ** 2) / (2 * var) - scale.log() - math.log(math.sqrt(2 * math.pi)) ).sum(-1) ) # (N, F, C, K, P, P) or (N, F, Q, C, K, P, P) if m is not None: height = m * height return height[..., None, None] * normalized_gaussian
[docs]def truncated_poisson_probs(lamda: torch.Tensor, K: int) -> torch.Tensor: r""" Probability of the number of non-specific spots. .. math:: \mathbf{TruncatedPoisson}(\lambda, K) = \begin{cases} 1 - e^{-\lambda} \sum_{i=0}^{K-1} \dfrac{\lambda^i}{i!} & \textrm{if } k = K \\ \dfrac{\lambda^k e^{-\lambda}}{k!} & \mathrm{otherwise} \end{cases} :param lamda: Average rate of target-nonspecific binding. :param K: Maximum number of spots that can be present in a single image. :return: A tensor of a shape ``lamda.shape + (K+1,)`` of probabilities. """ shape = lamda.shape + (K + 1,) dtype = lamda.dtype result = torch.zeros(shape, dtype=dtype) kdx = torch.arange(K) result[..., :-1] = torch.exp( kdx.xlogy(lamda.unsqueeze(-1)) - lamda.unsqueeze(-1) - (kdx + 1).lgamma() ) result[..., -1] = 1 - result[..., :-1].sum(-1) return result
[docs]def probs_m(lamda: torch.Tensor, K: int) -> torch.Tensor: r""" Prior spot presence probability :math:`p(m | \theta, \lambda)`. .. math:: p(m_{\mathsf{spot}(k)} | \theta, \lambda) = \begin{cases} \mathbf{Bernoulli}(1) & \text{$\theta = k$} \\ \mathbf{Bernoulli} \left( \sum_{l=1}^K \dfrac{l \cdot \mathbf{TruncPoisson}(l; \lambda, K)}{K} \right) & \text{$\theta = 0$} \rule{0pt}{4ex} \\ \mathbf{Bernoulli} \left( \sum_{l=1}^{K-1} \dfrac{l \cdot \mathbf{TruncPoisson}(l; \lambda, K-1)}{K-1} \right) & \text{otherwise} \rule{0pt}{4ex} \end{cases} :param lamda: Average rate of target-nonspecific binding. :param K: Maximum number of spots that can be present in a single image. :return: A tensor of a shape ``lamda.shape + (1 + K, K)`` of probabilities. """ shape = lamda.shape + (1 + K, K) dtype = lamda.dtype result = torch.zeros(shape, dtype=dtype) kdx = torch.arange(K) tr_pois_km1 = truncated_poisson_probs(lamda, K - 1) km1 = torch.arange(1, K) result[..., :, :] = (km1 * tr_pois_km1[..., km1]).sum(-1).unsqueeze(-1).unsqueeze( -1 ) / (K - 1) # theta == 0 tr_pois_k = truncated_poisson_probs(lamda, K) k = torch.arange(1, K + 1) result[..., 0, :] = (k * tr_pois_k[..., k]).sum(-1).unsqueeze(-1) / K # theta == k result[..., kdx + 1, kdx] = 1 return result
[docs]def expand_offtarget(probs: torch.Tensor) -> torch.Tensor: r""" Expand state probability ``probs`` (e.g., :math:`\pi` or :math:`A`) to off-target AOIs. .. math:: p(\mathsf{state}) = \begin{cases} \mathbf{Categorical}\left( \mathsf{probs} \right) & \textrm{if on-target} \\ \mathbf{Categorical}\left( \left[ 1, 0, \dots, 0 \right] \right) & \textrm{if off-target} \end{cases} :param probs: Probability of target-specific states. :return: A tensor of a shape ``probs.shape + (2,)`` of probabilities for off-target (``0``) and on-target (``1``) AOI. """ offtarget_probs = torch.zeros_like(probs) offtarget_probs[..., 0] = 1 return torch.stack([offtarget_probs, probs], dim=-1)
[docs]@lru_cache(maxsize=None) def probs_theta(K: int, device: torch.device) -> torch.Tensor: r""" Prior probability for target-specific spot index :math:`p(\theta | z)`. .. math:: p(\theta | z) = \begin{cases} \mathbf{Categorical}\left( \begin{bmatrix} 0 & 1/K & \dots & 1/K \end{bmatrix} \right) & z > 0 \\ \mathbf{Categorical}\left( \begin{bmatrix} 1 & 0 & \dots & 0 \end{bmatrix} \right) & z = 0 \end{cases} :param K: Maximum number of spots that can be present in a single image. :return: A tensor of a shape ``(2, 1 + K)`` of :math:`\theta` probabilities for spot-absent (``0``) and spot-present (``1``) cases. """ result = torch.zeros(2, 1 + K, device=device) result[0, 0] = 1 result[1, 1:] = 1 / K return result