2024年12月25日水曜日

dspy でプロンプトの最適化をしてみる

dspy でプロンプトの最適化をしてみる

概要

前回 dspy を使って Metrics 機能で判定を行ってみました
今回は dspy の目玉機能でもあるプロンプトの最適化機能を試してみました
簡単に言えば最適化機能はファインチューニングみたいなものでプロンプトをよりよくすることができる再学習機能です

環境

  • Python 3.11.3
  • dspy 2.5.43

コード全体

まずはコード全体を紹介します
以下で機能部分ごとにコードを紹介しています

import dspy
import pandas as pd
from dspy.evaluate.evaluate import Evaluate
from dspy.teleprompt import BootstrapFewShotWithRandomSearch

lm = dspy.LM(
    "azure/gpt-4-32k",
    api_key="xxx",
    api_version="",
    api_base="https://your-api-endpoint/ai/chat-ai/gpt4",
)
dspy.configure(lm=lm)


# プロンプトの入出力をカスタムするクラス
class BasicMathQA(dspy.Signature):
    """算数の文章問題を読んで、解答を数値で出力する。"""

    question = dspy.InputField(desc="算数の文章問題")
    answer = dspy.OutputField(
        desc="算数の文章問題の解答。数値のみ出力する。単位や説明、句読点は含まないこと。",
    )


# 計算に特化したChainOfThoughtを管理するカスタムモジュール
class CoTMathQA(dspy.Module):
    def __init__(self):
        super().__init__()
        self.generate_answer = dspy.ChainOfThought(BasicMathQA)

    def forward(self, question):
        return self.generate_answer(question=question)


# データセットのダウンロード
df = pd.read_csv(
    "https://raw.githubusercontent.com/google-research/url-nlp/main/mgsm/mgsm_ja.tsv",
    sep="\t",
    header=None,
)
df.rename(columns={0: "question", 1: "answer"}, inplace=True)


# 前半50を学習データにする、dspyで使えるようにdspy.Example形式に変換する
mgsm_ja_trainset = [
    dspy.Example(question=row.question, answer=row.answer).with_inputs("question")  # type: ignore
    for row in df[:50].itertuples()
]
# 後半50を学習データにする、dspyで使えるようにdspy.Example形式に変換する
mgsm_ja_devset = [
    dspy.Example(question=row.question, answer=row.answer).with_inputs("question")  # type: ignore
    for row in df[51:101].itertuples()
]

# プロンプトを学習するためのモデルを作成 https://dspy.ai/deep-dive/optimizers/bfrs/
# metric には独自の評価関数を設定することが可能でこの関数を制御することで精度を上げることもできる
# num_threadsは1じゃないとRateLimitになる場合が多い、今回のデータセットでだいたい30分くらいかかる
teleprompter = BootstrapFewShotWithRandomSearch(
    metric=dspy.evaluate.answer_exact_match,  # type: ignore
    max_labeled_demos=10,
    max_bootstrapped_demos=8,
    num_threads=4,
)

# teleprompter の学習の開始、これが最適化するということ
# 最適化後のプロンプトを使って質問する
compiled_cot_math_qa = teleprompter.compile(CoTMathQA(), trainset=mgsm_ja_trainset)
compiled_cot_math_qa(
    question="周囲の長さが300メートルの池の周りに木を植えることにした。5メートル間隔で植える場合、木は何本必要か?"
)

# 最適化されたプロンプトの保存
compiled_cot_math_qa.save("./compiled_cot_math_qa/", save_program=True)

# 最適化前と後での精度を比較する
evaluate = Evaluate(
    devset=mgsm_ja_devset,
    num_threads=5,
    display_progress=True,
    display_table=True,
)

accuracy_original = evaluate(CoTMathQA(), metric=dspy.evaluate.answer_exact_match, display_table=0)  # type: ignore
accuracy_compiled = evaluate(compiled_cot_math_qa, metric=dspy.evaluate.answer_exact_match, display_table=0)  # type: ignore
print(f"Original accuracy: {accuracy_original}")
print(f"Compiled accuracy: {accuracy_compiled}")

解説

まずはプロンプトを最適化するために学習データを用意します
今回は https://raw.githubusercontent.com/google-research/url-nlp/main/mgsm/mgsm_ja.tsv を使っていますがこれと同じ形式の tsv を用意すれば独自のデータでプロンプトを最適化することができます

学習/評価するデータは 50 個ずつ用意します
先程の tsv ファイルが全部 100 個あるので前半と後半で分けてそれぞれを学習/訓練データにしています

プロンプトの最適化にはそれに応じたクラスを使います
今回は BootstrapFewShotWithRandomSearch という FewShot + ランダム検索で最適化を行います
一番のポイントは metric でこれが評価関数になります
今回は dspy が用意している評価関数を使っていますがここに独自の評価関数を入れることでプロンプトの精度を向上させることができます
num_threads は増やせば増やすほど最適化が速く終わります

ファインチューニング同様に最適化したプロンプトは保存できます
なので毎回最適化する必要はなく良い評価値になったプロンプトを毎回使い回すことができます

あとは最適化前と最適化後で評価します

Original accuracy: 86.0
Compiled accuracy: 88.0

自分の環境ではそこまで大きな差は出ませんでしたがそれでもまだまだ精度向上の余地はあるかなと思います

最後に

dspy のプロンプト最適化を試してみました
基本的にはファインチューニングのような再学習機能になります
しかし以下のようなメリットがあります

  • Keras + LoRA のような難しい技術を知る必要がない
  • LLM さえ使えればいいので自分でベースとなるモデルを用意する必要がない
  • LLM をコールするだけで最適化できるので最適化のためのマシンスペックが不要

デメリットとしては

  • 優秀な LLM は基本有料なので最適化するためにかなりのコストがかかる可能性がある
  • 最適化中に 429 になる可能性があり面倒

あたりかなと思います
LLM という強力なツールがある上で精度のいい回答を作成することができるようになるのでファインチューニングに比べてかなり楽かなとは思います
結構複雑な回答を学ばせることもできるので一から数値化したりする必要のあるファインチューニングより学習コストも低いです

参考サイト

0 件のコメント:

コメントを投稿