You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
173 lines
5.2 KiB
173 lines
5.2 KiB
import operator
|
|
|
|
from peewee import *
|
|
from peewee import Expression
|
|
from playhouse.fields import PickleField
|
|
try:
|
|
from playhouse.sqlite_ext import CSqliteExtDatabase as SqliteExtDatabase
|
|
except ImportError:
|
|
from playhouse.sqlite_ext import SqliteExtDatabase
|
|
|
|
|
|
Sentinel = type('Sentinel', (object,), {})
|
|
|
|
|
|
class KeyValue(object):
|
|
"""
|
|
Persistent dictionary.
|
|
|
|
:param Field key_field: field to use for key. Defaults to CharField.
|
|
:param Field value_field: field to use for value. Defaults to PickleField.
|
|
:param bool ordered: data should be returned in key-sorted order.
|
|
:param Database database: database where key/value data is stored.
|
|
:param str table_name: table name for data.
|
|
"""
|
|
def __init__(self, key_field=None, value_field=None, ordered=False,
|
|
database=None, table_name='keyvalue'):
|
|
if key_field is None:
|
|
key_field = CharField(max_length=255, primary_key=True)
|
|
if not key_field.primary_key:
|
|
raise ValueError('key_field must have primary_key=True.')
|
|
|
|
if value_field is None:
|
|
value_field = PickleField()
|
|
|
|
self._key_field = key_field
|
|
self._value_field = value_field
|
|
self._ordered = ordered
|
|
self._database = database or SqliteExtDatabase(':memory:')
|
|
self._table_name = table_name
|
|
if isinstance(self._database, PostgresqlDatabase):
|
|
self.upsert = self._postgres_upsert
|
|
self.update = self._postgres_update
|
|
else:
|
|
self.upsert = self._upsert
|
|
self.update = self._update
|
|
|
|
self.model = self.create_model()
|
|
self.key = self.model.key
|
|
self.value = self.model.value
|
|
|
|
# Ensure table exists.
|
|
self.model.create_table()
|
|
|
|
def create_model(self):
|
|
class KeyValue(Model):
|
|
key = self._key_field
|
|
value = self._value_field
|
|
class Meta:
|
|
database = self._database
|
|
table_name = self._table_name
|
|
return KeyValue
|
|
|
|
def query(self, *select):
|
|
query = self.model.select(*select).tuples()
|
|
if self._ordered:
|
|
query = query.order_by(self.key)
|
|
return query
|
|
|
|
def convert_expression(self, expr):
|
|
if not isinstance(expr, Expression):
|
|
return (self.key == expr), True
|
|
return expr, False
|
|
|
|
def __contains__(self, key):
|
|
expr, _ = self.convert_expression(key)
|
|
return self.model.select().where(expr).exists()
|
|
|
|
def __len__(self):
|
|
return len(self.model)
|
|
|
|
def __getitem__(self, expr):
|
|
converted, is_single = self.convert_expression(expr)
|
|
query = self.query(self.value).where(converted)
|
|
item_getter = operator.itemgetter(0)
|
|
result = [item_getter(row) for row in query]
|
|
if len(result) == 0 and is_single:
|
|
raise KeyError(expr)
|
|
elif is_single:
|
|
return result[0]
|
|
return result
|
|
|
|
def _upsert(self, key, value):
|
|
(self.model
|
|
.insert(key=key, value=value)
|
|
.on_conflict('replace')
|
|
.execute())
|
|
|
|
def _postgres_upsert(self, key, value):
|
|
(self.model
|
|
.insert(key=key, value=value)
|
|
.on_conflict(conflict_target=[self.key],
|
|
preserve=[self.value])
|
|
.execute())
|
|
|
|
def __setitem__(self, expr, value):
|
|
if isinstance(expr, Expression):
|
|
self.model.update(value=value).where(expr).execute()
|
|
else:
|
|
self.upsert(expr, value)
|
|
|
|
def __delitem__(self, expr):
|
|
converted, _ = self.convert_expression(expr)
|
|
self.model.delete().where(converted).execute()
|
|
|
|
def __iter__(self):
|
|
return iter(self.query().execute())
|
|
|
|
def keys(self):
|
|
return map(operator.itemgetter(0), self.query(self.key))
|
|
|
|
def values(self):
|
|
return map(operator.itemgetter(0), self.query(self.value))
|
|
|
|
def items(self):
|
|
return iter(self.query().execute())
|
|
|
|
def _update(self, __data=None, **mapping):
|
|
if __data is not None:
|
|
mapping.update(__data)
|
|
return (self.model
|
|
.insert_many(list(mapping.items()),
|
|
fields=[self.key, self.value])
|
|
.on_conflict('replace')
|
|
.execute())
|
|
|
|
def _postgres_update(self, __data=None, **mapping):
|
|
if __data is not None:
|
|
mapping.update(__data)
|
|
return (self.model
|
|
.insert_many(list(mapping.items()),
|
|
fields=[self.key, self.value])
|
|
.on_conflict(conflict_target=[self.key],
|
|
preserve=[self.value])
|
|
.execute())
|
|
|
|
def get(self, key, default=None):
|
|
try:
|
|
return self[key]
|
|
except KeyError:
|
|
return default
|
|
|
|
def setdefault(self, key, default=None):
|
|
try:
|
|
return self[key]
|
|
except KeyError:
|
|
self[key] = default
|
|
return default
|
|
|
|
def pop(self, key, default=Sentinel):
|
|
with self._database.atomic():
|
|
try:
|
|
result = self[key]
|
|
except KeyError:
|
|
if default is Sentinel:
|
|
raise
|
|
return default
|
|
del self[key]
|
|
|
|
return result
|
|
|
|
def clear(self):
|
|
self.model.delete().execute()
|