You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
608 lines
17 KiB
608 lines
17 KiB
2 years ago
|
from __future__ import annotations
|
||
|
|
||
|
import contextlib
|
||
|
import re
|
||
|
from typing import Any
|
||
|
from typing import Iterable
|
||
|
from typing import Iterator
|
||
|
from typing import Mapping
|
||
|
from typing import Optional
|
||
|
from typing import TYPE_CHECKING
|
||
|
from typing import TypeVar
|
||
|
from typing import Union
|
||
|
|
||
|
from sqlalchemy import __version__
|
||
|
from sqlalchemy import inspect
|
||
|
from sqlalchemy import schema
|
||
|
from sqlalchemy import sql
|
||
|
from sqlalchemy import types as sqltypes
|
||
|
from sqlalchemy.engine import url
|
||
|
from sqlalchemy.ext.compiler import compiles
|
||
|
from sqlalchemy.schema import CheckConstraint
|
||
|
from sqlalchemy.schema import Column
|
||
|
from sqlalchemy.schema import ForeignKeyConstraint
|
||
|
from sqlalchemy.sql import visitors
|
||
|
from sqlalchemy.sql.elements import BindParameter
|
||
|
from sqlalchemy.sql.elements import ColumnClause
|
||
|
from sqlalchemy.sql.elements import quoted_name
|
||
|
from sqlalchemy.sql.elements import TextClause
|
||
|
from sqlalchemy.sql.elements import UnaryExpression
|
||
|
from sqlalchemy.sql.visitors import traverse
|
||
|
from typing_extensions import TypeGuard
|
||
|
|
||
|
if TYPE_CHECKING:
|
||
|
from sqlalchemy import Index
|
||
|
from sqlalchemy import Table
|
||
|
from sqlalchemy.engine import Connection
|
||
|
from sqlalchemy.engine import Dialect
|
||
|
from sqlalchemy.engine import Transaction
|
||
|
from sqlalchemy.engine.reflection import Inspector
|
||
|
from sqlalchemy.sql.base import ColumnCollection
|
||
|
from sqlalchemy.sql.compiler import SQLCompiler
|
||
|
from sqlalchemy.sql.dml import Insert
|
||
|
from sqlalchemy.sql.elements import ColumnElement
|
||
|
from sqlalchemy.sql.schema import Constraint
|
||
|
from sqlalchemy.sql.schema import SchemaItem
|
||
|
from sqlalchemy.sql.selectable import Select
|
||
|
from sqlalchemy.sql.selectable import TableClause
|
||
|
|
||
|
_CE = TypeVar("_CE", bound=Union["ColumnElement", "SchemaItem"])
|
||
|
|
||
|
|
||
|
def _safe_int(value: str) -> Union[int, str]:
|
||
|
try:
|
||
|
return int(value)
|
||
|
except:
|
||
|
return value
|
||
|
|
||
|
|
||
|
_vers = tuple(
|
||
|
[_safe_int(x) for x in re.findall(r"(\d+|[abc]\d)", __version__)]
|
||
|
)
|
||
|
sqla_13 = _vers >= (1, 3)
|
||
|
sqla_14 = _vers >= (1, 4)
|
||
|
sqla_14_26 = _vers >= (1, 4, 26)
|
||
|
sqla_2 = _vers >= (2,)
|
||
|
sqlalchemy_version = __version__
|
||
|
|
||
|
try:
|
||
|
from sqlalchemy.sql.naming import _NONE_NAME as _NONE_NAME
|
||
|
except ImportError:
|
||
|
from sqlalchemy.sql.elements import _NONE_NAME as _NONE_NAME # type: ignore # noqa: E501
|
||
|
|
||
|
|
||
|
if sqla_14:
|
||
|
# when future engine merges, this can be again based on version string
|
||
|
from sqlalchemy.engine import Connection as legacy_connection
|
||
|
|
||
|
sqla_1x = not hasattr(legacy_connection, "commit")
|
||
|
else:
|
||
|
sqla_1x = True
|
||
|
|
||
|
try:
|
||
|
from sqlalchemy import Computed # noqa
|
||
|
except ImportError:
|
||
|
Computed = type(None) # type: ignore
|
||
|
has_computed = False
|
||
|
has_computed_reflection = False
|
||
|
else:
|
||
|
has_computed = True
|
||
|
has_computed_reflection = _vers >= (1, 3, 16)
|
||
|
|
||
|
try:
|
||
|
from sqlalchemy import Identity # noqa
|
||
|
except ImportError:
|
||
|
Identity = type(None) # type: ignore
|
||
|
has_identity = False
|
||
|
else:
|
||
|
# attributes common to Indentity and Sequence
|
||
|
_identity_options_attrs = (
|
||
|
"start",
|
||
|
"increment",
|
||
|
"minvalue",
|
||
|
"maxvalue",
|
||
|
"nominvalue",
|
||
|
"nomaxvalue",
|
||
|
"cycle",
|
||
|
"cache",
|
||
|
"order",
|
||
|
)
|
||
|
# attributes of Indentity
|
||
|
_identity_attrs = _identity_options_attrs + ("on_null",)
|
||
|
has_identity = True
|
||
|
|
||
|
if sqla_2:
|
||
|
from sqlalchemy.sql.base import _NoneName
|
||
|
else:
|
||
|
from sqlalchemy.util import symbol as _NoneName # type: ignore[assignment]
|
||
|
|
||
|
|
||
|
_ConstraintName = Union[None, str, _NoneName]
|
||
|
|
||
|
_ConstraintNameDefined = Union[str, _NoneName]
|
||
|
|
||
|
|
||
|
def constraint_name_defined(
|
||
|
name: _ConstraintName,
|
||
|
) -> TypeGuard[_ConstraintNameDefined]:
|
||
|
return name is _NONE_NAME or isinstance(name, (str, _NoneName))
|
||
|
|
||
|
|
||
|
def constraint_name_string(
|
||
|
name: _ConstraintName,
|
||
|
) -> TypeGuard[str]:
|
||
|
return isinstance(name, str)
|
||
|
|
||
|
|
||
|
def constraint_name_or_none(
|
||
|
name: _ConstraintName,
|
||
|
) -> Optional[str]:
|
||
|
return name if constraint_name_string(name) else None
|
||
|
|
||
|
|
||
|
AUTOINCREMENT_DEFAULT = "auto"
|
||
|
|
||
|
|
||
|
@contextlib.contextmanager
|
||
|
def _ensure_scope_for_ddl(
|
||
|
connection: Optional[Connection],
|
||
|
) -> Iterator[None]:
|
||
|
try:
|
||
|
in_transaction = connection.in_transaction # type: ignore[union-attr]
|
||
|
except AttributeError:
|
||
|
# catch for MockConnection, None
|
||
|
in_transaction = None
|
||
|
pass
|
||
|
|
||
|
# yield outside the catch
|
||
|
if in_transaction is None:
|
||
|
yield
|
||
|
else:
|
||
|
if not in_transaction():
|
||
|
assert connection is not None
|
||
|
with connection.begin():
|
||
|
yield
|
||
|
else:
|
||
|
yield
|
||
|
|
||
|
|
||
|
def url_render_as_string(url, hide_password=True):
|
||
|
if sqla_14:
|
||
|
return url.render_as_string(hide_password=hide_password)
|
||
|
else:
|
||
|
return url.__to_string__(hide_password=hide_password)
|
||
|
|
||
|
|
||
|
def _safe_begin_connection_transaction(
|
||
|
connection: Connection,
|
||
|
) -> Transaction:
|
||
|
transaction = _get_connection_transaction(connection)
|
||
|
if transaction:
|
||
|
return transaction
|
||
|
else:
|
||
|
return connection.begin()
|
||
|
|
||
|
|
||
|
def _safe_commit_connection_transaction(
|
||
|
connection: Connection,
|
||
|
) -> None:
|
||
|
transaction = _get_connection_transaction(connection)
|
||
|
if transaction:
|
||
|
transaction.commit()
|
||
|
|
||
|
|
||
|
def _safe_rollback_connection_transaction(
|
||
|
connection: Connection,
|
||
|
) -> None:
|
||
|
transaction = _get_connection_transaction(connection)
|
||
|
if transaction:
|
||
|
transaction.rollback()
|
||
|
|
||
|
|
||
|
def _get_connection_in_transaction(connection: Optional[Connection]) -> bool:
|
||
|
try:
|
||
|
in_transaction = connection.in_transaction # type: ignore
|
||
|
except AttributeError:
|
||
|
# catch for MockConnection
|
||
|
return False
|
||
|
else:
|
||
|
return in_transaction()
|
||
|
|
||
|
|
||
|
def _idx_table_bound_expressions(idx: Index) -> Iterable[ColumnElement[Any]]:
|
||
|
return idx.expressions # type: ignore
|
||
|
|
||
|
|
||
|
def _copy(schema_item: _CE, **kw) -> _CE:
|
||
|
if hasattr(schema_item, "_copy"):
|
||
|
return schema_item._copy(**kw) # type: ignore[union-attr]
|
||
|
else:
|
||
|
return schema_item.copy(**kw) # type: ignore[union-attr]
|
||
|
|
||
|
|
||
|
def _get_connection_transaction(
|
||
|
connection: Connection,
|
||
|
) -> Optional[Transaction]:
|
||
|
if sqla_14:
|
||
|
return connection.get_transaction()
|
||
|
else:
|
||
|
r = connection._root # type: ignore[attr-defined]
|
||
|
return r._Connection__transaction
|
||
|
|
||
|
|
||
|
def _create_url(*arg, **kw) -> url.URL:
|
||
|
if hasattr(url.URL, "create"):
|
||
|
return url.URL.create(*arg, **kw)
|
||
|
else:
|
||
|
return url.URL(*arg, **kw)
|
||
|
|
||
|
|
||
|
def _connectable_has_table(
|
||
|
connectable: Connection, tablename: str, schemaname: Union[str, None]
|
||
|
) -> bool:
|
||
|
if sqla_14:
|
||
|
return inspect(connectable).has_table(tablename, schemaname)
|
||
|
else:
|
||
|
return connectable.dialect.has_table(
|
||
|
connectable, tablename, schemaname
|
||
|
)
|
||
|
|
||
|
|
||
|
def _exec_on_inspector(inspector, statement, **params):
|
||
|
if sqla_14:
|
||
|
with inspector._operation_context() as conn:
|
||
|
return conn.execute(statement, params)
|
||
|
else:
|
||
|
return inspector.bind.execute(statement, params)
|
||
|
|
||
|
|
||
|
def _nullability_might_be_unset(metadata_column):
|
||
|
if not sqla_14:
|
||
|
return metadata_column.nullable
|
||
|
else:
|
||
|
from sqlalchemy.sql import schema
|
||
|
|
||
|
return (
|
||
|
metadata_column._user_defined_nullable is schema.NULL_UNSPECIFIED
|
||
|
)
|
||
|
|
||
|
|
||
|
def _server_default_is_computed(*server_default) -> bool:
|
||
|
if not has_computed:
|
||
|
return False
|
||
|
else:
|
||
|
return any(isinstance(sd, Computed) for sd in server_default)
|
||
|
|
||
|
|
||
|
def _server_default_is_identity(*server_default) -> bool:
|
||
|
if not sqla_14:
|
||
|
return False
|
||
|
else:
|
||
|
return any(isinstance(sd, Identity) for sd in server_default)
|
||
|
|
||
|
|
||
|
def _table_for_constraint(constraint: Constraint) -> Table:
|
||
|
if isinstance(constraint, ForeignKeyConstraint):
|
||
|
table = constraint.parent
|
||
|
assert table is not None
|
||
|
return table # type: ignore[return-value]
|
||
|
else:
|
||
|
return constraint.table
|
||
|
|
||
|
|
||
|
def _columns_for_constraint(constraint):
|
||
|
if isinstance(constraint, ForeignKeyConstraint):
|
||
|
return [fk.parent for fk in constraint.elements]
|
||
|
elif isinstance(constraint, CheckConstraint):
|
||
|
return _find_columns(constraint.sqltext)
|
||
|
else:
|
||
|
return list(constraint.columns)
|
||
|
|
||
|
|
||
|
def _reflect_table(
|
||
|
inspector: Inspector, table: Table, include_cols: None
|
||
|
) -> None:
|
||
|
if sqla_14:
|
||
|
return inspector.reflect_table(table, None)
|
||
|
else:
|
||
|
return inspector.reflecttable( # type: ignore[attr-defined]
|
||
|
table, None
|
||
|
)
|
||
|
|
||
|
|
||
|
def _resolve_for_variant(type_, dialect):
|
||
|
if _type_has_variants(type_):
|
||
|
base_type, mapping = _get_variant_mapping(type_)
|
||
|
return mapping.get(dialect.name, base_type)
|
||
|
else:
|
||
|
return type_
|
||
|
|
||
|
|
||
|
if hasattr(sqltypes.TypeEngine, "_variant_mapping"):
|
||
|
|
||
|
def _type_has_variants(type_):
|
||
|
return bool(type_._variant_mapping)
|
||
|
|
||
|
def _get_variant_mapping(type_):
|
||
|
return type_, type_._variant_mapping
|
||
|
|
||
|
else:
|
||
|
|
||
|
def _type_has_variants(type_):
|
||
|
return type(type_) is sqltypes.Variant
|
||
|
|
||
|
def _get_variant_mapping(type_):
|
||
|
return type_.impl, type_.mapping
|
||
|
|
||
|
|
||
|
def _fk_spec(constraint):
|
||
|
source_columns = [
|
||
|
constraint.columns[key].name for key in constraint.column_keys
|
||
|
]
|
||
|
|
||
|
source_table = constraint.parent.name
|
||
|
source_schema = constraint.parent.schema
|
||
|
target_schema = constraint.elements[0].column.table.schema
|
||
|
target_table = constraint.elements[0].column.table.name
|
||
|
target_columns = [element.column.name for element in constraint.elements]
|
||
|
ondelete = constraint.ondelete
|
||
|
onupdate = constraint.onupdate
|
||
|
deferrable = constraint.deferrable
|
||
|
initially = constraint.initially
|
||
|
return (
|
||
|
source_schema,
|
||
|
source_table,
|
||
|
source_columns,
|
||
|
target_schema,
|
||
|
target_table,
|
||
|
target_columns,
|
||
|
onupdate,
|
||
|
ondelete,
|
||
|
deferrable,
|
||
|
initially,
|
||
|
)
|
||
|
|
||
|
|
||
|
def _fk_is_self_referential(constraint: ForeignKeyConstraint) -> bool:
|
||
|
spec = constraint.elements[0]._get_colspec() # type: ignore[attr-defined]
|
||
|
tokens = spec.split(".")
|
||
|
tokens.pop(-1) # colname
|
||
|
tablekey = ".".join(tokens)
|
||
|
assert constraint.parent is not None
|
||
|
return tablekey == constraint.parent.key
|
||
|
|
||
|
|
||
|
def _is_type_bound(constraint: Constraint) -> bool:
|
||
|
# this deals with SQLAlchemy #3260, don't copy CHECK constraints
|
||
|
# that will be generated by the type.
|
||
|
# new feature added for #3260
|
||
|
return constraint._type_bound # type: ignore[attr-defined]
|
||
|
|
||
|
|
||
|
def _find_columns(clause):
|
||
|
"""locate Column objects within the given expression."""
|
||
|
|
||
|
cols = set()
|
||
|
traverse(clause, {}, {"column": cols.add})
|
||
|
return cols
|
||
|
|
||
|
|
||
|
def _remove_column_from_collection(
|
||
|
collection: ColumnCollection, column: Union[Column, ColumnClause]
|
||
|
) -> None:
|
||
|
"""remove a column from a ColumnCollection."""
|
||
|
|
||
|
# workaround for older SQLAlchemy, remove the
|
||
|
# same object that's present
|
||
|
assert column.key is not None
|
||
|
to_remove = collection[column.key]
|
||
|
|
||
|
# SQLAlchemy 2.0 will use more ReadOnlyColumnCollection
|
||
|
# (renamed from ImmutableColumnCollection)
|
||
|
if hasattr(collection, "_immutable") or hasattr(collection, "_readonly"):
|
||
|
collection._parent.remove(to_remove)
|
||
|
else:
|
||
|
collection.remove(to_remove)
|
||
|
|
||
|
|
||
|
def _textual_index_column(
|
||
|
table: Table, text_: Union[str, TextClause, ColumnElement]
|
||
|
) -> Union[ColumnElement, Column]:
|
||
|
"""a workaround for the Index construct's severe lack of flexibility"""
|
||
|
if isinstance(text_, str):
|
||
|
c = Column(text_, sqltypes.NULLTYPE)
|
||
|
table.append_column(c)
|
||
|
return c
|
||
|
elif isinstance(text_, TextClause):
|
||
|
return _textual_index_element(table, text_)
|
||
|
elif isinstance(text_, _textual_index_element):
|
||
|
return _textual_index_column(table, text_.text)
|
||
|
elif isinstance(text_, sql.ColumnElement):
|
||
|
return _copy_expression(text_, table)
|
||
|
else:
|
||
|
raise ValueError("String or text() construct expected")
|
||
|
|
||
|
|
||
|
def _copy_expression(expression: _CE, target_table: Table) -> _CE:
|
||
|
def replace(col):
|
||
|
if (
|
||
|
isinstance(col, Column)
|
||
|
and col.table is not None
|
||
|
and col.table is not target_table
|
||
|
):
|
||
|
if col.name in target_table.c:
|
||
|
return target_table.c[col.name]
|
||
|
else:
|
||
|
c = _copy(col)
|
||
|
target_table.append_column(c)
|
||
|
return c
|
||
|
else:
|
||
|
return None
|
||
|
|
||
|
return visitors.replacement_traverse( # type: ignore[call-overload]
|
||
|
expression, {}, replace
|
||
|
)
|
||
|
|
||
|
|
||
|
class _textual_index_element(sql.ColumnElement):
|
||
|
"""Wrap around a sqlalchemy text() construct in such a way that
|
||
|
we appear like a column-oriented SQL expression to an Index
|
||
|
construct.
|
||
|
|
||
|
The issue here is that currently the Postgresql dialect, the biggest
|
||
|
recipient of functional indexes, keys all the index expressions to
|
||
|
the corresponding column expressions when rendering CREATE INDEX,
|
||
|
so the Index we create here needs to have a .columns collection that
|
||
|
is the same length as the .expressions collection. Ultimately
|
||
|
SQLAlchemy should support text() expressions in indexes.
|
||
|
|
||
|
See SQLAlchemy issue 3174.
|
||
|
|
||
|
"""
|
||
|
|
||
|
__visit_name__ = "_textual_idx_element"
|
||
|
|
||
|
def __init__(self, table: Table, text: TextClause) -> None:
|
||
|
self.table = table
|
||
|
self.text = text
|
||
|
self.key = text.text
|
||
|
self.fake_column = schema.Column(self.text.text, sqltypes.NULLTYPE)
|
||
|
table.append_column(self.fake_column)
|
||
|
|
||
|
def get_children(self):
|
||
|
return [self.fake_column]
|
||
|
|
||
|
|
||
|
@compiles(_textual_index_element)
|
||
|
def _render_textual_index_column(
|
||
|
element: _textual_index_element, compiler: SQLCompiler, **kw
|
||
|
) -> str:
|
||
|
return compiler.process(element.text, **kw)
|
||
|
|
||
|
|
||
|
class _literal_bindparam(BindParameter):
|
||
|
pass
|
||
|
|
||
|
|
||
|
@compiles(_literal_bindparam)
|
||
|
def _render_literal_bindparam(
|
||
|
element: _literal_bindparam, compiler: SQLCompiler, **kw
|
||
|
) -> str:
|
||
|
return compiler.render_literal_bindparam(element, **kw)
|
||
|
|
||
|
|
||
|
def _get_index_expressions(idx):
|
||
|
return list(idx.expressions)
|
||
|
|
||
|
|
||
|
def _get_index_column_names(idx):
|
||
|
return [getattr(exp, "name", None) for exp in _get_index_expressions(idx)]
|
||
|
|
||
|
|
||
|
def _column_kwargs(col: Column) -> Mapping:
|
||
|
if sqla_13:
|
||
|
return col.kwargs
|
||
|
else:
|
||
|
return {}
|
||
|
|
||
|
|
||
|
def _get_constraint_final_name(
|
||
|
constraint: Union[Index, Constraint], dialect: Optional[Dialect]
|
||
|
) -> Optional[str]:
|
||
|
if constraint.name is None:
|
||
|
return None
|
||
|
assert dialect is not None
|
||
|
if sqla_14:
|
||
|
# for SQLAlchemy 1.4 we would like to have the option to expand
|
||
|
# the use of "deferred" names for constraints as well as to have
|
||
|
# some flexibility with "None" name and similar; make use of new
|
||
|
# SQLAlchemy API to return what would be the final compiled form of
|
||
|
# the name for this dialect.
|
||
|
return dialect.identifier_preparer.format_constraint(
|
||
|
constraint, _alembic_quote=False
|
||
|
)
|
||
|
else:
|
||
|
|
||
|
# prior to SQLAlchemy 1.4, work around quoting logic to get at the
|
||
|
# final compiled name without quotes.
|
||
|
if hasattr(constraint.name, "quote"):
|
||
|
# might be quoted_name, might be truncated_name, keep it the
|
||
|
# same
|
||
|
quoted_name_cls: type = type(constraint.name)
|
||
|
else:
|
||
|
quoted_name_cls = quoted_name
|
||
|
|
||
|
new_name = quoted_name_cls(str(constraint.name), quote=False)
|
||
|
constraint = constraint.__class__(name=new_name)
|
||
|
|
||
|
if isinstance(constraint, schema.Index):
|
||
|
# name should not be quoted.
|
||
|
d = dialect.ddl_compiler(dialect, None) # type: ignore[arg-type]
|
||
|
return d._prepared_index_name( # type: ignore[attr-defined]
|
||
|
constraint
|
||
|
)
|
||
|
else:
|
||
|
# name should not be quoted.
|
||
|
return dialect.identifier_preparer.format_constraint(constraint)
|
||
|
|
||
|
|
||
|
def _constraint_is_named(
|
||
|
constraint: Union[Constraint, Index], dialect: Optional[Dialect]
|
||
|
) -> bool:
|
||
|
if sqla_14:
|
||
|
if constraint.name is None:
|
||
|
return False
|
||
|
assert dialect is not None
|
||
|
name = dialect.identifier_preparer.format_constraint(
|
||
|
constraint, _alembic_quote=False
|
||
|
)
|
||
|
return name is not None
|
||
|
else:
|
||
|
return constraint.name is not None
|
||
|
|
||
|
|
||
|
def _is_mariadb(mysql_dialect: Dialect) -> bool:
|
||
|
if sqla_14:
|
||
|
return mysql_dialect.is_mariadb # type: ignore[attr-defined]
|
||
|
else:
|
||
|
return bool(
|
||
|
mysql_dialect.server_version_info
|
||
|
and mysql_dialect._is_mariadb # type: ignore[attr-defined]
|
||
|
)
|
||
|
|
||
|
|
||
|
def _mariadb_normalized_version_info(mysql_dialect):
|
||
|
return mysql_dialect._mariadb_normalized_version_info
|
||
|
|
||
|
|
||
|
def _insert_inline(table: Union[TableClause, Table]) -> Insert:
|
||
|
if sqla_14:
|
||
|
return table.insert().inline()
|
||
|
else:
|
||
|
return table.insert(inline=True) # type: ignore[call-arg]
|
||
|
|
||
|
|
||
|
if sqla_14:
|
||
|
from sqlalchemy import create_mock_engine
|
||
|
from sqlalchemy import select as _select
|
||
|
else:
|
||
|
from sqlalchemy import create_engine
|
||
|
|
||
|
def create_mock_engine(url, executor, **kw): # type: ignore[misc]
|
||
|
return create_engine(
|
||
|
"postgresql://", strategy="mock", executor=executor
|
||
|
)
|
||
|
|
||
|
def _select(*columns, **kw) -> Select: # type: ignore[no-redef]
|
||
|
return sql.select(list(columns), **kw) # type: ignore[call-overload]
|
||
|
|
||
|
|
||
|
def is_expression_index(index: Index) -> bool:
|
||
|
expr: Any
|
||
|
for expr in index.expressions:
|
||
|
while isinstance(expr, UnaryExpression):
|
||
|
expr = expr.element
|
||
|
if not isinstance(expr, ColumnClause) or expr.is_literal:
|
||
|
return True
|
||
|
return False
|