生存時間解析をやってみる1(PyMC3)

1.生存時間解析

生存時間解析は、医療分野や工学分野を始めとした様々な領域で重要な統統計的手法として広く利用されています。この手法は、時間の経過とともに特定のイベントが発生するまでの時間を推定し、その確率や影響要因を明らかにするための有力な方法です。生存時間解析は、通常、イベントが発生するまでの時間、あるいは生存期間と呼ばれる期間に注目します。例えば、がん患者の場合、治療開始から再発や死亡が起こるまでの期間や製品の故障発生までの期間が生存期間に相当します。来の統計手法では、時間の要素を考慮することが難しいため、生存時間解析が開発されました。この手法では、生存関数とハザード関数という特別な統計的ツールを用いて、生存期間の分布や特定の時点でのイベント発生の危険性を推定することが可能です。生存時間解析の利点は、データの欠損を扱う能力や、イベントが発生しないまま終了する(censored)データを考慮することができる点にあります。これによりイベント発生をまたず、イベントがその時刻まで発生しなかったという情報をモデルに盛り込み、より妥当な生存時間を推定できます。
生存時間解析には指数分布やワイブル分布が用いられます。指数分布を用いる場合は、ハザード(特定の時点でのイベント発生確率)は一定になります。ハザードを時間に依存して変化させたモデル化を行いたい場合、ワイブル分布を用います。$a,b$を正の数とすると、密度分布は以下で表せます($x\geq 0$)。

 f(x)=\frac{bx^{b-1}}{a^b}\exp\left(-\left(\frac{x}{a}\right)^b\right)

ワイブル分布は発生確率に時間依存性を許容した一般化に対応するため、故障現象を説明するためによく用いられます。発生確率は以下の式で表され、$b=1$の場合に時間に対する依存性がなくなり指数分布に合致します。

 \lambda(x)=\frac{b}{a^b}x^{b-1}

ここで、$b$は発生確率の時間依存性の指数部分を決める項で、1より大きければ単調増加、小さければ単調減少となり、それぞれ摩耗的な故障、初期的な故障と対応付けることが可能です。

2.PyMC3による生存時間解析の実装

生存時間解析(故障時間解析)を行う場合、以下のようなデータを使用して解析を行っていくことになる。生存時間解析では、途中で打ち切られるデータ(少なくともその時間までは生存していた時間)も重要な情報となる。たとえば、架空の装置の故障データを使用して装置寿故障に関する生存時間解析を考える場合、Timeカラムは故障発生が発生する時間に相当する。Censoredカラムは観測の打ち切りを表す。Falseの場合、観測は打ち切られず装置がTimeカラムの時間で故障したことを表す。Trueの場合は、観測が打ち切られたことを示し、Timeカラムに記載の時間はその時間までは装置は故障せずに稼働していたことを示す。

No Time Censored
sample_1 27 False
sample_2 50 True
sample_3 10 False
sample_4 40 True
sample_5 33 False

それでは架空の装置に関する生存時間解析(故障時間解析)をPyMC3で行ってみる。 ワイブル分布を用いてモデリングしてみる。 今回打切りなしデータ(実施の故障観測データ)のみでモデル化する場合と打切りなしデータ、打切りありデータ両方を使用してモデル化した場合の結果を比較してみたいと思う。

2.1.データの作成

ワイブル分布を仮定して架空の装置の故障データを生成します。

import numpy as np
import pandas as pd
import arviz as az
import pymc3 as pm
from theano import shared, tensor as tt
import seaborn as sns
import matplotlib.pyplot as plt
sns.set()

# ワイブル分布のパラメータ
shape = 2 # ワイブル分布の形状パラメータ
scale = 28.0  # ワイブル分布の尺度パラメータ

# ランダムな故障時間を生成
np.random.seed(42)  # 乱数のシードを設定(再現性のため)
failures = np.random.weibull(shape, size=100) * scale

plt.hist(failures)
plt.xlabel('Time[Day]')
plt.ylabel('Count')
plt.show()

データ生成後、40日以降のデータは故障発生時刻ではなく、観測打切りデータとして扱って生存時間解析を進めてみます。

#40日よりも後を観測打切りデータとする。
df_failures=pd.DataFrame(failures, columns=['Time'])
#打ち切りかどうかを示すCensoredラムを追加する。
df_failures['censored'] = df_failures['Time'] >= 40
y = df_failures['Time'].values
censored = df_failures['censored'].values
df_failures.head()
Time    censored
0   19.180881   False
1   48.579164   True
2   32.129871   False
3   26.753448   False
4   11.531951   False

2.2.打ち切りなしデータのみを用いた生存時間解析

#打ち切りなしデータのみでベイズ推論
with pm.Model() as model_1:
    alpha_sd = 10.0
    #ワイブル分布の形状パラメータのパラメータの事前分布
    mu = pm.Normal("mu", mu=0, sigma=100)
    #ワイブル分布の形状パラメータのパラメータの事前分布
    alpha_raw = pm.Normal("a0", mu=0, sigma=0.1)
    #ワイブル分布の形状パラメータ
    alpha = pm.Deterministic("alpha", tt.exp(alpha_sd * alpha_raw))
    #ワイブル分布の尺度パラメータ
    beta = pm.Deterministic("beta", tt.exp(mu / alpha))
    #ワイブル分布からのサンプリング
    y_obs = pm.Weibull("y_obs", alpha=alpha, beta=beta, observed=y[~censored])

#MCMCによるサンプリング
with model_1:
    trace_1 = pm.sample(target_accept=0.9, init="adapt_diag")

