You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
bazarr/libs/sqlalchemy/orm/evaluator.py

369 lines
12 KiB

# orm/evaluator.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
"""Evaluation functions used **INTERNALLY** by ORM DML use cases.
This module is **private, for internal use by SQLAlchemy**.
.. versionchanged:: 2.0.4 renamed ``EvaluatorCompiler`` to
``_EvaluatorCompiler``.
"""
from __future__ import annotations
from typing import Type
from . import exc as orm_exc
from .base import LoaderCallableStatus
from .base import PassiveFlag
from .. import exc
from .. import inspect
from ..sql import and_
from ..sql import operators
from ..sql.sqltypes import Integer
from ..sql.sqltypes import Numeric
from ..util import warn_deprecated
class UnevaluatableError(exc.InvalidRequestError):
pass
class _NoObject(operators.ColumnOperators):
def operate(self, *arg, **kw):
return None
def reverse_operate(self, *arg, **kw):
return None
class _ExpiredObject(operators.ColumnOperators):
def operate(self, *arg, **kw):
return self
def reverse_operate(self, *arg, **kw):
return self
_NO_OBJECT = _NoObject()
_EXPIRED_OBJECT = _ExpiredObject()
class _EvaluatorCompiler:
def __init__(self, target_cls=None):
self.target_cls = target_cls
def process(self, clause, *clauses):
if clauses:
clause = and_(clause, *clauses)
meth = getattr(self, f"visit_{clause.__visit_name__}", None)
if not meth:
raise UnevaluatableError(
f"Cannot evaluate {type(clause).__name__}"
)
return meth(clause)
def visit_grouping(self, clause):
return self.process(clause.element)
def visit_null(self, clause):
return lambda obj: None
def visit_false(self, clause):
return lambda obj: False
def visit_true(self, clause):
return lambda obj: True
def visit_column(self, clause):
try:
parentmapper = clause._annotations["parentmapper"]
except KeyError as ke:
raise UnevaluatableError(
f"Cannot evaluate column: {clause}"
) from ke
if self.target_cls and not issubclass(
self.target_cls, parentmapper.class_
):
raise UnevaluatableError(
"Can't evaluate criteria against "
f"alternate class {parentmapper.class_}"
)
parentmapper._check_configure()
# we'd like to use "proxy_key" annotation to get the "key", however
# in relationship primaryjoin cases proxy_key is sometimes deannotated
# and sometimes apparently not present in the first place (?).
# While I can stop it from being deannotated (though need to see if
# this breaks other things), not sure right now about cases where it's
# not there in the first place. can fix at some later point.
# key = clause._annotations["proxy_key"]
# for now, use the old way
try:
key = parentmapper._columntoproperty[clause].key
except orm_exc.UnmappedColumnError as err:
raise UnevaluatableError(
f"Cannot evaluate expression: {err}"
) from err
# note this used to fall back to a simple `getattr(obj, key)` evaluator
# if impl was None; as of #8656, we ensure mappers are configured
# so that impl is available
impl = parentmapper.class_manager[key].impl
def get_corresponding_attr(obj):
if obj is None:
return _NO_OBJECT
state = inspect(obj)
dict_ = state.dict
value = impl.get(
state, dict_, passive=PassiveFlag.PASSIVE_NO_FETCH
)
if value is LoaderCallableStatus.PASSIVE_NO_RESULT:
return _EXPIRED_OBJECT
return value
return get_corresponding_attr
def visit_tuple(self, clause):
return self.visit_clauselist(clause)
def visit_expression_clauselist(self, clause):
return self.visit_clauselist(clause)
def visit_clauselist(self, clause):
evaluators = [self.process(clause) for clause in clause.clauses]
dispatch = (
f"visit_{clause.operator.__name__.rstrip('_')}_clauselist_op"
)
meth = getattr(self, dispatch, None)
if meth:
return meth(clause.operator, evaluators, clause)
else:
raise UnevaluatableError(
f"Cannot evaluate clauselist with operator {clause.operator}"
)
def visit_binary(self, clause):
eval_left = self.process(clause.left)
eval_right = self.process(clause.right)
dispatch = f"visit_{clause.operator.__name__.rstrip('_')}_binary_op"
meth = getattr(self, dispatch, None)
if meth:
return meth(clause.operator, eval_left, eval_right, clause)
else:
raise UnevaluatableError(
f"Cannot evaluate {type(clause).__name__} with "
f"operator {clause.operator}"
)
def visit_or_clauselist_op(self, operator, evaluators, clause):
def evaluate(obj):
has_null = False
for sub_evaluate in evaluators:
value = sub_evaluate(obj)
if value is _EXPIRED_OBJECT:
return _EXPIRED_OBJECT
elif value:
return True
has_null = has_null or value is None
if has_null:
return None
return False
return evaluate
def visit_and_clauselist_op(self, operator, evaluators, clause):
def evaluate(obj):
for sub_evaluate in evaluators:
value = sub_evaluate(obj)
if value is _EXPIRED_OBJECT:
return _EXPIRED_OBJECT
if not value:
if value is None or value is _NO_OBJECT:
return None
return False
return True
return evaluate
def visit_comma_op_clauselist_op(self, operator, evaluators, clause):
def evaluate(obj):
values = []
for sub_evaluate in evaluators:
value = sub_evaluate(obj)
if value is _EXPIRED_OBJECT:
return _EXPIRED_OBJECT
elif value is None or value is _NO_OBJECT:
return None
values.append(value)
return tuple(values)
return evaluate
def visit_custom_op_binary_op(
self, operator, eval_left, eval_right, clause
):
if operator.python_impl:
return self._straight_evaluate(
operator, eval_left, eval_right, clause
)
else:
raise UnevaluatableError(
f"Custom operator {operator.opstring!r} can't be evaluated "
"in Python unless it specifies a callable using "
"`.python_impl`."
)
def visit_is_binary_op(self, operator, eval_left, eval_right, clause):
def evaluate(obj):
left_val = eval_left(obj)
right_val = eval_right(obj)
if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT:
return _EXPIRED_OBJECT
return left_val == right_val
return evaluate
def visit_is_not_binary_op(self, operator, eval_left, eval_right, clause):
def evaluate(obj):
left_val = eval_left(obj)
right_val = eval_right(obj)
if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT:
return _EXPIRED_OBJECT
return left_val != right_val
return evaluate
def _straight_evaluate(self, operator, eval_left, eval_right, clause):
def evaluate(obj):
left_val = eval_left(obj)
right_val = eval_right(obj)
if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT:
return _EXPIRED_OBJECT
elif left_val is None or right_val is None:
return None
return operator(eval_left(obj), eval_right(obj))
return evaluate
def _straight_evaluate_numeric_only(
self, operator, eval_left, eval_right, clause
):
if clause.left.type._type_affinity not in (
Numeric,
Integer,
) or clause.right.type._type_affinity not in (Numeric, Integer):
raise UnevaluatableError(
f'Cannot evaluate math operator "{operator.__name__}" for '
f"datatypes {clause.left.type}, {clause.right.type}"
)
return self._straight_evaluate(operator, eval_left, eval_right, clause)
visit_add_binary_op = _straight_evaluate_numeric_only
visit_mul_binary_op = _straight_evaluate_numeric_only
visit_sub_binary_op = _straight_evaluate_numeric_only
visit_mod_binary_op = _straight_evaluate_numeric_only
visit_truediv_binary_op = _straight_evaluate_numeric_only
visit_lt_binary_op = _straight_evaluate
visit_le_binary_op = _straight_evaluate
visit_ne_binary_op = _straight_evaluate
visit_gt_binary_op = _straight_evaluate
visit_ge_binary_op = _straight_evaluate
visit_eq_binary_op = _straight_evaluate
def visit_in_op_binary_op(self, operator, eval_left, eval_right, clause):
return self._straight_evaluate(
lambda a, b: a in b if a is not _NO_OBJECT else None,
eval_left,
eval_right,
clause,
)
def visit_not_in_op_binary_op(
self, operator, eval_left, eval_right, clause
):
return self._straight_evaluate(
lambda a, b: a not in b if a is not _NO_OBJECT else None,
eval_left,
eval_right,
clause,
)
def visit_concat_op_binary_op(
self, operator, eval_left, eval_right, clause
):
return self._straight_evaluate(
lambda a, b: a + b, eval_left, eval_right, clause
)
def visit_startswith_op_binary_op(
self, operator, eval_left, eval_right, clause
):
return self._straight_evaluate(
lambda a, b: a.startswith(b), eval_left, eval_right, clause
)
def visit_endswith_op_binary_op(
self, operator, eval_left, eval_right, clause
):
return self._straight_evaluate(
lambda a, b: a.endswith(b), eval_left, eval_right, clause
)
def visit_unary(self, clause):
eval_inner = self.process(clause.element)
if clause.operator is operators.inv:
def evaluate(obj):
value = eval_inner(obj)
if value is _EXPIRED_OBJECT:
return _EXPIRED_OBJECT
elif value is None:
return None
return not value
return evaluate
raise UnevaluatableError(
f"Cannot evaluate {type(clause).__name__} "
f"with operator {clause.operator}"
)
def visit_bindparam(self, clause):
if clause.callable:
val = clause.callable()
else:
val = clause.value
return lambda obj: val
def __getattr__(name: str) -> Type[_EvaluatorCompiler]:
if name == "EvaluatorCompiler":
warn_deprecated(
"Direct use of 'EvaluatorCompiler' is not supported, and this "
"name will be removed in a future release. "
"'_EvaluatorCompiler' is for internal use only",
"2.0",
)
return _EvaluatorCompiler
else:
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")