Traffine I/O

日本語

2022-12-09

ディリクレ分布

ディリクレ分布とは

ディリクレ分布とは、 K 個の事象がそれぞれ \alpha_{i}-1 回発生したときに各事象の起こる確率x_i が従う分布です。ここで、\alpha _{i} は整数である必要はありません。

ディリクレ分布の確率密度関数は次の式で表されます。

f(x_1, x_2,.., x_{K-1}) = \frac{\Gamma(\sum^K_{i=1} \alpha_i)}{\prod_{i=1}^K\Gamma(\alpha_i)}\prod_{i=1}^Kx_i^{\alpha_i-1}
\sum^n_{i=1} = 1, \quad x_i \in \{0, 1 \}

上式において K=2 の場合はベータ分布の確率密度関数に一致します。つまり、ディリクレ分布はベータ分布を多変量に拡張した分布と言えます。

ディリクレ分布の期待値と分散

ディリクレ分布の期待値、分散はそれぞれ以下になります。

E(X_i)=\frac{\alpha_i}{\sum^{K}_{i=1}\alpha_i} \quad (i=1,2,...,K-1)
V(X_i)=\frac{\alpha_i(\sum^{K}_{i=1} \alpha_i-\alpha_i)}{(\sum^{K}_{i=1}\alpha_i)^2 (\sum^{K}_{i=1}\alpha_i + 1)} \quad (i=1,2,...,K-1)

パラメータの影響を確認

3次元(K=3)のディリクレ分布をPythonで可視化してみます。

import numpy as np
from scipy import special
from scipy.stats import dirichlet
import matplotlib
import matplotlib.pyplot as plt

plt.style.use('ggplot')

class Dirichlet():
    def __init__(self, param: list) -> None:
        self.param = np.array(param)

    def pdf(self, x: list) -> np.float:
        x_ar = np.array(x)
        cons = np.prod(special.gamma(self.param))/(special.gamma(np.sum(self.param)))
        p = (1./cons) * np.prod(x_ar**(self.param-1))
        return p

    def plt_3d(self, nrow: int, ncol: int, n: int, zlim=None)->None:
        xdata = np.linspace(0, 1, 200)
        ydata = np.linspace(0, 1, 200)
        X,Y = np.meshgrid(xdata, ydata)
        z = []
        X[X+Y>1] = 0
        Y[X+Y>1] = 0
        for _x, _y, _z in zip(X.flatten(), Y.flatten(), (1-X-Y).flatten()):
            z.append(self.pdf([_x, _y, _z]))

        Z = np.array(z).reshape(X.shape)
        fig = plt.figure(figsize=(10, 3))
        ax = plt.axes(projection='3d')
        ax.plot_surface(X, Y, Z, cmap='plasma')
        ax.set_zlim(zlim)
        ax.set_xlabel("$x_1$")
        ax.set_ylabel("$x_2$")
        ax.set_zlabel("PDF")
        ax.set_title("Dir($\\vec{\\alpha} = $" + "%s)" % self.param)
        plt.show()

Dirichlet([1.,1.,1.]).plt_3d(2, 2, 1, zlim=(0, 2.1))
Dirichlet([5.,5.,5.]).plt_3d(2, 2, 2)
Dirichlet([5.,1.,1.]).plt_3d(2, 2, 3)
Dirichlet([1.,5.,1.]).plt_3d(2, 2, 2)

Dirichlet distribution

3つの \alpha が同じ値の場合、\alpha が大きいほど分散が小さくなっていることが分かります。また、\alpha が大きい確率変数は確率密度が高くなっていることが分かります。

Ryusei Kakujo

researchgatelinkedingithub

Focusing on data science for mobility

Bench Press 100kg!