trlxを用いた文書生成モデルの学習②~PPO編~

こんにちは
AIチームの戸田です

前回に続き、最近話題のChatGPTの学習に使われているRLHF(Reinforcement Learning from Human Feedback)を行うことができる強化学習フレームワーク、trlxを使った文章生成を試してみたいと思います。

本記事ではいよいよ話題のChatGPTと同じ(と思われる※)PPOの学習を、前回と同様WRIMEのデータで試してみたいと思います。実験設定などは前回の記事をご参照ください。また、本記事はtrlxライブラリを一通り動かすことを目的とし、パラメータ調整やデータクレンジングなどのより良い生成を行うための工夫は行いません。

なお、trlxのバージョンはv0.5.0を使用しています。trlxの最新版ではいくつか仕様変更が行われているため、著者と同じ環境で実験したい方はv0.5.0でお試しください。

※: 2023.02.20現在、ChatGPTの論文は公開されていないのですが、前身のInstructGPTの論文からの予測になります。

reward modelの学習

InstructGPTでのreward modelは回答の順位付けを予測するタスクを解かせていますが、今回は簡略化のために単純な2値分類モデルにしたいと思います。

ベースには日本語の蒸留BERTであるbandainamco-mirai/distilbert-base-japaneseを使います。

コードの全容は非常に大量になってしまいますので、kaggle codeに公開しました。

また学習済みのモデルをHiggingFace Hubにアップロードしました。強化学習のFine-Tuningのみ試したい方はこちらをご利用ください。

  1. joyかそうでないかの2値分類モデル
  2. sadnessかそうでないかの2値分類モデル

以下のようにすると手元にモデルをダウンロードできると思います。

git lfs install
git clone https://huggingface.co/trtd56/wrime-rw-model-joy
git clone https://huggingface.co/trtd56/wrime-rw-model-sadness

ちなみにreward modelの2値分類モデルとしての評価結果(Accuracy)は以下のようになりました。

SentimentValidation BestTest
joy0.73750.776
sadness0.79750.784

こちらのモデルを評価器として、PPOによる言語モデルのFine-Tuningを行なっていきます。

PPOでFine-Tuning

学習コードの書き方は前回のILQLとほぼ同様なので、差分のみ紹介していきたいと思います。

configの上書き

公式のサンプルを利用して、必要なところを上書きします。

model_name = 'rinna/japanese-gpt2-medium'

with open('configs/ppo_config.yml') as f:
    default_config = yaml.safe_load(f)

default_config['train']['batch_size'] = 16

default_config['model']['model_path'] = model_name
default_config['tokenizer']['tokenizer_path'] = model_name

default_config['train']['tracker'] = None

config = TRLConfig.update(default_config, {})

モデルは前回と同様rinna/japanese-gpt2-mediumを使います。

評価関数の定義

学習中にreward modelを呼び出して報酬を計算するための関数を定義します。

モデルの読み込み

学習コードを参考に、同様のモデルを定義して、学習した2値分類モデルの重みを読み込みます。

from transformers import AutoTokenizer, AutoModel, AutoConfig
import torch.nn as nn
import torch

class CFG:
    MODEL_NAME = "bandainamco-mirai/distilbert-base-japanese"
    MAX_LEN = 64
    TOKENIZER = None

class CustomModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.config = AutoConfig.from_pretrained(CFG.MODEL_NAME, output_hidden_states=True)
        self.model = AutoModel.from_pretrained(CFG.MODEL_NAME, config=self.config)
        self.fc = nn.Linear(self.config.hidden_size, 2)

    def forward(self, inputs):
        outputs = self.model(**inputs)
        last_hidden_states = outputs[0]
        feature = last_hidden_states[:, 0, :]
        output = self.fc(feature)
        return output

senti_model = CustomModel()
state = torch.load("[reward model(2値分類モデル)の重みのPATH]")
senti_model.load_state_dict(state)
senti_model.to(device)

# tokenizerも併せて設定する
TOKENIZER = AutoTokenizer.from_pretrained(CFG.MODEL_NAME)
CFG.TOKENIZER = TOKENIZER

推論関数の定義

テキストを入力するとreward modelの予測値を出力する関数を作ります

