こんにちは、AIチームの戸田です
今回は去年Google検索に導入されたことでも話題になったBERTを使った比較実験の記事を書かせていただきます
というのも昨年発表報告を書かせていただいた第10回対話シンポジウム、参加して特に印象に残ったことの一つとして、文章をベクトルに変換するモデルとして BERT^1 を使用するのが当たり前になっていたことがあります
私が遅れているだけなのかもしれませんが、とりあえず文章をベクトル化するときはBERTという雰囲気で、Word2Vecで得られた単語ベクトルをコネコネ…とやっているのは(おそらく)今回の会議では私達だけだったと思います
BERTはファインチューニングにより自然言語処理の多くのタスクでState of the artを達成しましたが、単純な文書ベクトル抽出器としての能力はどうなんでしょうか?
私は手軽に文章の分散表現を得る方法としてWord2Vecから得られた単語ベクトルの平均やmax poolingをとる SWEM^2をよく使うのですが、語順が入れ替わった文章やノイズのある文章などでは、なかなか思うようなベクトルが得られないことが多々あります
本記事ではSWEMで得られたベクトルとBERTで得られたベクトルを比較し、SWEMでの課題をBERTが解決してくれるかを検証したいと思います
Word2VecやSWEM、BERTについての説明は本記事では扱いませんのでご容赦下さい
SWEMとBERTのベクトル比較
比較する文章はAI Shiftが提供しているチャットボットプロダクトAI Messengerを導入しているとあるサイトのユーザー質問から、現在SWEMを使ったときに苦戦している語順が入れ替わった文章とノイズが含まれる文章を抽出して使用します(記事に載せる都合上、一部を変更しています)
SWEMのベクトル化手法はconcatとhier(window size=2) を利用し、各文章ベクトルをPCAで二次元平面上に射影して比較します
語順が入れ替わったもの
SWEMでは基本的に語順が入れ替わった文章を区別できません
下記文章で比較してみいましょう
ユーザー質問 | |
---|---|
0 | 羽田から那覇に行きたい |
1 | 那覇から羽田に行きたい |
2 | 羽田から福岡に行きたい |
3 | 福岡から羽田に行きたい |
4 | 福岡から那覇に行きたい |
5 | 那覇から福岡に行きたい |
ただ平均やmax poolingをとるだけでは、出発地と目的地が入れ替わっていたとしても同じベクトルになってしまいます
hierはn-gramのように窓をとり、平均をとった結果に対してmax poolingする方法なので、ある程度は語順を考慮できるのですが、目的地や出発地でまとまらず、あまりうまく行きません
「〇〇に行きたい」という目的地ごとに近いベクトルが抽出されています
SWEMのベクトルを作るword2vecのベクトルの成分は-1.0〜+1.0に正規化されているのに対して、BERTはされていないので縮尺が違いますが、互いの距離関係を見ると、かなりうまく分かれていると思います(自分でも試していて驚きました)
ノイズのある文章
「〇〇を☓☓したい」のような最小限の文章ならば問題ないのですが、チャットボットには「〇〇を☓☓したいんだけどどうすればいいの?」といったノイズのある文章が度々入力されます
こういったノイズは文意を捉えづらくしてしまいます
下記文章で比較します
ユーザー質問 | |
---|---|
6 | チケットを郵送してもらいたい |
7 | チケットは郵送できませんか? |
8 | 行かなくなったから返金してもらいたい |
9 | 返金はできないんですか? |
10 | キャンセルしたいんだけどどうすればいいの? |
11 | キャンセルさせてください |
比較文章は、チケットの郵送、返金、キャンセル、といった3つのジャンルがあるのですが、SWEMはconcatもhierどちらのベクトルもまとまっていません
BERTもぱっと見すべての文章が離れてしまっていたので、うまくいかなかったのかと思ったのですが、よく見ると
各ジャンルが近い高さにあり「〜〜ですか?」という疑問形の文章と「〜〜してほしい」という願望系の文章の関係が同じ方向を向いていることがわかります
Word2Vecの生みの親、Tomas Mikolov氏の論文^3 に出てくる国と首都の単語ベクトルの関係に似ていて面白いです
まとめ
本記事では語順が入れ替わった文章とノイズが含まれる文章をSWEMとBERTでベクトル変換し、PCAで二次元平面上に射影して、それぞれのベクトルを比較しました
結果、BERTは現在私が使っているSWEMより良いベクトルを抽出できるように見え、私の最初の疑問である、文章ベクトル抽出器としてのBERTは、非常に優れたものだと考えられます
とはいえ、明らかに比較している文章が少ないですし、そもそも文書分類などのタスクに、(fine-tuningするのではなく)文章ベクトル抽出器として応用したときにどうなるのか、といったことを今後検証したいと思います
自然言語処理エンジニアとしてまだまだ知識不足なので、なにか間違いがございましたらtwitter等で指摘していただけると嬉しいです
最後までご覧いただきありがとうございました
実験で使ったコード
import numpy as np
import pandas as pd
import torch
from transformers import (
BertModel,
BertConfig,
BertTokenizer,
BertForPreTraining,
BertConfig
)
from chainer import functions as F
from pyknp import Jumanpp
from gensim import models
from sklearn.decomposition import PCA
from matplotlib import pyplot as plt
jm = Jumanpp()
WORD2VEC_ROOT = # Word2Vecのモデルパス
BERT_ROOT = # BERTのモデルパス
TEXT_PATH = # 評価データ
def wakati_jm(text):
result = jm.analysis(text)
tokenized_text =[mrph.midasi for mrph in result.mrph_list()]
return tokenized_text
class SWEM():
def __init__(self, vec_dict):
self.vec_dict = vec_dict
def __call__(self, text, mode):
vecs = self.w2v(text)
if mode == "aver":
return vecs.mean(axis=0)
elif mode == "max":
return vecs.max(axis=0)
elif mode == "concat":
return np.hstack([vecs.mean(axis=0), vecs.max(axis=0)])
elif mode == "hier_2":
return self.hier(vecs, 2)
elif mode == "hier_3":
return self.hier(vecs, 3)
def w2v(self, text):
sep_text = wakati_jm(text)
v = []
for w in sep_text:
try:
v.append(self.vec_dict[w])
except KeyError:
v.append(np.zeros(250))
return np.array(v)
def hier(self, vecs, window):
h, w = vecs.shape
if h < window:
return vecs.max(axis=0)
v = F.average_pooling_1d(vecs.reshape(1, w, h), ksize=window).data
return v.max(axis=2)[0]
class BertVectorizer():
def __init__(self, model, tokenizer):
self.model = model.eval()
self.tokenizer = tokenizer
def __call__(self, text):
tokenized_text = wakati_jm(text)
ids = tokenizer.convert_tokens_to_ids(tokenized_text)
ids = torch.tensor(ids).reshape(1,-1)
with torch.no_grad():
vec = model.bert(ids)[0][0].max(0)[0]
return vec.numpy()
def plot_pca(vecs, label, title):
X_reduced = PCA(n_components=2, random_state=0).fit_transform(vecs)
plt.scatter(X_reduced[:, 0], X_reduced[:, 1])
plt.grid()
plt.title(title)
for label, x, y in zip(label, X_reduced[:, 0], X_reduced[:, 1]):
plt.annotate(label, xy=(x, y), xytext=(0, 0), textcoords='offset points')
plt.savefig(f'{title}.png')
if __name__ == "__main__":
w2v = models.Word2Vec.load(WORD2VEC_ROOT)
swem = SWEM(w2v.wv)
config = BertConfig.from_json_file(BERT_ROOT + '/bert_config.json')
model = BertForPreTraining(config=config)
model.load_state_dict(torch.load(BERT_ROOT+"/pytorch_model.bin"))
tokenizer = BertTokenizer(BERT_ROOT+"/vocab.txt")
bert_vec = BertVectorizer(model, tokenizer)
text_df = pd.read_csv(TEXT_PATH)
label = text_df["ユーザー質問"].values
bert_vecs = np.array(text_df["ユーザー質問"].map(bert_vec).tolist())
swem_concat_vecs = np.array(text_df["ユーザー質問"].map(lambda x: swem(x, "concat")).tolist())
swem_hier_2_vecs = np.array(text_df["ユーザー質問"].map(lambda x: swem(x, "hier_2")).tolist())
plot_pca(bert_vecs[:6], label[:6], "語順: BERT")
plot_pca(bert_vecs[6:], label[6:], "ノイズ: BERT")
plot_pca(swem_concat_vecs[:6], label[:6], "語順: SWEM-concat")
plot_pca(swem_concat_vecs[6:], label[6:], "ノイズ: SWEM-concat")
plot_pca(swem_hier_2_vecs[:6], label[:6], "語順: SWEM-hider_2")
plot_pca(swem_hier_2_vecs[6:], label[6:], "ノイズ: SWEM-hider_2")