Source code for tapqir.models.model

# Copyright Contributors to the Tapqir project.
# SPDX-License-Identifier: Apache-2.0

import logging
import random
from collections import defaultdict, deque
from pathlib import Path
from typing import Union

import numpy as np
import pandas as pd
import torch
from pyroapi import infer, optim, pyro
from sklearn.metrics import (
    confusion_matrix,
    matthews_corrcoef,
    precision_score,
    recall_score,
)
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from tapqir import __version__ as tapqir_version
from tapqir.exceptions import CudaOutOfMemoryError, TapqirFileNotFoundError
from tapqir.utils.dataset import load
from tapqir.utils.stats import save_stats

logger = logging.getLogger(__name__)


[docs]class Model: r""" Base class for tapqir models. Derived models must implement the methods * :meth:`model` * :meth:`guide` * :meth:`init_parameters` * :meth:`TraceELBO` :param S: Number of distinct molecular states for the binder molecules. :param K: Maximum number of spots that can be present in a single image. :param Q: Number of fluorescent dyes. :param device: Computation device (cpu or gpu). :param dtype: Floating point precision. :param priors: Dictionary of parameters of prior distributions. """ def __init__( self, S: int = 1, K: int = 2, Q: int = None, device: str = "cpu", dtype: str = "double", priors: dict = None, ): self.S = S self.K = K self._Q = Q self.nbatch_size = None self.fbatch_size = None # priors settings self.priors = priors # for plotting self.n = None self.f = None self.data_path = None self.path = None self.run_path = None # set device & dtype self.to(device, dtype)
[docs] def to(self, device: str, dtype: str = "double") -> None: """ Change tensor device and dtype. :param device: Computation device, either "gpu" or "cpu". :param dtype: Floating point precision, either "double" or "float". """ self.dtype = getattr(torch, dtype) self.device = torch.device(device) if device == "cuda" and dtype == "double": torch.set_default_tensor_type(torch.cuda.DoubleTensor) elif device == "cuda" and dtype == "float": torch.set_default_tensor_type(torch.cuda.FloatTensor) elif device == "cpu" and dtype == "double": torch.set_default_tensor_type(torch.DoubleTensor) else: torch.set_default_tensor_type(torch.FloatTensor)
@property def Q(self): return self._Q or self.data.C
[docs] def load(self, path: Union[str, Path], data_only: bool = True) -> None: """ Load data and optionally parameters from a specified path :param path: Path to Tapqir analysis folder. :param data_only: Load only data or both data and parameters. """ # set path self.path = Path(path) self.run_path = self.path / ".tapqir" # load data self.data = load(self.path, self.device) logger.debug(f"Loaded data from {self.path / 'data.tpqr'}") # load fit results if not data_only: try: self.params = torch.load(self.path / f"{self.name}_params.tpqr") except FileNotFoundError: raise TapqirFileNotFoundError( "parameter", self.path / f"{self.name}_params.tpqr" ) try: self.summary = pd.read_csv( self.path / f"{self.name}_summary.csv", index_col=0 ) except FileNotFoundError: raise TapqirFileNotFoundError( "summary", self.path / f"{self.name}_summary.csv" )
[docs] def model(self): """ Generative Model """ raise NotImplementedError
[docs] def guide(self): """ Variational Distribution """ raise NotImplementedError
[docs] def TraceELBO(self, jit): """ A trace implementation of ELBO-based SVI. """ raise NotImplementedError
[docs] def init_parameters(self): """ Initialize variational parameters. """ raise NotImplementedError
[docs] def init( self, lr: float = 0.005, nbatch_size: int = 5, fbatch_size: int = 512, jit: bool = False, ) -> None: """ Initialize SVI object. :param lr: Learning rate. :param nbatch_size: AOI batch size. :param fbatch_size: Frame batch size. :param jit: Use JIT compiler. """ self.lr = lr self.optim_fn = optim.Adam self.optim_args = {"lr": lr, "betas": [0.9, 0.999]} self.optim = self.optim_fn(self.optim_args) try: self.load_checkpoint() except TapqirFileNotFoundError: pyro.clear_param_store() self.iter = 0 self.converged = False self._rolling = defaultdict(lambda: deque([], maxlen=100)) self.init_parameters() self.elbo = self.TraceELBO(jit) self.svi = infer.SVI(self.model, self.guide, self.optim, loss=self.elbo) self.nbatch_size = min(nbatch_size, self.data.Nt) self.fbatch_size = min(fbatch_size, self.data.F)
[docs] def run(self, num_iter: int = 0, progress_bar=tqdm) -> None: """ Run inference procedure for a specified number of iterations. If num_iter equals zero then run till model converges. :param num_iter: Number of iterations. """ use_crit = False if not num_iter: use_crit = True num_iter = 100000 logger.debug("Tapqir version - {}".format(tapqir_version)) logger.debug("Model - {}".format(self.name)) logger.debug("Device - {}".format(self.device)) logger.debug("Floating precision - {}".format(self.dtype)) logger.debug("Optimizer - {}".format(self.optim_fn.__name__)) logger.debug("Learning rate - {}".format(self.lr)) logger.debug("AOI batch size - {}".format(self.nbatch_size)) logger.debug("Frame batch size - {}".format(self.fbatch_size)) with SummaryWriter(log_dir=self.run_path / "logs" / self.name) as writer: for i in progress_bar(range(num_iter)): try: self.iter_loss = self.svi.step() # save a checkpoint every 200 iterations if not self.iter % 200: self.save_checkpoint(writer) if use_crit and self.converged: logger.info(f"Iteration #{self.iter} model converged.") break self.iter += 1 except ValueError: # load last checkpoint self.init( lr=self.lr, nbatch_size=self.nbatch_size, fbatch_size=self.fbatch_size, ) # change rng seed new_seed = random.randint(0, 100) pyro.set_rng_seed(new_seed) logger.warning( f"Iteration #{self.iter} restarting with a new seed: {new_seed}." ) except RuntimeError as err: assert err.args[0].startswith("CUDA out of memory") raise CudaOutOfMemoryError() else: logger.warning(f"Iteration #{self.iter} model has not converged.")
[docs] def save_checkpoint(self, writer: SummaryWriter = None): """ Save checkpoint. :param writer: SummaryWriter object. """ # save only if no NaN values for k, v in pyro.get_param_store().items(): if torch.isnan(v).any() or torch.isinf(v).any(): raise ValueError( "Iteration #{}. Detected NaN values in {}".format(self.iter, k) ) # update convergence criteria parameters for name in self.conv_params: if name == "-ELBO": self._rolling["-ELBO"].append(self.iter_loss) elif pyro.param(name).ndim == 1: for i in range(len(pyro.param(name))): self._rolling[f"{name}_{i}"].append(pyro.param(name)[i].item()) else: self._rolling[name].append(pyro.param(name).item()) # check convergence status self.converged = False if len(self._rolling["-ELBO"]) == self._rolling["-ELBO"].maxlen: crit = all( torch.tensor(value).std() / torch.tensor(value)[-50:].std() < 1.05 for value in self._rolling.values() ) if crit: self.converged = True # save the model state torch.save( { "iter": self.iter, "params": pyro.get_param_store().get_state(), "optimizer": self.optim.get_state(), "rolling": dict(self._rolling), "convergence_status": self.converged, }, self.run_path / f"{self.name}_model.tpqr", ) # save global parameters for tensorboard writer.add_scalar("-ELBO", self.iter_loss, self.iter) for name, val in pyro.get_param_store().items(): if val.dim() == 0: writer.add_scalar(name, val.item(), self.iter) elif val.dim() == 1 and len(val) <= self.Q * 2: scalars = {str(i): v.item() for i, v in enumerate(val)} writer.add_scalars(name, scalars, self.iter) elif val.dim() == 2 and len(val) <= self.Q * 2: scalars = { f"{i}_{j}": k.item() for i, v in enumerate(val) for j, k in enumerate(v) } writer.add_scalars(name, scalars, self.iter) if False and self.data.labels is not None: pred_labels = ( self.pspecific_map[self.data.is_ontarget].cpu().numpy().ravel() ) true_labels = self.data.labels["z"].ravel() metrics = {} with np.errstate(divide="ignore", invalid="ignore"): metrics["MCC"] = matthews_corrcoef(true_labels, pred_labels) metrics["Recall"] = recall_score(true_labels, pred_labels, zero_division=0) metrics["Precision"] = precision_score( true_labels, pred_labels, zero_division=0 ) neg, pos = {}, {} neg["TN"], neg["FP"], pos["FN"], pos["TP"] = confusion_matrix( true_labels, pred_labels, labels=(0, 1) ).ravel() writer.add_scalars("ACCURACY", metrics, self.iter) writer.add_scalars("NEGATIVES", neg, self.iter) writer.add_scalars("POSITIVES", pos, self.iter) logger.debug(f"Iteration #{self.iter}: Successful.")
[docs] def load_checkpoint( self, path: Union[str, Path] = None, param_only: bool = False, warnings: bool = False, ): """ Load checkpoint. :param path: Path to model checkpoint. :param param_only: Load only parameters. :param warnings: Give warnings if loaded model has not been fully trained. """ device = self.device path = Path(path) if path else self.run_path model_path = path / f"{self.name}_model.tpqr" try: checkpoint = torch.load(model_path, map_location=device) except FileNotFoundError: raise TapqirFileNotFoundError("model", model_path) pyro.clear_param_store() pyro.get_param_store().set_state(checkpoint["params"]) if not param_only: self.converged = checkpoint["convergence_status"] self._rolling = checkpoint["rolling"] self.iter = checkpoint["iter"] self.optim.set_state(checkpoint["optimizer"]) logger.info( f"Iteration #{self.iter}. Loaded a model checkpoint from {model_path}" ) if warnings and not checkpoint["convergence_status"]: logger.warning(f"Model at {path} has not been fully trained")
[docs] def compute_stats(self, CI: float = 0.95, save_matlab: bool = False): """ Compute credible regions (CI) and other stats. :param CI: credible region. :param save_matlab: Save output in Matlab format as well. """ try: save_stats(self, self.path, CI=CI, save_matlab=save_matlab) except RuntimeError as err: assert err.args[0].startswith("CUDA out of memory") raise CudaOutOfMemoryError() logger.debug("Computing stats: Successful.")