Phylodynamics inference with scPhyloX
Import necessary packages
[1]:
import scPhyloX as spx
import numpy as np
import pandas as pd
import gzip
import arviz as az
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.gridspec import GridSpec
import matplotlib.lines as mlines
import pickle
import os
WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.
[2]:
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['font.size'] = 12
Here, we take overshoot simulation data as an example to demonstrate the scPhyloX inference pipeline.
[3]:
os.chdir('../../')
[4]:
with gzip.open('datasets/simulation/overshoot/character_matrix.csv.gz', 'rb') as f:
charater_matrix = pd.read_csv(f, index_col=0)
charater_matrix = charater_matrix.to_numpy()
with gzip.open('datasets/simulation/overshoot/simulation_data.csv.gz', 'rb') as f:
ground_truth = pd.read_csv(f, index_col=0)
ground_truth = ground_truth.to_numpy()
ground_truth = ground_truth[np.arange(0, ground_truth.shape[0], 1000)]
time = ground_truth[:, 0]
cell_number = ground_truth[:, 1:]
Derive leaf-root and leaf-progenitor distances from charater matrix
[5]:
mutnum = spx.data_factory.get_mutnum(charater_matrix)
branch_len = spx.data_factory.get_branchlen(charater_matrix)
100%|██████████| 125250/125250 [00:04<00:00, 26687.96it/s]
Perform mutation rate estimation
[6]:
idata_bl = spx.est_mr.mutation_rate_mcmc(branch_len, draw=500, tune=500)
Population sampling (4 chains)
DEMetropolis: [mu, delta]
Attempting to parallelize chains to all cores. You can turn this off with `pm.sample(cores=1)`.
100.00% [4/4 00:00<00:00]
100.00% [1000/1000 01:24<00:00]
Sampling 4 chains for 500 tune and 500 draw iterations (2_000 + 2_000 draws total) took 85 seconds.
/home/wangkun/miniconda3/lib/python3.9/site-packages/arviz/utils.py:184: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.
numba_fn = numba.jit(**self.kwargs)(self.function)
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
MAP estimation of generations
[7]:
ge = spx.est_mr.GenerationEst(mutnum, 2)
gen_num = ge.estimate(cell_number[-1].sum())
Pre-estimate phylodynamics parameters using differential evolution (DE) algorithm.
[8]:
res_de = spx.est_tissue.para_inference_DE(gen_num, T=35, c0=100)
iter99, loss:[184671.53399582], est=[1.2288103 0.32414623 1.05964971 4.66003588 2.39814377 7.29721605]45127]
[9]:
fig, ax = plt.subplots(figsize=(3.2,3))
ax.plot(np.array(res_de[1]).flatten())
ax.set_xlabel('iter')
ax.set_ylabel('loss')
[9]:
Text(0, 0.5, 'loss')
Setting MCMC priors for each parameter, based on DE estimation.
[10]:
axh, bxh, rh, dh, kh, t0h = res_de[0][-1]
dh = 10**(-dh)
mcmc_prior = (axh, bxh, rh, dh, kh, t0h)
Perform phylodynamics inference
[11]:
idata = spx.est_tissue.mcmc_inference(gen_num, mcmc_prior, T=35, c0=100, sigma=100)
Population sampling (8 chains)
DEMetropolis: [ax, bx, r, k, t0, d]
Attempting to parallelize chains to all cores. You can turn this off with `pm.sample(cores=1)`.
100.00% [8/8 00:01<00:00]
100.00% [2000/2000 10:54<00:00]
Sampling 8 chains for 1_000 tune and 1_000 draw iterations (8_000 + 8_000 draws total) took 656 seconds.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
Here, we can see all parameters are converged and consistent with the simulation preset ground truth.
[19]:
fig = plt.figure(layout='constrained',figsize=(11,3))
gs = GridSpec(2, 3, figure=fig)
ax1 = fig.add_subplot(gs[:,0])
ls = 'solid,dotted,dashed,dashdot'.split(',')
for i, l in enumerate(ls):
sns.kdeplot(idata_bl.posterior['mu'].to_numpy()[i], linestyle=l, ax=ax1, label=f'Chain {i+1}')
ax1.vlines(2, 0, 1.7, color='black')
ax1.legend(fontsize=10,loc=2)
ax1.set_title(r'$\mu$ distribution')
ax2 = fig.add_subplot(gs[:,1])
for i, l in enumerate(ls):
sns.kdeplot(idata_bl.posterior['delta'].to_numpy()[i], linestyle=l, ax=ax2, label=f'Chain {i+1}')
ax2.set_title(r'$\delta$ distribution')
ax3 = fig.add_subplot(gs[0,2])
for i, l in enumerate(ls):
ax3.plot(idata_bl.posterior['mu'][i], linestyle=l, label=f'Chain {i+1}')
ax3.set_title('$\mu$ mcmc trace')
ax3.set_xlabel('step')
ax4 = fig.add_subplot(gs[1,2])
for i, l in enumerate(ls):
ax4.plot(idata_bl.posterior['delta'][i], linestyle=l, label=f'Chain {i+1}')
ax4.set_title(r'$\delta$ mcmc trace')
ax4.set_xlabel('step')
[19]:
Text(0.5, 0, 'step')
[20]:
fig = plt.figure(figsize=(15,7))
plt.rcParams['font.size'] = 12
gt = [1.44, 0.56, 0.4, 0.3, 12, 0.01]
gs = GridSpec(6, 4, figure=fig)
ls = 'solid.dotted.dashed.dashdot.solid.dotted.dashed'.split('.')
for ind, sym in enumerate('ax,bx,r,k,t0,d'.split(',')):
if ind >= 3:
ax = fig.add_subplot(gs[3:, ind-3])
else:
ax = fig.add_subplot(gs[:3, ind])
for i, l in enumerate(ls):
sns.kdeplot(idata.posterior[sym].to_numpy()[i], linestyle=l, ax=ax, label=f'Chain {i+1}')
# ax1.legend(fontsize=10,loc=2)
ylim = ax.get_ylim()
ax.vlines(gt[ind], *ylim, color='black')
ax.set_ylim(ylim)
ax.set_title(fr'${sym}$ distribution')
ax0 = fig.add_subplot(gs[-1, 3])
for i, l in enumerate(ls):
ax0.plot(idata.posterior['d'][i], linestyle=l, label=f'Chain {i+1}')
ax0.set_title(fr'$d$ mcmc trace')
for ind, sym in enumerate('ax,bx,r,k,t0'.split(',')):
ax = fig.add_subplot(gs[ind, 3], sharex=ax0)
for i, l in enumerate(ls):
ax.plot(idata.posterior[sym][i], linestyle=l, label=f'Chain {i+1}')
ax.set_title(fr'${sym}$ mcmc trace')
plt.setp(ax.get_xticklabels(), visible=False)
ax0.set_xlabel('step')
plt.tight_layout()
[ ]: