2023年8月10日木曜日

pydantic の BaseModel でデータベースに接続してバリデーションを行うケースを考える

pydantic の BaseModel でデータベースに接続してバリデーションを行うケースを考える

概要

結論からすると pydantic の BaseModel 内で DB を参照するのは好ましくないので context を使いましょう

コードは前回のものを流用します

環境

  • Python 3.11.3
  • fastapi 0.100.1
  • pypdantic 2.1.1

サンプルコード

from enum import Enum

from fastapi import Depends, FastAPI, Query, Request
from fastapi.datastructures import QueryParams
from pydantic import BaseModel, Field, FieldValidationInfo, field_validator
from sqlalchemy import Column, Integer, String, create_engine
from sqlalchemy.orm import Session, declarative_base, sessionmaker

app = FastAPI()


engine = create_engine("mysql+pymysql://root@localhost/test?charset=utf8mb4")
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()


class User(Base):
    __tablename__ = "user"

    id = Column(Integer, primary_key=True)
    name = Column(String(50))
    age = Column(Integer)


def get_db():
    db: Session = SessionLocal()
    try:
        yield db
        db.commit()
    except Exception:
        db.rollback()
    finally:
        db.close()


# 各種アクション名を管理する列挙型のクラス
class ActionName(Enum):
    CREATE_INSTANCE = "CreateInstance"


# CreateInstance(例) という名前のアクションのバリデーションを管理するモデル
class CreateInstance(BaseModel):
    instance_name: str = Field(alias="InstanceName")

    @field_validator("instance_name")
    @classmethod
    def validate_instance_name(cls, v: str, info: FieldValidationInfo) -> str:
        # context は FieldValidationInfo に格納されているので使用する場合は引数に追加する
        context = info.context
        # データベースから取得した値が取得できていることが確認できる
        print(context)
        if v != "ins01":
            raise ValueError()
        return v


# 起点となる Action 名のバリデーションを管理するモデル
class Action(BaseModel):
    action: str = Field(Query(alias="Action"))

    @field_validator("action")
    @classmethod
    def validate_action(cls, v: str) -> str:
        # ここでは許可するバリデーション名など基本的なバリデーションのみを行う
        if v not in [ActionName.CREATE_INSTANCE.value]:
            raise ValueError()
        return v

    # アクション名に応じて各種バリデーションをコールする
    def validate(self, query_params: QueryParams, context: dict):
        if self.action == ActionName.CREATE_INSTANCE.value:
            CreateInstance.model_validate(query_params, context=context)


@app.get("/")
async def root(
    request: Request, action: Action = Depends(), db: Session = Depends(get_db)
):
    # データベースの値はここで取得してバリデーション対応の値を context として渡す
    users = db.query(User).all()
    names = [user.name for user in users]
    # ルーティングの引数で Request オブジェクトを参照できるのでこれを使って各種アクションのバリデーションをコールする
    action.validate(request.query_params, context={"names": names})
    return action

解説

ポイントは context です
データベースへの参照はルーティング内でのみ行います
そして必要な値をデータベース側で渡すようにしましょう

このケースの場合は model_validate を使うことが前提となっているのでルーティングの引数でモデルを指定して自動でバリデーションしているような場合は少し改修が必要になります

最後に

一応データベースを参照することもできますがその場合はグローバルにアクセスできる関数やデータベースのコンテキストを考える必要があるので少し面倒なのとコードの可読性が下がるかなと思います

参考サイト

0 件のコメント:

コメントを投稿