# orm/evaluator.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
""" Evaluation functions used **INTERNALLY** by ORM DML use cases.
This module is * * private , for internal use by SQLAlchemy * * .
. . versionchanged : : 2.0 .4 renamed ` ` EvaluatorCompiler ` ` to
` ` _EvaluatorCompiler ` ` .
"""
from __future__ import annotations
from typing import Type
from . import exc as orm_exc
from . base import LoaderCallableStatus
from . base import PassiveFlag
from . . import exc
from . . import inspect
from . . sql import and_
from . . sql import operators
from . . sql . sqltypes import Integer
from . . sql . sqltypes import Numeric
from . . util import warn_deprecated
class UnevaluatableError ( exc . InvalidRequestError ) :
pass
class _NoObject ( operators . ColumnOperators ) :
def operate ( self , * arg , * * kw ) :
return None
def reverse_operate ( self , * arg , * * kw ) :
return None
class _ExpiredObject ( operators . ColumnOperators ) :
def operate ( self , * arg , * * kw ) :
return self
def reverse_operate ( self , * arg , * * kw ) :
return self
_NO_OBJECT = _NoObject ( )
_EXPIRED_OBJECT = _ExpiredObject ( )
class _EvaluatorCompiler :
def __init__ ( self , target_cls = None ) :
self . target_cls = target_cls
def process ( self , clause , * clauses ) :
if clauses :
clause = and_ ( clause , * clauses )
meth = getattr ( self , f " visit_ { clause . __visit_name__ } " , None )
if not meth :
raise UnevaluatableError (
f " Cannot evaluate { type ( clause ) . __name__ } "
)
return meth ( clause )
def visit_grouping ( self , clause ) :
return self . process ( clause . element )
def visit_null ( self , clause ) :
return lambda obj : None
def visit_false ( self , clause ) :
return lambda obj : False
def visit_true ( self , clause ) :
return lambda obj : True
def visit_column ( self , clause ) :
try :
parentmapper = clause . _annotations [ " parentmapper " ]
except KeyError as ke :
raise UnevaluatableError (
f " Cannot evaluate column: { clause } "
) from ke
if self . target_cls and not issubclass (
self . target_cls , parentmapper . class_
) :
raise UnevaluatableError (
" Can ' t evaluate criteria against "
f " alternate class { parentmapper . class_ } "
)
parentmapper . _check_configure ( )
# we'd like to use "proxy_key" annotation to get the "key", however
# in relationship primaryjoin cases proxy_key is sometimes deannotated
# and sometimes apparently not present in the first place (?).
# While I can stop it from being deannotated (though need to see if
# this breaks other things), not sure right now about cases where it's
# not there in the first place. can fix at some later point.
# key = clause._annotations["proxy_key"]
# for now, use the old way
try :
key = parentmapper . _columntoproperty [ clause ] . key
except orm_exc . UnmappedColumnError as err :
raise UnevaluatableError (
f " Cannot evaluate expression: { err } "
) from err
# note this used to fall back to a simple `getattr(obj, key)` evaluator
# if impl was None; as of #8656, we ensure mappers are configured
# so that impl is available
impl = parentmapper . class_manager [ key ] . impl
def get_corresponding_attr ( obj ) :
if obj is None :
return _NO_OBJECT
state = inspect ( obj )
dict_ = state . dict
value = impl . get (
state , dict_ , passive = PassiveFlag . PASSIVE_NO_FETCH
)
if value is LoaderCallableStatus . PASSIVE_NO_RESULT :
return _EXPIRED_OBJECT
return value
return get_corresponding_attr
def visit_tuple ( self , clause ) :
return self . visit_clauselist ( clause )
def visit_expression_clauselist ( self , clause ) :
return self . visit_clauselist ( clause )
def visit_clauselist ( self , clause ) :
evaluators = [ self . process ( clause ) for clause in clause . clauses ]
dispatch = (
f " visit_ { clause . operator . __name__ . rstrip ( ' _ ' ) } _clauselist_op "
)
meth = getattr ( self , dispatch , None )
if meth :
return meth ( clause . operator , evaluators , clause )
else :
raise UnevaluatableError (
f " Cannot evaluate clauselist with operator { clause . operator } "
)
def visit_binary ( self , clause ) :
eval_left = self . process ( clause . left )
eval_right = self . process ( clause . right )
dispatch = f " visit_ { clause . operator . __name__ . rstrip ( ' _ ' ) } _binary_op "
meth = getattr ( self , dispatch , None )
if meth :
return meth ( clause . operator , eval_left , eval_right , clause )
else :
raise UnevaluatableError (
f " Cannot evaluate { type ( clause ) . __name__ } with "
f " operator { clause . operator } "
)
def visit_or_clauselist_op ( self , operator , evaluators , clause ) :
def evaluate ( obj ) :
has_null = False
for sub_evaluate in evaluators :
value = sub_evaluate ( obj )
if value is _EXPIRED_OBJECT :
return _EXPIRED_OBJECT
elif value :
return True
has_null = has_null or value is None
if has_null :
return None
return False
return evaluate
def visit_and_clauselist_op ( self , operator , evaluators , clause ) :
def evaluate ( obj ) :
for sub_evaluate in evaluators :
value = sub_evaluate ( obj )
if value is _EXPIRED_OBJECT :
return _EXPIRED_OBJECT
if not value :
if value is None or value is _NO_OBJECT :
return None
return False
return True
return evaluate
def visit_comma_op_clauselist_op ( self , operator , evaluators , clause ) :
def evaluate ( obj ) :
values = [ ]
for sub_evaluate in evaluators :
value = sub_evaluate ( obj )
if value is _EXPIRED_OBJECT :
return _EXPIRED_OBJECT
elif value is None or value is _NO_OBJECT :
return None
values . append ( value )
return tuple ( values )
return evaluate
def visit_custom_op_binary_op (
self , operator , eval_left , eval_right , clause
) :
if operator . python_impl :
return self . _straight_evaluate (
operator , eval_left , eval_right , clause
)
else :
raise UnevaluatableError (
f " Custom operator { operator . opstring !r} can ' t be evaluated "
" in Python unless it specifies a callable using "
" `.python_impl`. "
)
def visit_is_binary_op ( self , operator , eval_left , eval_right , clause ) :
def evaluate ( obj ) :
left_val = eval_left ( obj )
right_val = eval_right ( obj )
if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT :
return _EXPIRED_OBJECT
return left_val == right_val
return evaluate
def visit_is_not_binary_op ( self , operator , eval_left , eval_right , clause ) :
def evaluate ( obj ) :
left_val = eval_left ( obj )
right_val = eval_right ( obj )
if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT :
return _EXPIRED_OBJECT
return left_val != right_val
return evaluate
def _straight_evaluate ( self , operator , eval_left , eval_right , clause ) :
def evaluate ( obj ) :
left_val = eval_left ( obj )
right_val = eval_right ( obj )
if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT :
return _EXPIRED_OBJECT
elif left_val is None or right_val is None :
return None
return operator ( eval_left ( obj ) , eval_right ( obj ) )
return evaluate
def _straight_evaluate_numeric_only (
self , operator , eval_left , eval_right , clause
) :
if clause . left . type . _type_affinity not in (
Numeric ,
Integer ,
) or clause . right . type . _type_affinity not in ( Numeric , Integer ) :
raise UnevaluatableError (
f ' Cannot evaluate math operator " { operator . __name__ } " for '
f " datatypes { clause . left . type } , { clause . right . type } "
)
return self . _straight_evaluate ( operator , eval_left , eval_right , clause )
visit_add_binary_op = _straight_evaluate_numeric_only
visit_mul_binary_op = _straight_evaluate_numeric_only
visit_sub_binary_op = _straight_evaluate_numeric_only
visit_mod_binary_op = _straight_evaluate_numeric_only
visit_truediv_binary_op = _straight_evaluate_numeric_only
visit_lt_binary_op = _straight_evaluate
visit_le_binary_op = _straight_evaluate
visit_ne_binary_op = _straight_evaluate
visit_gt_binary_op = _straight_evaluate
visit_ge_binary_op = _straight_evaluate
visit_eq_binary_op = _straight_evaluate
def visit_in_op_binary_op ( self , operator , eval_left , eval_right , clause ) :
return self . _straight_evaluate (
lambda a , b : a in b if a is not _NO_OBJECT else None ,
eval_left ,
eval_right ,
clause ,
)
def visit_not_in_op_binary_op (
self , operator , eval_left , eval_right , clause
) :
return self . _straight_evaluate (
lambda a , b : a not in b if a is not _NO_OBJECT else None ,
eval_left ,
eval_right ,
clause ,
)
def visit_concat_op_binary_op (
self , operator , eval_left , eval_right , clause
) :
return self . _straight_evaluate (
lambda a , b : a + b , eval_left , eval_right , clause
)
def visit_startswith_op_binary_op (
self , operator , eval_left , eval_right , clause
) :
return self . _straight_evaluate (
lambda a , b : a . startswith ( b ) , eval_left , eval_right , clause
)
def visit_endswith_op_binary_op (
self , operator , eval_left , eval_right , clause
) :
return self . _straight_evaluate (
lambda a , b : a . endswith ( b ) , eval_left , eval_right , clause
)
def visit_unary ( self , clause ) :
eval_inner = self . process ( clause . element )
if clause . operator is operators . inv :
def evaluate ( obj ) :
value = eval_inner ( obj )
if value is _EXPIRED_OBJECT :
return _EXPIRED_OBJECT
elif value is None :
return None
return not value
return evaluate
raise UnevaluatableError (
f " Cannot evaluate { type ( clause ) . __name__ } "
f " with operator { clause . operator } "
)
def visit_bindparam ( self , clause ) :
if clause . callable :
val = clause . callable ( )
else :
val = clause . value
return lambda obj : val
def __getattr__ ( name : str ) - > Type [ _EvaluatorCompiler ] :
if name == " EvaluatorCompiler " :
warn_deprecated (
" Direct use of ' EvaluatorCompiler ' is not supported, and this "
" name will be removed in a future release. "
" ' _EvaluatorCompiler ' is for internal use only " ,
" 2.0 " ,
)
return _EvaluatorCompiler
else :
raise AttributeError ( f " module { __name__ !r} has no attribute { name !r} " )