こんにちわ
AIチームの戸田です
今回の記事でもBERT^1を扱わせていただきます
BERTの事前学習タスクであるMasked LM、こちらは、入力文のトークンをランダムに[MASK]
シンボルに置き換え出力でその単語を予測する、という学習ですが、このタスクどこか既視感があると思いませんか?
Input: the man went to the [MASK1] . he bought a [MASK2] of milk. Labels: [MASK1] = store; [MASK2] = gallon
そう、センター試験やTOEICで出てくる単語穴埋めの問題です
ということで、今回は事前学習済みのBERTでfine-tuningせずにTOEICのPart 5の単語穴埋め問題を解けるか試してみたいと思います
問題はIIBC公式のサンプル問題^2を使用させていただきました
BERT学習済みモデルの読み込み
huggingfaceのtransformers^3 を利用します
import torch
from transformers import BertTokenizer, BertForPreTraining
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForPreTraining.from_pretrained('bert-base-uncased')
問題定義
text
に問題の文章、candidate
に穴埋め候補単語をリストで定義します
問題文章の空欄部分は*
とします
text = "Customer reviews indicate that many modern mobile devices are often unnecessarily * ."
candidate = ["complication", "complicates", "complicate", "complicated"]
こちら、品詞を問う問題で、正解は4番目の"complicated"なのですが、BERTは正解することができるのでしょうか?
トークナイズ
BertTokenizerをつかってトークン分割します
tokens = tokenizer.tokenize(text)
# -> ['customer', 'reviews', 'indicate', 'that', 'many', 'modern', 'mobile', 'devices', 'are', 'often', 'un', '##ne', '##ces', '##sari', '##ly', '*', '.']
トークン分割ができたら、次は元の問題で空欄だった部分を[MASK]
トークンに置き換えます
また、事前学習時と同様に文頭と文末にSpecial Tokenの[CLS]
と[SEP]
を入れます
masked_index = tokens.index("*") # 空欄部分のトークンのインデックスを取得
tokens[masked_index] = "[MASK]"
tokens = ["[CLS]"] + tokens + ["[SEP]"]
# -> ['[CLS]', 'customer', 'reviews', 'indicate', 'that', 'many', 'modern', 'mobile', 'devices', 'are', 'often', 'un', '##ne', '##ces', '##sari', '##ly', '[MASK]', '.', '[SEP]']
BERTで予測
トークナイズされた文章をIDに変換して事前学習済みのBERTに通します
今回は事前学習と解きたい問題が同じなので、fine-tuningは行わずにそのまま予測します
ids = tokenizer.convert_tokens_to_ids(tokens)
ids = torch.tensor(ids).reshape(1,-1) # バッチサイズ1の形に整形
with torch.no_grad():
outputs1, outputs2 = model(ids)
predictions = outputs1[0]
outputs1
にMasked LMの予測結果が入っています
outputs2
はBERTのもう一つの事前学習のNext Sentence Predictionの予測結果が入っていますが今回は使わないので無視します
予測上位のトークンを取得
[MASK]
に入ると予測される単語の上位1000件を取得します
_, predicted_indexes = torch.topk(predictions[masked_index+1], k=1000)
predicted_tokens = tokenizer.convert_ids_to_tokens(predicted_indexes.tolist())
# -> ['expensive', 'small', 'priced', 'used', ...
1位はexpensiveで、日本語にすると「カスタマーレビューによると、最近のモバイルデバイスは無駄に高価です」といったところでしょうか
意味に違和感はないと思います
予測単語を順々に見ていき、候補の単語が出てきたところで止めます
for i, v in enumerate(predicted_tokens):
if v in candidate:
print(i, v)
break
# -> 74 complicated
75番目(index=74)でcomplicated
がヒットしました
日本語訳は「カスタマーレビューによると、最近のモバイルデバイスは無駄に複雑です」ですかね
見事正解できました
関数化
ここまでの処理を1つの関数にまとめます
def part5_slover(text, candidate):
tokens = tokenizer.tokenize(text)
masked_index = tokens.index("*")
tokens[masked_index] = "[MASK]"
tokens = ["[CLS]"] + tokens + ["[SEP]"]
ids = tokenizer.convert_tokens_to_ids(tokens)
ids = torch.tensor(ids).reshape(1,-1)
with torch.no_grad():
outputs, _ = model(ids)
predictions = outputs[0]
_, predicted_indexes = torch.topk(predictions[masked_index+1], k=1000)
predicted_tokens = tokenizer.convert_ids_to_tokens(predicted_indexes.tolist())
for i, v in enumerate(predicted_tokens):
if v in candidate:
return (i, v)
return "don't know"
こちらの関数を使って残りの問題も解いてみます
なんと全問正解です!
おわりに
本記事では事前学習済みのBERTのモデルを使って、TOEICのPart 5の問題を解いてみました
解きたい問題がほぼ同じなので、fine-tuningすることなく良い結果が得られました
せっかくなのでこれで終わりにせず、リーディング問題をすべてBERTで解いてみようと思います(流石にPart 7はfine-tuningが必要でしょうか。。。)
次回はPart 6に挑戦します
最後まで読んでいただきありがとうございました