You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
336 lines
9.6 KiB
336 lines
9.6 KiB
2 years ago
|
from __future__ import annotations
|
||
|
|
||
|
from typing import Any
|
||
|
from typing import Dict
|
||
|
from typing import Set
|
||
|
|
||
|
from sqlalchemy import CHAR
|
||
|
from sqlalchemy import CheckConstraint
|
||
|
from sqlalchemy import Column
|
||
|
from sqlalchemy import event
|
||
|
from sqlalchemy import ForeignKey
|
||
|
from sqlalchemy import Index
|
||
|
from sqlalchemy import inspect
|
||
|
from sqlalchemy import Integer
|
||
|
from sqlalchemy import MetaData
|
||
|
from sqlalchemy import Numeric
|
||
|
from sqlalchemy import String
|
||
|
from sqlalchemy import Table
|
||
|
from sqlalchemy import Text
|
||
|
from sqlalchemy import text
|
||
|
from sqlalchemy import UniqueConstraint
|
||
|
|
||
|
from ... import autogenerate
|
||
|
from ... import util
|
||
|
from ...autogenerate import api
|
||
|
from ...ddl.base import _fk_spec
|
||
|
from ...migration import MigrationContext
|
||
|
from ...operations import ops
|
||
|
from ...testing import config
|
||
|
from ...testing import eq_
|
||
|
from ...testing.env import clear_staging_env
|
||
|
from ...testing.env import staging_env
|
||
|
|
||
|
names_in_this_test: Set[Any] = set()
|
||
|
|
||
|
|
||
|
@event.listens_for(Table, "after_parent_attach")
|
||
|
def new_table(table, parent):
|
||
|
names_in_this_test.add(table.name)
|
||
|
|
||
|
|
||
|
def _default_include_object(obj, name, type_, reflected, compare_to):
|
||
|
if type_ == "table":
|
||
|
return name in names_in_this_test
|
||
|
else:
|
||
|
return True
|
||
|
|
||
|
|
||
|
_default_object_filters: Any = _default_include_object
|
||
|
|
||
|
_default_name_filters: Any = None
|
||
|
|
||
|
|
||
|
class ModelOne:
|
||
|
__requires__ = ("unique_constraint_reflection",)
|
||
|
|
||
|
schema: Any = None
|
||
|
|
||
|
@classmethod
|
||
|
def _get_db_schema(cls):
|
||
|
schema = cls.schema
|
||
|
|
||
|
m = MetaData(schema=schema)
|
||
|
|
||
|
Table(
|
||
|
"user",
|
||
|
m,
|
||
|
Column("id", Integer, primary_key=True),
|
||
|
Column("name", String(50)),
|
||
|
Column("a1", Text),
|
||
|
Column("pw", String(50)),
|
||
|
Index("pw_idx", "pw"),
|
||
|
)
|
||
|
|
||
|
Table(
|
||
|
"address",
|
||
|
m,
|
||
|
Column("id", Integer, primary_key=True),
|
||
|
Column("email_address", String(100), nullable=False),
|
||
|
)
|
||
|
|
||
|
Table(
|
||
|
"order",
|
||
|
m,
|
||
|
Column("order_id", Integer, primary_key=True),
|
||
|
Column(
|
||
|
"amount",
|
||
|
Numeric(8, 2),
|
||
|
nullable=False,
|
||
|
server_default=text("0"),
|
||
|
),
|
||
|
CheckConstraint("amount >= 0", name="ck_order_amount"),
|
||
|
)
|
||
|
|
||
|
Table(
|
||
|
"extra",
|
||
|
m,
|
||
|
Column("x", CHAR),
|
||
|
Column("uid", Integer, ForeignKey("user.id")),
|
||
|
)
|
||
|
|
||
|
return m
|
||
|
|
||
|
@classmethod
|
||
|
def _get_model_schema(cls):
|
||
|
schema = cls.schema
|
||
|
|
||
|
m = MetaData(schema=schema)
|
||
|
|
||
|
Table(
|
||
|
"user",
|
||
|
m,
|
||
|
Column("id", Integer, primary_key=True),
|
||
|
Column("name", String(50), nullable=False),
|
||
|
Column("a1", Text, server_default="x"),
|
||
|
)
|
||
|
|
||
|
Table(
|
||
|
"address",
|
||
|
m,
|
||
|
Column("id", Integer, primary_key=True),
|
||
|
Column("email_address", String(100), nullable=False),
|
||
|
Column("street", String(50)),
|
||
|
UniqueConstraint("email_address", name="uq_email"),
|
||
|
)
|
||
|
|
||
|
Table(
|
||
|
"order",
|
||
|
m,
|
||
|
Column("order_id", Integer, primary_key=True),
|
||
|
Column(
|
||
|
"amount",
|
||
|
Numeric(10, 2),
|
||
|
nullable=True,
|
||
|
server_default=text("0"),
|
||
|
),
|
||
|
Column("user_id", Integer, ForeignKey("user.id")),
|
||
|
CheckConstraint("amount > -1", name="ck_order_amount"),
|
||
|
)
|
||
|
|
||
|
Table(
|
||
|
"item",
|
||
|
m,
|
||
|
Column("id", Integer, primary_key=True),
|
||
|
Column("description", String(100)),
|
||
|
Column("order_id", Integer, ForeignKey("order.order_id")),
|
||
|
CheckConstraint("len(description) > 5"),
|
||
|
)
|
||
|
return m
|
||
|
|
||
|
|
||
|
class _ComparesFKs:
|
||
|
def _assert_fk_diff(
|
||
|
self,
|
||
|
diff,
|
||
|
type_,
|
||
|
source_table,
|
||
|
source_columns,
|
||
|
target_table,
|
||
|
target_columns,
|
||
|
name=None,
|
||
|
conditional_name=None,
|
||
|
source_schema=None,
|
||
|
onupdate=None,
|
||
|
ondelete=None,
|
||
|
initially=None,
|
||
|
deferrable=None,
|
||
|
):
|
||
|
# the public API for ForeignKeyConstraint was not very rich
|
||
|
# in 0.7, 0.8, so here we use the well-known but slightly
|
||
|
# private API to get at its elements
|
||
|
(
|
||
|
fk_source_schema,
|
||
|
fk_source_table,
|
||
|
fk_source_columns,
|
||
|
fk_target_schema,
|
||
|
fk_target_table,
|
||
|
fk_target_columns,
|
||
|
fk_onupdate,
|
||
|
fk_ondelete,
|
||
|
fk_deferrable,
|
||
|
fk_initially,
|
||
|
) = _fk_spec(diff[1])
|
||
|
|
||
|
eq_(diff[0], type_)
|
||
|
eq_(fk_source_table, source_table)
|
||
|
eq_(fk_source_columns, source_columns)
|
||
|
eq_(fk_target_table, target_table)
|
||
|
eq_(fk_source_schema, source_schema)
|
||
|
eq_(fk_onupdate, onupdate)
|
||
|
eq_(fk_ondelete, ondelete)
|
||
|
eq_(fk_initially, initially)
|
||
|
eq_(fk_deferrable, deferrable)
|
||
|
|
||
|
eq_([elem.column.name for elem in diff[1].elements], target_columns)
|
||
|
if conditional_name is not None:
|
||
|
if conditional_name == "servergenerated":
|
||
|
fks = inspect(self.bind).get_foreign_keys(source_table)
|
||
|
server_fk_name = fks[0]["name"]
|
||
|
eq_(diff[1].name, server_fk_name)
|
||
|
else:
|
||
|
eq_(diff[1].name, conditional_name)
|
||
|
else:
|
||
|
eq_(diff[1].name, name)
|
||
|
|
||
|
|
||
|
class AutogenTest(_ComparesFKs):
|
||
|
def _flatten_diffs(self, diffs):
|
||
|
for d in diffs:
|
||
|
if isinstance(d, list):
|
||
|
yield from self._flatten_diffs(d)
|
||
|
else:
|
||
|
yield d
|
||
|
|
||
|
@classmethod
|
||
|
def _get_bind(cls):
|
||
|
return config.db
|
||
|
|
||
|
configure_opts: Dict[Any, Any] = {}
|
||
|
|
||
|
@classmethod
|
||
|
def setup_class(cls):
|
||
|
staging_env()
|
||
|
cls.bind = cls._get_bind()
|
||
|
cls.m1 = cls._get_db_schema()
|
||
|
cls.m1.create_all(cls.bind)
|
||
|
cls.m2 = cls._get_model_schema()
|
||
|
|
||
|
@classmethod
|
||
|
def teardown_class(cls):
|
||
|
cls.m1.drop_all(cls.bind)
|
||
|
clear_staging_env()
|
||
|
|
||
|
def setUp(self):
|
||
|
self.conn = conn = self.bind.connect()
|
||
|
ctx_opts = {
|
||
|
"compare_type": True,
|
||
|
"compare_server_default": True,
|
||
|
"target_metadata": self.m2,
|
||
|
"upgrade_token": "upgrades",
|
||
|
"downgrade_token": "downgrades",
|
||
|
"alembic_module_prefix": "op.",
|
||
|
"sqlalchemy_module_prefix": "sa.",
|
||
|
"include_object": _default_object_filters,
|
||
|
"include_name": _default_name_filters,
|
||
|
}
|
||
|
if self.configure_opts:
|
||
|
ctx_opts.update(self.configure_opts)
|
||
|
self.context = context = MigrationContext.configure(
|
||
|
connection=conn, opts=ctx_opts
|
||
|
)
|
||
|
|
||
|
self.autogen_context = api.AutogenContext(context, self.m2)
|
||
|
|
||
|
def tearDown(self):
|
||
|
self.conn.close()
|
||
|
|
||
|
def _update_context(
|
||
|
self, object_filters=None, name_filters=None, include_schemas=None
|
||
|
):
|
||
|
if include_schemas is not None:
|
||
|
self.autogen_context.opts["include_schemas"] = include_schemas
|
||
|
if object_filters is not None:
|
||
|
self.autogen_context._object_filters = [object_filters]
|
||
|
if name_filters is not None:
|
||
|
self.autogen_context._name_filters = [name_filters]
|
||
|
return self.autogen_context
|
||
|
|
||
|
|
||
|
class AutogenFixtureTest(_ComparesFKs):
|
||
|
def _fixture(
|
||
|
self,
|
||
|
m1,
|
||
|
m2,
|
||
|
include_schemas=False,
|
||
|
opts=None,
|
||
|
object_filters=_default_object_filters,
|
||
|
name_filters=_default_name_filters,
|
||
|
return_ops=False,
|
||
|
max_identifier_length=None,
|
||
|
):
|
||
|
if max_identifier_length:
|
||
|
dialect = self.bind.dialect
|
||
|
existing_length = dialect.max_identifier_length
|
||
|
dialect.max_identifier_length = (
|
||
|
dialect._user_defined_max_identifier_length
|
||
|
) = max_identifier_length
|
||
|
try:
|
||
|
self._alembic_metadata, model_metadata = m1, m2
|
||
|
for m in util.to_list(self._alembic_metadata):
|
||
|
m.create_all(self.bind)
|
||
|
|
||
|
with self.bind.connect() as conn:
|
||
|
ctx_opts = {
|
||
|
"compare_type": True,
|
||
|
"compare_server_default": True,
|
||
|
"target_metadata": model_metadata,
|
||
|
"upgrade_token": "upgrades",
|
||
|
"downgrade_token": "downgrades",
|
||
|
"alembic_module_prefix": "op.",
|
||
|
"sqlalchemy_module_prefix": "sa.",
|
||
|
"include_object": object_filters,
|
||
|
"include_name": name_filters,
|
||
|
"include_schemas": include_schemas,
|
||
|
}
|
||
|
if opts:
|
||
|
ctx_opts.update(opts)
|
||
|
self.context = context = MigrationContext.configure(
|
||
|
connection=conn, opts=ctx_opts
|
||
|
)
|
||
|
|
||
|
autogen_context = api.AutogenContext(context, model_metadata)
|
||
|
uo = ops.UpgradeOps(ops=[])
|
||
|
autogenerate._produce_net_changes(autogen_context, uo)
|
||
|
|
||
|
if return_ops:
|
||
|
return uo
|
||
|
else:
|
||
|
return uo.as_diffs()
|
||
|
finally:
|
||
|
if max_identifier_length:
|
||
|
dialect = self.bind.dialect
|
||
|
dialect.max_identifier_length = (
|
||
|
dialect._user_defined_max_identifier_length
|
||
|
) = existing_length
|
||
|
|
||
|
def setUp(self):
|
||
|
staging_env()
|
||
|
self.bind = config.db
|
||
|
|
||
|
def tearDown(self):
|
||
|
if hasattr(self, "_alembic_metadata"):
|
||
|
for m in util.to_list(self._alembic_metadata):
|
||
|
m.drop_all(self.bind)
|
||
|
clear_staging_env()
|