# Copyright Contributors to the Tapqir project.
# SPDX-License-Identifier: Apache-2.0


import pykeops
import torch
from pykeops.torch import Genred
from pyro.distributions import TorchDistribution
from torch.distributions import Categorical, constraints
from torch.distributions.utils import lazy_property

from .util import gaussian_spots


[docs]class KSMOGN(TorchDistribution): r""" K-Spots Marginalized Offset Gamma Noise Image Distribution. .. math:: \mu^S_{\mathsf{pixelX}(i), \mathsf{pixelY}(j)} = \dfrac{m \cdot h}{2 \pi w^2} \exp{\left( -\dfrac{(i-x-x^\mathsf{target})^2 + (j-y-y^\mathsf{target})^2}{2 w^2} \right)} .. math:: \mu^I = b + \sum_{\mathsf{spot}} \mu^S .. math:: p(D|\mu^I, g) = \sum_\delta p(\delta) p(D|\mu^I, g, \delta) = \sum_\delta \delta_\mathsf{weights} \cdot \mathrm{Gamma}(D - \delta_\mathsf{samples} | \mu^I, g) **Reference**: 1. Ordabayev YA, Friedman LJ, Gelles J, Theobald DL. Bayesian machine learning analysis of single-molecule fluorescence colocalization images. bioRxiv. 2021 Oct. doi: `10.1101/2021.09.30.462536 <>`_. :param height: Integrated spot intensity. Should be broadcastable to ``batch_shape + (K,)``. :param width: Spot width. Should be broadcastable to ``batch_shape + (K,)``. :param x: Spot center on x-axis. Should be broadcastable to ``batch_shape + (K,)``. :param y: Spot center on y-axis. Should be broadcastable to ``batch_shape + (K,)``. :param target_locs: Target location. Should have the rightmost size ``2`` correspondnig to locations on x- and y-axes, and be broadcastable to ``batch_shape + (2,)``. :param background: Background intensity. Should be broadcastable to ``batch_shape``. :param gain: Camera gain. :param offset_samples: Offset samples from the empirical distribution. :param offset_logits: Offset log weights corresponding to the offset samples. :param int P: Number of pixels along the axis. :param m: Spot presence indicator. Should be broadcastable to ``batch_shape + (K,)``. :param alpha: Signal cross-talk coefficient matrix. Should be broadcastable to ``(Q, C)``. :param bool use_pykeops: Use pykeops as backend to marginalize out offset. """ arg_constraints = {} support = constraints.positive def __init__( self, height: torch.Tensor, width: torch.Tensor, x: torch.Tensor, y: torch.Tensor, target_locs: torch.Tensor, background: torch.Tensor, gain: torch.Tensor, offset_samples: torch.Tensor, offset_logits: torch.Tensor, P: int, m: torch.Tensor = None, alpha: torch.Tensor = None, use_pykeops: bool = True, validate_args=None, ): # shapes for cosmos and crosstalk models self.height = height # (N, F, C, K) or (N, F, Q, K) self.width = width # (N, F, C, K) or (N, F, Q, K) self.x = x self.y = y self.target_locs = target_locs # (N, F, C, 2) self.m = m # (N, F, C, K) or (N, F, Q, K) self.background = background[..., None, None] # (N, F, C, P, P) if alpha is not None: C = alpha.shape[-1] self.gain = gain[..., None, None, None] # (1, K, P, P) self.height = ( self.height.unsqueeze(-2) * alpha[..., None] ) # (N, F, Q, C, K) self.width = self.width.unsqueeze(-2) # (N, F, Q, 1, K) self.x = self.x.unsqueeze(-2) # (N, F, Q, 1, K) self.y = self.y.unsqueeze(-2) # (N, F, Q, 1, K) self.m = self.m.unsqueeze(-2) # (N, F, Q, 1, K) self.target_locs = self.target_locs.unsqueeze(-3) # (N, F, 1, C, 2) else: self.gain = gain[..., None, None] # (1, P, P) self.alpha = alpha self.rate = 1 / self.gain self.offset_samples = offset_samples self.offset_logits = offset_logits self.P = P self.use_pykeops = use_pykeops if self.use_pykeops: device = self.target_locs.device.type self.device_pykeops = "GPU" if device == "cuda" else "CPU" # calculate batch shape batch_shape = torch.broadcast_shapes( height.shape, width.shape, x.shape, y.shape ) # (N, F, C, K) or (N, F, Q, K) if m is not None: batch_shape = torch.broadcast_shapes( batch_shape, m.shape ) # (N, F, C, K) or (N, F, Q, K) event_shape = torch.Size([P, P]) # (P, P) bg_shape = background.shape # (N, F, C) target_shape = target_locs.shape[:-1] # (N, F, C) # remove K dim batch_shape = batch_shape[:-1] # (N, F, C) or (N, F, Q) if alpha is not None: # remove Q dim batch_shape = batch_shape[:-1] # (N, F) # add C dim event_shape = (C,) + event_shape # (C, P, P) # remove C dim bg_shape = bg_shape[:-1] # (N, F) target_shape = target_shape[:-1] # (N, F) batch_shape = torch.broadcast_shapes( batch_shape, bg_shape, target_shape ) # (N, F, C) or (N, F) super().__init__(batch_shape, event_shape, validate_args=validate_args) @lazy_property def gaussians(self): return gaussian_spots( self.height, # (N, F, C, K) or (N, F, Q, 1, K) self.width, # (N, F, C, K) or (N, F, Q, 1, K) self.x, # (N, F, C, K) or (N, F, Q, 1, K) self.y, # (N, F, C, K) or (N, F, Q, 1, K) self.target_locs.unsqueeze(-2), # (N, F, C, 1, 2) or (N, F, 1, C, 1, 2) self.P, self.m, ) @lazy_property def image(self): gaussians = self.gaussians # (N, F, C, K, P, P) or (N, F, Q, C, K, P, P) if self.alpha is not None: # sum over the signal from multiple fluorescent dyes gaussians = gaussians.sum(-5) # sum the background and all the spots return self.background + gaussians.sum(-3) @lazy_property def concentration(self): return self.image / self.gain
[docs] def rsample(self, sample_shape=torch.Size()): odx = ( Categorical(logits=self.offset_logits) .expand(self.batch_shape + self.event_shape) .sample() ) offset = self.offset_samples[odx] shape = self._extended_shape(sample_shape) value = torch._standard_gamma( self.concentration.expand(shape) ) / self.rate.expand(shape) value.detach().clamp_( min=torch.finfo(value.dtype).tiny ) # do not record in autograd graph return value + offset
[docs] def log_prob(self, value): if self.use_pykeops: formula = "wj+Log(Step(xi-gj-IntCst(1)))+(ai-IntCst(1))*Log(IfElse(xi-gj-IntCst(1),xi-gj,xi))-bi*(xi-gj)" variables = [ "wj = Vj(1)", "gj = Vj(1)", "ai = Vi(1)", "bi = Vi(1)", "xi = Vi(1)", ] dtype = self.concentration.dtype my_routine = Genred( formula, variables, reduction_op="LogSumExp", axis=1, ) concentration, value, rate = torch.broadcast_tensors( self.concentration, value, self.rate ) shape = value.shape result = my_routine( self.offset_logits.reshape(-1, 1), self.offset_samples.reshape(-1, 1).to(dtype), concentration.reshape(-1, 1), rate.reshape(-1, 1).contiguous(), value.reshape(-1, 1).to(dtype), backend=self.device_pykeops, ) result = result.reshape(shape) result = ( self.concentration * torch.log(self.rate) - torch.lgamma(self.concentration) + result ) else: value = torch.as_tensor(value).unsqueeze(-1) concentration = self.concentration.unsqueeze(-1) mask = value > self.offset_samples new_value = torch.where( mask, value - self.offset_samples, value.new_ones(()) ) obs_logits = ( concentration * torch.log(self.rate) + (concentration - 1) * torch.log(new_value) - self.rate * (new_value) - torch.lgamma(concentration) ) result = obs_logits + self.offset_logits + torch.log(mask) result = torch.logsumexp(result, -1) event_dims = tuple(-i for i in range(1, len(self.event_shape) + 1)) return result.sum(event_dims)