triplet loss

文章XYは同じ意味か?を判断するようなモデルを学習する際に使われる損失関数。

一つの観察として、3つの文章を用意する

  • Anchor A. 任意の文章
  • Positive P. Aと同じ意味だが別の表現の文章
  • Negative N. Aと別の意味の文章

例えば

  • A: How old is he ?
  • P: What is his age ?
  • N: I’m 10 years old

など。

まず、類似度を計算する関数s(A,X)を用意し、s(A,N)s(A,P)という量を考える。 (s(A,X)が1つのネットワークで実装されている)

これは、s(A,N)が小さく、s(A,P)が大きいほど小さくなる。したがって、この量が小さくなるように学習すれば良い。 ただし、s(A,N)はいくらでも小さくし得るが、そうすることに意味はないので、

max{s(A,N)s(A,P),0}

損失関数とする。さらに、s(A,P)s(A,N)が0でなく有限の量だけ異なるように

max{s(A,N)s(A,P)+α,0}

を使うことが多い。

さらに、学習を早めるために、

max{meannegatives(A,P),0}max{closestnegatives(A,P),0}

といった量を使うこともある。

  • mean negative: バッチの中のnegativeのスコアの平均
  • closest negative: バッチの中のnegativeで一番高いスコア

コースでは、この2つの値の和を損失として使う。

データセット

Quora question answer dataset