2023年7月27日木曜日

SQLAlchemyのautomapを使って自動で生成されたクラスの定義を表示する方法

SQLAlchemyのautomapを使って自動で生成されたクラスの定義を表示する方法

概要

過去に automap の使い方を紹介しました
automap を使った場合クラスのモデル定義がないためコード内でデータベースの定義を確認するすべがありません
またコードエディタなどで補完なども行えず pyright ではそのような属性がないと言われてエラーになってしまいます
なのでクラス定義がないと困るケースが多いです

今回はそんな場合に使えそうな automap からクラスの情報を標準出力に吐き出す方法を紹介します

環境

  • macOS 13.4.1
  • Python 3.11.3
  • sqlalchemy 2.0.19

サンプルコード

  • vim ./app.py
from sqlalchemy import create_engine, inspect
from sqlalchemy.ext.automap import automap_base

# automap を使用するための準備
Base = automap_base()

# エンジンの作成
engine = create_engine("mysql+pymysql://root@localhost/test?charset=utf8mb4")

# テーブル定義の読み込み
Base.prepare(autoload_with=engine)

# user テーブルから user モデルの抽出
User = Base.classes.user

# インスペクタの作成
inspector = inspect(User)
# import pprint
# pprint.pprint(dir(inspector))

# クラス定義を標準出力に表示
print(f"class {User.__name__.title()}(Base):")
print(f"    __tablename__ = '{User.__table__}'")

# カラム情報を表示
for column in inspector.columns:
    column_params = []
    if column.primary_key:
        column_params.append("primary_key=True")
    if column.foreign_keys:
        foreign_key = list(column.foreign_keys)[0]
        column_params.append(
            f"ForeignKey('{foreign_key.column.table.name}.{foreign_key.column.name}')"
        )
    if column.unique:
        column_params.append("unique=True")
    if column.index:
        column_params.append("index=True")
    if column.default is not None:
        default_value = column.default.arg
        column_params.append(f"default={default_value}")
    else:
        default_value = column.default
        column_params.append(f"default={default_value}")

    type = repr(column.type)
    if type.startswith("VARCHAR"):
        type = type.replace("VARCHAR", "String")
    else:
        type = type.title()
    column_str = f"{column.name} = Column({type}"
    if column_params:
        column_str += f", {', '.join(column_params)}"
    column_str += ")"
    print(f"    {column_str}")

# ユニークキー情報を表示
if "unique_constraints" in dir(inspector):
    for uc in inspector.unique_constraints:
        uc_columns = ", ".join([column.name for column in uc["column_names"]])
        print(f"    __table_args__ = (UniqueConstraint({uc_columns}),)")

# インデックス情報を表示
if "indexes" in dir(inspector):
    for index in inspector.indexes:
        index_columns = ", ".join([column.name for column in index["column_names"]])
        print(f"    {index['name']} = Index('{index['name']}', {index_columns})")

# プライマリーキー情報を表示
primary_keys = inspector.primary_key
if primary_keys:
    pk_names = ", ".join([pk.name for pk in primary_keys])
    print(f"    __mapper_args__ = {{'primary_key': [{pk_names}]}}")

inspector を使って automap により生成されたクラスに対してカラムやキーの情報を取得しそれを print する流れになります
ユニークキーやインデックスがない場合はそもそも属性自体が inspector に生えないので属性があるかのチェックが必要です

また今回表示している情報がすべてでないケースがあるので表示したい情報は既存のテーブル定義を調べて表示させる必要があります
例えば enum や text, blob 型への変換は対応していないので上記コードに追加する必要があります

これを実行すると以下のようなクラス定義が表示されます

class User(Base):
    __tablename__ = 'user'
    id = Column(Integer(), primary_key=True, default=None)
    name = Column(String(collation='utf8mb4_general_ci', length=50), default=None)
    age = Column(Integer(), default=None)
    profile = Column(String(collation='utf8mb4_general_ci', length=50), default=None)
    __mapper_args__ = {'primary_key': [id]}

動作確認

生成されたクラス情報を使って実際に CRUD 処理してみます

  • vim ./test.py
from sqlalchemy import create_engine
from sqlalchemy.orm import declarative_base, sessionmaker
from sqlalchemy.schema import Column
from sqlalchemy.types import Integer, String

engine = create_engine("mysql+pymysql://root@localhost/test?charset=utf8mb4")
SessionClass = sessionmaker(engine)
db_session = SessionClass()

Base = declarative_base()


class User(Base):
    __tablename__ = "user"

    id = Column(Integer(), primary_key=True, default=None)
    name = Column(String(collation="utf8mb4_general_ci", length=50), default=None)
    age = Column(Integer(), default=None)
    profile = Column(String(collation="utf8mb4_general_ci", length=50), default=None)

    __mapper_args__ = {"primary_key": [id]}

    def __repr__(self) -> str:
        return f"id: {self.id}, name: {self.name}, age: {self.age}, profile: {self.profile}"


users = db_session.query(User).all()
for u in users:
    print(u)

改行や repr 関数は手動で追加しています
一応自動生成されたモデルを使ってデータを取得できることは確認できました

最後に

リフレクションを使って愚直にやっていくしか方法がないので完璧にモデルを定義するのであれば結構労力がかかりそうです
もしかすると最初から手動で頑張って既存テーブルの構造を解析してクラスを定義するほうが正確で早いかもしれません

automap は便利なのですがクラス定義がないのでコードの可読性が下がるほかエディタや静的型チェックとの相性も悪いのでプロダクションのコードで採用するのは微妙なのかもしれません

0 件のコメント:

コメントを投稿