You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1941 lines
63 KiB
1941 lines
63 KiB
# orm/bulk_persistence.py
|
|
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
|
# <see AUTHORS file>
|
|
#
|
|
# This module is part of SQLAlchemy and is released under
|
|
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
|
# mypy: ignore-errors
|
|
|
|
|
|
"""additional ORM persistence classes related to "bulk" operations,
|
|
specifically outside of the flush() process.
|
|
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
from typing import cast
|
|
from typing import Dict
|
|
from typing import Iterable
|
|
from typing import Optional
|
|
from typing import overload
|
|
from typing import TYPE_CHECKING
|
|
from typing import TypeVar
|
|
from typing import Union
|
|
|
|
from . import attributes
|
|
from . import context
|
|
from . import evaluator
|
|
from . import exc as orm_exc
|
|
from . import loading
|
|
from . import persistence
|
|
from .base import NO_VALUE
|
|
from .context import AbstractORMCompileState
|
|
from .context import FromStatement
|
|
from .context import ORMFromStatementCompileState
|
|
from .context import QueryContext
|
|
from .. import exc as sa_exc
|
|
from .. import util
|
|
from ..engine import Dialect
|
|
from ..engine import result as _result
|
|
from ..sql import coercions
|
|
from ..sql import dml
|
|
from ..sql import expression
|
|
from ..sql import roles
|
|
from ..sql import select
|
|
from ..sql import sqltypes
|
|
from ..sql.base import _entity_namespace_key
|
|
from ..sql.base import CompileState
|
|
from ..sql.base import Options
|
|
from ..sql.dml import DeleteDMLState
|
|
from ..sql.dml import InsertDMLState
|
|
from ..sql.dml import UpdateDMLState
|
|
from ..util import EMPTY_DICT
|
|
from ..util.typing import Literal
|
|
|
|
if TYPE_CHECKING:
|
|
from ._typing import DMLStrategyArgument
|
|
from ._typing import OrmExecuteOptionsParameter
|
|
from ._typing import SynchronizeSessionArgument
|
|
from .mapper import Mapper
|
|
from .session import _BindArguments
|
|
from .session import ORMExecuteState
|
|
from .session import Session
|
|
from .session import SessionTransaction
|
|
from .state import InstanceState
|
|
from ..engine import Connection
|
|
from ..engine import cursor
|
|
from ..engine.interfaces import _CoreAnyExecuteParams
|
|
|
|
_O = TypeVar("_O", bound=object)
|
|
|
|
|
|
@overload
|
|
def _bulk_insert(
|
|
mapper: Mapper[_O],
|
|
mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
|
|
session_transaction: SessionTransaction,
|
|
isstates: bool,
|
|
return_defaults: bool,
|
|
render_nulls: bool,
|
|
use_orm_insert_stmt: Literal[None] = ...,
|
|
execution_options: Optional[OrmExecuteOptionsParameter] = ...,
|
|
) -> None:
|
|
...
|
|
|
|
|
|
@overload
|
|
def _bulk_insert(
|
|
mapper: Mapper[_O],
|
|
mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
|
|
session_transaction: SessionTransaction,
|
|
isstates: bool,
|
|
return_defaults: bool,
|
|
render_nulls: bool,
|
|
use_orm_insert_stmt: Optional[dml.Insert] = ...,
|
|
execution_options: Optional[OrmExecuteOptionsParameter] = ...,
|
|
) -> cursor.CursorResult[Any]:
|
|
...
|
|
|
|
|
|
def _bulk_insert(
|
|
mapper: Mapper[_O],
|
|
mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
|
|
session_transaction: SessionTransaction,
|
|
isstates: bool,
|
|
return_defaults: bool,
|
|
render_nulls: bool,
|
|
use_orm_insert_stmt: Optional[dml.Insert] = None,
|
|
execution_options: Optional[OrmExecuteOptionsParameter] = None,
|
|
) -> Optional[cursor.CursorResult[Any]]:
|
|
base_mapper = mapper.base_mapper
|
|
|
|
if session_transaction.session.connection_callable:
|
|
raise NotImplementedError(
|
|
"connection_callable / per-instance sharding "
|
|
"not supported in bulk_insert()"
|
|
)
|
|
|
|
if isstates:
|
|
if return_defaults:
|
|
states = [(state, state.dict) for state in mappings]
|
|
mappings = [dict_ for (state, dict_) in states]
|
|
else:
|
|
mappings = [state.dict for state in mappings]
|
|
else:
|
|
mappings = [dict(m) for m in mappings]
|
|
_expand_composites(mapper, mappings)
|
|
|
|
connection = session_transaction.connection(base_mapper)
|
|
|
|
return_result: Optional[cursor.CursorResult[Any]] = None
|
|
|
|
for table, super_mapper in base_mapper._sorted_tables.items():
|
|
if not mapper.isa(super_mapper) or table not in mapper._pks_by_table:
|
|
continue
|
|
|
|
is_joined_inh_supertable = super_mapper is not mapper
|
|
bookkeeping = (
|
|
is_joined_inh_supertable
|
|
or return_defaults
|
|
or (
|
|
use_orm_insert_stmt is not None
|
|
and bool(use_orm_insert_stmt._returning)
|
|
)
|
|
)
|
|
|
|
records = (
|
|
(
|
|
None,
|
|
state_dict,
|
|
params,
|
|
mapper,
|
|
connection,
|
|
value_params,
|
|
has_all_pks,
|
|
has_all_defaults,
|
|
)
|
|
for (
|
|
state,
|
|
state_dict,
|
|
params,
|
|
mp,
|
|
conn,
|
|
value_params,
|
|
has_all_pks,
|
|
has_all_defaults,
|
|
) in persistence._collect_insert_commands(
|
|
table,
|
|
((None, mapping, mapper, connection) for mapping in mappings),
|
|
bulk=True,
|
|
return_defaults=bookkeeping,
|
|
render_nulls=render_nulls,
|
|
)
|
|
)
|
|
result = persistence._emit_insert_statements(
|
|
base_mapper,
|
|
None,
|
|
super_mapper,
|
|
table,
|
|
records,
|
|
bookkeeping=bookkeeping,
|
|
use_orm_insert_stmt=use_orm_insert_stmt,
|
|
execution_options=execution_options,
|
|
)
|
|
if use_orm_insert_stmt is not None:
|
|
if not use_orm_insert_stmt._returning or return_result is None:
|
|
return_result = result
|
|
elif result.returns_rows:
|
|
return_result = return_result.splice_horizontally(result)
|
|
|
|
if return_defaults and isstates:
|
|
identity_cls = mapper._identity_class
|
|
identity_props = [p.key for p in mapper._identity_key_props]
|
|
for state, dict_ in states:
|
|
state.key = (
|
|
identity_cls,
|
|
tuple([dict_[key] for key in identity_props]),
|
|
)
|
|
|
|
if use_orm_insert_stmt is not None:
|
|
assert return_result is not None
|
|
return return_result
|
|
|
|
|
|
@overload
|
|
def _bulk_update(
|
|
mapper: Mapper[Any],
|
|
mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
|
|
session_transaction: SessionTransaction,
|
|
isstates: bool,
|
|
update_changed_only: bool,
|
|
use_orm_update_stmt: Literal[None] = ...,
|
|
) -> None:
|
|
...
|
|
|
|
|
|
@overload
|
|
def _bulk_update(
|
|
mapper: Mapper[Any],
|
|
mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
|
|
session_transaction: SessionTransaction,
|
|
isstates: bool,
|
|
update_changed_only: bool,
|
|
use_orm_update_stmt: Optional[dml.Update] = ...,
|
|
) -> _result.Result[Any]:
|
|
...
|
|
|
|
|
|
def _bulk_update(
|
|
mapper: Mapper[Any],
|
|
mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
|
|
session_transaction: SessionTransaction,
|
|
isstates: bool,
|
|
update_changed_only: bool,
|
|
use_orm_update_stmt: Optional[dml.Update] = None,
|
|
) -> Optional[_result.Result[Any]]:
|
|
base_mapper = mapper.base_mapper
|
|
|
|
search_keys = mapper._primary_key_propkeys
|
|
if mapper._version_id_prop:
|
|
search_keys = {mapper._version_id_prop.key}.union(search_keys)
|
|
|
|
def _changed_dict(mapper, state):
|
|
return {
|
|
k: v
|
|
for k, v in state.dict.items()
|
|
if k in state.committed_state or k in search_keys
|
|
}
|
|
|
|
if isstates:
|
|
if update_changed_only:
|
|
mappings = [_changed_dict(mapper, state) for state in mappings]
|
|
else:
|
|
mappings = [state.dict for state in mappings]
|
|
else:
|
|
mappings = [dict(m) for m in mappings]
|
|
_expand_composites(mapper, mappings)
|
|
|
|
if session_transaction.session.connection_callable:
|
|
raise NotImplementedError(
|
|
"connection_callable / per-instance sharding "
|
|
"not supported in bulk_update()"
|
|
)
|
|
|
|
connection = session_transaction.connection(base_mapper)
|
|
|
|
for table, super_mapper in base_mapper._sorted_tables.items():
|
|
if not mapper.isa(super_mapper) or table not in mapper._pks_by_table:
|
|
continue
|
|
|
|
records = persistence._collect_update_commands(
|
|
None,
|
|
table,
|
|
(
|
|
(
|
|
None,
|
|
mapping,
|
|
mapper,
|
|
connection,
|
|
(
|
|
mapping[mapper._version_id_prop.key]
|
|
if mapper._version_id_prop
|
|
else None
|
|
),
|
|
)
|
|
for mapping in mappings
|
|
),
|
|
bulk=True,
|
|
use_orm_update_stmt=use_orm_update_stmt,
|
|
)
|
|
persistence._emit_update_statements(
|
|
base_mapper,
|
|
None,
|
|
super_mapper,
|
|
table,
|
|
records,
|
|
bookkeeping=False,
|
|
use_orm_update_stmt=use_orm_update_stmt,
|
|
)
|
|
|
|
if use_orm_update_stmt is not None:
|
|
return _result.null_result()
|
|
|
|
|
|
def _expand_composites(mapper, mappings):
|
|
composite_attrs = mapper.composites
|
|
if not composite_attrs:
|
|
return
|
|
|
|
composite_keys = set(composite_attrs.keys())
|
|
populators = {
|
|
key: composite_attrs[key]._populate_composite_bulk_save_mappings_fn()
|
|
for key in composite_keys
|
|
}
|
|
for mapping in mappings:
|
|
for key in composite_keys.intersection(mapping):
|
|
populators[key](mapping)
|
|
|
|
|
|
class ORMDMLState(AbstractORMCompileState):
|
|
is_dml_returning = True
|
|
from_statement_ctx: Optional[ORMFromStatementCompileState] = None
|
|
|
|
@classmethod
|
|
def _get_orm_crud_kv_pairs(
|
|
cls, mapper, statement, kv_iterator, needs_to_be_cacheable
|
|
):
|
|
|
|
core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs
|
|
|
|
for k, v in kv_iterator:
|
|
k = coercions.expect(roles.DMLColumnRole, k)
|
|
|
|
if isinstance(k, str):
|
|
desc = _entity_namespace_key(mapper, k, default=NO_VALUE)
|
|
if desc is NO_VALUE:
|
|
yield (
|
|
coercions.expect(roles.DMLColumnRole, k),
|
|
coercions.expect(
|
|
roles.ExpressionElementRole,
|
|
v,
|
|
type_=sqltypes.NullType(),
|
|
is_crud=True,
|
|
)
|
|
if needs_to_be_cacheable
|
|
else v,
|
|
)
|
|
else:
|
|
yield from core_get_crud_kv_pairs(
|
|
statement,
|
|
desc._bulk_update_tuples(v),
|
|
needs_to_be_cacheable,
|
|
)
|
|
elif "entity_namespace" in k._annotations:
|
|
k_anno = k._annotations
|
|
attr = _entity_namespace_key(
|
|
k_anno["entity_namespace"], k_anno["proxy_key"]
|
|
)
|
|
yield from core_get_crud_kv_pairs(
|
|
statement,
|
|
attr._bulk_update_tuples(v),
|
|
needs_to_be_cacheable,
|
|
)
|
|
else:
|
|
yield (
|
|
k,
|
|
v
|
|
if not needs_to_be_cacheable
|
|
else coercions.expect(
|
|
roles.ExpressionElementRole,
|
|
v,
|
|
type_=sqltypes.NullType(),
|
|
is_crud=True,
|
|
),
|
|
)
|
|
|
|
@classmethod
|
|
def _get_multi_crud_kv_pairs(cls, statement, kv_iterator):
|
|
plugin_subject = statement._propagate_attrs["plugin_subject"]
|
|
|
|
if not plugin_subject or not plugin_subject.mapper:
|
|
return UpdateDMLState._get_multi_crud_kv_pairs(
|
|
statement, kv_iterator
|
|
)
|
|
|
|
return [
|
|
dict(
|
|
cls._get_orm_crud_kv_pairs(
|
|
plugin_subject.mapper, statement, value_dict.items(), False
|
|
)
|
|
)
|
|
for value_dict in kv_iterator
|
|
]
|
|
|
|
@classmethod
|
|
def _get_crud_kv_pairs(cls, statement, kv_iterator, needs_to_be_cacheable):
|
|
assert (
|
|
needs_to_be_cacheable
|
|
), "no test coverage for needs_to_be_cacheable=False"
|
|
|
|
plugin_subject = statement._propagate_attrs["plugin_subject"]
|
|
|
|
if not plugin_subject or not plugin_subject.mapper:
|
|
return UpdateDMLState._get_crud_kv_pairs(
|
|
statement, kv_iterator, needs_to_be_cacheable
|
|
)
|
|
|
|
return list(
|
|
cls._get_orm_crud_kv_pairs(
|
|
plugin_subject.mapper,
|
|
statement,
|
|
kv_iterator,
|
|
needs_to_be_cacheable,
|
|
)
|
|
)
|
|
|
|
@classmethod
|
|
def get_entity_description(cls, statement):
|
|
ext_info = statement.table._annotations["parententity"]
|
|
mapper = ext_info.mapper
|
|
if ext_info.is_aliased_class:
|
|
_label_name = ext_info.name
|
|
else:
|
|
_label_name = mapper.class_.__name__
|
|
|
|
return {
|
|
"name": _label_name,
|
|
"type": mapper.class_,
|
|
"expr": ext_info.entity,
|
|
"entity": ext_info.entity,
|
|
"table": mapper.local_table,
|
|
}
|
|
|
|
@classmethod
|
|
def get_returning_column_descriptions(cls, statement):
|
|
def _ent_for_col(c):
|
|
return c._annotations.get("parententity", None)
|
|
|
|
def _attr_for_col(c, ent):
|
|
if ent is None:
|
|
return c
|
|
proxy_key = c._annotations.get("proxy_key", None)
|
|
if not proxy_key:
|
|
return c
|
|
else:
|
|
return getattr(ent.entity, proxy_key, c)
|
|
|
|
return [
|
|
{
|
|
"name": c.key,
|
|
"type": c.type,
|
|
"expr": _attr_for_col(c, ent),
|
|
"aliased": ent.is_aliased_class,
|
|
"entity": ent.entity,
|
|
}
|
|
for c, ent in [
|
|
(c, _ent_for_col(c)) for c in statement._all_selected_columns
|
|
]
|
|
]
|
|
|
|
def _setup_orm_returning(
|
|
self,
|
|
compiler,
|
|
orm_level_statement,
|
|
dml_level_statement,
|
|
use_supplemental_cols=True,
|
|
dml_mapper=None,
|
|
):
|
|
"""establish ORM column handlers for an INSERT, UPDATE, or DELETE
|
|
which uses explicit returning().
|
|
|
|
called within compilation level create_for_statement.
|
|
|
|
The _return_orm_returning() method then receives the Result
|
|
after the statement was executed, and applies ORM loading to the
|
|
state that we first established here.
|
|
|
|
"""
|
|
|
|
if orm_level_statement._returning:
|
|
|
|
fs = FromStatement(
|
|
orm_level_statement._returning,
|
|
dml_level_statement,
|
|
_adapt_on_names=False,
|
|
)
|
|
fs = fs.options(*orm_level_statement._with_options)
|
|
self.select_statement = fs
|
|
self.from_statement_ctx = (
|
|
fsc
|
|
) = ORMFromStatementCompileState.create_for_statement(fs, compiler)
|
|
fsc.setup_dml_returning_compile_state(dml_mapper)
|
|
|
|
dml_level_statement = dml_level_statement._generate()
|
|
dml_level_statement._returning = ()
|
|
|
|
cols_to_return = [c for c in fsc.primary_columns if c is not None]
|
|
|
|
# since we are splicing result sets together, make sure there
|
|
# are columns of some kind returned in each result set
|
|
if not cols_to_return:
|
|
cols_to_return.extend(dml_mapper.primary_key)
|
|
|
|
if use_supplemental_cols:
|
|
dml_level_statement = dml_level_statement.return_defaults(
|
|
supplemental_cols=cols_to_return
|
|
)
|
|
else:
|
|
dml_level_statement = dml_level_statement.returning(
|
|
*cols_to_return
|
|
)
|
|
|
|
return dml_level_statement
|
|
|
|
@classmethod
|
|
def _return_orm_returning(
|
|
cls,
|
|
session,
|
|
statement,
|
|
params,
|
|
execution_options,
|
|
bind_arguments,
|
|
result,
|
|
):
|
|
|
|
execution_context = result.context
|
|
compile_state = execution_context.compiled.compile_state
|
|
|
|
if compile_state.from_statement_ctx:
|
|
load_options = execution_options.get(
|
|
"_sa_orm_load_options", QueryContext.default_load_options
|
|
)
|
|
querycontext = QueryContext(
|
|
compile_state.from_statement_ctx,
|
|
compile_state.select_statement,
|
|
params,
|
|
session,
|
|
load_options,
|
|
execution_options,
|
|
bind_arguments,
|
|
)
|
|
return loading.instances(result, querycontext)
|
|
else:
|
|
return result
|
|
|
|
|
|
class BulkUDCompileState(ORMDMLState):
|
|
class default_update_options(Options):
|
|
_dml_strategy: DMLStrategyArgument = "auto"
|
|
_synchronize_session: SynchronizeSessionArgument = "auto"
|
|
_can_use_returning: bool = False
|
|
_is_delete_using: bool = False
|
|
_is_update_from: bool = False
|
|
_autoflush: bool = True
|
|
_subject_mapper: Optional[Mapper[Any]] = None
|
|
_resolved_values = EMPTY_DICT
|
|
_eval_condition = None
|
|
_matched_rows = None
|
|
_identity_token = None
|
|
|
|
@classmethod
|
|
def can_use_returning(
|
|
cls,
|
|
dialect: Dialect,
|
|
mapper: Mapper[Any],
|
|
*,
|
|
is_multitable: bool = False,
|
|
is_update_from: bool = False,
|
|
is_delete_using: bool = False,
|
|
) -> bool:
|
|
raise NotImplementedError()
|
|
|
|
@classmethod
|
|
def orm_pre_session_exec(
|
|
cls,
|
|
session,
|
|
statement,
|
|
params,
|
|
execution_options,
|
|
bind_arguments,
|
|
is_pre_event,
|
|
):
|
|
|
|
(
|
|
update_options,
|
|
execution_options,
|
|
) = BulkUDCompileState.default_update_options.from_execution_options(
|
|
"_sa_orm_update_options",
|
|
{
|
|
"synchronize_session",
|
|
"autoflush",
|
|
"identity_token",
|
|
"is_delete_using",
|
|
"is_update_from",
|
|
"dml_strategy",
|
|
},
|
|
execution_options,
|
|
statement._execution_options,
|
|
)
|
|
bind_arguments["clause"] = statement
|
|
try:
|
|
plugin_subject = statement._propagate_attrs["plugin_subject"]
|
|
except KeyError:
|
|
assert False, "statement had 'orm' plugin but no plugin_subject"
|
|
else:
|
|
bind_arguments["mapper"] = plugin_subject.mapper
|
|
|
|
update_options += {"_subject_mapper": plugin_subject.mapper}
|
|
|
|
if not isinstance(params, list):
|
|
if update_options._dml_strategy == "auto":
|
|
update_options += {"_dml_strategy": "orm"}
|
|
elif update_options._dml_strategy == "bulk":
|
|
raise sa_exc.InvalidRequestError(
|
|
'Can\'t use "bulk" ORM insert strategy without '
|
|
"passing separate parameters"
|
|
)
|
|
else:
|
|
if update_options._dml_strategy == "auto":
|
|
update_options += {"_dml_strategy": "bulk"}
|
|
elif update_options._dml_strategy == "orm":
|
|
raise sa_exc.InvalidRequestError(
|
|
'Can\'t use "orm" ORM insert strategy with a '
|
|
"separate parameter list"
|
|
)
|
|
|
|
sync = update_options._synchronize_session
|
|
if sync is not None:
|
|
if sync not in ("auto", "evaluate", "fetch", False):
|
|
raise sa_exc.ArgumentError(
|
|
"Valid strategies for session synchronization "
|
|
"are 'auto', 'evaluate', 'fetch', False"
|
|
)
|
|
if update_options._dml_strategy == "bulk" and sync == "fetch":
|
|
raise sa_exc.InvalidRequestError(
|
|
"The 'fetch' synchronization strategy is not available "
|
|
"for 'bulk' ORM updates (i.e. multiple parameter sets)"
|
|
)
|
|
|
|
if not is_pre_event:
|
|
if update_options._autoflush:
|
|
session._autoflush()
|
|
|
|
if update_options._dml_strategy == "orm":
|
|
|
|
if update_options._synchronize_session == "auto":
|
|
update_options = cls._do_pre_synchronize_auto(
|
|
session,
|
|
statement,
|
|
params,
|
|
execution_options,
|
|
bind_arguments,
|
|
update_options,
|
|
)
|
|
elif update_options._synchronize_session == "evaluate":
|
|
update_options = cls._do_pre_synchronize_evaluate(
|
|
session,
|
|
statement,
|
|
params,
|
|
execution_options,
|
|
bind_arguments,
|
|
update_options,
|
|
)
|
|
elif update_options._synchronize_session == "fetch":
|
|
update_options = cls._do_pre_synchronize_fetch(
|
|
session,
|
|
statement,
|
|
params,
|
|
execution_options,
|
|
bind_arguments,
|
|
update_options,
|
|
)
|
|
elif update_options._dml_strategy == "bulk":
|
|
if update_options._synchronize_session == "auto":
|
|
update_options += {"_synchronize_session": "evaluate"}
|
|
|
|
# indicators from the "pre exec" step that are then
|
|
# added to the DML statement, which will also be part of the cache
|
|
# key. The compile level create_for_statement() method will then
|
|
# consume these at compiler time.
|
|
statement = statement._annotate(
|
|
{
|
|
"synchronize_session": update_options._synchronize_session,
|
|
"is_delete_using": update_options._is_delete_using,
|
|
"is_update_from": update_options._is_update_from,
|
|
"dml_strategy": update_options._dml_strategy,
|
|
"can_use_returning": update_options._can_use_returning,
|
|
}
|
|
)
|
|
|
|
return (
|
|
statement,
|
|
util.immutabledict(execution_options).union(
|
|
{"_sa_orm_update_options": update_options}
|
|
),
|
|
)
|
|
|
|
@classmethod
|
|
def orm_setup_cursor_result(
|
|
cls,
|
|
session,
|
|
statement,
|
|
params,
|
|
execution_options,
|
|
bind_arguments,
|
|
result,
|
|
):
|
|
|
|
# this stage of the execution is called after the
|
|
# do_orm_execute event hook. meaning for an extension like
|
|
# horizontal sharding, this step happens *within* the horizontal
|
|
# sharding event handler which calls session.execute() re-entrantly
|
|
# and will occur for each backend individually.
|
|
# the sharding extension then returns its own merged result from the
|
|
# individual ones we return here.
|
|
|
|
update_options = execution_options["_sa_orm_update_options"]
|
|
if update_options._dml_strategy == "orm":
|
|
if update_options._synchronize_session == "evaluate":
|
|
cls._do_post_synchronize_evaluate(
|
|
session, statement, result, update_options
|
|
)
|
|
elif update_options._synchronize_session == "fetch":
|
|
cls._do_post_synchronize_fetch(
|
|
session, statement, result, update_options
|
|
)
|
|
elif update_options._dml_strategy == "bulk":
|
|
if update_options._synchronize_session == "evaluate":
|
|
cls._do_post_synchronize_bulk_evaluate(
|
|
session, params, result, update_options
|
|
)
|
|
return result
|
|
|
|
return cls._return_orm_returning(
|
|
session,
|
|
statement,
|
|
params,
|
|
execution_options,
|
|
bind_arguments,
|
|
result,
|
|
)
|
|
|
|
@classmethod
|
|
def _adjust_for_extra_criteria(cls, global_attributes, ext_info):
|
|
"""Apply extra criteria filtering.
|
|
|
|
For all distinct single-table-inheritance mappers represented in the
|
|
table being updated or deleted, produce additional WHERE criteria such
|
|
that only the appropriate subtypes are selected from the total results.
|
|
|
|
Additionally, add WHERE criteria originating from LoaderCriteriaOptions
|
|
collected from the statement.
|
|
|
|
"""
|
|
|
|
return_crit = ()
|
|
|
|
adapter = ext_info._adapter if ext_info.is_aliased_class else None
|
|
|
|
if (
|
|
"additional_entity_criteria",
|
|
ext_info.mapper,
|
|
) in global_attributes:
|
|
return_crit += tuple(
|
|
ae._resolve_where_criteria(ext_info)
|
|
for ae in global_attributes[
|
|
("additional_entity_criteria", ext_info.mapper)
|
|
]
|
|
if ae.include_aliases or ae.entity is ext_info
|
|
)
|
|
|
|
if ext_info.mapper._single_table_criterion is not None:
|
|
return_crit += (ext_info.mapper._single_table_criterion,)
|
|
|
|
if adapter:
|
|
return_crit = tuple(adapter.traverse(crit) for crit in return_crit)
|
|
|
|
return return_crit
|
|
|
|
@classmethod
|
|
def _interpret_returning_rows(cls, mapper, rows):
|
|
"""translate from local inherited table columns to base mapper
|
|
primary key columns.
|
|
|
|
Joined inheritance mappers always establish the primary key in terms of
|
|
the base table. When we UPDATE a sub-table, we can only get
|
|
RETURNING for the sub-table's columns.
|
|
|
|
Here, we create a lookup from the local sub table's primary key
|
|
columns to the base table PK columns so that we can get identity
|
|
key values from RETURNING that's against the joined inheritance
|
|
sub-table.
|
|
|
|
the complexity here is to support more than one level deep of
|
|
inheritance, where we have to link columns to each other across
|
|
the inheritance hierarchy.
|
|
|
|
"""
|
|
|
|
if mapper.local_table is not mapper.base_mapper.local_table:
|
|
return rows
|
|
|
|
# this starts as a mapping of
|
|
# local_pk_col: local_pk_col.
|
|
# we will then iteratively rewrite the "value" of the dict with
|
|
# each successive superclass column
|
|
local_pk_to_base_pk = {pk: pk for pk in mapper.local_table.primary_key}
|
|
|
|
for mp in mapper.iterate_to_root():
|
|
if mp.inherits is None:
|
|
break
|
|
elif mp.local_table is mp.inherits.local_table:
|
|
continue
|
|
|
|
t_to_e = dict(mp._table_to_equated[mp.inherits.local_table])
|
|
col_to_col = {sub_pk: super_pk for super_pk, sub_pk in t_to_e[mp]}
|
|
for pk, super_ in local_pk_to_base_pk.items():
|
|
local_pk_to_base_pk[pk] = col_to_col[super_]
|
|
|
|
lookup = {
|
|
local_pk_to_base_pk[lpk]: idx
|
|
for idx, lpk in enumerate(mapper.local_table.primary_key)
|
|
}
|
|
primary_key_convert = [
|
|
lookup[bpk] for bpk in mapper.base_mapper.primary_key
|
|
]
|
|
return [tuple(row[idx] for idx in primary_key_convert) for row in rows]
|
|
|
|
@classmethod
|
|
def _get_matched_objects_on_criteria(cls, update_options, states):
|
|
mapper = update_options._subject_mapper
|
|
eval_condition = update_options._eval_condition
|
|
|
|
raw_data = [
|
|
(state.obj(), state, state.dict)
|
|
for state in states
|
|
if state.mapper.isa(mapper) and not state.expired
|
|
]
|
|
|
|
identity_token = update_options._identity_token
|
|
if identity_token is not None:
|
|
raw_data = [
|
|
(obj, state, dict_)
|
|
for obj, state, dict_ in raw_data
|
|
if state.identity_token == identity_token
|
|
]
|
|
|
|
result = []
|
|
for obj, state, dict_ in raw_data:
|
|
evaled_condition = eval_condition(obj)
|
|
|
|
# caution: don't use "in ()" or == here, _EXPIRE_OBJECT
|
|
# evaluates as True for all comparisons
|
|
if (
|
|
evaled_condition is True
|
|
or evaled_condition is evaluator._EXPIRED_OBJECT
|
|
):
|
|
result.append(
|
|
(
|
|
obj,
|
|
state,
|
|
dict_,
|
|
evaled_condition is evaluator._EXPIRED_OBJECT,
|
|
)
|
|
)
|
|
return result
|
|
|
|
@classmethod
|
|
def _eval_condition_from_statement(cls, update_options, statement):
|
|
mapper = update_options._subject_mapper
|
|
target_cls = mapper.class_
|
|
|
|
evaluator_compiler = evaluator._EvaluatorCompiler(target_cls)
|
|
crit = ()
|
|
if statement._where_criteria:
|
|
crit += statement._where_criteria
|
|
|
|
global_attributes = {}
|
|
for opt in statement._with_options:
|
|
if opt._is_criteria_option:
|
|
opt.get_global_criteria(global_attributes)
|
|
|
|
if global_attributes:
|
|
crit += cls._adjust_for_extra_criteria(global_attributes, mapper)
|
|
|
|
if crit:
|
|
eval_condition = evaluator_compiler.process(*crit)
|
|
else:
|
|
# workaround for mypy https://github.com/python/mypy/issues/14027
|
|
def _eval_condition(obj):
|
|
return True
|
|
|
|
eval_condition = _eval_condition
|
|
|
|
return eval_condition
|
|
|
|
@classmethod
|
|
def _do_pre_synchronize_auto(
|
|
cls,
|
|
session,
|
|
statement,
|
|
params,
|
|
execution_options,
|
|
bind_arguments,
|
|
update_options,
|
|
):
|
|
"""setup auto sync strategy
|
|
|
|
|
|
"auto" checks if we can use "evaluate" first, then falls back
|
|
to "fetch"
|
|
|
|
evaluate is vastly more efficient for the common case
|
|
where session is empty, only has a few objects, and the UPDATE
|
|
statement can potentially match thousands/millions of rows.
|
|
|
|
OTOH more complex criteria that fails to work with "evaluate"
|
|
we would hope usually correlates with fewer net rows.
|
|
|
|
"""
|
|
|
|
try:
|
|
eval_condition = cls._eval_condition_from_statement(
|
|
update_options, statement
|
|
)
|
|
|
|
except evaluator.UnevaluatableError:
|
|
pass
|
|
else:
|
|
return update_options + {
|
|
"_eval_condition": eval_condition,
|
|
"_synchronize_session": "evaluate",
|
|
}
|
|
|
|
update_options += {"_synchronize_session": "fetch"}
|
|
return cls._do_pre_synchronize_fetch(
|
|
session,
|
|
statement,
|
|
params,
|
|
execution_options,
|
|
bind_arguments,
|
|
update_options,
|
|
)
|
|
|
|
@classmethod
|
|
def _do_pre_synchronize_evaluate(
|
|
cls,
|
|
session,
|
|
statement,
|
|
params,
|
|
execution_options,
|
|
bind_arguments,
|
|
update_options,
|
|
):
|
|
|
|
try:
|
|
eval_condition = cls._eval_condition_from_statement(
|
|
update_options, statement
|
|
)
|
|
|
|
except evaluator.UnevaluatableError as err:
|
|
raise sa_exc.InvalidRequestError(
|
|
'Could not evaluate current criteria in Python: "%s". '
|
|
"Specify 'fetch' or False for the "
|
|
"synchronize_session execution option." % err
|
|
) from err
|
|
|
|
return update_options + {
|
|
"_eval_condition": eval_condition,
|
|
}
|
|
|
|
@classmethod
|
|
def _get_resolved_values(cls, mapper, statement):
|
|
if statement._multi_values:
|
|
return []
|
|
elif statement._ordered_values:
|
|
return list(statement._ordered_values)
|
|
elif statement._values:
|
|
return list(statement._values.items())
|
|
else:
|
|
return []
|
|
|
|
@classmethod
|
|
def _resolved_keys_as_propnames(cls, mapper, resolved_values):
|
|
values = []
|
|
for k, v in resolved_values:
|
|
if mapper and isinstance(k, expression.ColumnElement):
|
|
try:
|
|
attr = mapper._columntoproperty[k]
|
|
except orm_exc.UnmappedColumnError:
|
|
pass
|
|
else:
|
|
values.append((attr.key, v))
|
|
else:
|
|
raise sa_exc.InvalidRequestError(
|
|
"Attribute name not found, can't be "
|
|
"synchronized back to objects: %r" % k
|
|
)
|
|
return values
|
|
|
|
@classmethod
|
|
def _do_pre_synchronize_fetch(
|
|
cls,
|
|
session,
|
|
statement,
|
|
params,
|
|
execution_options,
|
|
bind_arguments,
|
|
update_options,
|
|
):
|
|
mapper = update_options._subject_mapper
|
|
|
|
select_stmt = (
|
|
select(*(mapper.primary_key + (mapper.select_identity_token,)))
|
|
.select_from(mapper)
|
|
.options(*statement._with_options)
|
|
)
|
|
select_stmt._where_criteria = statement._where_criteria
|
|
|
|
# conditionally run the SELECT statement for pre-fetch, testing the
|
|
# "bind" for if we can use RETURNING or not using the do_orm_execute
|
|
# event. If RETURNING is available, the do_orm_execute event
|
|
# will cancel the SELECT from being actually run.
|
|
#
|
|
# The way this is organized seems strange, why don't we just
|
|
# call can_use_returning() before invoking the statement and get
|
|
# answer?, why does this go through the whole execute phase using an
|
|
# event? Answer: because we are integrating with extensions such
|
|
# as the horizontal sharding extention that "multiplexes" an individual
|
|
# statement run through multiple engines, and it uses
|
|
# do_orm_execute() to do that.
|
|
|
|
can_use_returning = None
|
|
|
|
def skip_for_returning(orm_context: ORMExecuteState) -> Any:
|
|
bind = orm_context.session.get_bind(**orm_context.bind_arguments)
|
|
nonlocal can_use_returning
|
|
|
|
per_bind_result = cls.can_use_returning(
|
|
bind.dialect,
|
|
mapper,
|
|
is_update_from=update_options._is_update_from,
|
|
is_delete_using=update_options._is_delete_using,
|
|
)
|
|
|
|
if can_use_returning is not None:
|
|
if can_use_returning != per_bind_result:
|
|
raise sa_exc.InvalidRequestError(
|
|
"For synchronize_session='fetch', can't mix multiple "
|
|
"backends where some support RETURNING and others "
|
|
"don't"
|
|
)
|
|
else:
|
|
can_use_returning = per_bind_result
|
|
|
|
if per_bind_result:
|
|
return _result.null_result()
|
|
else:
|
|
return None
|
|
|
|
result = session.execute(
|
|
select_stmt,
|
|
params,
|
|
execution_options=execution_options,
|
|
bind_arguments=bind_arguments,
|
|
_add_event=skip_for_returning,
|
|
)
|
|
matched_rows = result.fetchall()
|
|
|
|
return update_options + {
|
|
"_matched_rows": matched_rows,
|
|
"_can_use_returning": can_use_returning,
|
|
}
|
|
|
|
|
|
@CompileState.plugin_for("orm", "insert")
|
|
class BulkORMInsert(ORMDMLState, InsertDMLState):
|
|
class default_insert_options(Options):
|
|
_dml_strategy: DMLStrategyArgument = "auto"
|
|
_render_nulls: bool = False
|
|
_return_defaults: bool = False
|
|
_subject_mapper: Optional[Mapper[Any]] = None
|
|
_autoflush: bool = True
|
|
|
|
select_statement: Optional[FromStatement] = None
|
|
|
|
@classmethod
|
|
def orm_pre_session_exec(
|
|
cls,
|
|
session,
|
|
statement,
|
|
params,
|
|
execution_options,
|
|
bind_arguments,
|
|
is_pre_event,
|
|
):
|
|
|
|
(
|
|
insert_options,
|
|
execution_options,
|
|
) = BulkORMInsert.default_insert_options.from_execution_options(
|
|
"_sa_orm_insert_options",
|
|
{"dml_strategy", "autoflush"},
|
|
execution_options,
|
|
statement._execution_options,
|
|
)
|
|
bind_arguments["clause"] = statement
|
|
try:
|
|
plugin_subject = statement._propagate_attrs["plugin_subject"]
|
|
except KeyError:
|
|
assert False, "statement had 'orm' plugin but no plugin_subject"
|
|
else:
|
|
bind_arguments["mapper"] = plugin_subject.mapper
|
|
|
|
insert_options += {"_subject_mapper": plugin_subject.mapper}
|
|
|
|
if not params:
|
|
if insert_options._dml_strategy == "auto":
|
|
insert_options += {"_dml_strategy": "orm"}
|
|
elif insert_options._dml_strategy == "bulk":
|
|
raise sa_exc.InvalidRequestError(
|
|
'Can\'t use "bulk" ORM insert strategy without '
|
|
"passing separate parameters"
|
|
)
|
|
else:
|
|
if insert_options._dml_strategy == "auto":
|
|
insert_options += {"_dml_strategy": "bulk"}
|
|
elif insert_options._dml_strategy == "orm":
|
|
raise sa_exc.InvalidRequestError(
|
|
'Can\'t use "orm" ORM insert strategy with a '
|
|
"separate parameter list"
|
|
)
|
|
|
|
if insert_options._dml_strategy != "raw":
|
|
# for ORM object loading, like ORMContext, we have to disable
|
|
# result set adapt_to_context, because we will be generating a
|
|
# new statement with specific columns that's cached inside of
|
|
# an ORMFromStatementCompileState, which we will re-use for
|
|
# each result.
|
|
if not execution_options:
|
|
execution_options = context._orm_load_exec_options
|
|
else:
|
|
execution_options = execution_options.union(
|
|
context._orm_load_exec_options
|
|
)
|
|
|
|
if not is_pre_event and insert_options._autoflush:
|
|
session._autoflush()
|
|
|
|
statement = statement._annotate(
|
|
{"dml_strategy": insert_options._dml_strategy}
|
|
)
|
|
|
|
return (
|
|
statement,
|
|
util.immutabledict(execution_options).union(
|
|
{"_sa_orm_insert_options": insert_options}
|
|
),
|
|
)
|
|
|
|
@classmethod
|
|
def orm_execute_statement(
|
|
cls,
|
|
session: Session,
|
|
statement: dml.Insert,
|
|
params: _CoreAnyExecuteParams,
|
|
execution_options: OrmExecuteOptionsParameter,
|
|
bind_arguments: _BindArguments,
|
|
conn: Connection,
|
|
) -> _result.Result:
|
|
|
|
insert_options = execution_options.get(
|
|
"_sa_orm_insert_options", cls.default_insert_options
|
|
)
|
|
|
|
if insert_options._dml_strategy not in (
|
|
"raw",
|
|
"bulk",
|
|
"orm",
|
|
"auto",
|
|
):
|
|
raise sa_exc.ArgumentError(
|
|
"Valid strategies for ORM insert strategy "
|
|
"are 'raw', 'orm', 'bulk', 'auto"
|
|
)
|
|
|
|
result: _result.Result[Any]
|
|
|
|
if insert_options._dml_strategy == "raw":
|
|
result = conn.execute(
|
|
statement, params or {}, execution_options=execution_options
|
|
)
|
|
return result
|
|
|
|
if insert_options._dml_strategy == "bulk":
|
|
mapper = insert_options._subject_mapper
|
|
|
|
if (
|
|
statement._post_values_clause is not None
|
|
and mapper._multiple_persistence_tables
|
|
):
|
|
raise sa_exc.InvalidRequestError(
|
|
"bulk INSERT with a 'post values' clause "
|
|
"(typically upsert) not supported for multi-table "
|
|
f"mapper {mapper}"
|
|
)
|
|
|
|
assert mapper is not None
|
|
assert session._transaction is not None
|
|
result = _bulk_insert(
|
|
mapper,
|
|
cast(
|
|
"Iterable[Dict[str, Any]]",
|
|
[params] if isinstance(params, dict) else params,
|
|
),
|
|
session._transaction,
|
|
isstates=False,
|
|
return_defaults=insert_options._return_defaults,
|
|
render_nulls=insert_options._render_nulls,
|
|
use_orm_insert_stmt=statement,
|
|
execution_options=execution_options,
|
|
)
|
|
elif insert_options._dml_strategy == "orm":
|
|
result = conn.execute(
|
|
statement, params or {}, execution_options=execution_options
|
|
)
|
|
else:
|
|
raise AssertionError()
|
|
|
|
if not bool(statement._returning):
|
|
return result
|
|
|
|
return cls._return_orm_returning(
|
|
session,
|
|
statement,
|
|
params,
|
|
execution_options,
|
|
bind_arguments,
|
|
result,
|
|
)
|
|
|
|
@classmethod
|
|
def create_for_statement(cls, statement, compiler, **kw) -> BulkORMInsert:
|
|
|
|
self = cast(
|
|
BulkORMInsert,
|
|
super().create_for_statement(statement, compiler, **kw),
|
|
)
|
|
|
|
if compiler is not None:
|
|
toplevel = not compiler.stack
|
|
else:
|
|
toplevel = True
|
|
if not toplevel:
|
|
return self
|
|
|
|
mapper = statement._propagate_attrs["plugin_subject"]
|
|
dml_strategy = statement._annotations.get("dml_strategy", "raw")
|
|
if dml_strategy == "bulk":
|
|
self._setup_for_bulk_insert(compiler)
|
|
elif dml_strategy == "orm":
|
|
self._setup_for_orm_insert(compiler, mapper)
|
|
|
|
return self
|
|
|
|
@classmethod
|
|
def _resolved_keys_as_col_keys(cls, mapper, resolved_value_dict):
|
|
return {
|
|
col.key if col is not None else k: v
|
|
for col, k, v in (
|
|
(mapper.c.get(k), k, v) for k, v in resolved_value_dict.items()
|
|
)
|
|
}
|
|
|
|
def _setup_for_orm_insert(self, compiler, mapper):
|
|
statement = orm_level_statement = cast(dml.Insert, self.statement)
|
|
|
|
statement = self._setup_orm_returning(
|
|
compiler,
|
|
orm_level_statement,
|
|
statement,
|
|
use_supplemental_cols=False,
|
|
)
|
|
self.statement = statement
|
|
|
|
def _setup_for_bulk_insert(self, compiler):
|
|
"""establish an INSERT statement within the context of
|
|
bulk insert.
|
|
|
|
This method will be within the "conn.execute()" call that is invoked
|
|
by persistence._emit_insert_statement().
|
|
|
|
"""
|
|
statement = orm_level_statement = cast(dml.Insert, self.statement)
|
|
an = statement._annotations
|
|
|
|
emit_insert_table, emit_insert_mapper = (
|
|
an["_emit_insert_table"],
|
|
an["_emit_insert_mapper"],
|
|
)
|
|
|
|
statement = statement._clone()
|
|
|
|
statement.table = emit_insert_table
|
|
if self._dict_parameters:
|
|
self._dict_parameters = {
|
|
col: val
|
|
for col, val in self._dict_parameters.items()
|
|
if col.table is emit_insert_table
|
|
}
|
|
|
|
statement = self._setup_orm_returning(
|
|
compiler,
|
|
orm_level_statement,
|
|
statement,
|
|
use_supplemental_cols=True,
|
|
dml_mapper=emit_insert_mapper,
|
|
)
|
|
|
|
self.statement = statement
|
|
|
|
|
|
@CompileState.plugin_for("orm", "update")
|
|
class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
|
|
@classmethod
|
|
def create_for_statement(cls, statement, compiler, **kw):
|
|
|
|
self = cls.__new__(cls)
|
|
|
|
dml_strategy = statement._annotations.get(
|
|
"dml_strategy", "unspecified"
|
|
)
|
|
|
|
if dml_strategy == "bulk":
|
|
self._setup_for_bulk_update(statement, compiler)
|
|
elif dml_strategy in ("orm", "unspecified"):
|
|
self._setup_for_orm_update(statement, compiler)
|
|
|
|
return self
|
|
|
|
def _setup_for_orm_update(self, statement, compiler, **kw):
|
|
orm_level_statement = statement
|
|
|
|
ext_info = statement.table._annotations["parententity"]
|
|
|
|
self.mapper = mapper = ext_info.mapper
|
|
|
|
self.extra_criteria_entities = {}
|
|
|
|
self._resolved_values = self._get_resolved_values(mapper, statement)
|
|
|
|
extra_criteria_attributes = {}
|
|
|
|
for opt in statement._with_options:
|
|
if opt._is_criteria_option:
|
|
opt.get_global_criteria(extra_criteria_attributes)
|
|
|
|
if statement._values:
|
|
self._resolved_values = dict(self._resolved_values)
|
|
|
|
new_stmt = statement._clone()
|
|
new_stmt.table = mapper.local_table
|
|
|
|
# note if the statement has _multi_values, these
|
|
# are passed through to the new statement, which will then raise
|
|
# InvalidRequestError because UPDATE doesn't support multi_values
|
|
# right now.
|
|
if statement._ordered_values:
|
|
new_stmt._ordered_values = self._resolved_values
|
|
elif statement._values:
|
|
new_stmt._values = self._resolved_values
|
|
|
|
new_crit = self._adjust_for_extra_criteria(
|
|
extra_criteria_attributes, mapper
|
|
)
|
|
if new_crit:
|
|
new_stmt = new_stmt.where(*new_crit)
|
|
|
|
# if we are against a lambda statement we might not be the
|
|
# topmost object that received per-execute annotations
|
|
|
|
# do this first as we need to determine if there is
|
|
# UPDATE..FROM
|
|
|
|
UpdateDMLState.__init__(self, new_stmt, compiler, **kw)
|
|
|
|
use_supplemental_cols = False
|
|
|
|
synchronize_session = compiler._annotations.get(
|
|
"synchronize_session", None
|
|
)
|
|
can_use_returning = compiler._annotations.get(
|
|
"can_use_returning", None
|
|
)
|
|
if can_use_returning is not False:
|
|
# even though pre_exec has determined basic
|
|
# can_use_returning for the dialect, if we are to use
|
|
# RETURNING we need to run can_use_returning() at this level
|
|
# unconditionally because is_delete_using was not known
|
|
# at the pre_exec level
|
|
can_use_returning = (
|
|
synchronize_session == "fetch"
|
|
and self.can_use_returning(
|
|
compiler.dialect, mapper, is_multitable=self.is_multitable
|
|
)
|
|
)
|
|
|
|
if synchronize_session == "fetch" and can_use_returning:
|
|
use_supplemental_cols = True
|
|
|
|
# NOTE: we might want to RETURNING the actual columns to be
|
|
# synchronized also. however this is complicated and difficult
|
|
# to align against the behavior of "evaluate". Additionally,
|
|
# in a large number (if not the majority) of cases, we have the
|
|
# "evaluate" answer, usually a fixed value, in memory already and
|
|
# there's no need to re-fetch the same value
|
|
# over and over again. so perhaps if it could be RETURNING just
|
|
# the elements that were based on a SQL expression and not
|
|
# a constant. For now it doesn't quite seem worth it
|
|
new_stmt = new_stmt.return_defaults(
|
|
*(list(mapper.local_table.primary_key))
|
|
)
|
|
|
|
new_stmt = self._setup_orm_returning(
|
|
compiler,
|
|
orm_level_statement,
|
|
new_stmt,
|
|
use_supplemental_cols=use_supplemental_cols,
|
|
)
|
|
|
|
self.statement = new_stmt
|
|
|
|
def _setup_for_bulk_update(self, statement, compiler, **kw):
|
|
"""establish an UPDATE statement within the context of
|
|
bulk insert.
|
|
|
|
This method will be within the "conn.execute()" call that is invoked
|
|
by persistence._emit_update_statement().
|
|
|
|
"""
|
|
statement = cast(dml.Update, statement)
|
|
an = statement._annotations
|
|
|
|
emit_update_table, _ = (
|
|
an["_emit_update_table"],
|
|
an["_emit_update_mapper"],
|
|
)
|
|
|
|
statement = statement._clone()
|
|
statement.table = emit_update_table
|
|
|
|
UpdateDMLState.__init__(self, statement, compiler, **kw)
|
|
|
|
if self._ordered_values:
|
|
raise sa_exc.InvalidRequestError(
|
|
"bulk ORM UPDATE does not support ordered_values() for "
|
|
"custom UPDATE statements with bulk parameter sets. Use a "
|
|
"non-bulk UPDATE statement or use values()."
|
|
)
|
|
|
|
if self._dict_parameters:
|
|
self._dict_parameters = {
|
|
col: val
|
|
for col, val in self._dict_parameters.items()
|
|
if col.table is emit_update_table
|
|
}
|
|
self.statement = statement
|
|
|
|
@classmethod
|
|
def orm_execute_statement(
|
|
cls,
|
|
session: Session,
|
|
statement: dml.Update,
|
|
params: _CoreAnyExecuteParams,
|
|
execution_options: OrmExecuteOptionsParameter,
|
|
bind_arguments: _BindArguments,
|
|
conn: Connection,
|
|
) -> _result.Result:
|
|
|
|
update_options = execution_options.get(
|
|
"_sa_orm_update_options", cls.default_update_options
|
|
)
|
|
|
|
if update_options._dml_strategy not in ("orm", "auto", "bulk"):
|
|
raise sa_exc.ArgumentError(
|
|
"Valid strategies for ORM UPDATE strategy "
|
|
"are 'orm', 'auto', 'bulk'"
|
|
)
|
|
|
|
result: _result.Result[Any]
|
|
|
|
if update_options._dml_strategy == "bulk":
|
|
if statement._where_criteria:
|
|
raise sa_exc.InvalidRequestError(
|
|
"WHERE clause with bulk ORM UPDATE not "
|
|
"supported right now. Statement may be invoked at the "
|
|
"Core level using "
|
|
"session.connection().execute(stmt, parameters)"
|
|
)
|
|
mapper = update_options._subject_mapper
|
|
assert mapper is not None
|
|
assert session._transaction is not None
|
|
result = _bulk_update(
|
|
mapper,
|
|
cast(
|
|
"Iterable[Dict[str, Any]]",
|
|
[params] if isinstance(params, dict) else params,
|
|
),
|
|
session._transaction,
|
|
isstates=False,
|
|
update_changed_only=False,
|
|
use_orm_update_stmt=statement,
|
|
)
|
|
return cls.orm_setup_cursor_result(
|
|
session,
|
|
statement,
|
|
params,
|
|
execution_options,
|
|
bind_arguments,
|
|
result,
|
|
)
|
|
else:
|
|
return super().orm_execute_statement(
|
|
session,
|
|
statement,
|
|
params,
|
|
execution_options,
|
|
bind_arguments,
|
|
conn,
|
|
)
|
|
|
|
@classmethod
|
|
def can_use_returning(
|
|
cls,
|
|
dialect: Dialect,
|
|
mapper: Mapper[Any],
|
|
*,
|
|
is_multitable: bool = False,
|
|
is_update_from: bool = False,
|
|
is_delete_using: bool = False,
|
|
) -> bool:
|
|
|
|
# normal answer for "should we use RETURNING" at all.
|
|
normal_answer = (
|
|
dialect.update_returning and mapper.local_table.implicit_returning
|
|
)
|
|
if not normal_answer:
|
|
return False
|
|
|
|
# these workarounds are currently hypothetical for UPDATE,
|
|
# unlike DELETE where they impact MariaDB
|
|
if is_update_from:
|
|
return dialect.update_returning_multifrom
|
|
|
|
elif is_multitable and not dialect.update_returning_multifrom:
|
|
raise sa_exc.CompileError(
|
|
f'Dialect "{dialect.name}" does not support RETURNING '
|
|
"with UPDATE..FROM; for synchronize_session='fetch', "
|
|
"please add the additional execution option "
|
|
"'is_update_from=True' to the statement to indicate that "
|
|
"a separate SELECT should be used for this backend."
|
|
)
|
|
|
|
return True
|
|
|
|
@classmethod
|
|
def _do_post_synchronize_bulk_evaluate(
|
|
cls, session, params, result, update_options
|
|
):
|
|
if not params:
|
|
return
|
|
|
|
mapper = update_options._subject_mapper
|
|
pk_keys = [prop.key for prop in mapper._identity_key_props]
|
|
|
|
identity_map = session.identity_map
|
|
|
|
for param in params:
|
|
identity_key = mapper.identity_key_from_primary_key(
|
|
(param[key] for key in pk_keys),
|
|
update_options._identity_token,
|
|
)
|
|
state = identity_map.fast_get_state(identity_key)
|
|
if not state:
|
|
continue
|
|
|
|
evaluated_keys = set(param).difference(pk_keys)
|
|
|
|
dict_ = state.dict
|
|
# only evaluate unmodified attributes
|
|
to_evaluate = state.unmodified.intersection(evaluated_keys)
|
|
for key in to_evaluate:
|
|
if key in dict_:
|
|
dict_[key] = param[key]
|
|
|
|
state.manager.dispatch.refresh(state, None, to_evaluate)
|
|
|
|
state._commit(dict_, list(to_evaluate))
|
|
|
|
# attributes that were formerly modified instead get expired.
|
|
# this only gets hit if the session had pending changes
|
|
# and autoflush were set to False.
|
|
to_expire = evaluated_keys.intersection(dict_).difference(
|
|
to_evaluate
|
|
)
|
|
if to_expire:
|
|
state._expire_attributes(dict_, to_expire)
|
|
|
|
@classmethod
|
|
def _do_post_synchronize_evaluate(
|
|
cls, session, statement, result, update_options
|
|
):
|
|
|
|
matched_objects = cls._get_matched_objects_on_criteria(
|
|
update_options,
|
|
session.identity_map.all_states(),
|
|
)
|
|
|
|
cls._apply_update_set_values_to_objects(
|
|
session,
|
|
update_options,
|
|
statement,
|
|
[(obj, state, dict_) for obj, state, dict_, _ in matched_objects],
|
|
)
|
|
|
|
@classmethod
|
|
def _do_post_synchronize_fetch(
|
|
cls, session, statement, result, update_options
|
|
):
|
|
target_mapper = update_options._subject_mapper
|
|
|
|
returned_defaults_rows = result.returned_defaults_rows
|
|
if returned_defaults_rows:
|
|
pk_rows = cls._interpret_returning_rows(
|
|
target_mapper, returned_defaults_rows
|
|
)
|
|
|
|
matched_rows = [
|
|
tuple(row) + (update_options._identity_token,)
|
|
for row in pk_rows
|
|
]
|
|
else:
|
|
matched_rows = update_options._matched_rows
|
|
|
|
objs = [
|
|
session.identity_map[identity_key]
|
|
for identity_key in [
|
|
target_mapper.identity_key_from_primary_key(
|
|
list(primary_key),
|
|
identity_token=identity_token,
|
|
)
|
|
for primary_key, identity_token in [
|
|
(row[0:-1], row[-1]) for row in matched_rows
|
|
]
|
|
if update_options._identity_token is None
|
|
or identity_token == update_options._identity_token
|
|
]
|
|
if identity_key in session.identity_map
|
|
]
|
|
|
|
if not objs:
|
|
return
|
|
|
|
cls._apply_update_set_values_to_objects(
|
|
session,
|
|
update_options,
|
|
statement,
|
|
[
|
|
(
|
|
obj,
|
|
attributes.instance_state(obj),
|
|
attributes.instance_dict(obj),
|
|
)
|
|
for obj in objs
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def _apply_update_set_values_to_objects(
|
|
cls, session, update_options, statement, matched_objects
|
|
):
|
|
"""apply values to objects derived from an update statement, e.g.
|
|
UPDATE..SET <values>
|
|
|
|
"""
|
|
mapper = update_options._subject_mapper
|
|
target_cls = mapper.class_
|
|
evaluator_compiler = evaluator._EvaluatorCompiler(target_cls)
|
|
resolved_values = cls._get_resolved_values(mapper, statement)
|
|
resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
|
|
mapper, resolved_values
|
|
)
|
|
value_evaluators = {}
|
|
for key, value in resolved_keys_as_propnames:
|
|
try:
|
|
_evaluator = evaluator_compiler.process(
|
|
coercions.expect(roles.ExpressionElementRole, value)
|
|
)
|
|
except evaluator.UnevaluatableError:
|
|
pass
|
|
else:
|
|
value_evaluators[key] = _evaluator
|
|
|
|
evaluated_keys = list(value_evaluators.keys())
|
|
attrib = {k for k, v in resolved_keys_as_propnames}
|
|
|
|
states = set()
|
|
for obj, state, dict_ in matched_objects:
|
|
|
|
to_evaluate = state.unmodified.intersection(evaluated_keys)
|
|
|
|
for key in to_evaluate:
|
|
if key in dict_:
|
|
# only run eval for attributes that are present.
|
|
dict_[key] = value_evaluators[key](obj)
|
|
|
|
state.manager.dispatch.refresh(state, None, to_evaluate)
|
|
|
|
state._commit(dict_, list(to_evaluate))
|
|
|
|
# attributes that were formerly modified instead get expired.
|
|
# this only gets hit if the session had pending changes
|
|
# and autoflush were set to False.
|
|
to_expire = attrib.intersection(dict_).difference(to_evaluate)
|
|
if to_expire:
|
|
state._expire_attributes(dict_, to_expire)
|
|
|
|
states.add(state)
|
|
session._register_altered(states)
|
|
|
|
|
|
@CompileState.plugin_for("orm", "delete")
|
|
class BulkORMDelete(BulkUDCompileState, DeleteDMLState):
|
|
@classmethod
|
|
def create_for_statement(cls, statement, compiler, **kw):
|
|
self = cls.__new__(cls)
|
|
|
|
orm_level_statement = statement
|
|
|
|
ext_info = statement.table._annotations["parententity"]
|
|
self.mapper = mapper = ext_info.mapper
|
|
|
|
self.extra_criteria_entities = {}
|
|
|
|
extra_criteria_attributes = {}
|
|
|
|
for opt in statement._with_options:
|
|
if opt._is_criteria_option:
|
|
opt.get_global_criteria(extra_criteria_attributes)
|
|
|
|
new_stmt = statement._clone()
|
|
new_stmt.table = mapper.local_table
|
|
|
|
new_crit = cls._adjust_for_extra_criteria(
|
|
extra_criteria_attributes, mapper
|
|
)
|
|
if new_crit:
|
|
new_stmt = new_stmt.where(*new_crit)
|
|
|
|
# do this first as we need to determine if there is
|
|
# DELETE..FROM
|
|
DeleteDMLState.__init__(self, new_stmt, compiler, **kw)
|
|
|
|
use_supplemental_cols = False
|
|
|
|
synchronize_session = compiler._annotations.get(
|
|
"synchronize_session", None
|
|
)
|
|
can_use_returning = compiler._annotations.get(
|
|
"can_use_returning", None
|
|
)
|
|
if can_use_returning is not False:
|
|
# even though pre_exec has determined basic
|
|
# can_use_returning for the dialect, if we are to use
|
|
# RETURNING we need to run can_use_returning() at this level
|
|
# unconditionally because is_delete_using was not known
|
|
# at the pre_exec level
|
|
can_use_returning = (
|
|
synchronize_session == "fetch"
|
|
and self.can_use_returning(
|
|
compiler.dialect,
|
|
mapper,
|
|
is_multitable=self.is_multitable,
|
|
is_delete_using=compiler._annotations.get(
|
|
"is_delete_using", False
|
|
),
|
|
)
|
|
)
|
|
|
|
if can_use_returning:
|
|
use_supplemental_cols = True
|
|
|
|
new_stmt = new_stmt.return_defaults(*new_stmt.table.primary_key)
|
|
|
|
new_stmt = self._setup_orm_returning(
|
|
compiler,
|
|
orm_level_statement,
|
|
new_stmt,
|
|
use_supplemental_cols=use_supplemental_cols,
|
|
)
|
|
|
|
self.statement = new_stmt
|
|
|
|
return self
|
|
|
|
@classmethod
|
|
def orm_execute_statement(
|
|
cls,
|
|
session: Session,
|
|
statement: dml.Delete,
|
|
params: _CoreAnyExecuteParams,
|
|
execution_options: OrmExecuteOptionsParameter,
|
|
bind_arguments: _BindArguments,
|
|
conn: Connection,
|
|
) -> _result.Result:
|
|
|
|
update_options = execution_options.get(
|
|
"_sa_orm_update_options", cls.default_update_options
|
|
)
|
|
|
|
if update_options._dml_strategy == "bulk":
|
|
raise sa_exc.InvalidRequestError(
|
|
"Bulk ORM DELETE not supported right now. "
|
|
"Statement may be invoked at the "
|
|
"Core level using "
|
|
"session.connection().execute(stmt, parameters)"
|
|
)
|
|
|
|
if update_options._dml_strategy not in (
|
|
"orm",
|
|
"auto",
|
|
):
|
|
raise sa_exc.ArgumentError(
|
|
"Valid strategies for ORM DELETE strategy are 'orm', 'auto'"
|
|
)
|
|
|
|
return super().orm_execute_statement(
|
|
session, statement, params, execution_options, bind_arguments, conn
|
|
)
|
|
|
|
@classmethod
|
|
def can_use_returning(
|
|
cls,
|
|
dialect: Dialect,
|
|
mapper: Mapper[Any],
|
|
*,
|
|
is_multitable: bool = False,
|
|
is_update_from: bool = False,
|
|
is_delete_using: bool = False,
|
|
) -> bool:
|
|
|
|
# normal answer for "should we use RETURNING" at all.
|
|
normal_answer = (
|
|
dialect.delete_returning and mapper.local_table.implicit_returning
|
|
)
|
|
if not normal_answer:
|
|
return False
|
|
|
|
# now get into special workarounds because MariaDB supports
|
|
# DELETE...RETURNING but not DELETE...USING...RETURNING.
|
|
if is_delete_using:
|
|
# is_delete_using hint was passed. use
|
|
# additional dialect feature (True for PG, False for MariaDB)
|
|
return dialect.delete_returning_multifrom
|
|
|
|
elif is_multitable and not dialect.delete_returning_multifrom:
|
|
# is_delete_using hint was not passed, but we determined
|
|
# at compile time that this is in fact a DELETE..USING.
|
|
# it's too late to continue since we did not pre-SELECT.
|
|
# raise that we need that hint up front.
|
|
|
|
raise sa_exc.CompileError(
|
|
f'Dialect "{dialect.name}" does not support RETURNING '
|
|
"with DELETE..USING; for synchronize_session='fetch', "
|
|
"please add the additional execution option "
|
|
"'is_delete_using=True' to the statement to indicate that "
|
|
"a separate SELECT should be used for this backend."
|
|
)
|
|
|
|
return True
|
|
|
|
@classmethod
|
|
def _do_post_synchronize_evaluate(
|
|
cls, session, statement, result, update_options
|
|
):
|
|
matched_objects = cls._get_matched_objects_on_criteria(
|
|
update_options,
|
|
session.identity_map.all_states(),
|
|
)
|
|
|
|
to_delete = []
|
|
|
|
for _, state, dict_, is_partially_expired in matched_objects:
|
|
if is_partially_expired:
|
|
state._expire(dict_, session.identity_map._modified)
|
|
else:
|
|
to_delete.append(state)
|
|
|
|
if to_delete:
|
|
session._remove_newly_deleted(to_delete)
|
|
|
|
@classmethod
|
|
def _do_post_synchronize_fetch(
|
|
cls, session, statement, result, update_options
|
|
):
|
|
target_mapper = update_options._subject_mapper
|
|
|
|
returned_defaults_rows = result.returned_defaults_rows
|
|
|
|
if returned_defaults_rows:
|
|
pk_rows = cls._interpret_returning_rows(
|
|
target_mapper, returned_defaults_rows
|
|
)
|
|
|
|
matched_rows = [
|
|
tuple(row) + (update_options._identity_token,)
|
|
for row in pk_rows
|
|
]
|
|
else:
|
|
matched_rows = update_options._matched_rows
|
|
|
|
for row in matched_rows:
|
|
primary_key = row[0:-1]
|
|
identity_token = row[-1]
|
|
|
|
# TODO: inline this and call remove_newly_deleted
|
|
# once
|
|
identity_key = target_mapper.identity_key_from_primary_key(
|
|
list(primary_key),
|
|
identity_token=identity_token,
|
|
)
|
|
if identity_key in session.identity_map:
|
|
session._remove_newly_deleted(
|
|
[
|
|
attributes.instance_state(
|
|
session.identity_map[identity_key]
|
|
)
|
|
]
|
|
)
|