Traffine I/O

日本語

2023-01-27

DistilBERT

DistilBERT とは

DistilBERTとはHuggingface社が2019年に発表した DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter という論文で登場したモデルです。DistilBERTは BERT をベースとした Transformer モデルであり、BERT-baseよりもパラメータが40%少なく、60%高速に動作し、GLUE Benchmarkで測定されたBERTの97%の性能を維持することができます。

DistilBERTは、Teacherと呼ばれる大きなモデルをStudentと呼ばれる小さなモデルに圧縮する技術である知識蒸留(Distillation Knowledge)を用いて訓練されます。BERTを蒸留することで、元のBERTモデルと多くの類似点を持ちながら、より軽量で実行速度が速いTransformerモデルを得ることができます。

DistilBERT 誕生の背景

近年のNLPでは、大量のパラメータをラベルなしデータで事前学習をすることにより、精度の高い言語モデルを作ることが主流となっています。BERTやGPTシリーズなどはモデルがどんどん巨大化し、1つのサンプルの計算も非常に時間がかかってしまい、一般ユーザーには計算負荷が高すぎるという課題があります。

Parameter counts

そのため、本論文では、精度を維持しつつより軽量なモデルを作ることを目指しています。

Knowledge Distillation

Knowledge Distillationとは、2015年に Distilling the Knowledge in a Neural Network という論文で登場した手法で、計算量の多い大きなモデル(Teacherモデル)から、妥当性を保ちながら小さなモデル(Studentモデル)へ知識を圧縮して転送するというものです。StudentはTeacherの振る舞いを再現するように訓練されます。ただし、StudentはTeacherほど大きな記憶容量(パラメータ数)はありませんので、その少ない記憶容量でできるだけTeacherに近づけるようにします。

  • Teacherモデル
    非常に大きなモデルや、ドロップアウトのような強力な正則化を行なって個別に学習したモデルのアンサンブル

  • Studentモデル
    Teacherモデルの蒸留された知識に依存する小さなモデル

Knowledge distillation
Knowledge Distillation: A Survey

DistilBERTにおいて、TeacherはBERT、StudentはDistilBERTを指します。

DistilBERT の構造

Student のアーキテクチャ

StudentであるDistilBERTは、BERTと同じ一般的なアーキテクチャを踏襲しています。DistilBERTではパラメータを削減するために次の施策を行います。

  • Token embeddings層の削除
  • Pooler層(Transformer Encoderの後の分類用の層)の削除
  • Transformerの層数を半分にする(BERT_BASE: 12, DistilBERT: 6)

これらの施策により、パラメータ数が40%減少することになります。

Student の初期値

Studentの初期値にはTeacherであるBERTのパラメータが使われます。ただし、Studentの層数はTeacherの層数の半分なので、対応する2つの層のうちの1つを初期値としています。

Distillation

DistilBERTでは学習方法を次のように設定しています。

  • バッチサイズを4,000と大きくする
  • ダイナミックにマスキングを行う
  • Next Sentence Predictionは実施しない

トレーニングの損失関数

DistilBERTの損失関数はDistillation Loss (L_{ce})、 Masked Language Modelling Loss (L_{mlm})、 Cosine Embedding Loss (L_{cos}) の線形結合で表されます。

L = L_{ce} + L_{mlm} + L_{cos}

Distillation Loss

Distillation Loss (L_{ce}) はStudentがTeacherを模倣するようにする損失関数であり、次の式で表されます。

L_{ce} = - \sum_i t_i \times \log(s_i)

ここで、t_i はTeacher (BERT) による i 番目の単語の予測確率であり、s_i はStudent (DistilBERT) による i 番目の単語の予測確率です。

上式の損失関数よって、もっとも予測確率が高い単語だけでなく、その次に予測確率が高い単語の予測確率なども真似することができ、Teacherの予測する分布を学習することができます。

例えば次の文章があるとします。

I watched [MASK] yesterday.

TeacherであるBERTの予測確率が次のようになったとします。

予測単語 予測確率
movie 0.8
TV 0.1
baseball 0.05
MMA 0.05

このとき、Studentは「movie」だけでなく、「TV」や「MMA」、「baseball」の予測確率もTeacherであるBERTに近づけるように学習します。

ただし、ここでは予測確率 p_isoftmax–temperatureという関数を使って表します。

p_i = \frac{\exp(z_i / T)}{\sum_j \exp(z_j) / T}

T を1とすると普通のSoftmax関数になりますが、T を1よりも大きくすると普通のSoftmaxよりも予測確率が低い単語を予測する確率が上がります。

softmax–temperatureを使うことで予測確率の小さい単語の予測確率も学習するように仕向けています。

Masked Language Model Loss

Masked Language Modelling Loss (L_{mlm}) はBERTと同じMasked Language Modelによる損失関数です。

Cosine Embedding Loss

Cosine Embedding Loss (L_{cos}) はTeacher (BERT) とStudent (DistilBERT) の隠れ層のベクトルのコサイン類似度を損失関数とします。Cosine Embedding LossによりStudentの文章の埋め込みベクトルをTeacherの埋め込みベクトルに近づけることができます。

参考

https://arxiv.org/abs/1910.01108
https://arxiv.org/abs/1503.02531
https://arxiv.org/abs/2006.05525
https://arunm8489.medium.com/understanding-distil-bert-in-depth-5f2ca92cf1ed
https://www.analyticsvidhya.com/blog/2022/11/introduction-to-distilbert-in-student-model/
https://medium.com/huggingface/distilbert-8cf3380435b5
https://data-analytics.fun/2020/10/06/understanding-distilbert/

Ryusei Kakujo

researchgatelinkedingithub

Focusing on data science for mobility

Bench Press 100kg!