2023年11月16日木曜日

Pydantic で Union を使って model_validate する場合は必ず discriminator を使うべきである

Pydantic で Union を使って model_validate する場合は必ず discriminator を使うべきである

概要

model_validate を使用すると pydantic が自動で型を推論します
更に Union を使っている場合は Union に指定している複数のクラスのどれにするかを自動的に推論します
同じフィールドを持つ場合には意図しないクラスに割り当てられることもあります
そういったことを避けるために discriminator という機能を使いましょう

環境

  • Python 3.10.2
  • pydantic 2.5.1

サンプルコード

from typing import Annotated, Literal, Union

from pydantic import BaseModel, Field, field_validator


class Cat(BaseModel):
    # model_validate 時の型推論用フィールド
    animal_type: Literal["cat"] = "cat"
    value: str = Field()

    @field_validator("value")
    @classmethod
    def validate_value(cls, v: str):
        if v != "neko":
            raise ValueError()
        return v


class Dog(BaseModel):
    # model_validate 時の型推論用フィールド
    animal_type: Literal["dog"] = "dog"
    value: str = Field()

    @field_validator("value")
    @classmethod
    def validate_value(cls, v: str):
        if v != "inu":
            raise ValueError()
        return v


Animal = Annotated[
    Union[Cat, Dog],
    Field(discriminator="animal_type"),
]


class Trainer(BaseModel):
    # model_validate 時の型推論用フィールド
    animal: Animal = Field()


if __name__ == "__main__":
    # 成功
    cat = Cat(value="neko")
    dog = Dog(value="inu")
    trainer1 = Trainer(animal=cat)
    trainer2 = Trainer(animal=cat)
    Trainer.model_validate(trainer1.model_dump())
    Trainer.model_validate(trainer2.model_dump())
    # 失敗
    cat = Cat(value="inu")
    trainer1 = Trainer(animal=cat)
    Trainer.model_validate(trainer1.model_dump())

ちょっと解説

一見当たり前のことをしていますが同一のフィールド名をもつクラスを作成しそのクラスに対して model_validate を実行するとフィールドの情報からクラスを推測するため間違ったクラスになる可能性があります

model_dump はクラスの json 情報を返却します
また model_validate で渡す json 情報はすべてのフィールドがプリミティブな値でなければいけないので model_dump される json はすべて文字列だけの情報になります

その状態で discriminator 情報なしで model_validate するとクラスを特定する情報がない状態になってしまいます
なのでクラスを特定する情報として今回 animal_type フィールドを追加し必ずクラスごとに一意になるように値を設定することでプリミティブな値のみを持つ model_dump の情報からでも適切なクラスに変換し model_validate することができるようになります

最後に

Pydantic で Union を使う場合には必ず discriminator を使うようにしましょう

参考サイト

0 件のコメント:

コメントを投稿