You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
542 lines
16 KiB
542 lines
16 KiB
# util/_py_collections.py
|
|
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
|
# <see AUTHORS file>
|
|
#
|
|
# This module is part of SQLAlchemy and is released under
|
|
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
|
# mypy: allow-untyped-defs, allow-untyped-calls
|
|
|
|
from __future__ import annotations
|
|
|
|
from itertools import filterfalse
|
|
from typing import AbstractSet
|
|
from typing import Any
|
|
from typing import Callable
|
|
from typing import cast
|
|
from typing import Collection
|
|
from typing import Dict
|
|
from typing import Iterable
|
|
from typing import Iterator
|
|
from typing import List
|
|
from typing import Mapping
|
|
from typing import NoReturn
|
|
from typing import Optional
|
|
from typing import Set
|
|
from typing import Tuple
|
|
from typing import TYPE_CHECKING
|
|
from typing import TypeVar
|
|
from typing import Union
|
|
|
|
from ..util.typing import Self
|
|
|
|
_T = TypeVar("_T", bound=Any)
|
|
_S = TypeVar("_S", bound=Any)
|
|
_KT = TypeVar("_KT", bound=Any)
|
|
_VT = TypeVar("_VT", bound=Any)
|
|
|
|
|
|
class ReadOnlyContainer:
|
|
__slots__ = ()
|
|
|
|
def _readonly(self, *arg: Any, **kw: Any) -> NoReturn:
|
|
raise TypeError(
|
|
"%s object is immutable and/or readonly" % self.__class__.__name__
|
|
)
|
|
|
|
def _immutable(self, *arg: Any, **kw: Any) -> NoReturn:
|
|
raise TypeError("%s object is immutable" % self.__class__.__name__)
|
|
|
|
def __delitem__(self, key: Any) -> NoReturn:
|
|
self._readonly()
|
|
|
|
def __setitem__(self, key: Any, value: Any) -> NoReturn:
|
|
self._readonly()
|
|
|
|
def __setattr__(self, key: str, value: Any) -> NoReturn:
|
|
self._readonly()
|
|
|
|
|
|
class ImmutableDictBase(ReadOnlyContainer, Dict[_KT, _VT]):
|
|
if TYPE_CHECKING:
|
|
|
|
def __new__(cls, *args: Any) -> Self: ...
|
|
|
|
def __init__(cls, *args: Any): ...
|
|
|
|
def _readonly(self, *arg: Any, **kw: Any) -> NoReturn:
|
|
self._immutable()
|
|
|
|
def clear(self) -> NoReturn:
|
|
self._readonly()
|
|
|
|
def pop(self, key: Any, default: Optional[Any] = None) -> NoReturn:
|
|
self._readonly()
|
|
|
|
def popitem(self) -> NoReturn:
|
|
self._readonly()
|
|
|
|
def setdefault(self, key: Any, default: Optional[Any] = None) -> NoReturn:
|
|
self._readonly()
|
|
|
|
def update(self, *arg: Any, **kw: Any) -> NoReturn:
|
|
self._readonly()
|
|
|
|
|
|
class immutabledict(ImmutableDictBase[_KT, _VT]):
|
|
def __new__(cls, *args):
|
|
new = ImmutableDictBase.__new__(cls)
|
|
dict.__init__(new, *args)
|
|
return new
|
|
|
|
def __init__(
|
|
self, *args: Union[Mapping[_KT, _VT], Iterable[Tuple[_KT, _VT]]]
|
|
):
|
|
pass
|
|
|
|
def __reduce__(self):
|
|
return immutabledict, (dict(self),)
|
|
|
|
def union(
|
|
self, __d: Optional[Mapping[_KT, _VT]] = None
|
|
) -> immutabledict[_KT, _VT]:
|
|
if not __d:
|
|
return self
|
|
|
|
new = ImmutableDictBase.__new__(self.__class__)
|
|
dict.__init__(new, self)
|
|
dict.update(new, __d) # type: ignore
|
|
return new
|
|
|
|
def _union_w_kw(
|
|
self, __d: Optional[Mapping[_KT, _VT]] = None, **kw: _VT
|
|
) -> immutabledict[_KT, _VT]:
|
|
# not sure if C version works correctly w/ this yet
|
|
if not __d and not kw:
|
|
return self
|
|
|
|
new = ImmutableDictBase.__new__(self.__class__)
|
|
dict.__init__(new, self)
|
|
if __d:
|
|
dict.update(new, __d) # type: ignore
|
|
dict.update(new, kw) # type: ignore
|
|
return new
|
|
|
|
def merge_with(
|
|
self, *dicts: Optional[Mapping[_KT, _VT]]
|
|
) -> immutabledict[_KT, _VT]:
|
|
new = None
|
|
for d in dicts:
|
|
if d:
|
|
if new is None:
|
|
new = ImmutableDictBase.__new__(self.__class__)
|
|
dict.__init__(new, self)
|
|
dict.update(new, d) # type: ignore
|
|
if new is None:
|
|
return self
|
|
|
|
return new
|
|
|
|
def __repr__(self) -> str:
|
|
return "immutabledict(%s)" % dict.__repr__(self)
|
|
|
|
# PEP 584
|
|
def __ior__(self, __value: Any) -> NoReturn: # type: ignore
|
|
self._readonly()
|
|
|
|
def __or__( # type: ignore[override]
|
|
self, __value: Mapping[_KT, _VT]
|
|
) -> immutabledict[_KT, _VT]:
|
|
return immutabledict(
|
|
super().__or__(__value), # type: ignore[call-overload]
|
|
)
|
|
|
|
def __ror__( # type: ignore[override]
|
|
self, __value: Mapping[_KT, _VT]
|
|
) -> immutabledict[_KT, _VT]:
|
|
return immutabledict(
|
|
super().__ror__(__value), # type: ignore[call-overload]
|
|
)
|
|
|
|
|
|
class OrderedSet(Set[_T]):
|
|
__slots__ = ("_list",)
|
|
|
|
_list: List[_T]
|
|
|
|
def __init__(self, d: Optional[Iterable[_T]] = None) -> None:
|
|
if d is not None:
|
|
self._list = unique_list(d)
|
|
super().update(self._list)
|
|
else:
|
|
self._list = []
|
|
|
|
def copy(self) -> OrderedSet[_T]:
|
|
cp = self.__class__()
|
|
cp._list = self._list.copy()
|
|
set.update(cp, cp._list)
|
|
return cp
|
|
|
|
def add(self, element: _T) -> None:
|
|
if element not in self:
|
|
self._list.append(element)
|
|
super().add(element)
|
|
|
|
def remove(self, element: _T) -> None:
|
|
super().remove(element)
|
|
self._list.remove(element)
|
|
|
|
def pop(self) -> _T:
|
|
try:
|
|
value = self._list.pop()
|
|
except IndexError:
|
|
raise KeyError("pop from an empty set") from None
|
|
super().remove(value)
|
|
return value
|
|
|
|
def insert(self, pos: int, element: _T) -> None:
|
|
if element not in self:
|
|
self._list.insert(pos, element)
|
|
super().add(element)
|
|
|
|
def discard(self, element: _T) -> None:
|
|
if element in self:
|
|
self._list.remove(element)
|
|
super().remove(element)
|
|
|
|
def clear(self) -> None:
|
|
super().clear()
|
|
self._list = []
|
|
|
|
def __getitem__(self, key: int) -> _T:
|
|
return self._list[key]
|
|
|
|
def __iter__(self) -> Iterator[_T]:
|
|
return iter(self._list)
|
|
|
|
def __add__(self, other: Iterator[_T]) -> OrderedSet[_T]:
|
|
return self.union(other)
|
|
|
|
def __repr__(self) -> str:
|
|
return "%s(%r)" % (self.__class__.__name__, self._list)
|
|
|
|
__str__ = __repr__
|
|
|
|
def update(self, *iterables: Iterable[_T]) -> None:
|
|
for iterable in iterables:
|
|
for e in iterable:
|
|
if e not in self:
|
|
self._list.append(e)
|
|
super().add(e)
|
|
|
|
def __ior__(self, other: AbstractSet[_S]) -> OrderedSet[Union[_T, _S]]:
|
|
self.update(other)
|
|
return self
|
|
|
|
def union(self, *other: Iterable[_S]) -> OrderedSet[Union[_T, _S]]:
|
|
result: OrderedSet[Union[_T, _S]] = self.copy()
|
|
result.update(*other)
|
|
return result
|
|
|
|
def __or__(self, other: AbstractSet[_S]) -> OrderedSet[Union[_T, _S]]:
|
|
return self.union(other)
|
|
|
|
def intersection(self, *other: Iterable[Any]) -> OrderedSet[_T]:
|
|
other_set: Set[Any] = set()
|
|
other_set.update(*other)
|
|
return self.__class__(a for a in self if a in other_set)
|
|
|
|
def __and__(self, other: AbstractSet[object]) -> OrderedSet[_T]:
|
|
return self.intersection(other)
|
|
|
|
def symmetric_difference(self, other: Iterable[_T]) -> OrderedSet[_T]:
|
|
collection: Collection[_T]
|
|
if isinstance(other, set):
|
|
collection = other_set = other
|
|
elif isinstance(other, Collection):
|
|
collection = other
|
|
other_set = set(other)
|
|
else:
|
|
collection = list(other)
|
|
other_set = set(collection)
|
|
result = self.__class__(a for a in self if a not in other_set)
|
|
result.update(a for a in collection if a not in self)
|
|
return result
|
|
|
|
def __xor__(self, other: AbstractSet[_S]) -> OrderedSet[Union[_T, _S]]:
|
|
return cast(OrderedSet[Union[_T, _S]], self).symmetric_difference(
|
|
other
|
|
)
|
|
|
|
def difference(self, *other: Iterable[Any]) -> OrderedSet[_T]:
|
|
other_set = super().difference(*other)
|
|
return self.__class__(a for a in self._list if a in other_set)
|
|
|
|
def __sub__(self, other: AbstractSet[Optional[_T]]) -> OrderedSet[_T]:
|
|
return self.difference(other)
|
|
|
|
def intersection_update(self, *other: Iterable[Any]) -> None:
|
|
super().intersection_update(*other)
|
|
self._list = [a for a in self._list if a in self]
|
|
|
|
def __iand__(self, other: AbstractSet[object]) -> OrderedSet[_T]:
|
|
self.intersection_update(other)
|
|
return self
|
|
|
|
def symmetric_difference_update(self, other: Iterable[Any]) -> None:
|
|
collection = other if isinstance(other, Collection) else list(other)
|
|
super().symmetric_difference_update(collection)
|
|
self._list = [a for a in self._list if a in self]
|
|
self._list += [a for a in collection if a in self]
|
|
|
|
def __ixor__(self, other: AbstractSet[_S]) -> OrderedSet[Union[_T, _S]]:
|
|
self.symmetric_difference_update(other)
|
|
return cast(OrderedSet[Union[_T, _S]], self)
|
|
|
|
def difference_update(self, *other: Iterable[Any]) -> None:
|
|
super().difference_update(*other)
|
|
self._list = [a for a in self._list if a in self]
|
|
|
|
def __isub__(self, other: AbstractSet[Optional[_T]]) -> OrderedSet[_T]: # type: ignore # noqa: E501
|
|
self.difference_update(other)
|
|
return self
|
|
|
|
|
|
class IdentitySet:
|
|
"""A set that considers only object id() for uniqueness.
|
|
|
|
This strategy has edge cases for builtin types- it's possible to have
|
|
two 'foo' strings in one of these sets, for example. Use sparingly.
|
|
|
|
"""
|
|
|
|
_members: Dict[int, Any]
|
|
|
|
def __init__(self, iterable: Optional[Iterable[Any]] = None):
|
|
self._members = dict()
|
|
if iterable:
|
|
self.update(iterable)
|
|
|
|
def add(self, value: Any) -> None:
|
|
self._members[id(value)] = value
|
|
|
|
def __contains__(self, value: Any) -> bool:
|
|
return id(value) in self._members
|
|
|
|
def remove(self, value: Any) -> None:
|
|
del self._members[id(value)]
|
|
|
|
def discard(self, value: Any) -> None:
|
|
try:
|
|
self.remove(value)
|
|
except KeyError:
|
|
pass
|
|
|
|
def pop(self) -> Any:
|
|
try:
|
|
pair = self._members.popitem()
|
|
return pair[1]
|
|
except KeyError:
|
|
raise KeyError("pop from an empty set")
|
|
|
|
def clear(self) -> None:
|
|
self._members.clear()
|
|
|
|
def __eq__(self, other: Any) -> bool:
|
|
if isinstance(other, IdentitySet):
|
|
return self._members == other._members
|
|
else:
|
|
return False
|
|
|
|
def __ne__(self, other: Any) -> bool:
|
|
if isinstance(other, IdentitySet):
|
|
return self._members != other._members
|
|
else:
|
|
return True
|
|
|
|
def issubset(self, iterable: Iterable[Any]) -> bool:
|
|
if isinstance(iterable, self.__class__):
|
|
other = iterable
|
|
else:
|
|
other = self.__class__(iterable)
|
|
|
|
if len(self) > len(other):
|
|
return False
|
|
for m in filterfalse(
|
|
other._members.__contains__, iter(self._members.keys())
|
|
):
|
|
return False
|
|
return True
|
|
|
|
def __le__(self, other: Any) -> bool:
|
|
if not isinstance(other, IdentitySet):
|
|
return NotImplemented
|
|
return self.issubset(other)
|
|
|
|
def __lt__(self, other: Any) -> bool:
|
|
if not isinstance(other, IdentitySet):
|
|
return NotImplemented
|
|
return len(self) < len(other) and self.issubset(other)
|
|
|
|
def issuperset(self, iterable: Iterable[Any]) -> bool:
|
|
if isinstance(iterable, self.__class__):
|
|
other = iterable
|
|
else:
|
|
other = self.__class__(iterable)
|
|
|
|
if len(self) < len(other):
|
|
return False
|
|
|
|
for m in filterfalse(
|
|
self._members.__contains__, iter(other._members.keys())
|
|
):
|
|
return False
|
|
return True
|
|
|
|
def __ge__(self, other: Any) -> bool:
|
|
if not isinstance(other, IdentitySet):
|
|
return NotImplemented
|
|
return self.issuperset(other)
|
|
|
|
def __gt__(self, other: Any) -> bool:
|
|
if not isinstance(other, IdentitySet):
|
|
return NotImplemented
|
|
return len(self) > len(other) and self.issuperset(other)
|
|
|
|
def union(self, iterable: Iterable[Any]) -> IdentitySet:
|
|
result = self.__class__()
|
|
members = self._members
|
|
result._members.update(members)
|
|
result._members.update((id(obj), obj) for obj in iterable)
|
|
return result
|
|
|
|
def __or__(self, other: Any) -> IdentitySet:
|
|
if not isinstance(other, IdentitySet):
|
|
return NotImplemented
|
|
return self.union(other)
|
|
|
|
def update(self, iterable: Iterable[Any]) -> None:
|
|
self._members.update((id(obj), obj) for obj in iterable)
|
|
|
|
def __ior__(self, other: Any) -> IdentitySet:
|
|
if not isinstance(other, IdentitySet):
|
|
return NotImplemented
|
|
self.update(other)
|
|
return self
|
|
|
|
def difference(self, iterable: Iterable[Any]) -> IdentitySet:
|
|
result = self.__new__(self.__class__)
|
|
other: Collection[Any]
|
|
|
|
if isinstance(iterable, self.__class__):
|
|
other = iterable._members
|
|
else:
|
|
other = {id(obj) for obj in iterable}
|
|
result._members = {
|
|
k: v for k, v in self._members.items() if k not in other
|
|
}
|
|
return result
|
|
|
|
def __sub__(self, other: IdentitySet) -> IdentitySet:
|
|
if not isinstance(other, IdentitySet):
|
|
return NotImplemented
|
|
return self.difference(other)
|
|
|
|
def difference_update(self, iterable: Iterable[Any]) -> None:
|
|
self._members = self.difference(iterable)._members
|
|
|
|
def __isub__(self, other: IdentitySet) -> IdentitySet:
|
|
if not isinstance(other, IdentitySet):
|
|
return NotImplemented
|
|
self.difference_update(other)
|
|
return self
|
|
|
|
def intersection(self, iterable: Iterable[Any]) -> IdentitySet:
|
|
result = self.__new__(self.__class__)
|
|
|
|
other: Collection[Any]
|
|
|
|
if isinstance(iterable, self.__class__):
|
|
other = iterable._members
|
|
else:
|
|
other = {id(obj) for obj in iterable}
|
|
result._members = {
|
|
k: v for k, v in self._members.items() if k in other
|
|
}
|
|
return result
|
|
|
|
def __and__(self, other: IdentitySet) -> IdentitySet:
|
|
if not isinstance(other, IdentitySet):
|
|
return NotImplemented
|
|
return self.intersection(other)
|
|
|
|
def intersection_update(self, iterable: Iterable[Any]) -> None:
|
|
self._members = self.intersection(iterable)._members
|
|
|
|
def __iand__(self, other: IdentitySet) -> IdentitySet:
|
|
if not isinstance(other, IdentitySet):
|
|
return NotImplemented
|
|
self.intersection_update(other)
|
|
return self
|
|
|
|
def symmetric_difference(self, iterable: Iterable[Any]) -> IdentitySet:
|
|
result = self.__new__(self.__class__)
|
|
if isinstance(iterable, self.__class__):
|
|
other = iterable._members
|
|
else:
|
|
other = {id(obj): obj for obj in iterable}
|
|
result._members = {
|
|
k: v for k, v in self._members.items() if k not in other
|
|
}
|
|
result._members.update(
|
|
(k, v) for k, v in other.items() if k not in self._members
|
|
)
|
|
return result
|
|
|
|
def __xor__(self, other: IdentitySet) -> IdentitySet:
|
|
if not isinstance(other, IdentitySet):
|
|
return NotImplemented
|
|
return self.symmetric_difference(other)
|
|
|
|
def symmetric_difference_update(self, iterable: Iterable[Any]) -> None:
|
|
self._members = self.symmetric_difference(iterable)._members
|
|
|
|
def __ixor__(self, other: IdentitySet) -> IdentitySet:
|
|
if not isinstance(other, IdentitySet):
|
|
return NotImplemented
|
|
self.symmetric_difference(other)
|
|
return self
|
|
|
|
def copy(self) -> IdentitySet:
|
|
result = self.__new__(self.__class__)
|
|
result._members = self._members.copy()
|
|
return result
|
|
|
|
__copy__ = copy
|
|
|
|
def __len__(self) -> int:
|
|
return len(self._members)
|
|
|
|
def __iter__(self) -> Iterator[Any]:
|
|
return iter(self._members.values())
|
|
|
|
def __hash__(self) -> NoReturn:
|
|
raise TypeError("set objects are unhashable")
|
|
|
|
def __repr__(self) -> str:
|
|
return "%s(%r)" % (type(self).__name__, list(self._members.values()))
|
|
|
|
|
|
def unique_list(
|
|
seq: Iterable[_T], hashfunc: Optional[Callable[[_T], int]] = None
|
|
) -> List[_T]:
|
|
seen: Set[Any] = set()
|
|
seen_add = seen.add
|
|
if not hashfunc:
|
|
return [x for x in seq if x not in seen and not seen_add(x)]
|
|
else:
|
|
return [
|
|
x
|
|
for x in seq
|
|
if hashfunc(x) not in seen and not seen_add(hashfunc(x))
|
|
]
|