サマリーの確認

pm.summary(trace_1)
 mean    sd  hdi_3%  hdi_97% mcse_mean   mcse_sd ess_bulk    ess_tail    r_hat
mu  6.946   0.612   5.849   8.083   0.024   0.017   678.0   901.0   1.01
a0  0.079   0.008   0.063   0.094   0.000   0.000   711.0   926.0   1.01
alpha   2.209   0.182   1.885   2.550   0.007   0.005   711.0   926.0   1.01
beta    23.197  1.184   20.902  25.344  0.025   0.017   2300.0  2355.0  1.00

サンプリング結果をプロット

_ = pm.traceplot(trace_1)
_ = pm.plot_posterior(trace_1)

ベイズモデリングで推定したパラメータを使用してワイブル分布の各時刻の故障確率(pdf)と特定の時刻までの生存確率(cdf)を算出する関数を準備しておく。 なお、cfdはその時刻までの各時刻での故障発生確率(pdf)を時間で積分することで得られる。

# pdf
def weibProbDist(x, a, b):
    return (a / b) * (x / b) ** (a - 1) * np.exp(-(x / b) ** a)

# cdf
def weibCumDist(x, a, b):
    return np.exp(-(x / b) ** a)

ベイズモデリングにより推定したパラメータを使用して各時刻の故障発生確率(左)と各時刻までの故障発生確率(右)を可視化する。

shape_ = 2.209
scale_ = 23.197
t = np.linspace(1, 90 ,90)

plt.plot(t, weibProbDist(t, shape_, scale_))
plt.title('pdf')
plt.xlabel('Time[Day]')
plt.ylabel('Probability')
plt.show()

plt.plot(t, weibDumDist(t, shape_, scale_))
plt.title('pdf')
plt.xlabel('Time[Day]')
plt.ylabel('Probability')
plt.show()

打切りなしデータのみでモデルを作成しているためcdfをみると40日以降の生存確率はほぼ0になっており、本来は40日以降にも故障せずに稼働している装置が存在することを考慮できていない。

2.3.打ち切りなし+打ち切りありデータを用いた生存時間解析

続いて、打切りデータも取り入れてモデル化してみる。

def weibull_lccdf(x, alpha, beta):
    """ Log complementary cdf of Weibull distribution. """
    return -((x / beta) ** alpha)

#打ち切りなしデータ+打ち切りありデータ
with pm.Model() as model_2:
    alpha_sd = 10.0
    mu = pm.Normal("mu", mu=0, sigma=100)
    alpha_raw = pm.Normal("a0", mu=0, sigma=0.1)
    alpha = pm.Deterministic("alpha", tt.exp(alpha_sd * alpha_raw))
    beta = pm.Deterministic("beta", tt.exp(mu / alpha))

    y_obs = pm.Weibull("y_obs", alpha=alpha, beta=beta, observed=y[~censored])
    #事後対数尤度に制約としてつけ加える。
    #この項を含めた事後対数尤度が最大になるようにMCMCでサンプリングする。
    y_cens = pm.Potential('y_cens', weibull_lccdf(y[censored], alpha, beta))

with model_2:
    trace_2 = pm.sample(target_accept=0.9, init="adapt_diag")

上記のように打切りデータをモデルにベイズモデリングの中に組み込めるようにワイブル分布の対数補完累積分布関数を定義している。この関数は与えられたワイブル分布パラメータのもとでx以下になる累積確率の対数値を返す。この値をMCMCの事後対数尤度に加えたものが最大になるようにパラメータを決定する。こうすることで打切りデータが持つ「最低限その時点までは生存していた」という情報を確率としてモデルに組み込むことができる。

サンプリング結果を可視化

_ = pm.traceplot(trace_2)
_ = pm.plot_posterior(trace_2)

サマリーの確認

pm.summary(trace_2)
mean sd  hdi_3%  hdi_97% mcse_mean   mcse_sd ess_bulk    ess_tail    r_hat
mu  5.625   0.513   4.622   6.559   0.019   0.014   745.0   628.0   1.0
a0  0.052   0.009   0.037   0.069   0.000   0.000   748.0   622.0   1.0
alpha   1.691   0.145   1.419   1.968   0.005   0.004   748.0   622.0   1.0
beta    27.893  1.840   24.534  31.411  0.034   0.024   2883.0  2072.0  1.0

ベイズモデリングにより推定したパラメータを使用して各時刻の故障発生確率(左)と各時刻までの故障発生確率(右)を可視化する。

shape_2 = 1.691
scale_2 = 27.893

plt.plot(t, weibProbDist(t, shape_2, scale_2))
plt.ylim(0, 0.04)

plt.title('pdf')
plt.xlabel('Time[Day]')
plt.ylabel('Probability')
plt.show()


plt.plot(t, weibCumDist(t, shape_2, scale_2))

plt.title('cdf')
plt.xlabel('Time[Day]')
plt.ylabel('Probability')
plt.show()

続いて、打切りデータのみの場合と打切りデータと打切りありデータでモデル化した場合の比較をしてみる。

打切りデータを考慮してモデリングが行えているため、40日以降の生存確率が打切りなしデータのみでモデリングした場合よりも高くなっている。生存時間解析では、打切りデータもモデル化に組み込むことができるため、より妥当な生存時間の推定が可能になる。故障発生期間が長い場合,、故障発生データを取得していくには長期間を要する。故障発生を待たず打切りデータとして生存情報を考慮できるのは生存時間解析のメリットである。