Source code for tapqir.models.cosmos

# Copyright Contributors to the Tapqir project.
# SPDX-License-Identifier: Apache-2.0

"""
cosmos
^^^^^^
"""

import itertools
import math
from functools import reduce
from typing import Tuple

import torch
import torch.distributions.constraints as constraints
from pyro.ops.indexing import Vindex
from pyroapi import distributions as dist
from pyroapi import handlers, infer, pyro
from torch.distributions.utils import lazy_property
from torch.nn.functional import one_hot

from tapqir.distributions import KSMOGN, AffineBeta
from tapqir.distributions.util import expand_offtarget, probs_m, probs_theta
from tapqir.models.model import Model
from tapqir.utils.stats import torch_to_scipy_dist


[docs]class cosmos(Model): r""" **Multi-Color Time-Independent Colocalization Model** **Reference**: 1. Ordabayev YA, Friedman LJ, Gelles J, Theobald DL. Bayesian machine learning analysis of single-molecule fluorescence colocalization images. eLife. 2022 March. doi: `10.7554/eLife.73860 <https://doi.org/10.7554/eLife.73860>`_. :param K: Maximum number of spots that can be present in a single image. :param device: Computation device (cpu or gpu). :param dtype: Floating point precision. :param use_pykeops: Use pykeops as backend to marginalize out offset. :param priors: Dictionary of parameters of prior distributions. """ name = "cosmos" def __init__( self, S: int = 1, K: int = 2, Q: int = None, device: str = "cpu", dtype: str = "double", use_pykeops: bool = True, priors: dict = { "background_mean_std": 1000.0, "background_std_std": 100.0, "lamda_rate": 1.0, "height_std": 10000.0, "width_min": 0.75, "width_max": 2.25, "proximity_rate": 1.0, "gain_std": 50.0, }, ): super().__init__(S=S, K=K, Q=Q, device=device, dtype=dtype, priors=priors) self._global_params = ["gain", "proximity", "lamda", "pi"] self.use_pykeops = use_pykeops self.conv_params = ["-ELBO", "proximity_loc", "gain_loc", "lamda_loc"] self.ci_params = [ "gain", "pi", "lamda", "proximity", "background", "height", "width", "x", "y", ]
[docs] def model(self): r""" **Generative Model** Model parameters: +-----------------+-----------+-------------------------------------+ | Parameter | Shape | Description | +=================+===========+=====================================+ | |g| - :math:`g` | (1,) | camera gain | +-----------------+-----------+-------------------------------------+ | |sigma| - |prox|| (1,) | proximity | +-----------------+-----------+-------------------------------------+ | ``lamda`` - |ld|| (1,) | average rate of target-nonspecific | | | | binding | +-----------------+-----------+-------------------------------------+ | ``pi`` - |pi| | (1,) | average binding probability of | | | | target-specific binding | +-----------------+-----------+-------------------------------------+ | |bg| - |b| | (N, F) | background intensity | +-----------------+-----------+-------------------------------------+ | |z| - :math:`z` | (N, F) | target-specific spot presence | +-----------------+-----------+-------------------------------------+ | |t| - |theta| | (N, F) | target-specific spot index | +-----------------+-----------+-------------------------------------+ | |m| - :math:`m` | (K, N, F) | spot presence indicator | +-----------------+-----------+-------------------------------------+ | |h| - :math:`h` | (K, N, F) | spot intensity | +-----------------+-----------+-------------------------------------+ | |w| - :math:`w` | (K, N, F) | spot width | +-----------------+-----------+-------------------------------------+ | |x| - :math:`x` | (K, N, F) | spot position on x-axis | +-----------------+-----------+-------------------------------------+ | |y| - :math:`y` | (K, N, F) | spot position on y-axis | +-----------------+-----------+-------------------------------------+ | |D| - :math:`D` | |shape| | observed images | +-----------------+-----------+-------------------------------------+ .. |ps| replace:: :math:`p(\mathsf{specific})` .. |theta| replace:: :math:`\theta` .. |prox| replace:: :math:`\sigma^{xy}` .. |ld| replace:: :math:`\lambda` .. |b| replace:: :math:`b` .. |shape| replace:: (N, F, P, P) .. |sigma| replace:: ``proximity`` .. |bg| replace:: ``background`` .. |h| replace:: ``height`` .. |w| replace:: ``width`` .. |D| replace:: ``data`` .. |m| replace:: ``m`` .. |z| replace:: ``z`` .. |t| replace:: ``theta`` .. |x| replace:: ``x`` .. |y| replace:: ``y`` .. |pi| replace:: :math:`\pi` .. |g| replace:: ``gain`` Full joint distribution: .. math:: \begin{aligned} p(D, \phi) =~&p(g) p(\sigma^{xy}) p(\pi) p(\lambda) \prod_{\mathsf{AOI}} \left[ p(\mu^b) p(\sigma^b) \prod_{\mathsf{frame}} \left[ \vphantom{\prod_{F}} p(b | \mu^b, \sigma^b) p(z | \pi) p(\theta | z) \vphantom{\prod_{\substack{\mathsf{pixelX} \\ \mathsf{pixelY}}}} \cdot \right. \right. \\ &\prod_{\mathsf{spot}} \left[ \vphantom{\prod_{F}} p(m | \theta, \lambda) p(h) p(w) p(x | \sigma^{xy}, \theta) p(y | \sigma^{xy}, \theta) \right] \left. \left. \prod_{\substack{\mathsf{pixelX} \\ \mathsf{pixelY}}} \sum_{\delta} p(\delta) p(D | \mu^I, g, \delta) \right] \right] \end{aligned} :math:`z` and :math:`\theta` marginalized joint distribution: .. math:: \begin{aligned} \sum_{z, \theta} p(D, \phi) =~&p(g) p(\sigma^{xy}) p(\pi) p(\lambda) \prod_{\mathsf{AOI}} \left[ p(\mu^b) p(\sigma^b) \prod_{\mathsf{frame}} \left[ \vphantom{\prod_{F}} p(b | \mu^b, \sigma^b) \sum_{z} p(z | \pi) \sum_{\theta} p(\theta | z) \vphantom{\prod_{\substack{\mathsf{pixelX} \\ \mathsf{pixelY}}}} \cdot \right. \right. \\ &\prod_{\mathsf{spot}} \left[ \vphantom{\prod_{F}} p(m | \theta, \lambda) p(h) p(w) p(x | \sigma^{xy}, \theta) p(y | \sigma^{xy}, \theta) \right] \left. \left. \prod_{\substack{\mathsf{pixelX} \\ \mathsf{pixelY}}} \sum_{\delta} p(\delta) p(D | \mu^I, g, \delta) \right] \right] \end{aligned} """ # global parameters gain = pyro.sample("gain", dist.HalfNormal(self.priors["gain_std"])) pi = pyro.sample( "pi", dist.Dirichlet(torch.ones((self.Q, self.S + 1)) / (self.S + 1)).to_event(1), ) pi = expand_offtarget(pi) lamda = pyro.sample( "lamda", dist.Exponential(torch.full((self.Q,), self.priors["lamda_rate"])).to_event( 1 ), ) proximity = pyro.sample( "proximity", dist.Exponential(self.priors["proximity_rate"]) ) size = torch.stack( ( torch.full_like(proximity, 2.0), (((self.data.P + 1) / (2 * proximity)) ** 2 - 1), ), dim=-1, ) # aoi sites aois = pyro.plate( "aois", self.data.Nt, subsample=self.n, subsample_size=self.nbatch_size, dim=-3, ) # time frames frames = pyro.plate( "frames", self.data.F, subsample=self.f, subsample_size=self.fbatch_size, dim=-2, ) # color channels channels = pyro.plate( "channels", self.data.C, dim=-1, ) with channels as cdx, aois as ndx: ndx = ndx[:, None, None] mask = Vindex(self.data.mask.to(self.device))[ndx] with handlers.mask(mask=mask): # background mean and std background_mean = pyro.sample( "background_mean", dist.HalfNormal(self.priors["background_mean_std"]), ) background_std = pyro.sample( "background_std", dist.HalfNormal(self.priors["background_std_std"]) ) with frames as fdx: fdx = fdx[:, None] # fetch data obs, target_locs, is_ontarget = self.data.fetch(ndx, fdx, cdx) # sample background intensity background = pyro.sample( "background", dist.Gamma( (background_mean / background_std) ** 2, background_mean / background_std**2, ), ) # sample hidden model state (1+S,) z = pyro.sample( "z", dist.Categorical(Vindex(pi)[..., cdx, :, is_ontarget.long()]), infer={"enumerate": "parallel"}, ) theta = pyro.sample( "theta", dist.Categorical( Vindex(probs_theta(self.K, self.device))[ torch.clamp(z, min=0, max=1) ] ), infer={"enumerate": "parallel"}, ) onehot_theta = one_hot(theta, num_classes=1 + self.K) ms, heights, widths, xs, ys = [], [], [], [], [] for kdx in range(self.K): specific = onehot_theta[..., 1 + kdx] # spot presence m = pyro.sample( f"m_k{kdx}", dist.Bernoulli( Vindex(probs_m(lamda, self.K))[..., cdx, theta, kdx] ), ) with handlers.mask(mask=m > 0): # sample spot variables height = pyro.sample( f"height_k{kdx}", dist.HalfNormal(self.priors["height_std"]), ) width = pyro.sample( f"width_k{kdx}", AffineBeta( 1.5, 2, self.priors["width_min"], self.priors["width_max"], ), ) x = pyro.sample( f"x_k{kdx}", AffineBeta( 0, Vindex(size)[..., specific], -(self.data.P + 1) / 2, (self.data.P + 1) / 2, ), ) y = pyro.sample( f"y_k{kdx}", AffineBeta( 0, Vindex(size)[..., specific], -(self.data.P + 1) / 2, (self.data.P + 1) / 2, ), ) # append ms.append(m) heights.append(height) widths.append(width) xs.append(x) ys.append(y) # observed data pyro.sample( "data", KSMOGN( torch.stack(heights, -1), torch.stack(widths, -1), torch.stack(xs, -1), torch.stack(ys, -1), target_locs, background, gain, self.data.offset.samples, self.data.offset.logits.to(self.dtype), self.data.P, torch.stack(torch.broadcast_tensors(*ms), -1), use_pykeops=self.use_pykeops, ), obs=obs, )
[docs] def guide(self): r""" **Variational Distribution** .. math:: \begin{aligned} q(\phi \setminus \{z, \theta\}) =~&q(g) q(\sigma^{xy}) q(\pi) q(\lambda) \cdot \\ &\prod_{\mathsf{AOI}} \left[ q(\mu^b) q(\sigma^b) \prod_{\mathsf{frame}} \left[ \vphantom{\prod_{F}} q(b) \prod_{\mathsf{spot}} q(m) q(h | m) q(w | m) q(x | m) q(y | m) \right] \right] \end{aligned} """ # global parameters pyro.sample( "gain", dist.Gamma( pyro.param("gain_loc") * pyro.param("gain_beta"), pyro.param("gain_beta"), ), ) pyro.sample( "pi", dist.Dirichlet(pyro.param("pi_mean") * pyro.param("pi_size")).to_event(1), ) pyro.sample( "lamda", dist.Gamma( pyro.param("lamda_loc") * pyro.param("lamda_beta"), pyro.param("lamda_beta"), ).to_event(1), ) pyro.sample( "proximity", AffineBeta( pyro.param("proximity_loc"), pyro.param("proximity_size"), 0, (self.data.P + 1) / math.sqrt(12), ), ) # aoi sites aois = pyro.plate( "aois", self.data.Nt, subsample=self.n, subsample_size=self.nbatch_size, dim=-3, ) # time frames frames = pyro.plate( "frames", self.data.F, subsample=self.f, subsample_size=self.fbatch_size, dim=-2, ) # color channels channels = pyro.plate( "channels", self.data.C, dim=-1, ) with channels as cdx, aois as ndx: ndx = ndx[:, None, None] mask = Vindex(self.data.mask.to(self.device))[ndx] with handlers.mask(mask=mask): pyro.sample( "background_mean", dist.Delta(Vindex(pyro.param("background_mean_loc"))[ndx, 0, cdx]), ) pyro.sample( "background_std", dist.Delta(Vindex(pyro.param("background_std_loc"))[ndx, 0, cdx]), ) with frames as fdx: fdx = fdx[:, None] # sample background intensity pyro.sample( "background", dist.Gamma( Vindex(pyro.param("b_loc"))[ndx, fdx, cdx] * Vindex(pyro.param("b_beta"))[ndx, fdx, cdx], Vindex(pyro.param("b_beta"))[ndx, fdx, cdx], ), ) for kdx in range(self.K): # sample spot presence m m = pyro.sample( f"m_k{kdx}", dist.Bernoulli( Vindex(pyro.param("m_probs"))[kdx, ndx, fdx, cdx] ), infer={"enumerate": "parallel"}, ) with handlers.mask(mask=m > 0): # sample spot variables pyro.sample( f"height_k{kdx}", dist.Gamma( Vindex(pyro.param("h_loc"))[kdx, ndx, fdx, cdx] * Vindex(pyro.param("h_beta"))[kdx, ndx, fdx, cdx], Vindex(pyro.param("h_beta"))[kdx, ndx, fdx, cdx], ), ) pyro.sample( f"width_k{kdx}", AffineBeta( Vindex(pyro.param("w_mean"))[kdx, ndx, fdx, cdx], Vindex(pyro.param("w_size"))[kdx, ndx, fdx, cdx], self.priors["width_min"], self.priors["width_max"], ), ) pyro.sample( f"x_k{kdx}", AffineBeta( Vindex(pyro.param("x_mean"))[kdx, ndx, fdx, cdx], Vindex(pyro.param("size"))[kdx, ndx, fdx, cdx], -(self.data.P + 1) / 2, (self.data.P + 1) / 2, ), ) pyro.sample( f"y_k{kdx}", AffineBeta( Vindex(pyro.param("y_mean"))[kdx, ndx, fdx, cdx], Vindex(pyro.param("size"))[kdx, ndx, fdx, cdx], -(self.data.P + 1) / 2, (self.data.P + 1) / 2, ), )
[docs] def init_parameters(self): """ Initialize variational parameters. """ device = self.device data = self.data pyro.param( "pi_mean", lambda: torch.ones((self.Q, self.S + 1), device=device), constraint=constraints.simplex, ) pyro.param( "pi_size", lambda: torch.full((self.Q, 1), 2, device=device), constraint=constraints.positive, ) pyro.param( "m_probs", lambda: torch.full((self.K, data.Nt, data.F, self.Q), 0.5, device=device), constraint=constraints.unit_interval, ) self._init_parameters()
def _init_parameters(self): """ Parameters shared between different models. """ device = self.device data = self.data pyro.param( "proximity_loc", lambda: torch.tensor(0.5, device=device), constraint=constraints.interval( 0, (self.data.P + 1) / math.sqrt(12) - torch.finfo(self.dtype).eps, ), ) pyro.param( "proximity_size", lambda: torch.tensor(100, device=device), constraint=constraints.greater_than(2.0), ) pyro.param( "lamda_loc", lambda: torch.full((self.Q,), 0.5, device=device), constraint=constraints.positive, ) pyro.param( "lamda_beta", lambda: torch.full((self.Q,), 100, device=device), constraint=constraints.positive, ) pyro.param( "gain_loc", lambda: torch.tensor(5, device=device), constraint=constraints.positive, ) pyro.param( "gain_beta", lambda: torch.tensor(100, device=device), constraint=constraints.positive, ) pyro.param( "background_mean_loc", lambda: (data.median.to(device) - data.offset.mean).expand( data.Nt, 1, data.C ), constraint=constraints.positive, ) pyro.param( "background_std_loc", lambda: torch.ones(data.Nt, 1, data.C, device=device), constraint=constraints.positive, ) pyro.param( "b_loc", lambda: (data.median.to(device) - self.data.offset.mean).expand( data.Nt, data.F, data.C ), constraint=constraints.positive, ) pyro.param( "b_beta", lambda: torch.ones(data.Nt, data.F, data.C, device=device), constraint=constraints.positive, ) pyro.param( "h_loc", lambda: torch.full((self.K, data.Nt, data.F, self.Q), 2000, device=device), constraint=constraints.positive, ) pyro.param( "h_beta", lambda: torch.full((self.K, data.Nt, data.F, self.Q), 0.001, device=device), constraint=constraints.positive, ) pyro.param( "w_mean", lambda: torch.full((self.K, data.Nt, data.F, self.Q), 1.5, device=device), constraint=constraints.interval( 0.75 + torch.finfo(self.dtype).eps, 2.25 - torch.finfo(self.dtype).eps, ), ) pyro.param( "w_size", lambda: torch.full((self.K, data.Nt, data.F, self.Q), 100, device=device), constraint=constraints.greater_than(2.0), ) pyro.param( "x_mean", lambda: torch.zeros(self.K, data.Nt, data.F, self.Q, device=device), constraint=constraints.interval( -(data.P + 1) / 2 + torch.finfo(self.dtype).eps, (data.P + 1) / 2 - torch.finfo(self.dtype).eps, ), ) pyro.param( "y_mean", lambda: torch.zeros(self.K, data.Nt, data.F, self.Q, device=device), constraint=constraints.interval( -(data.P + 1) / 2 + torch.finfo(self.dtype).eps, (data.P + 1) / 2 - torch.finfo(self.dtype).eps, ), ) pyro.param( "size", lambda: torch.full((self.K, data.Nt, data.F, self.Q), 200, device=device), constraint=constraints.greater_than(2.0), )
[docs] def TraceELBO(self, jit=False): """ A trace implementation of ELBO-based SVI that supports - exhaustive enumeration over discrete sample sites, and - local parallel sampling over any sample site in the guide. """ return (infer.JitTraceEnum_ELBO if jit else infer.TraceEnum_ELBO)( max_plate_nesting=3, ignore_jit_warnings=True )
@lazy_property def compute_probs(self) -> Tuple[torch.Tensor, torch.Tensor]: z_probs = torch.zeros(self.data.Nt, self.data.F, self.Q, 1 + self.S) theta_probs = torch.zeros(self.K, self.data.Nt, self.data.F, self.Q) nbatch_size = self.nbatch_size fbatch_size = self.fbatch_size N = sum(self.data.is_ontarget) params = ["m", "x", "y"] params = list(map(lambda x: [f"{x}_k{i}" for i in range(self.K)], params)) params = list(itertools.chain(*params)) params += ["z", "theta"] theta_dims = tuple(i for i in range(0, 2, 2)) z_dims = tuple(i for i in range(1, 2, 2)) m_dims = tuple(i for i in range(2, self.K + 2)) for ndx in torch.split(torch.arange(N), nbatch_size): for fdx in torch.split(torch.arange(self.data.F), fbatch_size): self.n = ndx self.f = fdx self.nbatch_size = len(ndx) self.fbatch_size = len(fdx) qdx = torch.arange(self.Q) with torch.no_grad(), pyro.plate( "particles", size=50, dim=-4 ), handlers.enum(first_available_dim=-5): guide_tr = handlers.trace(self.guide).get_trace() model_tr = handlers.trace( handlers.replay( handlers.block(self.model, hide=["data"]), trace=guide_tr ) ).get_trace() model_tr.compute_log_prob() guide_tr.compute_log_prob() # 0 - theta # 1 - z # 2 - m_1 # 3 - m_0 # p(z, theta, phi) logp = 0 for name in params: logp = logp + model_tr.nodes[name]["unscaled_log_prob"] # p(z, theta | phi) = p(z, theta, phi) - p(z, theta, phi).sum(z, theta) logp = logp - logp.logsumexp(z_dims + theta_dims) m_log_probs = [ guide_tr.nodes[f"m_k{k}"]["unscaled_log_prob"] for k in range(self.K) ] expectation = reduce(lambda x, y: x + y, m_log_probs) + logp # average over m result = expectation.logsumexp(m_dims) # marginalize theta z_logits = result.logsumexp(theta_dims) z_probs[ndx[:, None, None], fdx[:, None], qdx] = ( z_logits.exp().mean(-4).permute(1, 2, 3, 0) ) # marginalize z theta_logits = result.logsumexp(z_dims) theta_probs[:, ndx[:, None, None], fdx[:, None], qdx] = ( theta_logits[1:].exp().mean(-4) ) self.n = None self.f = None self.nbatch_size = nbatch_size self.fbatch_size = fbatch_size return z_probs, theta_probs @property def z_probs(self) -> torch.Tensor: r""" Probability of there being a target-specific spot :math:`p(z=1)` """ return self.compute_probs[0] @property def theta_probs(self) -> torch.Tensor: r""" Posterior target-specific spot probability :math:`q(\theta = k)` for :math:`k \in \{1, \dots, K\}`. """ return self.compute_probs[1] @property def m_probs(self) -> torch.Tensor: r""" Posterior spot presence probability :math:`q(m=1)`. """ return pyro.param("m_probs").data @property def pspecific(self) -> torch.Tensor: r""" Probability of there being a target-specific spot :math:`p(\mathsf{specific})` """ return self.z_probs @property def z_map(self) -> torch.Tensor: return torch.argmax(self.z_probs, dim=-1) def z_sample(self, num_samples): return dist.Categorical(self.params["z_probs"][: self.data.N]).sample( (num_samples,) ) @torch.no_grad() def compute_params(self, CI): params = {} for param in self.ci_params: if param == "gain": fn = dist.Gamma( pyro.param("gain_loc") * pyro.param("gain_beta"), pyro.param("gain_beta"), ) elif param == "alpha": fn = dist.Dirichlet(pyro.param("alpha_mean") * pyro.param("alpha_size")) elif param == "pi": fn = dist.Dirichlet(pyro.param("pi_mean") * pyro.param("pi_size")) elif param == "init": fn = dist.Dirichlet(pyro.param("init_mean") * pyro.param("init_size")) elif param == "trans": fn = dist.Dirichlet(pyro.param("trans_mean") * pyro.param("trans_size")) elif param == "lamda": fn = dist.Gamma( pyro.param("lamda_loc") * pyro.param("lamda_beta"), pyro.param("lamda_beta"), ) elif param == "proximity": fn = AffineBeta( pyro.param("proximity_loc"), pyro.param("proximity_size"), 0, (self.data.P + 1) / math.sqrt(12), ) elif param == "background": fn = dist.Gamma( pyro.param("b_loc") * pyro.param("b_beta"), pyro.param("b_beta"), ) elif param == "height": fn = dist.Gamma( pyro.param("h_loc") * pyro.param("h_beta"), pyro.param("h_beta"), ) elif param == "width": fn = AffineBeta( pyro.param("w_mean"), pyro.param("w_size"), self.priors["width_min"], self.priors["width_max"], ) elif param == "x": fn = AffineBeta( pyro.param("x_mean"), pyro.param("size"), -(self.data.P + 1) / 2, (self.data.P + 1) / 2, ) elif param == "y": fn = AffineBeta( pyro.param("y_mean"), pyro.param("size"), -(self.data.P + 1) / 2, (self.data.P + 1) / 2, ) scipy_dist = torch_to_scipy_dist(fn) LL, UL = scipy_dist.interval(alpha=CI) params[param] = {} params[param]["LL"] = torch.as_tensor(LL, device=torch.device("cpu")) params[param]["UL"] = torch.as_tensor(UL, device=torch.device("cpu")) params[param]["Mean"] = fn.mean.detach().cpu() params["m_probs"] = self.m_probs.cpu() params["z_probs"] = self.z_probs.cpu() params["theta_probs"] = self.theta_probs.cpu() params["z_map"] = self.z_map.data.cpu() params["p_specific"] = params["theta_probs"].sum(0) return params