def predict_sentiment(text):
    inputs = CFG.TOKENIZER.encode_plus(
            text,
            return_tensors='pt',
            add_special_tokens=True,
            max_length=CFG.MAX_LEN,
            padding='max_length',
            truncation=True
    )
    del inputs['token_type_ids']
    for k in inputs.keys():
        inputs[k] = inputs[k].to(device)

    with torch.no_grad():
        outputs = senti_model(inputs)
    pred = outputs.softmax(1)[:, 1]
    return pred[0].cpu().item()

この関数を使って評価関数を作ります。入力は生成されたテキストのリストになるので、それぞれを上記関数に通して、予測値を得ます。

def reward_fn_wrime(samples: List[str], **kwargs) -> Dict[str, List[float]]:
    sentiments = [predict_sentiment(s) for s in samples]
    return sentiments

出力は予測値のリストになります。

学習

学習用のpromptとして、学習データのテキストの先頭4文字を与えます。参考にしたIMDBを使ったexampleだと4単語だったので、きちんと形態素解析器を通して、区切りの良い単語で与えた方が結果は良くなりそうですが、今回は一通り動かすことを目的としているので、そこはまた今度試してみたいと思います。

wrime = load_dataset("shunk031/wrime", name="ver1")
train_sentence = wrime['train']['sentence']
valid_sentence = wrime['validation']['sentence']
sentence = train_sentence + valid_sentence

prompts = [s[:4] for s in sentence]

ちなみにpromptを設定せずに学習することもできます。この際のpromptは<s>などの先頭文字特殊トークンになります。

以下のコードで学習開始します。

trainer = trlx.train(
    reward_fn=reward_fn_wrime,
    prompts=prompts,
    eval_prompts=eval_prompts,  # こちらの中身は前回の記事をご参照ください
    config=config,
)

出力

学習が完了したモデルの出力結果例を以下に示します。

joy

promptoutput
今週末に良いお天気だったので気合を入れて消化させましたが、味噌汁を作りました。 小豆こしがあったのでとろいのが出来ましたが、 味噌がなかったので、味噌汁は
今日6時半ごろに貴船神社に着いたところですが、 なぜか裏が表になりました。 大島から 海が見えるのは大きく40度ぐらいのようですが 7度くらいですね。 そちら
帰ったらお天気だし、ご飯用意しなくていいんだ。ご飯を作ってくれた両親も大変だもんね。 風が

sadness

promptoutput
今週末に体調をくずし、最高気温の関係で日曜のレースは見に行くことは叶わなかったのだが ...
今日6時半の地震 震度6 大分 津波 (д д
帰ったらクリスマスのプレゼントの仕分け・・・『お父さんラゼェド家に戻るんだよね?

「今週末に」に続く文章を見てみると、joyは「良いお天気〜」とポジティブらしい内容、sadnessは「体調をくずし〜」とネガティブらしい内容になっており、前回のILQLに比べるとフィードバックを反映した生成ができているのではないかと思われます。
しかし文章的におかしなところが多いように見えます。これはモデルサイズを大きくすればある程度は解決するのではないかな、と考えています。

おまけ

SFT(Supervised Fine-Tuning)を学習させる際は以下のような設定にすると良いです。

with open('configs/sft_config.yml') as f:
    default_config = yaml.safe_load(f)
  .
  .
  .
trainer = trlx.train(
    samples=sentence,
    eval_prompts=eval_prompts,
    config=config,
)

与えるのはsentence、つまりテキストだけになるので非常にシンプルですね。

おわりに

本記事では言語モデル向け強化学習フレームワークであるtrlxを使ってPPOによる言語モデルのFine-Tuningを試してみました。

一通り動かすことを目的としていたので、SFTや順位づけによる人間のフィードバックは行いませんでしたが、生成された文章はフィードバックを反映してそうなものもありました。

経験上、パラメーター数1Bが生成文章の品質における一つの壁だと感じているので、今後はパラメータ数の多いモデル(今回使ったrinna/japanese-gpt2-mediumはパラメーター数336M)で試してみて生成結果の変化を確認してみたいです。

最後までお読みいただき、ありがとうございました!

PICK UP

TAG