2024年5月31日金曜日

ツイート情報を LoRA に学習させるためのフォーマットに変換する

ツイート情報を LoRA に学習させるためのフォーマットに変換する

概要

過去に mlx で LoRA 学習をしました今回はツイートのアーカイブデータを学習させられるように加工してみました

環境

  • macOS 11.7.10
  • Python 3.11.6

準備

  • pip install pandas scikit-learn

サンプルコード

基本的な作成の流れは前回と同じです
成果物も json ファイル 3 つになります

ツイートのアーカイブ情報にはいろいろありますが今回は full_text のみ学習させます

命令も単純に「ツイートして」にしていますが本来であれば時刻やそのツイートに関するコンテキストも命令に与えたほうがいいです

import json
from io import StringIO

import pandas as pd
from sklearn.model_selection import train_test_split

# tweets.js を読み込んでツイート情報を辞書に変換
tweets = []
with open("./tweets.js", mode="r") as f:
    for tweet in json.loads(f.read().replace("window.YTD.tweets.part0 = ", "")):
        t = {
            "full_text": tweet["tweet"]["full_text"],
            "created_at": tweet["tweet"]["created_at"],
        }
        tweets.append(t)

# JSONデータをPandas DataFrameに読み込む
df = pd.read_json(StringIO(json.dumps(tweets)))
pd.set_option("display.max_colwidth", 1000)
pd.set_option("display.max_rows", 1000)

# リンクを含むツイートは削除
df = df[~df["full_text"].str.contains("https://")]

print(df.head(100))

# テキスト長が短い学習データTOP1300をデータセットとして扱う
# 学習に時間がかかる場合もしくはメモリを使いすぎる場合は ascending=True にして短い順にするか 1300 の数を小さくすること
df["length"] = df.full_text.str.len()
df = df.sort_values(by="length", ascending=False)
df = df.head(1300)

# データフレームをシャッフル
df = df.sample(frac=1).reset_index(drop=True)

# validとtest用のデータを100件ずつ取り出し、残りをtrainに分割
valid_df, remaining_df = train_test_split(df, test_size=len(df) - 100, random_state=42)
test_df, train_df = train_test_split(
    remaining_df, test_size=len(remaining_df) - 100, random_state=42
)


# ヘルパー関数:データフレームを新しい形式でJSON Linesファイルに変換
def df_to_jsonl(df, file_name):
    with open(file_name, "w", encoding="utf-8") as file:
        for _, row in df.iterrows():
            formatted_data = {"text": f"USER:ツイートして ASSISTANT:{row['full_text']}"}
            file.write(json.dumps(formatted_data, ensure_ascii=False) + "\n")


# 各データセットを対応するJSON Linesファイルに変換
df_to_jsonl(train_df, "./train.jsonl")
df_to_jsonl(valid_df, "./valid.jsonl")
df_to_jsonl(test_df, "./test.jsonl")

トラブルシューティング

pandas を使う場合は pipenv 配下だとエラー (Segmentation fault) になります
なので pip でインストールしてその配下で使いましょう

最後に

次回はこのデータを使って MLX + LoRA をしてみたいと思います

0 件のコメント:

コメントを投稿