Source code for tapqir.utils.imscroll

# Copyright Contributors to the Tapqir project.
# SPDX-License-Identifier: Apache-2.0

from functools import singledispatch

import numpy as np
import pandas as pd
import torch
from pyro.ops.stats import pi, resample
from pyroapi import distributions as dist


[docs]@singledispatch def count_intervals(labels): r""" Count binding interval data. Co-localization and absent intervals are coded as -3 and -2 respectively when they are the first (or only) interval in a record, 3 and 2 when they are the last interval in a record and 1 and 0 elsewhere. Reference:: @article{friedman2015multi, title={Multi-wavelength single-molecule fluorescence analysis of transcription mechanisms}, author={Friedman, Larry J and Gelles, Jeff}, journal={Methods}, volume={86}, pages={27--36}, year={2015}, publisher={Elsevier} } """ raise NotImplementedError
@count_intervals.register(np.ndarray) def _(labels): z = labels labels = labels.astype("bool") start_condition = ( np.concatenate((~labels[..., 0:1], labels[..., :-1]), axis=-1) != labels ) start_sample, start_aoi, start_frame = np.nonzero(start_condition) start_type = labels.astype("long") start_type[..., 0] = -start_type[..., 0] - 2 start_type = start_type[start_sample, start_aoi, start_frame] stop_condition = np.concatenate( (labels[..., :-1] != labels[..., 1:], np.ones_like(labels[..., 0:1])), axis=-1 ) stop_sample, stop_aoi, stop_frame = np.nonzero(stop_condition) stop_type = labels.astype("long") stop_type[..., -1] += 2 stop_type = stop_type[stop_sample, stop_aoi, stop_frame] assert np.array_equal(start_aoi, stop_aoi) low_or_high = np.where(abs(start_type) > abs(stop_type), start_type, stop_type) z_type = z[start_sample, start_aoi, start_frame] result = pd.DataFrame( data={ "posterior_sample": start_sample, "aoi": start_aoi, "start_frame": start_frame, "stop_frame": stop_frame, "dwell_time": stop_frame + 1 - start_frame, "low_or_high": low_or_high, "z": z_type, } ) return result @count_intervals.register(torch.Tensor) def _(labels): z = labels labels = labels.bool() start_condition = torch.cat((~labels[..., 0:1], labels[..., :-1]), dim=-1) != labels start_sample, start_aoi, start_frame = torch.nonzero(start_condition, as_tuple=True) start_type = labels.long() start_type[..., 0] = -start_type[..., 0] - 2 start_type = start_type[start_sample, start_aoi, start_frame] stop_condition = torch.cat( (labels[..., :-1] != labels[..., 1:], torch.ones_like(labels[..., 0:1])), dim=-1 ) stop_sample, stop_aoi, stop_frame = torch.nonzero(stop_condition, as_tuple=True) stop_type = labels.long() stop_type[..., -1] += 2 stop_type = stop_type[stop_sample, stop_aoi, stop_frame] assert torch.equal(start_aoi, stop_aoi) low_or_high = torch.where(abs(start_type) > abs(stop_type), start_type, stop_type) z_type = z[start_sample, start_aoi, start_frame] result = pd.DataFrame( data={ "posterior_sample": start_sample, "aoi": start_aoi, "start_frame": start_frame, "stop_frame": stop_frame, "dwell_time": stop_frame + 1 - start_frame, "low_or_high": low_or_high, "z": z_type, } ) return result def bound_dwell_times(intervals): assert isinstance(intervals, pd.DataFrame) mask = intervals["low_or_high"] == 1 result = intervals.loc[mask, ["posterior_sample", "dwell_time"]] value_counts = result["posterior_sample"].value_counts() max_count = value_counts.max() n_values = len(value_counts) data = np.zeros((n_values, max_count), dtype=np.float32) for i in range(n_values): mask = result["posterior_sample"] == i dwell_times = result.loc[mask, "dwell_time"] data[i, : len(dwell_times)] = dwell_times.values return data def unbound_dwell_times(intervals): assert isinstance(intervals, pd.DataFrame) mask = intervals["low_or_high"] == 0 result = intervals.loc[mask, ["posterior_sample", "dwell_time"]] value_counts = result["posterior_sample"].value_counts() max_count = value_counts.max() n_values = len(value_counts) data = np.zeros((n_values, max_count), dtype=np.float32) for i in range(n_values): mask = result["posterior_sample"] == i dwell_times = result.loc[mask, "dwell_time"] data[i, : len(dwell_times)] = dwell_times.values return data
[docs]@singledispatch def time_to_first_binding(labels): r""" Measure the time elapsed prior to the first binding. Time-to-first binding for a binary data: :math:`\mathrm{ttfb} = \sum_{f=1}^{F-1} f z_{n,f} \prod_{f^\prime=0}^{f-1} (1 - z_{n,f^\prime}) + F \prod_{f^\prime=0}^{F-1} (1 - z_{n,f^\prime})` Expected value of the time-to-first binding: :math:`\mathbb{E}[\mathrm{ttfb}] = \sum_{f=1}^{F-1} f q(z_{n,f}=1) \prod_{f^\prime=f-1}^{f-1} q(z_{n,f^\prime}=0) + F \prod_{f^\prime=0}^{F-1} q(z_{n,f^\prime}=0)` Reference:: @article{friedman2015multi, title={Multi-wavelength single-molecule fluorescence analysis of transcription mechanisms}, author={Friedman, Larry J and Gelles, Jeff}, journal={Methods}, volume={86}, pages={27--36}, year={2015}, publisher={Elsevier} } """ raise NotImplementedError
@time_to_first_binding.register(np.ndarray) def _(labels): labels = labels.astype("float") F = labels.shape[-1] frames = np.arange(1, F + 1) q1 = np.ones_like(labels) q1[..., :-1] = labels[..., 1:] cumq0 = np.cumprod(1 - labels, axis=-1) ttfb = (frames * q1 * cumq0).sum(-1) return ttfb @time_to_first_binding.register(torch.Tensor) def _(labels): labels = labels.float() F = labels.shape[-1] frames = torch.arange(1, F + 1) q1 = torch.ones_like(labels) q1[..., :-1] = labels[..., 1:] cumq0 = torch.cumprod(1 - labels, dim=-1) ttfb = (frames * q1 * cumq0).sum(-1) return ttfb
[docs]@singledispatch def association_rate(labels): r""" Compute the on-rate from the binary data assuming a two-state HMM model. """ raise NotImplementedError
@association_rate.register(np.ndarray) def _(labels): binding_events = ((1 - labels[..., :-1]) * labels[..., 1:]).sum((-2, -1)) off_states = (1 - labels[..., :-1]).sum((-2, -1)) kon = binding_events / off_states return kon @association_rate.register(torch.Tensor) def _(labels): labels = labels.float() binding_events = ((1 - labels[..., :-1]) * labels[..., 1:]).sum((-2, -1)) off_states = (1 - labels[..., :-1]).sum((-2, -1)) kon = binding_events / off_states return kon
[docs]@singledispatch def dissociation_rate(labels): r""" Compute the off-rate from the binary data assuming a two-state HMM model. """ raise NotImplementedError
@dissociation_rate.register(np.ndarray) def _(labels): dissociation_events = (labels[..., :-1] * (1 - labels[..., 1:])).sum((-2, -1)) on_states = labels[..., :-1].sum((-2, -1)) koff = dissociation_events / on_states return koff @dissociation_rate.register(torch.Tensor) def _(labels): labels = labels.float() dissociation_events = (labels[..., :-1] * (1 - labels[..., 1:])).sum((-2, -1)) on_states = labels[..., :-1].sum((-2, -1)) koff = dissociation_events / on_states return koff
[docs]@singledispatch def bootstrap(samples, estimator, repetitions=1000, probs=0.68): r""" Estimate the confidence interval of an estimator by constructing approximating distributions using the bootstrap method (resampling with replacement). """ raise NotImplementedError
@bootstrap.register(np.ndarray) def _(samples, estimator, repetitions=1000, probs=0.68): estimand = np.zeros((repetitions,)) for i in range(repetitions): bootstrap_values = np.random.choice(samples, size=len(samples), replace=True) estimand[i] = estimator(bootstrap_values) return np.quantile(estimand, (1 - probs) / 2), np.quantile( estimand, (1 + probs) / 2 ) @bootstrap.register(torch.Tensor) def _(samples, estimator, repetitions=1000, probs=0.68): estimand = torch.zeros(repetitions) for i in range(repetitions): bootstrap_values = resample(samples, num_samples=len(samples), replacement=True) estimand[i] = estimator(bootstrap_values) return pi(estimand, probs)
[docs]@singledispatch def posterior_estimate(dist, estimator, repetitions=1000, probs=0.68): r""" A version of bootstrapping method where samples are first drawn from a distribution and then resampled with replacement. """ raise NotImplementedError
@posterior_estimate.register(dist.Distribution) def _(dist, estimator, repetitions=1000, probs=0.68): samples = dist.sample((repetitions,)) estimand = torch.zeros(repetitions) for i in range(repetitions): estimand[i] = estimator(samples[i]) return pi(estimand, probs)
[docs]@singledispatch def sample_and_bootstrap( dist, estimator, preprocess=None, repetitions=1000, probs=0.68 ): r""" A version of bootstrapping method where samples are first drawn from a distribution and then resampled with replacement. """ raise NotImplementedError
@sample_and_bootstrap.register(dist.Distribution) def _(dist, estimator, preprocess=None, repetitions=1000, probs=0.68): estimand = torch.zeros(repetitions) for i in range(repetitions): samples = dist.sample() if preprocess is not None: samples = preprocess(samples) # bootstrap_values = resample(samples, num_samples=len(samples), replacement=True) bootstrap_values = np.random.choice(samples, size=len(samples), replace=True) estimand[i] = estimator(bootstrap_values) return pi(estimand, probs)