You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
859 lines
27 KiB
859 lines
27 KiB
# mypy: ignore-errors
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import collections
|
|
from functools import update_wrapper
|
|
import inspect
|
|
import itertools
|
|
import operator
|
|
import os
|
|
import re
|
|
import sys
|
|
import uuid
|
|
|
|
import pytest
|
|
|
|
try:
|
|
# installed by bootstrap.py
|
|
import sqla_plugin_base as plugin_base
|
|
except ImportError:
|
|
# assume we're a package, use traditional import
|
|
from . import plugin_base
|
|
|
|
|
|
def pytest_addoption(parser):
|
|
group = parser.getgroup("sqlalchemy")
|
|
|
|
def make_option(name, **kw):
|
|
callback_ = kw.pop("callback", None)
|
|
if callback_:
|
|
|
|
class CallableAction(argparse.Action):
|
|
def __call__(
|
|
self, parser, namespace, values, option_string=None
|
|
):
|
|
callback_(option_string, values, parser)
|
|
|
|
kw["action"] = CallableAction
|
|
|
|
zeroarg_callback = kw.pop("zeroarg_callback", None)
|
|
if zeroarg_callback:
|
|
|
|
class CallableAction(argparse.Action):
|
|
def __init__(
|
|
self,
|
|
option_strings,
|
|
dest,
|
|
default=False,
|
|
required=False,
|
|
help=None, # noqa
|
|
):
|
|
super().__init__(
|
|
option_strings=option_strings,
|
|
dest=dest,
|
|
nargs=0,
|
|
const=True,
|
|
default=default,
|
|
required=required,
|
|
help=help,
|
|
)
|
|
|
|
def __call__(
|
|
self, parser, namespace, values, option_string=None
|
|
):
|
|
zeroarg_callback(option_string, values, parser)
|
|
|
|
kw["action"] = CallableAction
|
|
|
|
group.addoption(name, **kw)
|
|
|
|
plugin_base.setup_options(make_option)
|
|
|
|
|
|
def pytest_configure(config: pytest.Config):
|
|
plugin_base.read_config(config.rootpath)
|
|
if plugin_base.exclude_tags or plugin_base.include_tags:
|
|
new_expr = " and ".join(
|
|
list(plugin_base.include_tags)
|
|
+ [f"not {tag}" for tag in plugin_base.exclude_tags]
|
|
)
|
|
|
|
if config.option.markexpr:
|
|
config.option.markexpr += f" and {new_expr}"
|
|
else:
|
|
config.option.markexpr = new_expr
|
|
|
|
if config.pluginmanager.hasplugin("xdist"):
|
|
config.pluginmanager.register(XDistHooks())
|
|
|
|
if hasattr(config, "workerinput"):
|
|
plugin_base.restore_important_follower_config(config.workerinput)
|
|
plugin_base.configure_follower(config.workerinput["follower_ident"])
|
|
else:
|
|
if config.option.write_idents and os.path.exists(
|
|
config.option.write_idents
|
|
):
|
|
os.remove(config.option.write_idents)
|
|
|
|
plugin_base.pre_begin(config.option)
|
|
|
|
plugin_base.set_coverage_flag(
|
|
bool(getattr(config.option, "cov_source", False))
|
|
)
|
|
|
|
plugin_base.set_fixture_functions(PytestFixtureFunctions)
|
|
|
|
if config.option.dump_pyannotate:
|
|
global DUMP_PYANNOTATE
|
|
DUMP_PYANNOTATE = True
|
|
|
|
|
|
DUMP_PYANNOTATE = False
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def collect_types_fixture():
|
|
if DUMP_PYANNOTATE:
|
|
from pyannotate_runtime import collect_types
|
|
|
|
collect_types.start()
|
|
yield
|
|
if DUMP_PYANNOTATE:
|
|
collect_types.stop()
|
|
|
|
|
|
def _log_sqlalchemy_info(session):
|
|
import sqlalchemy
|
|
from sqlalchemy import __version__
|
|
from sqlalchemy.util import has_compiled_ext
|
|
from sqlalchemy.util._has_cy import _CYEXTENSION_MSG
|
|
|
|
greet = "sqlalchemy installation"
|
|
site = "no user site" if sys.flags.no_user_site else "user site loaded"
|
|
msgs = [
|
|
f"SQLAlchemy {__version__} ({site})",
|
|
f"Path: {sqlalchemy.__file__}",
|
|
]
|
|
|
|
if has_compiled_ext():
|
|
from sqlalchemy.cyextension import util
|
|
|
|
msgs.append(f"compiled extension enabled, e.g. {util.__file__} ")
|
|
else:
|
|
msgs.append(f"compiled extension not enabled; {_CYEXTENSION_MSG}")
|
|
|
|
pm = session.config.pluginmanager.get_plugin("terminalreporter")
|
|
if pm:
|
|
pm.write_sep("=", greet)
|
|
for m in msgs:
|
|
pm.write_line(m)
|
|
else:
|
|
# fancy pants reporter not found, fallback to plain print
|
|
print("=" * 25, greet, "=" * 25)
|
|
for m in msgs:
|
|
print(m)
|
|
|
|
|
|
def pytest_sessionstart(session):
|
|
from sqlalchemy.testing import asyncio
|
|
|
|
_log_sqlalchemy_info(session)
|
|
asyncio._assume_async(plugin_base.post_begin)
|
|
|
|
|
|
def pytest_sessionfinish(session):
|
|
from sqlalchemy.testing import asyncio
|
|
|
|
asyncio._maybe_async_provisioning(plugin_base.final_process_cleanup)
|
|
|
|
if session.config.option.dump_pyannotate:
|
|
from pyannotate_runtime import collect_types
|
|
|
|
collect_types.dump_stats(session.config.option.dump_pyannotate)
|
|
|
|
|
|
def pytest_collection_finish(session):
|
|
if session.config.option.dump_pyannotate:
|
|
from pyannotate_runtime import collect_types
|
|
|
|
lib_sqlalchemy = os.path.abspath("lib/sqlalchemy")
|
|
|
|
def _filter(filename):
|
|
filename = os.path.normpath(os.path.abspath(filename))
|
|
if "lib/sqlalchemy" not in os.path.commonpath(
|
|
[filename, lib_sqlalchemy]
|
|
):
|
|
return None
|
|
if "testing" in filename:
|
|
return None
|
|
|
|
return filename
|
|
|
|
collect_types.init_types_collection(filter_filename=_filter)
|
|
|
|
|
|
class XDistHooks:
|
|
def pytest_configure_node(self, node):
|
|
from sqlalchemy.testing import provision
|
|
from sqlalchemy.testing import asyncio
|
|
|
|
# the master for each node fills workerinput dictionary
|
|
# which pytest-xdist will transfer to the subprocess
|
|
|
|
plugin_base.memoize_important_follower_config(node.workerinput)
|
|
|
|
node.workerinput["follower_ident"] = "test_%s" % uuid.uuid4().hex[0:12]
|
|
|
|
asyncio._maybe_async_provisioning(
|
|
provision.create_follower_db, node.workerinput["follower_ident"]
|
|
)
|
|
|
|
def pytest_testnodedown(self, node, error):
|
|
from sqlalchemy.testing import provision
|
|
from sqlalchemy.testing import asyncio
|
|
|
|
asyncio._maybe_async_provisioning(
|
|
provision.drop_follower_db, node.workerinput["follower_ident"]
|
|
)
|
|
|
|
|
|
def pytest_collection_modifyitems(session, config, items):
|
|
|
|
# look for all those classes that specify __backend__ and
|
|
# expand them out into per-database test cases.
|
|
|
|
# this is much easier to do within pytest_pycollect_makeitem, however
|
|
# pytest is iterating through cls.__dict__ as makeitem is
|
|
# called which causes a "dictionary changed size" error on py3k.
|
|
# I'd submit a pullreq for them to turn it into a list first, but
|
|
# it's to suit the rather odd use case here which is that we are adding
|
|
# new classes to a module on the fly.
|
|
|
|
from sqlalchemy.testing import asyncio
|
|
|
|
rebuilt_items = collections.defaultdict(
|
|
lambda: collections.defaultdict(list)
|
|
)
|
|
|
|
items[:] = [
|
|
item
|
|
for item in items
|
|
if item.getparent(pytest.Class) is not None
|
|
and not item.getparent(pytest.Class).name.startswith("_")
|
|
]
|
|
|
|
test_classes = {item.getparent(pytest.Class) for item in items}
|
|
|
|
def collect(element):
|
|
for inst_or_fn in element.collect():
|
|
if isinstance(inst_or_fn, pytest.Collector):
|
|
yield from collect(inst_or_fn)
|
|
else:
|
|
yield inst_or_fn
|
|
|
|
def setup_test_classes():
|
|
for test_class in test_classes:
|
|
|
|
# transfer legacy __backend__ and __sparse_backend__ symbols
|
|
# to be markers
|
|
add_markers = set()
|
|
if getattr(test_class.cls, "__backend__", False) or getattr(
|
|
test_class.cls, "__only_on__", False
|
|
):
|
|
add_markers = {"backend"}
|
|
elif getattr(test_class.cls, "__sparse_backend__", False):
|
|
add_markers = {"sparse_backend"}
|
|
else:
|
|
add_markers = frozenset()
|
|
|
|
existing_markers = {
|
|
mark.name for mark in test_class.iter_markers()
|
|
}
|
|
add_markers = add_markers - existing_markers
|
|
all_markers = existing_markers.union(add_markers)
|
|
|
|
for marker in add_markers:
|
|
test_class.add_marker(marker)
|
|
|
|
for sub_cls in plugin_base.generate_sub_tests(
|
|
test_class.cls, test_class.module, all_markers
|
|
):
|
|
if sub_cls is not test_class.cls:
|
|
per_cls_dict = rebuilt_items[test_class.cls]
|
|
|
|
module = test_class.getparent(pytest.Module)
|
|
|
|
new_cls = pytest.Class.from_parent(
|
|
name=sub_cls.__name__, parent=module
|
|
)
|
|
for marker in add_markers:
|
|
new_cls.add_marker(marker)
|
|
|
|
for fn in collect(new_cls):
|
|
per_cls_dict[fn.name].append(fn)
|
|
|
|
# class requirements will sometimes need to access the DB to check
|
|
# capabilities, so need to do this for async
|
|
asyncio._maybe_async_provisioning(setup_test_classes)
|
|
|
|
newitems = []
|
|
for item in items:
|
|
cls_ = item.cls
|
|
if cls_ in rebuilt_items:
|
|
newitems.extend(rebuilt_items[cls_][item.name])
|
|
else:
|
|
newitems.append(item)
|
|
|
|
# seems like the functions attached to a test class aren't sorted already?
|
|
# is that true and why's that? (when using unittest, they're sorted)
|
|
items[:] = sorted(
|
|
newitems,
|
|
key=lambda item: (
|
|
item.getparent(pytest.Module).name,
|
|
item.getparent(pytest.Class).name,
|
|
item.name,
|
|
),
|
|
)
|
|
|
|
|
|
def pytest_pycollect_makeitem(collector, name, obj):
|
|
if inspect.isclass(obj) and plugin_base.want_class(name, obj):
|
|
from sqlalchemy.testing import config
|
|
|
|
if config.any_async:
|
|
obj = _apply_maybe_async(obj)
|
|
|
|
return [
|
|
pytest.Class.from_parent(
|
|
name=parametrize_cls.__name__, parent=collector
|
|
)
|
|
for parametrize_cls in _parametrize_cls(collector.module, obj)
|
|
]
|
|
elif (
|
|
inspect.isfunction(obj)
|
|
and collector.cls is not None
|
|
and plugin_base.want_method(collector.cls, obj)
|
|
):
|
|
# None means, fall back to default logic, which includes
|
|
# method-level parametrize
|
|
return None
|
|
else:
|
|
# empty list means skip this item
|
|
return []
|
|
|
|
|
|
def _is_wrapped_coroutine_function(fn):
|
|
while hasattr(fn, "__wrapped__"):
|
|
fn = fn.__wrapped__
|
|
|
|
return inspect.iscoroutinefunction(fn)
|
|
|
|
|
|
def _apply_maybe_async(obj, recurse=True):
|
|
from sqlalchemy.testing import asyncio
|
|
|
|
for name, value in vars(obj).items():
|
|
if (
|
|
(callable(value) or isinstance(value, classmethod))
|
|
and not getattr(value, "_maybe_async_applied", False)
|
|
and (name.startswith("test_"))
|
|
and not _is_wrapped_coroutine_function(value)
|
|
):
|
|
is_classmethod = False
|
|
if isinstance(value, classmethod):
|
|
value = value.__func__
|
|
is_classmethod = True
|
|
|
|
@_pytest_fn_decorator
|
|
def make_async(fn, *args, **kwargs):
|
|
return asyncio._maybe_async(fn, *args, **kwargs)
|
|
|
|
do_async = make_async(value)
|
|
if is_classmethod:
|
|
do_async = classmethod(do_async)
|
|
do_async._maybe_async_applied = True
|
|
|
|
setattr(obj, name, do_async)
|
|
if recurse:
|
|
for cls in obj.mro()[1:]:
|
|
if cls != object:
|
|
_apply_maybe_async(cls, False)
|
|
return obj
|
|
|
|
|
|
def _parametrize_cls(module, cls):
|
|
"""implement a class-based version of pytest parametrize."""
|
|
|
|
if "_sa_parametrize" not in cls.__dict__:
|
|
return [cls]
|
|
|
|
_sa_parametrize = cls._sa_parametrize
|
|
classes = []
|
|
for full_param_set in itertools.product(
|
|
*[params for argname, params in _sa_parametrize]
|
|
):
|
|
cls_variables = {}
|
|
|
|
for argname, param in zip(
|
|
[_sa_param[0] for _sa_param in _sa_parametrize], full_param_set
|
|
):
|
|
if not argname:
|
|
raise TypeError("need argnames for class-based combinations")
|
|
argname_split = re.split(r",\s*", argname)
|
|
for arg, val in zip(argname_split, param.values):
|
|
cls_variables[arg] = val
|
|
parametrized_name = "_".join(
|
|
re.sub(r"\W", "", token)
|
|
for param in full_param_set
|
|
for token in param.id.split("-")
|
|
)
|
|
name = "%s_%s" % (cls.__name__, parametrized_name)
|
|
newcls = type.__new__(type, name, (cls,), cls_variables)
|
|
setattr(module, name, newcls)
|
|
classes.append(newcls)
|
|
return classes
|
|
|
|
|
|
_current_class = None
|
|
|
|
|
|
def pytest_runtest_setup(item):
|
|
from sqlalchemy.testing import asyncio
|
|
|
|
# pytest_runtest_setup runs *before* pytest fixtures with scope="class".
|
|
# plugin_base.start_test_class_outside_fixtures may opt to raise SkipTest
|
|
# for the whole class and has to run things that are across all current
|
|
# databases, so we run this outside of the pytest fixture system altogether
|
|
# and ensure asyncio greenlet if any engines are async
|
|
|
|
global _current_class
|
|
|
|
if isinstance(item, pytest.Function) and _current_class is None:
|
|
asyncio._maybe_async_provisioning(
|
|
plugin_base.start_test_class_outside_fixtures,
|
|
item.cls,
|
|
)
|
|
_current_class = item.getparent(pytest.Class)
|
|
|
|
|
|
@pytest.hookimpl(hookwrapper=True)
|
|
def pytest_runtest_teardown(item, nextitem):
|
|
# runs inside of pytest function fixture scope
|
|
# after test function runs
|
|
|
|
from sqlalchemy.testing import asyncio
|
|
|
|
asyncio._maybe_async(plugin_base.after_test, item)
|
|
|
|
yield
|
|
# this is now after all the fixture teardown have run, the class can be
|
|
# finalized. Since pytest v7 this finalizer can no longer be added in
|
|
# pytest_runtest_setup since the class has not yet been setup at that
|
|
# time.
|
|
# See https://github.com/pytest-dev/pytest/issues/9343
|
|
global _current_class, _current_report
|
|
|
|
if _current_class is not None and (
|
|
# last test or a new class
|
|
nextitem is None
|
|
or nextitem.getparent(pytest.Class) is not _current_class
|
|
):
|
|
_current_class = None
|
|
|
|
try:
|
|
asyncio._maybe_async_provisioning(
|
|
plugin_base.stop_test_class_outside_fixtures, item.cls
|
|
)
|
|
except Exception as e:
|
|
# in case of an exception during teardown attach the original
|
|
# error to the exception message, otherwise it will get lost
|
|
if _current_report.failed:
|
|
if not e.args:
|
|
e.args = (
|
|
"__Original test failure__:\n"
|
|
+ _current_report.longreprtext,
|
|
)
|
|
elif e.args[-1] and isinstance(e.args[-1], str):
|
|
args = list(e.args)
|
|
args[-1] += (
|
|
"\n__Original test failure__:\n"
|
|
+ _current_report.longreprtext
|
|
)
|
|
e.args = tuple(args)
|
|
else:
|
|
e.args += (
|
|
"__Original test failure__",
|
|
_current_report.longreprtext,
|
|
)
|
|
raise
|
|
finally:
|
|
_current_report = None
|
|
|
|
|
|
def pytest_runtest_call(item):
|
|
# runs inside of pytest function fixture scope
|
|
# before test function runs
|
|
|
|
from sqlalchemy.testing import asyncio
|
|
|
|
asyncio._maybe_async(
|
|
plugin_base.before_test,
|
|
item,
|
|
item.module.__name__,
|
|
item.cls,
|
|
item.name,
|
|
)
|
|
|
|
|
|
_current_report = None
|
|
|
|
|
|
def pytest_runtest_logreport(report):
|
|
global _current_report
|
|
if report.when == "call":
|
|
_current_report = report
|
|
|
|
|
|
@pytest.fixture(scope="class")
|
|
def setup_class_methods(request):
|
|
from sqlalchemy.testing import asyncio
|
|
|
|
cls = request.cls
|
|
|
|
if hasattr(cls, "setup_test_class"):
|
|
asyncio._maybe_async(cls.setup_test_class)
|
|
|
|
yield
|
|
|
|
if hasattr(cls, "teardown_test_class"):
|
|
asyncio._maybe_async(cls.teardown_test_class)
|
|
|
|
asyncio._maybe_async(plugin_base.stop_test_class, cls)
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def setup_test_methods(request):
|
|
from sqlalchemy.testing import asyncio
|
|
|
|
# called for each test
|
|
|
|
self = request.instance
|
|
|
|
# before this fixture runs:
|
|
|
|
# 1. function level "autouse" fixtures under py3k (examples: TablesTest
|
|
# define tables / data, MappedTest define tables / mappers / data)
|
|
|
|
# 2. was for p2k. no longer applies
|
|
|
|
# 3. run outer xdist-style setup
|
|
if hasattr(self, "setup_test"):
|
|
asyncio._maybe_async(self.setup_test)
|
|
|
|
# alembic test suite is using setUp and tearDown
|
|
# xdist methods; support these in the test suite
|
|
# for the near term
|
|
if hasattr(self, "setUp"):
|
|
asyncio._maybe_async(self.setUp)
|
|
|
|
# inside the yield:
|
|
# 4. function level fixtures defined on test functions themselves,
|
|
# e.g. "connection", "metadata" run next
|
|
|
|
# 5. pytest hook pytest_runtest_call then runs
|
|
|
|
# 6. test itself runs
|
|
|
|
yield
|
|
|
|
# yield finishes:
|
|
|
|
# 7. function level fixtures defined on test functions
|
|
# themselves, e.g. "connection" rolls back the transaction, "metadata"
|
|
# emits drop all
|
|
|
|
# 8. pytest hook pytest_runtest_teardown hook runs, this is associated
|
|
# with fixtures close all sessions, provisioning.stop_test_class(),
|
|
# engines.testing_reaper -> ensure all connection pool connections
|
|
# are returned, engines created by testing_engine that aren't the
|
|
# config engine are disposed
|
|
|
|
asyncio._maybe_async(plugin_base.after_test_fixtures, self)
|
|
|
|
# 10. run xdist-style teardown
|
|
if hasattr(self, "tearDown"):
|
|
asyncio._maybe_async(self.tearDown)
|
|
|
|
if hasattr(self, "teardown_test"):
|
|
asyncio._maybe_async(self.teardown_test)
|
|
|
|
# 11. was for p2k. no longer applies
|
|
|
|
# 12. function level "autouse" fixtures under py3k (examples: TablesTest /
|
|
# MappedTest delete table data, possibly drop tables and clear mappers
|
|
# depending on the flags defined by the test class)
|
|
|
|
|
|
def _pytest_fn_decorator(target):
|
|
"""Port of langhelpers.decorator with pytest-specific tricks."""
|
|
|
|
from sqlalchemy.util.langhelpers import format_argspec_plus
|
|
from sqlalchemy.util.compat import inspect_getfullargspec
|
|
|
|
def _exec_code_in_env(code, env, fn_name):
|
|
# note this is affected by "from __future__ import annotations" at
|
|
# the top; exec'ed code will use non-evaluated annotations
|
|
# which allows us to be more flexible with code rendering
|
|
# in format_argpsec_plus()
|
|
exec(code, env)
|
|
return env[fn_name]
|
|
|
|
def decorate(fn, add_positional_parameters=()):
|
|
|
|
spec = inspect_getfullargspec(fn)
|
|
if add_positional_parameters:
|
|
spec.args.extend(add_positional_parameters)
|
|
|
|
metadata = dict(
|
|
__target_fn="__target_fn", __orig_fn="__orig_fn", name=fn.__name__
|
|
)
|
|
metadata.update(format_argspec_plus(spec, grouped=False))
|
|
code = (
|
|
"""\
|
|
def %(name)s%(grouped_args)s:
|
|
return %(__target_fn)s(%(__orig_fn)s, %(apply_kw)s)
|
|
"""
|
|
% metadata
|
|
)
|
|
decorated = _exec_code_in_env(
|
|
code, {"__target_fn": target, "__orig_fn": fn}, fn.__name__
|
|
)
|
|
if not add_positional_parameters:
|
|
decorated.__defaults__ = getattr(fn, "__func__", fn).__defaults__
|
|
decorated.__wrapped__ = fn
|
|
return update_wrapper(decorated, fn)
|
|
else:
|
|
# this is the pytest hacky part. don't do a full update wrapper
|
|
# because pytest is really being sneaky about finding the args
|
|
# for the wrapped function
|
|
decorated.__module__ = fn.__module__
|
|
decorated.__name__ = fn.__name__
|
|
if hasattr(fn, "pytestmark"):
|
|
decorated.pytestmark = fn.pytestmark
|
|
return decorated
|
|
|
|
return decorate
|
|
|
|
|
|
class PytestFixtureFunctions(plugin_base.FixtureFunctions):
|
|
def skip_test_exception(self, *arg, **kw):
|
|
return pytest.skip.Exception(*arg, **kw)
|
|
|
|
@property
|
|
def add_to_marker(self):
|
|
return pytest.mark
|
|
|
|
def mark_base_test_class(self):
|
|
return pytest.mark.usefixtures(
|
|
"setup_class_methods", "setup_test_methods"
|
|
)
|
|
|
|
_combination_id_fns = {
|
|
"i": lambda obj: obj,
|
|
"r": repr,
|
|
"s": str,
|
|
"n": lambda obj: obj.__name__
|
|
if hasattr(obj, "__name__")
|
|
else type(obj).__name__,
|
|
}
|
|
|
|
def combinations(self, *arg_sets, **kw):
|
|
"""Facade for pytest.mark.parametrize.
|
|
|
|
Automatically derives argument names from the callable which in our
|
|
case is always a method on a class with positional arguments.
|
|
|
|
ids for parameter sets are derived using an optional template.
|
|
|
|
"""
|
|
from sqlalchemy.testing import exclusions
|
|
|
|
if len(arg_sets) == 1 and hasattr(arg_sets[0], "__next__"):
|
|
arg_sets = list(arg_sets[0])
|
|
|
|
argnames = kw.pop("argnames", None)
|
|
|
|
def _filter_exclusions(args):
|
|
result = []
|
|
gathered_exclusions = []
|
|
for a in args:
|
|
if isinstance(a, exclusions.compound):
|
|
gathered_exclusions.append(a)
|
|
else:
|
|
result.append(a)
|
|
|
|
return result, gathered_exclusions
|
|
|
|
id_ = kw.pop("id_", None)
|
|
|
|
tobuild_pytest_params = []
|
|
has_exclusions = False
|
|
if id_:
|
|
_combination_id_fns = self._combination_id_fns
|
|
|
|
# because itemgetter is not consistent for one argument vs.
|
|
# multiple, make it multiple in all cases and use a slice
|
|
# to omit the first argument
|
|
_arg_getter = operator.itemgetter(
|
|
0,
|
|
*[
|
|
idx
|
|
for idx, char in enumerate(id_)
|
|
if char in ("n", "r", "s", "a")
|
|
],
|
|
)
|
|
fns = [
|
|
(operator.itemgetter(idx), _combination_id_fns[char])
|
|
for idx, char in enumerate(id_)
|
|
if char in _combination_id_fns
|
|
]
|
|
|
|
for arg in arg_sets:
|
|
if not isinstance(arg, tuple):
|
|
arg = (arg,)
|
|
|
|
fn_params, param_exclusions = _filter_exclusions(arg)
|
|
|
|
parameters = _arg_getter(fn_params)[1:]
|
|
|
|
if param_exclusions:
|
|
has_exclusions = True
|
|
|
|
tobuild_pytest_params.append(
|
|
(
|
|
parameters,
|
|
param_exclusions,
|
|
"-".join(
|
|
comb_fn(getter(arg)) for getter, comb_fn in fns
|
|
),
|
|
)
|
|
)
|
|
|
|
else:
|
|
|
|
for arg in arg_sets:
|
|
if not isinstance(arg, tuple):
|
|
arg = (arg,)
|
|
|
|
fn_params, param_exclusions = _filter_exclusions(arg)
|
|
|
|
if param_exclusions:
|
|
has_exclusions = True
|
|
|
|
tobuild_pytest_params.append(
|
|
(fn_params, param_exclusions, None)
|
|
)
|
|
|
|
pytest_params = []
|
|
for parameters, param_exclusions, id_ in tobuild_pytest_params:
|
|
if has_exclusions:
|
|
parameters += (param_exclusions,)
|
|
|
|
param = pytest.param(*parameters, id=id_)
|
|
pytest_params.append(param)
|
|
|
|
def decorate(fn):
|
|
if inspect.isclass(fn):
|
|
if has_exclusions:
|
|
raise NotImplementedError(
|
|
"exclusions not supported for class level combinations"
|
|
)
|
|
if "_sa_parametrize" not in fn.__dict__:
|
|
fn._sa_parametrize = []
|
|
fn._sa_parametrize.append((argnames, pytest_params))
|
|
return fn
|
|
else:
|
|
_fn_argnames = inspect.getfullargspec(fn).args[1:]
|
|
if argnames is None:
|
|
_argnames = _fn_argnames
|
|
else:
|
|
_argnames = re.split(r", *", argnames)
|
|
|
|
if has_exclusions:
|
|
existing_exl = sum(
|
|
1 for n in _fn_argnames if n.startswith("_exclusions")
|
|
)
|
|
current_exclusion_name = f"_exclusions_{existing_exl}"
|
|
_argnames += [current_exclusion_name]
|
|
|
|
@_pytest_fn_decorator
|
|
def check_exclusions(fn, *args, **kw):
|
|
_exclusions = args[-1]
|
|
if _exclusions:
|
|
exlu = exclusions.compound().add(*_exclusions)
|
|
fn = exlu(fn)
|
|
return fn(*args[:-1], **kw)
|
|
|
|
fn = check_exclusions(
|
|
fn, add_positional_parameters=(current_exclusion_name,)
|
|
)
|
|
|
|
return pytest.mark.parametrize(_argnames, pytest_params)(fn)
|
|
|
|
return decorate
|
|
|
|
def param_ident(self, *parameters):
|
|
ident = parameters[0]
|
|
return pytest.param(*parameters[1:], id=ident)
|
|
|
|
def fixture(self, *arg, **kw):
|
|
from sqlalchemy.testing import config
|
|
from sqlalchemy.testing import asyncio
|
|
|
|
# wrapping pytest.fixture function. determine if
|
|
# decorator was called as @fixture or @fixture().
|
|
if len(arg) > 0 and callable(arg[0]):
|
|
# was called as @fixture(), we have the function to wrap.
|
|
fn = arg[0]
|
|
arg = arg[1:]
|
|
else:
|
|
# was called as @fixture, don't have the function yet.
|
|
fn = None
|
|
|
|
# create a pytest.fixture marker. because the fn is not being
|
|
# passed, this is always a pytest.FixtureFunctionMarker()
|
|
# object (or whatever pytest is calling it when you read this)
|
|
# that is waiting for a function.
|
|
fixture = pytest.fixture(*arg, **kw)
|
|
|
|
# now apply wrappers to the function, including fixture itself
|
|
|
|
def wrap(fn):
|
|
if config.any_async:
|
|
fn = asyncio._maybe_async_wrapper(fn)
|
|
# other wrappers may be added here
|
|
|
|
# now apply FixtureFunctionMarker
|
|
fn = fixture(fn)
|
|
|
|
return fn
|
|
|
|
if fn:
|
|
return wrap(fn)
|
|
else:
|
|
return wrap
|
|
|
|
def get_current_test_name(self):
|
|
return os.environ.get("PYTEST_CURRENT_TEST")
|
|
|
|
def async_test(self, fn):
|
|
from sqlalchemy.testing import asyncio
|
|
|
|
@_pytest_fn_decorator
|
|
def decorate(fn, *args, **kwargs):
|
|
asyncio._run_coroutine_function(fn, *args, **kwargs)
|
|
|
|
return decorate(fn)
|