Source code for tapqir.models.hmm

# Copyright Contributors to the Tapqir project.
# SPDX-License-Identifier: Apache-2.0

"""
hmm
^^^
"""

import math

import funsor
import torch
import torch.distributions.constraints as constraints
from pyro.distributions.hmm import _logmatmulexp, _sequential_index
from pyro.ops.indexing import Vindex
from pyroapi import distributions as dist
from pyroapi import handlers, infer, pyro
from torch.distributions.utils import lazy_property
from torch.nn.functional import one_hot

from tapqir.distributions import KSMOGN, AffineBeta
from tapqir.distributions.util import expand_offtarget, probs_m, probs_theta
from tapqir.handlers import trace, vectorized_markov
from tapqir.infer.elbo import TraceMarkovEnum_ELBO
from tapqir.models.cosmos import cosmos


[docs]class hmm(cosmos): r""" **Multi-Color Hidden Markov Colocalization Model** EXPERIMENTAL This model relies on Funsor backend. :param K: Maximum number of spots that can be present in a single image. :param device: Computation device (cpu or gpu). :param dtype: Floating point precision. :param use_pykeops: Use pykeops as backend to marginalize out offset. :param vectorized: Vectorize time-dimension. :param priors: Dictionary of parameters of prior distributions. """ name = "cosmos+hmm" def __init__( self, S: int = 1, K: int = 2, device: str = "cpu", dtype: str = "double", use_pykeops: bool = True, vectorized: bool = True, priors: dict = { "background_mean_std": 1000.0, "background_std_std": 100.0, "lamda_rate": 1.0, "height_std": 10000.0, "width_min": 0.75, "width_max": 2.25, "proximity_rate": 1.0, "gain_std": 50.0, }, ): self.vectorized = vectorized super().__init__( S=S, K=K, device=device, dtype=dtype, use_pykeops=use_pykeops, priors=priors ) self._global_params = ["gain", "proximity", "lamda", "trans"] self.ci_params = [ "gain", "init", "trans", "lamda", "proximity", "background", "height", "width", "x", "y", ]
[docs] def model(self): """ **Generative Model** """ # global parameters gain = pyro.sample("gain", dist.HalfNormal(self.priors["gain_std"])) init = pyro.sample( "init", dist.Dirichlet(torch.ones(self.Q, self.S + 1) / (self.S + 1)).to_event(1), ) init = expand_offtarget(init) trans = pyro.sample( "trans", dist.Dirichlet( torch.ones(self.Q, self.S + 1, self.S + 1) / (self.S + 1) ).to_event(2), ) trans = expand_offtarget(trans) lamda = pyro.sample( "lamda", dist.Exponential(torch.full((self.Q,), self.priors["lamda_rate"])).to_event( 1 ), ) proximity = pyro.sample( "proximity", dist.Exponential(self.priors["proximity_rate"]) ) size = torch.stack( ( torch.full_like(proximity, 2.0), (((self.data.P + 1) / (2 * proximity)) ** 2 - 1), ), dim=-1, ) # spots spots = pyro.plate("spots", self.K) # aoi sites aois = pyro.plate( "aois", self.data.Nt, subsample=self.n, subsample_size=self.nbatch_size, dim=-3, ) # time frames frames = ( vectorized_markov(name="frames", size=self.data.F, dim=-2) if self.vectorized else pyro.markov(range(self.data.F)) ) # color channels channels = pyro.plate( "channels", self.data.C, dim=-1, ) with channels as cdx, aois as ndx: ndx = ndx[:, None, None] mask = Vindex(self.data.mask.to(self.device))[ndx] with handlers.mask(mask=mask): # background mean and std background_mean = pyro.sample( "background_mean", dist.HalfNormal(self.priors["background_mean_std"]), ) background_std = pyro.sample( "background_std", dist.HalfNormal(self.priors["background_std_std"]) ) z_prev = None for fdx in frames: if self.vectorized: fsx, fdx = fdx fdx = torch.as_tensor(fdx) fdx = fdx.unsqueeze(-1) else: fsx = fdx # fetch data obs, target_locs, is_ontarget = self.data.fetch(ndx, fdx, cdx) # sample background intensity background = pyro.sample( f"background_f{fsx}", dist.Gamma( (background_mean / background_std) ** 2, background_mean / background_std**2, ), ) # sample hidden model state (1+S,) z_probs = ( Vindex(init)[..., cdx, :, is_ontarget.long()] if z_prev is None else Vindex(trans)[..., cdx, z_prev, :, is_ontarget.long()] ) z_curr = pyro.sample(f"z_f{fsx}", dist.Categorical(z_probs)) theta = pyro.sample( f"theta_f{fsx}", dist.Categorical( Vindex(probs_theta(self.K, self.device))[ torch.clamp(z_curr, min=0, max=1) ] ), infer={"enumerate": "parallel"}, ) onehot_theta = one_hot(theta, num_classes=1 + self.K) ms, heights, widths, xs, ys = [], [], [], [], [] for kdx in spots: specific = onehot_theta[..., 1 + kdx] # spot presence m_probs = Vindex(probs_m(lamda, self.K))[..., cdx, theta, kdx] m = pyro.sample( f"m_k{kdx}_f{fsx}", dist.Categorical(torch.stack((1 - m_probs, m_probs), -1)), ) with handlers.mask(mask=m > 0): # sample spot variables height = pyro.sample( f"height_k{kdx}_f{fsx}", dist.HalfNormal(self.priors["height_std"]), ) width = pyro.sample( f"width_k{kdx}_f{fsx}", AffineBeta( 1.5, 2, self.priors["width_min"], self.priors["width_max"], ), ) x = pyro.sample( f"x_k{kdx}_f{fsx}", AffineBeta( 0, Vindex(size)[..., specific], -(self.data.P + 1) / 2, (self.data.P + 1) / 2, ), ) y = pyro.sample( f"y_k{kdx}_f{fsx}", AffineBeta( 0, Vindex(size)[..., specific], -(self.data.P + 1) / 2, (self.data.P + 1) / 2, ), ) # append ms.append(m) heights.append(height) widths.append(width) xs.append(x) ys.append(y) # observed data pyro.sample( f"data_f{fsx}", KSMOGN( torch.stack(heights, -1), torch.stack(widths, -1), torch.stack(xs, -1), torch.stack(ys, -1), target_locs, background, gain, self.data.offset.samples, self.data.offset.logits.to(self.dtype), self.data.P, torch.stack(torch.broadcast_tensors(*ms), -1), use_pykeops=self.use_pykeops, ), obs=obs, ) z_prev = z_curr
[docs] def guide(self): """ **Variational Distribution** """ # global parameters pyro.sample( "gain", dist.Gamma( pyro.param("gain_loc") * pyro.param("gain_beta"), pyro.param("gain_beta"), ), ) pyro.sample( "init", dist.Dirichlet(pyro.param("init_mean") * pyro.param("init_size")).to_event( 1 ), ) pyro.sample( "trans", dist.Dirichlet( pyro.param("trans_mean") * pyro.param("trans_size") ).to_event(2), ) pyro.sample( "lamda", dist.Gamma( pyro.param("lamda_loc") * pyro.param("lamda_beta"), pyro.param("lamda_beta"), ).to_event(1), ) pyro.sample( "proximity", AffineBeta( pyro.param("proximity_loc"), pyro.param("proximity_size"), 0, (self.data.P + 1) / math.sqrt(12), ), ) # spots spots = pyro.plate("spots", self.K) # aoi sites aois = pyro.plate( "aois", self.data.Nt, subsample=self.n, subsample_size=self.nbatch_size, dim=-3, ) # time frames frames = ( vectorized_markov(name="frames", size=self.data.F, dim=-2) if self.vectorized else pyro.markov(range(self.data.F)) ) # color channels channels = pyro.plate( "channels", self.data.C, dim=-1, ) with channels as cdx, aois as ndx: ndx = ndx[:, None, None] mask = Vindex(self.data.mask.to(self.device))[ndx] with handlers.mask(mask=mask): pyro.sample( "background_mean", dist.Delta(Vindex(pyro.param("background_mean_loc"))[ndx, 0, cdx]), ) pyro.sample( "background_std", dist.Delta(Vindex(pyro.param("background_std_loc"))[ndx, 0, cdx]), ) z_prev = None for fdx in frames: if self.vectorized: fsx, fdx = fdx fdx = torch.as_tensor(fdx) fdx = fdx.unsqueeze(-1) else: fsx = fdx # sample background intensity pyro.sample( f"background_f{fsx}", dist.Gamma( Vindex(pyro.param("b_loc"))[ndx, fdx, cdx] * Vindex(pyro.param("b_beta"))[ndx, fdx, cdx], Vindex(pyro.param("b_beta"))[ndx, fdx, cdx], ), ) # sample hidden model state z_probs = ( Vindex(pyro.param("z_trans"))[ndx, fdx, cdx, 0] if z_prev is None else Vindex(pyro.param("z_trans"))[ndx, fdx, cdx, z_prev] ) z_curr = pyro.sample( f"z_f{fsx}", dist.Categorical(z_probs), infer={"enumerate": "parallel"}, ) for kdx in spots: # spot presence m_probs = Vindex(pyro.param("m_probs"))[ z_curr, kdx, ndx, fdx, cdx ] m = pyro.sample( f"m_k{kdx}_f{fsx}", dist.Categorical(torch.stack((1 - m_probs, m_probs), -1)), infer={"enumerate": "parallel"}, ) with handlers.mask(mask=m > 0): # sample spot variables pyro.sample( f"height_k{kdx}_f{fsx}", dist.Gamma( Vindex(pyro.param("h_loc"))[kdx, ndx, fdx, cdx] * Vindex(pyro.param("h_beta"))[kdx, ndx, fdx, cdx], Vindex(pyro.param("h_beta"))[kdx, ndx, fdx, cdx], ), ) pyro.sample( f"width_k{kdx}_f{fsx}", AffineBeta( Vindex(pyro.param("w_mean"))[kdx, ndx, fdx, cdx], Vindex(pyro.param("w_size"))[kdx, ndx, fdx, cdx], self.priors["width_min"], self.priors["width_max"], ), ) pyro.sample( f"x_k{kdx}_f{fsx}", AffineBeta( Vindex(pyro.param("x_mean"))[kdx, ndx, fdx, cdx], Vindex(pyro.param("size"))[kdx, ndx, fdx, cdx], -(self.data.P + 1) / 2, (self.data.P + 1) / 2, ), ) pyro.sample( f"y_k{kdx}_f{fsx}", AffineBeta( Vindex(pyro.param("y_mean"))[kdx, ndx, fdx, cdx], Vindex(pyro.param("size"))[kdx, ndx, fdx, cdx], -(self.data.P + 1) / 2, (self.data.P + 1) / 2, ), ) z_prev = z_curr
[docs] def init_parameters(self): """ Initialize variational parameters. """ device = self.device data = self.data pyro.param( "init_mean", lambda: torch.ones(self.Q, self.S + 1, device=device), constraint=constraints.simplex, ) pyro.param( "init_size", lambda: torch.full((self.Q, 1), 2, device=device), constraint=constraints.positive, ) pyro.param( "trans_mean", lambda: torch.ones(self.Q, self.S + 1, self.S + 1, device=device), constraint=constraints.simplex, ) pyro.param( "trans_size", lambda: torch.full((self.Q, self.S + 1, 1), 2, device=device), constraint=constraints.positive, ) # classification pyro.param( "z_trans", lambda: torch.ones( data.Nt, data.F, data.C, 1 + self.S, 1 + self.S, device=device, ), constraint=constraints.simplex, ) pyro.param( "m_probs", lambda: torch.full( (1 + self.S, self.K, data.Nt, data.F, data.C), 0.5, device=device, ), constraint=constraints.unit_interval, ) self._init_parameters()
[docs] def TraceELBO(self, jit=False): """ A trace implementation of ELBO-based SVI that supports - exhaustive enumeration over discrete sample sites, and - local parallel sampling over any sample site in the guide. """ if self.vectorized: return (TraceMarkovEnum_ELBO)(max_plate_nesting=3, ignore_jit_warnings=True) return (infer.JitTraceEnum_ELBO if jit else infer.TraceEnum_ELBO)( max_plate_nesting=3, ignore_jit_warnings=True )
@staticmethod def _sequential_logmatmulexp(logits: torch.Tensor) -> torch.Tensor: """ For a tensor ``x`` whose time dimension is -4, computes:: x[..., 0, :, :, :] @ x[..., 1, :, :, :] @ ... @ x[..., T-1, :, :, :] but does so numerically stably in log space. """ batch_shape = logits.shape[:-4] state_dim = logits.size(-1) c_dim = logits.size(-3) sum_terms = [] # up sweep while logits.size(-4) > 1: time = logits.size(-4) even_time = time // 2 * 2 even_part = logits[..., :even_time, :, :, :] x_y = even_part.reshape( batch_shape + (even_time // 2, 2, c_dim, state_dim, state_dim) ) x, y = x_y.unbind(-4) contracted = _logmatmulexp(x, y) if time > even_time: contracted = torch.cat((contracted, logits[..., -1:, :, :, :]), dim=-4) sum_terms.append(logits) logits = contracted else: sum_terms.append(logits) # handle root case sum_term = sum_terms.pop() left_term = hmm._contraction_identity(sum_term) # down sweep while sum_terms: sum_term = sum_terms.pop() new_left_term = hmm._contraction_identity(sum_term) time = sum_term.size(-4) even_time = time // 2 * 2 if time > even_time: new_left_term[..., time - 1 : time, :, :, :] = left_term[ ..., even_time // 2 : even_time // 2 + 1, :, :, : ] left_term = left_term[..., : even_time // 2, :, :, :] left_sum = sum_term[..., :even_time:2, :, :, :] left_sum_and_term = _logmatmulexp(left_term, left_sum) new_left_term[..., :even_time:2, :, :, :] = left_term new_left_term[..., 1:even_time:2, :, :, :] = left_sum_and_term left_term = new_left_term else: alphas = _logmatmulexp(left_term, sum_term) return alphas @staticmethod def _contraction_identity(logits: torch.Tensor) -> torch.Tensor: batch_shape = logits.shape[:-3] state_dim = logits.size(-1) c_dim = logits.size(-3) result = torch.eye(state_dim).log() result = result.reshape((1,) * len(batch_shape) + (1, state_dim, state_dim)) result = result.repeat(batch_shape + (c_dim, 1, 1)) return result @lazy_property def compute_probs(self) -> torch.Tensor: theta_probs = torch.zeros(self.K, self.data.Nt, self.data.F, self.Q) nbatch_size = self.nbatch_size N = sum(self.data.is_ontarget) for ndx in torch.split(torch.arange(N), nbatch_size): self.n = ndx self.nbatch_size = len(ndx) with torch.no_grad(), pyro.plate( "particles", size=5, dim=-4 ), handlers.enum(first_available_dim=-5): guide_tr = trace()(self.guide).get_trace() model_tr = trace()( handlers.replay(self.model, trace=guide_tr) ).get_trace() model_tr.compute_log_prob() guide_tr.compute_log_prob() logp = {} result = {} for fsx in ("0", f"slice(1, {self.data.F}, None)"): logp[fsx] = 0 # collect log_prob terms p(z, theta, phi) for name in [ "z", "theta", "m_k0", "m_k1", "x_k0", "x_k1", "y_k0", "y_k1", ]: logp[fsx] += model_tr.nodes[f"{name}_f{fsx}"]["funsor"]["log_prob"] if fsx == "0": # substitute MAP values of z into p(z=z_map, theta, phi) z_map = funsor.Tensor(self.z_map[ndx, 0].long(), dtype=self.S + 1)[ "aois", "channels" ] logp[fsx] = logp[fsx](**{f"z_f{fsx}": z_map}) # compute log_measure q for given z_map log_measure = ( guide_tr.nodes[f"m_k0_f{fsx}"]["funsor"]["log_measure"] + guide_tr.nodes[f"m_k1_f{fsx}"]["funsor"]["log_measure"] ) log_measure = log_measure(**{f"z_f{fsx}": z_map}) else: # substitute MAP values of z into p(z=z_map, theta, phi) z_map = funsor.Tensor(self.z_map[ndx, 1:].long(), dtype=self.S + 1)[ "aois", "frames", "channels" ] z_map_prev = funsor.Tensor( self.z_map[ndx, :-1].long(), dtype=self.S + 1 )["aois", "frames", "channels"] fsx_prev = f"slice(0, {self.data.F-1}, None)" logp[fsx] = logp[fsx]( **{f"z_f{fsx}": z_map, f"z_f{fsx_prev}": z_map_prev} ) # compute log_measure q for given z_map log_measure = ( guide_tr.nodes[f"m_k0_f{fsx}"]["funsor"]["log_measure"] + guide_tr.nodes[f"m_k1_f{fsx}"]["funsor"]["log_measure"] ) log_measure = log_measure( **{f"z_f{fsx}": z_map, f"z_f{fsx_prev}": z_map_prev} ) # compute p(z_map, theta | phi) = p(z_map, theta, phi) - p(z_map, phi) logp[fsx] = logp[fsx] - logp[fsx].reduce( funsor.ops.logaddexp, f"theta_f{fsx}" ) # average over m in p * q result[fsx] = (logp[fsx] + log_measure).reduce( funsor.ops.logaddexp, frozenset({f"m_k0_f{fsx}", f"m_k1_f{fsx}"}) ) # average over particles result[fsx] = result[fsx].exp().reduce(funsor.ops.mean, "particles") theta_probs[:, ndx, 0] = result["0"].data[..., 1:].permute(2, 0, 1) theta_probs[:, ndx, 1:] = ( result[f"slice(1, {self.data.F}, None)"] .data[..., 1:] .permute(3, 0, 1, 2) ) self.n = None self.nbatch_size = nbatch_size return theta_probs @property def z_probs(self) -> torch.Tensor: r""" Probability of there being a target-specific spot :math:`p(z=1)` """ result = self._sequential_logmatmulexp(pyro.param("z_trans").data.log()) return result[..., 0, :].exp() @property def theta_probs(self) -> torch.Tensor: r""" Posterior target-specific spot probability :math:`q(\theta = k, z=z_\mathsf{MAP})`. """ return self.compute_probs @property def pspecific(self) -> torch.Tensor: r""" Probability of there being a target-specific spot :math:`p(\mathsf{specific})` """ return self.z_probs @property def m_probs(self) -> torch.Tensor: r""" Posterior spot presence probability :math:`q(m=1, z=z_\mathsf{MAP})`. """ return Vindex(torch.permute(pyro.param("m_probs").data, (1, 2, 3, 4, 0)))[ ..., self.z_map.long() ] def z_sample(self, num_samples): init_probs = self.params["z_trans"][: self.data.N, 0, :, 0] init_probs = init_probs.expand((num_samples,) + init_probs.shape) x = dist.Categorical(init_probs).sample() trans_probs = self.params["z_trans"][: self.data.N, 1:].permute(0, 2, 1, 3, 4) trans_probs = trans_probs.expand((num_samples,) + trans_probs.shape) xs = dist.Categorical(trans_probs).sample() xs = _sequential_index(xs) x = Vindex(xs)[..., :, x] return x.permute(0, 1, 3, 2) @torch.no_grad() def compute_params(self, CI): params = super().compute_params(CI) params["z_trans"] = pyro.param("z_trans").cpu() return params