You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
307 lines
8.9 KiB
307 lines
8.9 KiB
from __future__ import annotations
|
|
|
|
import configparser
|
|
from contextlib import contextmanager
|
|
import io
|
|
import re
|
|
from typing import Any
|
|
from typing import Dict
|
|
|
|
from sqlalchemy import Column
|
|
from sqlalchemy import inspect
|
|
from sqlalchemy import MetaData
|
|
from sqlalchemy import String
|
|
from sqlalchemy import Table
|
|
from sqlalchemy import testing
|
|
from sqlalchemy import text
|
|
from sqlalchemy.testing import config
|
|
from sqlalchemy.testing import mock
|
|
from sqlalchemy.testing.assertions import eq_
|
|
from sqlalchemy.testing.fixtures import TablesTest as SQLAlchemyTablesTest
|
|
from sqlalchemy.testing.fixtures import TestBase as SQLAlchemyTestBase
|
|
|
|
import alembic
|
|
from .assertions import _get_dialect
|
|
from ..environment import EnvironmentContext
|
|
from ..migration import MigrationContext
|
|
from ..operations import Operations
|
|
from ..util import sqla_compat
|
|
from ..util.sqla_compat import create_mock_engine
|
|
from ..util.sqla_compat import sqla_14
|
|
from ..util.sqla_compat import sqla_2
|
|
|
|
|
|
testing_config = configparser.ConfigParser()
|
|
testing_config.read(["test.cfg"])
|
|
|
|
|
|
class TestBase(SQLAlchemyTestBase):
|
|
is_sqlalchemy_future = sqla_2
|
|
|
|
@testing.fixture()
|
|
def ops_context(self, migration_context):
|
|
with migration_context.begin_transaction(_per_migration=True):
|
|
yield Operations(migration_context)
|
|
|
|
@testing.fixture
|
|
def migration_context(self, connection):
|
|
return MigrationContext.configure(
|
|
connection, opts=dict(transaction_per_migration=True)
|
|
)
|
|
|
|
@testing.fixture
|
|
def connection(self):
|
|
with config.db.connect() as conn:
|
|
yield conn
|
|
|
|
|
|
class TablesTest(TestBase, SQLAlchemyTablesTest):
|
|
pass
|
|
|
|
|
|
if sqla_14:
|
|
from sqlalchemy.testing.fixtures import FutureEngineMixin
|
|
else:
|
|
|
|
class FutureEngineMixin: # type:ignore[no-redef]
|
|
__requires__ = ("sqlalchemy_14",)
|
|
|
|
|
|
FutureEngineMixin.is_sqlalchemy_future = True
|
|
|
|
|
|
def capture_db(dialect="postgresql://"):
|
|
buf = []
|
|
|
|
def dump(sql, *multiparams, **params):
|
|
buf.append(str(sql.compile(dialect=engine.dialect)))
|
|
|
|
engine = create_mock_engine(dialect, dump)
|
|
return engine, buf
|
|
|
|
|
|
_engs: Dict[Any, Any] = {}
|
|
|
|
|
|
@contextmanager
|
|
def capture_context_buffer(**kw):
|
|
if kw.pop("bytes_io", False):
|
|
buf = io.BytesIO()
|
|
else:
|
|
buf = io.StringIO()
|
|
|
|
kw.update({"dialect_name": "sqlite", "output_buffer": buf})
|
|
conf = EnvironmentContext.configure
|
|
|
|
def configure(*arg, **opt):
|
|
opt.update(**kw)
|
|
return conf(*arg, **opt)
|
|
|
|
with mock.patch.object(EnvironmentContext, "configure", configure):
|
|
yield buf
|
|
|
|
|
|
@contextmanager
|
|
def capture_engine_context_buffer(**kw):
|
|
from .env import _sqlite_file_db
|
|
from sqlalchemy import event
|
|
|
|
buf = io.StringIO()
|
|
|
|
eng = _sqlite_file_db()
|
|
|
|
conn = eng.connect()
|
|
|
|
@event.listens_for(conn, "before_cursor_execute")
|
|
def bce(conn, cursor, statement, parameters, context, executemany):
|
|
buf.write(statement + "\n")
|
|
|
|
kw.update({"connection": conn})
|
|
conf = EnvironmentContext.configure
|
|
|
|
def configure(*arg, **opt):
|
|
opt.update(**kw)
|
|
return conf(*arg, **opt)
|
|
|
|
with mock.patch.object(EnvironmentContext, "configure", configure):
|
|
yield buf
|
|
|
|
|
|
def op_fixture(
|
|
dialect="default",
|
|
as_sql=False,
|
|
naming_convention=None,
|
|
literal_binds=False,
|
|
native_boolean=None,
|
|
):
|
|
opts = {}
|
|
if naming_convention:
|
|
opts["target_metadata"] = MetaData(naming_convention=naming_convention)
|
|
|
|
class buffer_:
|
|
def __init__(self):
|
|
self.lines = []
|
|
|
|
def write(self, msg):
|
|
msg = msg.strip()
|
|
msg = re.sub(r"[\n\t]", "", msg)
|
|
if as_sql:
|
|
# the impl produces soft tabs,
|
|
# so search for blocks of 4 spaces
|
|
msg = re.sub(r" ", "", msg)
|
|
msg = re.sub(r"\;\n*$", "", msg)
|
|
|
|
self.lines.append(msg)
|
|
|
|
def flush(self):
|
|
pass
|
|
|
|
buf = buffer_()
|
|
|
|
class ctx(MigrationContext):
|
|
def get_buf(self):
|
|
return buf
|
|
|
|
def clear_assertions(self):
|
|
buf.lines[:] = []
|
|
|
|
def assert_(self, *sql):
|
|
# TODO: make this more flexible about
|
|
# whitespace and such
|
|
eq_(buf.lines, [re.sub(r"[\n\t]", "", s) for s in sql])
|
|
|
|
def assert_contains(self, sql):
|
|
for stmt in buf.lines:
|
|
if re.sub(r"[\n\t]", "", sql) in stmt:
|
|
return
|
|
else:
|
|
assert False, "Could not locate fragment %r in %r" % (
|
|
sql,
|
|
buf.lines,
|
|
)
|
|
|
|
if as_sql:
|
|
opts["as_sql"] = as_sql
|
|
if literal_binds:
|
|
opts["literal_binds"] = literal_binds
|
|
if not sqla_14 and dialect == "mariadb":
|
|
ctx_dialect = _get_dialect("mysql")
|
|
ctx_dialect.server_version_info = (10, 4, 0, "MariaDB")
|
|
|
|
else:
|
|
ctx_dialect = _get_dialect(dialect)
|
|
if native_boolean is not None:
|
|
ctx_dialect.supports_native_boolean = native_boolean
|
|
# this is new as of SQLAlchemy 1.2.7 and is used by SQL Server,
|
|
# which breaks assumptions in the alembic test suite
|
|
ctx_dialect.non_native_boolean_check_constraint = True
|
|
if not as_sql:
|
|
|
|
def execute(stmt, *multiparam, **param):
|
|
if isinstance(stmt, str):
|
|
stmt = text(stmt)
|
|
assert stmt.supports_execution
|
|
sql = str(stmt.compile(dialect=ctx_dialect))
|
|
|
|
buf.write(sql)
|
|
|
|
connection = mock.Mock(dialect=ctx_dialect, execute=execute)
|
|
else:
|
|
opts["output_buffer"] = buf
|
|
connection = None
|
|
context = ctx(ctx_dialect, connection, opts)
|
|
|
|
alembic.op._proxy = Operations(context)
|
|
return context
|
|
|
|
|
|
class AlterColRoundTripFixture:
|
|
# since these tests are about syntax, use more recent SQLAlchemy as some of
|
|
# the type / server default compare logic might not work on older
|
|
# SQLAlchemy versions as seems to be the case for SQLAlchemy 1.1 on Oracle
|
|
|
|
__requires__ = ("alter_column",)
|
|
|
|
def setUp(self):
|
|
self.conn = config.db.connect()
|
|
self.ctx = MigrationContext.configure(self.conn)
|
|
self.op = Operations(self.ctx)
|
|
self.metadata = MetaData()
|
|
|
|
def _compare_type(self, t1, t2):
|
|
c1 = Column("q", t1)
|
|
c2 = Column("q", t2)
|
|
assert not self.ctx.impl.compare_type(
|
|
c1, c2
|
|
), "Type objects %r and %r didn't compare as equivalent" % (t1, t2)
|
|
|
|
def _compare_server_default(self, t1, s1, t2, s2):
|
|
c1 = Column("q", t1, server_default=s1)
|
|
c2 = Column("q", t2, server_default=s2)
|
|
assert not self.ctx.impl.compare_server_default(
|
|
c1, c2, s2, s1
|
|
), "server defaults %r and %r didn't compare as equivalent" % (s1, s2)
|
|
|
|
def tearDown(self):
|
|
sqla_compat._safe_rollback_connection_transaction(self.conn)
|
|
with self.conn.begin():
|
|
self.metadata.drop_all(self.conn)
|
|
self.conn.close()
|
|
|
|
def _run_alter_col(self, from_, to_, compare=None):
|
|
column = Column(
|
|
from_.get("name", "colname"),
|
|
from_.get("type", String(10)),
|
|
nullable=from_.get("nullable", True),
|
|
server_default=from_.get("server_default", None),
|
|
# comment=from_.get("comment", None)
|
|
)
|
|
t = Table("x", self.metadata, column)
|
|
|
|
with sqla_compat._ensure_scope_for_ddl(self.conn):
|
|
t.create(self.conn)
|
|
insp = inspect(self.conn)
|
|
old_col = insp.get_columns("x")[0]
|
|
|
|
# TODO: conditional comment support
|
|
self.op.alter_column(
|
|
"x",
|
|
column.name,
|
|
existing_type=column.type,
|
|
existing_server_default=column.server_default
|
|
if column.server_default is not None
|
|
else False,
|
|
existing_nullable=True if column.nullable else False,
|
|
# existing_comment=column.comment,
|
|
nullable=to_.get("nullable", None),
|
|
# modify_comment=False,
|
|
server_default=to_.get("server_default", False),
|
|
new_column_name=to_.get("name", None),
|
|
type_=to_.get("type", None),
|
|
)
|
|
|
|
insp = inspect(self.conn)
|
|
new_col = insp.get_columns("x")[0]
|
|
|
|
if compare is None:
|
|
compare = to_
|
|
|
|
eq_(
|
|
new_col["name"],
|
|
compare["name"] if "name" in compare else column.name,
|
|
)
|
|
self._compare_type(
|
|
new_col["type"], compare.get("type", old_col["type"])
|
|
)
|
|
eq_(new_col["nullable"], compare.get("nullable", column.nullable))
|
|
self._compare_server_default(
|
|
new_col["type"],
|
|
new_col.get("default", None),
|
|
compare.get("type", old_col["type"]),
|
|
compare["server_default"].text
|
|
if "server_default" in compare
|
|
else column.server_default.arg.text
|
|
if column.server_default is not None
|
|
else None,
|
|
)
|