sqlalchemy-enums.py (Source)

#!/usr/bin/env python
# sqlalchemy-enums.py
# Note: you need to at least install `pip install sqlalchemy` for this to work.
from enum import Enum
import sqlalchemy
from sqlalchemy.orm import Mapped, DeclarativeBase, Session, mapped_column
from sqlalchemy import create_engine, select, text
import typing as t
class EnumWithUnknown(sqlalchemy.Enum):
    def __init__(self, *enums, **kw: t.Any):
        super().__init__(*enums, **kw)
        # SQLAlchemy sets the _adapted_from keyword argument sometimes, which contains a reference to the original type - but won't include
        # original keyword arguments, so we need to handle that here.
        self._unknown_value = (
            kw["_adapted_from"]._unknown_value
            if "_adapted_from" in kw
            else kw.get("unknown_value", None)
        )
        if self._unknown_value is None:
            raise ValueError("unknown_value should be a member of the enum")
    # This is the function which resolves the object for the DB value
    def _object_value_for_elem(self, elem):
        try:
            return self._object_lookup[elem]
        except LookupError:
            return self._unknown_value
class Color(Enum):
    UNKNOWN = "unknown"
    LEGACY_RED = "red"
    GREEN = "green"
    BLUE = "blue"
class Base(DeclarativeBase):
    pass
class TestTable(Base):
    __tablename__ = "test_table"
    id: Mapped[int] = mapped_column(primary_key=True)
    value: Mapped[Color] = mapped_column(
        EnumWithUnknown(
            Color,
            values_callable=lambda t: [str(item.value) for item in t],
            unknown_value=Color.UNKNOWN,
        )
    )
engine = create_engine("sqlite://")
Base.metadata.create_all(engine)
with Session(engine) as session:
    # Create normal values
    for enum_item in [Color.LEGACY_RED, Color.GREEN, Color.BLUE]:
        session.add(TestTable(value=enum_item))
    session.commit()
with Session(engine) as session:
    session.add(TestTable(value="reed"))
    session.commit()
# Now try and read the values back
with Session(engine) as session:
    records = session.scalars(select(TestTable)).all()
    print("We restored the following values in code...")
    for record in records:
        print(record.value)
print("But the underlying table contains...")
with engine.connect() as conn:
    print(conn.execute(text("SELECT * FROM test_table;")).all())