You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
369 lines
12 KiB
369 lines
12 KiB
# orm/evaluator.py
|
|
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
|
# <see AUTHORS file>
|
|
#
|
|
# This module is part of SQLAlchemy and is released under
|
|
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
|
# mypy: ignore-errors
|
|
|
|
"""Evaluation functions used **INTERNALLY** by ORM DML use cases.
|
|
|
|
|
|
This module is **private, for internal use by SQLAlchemy**.
|
|
|
|
.. versionchanged:: 2.0.4 renamed ``EvaluatorCompiler`` to
|
|
``_EvaluatorCompiler``.
|
|
|
|
"""
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Type
|
|
|
|
from . import exc as orm_exc
|
|
from .base import LoaderCallableStatus
|
|
from .base import PassiveFlag
|
|
from .. import exc
|
|
from .. import inspect
|
|
from ..sql import and_
|
|
from ..sql import operators
|
|
from ..sql.sqltypes import Integer
|
|
from ..sql.sqltypes import Numeric
|
|
from ..util import warn_deprecated
|
|
|
|
|
|
class UnevaluatableError(exc.InvalidRequestError):
|
|
pass
|
|
|
|
|
|
class _NoObject(operators.ColumnOperators):
|
|
def operate(self, *arg, **kw):
|
|
return None
|
|
|
|
def reverse_operate(self, *arg, **kw):
|
|
return None
|
|
|
|
|
|
class _ExpiredObject(operators.ColumnOperators):
|
|
def operate(self, *arg, **kw):
|
|
return self
|
|
|
|
def reverse_operate(self, *arg, **kw):
|
|
return self
|
|
|
|
|
|
_NO_OBJECT = _NoObject()
|
|
_EXPIRED_OBJECT = _ExpiredObject()
|
|
|
|
|
|
class _EvaluatorCompiler:
|
|
def __init__(self, target_cls=None):
|
|
self.target_cls = target_cls
|
|
|
|
def process(self, clause, *clauses):
|
|
if clauses:
|
|
clause = and_(clause, *clauses)
|
|
|
|
meth = getattr(self, f"visit_{clause.__visit_name__}", None)
|
|
if not meth:
|
|
raise UnevaluatableError(
|
|
f"Cannot evaluate {type(clause).__name__}"
|
|
)
|
|
return meth(clause)
|
|
|
|
def visit_grouping(self, clause):
|
|
return self.process(clause.element)
|
|
|
|
def visit_null(self, clause):
|
|
return lambda obj: None
|
|
|
|
def visit_false(self, clause):
|
|
return lambda obj: False
|
|
|
|
def visit_true(self, clause):
|
|
return lambda obj: True
|
|
|
|
def visit_column(self, clause):
|
|
try:
|
|
parentmapper = clause._annotations["parentmapper"]
|
|
except KeyError as ke:
|
|
raise UnevaluatableError(
|
|
f"Cannot evaluate column: {clause}"
|
|
) from ke
|
|
|
|
if self.target_cls and not issubclass(
|
|
self.target_cls, parentmapper.class_
|
|
):
|
|
raise UnevaluatableError(
|
|
"Can't evaluate criteria against "
|
|
f"alternate class {parentmapper.class_}"
|
|
)
|
|
|
|
parentmapper._check_configure()
|
|
|
|
# we'd like to use "proxy_key" annotation to get the "key", however
|
|
# in relationship primaryjoin cases proxy_key is sometimes deannotated
|
|
# and sometimes apparently not present in the first place (?).
|
|
# While I can stop it from being deannotated (though need to see if
|
|
# this breaks other things), not sure right now about cases where it's
|
|
# not there in the first place. can fix at some later point.
|
|
# key = clause._annotations["proxy_key"]
|
|
|
|
# for now, use the old way
|
|
try:
|
|
key = parentmapper._columntoproperty[clause].key
|
|
except orm_exc.UnmappedColumnError as err:
|
|
raise UnevaluatableError(
|
|
f"Cannot evaluate expression: {err}"
|
|
) from err
|
|
|
|
# note this used to fall back to a simple `getattr(obj, key)` evaluator
|
|
# if impl was None; as of #8656, we ensure mappers are configured
|
|
# so that impl is available
|
|
impl = parentmapper.class_manager[key].impl
|
|
|
|
def get_corresponding_attr(obj):
|
|
if obj is None:
|
|
return _NO_OBJECT
|
|
state = inspect(obj)
|
|
dict_ = state.dict
|
|
|
|
value = impl.get(
|
|
state, dict_, passive=PassiveFlag.PASSIVE_NO_FETCH
|
|
)
|
|
if value is LoaderCallableStatus.PASSIVE_NO_RESULT:
|
|
return _EXPIRED_OBJECT
|
|
return value
|
|
|
|
return get_corresponding_attr
|
|
|
|
def visit_tuple(self, clause):
|
|
return self.visit_clauselist(clause)
|
|
|
|
def visit_expression_clauselist(self, clause):
|
|
return self.visit_clauselist(clause)
|
|
|
|
def visit_clauselist(self, clause):
|
|
evaluators = [self.process(clause) for clause in clause.clauses]
|
|
|
|
dispatch = (
|
|
f"visit_{clause.operator.__name__.rstrip('_')}_clauselist_op"
|
|
)
|
|
meth = getattr(self, dispatch, None)
|
|
if meth:
|
|
return meth(clause.operator, evaluators, clause)
|
|
else:
|
|
raise UnevaluatableError(
|
|
f"Cannot evaluate clauselist with operator {clause.operator}"
|
|
)
|
|
|
|
def visit_binary(self, clause):
|
|
eval_left = self.process(clause.left)
|
|
eval_right = self.process(clause.right)
|
|
|
|
dispatch = f"visit_{clause.operator.__name__.rstrip('_')}_binary_op"
|
|
meth = getattr(self, dispatch, None)
|
|
if meth:
|
|
return meth(clause.operator, eval_left, eval_right, clause)
|
|
else:
|
|
raise UnevaluatableError(
|
|
f"Cannot evaluate {type(clause).__name__} with "
|
|
f"operator {clause.operator}"
|
|
)
|
|
|
|
def visit_or_clauselist_op(self, operator, evaluators, clause):
|
|
def evaluate(obj):
|
|
has_null = False
|
|
for sub_evaluate in evaluators:
|
|
value = sub_evaluate(obj)
|
|
if value is _EXPIRED_OBJECT:
|
|
return _EXPIRED_OBJECT
|
|
elif value:
|
|
return True
|
|
has_null = has_null or value is None
|
|
if has_null:
|
|
return None
|
|
return False
|
|
|
|
return evaluate
|
|
|
|
def visit_and_clauselist_op(self, operator, evaluators, clause):
|
|
def evaluate(obj):
|
|
for sub_evaluate in evaluators:
|
|
value = sub_evaluate(obj)
|
|
if value is _EXPIRED_OBJECT:
|
|
return _EXPIRED_OBJECT
|
|
|
|
if not value:
|
|
if value is None or value is _NO_OBJECT:
|
|
return None
|
|
return False
|
|
return True
|
|
|
|
return evaluate
|
|
|
|
def visit_comma_op_clauselist_op(self, operator, evaluators, clause):
|
|
def evaluate(obj):
|
|
values = []
|
|
for sub_evaluate in evaluators:
|
|
value = sub_evaluate(obj)
|
|
if value is _EXPIRED_OBJECT:
|
|
return _EXPIRED_OBJECT
|
|
elif value is None or value is _NO_OBJECT:
|
|
return None
|
|
values.append(value)
|
|
return tuple(values)
|
|
|
|
return evaluate
|
|
|
|
def visit_custom_op_binary_op(
|
|
self, operator, eval_left, eval_right, clause
|
|
):
|
|
if operator.python_impl:
|
|
return self._straight_evaluate(
|
|
operator, eval_left, eval_right, clause
|
|
)
|
|
else:
|
|
raise UnevaluatableError(
|
|
f"Custom operator {operator.opstring!r} can't be evaluated "
|
|
"in Python unless it specifies a callable using "
|
|
"`.python_impl`."
|
|
)
|
|
|
|
def visit_is_binary_op(self, operator, eval_left, eval_right, clause):
|
|
def evaluate(obj):
|
|
left_val = eval_left(obj)
|
|
right_val = eval_right(obj)
|
|
if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT:
|
|
return _EXPIRED_OBJECT
|
|
return left_val == right_val
|
|
|
|
return evaluate
|
|
|
|
def visit_is_not_binary_op(self, operator, eval_left, eval_right, clause):
|
|
def evaluate(obj):
|
|
left_val = eval_left(obj)
|
|
right_val = eval_right(obj)
|
|
if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT:
|
|
return _EXPIRED_OBJECT
|
|
return left_val != right_val
|
|
|
|
return evaluate
|
|
|
|
def _straight_evaluate(self, operator, eval_left, eval_right, clause):
|
|
def evaluate(obj):
|
|
left_val = eval_left(obj)
|
|
right_val = eval_right(obj)
|
|
if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT:
|
|
return _EXPIRED_OBJECT
|
|
elif left_val is None or right_val is None:
|
|
return None
|
|
|
|
return operator(eval_left(obj), eval_right(obj))
|
|
|
|
return evaluate
|
|
|
|
def _straight_evaluate_numeric_only(
|
|
self, operator, eval_left, eval_right, clause
|
|
):
|
|
if clause.left.type._type_affinity not in (
|
|
Numeric,
|
|
Integer,
|
|
) or clause.right.type._type_affinity not in (Numeric, Integer):
|
|
raise UnevaluatableError(
|
|
f'Cannot evaluate math operator "{operator.__name__}" for '
|
|
f"datatypes {clause.left.type}, {clause.right.type}"
|
|
)
|
|
|
|
return self._straight_evaluate(operator, eval_left, eval_right, clause)
|
|
|
|
visit_add_binary_op = _straight_evaluate_numeric_only
|
|
visit_mul_binary_op = _straight_evaluate_numeric_only
|
|
visit_sub_binary_op = _straight_evaluate_numeric_only
|
|
visit_mod_binary_op = _straight_evaluate_numeric_only
|
|
visit_truediv_binary_op = _straight_evaluate_numeric_only
|
|
visit_lt_binary_op = _straight_evaluate
|
|
visit_le_binary_op = _straight_evaluate
|
|
visit_ne_binary_op = _straight_evaluate
|
|
visit_gt_binary_op = _straight_evaluate
|
|
visit_ge_binary_op = _straight_evaluate
|
|
visit_eq_binary_op = _straight_evaluate
|
|
|
|
def visit_in_op_binary_op(self, operator, eval_left, eval_right, clause):
|
|
return self._straight_evaluate(
|
|
lambda a, b: a in b if a is not _NO_OBJECT else None,
|
|
eval_left,
|
|
eval_right,
|
|
clause,
|
|
)
|
|
|
|
def visit_not_in_op_binary_op(
|
|
self, operator, eval_left, eval_right, clause
|
|
):
|
|
return self._straight_evaluate(
|
|
lambda a, b: a not in b if a is not _NO_OBJECT else None,
|
|
eval_left,
|
|
eval_right,
|
|
clause,
|
|
)
|
|
|
|
def visit_concat_op_binary_op(
|
|
self, operator, eval_left, eval_right, clause
|
|
):
|
|
return self._straight_evaluate(
|
|
lambda a, b: a + b, eval_left, eval_right, clause
|
|
)
|
|
|
|
def visit_startswith_op_binary_op(
|
|
self, operator, eval_left, eval_right, clause
|
|
):
|
|
return self._straight_evaluate(
|
|
lambda a, b: a.startswith(b), eval_left, eval_right, clause
|
|
)
|
|
|
|
def visit_endswith_op_binary_op(
|
|
self, operator, eval_left, eval_right, clause
|
|
):
|
|
return self._straight_evaluate(
|
|
lambda a, b: a.endswith(b), eval_left, eval_right, clause
|
|
)
|
|
|
|
def visit_unary(self, clause):
|
|
eval_inner = self.process(clause.element)
|
|
if clause.operator is operators.inv:
|
|
|
|
def evaluate(obj):
|
|
value = eval_inner(obj)
|
|
if value is _EXPIRED_OBJECT:
|
|
return _EXPIRED_OBJECT
|
|
elif value is None:
|
|
return None
|
|
return not value
|
|
|
|
return evaluate
|
|
raise UnevaluatableError(
|
|
f"Cannot evaluate {type(clause).__name__} "
|
|
f"with operator {clause.operator}"
|
|
)
|
|
|
|
def visit_bindparam(self, clause):
|
|
if clause.callable:
|
|
val = clause.callable()
|
|
else:
|
|
val = clause.value
|
|
return lambda obj: val
|
|
|
|
|
|
def __getattr__(name: str) -> Type[_EvaluatorCompiler]:
|
|
if name == "EvaluatorCompiler":
|
|
warn_deprecated(
|
|
"Direct use of 'EvaluatorCompiler' is not supported, and this "
|
|
"name will be removed in a future release. "
|
|
"'_EvaluatorCompiler' is for internal use only",
|
|
"2.0",
|
|
)
|
|
return _EvaluatorCompiler
|
|
else:
|
|
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|