You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
227 lines
9.0 KiB
227 lines
9.0 KiB
import functools
|
|
import re
|
|
import sys
|
|
|
|
from peewee import *
|
|
from peewee import _atomic
|
|
from peewee import _manual
|
|
from peewee import ColumnMetadata # (name, data_type, null, primary_key, table, default)
|
|
from peewee import EnclosedNodeList
|
|
from peewee import Entity
|
|
from peewee import ForeignKeyMetadata # (column, dest_table, dest_column, table).
|
|
from peewee import IndexMetadata
|
|
from peewee import NodeList
|
|
from playhouse.pool import _PooledPostgresqlDatabase
|
|
try:
|
|
from playhouse.postgres_ext import ArrayField
|
|
from playhouse.postgres_ext import BinaryJSONField
|
|
from playhouse.postgres_ext import IntervalField
|
|
JSONField = BinaryJSONField
|
|
except ImportError: # psycopg2 not installed, ignore.
|
|
ArrayField = BinaryJSONField = IntervalField = JSONField = None
|
|
|
|
if sys.version_info[0] > 2:
|
|
basestring = str
|
|
|
|
|
|
NESTED_TX_MIN_VERSION = 200100
|
|
|
|
TXN_ERR_MSG = ('CockroachDB does not support nested transactions. You may '
|
|
'alternatively use the @transaction context-manager/decorator, '
|
|
'which only wraps the outer-most block in transactional logic. '
|
|
'To run a transaction with automatic retries, use the '
|
|
'run_transaction() helper.')
|
|
|
|
class ExceededMaxAttempts(OperationalError): pass
|
|
|
|
|
|
class UUIDKeyField(UUIDField):
|
|
auto_increment = True
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
if kwargs.get('constraints'):
|
|
raise ValueError('%s cannot specify constraints.' % type(self))
|
|
kwargs['constraints'] = [SQL('DEFAULT gen_random_uuid()')]
|
|
kwargs.setdefault('primary_key', True)
|
|
super(UUIDKeyField, self).__init__(*args, **kwargs)
|
|
|
|
|
|
class RowIDField(AutoField):
|
|
field_type = 'INT'
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
if kwargs.get('constraints'):
|
|
raise ValueError('%s cannot specify constraints.' % type(self))
|
|
kwargs['constraints'] = [SQL('DEFAULT unique_rowid()')]
|
|
super(RowIDField, self).__init__(*args, **kwargs)
|
|
|
|
|
|
class CockroachDatabase(PostgresqlDatabase):
|
|
field_types = PostgresqlDatabase.field_types.copy()
|
|
field_types.update({
|
|
'BLOB': 'BYTES',
|
|
})
|
|
|
|
for_update = False
|
|
nulls_ordering = False
|
|
release_after_rollback = True
|
|
|
|
def __init__(self, database, *args, **kwargs):
|
|
# Unless a DSN or database connection-url were specified, provide
|
|
# convenient defaults for the user and port.
|
|
if 'dsn' not in kwargs and (database and
|
|
not database.startswith('postgresql://')):
|
|
kwargs.setdefault('user', 'root')
|
|
kwargs.setdefault('port', 26257)
|
|
super(CockroachDatabase, self).__init__(database, *args, **kwargs)
|
|
|
|
def _set_server_version(self, conn):
|
|
curs = conn.cursor()
|
|
curs.execute('select version()')
|
|
raw, = curs.fetchone()
|
|
match_obj = re.match(r'^CockroachDB.+?v(\d+)\.(\d+)\.(\d+)', raw)
|
|
if match_obj is not None:
|
|
clean = '%d%02d%02d' % tuple(int(i) for i in match_obj.groups())
|
|
self.server_version = int(clean) # 19.1.5 -> 190105.
|
|
else:
|
|
# Fallback to use whatever cockroachdb tells us via protocol.
|
|
super(CockroachDatabase, self)._set_server_version(conn)
|
|
|
|
def _get_pk_constraint(self, table, schema=None):
|
|
query = ('SELECT constraint_name '
|
|
'FROM information_schema.table_constraints '
|
|
'WHERE table_name = %s AND table_schema = %s '
|
|
'AND constraint_type = %s')
|
|
cursor = self.execute_sql(query, (table, schema or 'public',
|
|
'PRIMARY KEY'))
|
|
row = cursor.fetchone()
|
|
return row and row[0] or None
|
|
|
|
def get_indexes(self, table, schema=None):
|
|
# The primary-key index is returned by default, so we will just strip
|
|
# it out here.
|
|
indexes = super(CockroachDatabase, self).get_indexes(table, schema)
|
|
pkc = self._get_pk_constraint(table, schema)
|
|
return [idx for idx in indexes if (not pkc) or (idx.name != pkc)]
|
|
|
|
def conflict_statement(self, on_conflict, query):
|
|
if not on_conflict._action: return
|
|
|
|
action = on_conflict._action.lower()
|
|
if action in ('replace', 'upsert'):
|
|
return SQL('UPSERT')
|
|
elif action not in ('ignore', 'nothing', 'update'):
|
|
raise ValueError('Un-supported action for conflict resolution. '
|
|
'CockroachDB supports REPLACE (UPSERT), IGNORE '
|
|
'and UPDATE.')
|
|
|
|
def conflict_update(self, oc, query):
|
|
action = oc._action.lower() if oc._action else ''
|
|
if action in ('ignore', 'nothing'):
|
|
parts = [SQL('ON CONFLICT')]
|
|
if oc._conflict_target:
|
|
parts.append(EnclosedNodeList([
|
|
Entity(col) if isinstance(col, basestring) else col
|
|
for col in oc._conflict_target]))
|
|
parts.append(SQL('DO NOTHING'))
|
|
return NodeList(parts)
|
|
elif action in ('replace', 'upsert'):
|
|
# No special stuff is necessary, this is just indicated by starting
|
|
# the statement with UPSERT instead of INSERT.
|
|
return
|
|
elif oc._conflict_constraint:
|
|
raise ValueError('CockroachDB does not support the usage of a '
|
|
'constraint name. Use the column(s) instead.')
|
|
|
|
return super(CockroachDatabase, self).conflict_update(oc, query)
|
|
|
|
def extract_date(self, date_part, date_field):
|
|
return fn.extract(date_part, date_field)
|
|
|
|
def from_timestamp(self, date_field):
|
|
# CRDB does not allow casting a decimal/float to timestamp, so we first
|
|
# cast to int, then to timestamptz.
|
|
return date_field.cast('int').cast('timestamptz')
|
|
|
|
def begin(self, system_time=None, priority=None):
|
|
super(CockroachDatabase, self).begin()
|
|
if system_time is not None:
|
|
self.execute_sql('SET TRANSACTION AS OF SYSTEM TIME %s',
|
|
(system_time,), commit=False)
|
|
if priority is not None:
|
|
priority = priority.lower()
|
|
if priority not in ('low', 'normal', 'high'):
|
|
raise ValueError('priority must be low, normal or high')
|
|
self.execute_sql('SET TRANSACTION PRIORITY %s' % priority,
|
|
commit=False)
|
|
|
|
def atomic(self, system_time=None, priority=None):
|
|
if self.is_closed(): self.connect() # Side-effect, set server version.
|
|
if self.server_version < NESTED_TX_MIN_VERSION:
|
|
return _crdb_atomic(self, system_time, priority)
|
|
return super(CockroachDatabase, self).atomic(system_time, priority)
|
|
|
|
def savepoint(self):
|
|
if self.is_closed(): self.connect() # Side-effect, set server version.
|
|
if self.server_version < NESTED_TX_MIN_VERSION:
|
|
raise NotImplementedError(TXN_ERR_MSG)
|
|
return super(CockroachDatabase, self).savepoint()
|
|
|
|
def retry_transaction(self, max_attempts=None, system_time=None,
|
|
priority=None):
|
|
def deco(cb):
|
|
@functools.wraps(cb)
|
|
def new_fn():
|
|
return run_transaction(self, cb, max_attempts, system_time,
|
|
priority)
|
|
return new_fn
|
|
return deco
|
|
|
|
def run_transaction(self, cb, max_attempts=None, system_time=None,
|
|
priority=None):
|
|
return run_transaction(self, cb, max_attempts, system_time, priority)
|
|
|
|
|
|
class _crdb_atomic(_atomic):
|
|
def __enter__(self):
|
|
if self.db.transaction_depth() > 0:
|
|
if not isinstance(self.db.top_transaction(), _manual):
|
|
raise NotImplementedError(TXN_ERR_MSG)
|
|
return super(_crdb_atomic, self).__enter__()
|
|
|
|
|
|
def run_transaction(db, callback, max_attempts=None, system_time=None,
|
|
priority=None):
|
|
"""
|
|
Run transactional SQL in a transaction with automatic retries.
|
|
|
|
User-provided `callback`:
|
|
* Must accept one parameter, the `db` instance representing the connection
|
|
the transaction is running under.
|
|
* Must not attempt to commit, rollback or otherwise manage transactions.
|
|
* May be called more than once.
|
|
* Should ideally only contain SQL operations.
|
|
|
|
Additionally, the database must not have any open transaction at the time
|
|
this function is called, as CRDB does not support nested transactions.
|
|
"""
|
|
max_attempts = max_attempts or -1
|
|
with db.atomic(system_time=system_time, priority=priority) as txn:
|
|
db.execute_sql('SAVEPOINT cockroach_restart')
|
|
while max_attempts != 0:
|
|
try:
|
|
result = callback(db)
|
|
db.execute_sql('RELEASE SAVEPOINT cockroach_restart')
|
|
return result
|
|
except OperationalError as exc:
|
|
if exc.orig.pgcode == '40001':
|
|
max_attempts -= 1
|
|
db.execute_sql('ROLLBACK TO SAVEPOINT cockroach_restart')
|
|
continue
|
|
raise
|
|
raise ExceededMaxAttempts(None, 'unable to commit transaction')
|
|
|
|
|
|
class PooledCockroachDatabase(_PooledPostgresqlDatabase, CockroachDatabase):
|
|
pass
|