You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
168 lines
4.9 KiB
168 lines
4.9 KiB
2 years ago
|
from __future__ import annotations
|
||
|
|
||
|
import contextlib
|
||
|
import re
|
||
|
import sys
|
||
|
from typing import Any
|
||
|
from typing import Dict
|
||
|
|
||
|
from sqlalchemy import exc as sa_exc
|
||
|
from sqlalchemy.engine import default
|
||
|
from sqlalchemy.testing.assertions import _expect_warnings
|
||
|
from sqlalchemy.testing.assertions import eq_ # noqa
|
||
|
from sqlalchemy.testing.assertions import is_ # noqa
|
||
|
from sqlalchemy.testing.assertions import is_false # noqa
|
||
|
from sqlalchemy.testing.assertions import is_not_ # noqa
|
||
|
from sqlalchemy.testing.assertions import is_true # noqa
|
||
|
from sqlalchemy.testing.assertions import ne_ # noqa
|
||
|
from sqlalchemy.util import decorator
|
||
|
|
||
|
from ..util import sqla_compat
|
||
|
|
||
|
|
||
|
def _assert_proper_exception_context(exception):
|
||
|
"""assert that any exception we're catching does not have a __context__
|
||
|
without a __cause__, and that __suppress_context__ is never set.
|
||
|
|
||
|
Python 3 will report nested as exceptions as "during the handling of
|
||
|
error X, error Y occurred". That's not what we want to do. we want
|
||
|
these exceptions in a cause chain.
|
||
|
|
||
|
"""
|
||
|
|
||
|
if (
|
||
|
exception.__context__ is not exception.__cause__
|
||
|
and not exception.__suppress_context__
|
||
|
):
|
||
|
assert False, (
|
||
|
"Exception %r was correctly raised but did not set a cause, "
|
||
|
"within context %r as its cause."
|
||
|
% (exception, exception.__context__)
|
||
|
)
|
||
|
|
||
|
|
||
|
def assert_raises(except_cls, callable_, *args, **kw):
|
||
|
return _assert_raises(except_cls, callable_, args, kw, check_context=True)
|
||
|
|
||
|
|
||
|
def assert_raises_context_ok(except_cls, callable_, *args, **kw):
|
||
|
return _assert_raises(except_cls, callable_, args, kw)
|
||
|
|
||
|
|
||
|
def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
|
||
|
return _assert_raises(
|
||
|
except_cls, callable_, args, kwargs, msg=msg, check_context=True
|
||
|
)
|
||
|
|
||
|
|
||
|
def assert_raises_message_context_ok(
|
||
|
except_cls, msg, callable_, *args, **kwargs
|
||
|
):
|
||
|
return _assert_raises(except_cls, callable_, args, kwargs, msg=msg)
|
||
|
|
||
|
|
||
|
def _assert_raises(
|
||
|
except_cls, callable_, args, kwargs, msg=None, check_context=False
|
||
|
):
|
||
|
with _expect_raises(except_cls, msg, check_context) as ec:
|
||
|
callable_(*args, **kwargs)
|
||
|
return ec.error
|
||
|
|
||
|
|
||
|
class _ErrorContainer:
|
||
|
error: Any = None
|
||
|
|
||
|
|
||
|
@contextlib.contextmanager
|
||
|
def _expect_raises(except_cls, msg=None, check_context=False):
|
||
|
ec = _ErrorContainer()
|
||
|
if check_context:
|
||
|
are_we_already_in_a_traceback = sys.exc_info()[0]
|
||
|
try:
|
||
|
yield ec
|
||
|
success = False
|
||
|
except except_cls as err:
|
||
|
ec.error = err
|
||
|
success = True
|
||
|
if msg is not None:
|
||
|
assert re.search(msg, str(err), re.UNICODE), f"{msg} !~ {err}"
|
||
|
if check_context and not are_we_already_in_a_traceback:
|
||
|
_assert_proper_exception_context(err)
|
||
|
print(str(err).encode("utf-8"))
|
||
|
|
||
|
# assert outside the block so it works for AssertionError too !
|
||
|
assert success, "Callable did not raise an exception"
|
||
|
|
||
|
|
||
|
def expect_raises(except_cls, check_context=True):
|
||
|
return _expect_raises(except_cls, check_context=check_context)
|
||
|
|
||
|
|
||
|
def expect_raises_message(except_cls, msg, check_context=True):
|
||
|
return _expect_raises(except_cls, msg=msg, check_context=check_context)
|
||
|
|
||
|
|
||
|
def eq_ignore_whitespace(a, b, msg=None):
|
||
|
a = re.sub(r"^\s+?|\n", "", a)
|
||
|
a = re.sub(r" {2,}", " ", a)
|
||
|
b = re.sub(r"^\s+?|\n", "", b)
|
||
|
b = re.sub(r" {2,}", " ", b)
|
||
|
|
||
|
assert a == b, msg or "%r != %r" % (a, b)
|
||
|
|
||
|
|
||
|
_dialect_mods: Dict[Any, Any] = {}
|
||
|
|
||
|
|
||
|
def _get_dialect(name):
|
||
|
if name is None or name == "default":
|
||
|
return default.DefaultDialect()
|
||
|
else:
|
||
|
d = sqla_compat._create_url(name).get_dialect()()
|
||
|
|
||
|
if name == "postgresql":
|
||
|
d.implicit_returning = True
|
||
|
elif name == "mssql":
|
||
|
d.legacy_schema_aliasing = False
|
||
|
return d
|
||
|
|
||
|
|
||
|
def expect_warnings(*messages, **kw):
|
||
|
"""Context manager which expects one or more warnings.
|
||
|
|
||
|
With no arguments, squelches all SAWarnings emitted via
|
||
|
sqlalchemy.util.warn and sqlalchemy.util.warn_limited. Otherwise
|
||
|
pass string expressions that will match selected warnings via regex;
|
||
|
all non-matching warnings are sent through.
|
||
|
|
||
|
The expect version **asserts** that the warnings were in fact seen.
|
||
|
|
||
|
Note that the test suite sets SAWarning warnings to raise exceptions.
|
||
|
|
||
|
"""
|
||
|
return _expect_warnings(Warning, messages, **kw)
|
||
|
|
||
|
|
||
|
def emits_python_deprecation_warning(*messages):
|
||
|
"""Decorator form of expect_warnings().
|
||
|
|
||
|
Note that emits_warning does **not** assert that the warnings
|
||
|
were in fact seen.
|
||
|
|
||
|
"""
|
||
|
|
||
|
@decorator
|
||
|
def decorate(fn, *args, **kw):
|
||
|
with _expect_warnings(DeprecationWarning, assert_=False, *messages):
|
||
|
return fn(*args, **kw)
|
||
|
|
||
|
return decorate
|
||
|
|
||
|
|
||
|
def expect_sqlalchemy_deprecated(*messages, **kw):
|
||
|
return _expect_warnings(sa_exc.SADeprecationWarning, messages, **kw)
|
||
|
|
||
|
|
||
|
def expect_sqlalchemy_deprecated_20(*messages, **kw):
|
||
|
return _expect_warnings(sa_exc.RemovedIn20Warning, messages, **kw)
|