概要
前回 文章を数値化し faiss に保存可能なベクトル情報を計算しました
今回は faiss にベクトル情報を保存する方法を紹介します
環境
- macOS 15.2
- transformers 4.47.0
- faiss-cpu 1.9.0-post1
コード全体
from torch import Tensor
from transformers import AutoModel, AutoTokenizer
# このaverage_pool関数はbertの出力結果ではよく使われる手法、何をしているかの詳細は以下の説明に記載
def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
# faissに保存する文章、この文章と類似度検索をする
input_texts = [
"好きな食べ物は何ですか?",
"どこにお住まいですか?",
"朝の電車は混みますね",
"今日は良いお天気ですね",
"最近景気悪いですね",
"最近、出かけていないので、たまには外で食事でもどうですか?",
]
# モデルとトークナイザの取得、日本語なので日本語に特化したモデルを使用
tokenizer = AutoTokenizer.from_pretrained("cl-tohoku/bert-base-japanese")
model = AutoModel.from_pretrained("cl-tohoku/bert-base-japanese")
# 文章の数値化
inputs = tokenizer(input_texts, padding=True, truncation=True, return_tensors="pt")
# 文章のベクトル化、768 次元のベクトル情報が取得できる (参考: https://github.com/cl-tohoku/bert-japanese)
outputs = model(**inputs)
# bert の結果を更によくするためのおまじない
embeddings = average_pool(outputs.last_hidden_state, inputs["attention_mask"])
# faiss で扱えるように numpy 配列に変換
embeddings_np = embeddings.cpu().detach().numpy()
# なぜか先頭でimportすると segmentation fault になるのでここで import
import faiss
# 次元数(768)で初期化
index_flat_l2 = faiss.IndexFlatL2(embeddings_np.shape[1])
# ベクトル情報を追加、6文章(vectors)登録される
index_flat_l2.add(embeddings_np) # type: ignore
今回追加した faiss にデータを保存する部分は以下の3行です
# なぜか先頭でimportすると segmentation fault になるのでここで import
import faiss
# 次元数(768)で初期化
index_flat_l2 = faiss.IndexFlatL2(embeddings_np.shape[1])
# ベクトル情報を追加、6文章(vectors)登録される
index_flat_l2.add(embeddings_np) # type: ignore
最後に
たったこれだけで faiss にベクトル情報を保存することができました
次回は faiss に保存した文章をもとにクエリとして与えた文章と似ている文章を探す処理を追加します
0 件のコメント:
コメントを投稿