parent
c2eec34aff
commit
2b9d892ca9
@ -1,88 +0,0 @@
|
||||
BEGIN TRANSACTION;
|
||||
CREATE TABLE "table_shows" (
|
||||
`tvdbId` INTEGER NOT NULL UNIQUE,
|
||||
`title` TEXT NOT NULL,
|
||||
`path` TEXT NOT NULL UNIQUE,
|
||||
`languages` TEXT,
|
||||
`hearing_impaired` TEXT,
|
||||
`sonarrSeriesId` INTEGER NOT NULL UNIQUE,
|
||||
`overview` TEXT,
|
||||
`poster` TEXT,
|
||||
`fanart` TEXT,
|
||||
`audio_language` "text",
|
||||
`sortTitle` "text",
|
||||
PRIMARY KEY(`tvdbId`)
|
||||
);
|
||||
CREATE TABLE "table_settings_providers" (
|
||||
`name` TEXT NOT NULL UNIQUE,
|
||||
`enabled` INTEGER,
|
||||
`username` "text",
|
||||
`password` "text",
|
||||
PRIMARY KEY(`name`)
|
||||
);
|
||||
CREATE TABLE "table_settings_notifier" (
|
||||
`name` TEXT,
|
||||
`url` TEXT,
|
||||
`enabled` INTEGER,
|
||||
PRIMARY KEY(`name`)
|
||||
);
|
||||
CREATE TABLE "table_settings_languages" (
|
||||
`code3` TEXT NOT NULL UNIQUE,
|
||||
`code2` TEXT,
|
||||
`name` TEXT NOT NULL,
|
||||
`enabled` INTEGER,
|
||||
`code3b` TEXT,
|
||||
PRIMARY KEY(`code3`)
|
||||
);
|
||||
CREATE TABLE "table_history" (
|
||||
`id` INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT UNIQUE,
|
||||
`action` INTEGER NOT NULL,
|
||||
`sonarrSeriesId` INTEGER NOT NULL,
|
||||
`sonarrEpisodeId` INTEGER NOT NULL,
|
||||
`timestamp` INTEGER NOT NULL,
|
||||
`description` TEXT NOT NULL
|
||||
);
|
||||
CREATE TABLE "table_episodes" (
|
||||
`sonarrSeriesId` INTEGER NOT NULL,
|
||||
`sonarrEpisodeId` INTEGER NOT NULL UNIQUE,
|
||||
`title` TEXT NOT NULL,
|
||||
`path` TEXT NOT NULL,
|
||||
`season` INTEGER NOT NULL,
|
||||
`episode` INTEGER NOT NULL,
|
||||
`subtitles` TEXT,
|
||||
`missing_subtitles` TEXT,
|
||||
`scene_name` TEXT,
|
||||
`monitored` TEXT,
|
||||
`failedAttempts` "text"
|
||||
);
|
||||
CREATE TABLE "table_movies" (
|
||||
`tmdbId` TEXT NOT NULL UNIQUE,
|
||||
`title` TEXT NOT NULL,
|
||||
`path` TEXT NOT NULL UNIQUE,
|
||||
`languages` TEXT,
|
||||
`subtitles` TEXT,
|
||||
`missing_subtitles` TEXT,
|
||||
`hearing_impaired` TEXT,
|
||||
`radarrId` INTEGER NOT NULL UNIQUE,
|
||||
`overview` TEXT,
|
||||
`poster` TEXT,
|
||||
`fanart` TEXT,
|
||||
`audio_language` "text",
|
||||
`sceneName` TEXT,
|
||||
`monitored` TEXT,
|
||||
`failedAttempts` "text",
|
||||
PRIMARY KEY(`tmdbId`)
|
||||
);
|
||||
CREATE TABLE "table_history_movie" (
|
||||
`id` INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT UNIQUE,
|
||||
`action` INTEGER NOT NULL,
|
||||
`radarrId` INTEGER NOT NULL,
|
||||
`timestamp` INTEGER NOT NULL,
|
||||
`description` TEXT NOT NULL
|
||||
);
|
||||
CREATE TABLE "system" (
|
||||
`configured` TEXT,
|
||||
`updated` TEXT
|
||||
);
|
||||
INSERT INTO `system` (configured, updated) VALUES ('0', '0');
|
||||
COMMIT;
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,48 @@
|
||||
## Playhouse
|
||||
|
||||
The `playhouse` namespace contains numerous extensions to Peewee. These include vendor-specific database extensions, high-level abstractions to simplify working with databases, and tools for low-level database operations and introspection.
|
||||
|
||||
### Vendor extensions
|
||||
|
||||
* [SQLite extensions](http://docs.peewee-orm.com/en/latest/peewee/sqlite_ext.html)
|
||||
* Full-text search (FTS3/4/5)
|
||||
* BM25 ranking algorithm implemented as SQLite C extension, backported to FTS4
|
||||
* Virtual tables and C extensions
|
||||
* Closure tables
|
||||
* JSON extension support
|
||||
* LSM1 (key/value database) support
|
||||
* BLOB API
|
||||
* Online backup API
|
||||
* [APSW extensions](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#apsw): use Peewee with the powerful [APSW](https://github.com/rogerbinns/apsw) SQLite driver.
|
||||
* [SQLCipher](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#sqlcipher-ext): encrypted SQLite databases.
|
||||
* [SqliteQ](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#sqliteq): dedicated writer thread for multi-threaded SQLite applications. [More info here](http://charlesleifer.com/blog/multi-threaded-sqlite-without-the-operationalerrors/).
|
||||
* [Postgresql extensions](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#postgres-ext)
|
||||
* JSON and JSONB
|
||||
* HStore
|
||||
* Arrays
|
||||
* Server-side cursors
|
||||
* Full-text search
|
||||
* [MySQL extensions](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#mysql-ext)
|
||||
|
||||
### High-level libraries
|
||||
|
||||
* [Extra fields](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#extra-fields)
|
||||
* Compressed field
|
||||
* PickleField
|
||||
* [Shortcuts / helpers](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#shortcuts)
|
||||
* Model-to-dict serializer
|
||||
* Dict-to-model deserializer
|
||||
* [Hybrid attributes](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#hybrid)
|
||||
* [Signals](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#signals): pre/post-save, pre/post-delete, pre-init.
|
||||
* [Dataset](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#dataset): high-level API for working with databases popuarlized by the [project of the same name](https://dataset.readthedocs.io/).
|
||||
* [Key/Value Store](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#kv): key/value store using SQLite. Supports *smart indexing*, for *Pandas*-style queries.
|
||||
|
||||
### Database management and framework support
|
||||
|
||||
* [pwiz](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#pwiz): generate model code from a pre-existing database.
|
||||
* [Schema migrations](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#migrate): modify your schema using high-level APIs. Even supports dropping or renaming columns in SQLite.
|
||||
* [Connection pool](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#pool): simple connection pooling.
|
||||
* [Reflection](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#reflection): low-level, cross-platform database introspection
|
||||
* [Database URLs](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#db-url): use URLs to connect to database
|
||||
* [Test utils](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#test-utils): helpers for unit-testing Peewee applications.
|
||||
* [Flask utils](http://docs.peewee-orm.com/en/latest/peewee/playhouse.html#flask-utils): paginated object lists, database connection management, and more.
|
@ -0,0 +1,73 @@
|
||||
/* cache.h - definitions for the LRU cache
|
||||
*
|
||||
* Copyright (C) 2004-2015 Gerhard Häring <gh@ghaering.de>
|
||||
*
|
||||
* This file is part of pysqlite.
|
||||
*
|
||||
* This software is provided 'as-is', without any express or implied
|
||||
* warranty. In no event will the authors be held liable for any damages
|
||||
* arising from the use of this software.
|
||||
*
|
||||
* Permission is granted to anyone to use this software for any purpose,
|
||||
* including commercial applications, and to alter it and redistribute it
|
||||
* freely, subject to the following restrictions:
|
||||
*
|
||||
* 1. The origin of this software must not be misrepresented; you must not
|
||||
* claim that you wrote the original software. If you use this software
|
||||
* in a product, an acknowledgment in the product documentation would be
|
||||
* appreciated but is not required.
|
||||
* 2. Altered source versions must be plainly marked as such, and must not be
|
||||
* misrepresented as being the original software.
|
||||
* 3. This notice may not be removed or altered from any source distribution.
|
||||
*/
|
||||
|
||||
#ifndef PYSQLITE_CACHE_H
|
||||
#define PYSQLITE_CACHE_H
|
||||
#include "Python.h"
|
||||
|
||||
/* The LRU cache is implemented as a combination of a doubly-linked with a
|
||||
* dictionary. The list items are of type 'Node' and the dictionary has the
|
||||
* nodes as values. */
|
||||
|
||||
typedef struct _pysqlite_Node
|
||||
{
|
||||
PyObject_HEAD
|
||||
PyObject* key;
|
||||
PyObject* data;
|
||||
long count;
|
||||
struct _pysqlite_Node* prev;
|
||||
struct _pysqlite_Node* next;
|
||||
} pysqlite_Node;
|
||||
|
||||
typedef struct
|
||||
{
|
||||
PyObject_HEAD
|
||||
int size;
|
||||
|
||||
/* a dictionary mapping keys to Node entries */
|
||||
PyObject* mapping;
|
||||
|
||||
/* the factory callable */
|
||||
PyObject* factory;
|
||||
|
||||
pysqlite_Node* first;
|
||||
pysqlite_Node* last;
|
||||
|
||||
/* if set, decrement the factory function when the Cache is deallocated.
|
||||
* this is almost always desirable, but not in the pysqlite context */
|
||||
int decref_factory;
|
||||
} pysqlite_Cache;
|
||||
|
||||
extern PyTypeObject pysqlite_NodeType;
|
||||
extern PyTypeObject pysqlite_CacheType;
|
||||
|
||||
int pysqlite_node_init(pysqlite_Node* self, PyObject* args, PyObject* kwargs);
|
||||
void pysqlite_node_dealloc(pysqlite_Node* self);
|
||||
|
||||
int pysqlite_cache_init(pysqlite_Cache* self, PyObject* args, PyObject* kwargs);
|
||||
void pysqlite_cache_dealloc(pysqlite_Cache* self);
|
||||
PyObject* pysqlite_cache_get(pysqlite_Cache* self, PyObject* args);
|
||||
|
||||
int pysqlite_cache_setup_types(void);
|
||||
|
||||
#endif
|
@ -0,0 +1,129 @@
|
||||
/* connection.h - definitions for the connection type
|
||||
*
|
||||
* Copyright (C) 2004-2015 Gerhard Häring <gh@ghaering.de>
|
||||
*
|
||||
* This file is part of pysqlite.
|
||||
*
|
||||
* This software is provided 'as-is', without any express or implied
|
||||
* warranty. In no event will the authors be held liable for any damages
|
||||
* arising from the use of this software.
|
||||
*
|
||||
* Permission is granted to anyone to use this software for any purpose,
|
||||
* including commercial applications, and to alter it and redistribute it
|
||||
* freely, subject to the following restrictions:
|
||||
*
|
||||
* 1. The origin of this software must not be misrepresented; you must not
|
||||
* claim that you wrote the original software. If you use this software
|
||||
* in a product, an acknowledgment in the product documentation would be
|
||||
* appreciated but is not required.
|
||||
* 2. Altered source versions must be plainly marked as such, and must not be
|
||||
* misrepresented as being the original software.
|
||||
* 3. This notice may not be removed or altered from any source distribution.
|
||||
*/
|
||||
|
||||
#ifndef PYSQLITE_CONNECTION_H
|
||||
#define PYSQLITE_CONNECTION_H
|
||||
#include "Python.h"
|
||||
#include "pythread.h"
|
||||
#include "structmember.h"
|
||||
|
||||
#include "cache.h"
|
||||
#include "module.h"
|
||||
|
||||
#include "sqlite3.h"
|
||||
|
||||
typedef struct
|
||||
{
|
||||
PyObject_HEAD
|
||||
sqlite3* db;
|
||||
|
||||
/* the type detection mode. Only 0, PARSE_DECLTYPES, PARSE_COLNAMES or a
|
||||
* bitwise combination thereof makes sense */
|
||||
int detect_types;
|
||||
|
||||
/* the timeout value in seconds for database locks */
|
||||
double timeout;
|
||||
|
||||
/* for internal use in the timeout handler: when did the timeout handler
|
||||
* first get called with count=0? */
|
||||
double timeout_started;
|
||||
|
||||
/* None for autocommit, otherwise a PyString with the isolation level */
|
||||
PyObject* isolation_level;
|
||||
|
||||
/* NULL for autocommit, otherwise a string with the BEGIN statement; will be
|
||||
* freed in connection destructor */
|
||||
char* begin_statement;
|
||||
|
||||
/* 1 if a check should be performed for each API call if the connection is
|
||||
* used from the same thread it was created in */
|
||||
int check_same_thread;
|
||||
|
||||
int initialized;
|
||||
|
||||
/* thread identification of the thread the connection was created in */
|
||||
long thread_ident;
|
||||
|
||||
pysqlite_Cache* statement_cache;
|
||||
|
||||
/* Lists of weak references to statements and cursors used within this connection */
|
||||
PyObject* statements;
|
||||
PyObject* cursors;
|
||||
|
||||
/* Counters for how many statements/cursors were created in the connection. May be
|
||||
* reset to 0 at certain intervals */
|
||||
int created_statements;
|
||||
int created_cursors;
|
||||
|
||||
PyObject* row_factory;
|
||||
|
||||
/* Determines how bytestrings from SQLite are converted to Python objects:
|
||||
* - PyUnicode_Type: Python Unicode objects are constructed from UTF-8 bytestrings
|
||||
* - OptimizedUnicode: Like before, but for ASCII data, only PyStrings are created.
|
||||
* - PyString_Type: PyStrings are created as-is.
|
||||
* - Any custom callable: Any object returned from the callable called with the bytestring
|
||||
* as single parameter.
|
||||
*/
|
||||
PyObject* text_factory;
|
||||
|
||||
/* remember references to functions/classes used in
|
||||
* create_function/create/aggregate, use these as dictionary keys, so we
|
||||
* can keep the total system refcount constant by clearing that dictionary
|
||||
* in connection_dealloc */
|
||||
PyObject* function_pinboard;
|
||||
|
||||
/* a dictionary of registered collation name => collation callable mappings */
|
||||
PyObject* collations;
|
||||
|
||||
/* Exception objects */
|
||||
PyObject* Warning;
|
||||
PyObject* Error;
|
||||
PyObject* InterfaceError;
|
||||
PyObject* DatabaseError;
|
||||
PyObject* DataError;
|
||||
PyObject* OperationalError;
|
||||
PyObject* IntegrityError;
|
||||
PyObject* InternalError;
|
||||
PyObject* ProgrammingError;
|
||||
PyObject* NotSupportedError;
|
||||
} pysqlite_Connection;
|
||||
|
||||
extern PyTypeObject pysqlite_ConnectionType;
|
||||
|
||||
PyObject* pysqlite_connection_alloc(PyTypeObject* type, int aware);
|
||||
void pysqlite_connection_dealloc(pysqlite_Connection* self);
|
||||
PyObject* pysqlite_connection_cursor(pysqlite_Connection* self, PyObject* args, PyObject* kwargs);
|
||||
PyObject* pysqlite_connection_close(pysqlite_Connection* self, PyObject* args);
|
||||
PyObject* _pysqlite_connection_begin(pysqlite_Connection* self);
|
||||
PyObject* pysqlite_connection_commit(pysqlite_Connection* self, PyObject* args);
|
||||
PyObject* pysqlite_connection_rollback(pysqlite_Connection* self, PyObject* args);
|
||||
PyObject* pysqlite_connection_new(PyTypeObject* type, PyObject* args, PyObject* kw);
|
||||
int pysqlite_connection_init(pysqlite_Connection* self, PyObject* args, PyObject* kwargs);
|
||||
|
||||
int pysqlite_connection_register_cursor(pysqlite_Connection* connection, PyObject* cursor);
|
||||
int pysqlite_check_thread(pysqlite_Connection* self);
|
||||
int pysqlite_check_connection(pysqlite_Connection* con);
|
||||
|
||||
int pysqlite_connection_setup_types(void);
|
||||
|
||||
#endif
|
@ -0,0 +1,58 @@
|
||||
/* module.h - definitions for the module
|
||||
*
|
||||
* Copyright (C) 2004-2015 Gerhard Häring <gh@ghaering.de>
|
||||
*
|
||||
* This file is part of pysqlite.
|
||||
*
|
||||
* This software is provided 'as-is', without any express or implied
|
||||
* warranty. In no event will the authors be held liable for any damages
|
||||
* arising from the use of this software.
|
||||
*
|
||||
* Permission is granted to anyone to use this software for any purpose,
|
||||
* including commercial applications, and to alter it and redistribute it
|
||||
* freely, subject to the following restrictions:
|
||||
*
|
||||
* 1. The origin of this software must not be misrepresented; you must not
|
||||
* claim that you wrote the original software. If you use this software
|
||||
* in a product, an acknowledgment in the product documentation would be
|
||||
* appreciated but is not required.
|
||||
* 2. Altered source versions must be plainly marked as such, and must not be
|
||||
* misrepresented as being the original software.
|
||||
* 3. This notice may not be removed or altered from any source distribution.
|
||||
*/
|
||||
|
||||
#ifndef PYSQLITE_MODULE_H
|
||||
#define PYSQLITE_MODULE_H
|
||||
#include "Python.h"
|
||||
|
||||
#define PYSQLITE_VERSION "2.8.2"
|
||||
|
||||
extern PyObject* pysqlite_Error;
|
||||
extern PyObject* pysqlite_Warning;
|
||||
extern PyObject* pysqlite_InterfaceError;
|
||||
extern PyObject* pysqlite_DatabaseError;
|
||||
extern PyObject* pysqlite_InternalError;
|
||||
extern PyObject* pysqlite_OperationalError;
|
||||
extern PyObject* pysqlite_ProgrammingError;
|
||||
extern PyObject* pysqlite_IntegrityError;
|
||||
extern PyObject* pysqlite_DataError;
|
||||
extern PyObject* pysqlite_NotSupportedError;
|
||||
|
||||
extern PyObject* pysqlite_OptimizedUnicode;
|
||||
|
||||
/* the functions time.time() and time.sleep() */
|
||||
extern PyObject* time_time;
|
||||
extern PyObject* time_sleep;
|
||||
|
||||
/* A dictionary, mapping colum types (INTEGER, VARCHAR, etc.) to converter
|
||||
* functions, that convert the SQL value to the appropriate Python value.
|
||||
* The key is uppercase.
|
||||
*/
|
||||
extern PyObject* converters;
|
||||
|
||||
extern int _enable_callback_tracebacks;
|
||||
extern int pysqlite_BaseTypeAdapted;
|
||||
|
||||
#define PARSE_DECLTYPES 1
|
||||
#define PARSE_COLNAMES 2
|
||||
#endif
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,137 @@
|
||||
import sys
|
||||
from difflib import SequenceMatcher
|
||||
from random import randint
|
||||
|
||||
|
||||
IS_PY3K = sys.version_info[0] == 3
|
||||
|
||||
# String UDF.
|
||||
def damerau_levenshtein_dist(s1, s2):
|
||||
cdef:
|
||||
int i, j, del_cost, add_cost, sub_cost
|
||||
int s1_len = len(s1), s2_len = len(s2)
|
||||
list one_ago, two_ago, current_row
|
||||
list zeroes = [0] * (s2_len + 1)
|
||||
|
||||
if IS_PY3K:
|
||||
current_row = list(range(1, s2_len + 2))
|
||||
else:
|
||||
current_row = range(1, s2_len + 2)
|
||||
|
||||
current_row[-1] = 0
|
||||
one_ago = None
|
||||
|
||||
for i in range(s1_len):
|
||||
two_ago = one_ago
|
||||
one_ago = current_row
|
||||
current_row = list(zeroes)
|
||||
current_row[-1] = i + 1
|
||||
for j in range(s2_len):
|
||||
del_cost = one_ago[j] + 1
|
||||
add_cost = current_row[j - 1] + 1
|
||||
sub_cost = one_ago[j - 1] + (s1[i] != s2[j])
|
||||
current_row[j] = min(del_cost, add_cost, sub_cost)
|
||||
|
||||
# Handle transpositions.
|
||||
if (i > 0 and j > 0 and s1[i] == s2[j - 1]
|
||||
and s1[i-1] == s2[j] and s1[i] != s2[j]):
|
||||
current_row[j] = min(current_row[j], two_ago[j - 2] + 1)
|
||||
|
||||
return current_row[s2_len - 1]
|
||||
|
||||
# String UDF.
|
||||
def levenshtein_dist(a, b):
|
||||
cdef:
|
||||
int add, delete, change
|
||||
int i, j
|
||||
int n = len(a), m = len(b)
|
||||
list current, previous
|
||||
list zeroes
|
||||
|
||||
if n > m:
|
||||
a, b = b, a
|
||||
n, m = m, n
|
||||
|
||||
zeroes = [0] * (m + 1)
|
||||
|
||||
if IS_PY3K:
|
||||
current = list(range(n + 1))
|
||||
else:
|
||||
current = range(n + 1)
|
||||
|
||||
for i in range(1, m + 1):
|
||||
previous = current
|
||||
current = list(zeroes)
|
||||
current[0] = i
|
||||
|
||||
for j in range(1, n + 1):
|
||||
add = previous[j] + 1
|
||||
delete = current[j - 1] + 1
|
||||
change = previous[j - 1]
|
||||
if a[j - 1] != b[i - 1]:
|
||||
change +=1
|
||||
current[j] = min(add, delete, change)
|
||||
|
||||
return current[n]
|
||||
|
||||
# String UDF.
|
||||
def str_dist(a, b):
|
||||
cdef:
|
||||
int t = 0
|
||||
|
||||
for i in SequenceMatcher(None, a, b).get_opcodes():
|
||||
if i[0] == 'equal':
|
||||
continue
|
||||
t = t + max(i[4] - i[3], i[2] - i[1])
|
||||
return t
|
||||
|
||||
# Math Aggregate.
|
||||
cdef class median(object):
|
||||
cdef:
|
||||
int ct
|
||||
list items
|
||||
|
||||
def __init__(self):
|
||||
self.ct = 0
|
||||
self.items = []
|
||||
|
||||
cdef selectKth(self, int k, int s=0, int e=-1):
|
||||
cdef:
|
||||
int idx
|
||||
if e < 0:
|
||||
e = len(self.items)
|
||||
idx = randint(s, e-1)
|
||||
idx = self.partition_k(idx, s, e)
|
||||
if idx > k:
|
||||
return self.selectKth(k, s, idx)
|
||||
elif idx < k:
|
||||
return self.selectKth(k, idx + 1, e)
|
||||
else:
|
||||
return self.items[idx]
|
||||
|
||||
cdef int partition_k(self, int pi, int s, int e):
|
||||
cdef:
|
||||
int i, x
|
||||
|
||||
val = self.items[pi]
|
||||
# Swap pivot w/last item.
|
||||
self.items[e - 1], self.items[pi] = self.items[pi], self.items[e - 1]
|
||||
x = s
|
||||
for i in range(s, e):
|
||||
if self.items[i] < val:
|
||||
self.items[i], self.items[x] = self.items[x], self.items[i]
|
||||
x += 1
|
||||
self.items[x], self.items[e-1] = self.items[e-1], self.items[x]
|
||||
return x
|
||||
|
||||
def step(self, item):
|
||||
self.items.append(item)
|
||||
self.ct += 1
|
||||
|
||||
def finalize(self):
|
||||
if self.ct == 0:
|
||||
return None
|
||||
elif self.ct < 3:
|
||||
return self.items[0]
|
||||
else:
|
||||
return self.selectKth(self.ct / 2)
|
@ -0,0 +1,146 @@
|
||||
"""
|
||||
Peewee integration with APSW, "another python sqlite wrapper".
|
||||
|
||||
Project page: https://rogerbinns.github.io/apsw/
|
||||
|
||||
APSW is a really neat library that provides a thin wrapper on top of SQLite's
|
||||
C interface.
|
||||
|
||||
Here are just a few reasons to use APSW, taken from the documentation:
|
||||
|
||||
* APSW gives all functionality of SQLite, including virtual tables, virtual
|
||||
file system, blob i/o, backups and file control.
|
||||
* Connections can be shared across threads without any additional locking.
|
||||
* Transactions are managed explicitly by your code.
|
||||
* APSW can handle nested transactions.
|
||||
* Unicode is handled correctly.
|
||||
* APSW is faster.
|
||||
"""
|
||||
import apsw
|
||||
from peewee import *
|
||||
from peewee import __exception_wrapper__
|
||||
from peewee import BooleanField as _BooleanField
|
||||
from peewee import DateField as _DateField
|
||||
from peewee import DateTimeField as _DateTimeField
|
||||
from peewee import DecimalField as _DecimalField
|
||||
from peewee import TimeField as _TimeField
|
||||
from peewee import logger
|
||||
|
||||
from playhouse.sqlite_ext import SqliteExtDatabase
|
||||
|
||||
|
||||
class APSWDatabase(SqliteExtDatabase):
|
||||
server_version = tuple(int(i) for i in apsw.sqlitelibversion().split('.'))
|
||||
|
||||
def __init__(self, database, **kwargs):
|
||||
self._modules = {}
|
||||
super(APSWDatabase, self).__init__(database, **kwargs)
|
||||
|
||||
def register_module(self, mod_name, mod_inst):
|
||||
self._modules[mod_name] = mod_inst
|
||||
if not self.is_closed():
|
||||
self.connection().createmodule(mod_name, mod_inst)
|
||||
|
||||
def unregister_module(self, mod_name):
|
||||
del(self._modules[mod_name])
|
||||
|
||||
def _connect(self):
|
||||
conn = apsw.Connection(self.database, **self.connect_params)
|
||||
if self._timeout is not None:
|
||||
conn.setbusytimeout(self._timeout * 1000)
|
||||
try:
|
||||
self._add_conn_hooks(conn)
|
||||
except:
|
||||
conn.close()
|
||||
raise
|
||||
return conn
|
||||
|
||||
def _add_conn_hooks(self, conn):
|
||||
super(APSWDatabase, self)._add_conn_hooks(conn)
|
||||
self._load_modules(conn) # APSW-only.
|
||||
|
||||
def _load_modules(self, conn):
|
||||
for mod_name, mod_inst in self._modules.items():
|
||||
conn.createmodule(mod_name, mod_inst)
|
||||
return conn
|
||||
|
||||
def _load_aggregates(self, conn):
|
||||
for name, (klass, num_params) in self._aggregates.items():
|
||||
def make_aggregate():
|
||||
return (klass(), klass.step, klass.finalize)
|
||||
conn.createaggregatefunction(name, make_aggregate)
|
||||
|
||||
def _load_collations(self, conn):
|
||||
for name, fn in self._collations.items():
|
||||
conn.createcollation(name, fn)
|
||||
|
||||
def _load_functions(self, conn):
|
||||
for name, (fn, num_params) in self._functions.items():
|
||||
conn.createscalarfunction(name, fn, num_params)
|
||||
|
||||
def _load_extensions(self, conn):
|
||||
conn.enableloadextension(True)
|
||||
for extension in self._extensions:
|
||||
conn.loadextension(extension)
|
||||
|
||||
def load_extension(self, extension):
|
||||
self._extensions.add(extension)
|
||||
if not self.is_closed():
|
||||
conn = self.connection()
|
||||
conn.enableloadextension(True)
|
||||
conn.loadextension(extension)
|
||||
|
||||
def last_insert_id(self, cursor, query_type=None):
|
||||
return cursor.getconnection().last_insert_rowid()
|
||||
|
||||
def rows_affected(self, cursor):
|
||||
return cursor.getconnection().changes()
|
||||
|
||||
def begin(self, lock_type='deferred'):
|
||||
self.cursor().execute('begin %s;' % lock_type)
|
||||
|
||||
def commit(self):
|
||||
with __exception_wrapper__:
|
||||
curs = self.cursor()
|
||||
if curs.getconnection().getautocommit():
|
||||
return False
|
||||
curs.execute('commit;')
|
||||
return True
|
||||
|
||||
def rollback(self):
|
||||
with __exception_wrapper__:
|
||||
curs = self.cursor()
|
||||
if curs.getconnection().getautocommit():
|
||||
return False
|
||||
curs.execute('rollback;')
|
||||
return True
|
||||
|
||||
def execute_sql(self, sql, params=None, commit=True):
|
||||
logger.debug((sql, params))
|
||||
with __exception_wrapper__:
|
||||
cursor = self.cursor()
|
||||
cursor.execute(sql, params or ())
|
||||
return cursor
|
||||
|
||||
|
||||
def nh(s, v):
|
||||
if v is not None:
|
||||
return str(v)
|
||||
|
||||
class BooleanField(_BooleanField):
|
||||
def db_value(self, v):
|
||||
v = super(BooleanField, self).db_value(v)
|
||||
if v is not None:
|
||||
return v and 1 or 0
|
||||
|
||||
class DateField(_DateField):
|
||||
db_value = nh
|
||||
|
||||
class TimeField(_TimeField):
|
||||
db_value = nh
|
||||
|
||||
class DateTimeField(_DateTimeField):
|
||||
db_value = nh
|
||||
|
||||
class DecimalField(_DecimalField):
|
||||
db_value = nh
|
@ -0,0 +1,207 @@
|
||||
import functools
|
||||
import re
|
||||
|
||||
from peewee import *
|
||||
from peewee import _atomic
|
||||
from peewee import _manual
|
||||
from peewee import ColumnMetadata # (name, data_type, null, primary_key, table, default)
|
||||
from peewee import ForeignKeyMetadata # (column, dest_table, dest_column, table).
|
||||
from peewee import IndexMetadata
|
||||
from playhouse.pool import _PooledPostgresqlDatabase
|
||||
try:
|
||||
from playhouse.postgres_ext import ArrayField
|
||||
from playhouse.postgres_ext import BinaryJSONField
|
||||
from playhouse.postgres_ext import IntervalField
|
||||
JSONField = BinaryJSONField
|
||||
except ImportError: # psycopg2 not installed, ignore.
|
||||
ArrayField = BinaryJSONField = IntervalField = JSONField = None
|
||||
|
||||
|
||||
NESTED_TX_MIN_VERSION = 200100
|
||||
|
||||
TXN_ERR_MSG = ('CockroachDB does not support nested transactions. You may '
|
||||
'alternatively use the @transaction context-manager/decorator, '
|
||||
'which only wraps the outer-most block in transactional logic. '
|
||||
'To run a transaction with automatic retries, use the '
|
||||
'run_transaction() helper.')
|
||||
|
||||
class ExceededMaxAttempts(OperationalError): pass
|
||||
|
||||
|
||||
class UUIDKeyField(UUIDField):
|
||||
auto_increment = True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if kwargs.get('constraints'):
|
||||
raise ValueError('%s cannot specify constraints.' % type(self))
|
||||
kwargs['constraints'] = [SQL('DEFAULT gen_random_uuid()')]
|
||||
kwargs.setdefault('primary_key', True)
|
||||
super(UUIDKeyField, self).__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class RowIDField(AutoField):
|
||||
field_type = 'INT'
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if kwargs.get('constraints'):
|
||||
raise ValueError('%s cannot specify constraints.' % type(self))
|
||||
kwargs['constraints'] = [SQL('DEFAULT unique_rowid()')]
|
||||
super(RowIDField, self).__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class CockroachDatabase(PostgresqlDatabase):
|
||||
field_types = PostgresqlDatabase.field_types.copy()
|
||||
field_types.update({
|
||||
'BLOB': 'BYTES',
|
||||
})
|
||||
|
||||
for_update = False
|
||||
nulls_ordering = False
|
||||
release_after_rollback = True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs.setdefault('user', 'root')
|
||||
kwargs.setdefault('port', 26257)
|
||||
super(CockroachDatabase, self).__init__(*args, **kwargs)
|
||||
|
||||
def _set_server_version(self, conn):
|
||||
curs = conn.cursor()
|
||||
curs.execute('select version()')
|
||||
raw, = curs.fetchone()
|
||||
match_obj = re.match(r'^CockroachDB.+?v(\d+)\.(\d+)\.(\d+)', raw)
|
||||
if match_obj is not None:
|
||||
clean = '%d%02d%02d' % tuple(int(i) for i in match_obj.groups())
|
||||
self.server_version = int(clean) # 19.1.5 -> 190105.
|
||||
else:
|
||||
# Fallback to use whatever cockroachdb tells us via protocol.
|
||||
super(CockroachDatabase, self)._set_server_version(conn)
|
||||
|
||||
def _get_pk_constraint(self, table, schema=None):
|
||||
query = ('SELECT constraint_name '
|
||||
'FROM information_schema.table_constraints '
|
||||
'WHERE table_name = %s AND table_schema = %s '
|
||||
'AND constraint_type = %s')
|
||||
cursor = self.execute_sql(query, (table, schema or 'public',
|
||||
'PRIMARY KEY'))
|
||||
row = cursor.fetchone()
|
||||
return row and row[0] or None
|
||||
|
||||
def get_indexes(self, table, schema=None):
|
||||
# The primary-key index is returned by default, so we will just strip
|
||||
# it out here.
|
||||
indexes = super(CockroachDatabase, self).get_indexes(table, schema)
|
||||
pkc = self._get_pk_constraint(table, schema)
|
||||
return [idx for idx in indexes if (not pkc) or (idx.name != pkc)]
|
||||
|
||||
def conflict_statement(self, on_conflict, query):
|
||||
if not on_conflict._action: return
|
||||
|
||||
action = on_conflict._action.lower()
|
||||
if action in ('replace', 'upsert'):
|
||||
return SQL('UPSERT')
|
||||
elif action not in ('ignore', 'nothing', 'update'):
|
||||
raise ValueError('Un-supported action for conflict resolution. '
|
||||
'CockroachDB supports REPLACE (UPSERT), IGNORE '
|
||||
'and UPDATE.')
|
||||
|
||||
def conflict_update(self, oc, query):
|
||||
action = oc._action.lower() if oc._action else ''
|
||||
if action in ('ignore', 'nothing'):
|
||||
return SQL('ON CONFLICT DO NOTHING')
|
||||
elif action in ('replace', 'upsert'):
|
||||
# No special stuff is necessary, this is just indicated by starting
|
||||
# the statement with UPSERT instead of INSERT.
|
||||
return
|
||||
elif oc._conflict_constraint:
|
||||
raise ValueError('CockroachDB does not support the usage of a '
|
||||
'constraint name. Use the column(s) instead.')
|
||||
|
||||
return super(CockroachDatabase, self).conflict_update(oc, query)
|
||||
|
||||
def extract_date(self, date_part, date_field):
|
||||
return fn.extract(date_part, date_field)
|
||||
|
||||
def from_timestamp(self, date_field):
|
||||
# CRDB does not allow casting a decimal/float to timestamp, so we first
|
||||
# cast to int, then to timestamptz.
|
||||
return date_field.cast('int').cast('timestamptz')
|
||||
|
||||
def begin(self, system_time=None, priority=None):
|
||||
super(CockroachDatabase, self).begin()
|
||||
if system_time is not None:
|
||||
self.execute_sql('SET TRANSACTION AS OF SYSTEM TIME %s',
|
||||
(system_time,), commit=False)
|
||||
if priority is not None:
|
||||
priority = priority.lower()
|
||||
if priority not in ('low', 'normal', 'high'):
|
||||
raise ValueError('priority must be low, normal or high')
|
||||
self.execute_sql('SET TRANSACTION PRIORITY %s' % priority,
|
||||
commit=False)
|
||||
|
||||
def atomic(self, system_time=None, priority=None):
|
||||
if self.server_version < NESTED_TX_MIN_VERSION:
|
||||
return _crdb_atomic(self, system_time, priority)
|
||||
return super(CockroachDatabase, self).atomic(system_time, priority)
|
||||
|
||||
def savepoint(self):
|
||||
if self.server_version < NESTED_TX_MIN_VERSION:
|
||||
raise NotImplementedError(TXN_ERR_MSG)
|
||||
return super(CockroachDatabase, self).savepoint()
|
||||
|
||||
def retry_transaction(self, max_attempts=None, system_time=None,
|
||||
priority=None):
|
||||
def deco(cb):
|
||||
@functools.wraps(cb)
|
||||
def new_fn():
|
||||
return run_transaction(self, cb, max_attempts, system_time,
|
||||
priority)
|
||||
return new_fn
|
||||
return deco
|
||||
|
||||
def run_transaction(self, cb, max_attempts=None, system_time=None,
|
||||
priority=None):
|
||||
return run_transaction(self, cb, max_attempts, system_time, priority)
|
||||
|
||||
|
||||
class _crdb_atomic(_atomic):
|
||||
def __enter__(self):
|
||||
if self.db.transaction_depth() > 0:
|
||||
if not isinstance(self.db.top_transaction(), _manual):
|
||||
raise NotImplementedError(TXN_ERR_MSG)
|
||||
return super(_crdb_atomic, self).__enter__()
|
||||
|
||||
|
||||
def run_transaction(db, callback, max_attempts=None, system_time=None,
|
||||
priority=None):
|
||||
"""
|
||||
Run transactional SQL in a transaction with automatic retries.
|
||||
|
||||
User-provided `callback`:
|
||||
* Must accept one parameter, the `db` instance representing the connection
|
||||
the transaction is running under.
|
||||
* Must not attempt to commit, rollback or otherwise manage transactions.
|
||||
* May be called more than once.
|
||||
* Should ideally only contain SQL operations.
|
||||
|
||||
Additionally, the database must not have any open transaction at the time
|
||||
this function is called, as CRDB does not support nested transactions.
|
||||
"""
|
||||
max_attempts = max_attempts or -1
|
||||
with db.atomic(system_time=system_time, priority=priority) as txn:
|
||||
db.execute_sql('SAVEPOINT cockroach_restart')
|
||||
while max_attempts != 0:
|
||||
try:
|
||||
result = callback(db)
|
||||
db.execute_sql('RELEASE SAVEPOINT cockroach_restart')
|
||||
return result
|
||||
except OperationalError as exc:
|
||||
if exc.orig.pgcode == '40001':
|
||||
max_attempts -= 1
|
||||
db.execute_sql('ROLLBACK TO SAVEPOINT cockroach_restart')
|
||||
continue
|
||||
raise
|
||||
raise ExceededMaxAttempts(None, 'unable to commit transaction')
|
||||
|
||||
|
||||
class PooledCockroachDatabase(_PooledPostgresqlDatabase, CockroachDatabase):
|
||||
pass
|
@ -0,0 +1,451 @@
|
||||
import csv
|
||||
import datetime
|
||||
from decimal import Decimal
|
||||
import json
|
||||
import operator
|
||||
try:
|
||||
from urlparse import urlparse
|
||||
except ImportError:
|
||||
from urllib.parse import urlparse
|
||||
import sys
|
||||
|
||||
from peewee import *
|
||||
from playhouse.db_url import connect
|
||||
from playhouse.migrate import migrate
|
||||
from playhouse.migrate import SchemaMigrator
|
||||
from playhouse.reflection import Introspector
|
||||
|
||||
if sys.version_info[0] == 3:
|
||||
basestring = str
|
||||
from functools import reduce
|
||||
def open_file(f, mode):
|
||||
return open(f, mode, encoding='utf8')
|
||||
else:
|
||||
open_file = open
|
||||
|
||||
|
||||
class DataSet(object):
|
||||
def __init__(self, url, **kwargs):
|
||||
if isinstance(url, Database):
|
||||
self._url = None
|
||||
self._database = url
|
||||
self._database_path = self._database.database
|
||||
else:
|
||||
self._url = url
|
||||
parse_result = urlparse(url)
|
||||
self._database_path = parse_result.path[1:]
|
||||
|
||||
# Connect to the database.
|
||||
self._database = connect(url)
|
||||
|
||||
self._database.connect()
|
||||
|
||||
# Introspect the database and generate models.
|
||||
self._introspector = Introspector.from_database(self._database)
|
||||
self._models = self._introspector.generate_models(
|
||||
skip_invalid=True,
|
||||
literal_column_names=True,
|
||||
**kwargs)
|
||||
self._migrator = SchemaMigrator.from_database(self._database)
|
||||
|
||||
class BaseModel(Model):
|
||||
class Meta:
|
||||
database = self._database
|
||||
self._base_model = BaseModel
|
||||
self._export_formats = self.get_export_formats()
|
||||
self._import_formats = self.get_import_formats()
|
||||
|
||||
def __repr__(self):
|
||||
return '<DataSet: %s>' % self._database_path
|
||||
|
||||
def get_export_formats(self):
|
||||
return {
|
||||
'csv': CSVExporter,
|
||||
'json': JSONExporter,
|
||||
'tsv': TSVExporter}
|
||||
|
||||
def get_import_formats(self):
|
||||
return {
|
||||
'csv': CSVImporter,
|
||||
'json': JSONImporter,
|
||||
'tsv': TSVImporter}
|
||||
|
||||
def __getitem__(self, table):
|
||||
if table not in self._models and table in self.tables:
|
||||
self.update_cache(table)
|
||||
return Table(self, table, self._models.get(table))
|
||||
|
||||
@property
|
||||
def tables(self):
|
||||
return self._database.get_tables()
|
||||
|
||||
def __contains__(self, table):
|
||||
return table in self.tables
|
||||
|
||||
def connect(self):
|
||||
self._database.connect()
|
||||
|
||||
def close(self):
|
||||
self._database.close()
|
||||
|
||||
def update_cache(self, table=None):
|
||||
if table:
|
||||
dependencies = [table]
|
||||
if table in self._models:
|
||||
model_class = self._models[table]
|
||||
dependencies.extend([
|
||||
related._meta.table_name for _, related, _ in
|
||||
model_class._meta.model_graph()])
|
||||
else:
|
||||
dependencies.extend(self.get_table_dependencies(table))
|
||||
else:
|
||||
dependencies = None # Update all tables.
|
||||
self._models = {}
|
||||
updated = self._introspector.generate_models(
|
||||
skip_invalid=True,
|
||||
table_names=dependencies,
|
||||
literal_column_names=True)
|
||||
self._models.update(updated)
|
||||
|
||||
def get_table_dependencies(self, table):
|
||||
stack = [table]
|
||||
accum = []
|
||||
seen = set()
|
||||
while stack:
|
||||
table = stack.pop()
|
||||
for fk_meta in self._database.get_foreign_keys(table):
|
||||
dest = fk_meta.dest_table
|
||||
if dest not in seen:
|
||||
stack.append(dest)
|
||||
accum.append(dest)
|
||||
return accum
|
||||
|
||||
def __enter__(self):
|
||||
self.connect()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if not self._database.is_closed():
|
||||
self.close()
|
||||
|
||||
def query(self, sql, params=None, commit=True):
|
||||
return self._database.execute_sql(sql, params, commit)
|
||||
|
||||
def transaction(self):
|
||||
if self._database.transaction_depth() == 0:
|
||||
return self._database.transaction()
|
||||
else:
|
||||
return self._database.savepoint()
|
||||
|
||||
def _check_arguments(self, filename, file_obj, format, format_dict):
|
||||
if filename and file_obj:
|
||||
raise ValueError('file is over-specified. Please use either '
|
||||
'filename or file_obj, but not both.')
|
||||
if not filename and not file_obj:
|
||||
raise ValueError('A filename or file-like object must be '
|
||||
'specified.')
|
||||
if format not in format_dict:
|
||||
valid_formats = ', '.join(sorted(format_dict.keys()))
|
||||
raise ValueError('Unsupported format "%s". Use one of %s.' % (
|
||||
format, valid_formats))
|
||||
|
||||
def freeze(self, query, format='csv', filename=None, file_obj=None,
|
||||
**kwargs):
|
||||
self._check_arguments(filename, file_obj, format, self._export_formats)
|
||||
if filename:
|
||||
file_obj = open_file(filename, 'w')
|
||||
|
||||
exporter = self._export_formats[format](query)
|
||||
exporter.export(file_obj, **kwargs)
|
||||
|
||||
if filename:
|
||||
file_obj.close()
|
||||
|
||||
def thaw(self, table, format='csv', filename=None, file_obj=None,
|
||||
strict=False, **kwargs):
|
||||
self._check_arguments(filename, file_obj, format, self._export_formats)
|
||||
if filename:
|
||||
file_obj = open_file(filename, 'r')
|
||||
|
||||
importer = self._import_formats[format](self[table], strict)
|
||||
count = importer.load(file_obj, **kwargs)
|
||||
|
||||
if filename:
|
||||
file_obj.close()
|
||||
|
||||
return count
|
||||
|
||||
|
||||
class Table(object):
|
||||
def __init__(self, dataset, name, model_class):
|
||||
self.dataset = dataset
|
||||
self.name = name
|
||||
if model_class is None:
|
||||
model_class = self._create_model()
|
||||
model_class.create_table()
|
||||
self.dataset._models[name] = model_class
|
||||
|
||||
@property
|
||||
def model_class(self):
|
||||
return self.dataset._models[self.name]
|
||||
|
||||
def __repr__(self):
|
||||
return '<Table: %s>' % self.name
|
||||
|
||||
def __len__(self):
|
||||
return self.find().count()
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.find().iterator())
|
||||
|
||||
def _create_model(self):
|
||||
class Meta:
|
||||
table_name = self.name
|
||||
return type(
|
||||
str(self.name),
|
||||
(self.dataset._base_model,),
|
||||
{'Meta': Meta})
|
||||
|
||||
def create_index(self, columns, unique=False):
|
||||
index = ModelIndex(self.model_class, columns, unique=unique)
|
||||
self.model_class.add_index(index)
|
||||
self.dataset._database.execute(index)
|
||||
|
||||
def _guess_field_type(self, value):
|
||||
if isinstance(value, basestring):
|
||||
return TextField
|
||||
if isinstance(value, (datetime.date, datetime.datetime)):
|
||||
return DateTimeField
|
||||
elif value is True or value is False:
|
||||
return BooleanField
|
||||
elif isinstance(value, int):
|
||||
return IntegerField
|
||||
elif isinstance(value, float):
|
||||
return FloatField
|
||||
elif isinstance(value, Decimal):
|
||||
return DecimalField
|
||||
return TextField
|
||||
|
||||
@property
|
||||
def columns(self):
|
||||
return [f.name for f in self.model_class._meta.sorted_fields]
|
||||
|
||||
def _migrate_new_columns(self, data):
|
||||
new_keys = set(data) - set(self.model_class._meta.fields)
|
||||
if new_keys:
|
||||
operations = []
|
||||
for key in new_keys:
|
||||
field_class = self._guess_field_type(data[key])
|
||||
field = field_class(null=True)
|
||||
operations.append(
|
||||
self.dataset._migrator.add_column(self.name, key, field))
|
||||
field.bind(self.model_class, key)
|
||||
|
||||
migrate(*operations)
|
||||
|
||||
self.dataset.update_cache(self.name)
|
||||
|
||||
def __getitem__(self, item):
|
||||
try:
|
||||
return self.model_class[item]
|
||||
except self.model_class.DoesNotExist:
|
||||
pass
|
||||
|
||||
def __setitem__(self, item, value):
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError('Table.__setitem__() value must be a dict')
|
||||
|
||||
pk = self.model_class._meta.primary_key
|
||||
value[pk.name] = item
|
||||
|
||||
try:
|
||||
with self.dataset.transaction() as txn:
|
||||
self.insert(**value)
|
||||
except IntegrityError:
|
||||
self.dataset.update_cache(self.name)
|
||||
self.update(columns=[pk.name], **value)
|
||||
|
||||
def __delitem__(self, item):
|
||||
del self.model_class[item]
|
||||
|
||||
def insert(self, **data):
|
||||
self._migrate_new_columns(data)
|
||||
return self.model_class.insert(**data).execute()
|
||||
|
||||
def _apply_where(self, query, filters, conjunction=None):
|
||||
conjunction = conjunction or operator.and_
|
||||
if filters:
|
||||
expressions = [
|
||||
(self.model_class._meta.fields[column] == value)
|
||||
for column, value in filters.items()]
|
||||
query = query.where(reduce(conjunction, expressions))
|
||||
return query
|
||||
|
||||
def update(self, columns=None, conjunction=None, **data):
|
||||
self._migrate_new_columns(data)
|
||||
filters = {}
|
||||
if columns:
|
||||
for column in columns:
|
||||
filters[column] = data.pop(column)
|
||||
|
||||
return self._apply_where(
|
||||
self.model_class.update(**data),
|
||||
filters,
|
||||
conjunction).execute()
|
||||
|
||||
def _query(self, **query):
|
||||
return self._apply_where(self.model_class.select(), query)
|
||||
|
||||
def find(self, **query):
|
||||
return self._query(**query).dicts()
|
||||
|
||||
def find_one(self, **query):
|
||||
try:
|
||||
return self.find(**query).get()
|
||||
except self.model_class.DoesNotExist:
|
||||
return None
|
||||
|
||||
def all(self):
|
||||
return self.find()
|
||||
|
||||
def delete(self, **query):
|
||||
return self._apply_where(self.model_class.delete(), query).execute()
|
||||
|
||||
def freeze(self, *args, **kwargs):
|
||||
return self.dataset.freeze(self.all(), *args, **kwargs)
|
||||
|
||||
def thaw(self, *args, **kwargs):
|
||||
return self.dataset.thaw(self.name, *args, **kwargs)
|
||||
|
||||
|
||||
class Exporter(object):
|
||||
def __init__(self, query):
|
||||
self.query = query
|
||||
|
||||
def export(self, file_obj):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class JSONExporter(Exporter):
|
||||
def __init__(self, query, iso8601_datetimes=False):
|
||||
super(JSONExporter, self).__init__(query)
|
||||
self.iso8601_datetimes = iso8601_datetimes
|
||||
|
||||
def _make_default(self):
|
||||
datetime_types = (datetime.datetime, datetime.date, datetime.time)
|
||||
|
||||
if self.iso8601_datetimes:
|
||||
def default(o):
|
||||
if isinstance(o, datetime_types):
|
||||
return o.isoformat()
|
||||
elif isinstance(o, Decimal):
|
||||
return str(o)
|
||||
raise TypeError('Unable to serialize %r as JSON' % o)
|
||||
else:
|
||||
def default(o):
|
||||
if isinstance(o, datetime_types + (Decimal,)):
|
||||
return str(o)
|
||||
raise TypeError('Unable to serialize %r as JSON' % o)
|
||||
return default
|
||||
|
||||
def export(self, file_obj, **kwargs):
|
||||
json.dump(
|
||||
list(self.query),
|
||||
file_obj,
|
||||
default=self._make_default(),
|
||||
**kwargs)
|
||||
|
||||
|
||||
class CSVExporter(Exporter):
|
||||
def export(self, file_obj, header=True, **kwargs):
|
||||
writer = csv.writer(file_obj, **kwargs)
|
||||
tuples = self.query.tuples().execute()
|
||||
tuples.initialize()
|
||||
if header and getattr(tuples, 'columns', None):
|
||||
writer.writerow([column for column in tuples.columns])
|
||||
for row in tuples:
|
||||
writer.writerow(row)
|
||||
|
||||
|
||||
class TSVExporter(CSVExporter):
|
||||
def export(self, file_obj, header=True, **kwargs):
|
||||
kwargs.setdefault('delimiter', '\t')
|
||||
return super(TSVExporter, self).export(file_obj, header, **kwargs)
|
||||
|
||||
|
||||
class Importer(object):
|
||||
def __init__(self, table, strict=False):
|
||||
self.table = table
|
||||
self.strict = strict
|
||||
|
||||
model = self.table.model_class
|
||||
self.columns = model._meta.columns
|
||||
self.columns.update(model._meta.fields)
|
||||
|
||||
def load(self, file_obj):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class JSONImporter(Importer):
|
||||
def load(self, file_obj, **kwargs):
|
||||
data = json.load(file_obj, **kwargs)
|
||||
count = 0
|
||||
|
||||
for row in data:
|
||||
if self.strict:
|
||||
obj = {}
|
||||
for key in row:
|
||||
field = self.columns.get(key)
|
||||
if field is not None:
|
||||
obj[field.name] = field.python_value(row[key])
|
||||
else:
|
||||
obj = row
|
||||
|
||||
if obj:
|
||||
self.table.insert(**obj)
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
|
||||
class CSVImporter(Importer):
|
||||
def load(self, file_obj, header=True, **kwargs):
|
||||
count = 0
|
||||
reader = csv.reader(file_obj, **kwargs)
|
||||
if header:
|
||||
try:
|
||||
header_keys = next(reader)
|
||||
except StopIteration:
|
||||
return count
|
||||
|
||||
if self.strict:
|
||||
header_fields = []
|
||||
for idx, key in enumerate(header_keys):
|
||||
if key in self.columns:
|
||||
header_fields.append((idx, self.columns[key]))
|
||||
else:
|
||||
header_fields = list(enumerate(header_keys))
|
||||
else:
|
||||
header_fields = list(enumerate(self.model._meta.sorted_fields))
|
||||
|
||||
if not header_fields:
|
||||
return count
|
||||
|
||||
for row in reader:
|
||||
obj = {}
|
||||
for idx, field in header_fields:
|
||||
if self.strict:
|
||||
obj[field.name] = field.python_value(row[idx])
|
||||
else:
|
||||
obj[field] = row[idx]
|
||||
|
||||
self.table.insert(**obj)
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
|
||||
class TSVImporter(CSVImporter):
|
||||
def load(self, file_obj, header=True, **kwargs):
|
||||
kwargs.setdefault('delimiter', '\t')
|
||||
return super(TSVImporter, self).load(file_obj, header, **kwargs)
|
@ -0,0 +1,130 @@
|
||||
try:
|
||||
from urlparse import parse_qsl, unquote, urlparse
|
||||
except ImportError:
|
||||
from urllib.parse import parse_qsl, unquote, urlparse
|
||||
|
||||
from peewee import *
|
||||
from playhouse.cockroachdb import CockroachDatabase
|
||||
from playhouse.cockroachdb import PooledCockroachDatabase
|
||||
from playhouse.pool import PooledMySQLDatabase
|
||||
from playhouse.pool import PooledPostgresqlDatabase
|
||||
from playhouse.pool import PooledSqliteDatabase
|
||||
from playhouse.pool import PooledSqliteExtDatabase
|
||||
from playhouse.sqlite_ext import SqliteExtDatabase
|
||||
|
||||
|
||||
schemes = {
|
||||
'cockroachdb': CockroachDatabase,
|
||||
'cockroachdb+pool': PooledCockroachDatabase,
|
||||
'crdb': CockroachDatabase,
|
||||
'crdb+pool': PooledCockroachDatabase,
|
||||
'mysql': MySQLDatabase,
|
||||
'mysql+pool': PooledMySQLDatabase,
|
||||
'postgres': PostgresqlDatabase,
|
||||
'postgresql': PostgresqlDatabase,
|
||||
'postgres+pool': PooledPostgresqlDatabase,
|
||||
'postgresql+pool': PooledPostgresqlDatabase,
|
||||
'sqlite': SqliteDatabase,
|
||||
'sqliteext': SqliteExtDatabase,
|
||||
'sqlite+pool': PooledSqliteDatabase,
|
||||
'sqliteext+pool': PooledSqliteExtDatabase,
|
||||
}
|
||||
|
||||
def register_database(db_class, *names):
|
||||
global schemes
|
||||
for name in names:
|
||||
schemes[name] = db_class
|
||||
|
||||
def parseresult_to_dict(parsed, unquote_password=False):
|
||||
|
||||
# urlparse in python 2.6 is broken so query will be empty and instead
|
||||
# appended to path complete with '?'
|
||||
path_parts = parsed.path[1:].split('?')
|
||||
try:
|
||||
query = path_parts[1]
|
||||
except IndexError:
|
||||
query = parsed.query
|
||||
|
||||
connect_kwargs = {'database': path_parts[0]}
|
||||
if parsed.username:
|
||||
connect_kwargs['user'] = parsed.username
|
||||
if parsed.password:
|
||||
connect_kwargs['password'] = parsed.password
|
||||
if unquote_password:
|
||||
connect_kwargs['password'] = unquote(connect_kwargs['password'])
|
||||
if parsed.hostname:
|
||||
connect_kwargs['host'] = parsed.hostname
|
||||
if parsed.port:
|
||||
connect_kwargs['port'] = parsed.port
|
||||
|
||||
# Adjust parameters for MySQL.
|
||||
if parsed.scheme == 'mysql' and 'password' in connect_kwargs:
|
||||
connect_kwargs['passwd'] = connect_kwargs.pop('password')
|
||||
elif 'sqlite' in parsed.scheme and not connect_kwargs['database']:
|
||||
connect_kwargs['database'] = ':memory:'
|
||||
|
||||
# Get additional connection args from the query string
|
||||
qs_args = parse_qsl(query, keep_blank_values=True)
|
||||
for key, value in qs_args:
|
||||
if value.lower() == 'false':
|
||||
value = False
|
||||
elif value.lower() == 'true':
|
||||
value = True
|
||||
elif value.isdigit():
|
||||
value = int(value)
|
||||
elif '.' in value and all(p.isdigit() for p in value.split('.', 1)):
|
||||
try:
|
||||
value = float(value)
|
||||
except ValueError:
|
||||
pass
|
||||
elif value.lower() in ('null', 'none'):
|
||||
value = None
|
||||
|
||||
connect_kwargs[key] = value
|
||||
|
||||
return connect_kwargs
|
||||
|
||||
def parse(url, unquote_password=False):
|
||||
parsed = urlparse(url)
|
||||
return parseresult_to_dict(parsed, unquote_password)
|
||||
|
||||
def connect(url, unquote_password=False, **connect_params):
|
||||
parsed = urlparse(url)
|
||||
connect_kwargs = parseresult_to_dict(parsed, unquote_password)
|
||||
connect_kwargs.update(connect_params)
|
||||
database_class = schemes.get(parsed.scheme)
|
||||
|
||||
if database_class is None:
|
||||
if database_class in schemes:
|
||||
raise RuntimeError('Attempted to use "%s" but a required library '
|
||||
'could not be imported.' % parsed.scheme)
|
||||
else:
|
||||
raise RuntimeError('Unrecognized or unsupported scheme: "%s".' %
|
||||
parsed.scheme)
|
||||
|
||||
return database_class(**connect_kwargs)
|
||||
|
||||
# Conditionally register additional databases.
|
||||
try:
|
||||
from playhouse.pool import PooledPostgresqlExtDatabase
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
register_database(
|
||||
PooledPostgresqlExtDatabase,
|
||||
'postgresext+pool',
|
||||
'postgresqlext+pool')
|
||||
|
||||
try:
|
||||
from playhouse.apsw_ext import APSWDatabase
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
register_database(APSWDatabase, 'apsw')
|
||||
|
||||
try:
|
||||
from playhouse.postgres_ext import PostgresqlExtDatabase
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
register_database(PostgresqlExtDatabase, 'postgresext', 'postgresqlext')
|
@ -0,0 +1,64 @@
|
||||
try:
|
||||
import bz2
|
||||
except ImportError:
|
||||
bz2 = None
|
||||
try:
|
||||
import zlib
|
||||
except ImportError:
|
||||
zlib = None
|
||||
try:
|
||||
import cPickle as pickle
|
||||
except ImportError:
|
||||
import pickle
|
||||
import sys
|
||||
|
||||
from peewee import BlobField
|
||||
from peewee import buffer_type
|
||||
|
||||
|
||||
PY2 = sys.version_info[0] == 2
|
||||
|
||||
|
||||
class CompressedField(BlobField):
|
||||
ZLIB = 'zlib'
|
||||
BZ2 = 'bz2'
|
||||
algorithm_to_import = {
|
||||
ZLIB: zlib,
|
||||
BZ2: bz2,
|
||||
}
|
||||
|
||||
def __init__(self, compression_level=6, algorithm=ZLIB, *args,
|
||||
**kwargs):
|
||||
self.compression_level = compression_level
|
||||
if algorithm not in self.algorithm_to_import:
|
||||
raise ValueError('Unrecognized algorithm %s' % algorithm)
|
||||
compress_module = self.algorithm_to_import[algorithm]
|
||||
if compress_module is None:
|
||||
raise ValueError('Missing library required for %s.' % algorithm)
|
||||
|
||||
self.algorithm = algorithm
|
||||
self.compress = compress_module.compress
|
||||
self.decompress = compress_module.decompress
|
||||
super(CompressedField, self).__init__(*args, **kwargs)
|
||||
|
||||
def python_value(self, value):
|
||||
if value is not None:
|
||||
return self.decompress(value)
|
||||
|
||||
def db_value(self, value):
|
||||
if value is not None:
|
||||
return self._constructor(
|
||||
self.compress(value, self.compression_level))
|
||||
|
||||
|
||||
class PickleField(BlobField):
|
||||
def python_value(self, value):
|
||||
if value is not None:
|
||||
if isinstance(value, buffer_type):
|
||||
value = bytes(value)
|
||||
return pickle.loads(value)
|
||||
|
||||
def db_value(self, value):
|
||||
if value is not None:
|
||||
pickled = pickle.dumps(value, pickle.HIGHEST_PROTOCOL)
|
||||
return self._constructor(pickled)
|
@ -0,0 +1,185 @@
|
||||
import math
|
||||
import sys
|
||||
|
||||
from flask import abort
|
||||
from flask import render_template
|
||||
from flask import request
|
||||
from peewee import Database
|
||||
from peewee import DoesNotExist
|
||||
from peewee import Model
|
||||
from peewee import Proxy
|
||||
from peewee import SelectQuery
|
||||
from playhouse.db_url import connect as db_url_connect
|
||||
|
||||
|
||||
class PaginatedQuery(object):
|
||||
def __init__(self, query_or_model, paginate_by, page_var='page', page=None,
|
||||
check_bounds=False):
|
||||
self.paginate_by = paginate_by
|
||||
self.page_var = page_var
|
||||
self.page = page or None
|
||||
self.check_bounds = check_bounds
|
||||
|
||||
if isinstance(query_or_model, SelectQuery):
|
||||
self.query = query_or_model
|
||||
self.model = self.query.model
|
||||
else:
|
||||
self.model = query_or_model
|
||||
self.query = self.model.select()
|
||||
|
||||
def get_page(self):
|
||||
if self.page is not None:
|
||||
return self.page
|
||||
|
||||
curr_page = request.args.get(self.page_var)
|
||||
if curr_page and curr_page.isdigit():
|
||||
return max(1, int(curr_page))
|
||||
return 1
|
||||
|
||||
def get_page_count(self):
|
||||
if not hasattr(self, '_page_count'):
|
||||
self._page_count = int(math.ceil(
|
||||
float(self.query.count()) / self.paginate_by))
|
||||
return self._page_count
|
||||
|
||||
def get_object_list(self):
|
||||
if self.check_bounds and self.get_page() > self.get_page_count():
|
||||
abort(404)
|
||||
return self.query.paginate(self.get_page(), self.paginate_by)
|
||||
|
||||
|
||||
def get_object_or_404(query_or_model, *query):
|
||||
if not isinstance(query_or_model, SelectQuery):
|
||||
query_or_model = query_or_model.select()
|
||||
try:
|
||||
return query_or_model.where(*query).get()
|
||||
except DoesNotExist:
|
||||
abort(404)
|
||||
|
||||
def object_list(template_name, query, context_variable='object_list',
|
||||
paginate_by=20, page_var='page', page=None, check_bounds=True,
|
||||
**kwargs):
|
||||
paginated_query = PaginatedQuery(
|
||||
query,
|
||||
paginate_by=paginate_by,
|
||||
page_var=page_var,
|
||||
page=page,
|
||||
check_bounds=check_bounds)
|
||||
kwargs[context_variable] = paginated_query.get_object_list()
|
||||
return render_template(
|
||||
template_name,
|
||||
pagination=paginated_query,
|
||||
page=paginated_query.get_page(),
|
||||
**kwargs)
|
||||
|
||||
def get_current_url():
|
||||
if not request.query_string:
|
||||
return request.path
|
||||
return '%s?%s' % (request.path, request.query_string)
|
||||
|
||||
def get_next_url(default='/'):
|
||||
if request.args.get('next'):
|
||||
return request.args['next']
|
||||
elif request.form.get('next'):
|
||||
return request.form['next']
|
||||
return default
|
||||
|
||||
class FlaskDB(object):
|
||||
def __init__(self, app=None, database=None, model_class=Model):
|
||||
self.database = None # Reference to actual Peewee database instance.
|
||||
self.base_model_class = model_class
|
||||
self._app = app
|
||||
self._db = database # dict, url, Database, or None (default).
|
||||
if app is not None:
|
||||
self.init_app(app)
|
||||
|
||||
def init_app(self, app):
|
||||
self._app = app
|
||||
|
||||
if self._db is None:
|
||||
if 'DATABASE' in app.config:
|
||||
initial_db = app.config['DATABASE']
|
||||
elif 'DATABASE_URL' in app.config:
|
||||
initial_db = app.config['DATABASE_URL']
|
||||
else:
|
||||
raise ValueError('Missing required configuration data for '
|
||||
'database: DATABASE or DATABASE_URL.')
|
||||
else:
|
||||
initial_db = self._db
|
||||
|
||||
self._load_database(app, initial_db)
|
||||
self._register_handlers(app)
|
||||
|
||||
def _load_database(self, app, config_value):
|
||||
if isinstance(config_value, Database):
|
||||
database = config_value
|
||||
elif isinstance(config_value, dict):
|
||||
database = self._load_from_config_dict(dict(config_value))
|
||||
else:
|
||||
# Assume a database connection URL.
|
||||
database = db_url_connect(config_value)
|
||||
|
||||
if isinstance(self.database, Proxy):
|
||||
self.database.initialize(database)
|
||||
else:
|
||||
self.database = database
|
||||
|
||||
def _load_from_config_dict(self, config_dict):
|
||||
try:
|
||||
name = config_dict.pop('name')
|
||||
engine = config_dict.pop('engine')
|
||||
except KeyError:
|
||||
raise RuntimeError('DATABASE configuration must specify a '
|
||||
'`name` and `engine`.')
|
||||
|
||||
if '.' in engine:
|
||||
path, class_name = engine.rsplit('.', 1)
|
||||
else:
|
||||
path, class_name = 'peewee', engine
|
||||
|
||||
try:
|
||||
__import__(path)
|
||||
module = sys.modules[path]
|
||||
database_class = getattr(module, class_name)
|
||||
assert issubclass(database_class, Database)
|
||||
except ImportError:
|
||||
raise RuntimeError('Unable to import %s' % engine)
|
||||
except AttributeError:
|
||||
raise RuntimeError('Database engine not found %s' % engine)
|
||||
except AssertionError:
|
||||
raise RuntimeError('Database engine not a subclass of '
|
||||
'peewee.Database: %s' % engine)
|
||||
|
||||
return database_class(name, **config_dict)
|
||||
|
||||
def _register_handlers(self, app):
|
||||
app.before_request(self.connect_db)
|
||||
app.teardown_request(self.close_db)
|
||||
|
||||
def get_model_class(self):
|
||||
if self.database is None:
|
||||
raise RuntimeError('Database must be initialized.')
|
||||
|
||||
class BaseModel(self.base_model_class):
|
||||
class Meta:
|
||||
database = self.database
|
||||
|
||||
return BaseModel
|
||||
|
||||
@property
|
||||
def Model(self):
|
||||
if self._app is None:
|
||||
database = getattr(self, 'database', None)
|
||||
if database is None:
|
||||
self.database = Proxy()
|
||||
|
||||
if not hasattr(self, '_model_class'):
|
||||
self._model_class = self.get_model_class()
|
||||
return self._model_class
|
||||
|
||||
def connect_db(self):
|
||||
self.database.connect()
|
||||
|
||||
def close_db(self, exc):
|
||||
if not self.database.is_closed():
|
||||
self.database.close()
|
@ -0,0 +1,53 @@
|
||||
from peewee import ModelDescriptor
|
||||
|
||||
|
||||
# Hybrid methods/attributes, based on similar functionality in SQLAlchemy:
|
||||
# http://docs.sqlalchemy.org/en/improve_toc/orm/extensions/hybrid.html
|
||||
class hybrid_method(ModelDescriptor):
|
||||
def __init__(self, func, expr=None):
|
||||
self.func = func
|
||||
self.expr = expr or func
|
||||
|
||||
def __get__(self, instance, instance_type):
|
||||
if instance is None:
|
||||
return self.expr.__get__(instance_type, instance_type.__class__)
|
||||
return self.func.__get__(instance, instance_type)
|
||||
|
||||
def expression(self, expr):
|
||||
self.expr = expr
|
||||
return self
|
||||
|
||||
|
||||
class hybrid_property(ModelDescriptor):
|
||||
def __init__(self, fget, fset=None, fdel=None, expr=None):
|
||||
self.fget = fget
|
||||
self.fset = fset
|
||||
self.fdel = fdel
|
||||
self.expr = expr or fget
|
||||
|
||||
def __get__(self, instance, instance_type):
|
||||
if instance is None:
|
||||
return self.expr(instance_type)
|
||||
return self.fget(instance)
|
||||
|
||||
def __set__(self, instance, value):
|
||||
if self.fset is None:
|
||||
raise AttributeError('Cannot set attribute.')
|
||||
self.fset(instance, value)
|
||||
|
||||
def __delete__(self, instance):
|
||||
if self.fdel is None:
|
||||
raise AttributeError('Cannot delete attribute.')
|
||||
self.fdel(instance)
|
||||
|
||||
def setter(self, fset):
|
||||
self.fset = fset
|
||||
return self
|
||||
|
||||
def deleter(self, fdel):
|
||||
self.fdel = fdel
|
||||
return self
|
||||
|
||||
def expression(self, expr):
|
||||
self.expr = expr
|
||||
return self
|
@ -0,0 +1,172 @@
|
||||
import operator
|
||||
|
||||
from peewee import *
|
||||
from peewee import Expression
|
||||
from playhouse.fields import PickleField
|
||||
try:
|
||||
from playhouse.sqlite_ext import CSqliteExtDatabase as SqliteExtDatabase
|
||||
except ImportError:
|
||||
from playhouse.sqlite_ext import SqliteExtDatabase
|
||||
|
||||
|
||||
Sentinel = type('Sentinel', (object,), {})
|
||||
|
||||
|
||||
class KeyValue(object):
|
||||
"""
|
||||
Persistent dictionary.
|
||||
|
||||
:param Field key_field: field to use for key. Defaults to CharField.
|
||||
:param Field value_field: field to use for value. Defaults to PickleField.
|
||||
:param bool ordered: data should be returned in key-sorted order.
|
||||
:param Database database: database where key/value data is stored.
|
||||
:param str table_name: table name for data.
|
||||
"""
|
||||
def __init__(self, key_field=None, value_field=None, ordered=False,
|
||||
database=None, table_name='keyvalue'):
|
||||
if key_field is None:
|
||||
key_field = CharField(max_length=255, primary_key=True)
|
||||
if not key_field.primary_key:
|
||||
raise ValueError('key_field must have primary_key=True.')
|
||||
|
||||
if value_field is None:
|
||||
value_field = PickleField()
|
||||
|
||||
self._key_field = key_field
|
||||
self._value_field = value_field
|
||||
self._ordered = ordered
|
||||
self._database = database or SqliteExtDatabase(':memory:')
|
||||
self._table_name = table_name
|
||||
if isinstance(self._database, PostgresqlDatabase):
|
||||
self.upsert = self._postgres_upsert
|
||||
self.update = self._postgres_update
|
||||
else:
|
||||
self.upsert = self._upsert
|
||||
self.update = self._update
|
||||
|
||||
self.model = self.create_model()
|
||||
self.key = self.model.key
|
||||
self.value = self.model.value
|
||||
|
||||
# Ensure table exists.
|
||||
self.model.create_table()
|
||||
|
||||
def create_model(self):
|
||||
class KeyValue(Model):
|
||||
key = self._key_field
|
||||
value = self._value_field
|
||||
class Meta:
|
||||
database = self._database
|
||||
table_name = self._table_name
|
||||
return KeyValue
|
||||
|
||||
def query(self, *select):
|
||||
query = self.model.select(*select).tuples()
|
||||
if self._ordered:
|
||||
query = query.order_by(self.key)
|
||||
return query
|
||||
|
||||
def convert_expression(self, expr):
|
||||
if not isinstance(expr, Expression):
|
||||
return (self.key == expr), True
|
||||
return expr, False
|
||||
|
||||
def __contains__(self, key):
|
||||
expr, _ = self.convert_expression(key)
|
||||
return self.model.select().where(expr).exists()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.model)
|
||||
|
||||
def __getitem__(self, expr):
|
||||
converted, is_single = self.convert_expression(expr)
|
||||
query = self.query(self.value).where(converted)
|
||||
item_getter = operator.itemgetter(0)
|
||||
result = [item_getter(row) for row in query]
|
||||
if len(result) == 0 and is_single:
|
||||
raise KeyError(expr)
|
||||
elif is_single:
|
||||
return result[0]
|
||||
return result
|
||||
|
||||
def _upsert(self, key, value):
|
||||
(self.model
|
||||
.insert(key=key, value=value)
|
||||
.on_conflict('replace')
|
||||
.execute())
|
||||
|
||||
def _postgres_upsert(self, key, value):
|
||||
(self.model
|
||||
.insert(key=key, value=value)
|
||||
.on_conflict(conflict_target=[self.key],
|
||||
preserve=[self.value])
|
||||
.execute())
|
||||
|
||||
def __setitem__(self, expr, value):
|
||||
if isinstance(expr, Expression):
|
||||
self.model.update(value=value).where(expr).execute()
|
||||
else:
|
||||
self.upsert(expr, value)
|
||||
|
||||
def __delitem__(self, expr):
|
||||
converted, _ = self.convert_expression(expr)
|
||||
self.model.delete().where(converted).execute()
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.query().execute())
|
||||
|
||||
def keys(self):
|
||||
return map(operator.itemgetter(0), self.query(self.key))
|
||||
|
||||
def values(self):
|
||||
return map(operator.itemgetter(0), self.query(self.value))
|
||||
|
||||
def items(self):
|
||||
return iter(self.query().execute())
|
||||
|
||||
def _update(self, __data=None, **mapping):
|
||||
if __data is not None:
|
||||
mapping.update(__data)
|
||||
return (self.model
|
||||
.insert_many(list(mapping.items()),
|
||||
fields=[self.key, self.value])
|
||||
.on_conflict('replace')
|
||||
.execute())
|
||||
|
||||
def _postgres_update(self, __data=None, **mapping):
|
||||
if __data is not None:
|
||||
mapping.update(__data)
|
||||
return (self.model
|
||||
.insert_many(list(mapping.items()),
|
||||
fields=[self.key, self.value])
|
||||
.on_conflict(conflict_target=[self.key],
|
||||
preserve=[self.value])
|
||||
.execute())
|
||||
|
||||
def get(self, key, default=None):
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
return default
|
||||
|
||||
def setdefault(self, key, default=None):
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
self[key] = default
|
||||
return default
|
||||
|
||||
def pop(self, key, default=Sentinel):
|
||||
with self._database.atomic():
|
||||
try:
|
||||
result = self[key]
|
||||
except KeyError:
|
||||
if default is Sentinel:
|
||||
raise
|
||||
return default
|
||||
del self[key]
|
||||
|
||||
return result
|
||||
|
||||
def clear(self):
|
||||
self.model.delete().execute()
|
@ -0,0 +1,886 @@
|
||||
"""
|
||||
Lightweight schema migrations.
|
||||
|
||||
NOTE: Currently tested with SQLite and Postgresql. MySQL may be missing some
|
||||
features.
|
||||
|
||||
Example Usage
|
||||
-------------
|
||||
|
||||
Instantiate a migrator:
|
||||
|
||||
# Postgres example:
|
||||
my_db = PostgresqlDatabase(...)
|
||||
migrator = PostgresqlMigrator(my_db)
|
||||
|
||||
# SQLite example:
|
||||
my_db = SqliteDatabase('my_database.db')
|
||||
migrator = SqliteMigrator(my_db)
|
||||
|
||||
Then you will use the `migrate` function to run various `Operation`s which
|
||||
are generated by the migrator:
|
||||
|
||||
migrate(
|
||||
migrator.add_column('some_table', 'column_name', CharField(default=''))
|
||||
)
|
||||
|
||||
Migrations are not run inside a transaction, so if you wish the migration to
|
||||
run in a transaction you will need to wrap the call to `migrate` in a
|
||||
transaction block, e.g.:
|
||||
|
||||
with my_db.transaction():
|
||||
migrate(...)
|
||||
|
||||
Supported Operations
|
||||
--------------------
|
||||
|
||||
Add new field(s) to an existing model:
|
||||
|
||||
# Create your field instances. For non-null fields you must specify a
|
||||
# default value.
|
||||
pubdate_field = DateTimeField(null=True)
|
||||
comment_field = TextField(default='')
|
||||
|
||||
# Run the migration, specifying the database table, field name and field.
|
||||
migrate(
|
||||
migrator.add_column('comment_tbl', 'pub_date', pubdate_field),
|
||||
migrator.add_column('comment_tbl', 'comment', comment_field),
|
||||
)
|
||||
|
||||
Renaming a field:
|
||||
|
||||
# Specify the table, original name of the column, and its new name.
|
||||
migrate(
|
||||
migrator.rename_column('story', 'pub_date', 'publish_date'),
|
||||
migrator.rename_column('story', 'mod_date', 'modified_date'),
|
||||
)
|
||||
|
||||
Dropping a field:
|
||||
|
||||
migrate(
|
||||
migrator.drop_column('story', 'some_old_field'),
|
||||
)
|
||||
|
||||
Making a field nullable or not nullable:
|
||||
|
||||
# Note that when making a field not null that field must not have any
|
||||
# NULL values present.
|
||||
migrate(
|
||||
# Make `pub_date` allow NULL values.
|
||||
migrator.drop_not_null('story', 'pub_date'),
|
||||
|
||||
# Prevent `modified_date` from containing NULL values.
|
||||
migrator.add_not_null('story', 'modified_date'),
|
||||
)
|
||||
|
||||
Renaming a table:
|
||||
|
||||
migrate(
|
||||
migrator.rename_table('story', 'stories_tbl'),
|
||||
)
|
||||
|
||||
Adding an index:
|
||||
|
||||
# Specify the table, column names, and whether the index should be
|
||||
# UNIQUE or not.
|
||||
migrate(
|
||||
# Create an index on the `pub_date` column.
|
||||
migrator.add_index('story', ('pub_date',), False),
|
||||
|
||||
# Create a multi-column index on the `pub_date` and `status` fields.
|
||||
migrator.add_index('story', ('pub_date', 'status'), False),
|
||||
|
||||
# Create a unique index on the category and title fields.
|
||||
migrator.add_index('story', ('category_id', 'title'), True),
|
||||
)
|
||||
|
||||
Dropping an index:
|
||||
|
||||
# Specify the index name.
|
||||
migrate(migrator.drop_index('story', 'story_pub_date_status'))
|
||||
|
||||
Adding or dropping table constraints:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Add a CHECK() constraint to enforce the price cannot be negative.
|
||||
migrate(migrator.add_constraint(
|
||||
'products',
|
||||
'price_check',
|
||||
Check('price >= 0')))
|
||||
|
||||
# Remove the price check constraint.
|
||||
migrate(migrator.drop_constraint('products', 'price_check'))
|
||||
|
||||
# Add a UNIQUE constraint on the first and last names.
|
||||
migrate(migrator.add_unique('person', 'first_name', 'last_name'))
|
||||
"""
|
||||
from collections import namedtuple
|
||||
import functools
|
||||
import hashlib
|
||||
import re
|
||||
|
||||
from peewee import *
|
||||
from peewee import CommaNodeList
|
||||
from peewee import EnclosedNodeList
|
||||
from peewee import Entity
|
||||
from peewee import Expression
|
||||
from peewee import Node
|
||||
from peewee import NodeList
|
||||
from peewee import OP
|
||||
from peewee import callable_
|
||||
from peewee import sort_models
|
||||
from peewee import _truncate_constraint_name
|
||||
try:
|
||||
from playhouse.cockroachdb import CockroachDatabase
|
||||
except ImportError:
|
||||
CockroachDatabase = None
|
||||
|
||||
|
||||
class Operation(object):
|
||||
"""Encapsulate a single schema altering operation."""
|
||||
def __init__(self, migrator, method, *args, **kwargs):
|
||||
self.migrator = migrator
|
||||
self.method = method
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
def execute(self, node):
|
||||
self.migrator.database.execute(node)
|
||||
|
||||
def _handle_result(self, result):
|
||||
if isinstance(result, (Node, Context)):
|
||||
self.execute(result)
|
||||
elif isinstance(result, Operation):
|
||||
result.run()
|
||||
elif isinstance(result, (list, tuple)):
|
||||
for item in result:
|
||||
self._handle_result(item)
|
||||
|
||||
def run(self):
|
||||
kwargs = self.kwargs.copy()
|
||||
kwargs['with_context'] = True
|
||||
method = getattr(self.migrator, self.method)
|
||||
self._handle_result(method(*self.args, **kwargs))
|
||||
|
||||
|
||||
def operation(fn):
|
||||
@functools.wraps(fn)
|
||||
def inner(self, *args, **kwargs):
|
||||
with_context = kwargs.pop('with_context', False)
|
||||
if with_context:
|
||||
return fn(self, *args, **kwargs)
|
||||
return Operation(self, fn.__name__, *args, **kwargs)
|
||||
return inner
|
||||
|
||||
|
||||
def make_index_name(table_name, columns):
|
||||
index_name = '_'.join((table_name,) + tuple(columns))
|
||||
if len(index_name) > 64:
|
||||
index_hash = hashlib.md5(index_name.encode('utf-8')).hexdigest()
|
||||
index_name = '%s_%s' % (index_name[:56], index_hash[:7])
|
||||
return index_name
|
||||
|
||||
|
||||
class SchemaMigrator(object):
|
||||
explicit_create_foreign_key = False
|
||||
explicit_delete_foreign_key = False
|
||||
|
||||
def __init__(self, database):
|
||||
self.database = database
|
||||
|
||||
def make_context(self):
|
||||
return self.database.get_sql_context()
|
||||
|
||||
@classmethod
|
||||
def from_database(cls, database):
|
||||
if CockroachDatabase and isinstance(database, CockroachDatabase):
|
||||
return CockroachDBMigrator(database)
|
||||
elif isinstance(database, PostgresqlDatabase):
|
||||
return PostgresqlMigrator(database)
|
||||
elif isinstance(database, MySQLDatabase):
|
||||
return MySQLMigrator(database)
|
||||
elif isinstance(database, SqliteDatabase):
|
||||
return SqliteMigrator(database)
|
||||
raise ValueError('Unsupported database: %s' % database)
|
||||
|
||||
@operation
|
||||
def apply_default(self, table, column_name, field):
|
||||
default = field.default
|
||||
if callable_(default):
|
||||
default = default()
|
||||
|
||||
return (self.make_context()
|
||||
.literal('UPDATE ')
|
||||
.sql(Entity(table))
|
||||
.literal(' SET ')
|
||||
.sql(Expression(
|
||||
Entity(column_name),
|
||||
OP.EQ,
|
||||
field.db_value(default),
|
||||
flat=True)))
|
||||
|
||||
def _alter_table(self, ctx, table):
|
||||
return ctx.literal('ALTER TABLE ').sql(Entity(table))
|
||||
|
||||
def _alter_column(self, ctx, table, column):
|
||||
return (self
|
||||
._alter_table(ctx, table)
|
||||
.literal(' ALTER COLUMN ')
|
||||
.sql(Entity(column)))
|
||||
|
||||
@operation
|
||||
def alter_add_column(self, table, column_name, field):
|
||||
# Make field null at first.
|
||||
ctx = self.make_context()
|
||||
field_null, field.null = field.null, True
|
||||
|
||||
# Set the field's column-name and name, if it is not set or doesn't
|
||||
# match the new value.
|
||||
if field.column_name != column_name:
|
||||
field.name = field.column_name = column_name
|
||||
|
||||
(self
|
||||
._alter_table(ctx, table)
|
||||
.literal(' ADD COLUMN ')
|
||||
.sql(field.ddl(ctx)))
|
||||
|
||||
field.null = field_null
|
||||
if isinstance(field, ForeignKeyField):
|
||||
self.add_inline_fk_sql(ctx, field)
|
||||
return ctx
|
||||
|
||||
@operation
|
||||
def add_constraint(self, table, name, constraint):
|
||||
return (self
|
||||
._alter_table(self.make_context(), table)
|
||||
.literal(' ADD CONSTRAINT ')
|
||||
.sql(Entity(name))
|
||||
.literal(' ')
|
||||
.sql(constraint))
|
||||
|
||||
@operation
|
||||
def add_unique(self, table, *column_names):
|
||||
constraint_name = 'uniq_%s' % '_'.join(column_names)
|
||||
constraint = NodeList((
|
||||
SQL('UNIQUE'),
|
||||
EnclosedNodeList([Entity(column) for column in column_names])))
|
||||
return self.add_constraint(table, constraint_name, constraint)
|
||||
|
||||
@operation
|
||||
def drop_constraint(self, table, name):
|
||||
return (self
|
||||
._alter_table(self.make_context(), table)
|
||||
.literal(' DROP CONSTRAINT ')
|
||||
.sql(Entity(name)))
|
||||
|
||||
def add_inline_fk_sql(self, ctx, field):
|
||||
ctx = (ctx
|
||||
.literal(' REFERENCES ')
|
||||
.sql(Entity(field.rel_model._meta.table_name))
|
||||
.literal(' ')
|
||||
.sql(EnclosedNodeList((Entity(field.rel_field.column_name),))))
|
||||
if field.on_delete is not None:
|
||||
ctx = ctx.literal(' ON DELETE %s' % field.on_delete)
|
||||
if field.on_update is not None:
|
||||
ctx = ctx.literal(' ON UPDATE %s' % field.on_update)
|
||||
return ctx
|
||||
|
||||
@operation
|
||||
def add_foreign_key_constraint(self, table, column_name, rel, rel_column,
|
||||
on_delete=None, on_update=None):
|
||||
constraint = 'fk_%s_%s_refs_%s' % (table, column_name, rel)
|
||||
ctx = (self
|
||||
.make_context()
|
||||
.literal('ALTER TABLE ')
|
||||
.sql(Entity(table))
|
||||
.literal(' ADD CONSTRAINT ')
|
||||
.sql(Entity(_truncate_constraint_name(constraint)))
|
||||
.literal(' FOREIGN KEY ')
|
||||
.sql(EnclosedNodeList((Entity(column_name),)))
|
||||
.literal(' REFERENCES ')
|
||||
.sql(Entity(rel))
|
||||
.literal(' (')
|
||||
.sql(Entity(rel_column))
|
||||
.literal(')'))
|
||||
if on_delete is not None:
|
||||
ctx = ctx.literal(' ON DELETE %s' % on_delete)
|
||||
if on_update is not None:
|
||||
ctx = ctx.literal(' ON UPDATE %s' % on_update)
|
||||
return ctx
|
||||
|
||||
@operation
|
||||
def add_column(self, table, column_name, field):
|
||||
# Adding a column is complicated by the fact that if there are rows
|
||||
# present and the field is non-null, then we need to first add the
|
||||
# column as a nullable field, then set the value, then add a not null
|
||||
# constraint.
|
||||
if not field.null and field.default is None:
|
||||
raise ValueError('%s is not null but has no default' % column_name)
|
||||
|
||||
is_foreign_key = isinstance(field, ForeignKeyField)
|
||||
if is_foreign_key and not field.rel_field:
|
||||
raise ValueError('Foreign keys must specify a `field`.')
|
||||
|
||||
operations = [self.alter_add_column(table, column_name, field)]
|
||||
|
||||
# In the event the field is *not* nullable, update with the default
|
||||
# value and set not null.
|
||||
if not field.null:
|
||||
operations.extend([
|
||||
self.apply_default(table, column_name, field),
|
||||
self.add_not_null(table, column_name)])
|
||||
|
||||
if is_foreign_key and self.explicit_create_foreign_key:
|
||||
operations.append(
|
||||
self.add_foreign_key_constraint(
|
||||
table,
|
||||
column_name,
|
||||
field.rel_model._meta.table_name,
|
||||
field.rel_field.column_name,
|
||||
field.on_delete,
|
||||
field.on_update))
|
||||
|
||||
if field.index or field.unique:
|
||||
using = getattr(field, 'index_type', None)
|
||||
operations.append(self.add_index(table, (column_name,),
|
||||
field.unique, using))
|
||||
|
||||
return operations
|
||||
|
||||
@operation
|
||||
def drop_foreign_key_constraint(self, table, column_name):
|
||||
raise NotImplementedError
|
||||
|
||||
@operation
|
||||
def drop_column(self, table, column_name, cascade=True):
|
||||
ctx = self.make_context()
|
||||
(self._alter_table(ctx, table)
|
||||
.literal(' DROP COLUMN ')
|
||||
.sql(Entity(column_name)))
|
||||
|
||||
if cascade:
|
||||
ctx.literal(' CASCADE')
|
||||
|
||||
fk_columns = [
|
||||
foreign_key.column
|
||||
for foreign_key in self.database.get_foreign_keys(table)]
|
||||
if column_name in fk_columns and self.explicit_delete_foreign_key:
|
||||
return [self.drop_foreign_key_constraint(table, column_name), ctx]
|
||||
|
||||
return ctx
|
||||
|
||||
@operation
|
||||
def rename_column(self, table, old_name, new_name):
|
||||
return (self
|
||||
._alter_table(self.make_context(), table)
|
||||
.literal(' RENAME COLUMN ')
|
||||
.sql(Entity(old_name))
|
||||
.literal(' TO ')
|
||||
.sql(Entity(new_name)))
|
||||
|
||||
@operation
|
||||
def add_not_null(self, table, column):
|
||||
return (self
|
||||
._alter_column(self.make_context(), table, column)
|
||||
.literal(' SET NOT NULL'))
|
||||
|
||||
@operation
|
||||
def drop_not_null(self, table, column):
|
||||
return (self
|
||||
._alter_column(self.make_context(), table, column)
|
||||
.literal(' DROP NOT NULL'))
|
||||
|
||||
@operation
|
||||
def alter_column_type(self, table, column, field, cast=None):
|
||||
# ALTER TABLE <table> ALTER COLUMN <column>
|
||||
ctx = self.make_context()
|
||||
ctx = (self
|
||||
._alter_column(ctx, table, column)
|
||||
.literal(' TYPE ')
|
||||
.sql(field.ddl_datatype(ctx)))
|
||||
if cast is not None:
|
||||
if not isinstance(cast, Node):
|
||||
cast = SQL(cast)
|
||||
ctx = ctx.literal(' USING ').sql(cast)
|
||||
return ctx
|
||||
|
||||
@operation
|
||||
def rename_table(self, old_name, new_name):
|
||||
return (self
|
||||
._alter_table(self.make_context(), old_name)
|
||||
.literal(' RENAME TO ')
|
||||
.sql(Entity(new_name)))
|
||||
|
||||
@operation
|
||||
def add_index(self, table, columns, unique=False, using=None):
|
||||
ctx = self.make_context()
|
||||
index_name = make_index_name(table, columns)
|
||||
table_obj = Table(table)
|
||||
cols = [getattr(table_obj.c, column) for column in columns]
|
||||
index = Index(index_name, table_obj, cols, unique=unique, using=using)
|
||||
return ctx.sql(index)
|
||||
|
||||
@operation
|
||||
def drop_index(self, table, index_name):
|
||||
return (self
|
||||
.make_context()
|
||||
.literal('DROP INDEX ')
|
||||
.sql(Entity(index_name)))
|
||||
|
||||
|
||||
class PostgresqlMigrator(SchemaMigrator):
|
||||
def _primary_key_columns(self, tbl):
|
||||
query = """
|
||||
SELECT pg_attribute.attname
|
||||
FROM pg_index, pg_class, pg_attribute
|
||||
WHERE
|
||||
pg_class.oid = '%s'::regclass AND
|
||||
indrelid = pg_class.oid AND
|
||||
pg_attribute.attrelid = pg_class.oid AND
|
||||
pg_attribute.attnum = any(pg_index.indkey) AND
|
||||
indisprimary;
|
||||
"""
|
||||
cursor = self.database.execute_sql(query % tbl)
|
||||
return [row[0] for row in cursor.fetchall()]
|
||||
|
||||
@operation
|
||||
def set_search_path(self, schema_name):
|
||||
return (self
|
||||
.make_context()
|
||||
.literal('SET search_path TO %s' % schema_name))
|
||||
|
||||
@operation
|
||||
def rename_table(self, old_name, new_name):
|
||||
pk_names = self._primary_key_columns(old_name)
|
||||
ParentClass = super(PostgresqlMigrator, self)
|
||||
|
||||
operations = [
|
||||
ParentClass.rename_table(old_name, new_name, with_context=True)]
|
||||
|
||||
if len(pk_names) == 1:
|
||||
# Check for existence of primary key sequence.
|
||||
seq_name = '%s_%s_seq' % (old_name, pk_names[0])
|
||||
query = """
|
||||
SELECT 1
|
||||
FROM information_schema.sequences
|
||||
WHERE LOWER(sequence_name) = LOWER(%s)
|
||||
"""
|
||||
cursor = self.database.execute_sql(query, (seq_name,))
|
||||
if bool(cursor.fetchone()):
|
||||
new_seq_name = '%s_%s_seq' % (new_name, pk_names[0])
|
||||
operations.append(ParentClass.rename_table(
|
||||
seq_name, new_seq_name))
|
||||
|
||||
return operations
|
||||
|
||||
|
||||
class CockroachDBMigrator(PostgresqlMigrator):
|
||||
explicit_create_foreign_key = True
|
||||
|
||||
def add_inline_fk_sql(self, ctx, field):
|
||||
pass
|
||||
|
||||
@operation
|
||||
def drop_index(self, table, index_name):
|
||||
return (self
|
||||
.make_context()
|
||||
.literal('DROP INDEX ')
|
||||
.sql(Entity(index_name))
|
||||
.literal(' CASCADE'))
|
||||
|
||||
|
||||
class MySQLColumn(namedtuple('_Column', ('name', 'definition', 'null', 'pk',
|
||||
'default', 'extra'))):
|
||||
@property
|
||||
def is_pk(self):
|
||||
return self.pk == 'PRI'
|
||||
|
||||
@property
|
||||
def is_unique(self):
|
||||
return self.pk == 'UNI'
|
||||
|
||||
@property
|
||||
def is_null(self):
|
||||
return self.null == 'YES'
|
||||
|
||||
def sql(self, column_name=None, is_null=None):
|
||||
if is_null is None:
|
||||
is_null = self.is_null
|
||||
if column_name is None:
|
||||
column_name = self.name
|
||||
parts = [
|
||||
Entity(column_name),
|
||||
SQL(self.definition)]
|
||||
if self.is_unique:
|
||||
parts.append(SQL('UNIQUE'))
|
||||
if is_null:
|
||||
parts.append(SQL('NULL'))
|
||||
else:
|
||||
parts.append(SQL('NOT NULL'))
|
||||
if self.is_pk:
|
||||
parts.append(SQL('PRIMARY KEY'))
|
||||
if self.extra:
|
||||
parts.append(SQL(self.extra))
|
||||
return NodeList(parts)
|
||||
|
||||
|
||||
class MySQLMigrator(SchemaMigrator):
|
||||
explicit_create_foreign_key = True
|
||||
explicit_delete_foreign_key = True
|
||||
|
||||
def _alter_column(self, ctx, table, column):
|
||||
return (self
|
||||
._alter_table(ctx, table)
|
||||
.literal(' MODIFY ')
|
||||
.sql(Entity(column)))
|
||||
|
||||
@operation
|
||||
def rename_table(self, old_name, new_name):
|
||||
return (self
|
||||
.make_context()
|
||||
.literal('RENAME TABLE ')
|
||||
.sql(Entity(old_name))
|
||||
.literal(' TO ')
|
||||
.sql(Entity(new_name)))
|
||||
|
||||
def _get_column_definition(self, table, column_name):
|
||||
cursor = self.database.execute_sql('DESCRIBE `%s`;' % table)
|
||||
rows = cursor.fetchall()
|
||||
for row in rows:
|
||||
column = MySQLColumn(*row)
|
||||
if column.name == column_name:
|
||||
return column
|
||||
return False
|
||||
|
||||
def get_foreign_key_constraint(self, table, column_name):
|
||||
cursor = self.database.execute_sql(
|
||||
('SELECT constraint_name '
|
||||
'FROM information_schema.key_column_usage WHERE '
|
||||
'table_schema = DATABASE() AND '
|
||||
'table_name = %s AND '
|
||||
'column_name = %s AND '
|
||||
'referenced_table_name IS NOT NULL AND '
|
||||
'referenced_column_name IS NOT NULL;'),
|
||||
(table, column_name))
|
||||
result = cursor.fetchone()
|
||||
if not result:
|
||||
raise AttributeError(
|
||||
'Unable to find foreign key constraint for '
|
||||
'"%s" on table "%s".' % (table, column_name))
|
||||
return result[0]
|
||||
|
||||
@operation
|
||||
def drop_foreign_key_constraint(self, table, column_name):
|
||||
fk_constraint = self.get_foreign_key_constraint(table, column_name)
|
||||
return (self
|
||||
._alter_table(self.make_context(), table)
|
||||
.literal(' DROP FOREIGN KEY ')
|
||||
.sql(Entity(fk_constraint)))
|
||||
|
||||
def add_inline_fk_sql(self, ctx, field):
|
||||
pass
|
||||
|
||||
@operation
|
||||
def add_not_null(self, table, column):
|
||||
column_def = self._get_column_definition(table, column)
|
||||
add_not_null = (self
|
||||
._alter_table(self.make_context(), table)
|
||||
.literal(' MODIFY ')
|
||||
.sql(column_def.sql(is_null=False)))
|
||||
|
||||
fk_objects = dict(
|
||||
(fk.column, fk)
|
||||
for fk in self.database.get_foreign_keys(table))
|
||||
if column not in fk_objects:
|
||||
return add_not_null
|
||||
|
||||
fk_metadata = fk_objects[column]
|
||||
return (self.drop_foreign_key_constraint(table, column),
|
||||
add_not_null,
|
||||
self.add_foreign_key_constraint(
|
||||
table,
|
||||
column,
|
||||
fk_metadata.dest_table,
|
||||
fk_metadata.dest_column))
|
||||
|
||||
@operation
|
||||
def drop_not_null(self, table, column):
|
||||
column = self._get_column_definition(table, column)
|
||||
if column.is_pk:
|
||||
raise ValueError('Primary keys can not be null')
|
||||
return (self
|
||||
._alter_table(self.make_context(), table)
|
||||
.literal(' MODIFY ')
|
||||
.sql(column.sql(is_null=True)))
|
||||
|
||||
@operation
|
||||
def rename_column(self, table, old_name, new_name):
|
||||
fk_objects = dict(
|
||||
(fk.column, fk)
|
||||
for fk in self.database.get_foreign_keys(table))
|
||||
is_foreign_key = old_name in fk_objects
|
||||
|
||||
column = self._get_column_definition(table, old_name)
|
||||
rename_ctx = (self
|
||||
._alter_table(self.make_context(), table)
|
||||
.literal(' CHANGE ')
|
||||
.sql(Entity(old_name))
|
||||
.literal(' ')
|
||||
.sql(column.sql(column_name=new_name)))
|
||||
if is_foreign_key:
|
||||
fk_metadata = fk_objects[old_name]
|
||||
return [
|
||||
self.drop_foreign_key_constraint(table, old_name),
|
||||
rename_ctx,
|
||||
self.add_foreign_key_constraint(
|
||||
table,
|
||||
new_name,
|
||||
fk_metadata.dest_table,
|
||||
fk_metadata.dest_column),
|
||||
]
|
||||
else:
|
||||
return rename_ctx
|
||||
|
||||
@operation
|
||||
def alter_column_type(self, table, column, field, cast=None):
|
||||
if cast is not None:
|
||||
raise ValueError('alter_column_type() does not support cast with '
|
||||
'MySQL.')
|
||||
ctx = self.make_context()
|
||||
return (self
|
||||
._alter_table(ctx, table)
|
||||
.literal(' MODIFY ')
|
||||
.sql(Entity(column))
|
||||
.literal(' ')
|
||||
.sql(field.ddl(ctx)))
|
||||
|
||||
@operation
|
||||
def drop_index(self, table, index_name):
|
||||
return (self
|
||||
.make_context()
|
||||
.literal('DROP INDEX ')
|
||||
.sql(Entity(index_name))
|
||||
.literal(' ON ')
|
||||
.sql(Entity(table)))
|
||||
|
||||
|
||||
class SqliteMigrator(SchemaMigrator):
|
||||
"""
|
||||
SQLite supports a subset of ALTER TABLE queries, view the docs for the
|
||||
full details http://sqlite.org/lang_altertable.html
|
||||
"""
|
||||
column_re = re.compile('(.+?)\((.+)\)')
|
||||
column_split_re = re.compile(r'(?:[^,(]|\([^)]*\))+')
|
||||
column_name_re = re.compile(r'''["`']?([\w]+)''')
|
||||
fk_re = re.compile(r'FOREIGN KEY\s+\("?([\w]+)"?\)\s+', re.I)
|
||||
|
||||
def _get_column_names(self, table):
|
||||
res = self.database.execute_sql('select * from "%s" limit 1' % table)
|
||||
return [item[0] for item in res.description]
|
||||
|
||||
def _get_create_table(self, table):
|
||||
res = self.database.execute_sql(
|
||||
('select name, sql from sqlite_master '
|
||||
'where type=? and LOWER(name)=?'),
|
||||
['table', table.lower()])
|
||||
return res.fetchone()
|
||||
|
||||
@operation
|
||||
def _update_column(self, table, column_to_update, fn):
|
||||
columns = set(column.name.lower()
|
||||
for column in self.database.get_columns(table))
|
||||
if column_to_update.lower() not in columns:
|
||||
raise ValueError('Column "%s" does not exist on "%s"' %
|
||||
(column_to_update, table))
|
||||
|
||||
# Get the SQL used to create the given table.
|
||||
table, create_table = self._get_create_table(table)
|
||||
|
||||
# Get the indexes and SQL to re-create indexes.
|
||||
indexes = self.database.get_indexes(table)
|
||||
|
||||
# Find any foreign keys we may need to remove.
|
||||
self.database.get_foreign_keys(table)
|
||||
|
||||
# Make sure the create_table does not contain any newlines or tabs,
|
||||
# allowing the regex to work correctly.
|
||||
create_table = re.sub(r'\s+', ' ', create_table)
|
||||
|
||||
# Parse out the `CREATE TABLE` and column list portions of the query.
|
||||
raw_create, raw_columns = self.column_re.search(create_table).groups()
|
||||
|
||||
# Clean up the individual column definitions.
|
||||
split_columns = self.column_split_re.findall(raw_columns)
|
||||
column_defs = [col.strip() for col in split_columns]
|
||||
|
||||
new_column_defs = []
|
||||
new_column_names = []
|
||||
original_column_names = []
|
||||
constraint_terms = ('foreign ', 'primary ', 'constraint ', 'check ')
|
||||
|
||||
for column_def in column_defs:
|
||||
column_name, = self.column_name_re.match(column_def).groups()
|
||||
|
||||
if column_name == column_to_update:
|
||||
new_column_def = fn(column_name, column_def)
|
||||
if new_column_def:
|
||||
new_column_defs.append(new_column_def)
|
||||
original_column_names.append(column_name)
|
||||
column_name, = self.column_name_re.match(
|
||||
new_column_def).groups()
|
||||
new_column_names.append(column_name)
|
||||
else:
|
||||
new_column_defs.append(column_def)
|
||||
|
||||
# Avoid treating constraints as columns.
|
||||
if not column_def.lower().startswith(constraint_terms):
|
||||
new_column_names.append(column_name)
|
||||
original_column_names.append(column_name)
|
||||
|
||||
# Create a mapping of original columns to new columns.
|
||||
original_to_new = dict(zip(original_column_names, new_column_names))
|
||||
new_column = original_to_new.get(column_to_update)
|
||||
|
||||
fk_filter_fn = lambda column_def: column_def
|
||||
if not new_column:
|
||||
# Remove any foreign keys associated with this column.
|
||||
fk_filter_fn = lambda column_def: None
|
||||
elif new_column != column_to_update:
|
||||
# Update any foreign keys for this column.
|
||||
fk_filter_fn = lambda column_def: self.fk_re.sub(
|
||||
'FOREIGN KEY ("%s") ' % new_column,
|
||||
column_def)
|
||||
|
||||
cleaned_columns = []
|
||||
for column_def in new_column_defs:
|
||||
match = self.fk_re.match(column_def)
|
||||
if match is not None and match.groups()[0] == column_to_update:
|
||||
column_def = fk_filter_fn(column_def)
|
||||
if column_def:
|
||||
cleaned_columns.append(column_def)
|
||||
|
||||
# Update the name of the new CREATE TABLE query.
|
||||
temp_table = table + '__tmp__'
|
||||
rgx = re.compile('("?)%s("?)' % table, re.I)
|
||||
create = rgx.sub(
|
||||
'\\1%s\\2' % temp_table,
|
||||
raw_create)
|
||||
|
||||
# Create the new table.
|
||||
columns = ', '.join(cleaned_columns)
|
||||
queries = [
|
||||
NodeList([SQL('DROP TABLE IF EXISTS'), Entity(temp_table)]),
|
||||
SQL('%s (%s)' % (create.strip(), columns))]
|
||||
|
||||
# Populate new table.
|
||||
populate_table = NodeList((
|
||||
SQL('INSERT INTO'),
|
||||
Entity(temp_table),
|
||||
EnclosedNodeList([Entity(col) for col in new_column_names]),
|
||||
SQL('SELECT'),
|
||||
CommaNodeList([Entity(col) for col in original_column_names]),
|
||||
SQL('FROM'),
|
||||
Entity(table)))
|
||||
drop_original = NodeList([SQL('DROP TABLE'), Entity(table)])
|
||||
|
||||
# Drop existing table and rename temp table.
|
||||
queries += [
|
||||
populate_table,
|
||||
drop_original,
|
||||
self.rename_table(temp_table, table)]
|
||||
|
||||
# Re-create user-defined indexes. User-defined indexes will have a
|
||||
# non-empty SQL attribute.
|
||||
for index in filter(lambda idx: idx.sql, indexes):
|
||||
if column_to_update not in index.columns:
|
||||
queries.append(SQL(index.sql))
|
||||
elif new_column:
|
||||
sql = self._fix_index(index.sql, column_to_update, new_column)
|
||||
if sql is not None:
|
||||
queries.append(SQL(sql))
|
||||
|
||||
return queries
|
||||
|
||||
def _fix_index(self, sql, column_to_update, new_column):
|
||||
# Split on the name of the column to update. If it splits into two
|
||||
# pieces, then there's no ambiguity and we can simply replace the
|
||||
# old with the new.
|
||||
parts = sql.split(column_to_update)
|
||||
if len(parts) == 2:
|
||||
return sql.replace(column_to_update, new_column)
|
||||
|
||||
# Find the list of columns in the index expression.
|
||||
lhs, rhs = sql.rsplit('(', 1)
|
||||
|
||||
# Apply the same "split in two" logic to the column list portion of
|
||||
# the query.
|
||||
if len(rhs.split(column_to_update)) == 2:
|
||||
return '%s(%s' % (lhs, rhs.replace(column_to_update, new_column))
|
||||
|
||||
# Strip off the trailing parentheses and go through each column.
|
||||
parts = rhs.rsplit(')', 1)[0].split(',')
|
||||
columns = [part.strip('"`[]\' ') for part in parts]
|
||||
|
||||
# `columns` looks something like: ['status', 'timestamp" DESC']
|
||||
# https://www.sqlite.org/lang_keywords.html
|
||||
# Strip out any junk after the column name.
|
||||
clean = []
|
||||
for column in columns:
|
||||
if re.match('%s(?:[\'"`\]]?\s|$)' % column_to_update, column):
|
||||
column = new_column + column[len(column_to_update):]
|
||||
clean.append(column)
|
||||
|
||||
return '%s(%s)' % (lhs, ', '.join('"%s"' % c for c in clean))
|
||||
|
||||
@operation
|
||||
def drop_column(self, table, column_name, cascade=True):
|
||||
return self._update_column(table, column_name, lambda a, b: None)
|
||||
|
||||
@operation
|
||||
def rename_column(self, table, old_name, new_name):
|
||||
def _rename(column_name, column_def):
|
||||
return column_def.replace(column_name, new_name)
|
||||
return self._update_column(table, old_name, _rename)
|
||||
|
||||
@operation
|
||||
def add_not_null(self, table, column):
|
||||
def _add_not_null(column_name, column_def):
|
||||
return column_def + ' NOT NULL'
|
||||
return self._update_column(table, column, _add_not_null)
|
||||
|
||||
@operation
|
||||
def drop_not_null(self, table, column):
|
||||
def _drop_not_null(column_name, column_def):
|
||||
return column_def.replace('NOT NULL', '')
|
||||
return self._update_column(table, column, _drop_not_null)
|
||||
|
||||
@operation
|
||||
def alter_column_type(self, table, column, field, cast=None):
|
||||
if cast is not None:
|
||||
raise ValueError('alter_column_type() does not support cast with '
|
||||
'Sqlite.')
|
||||
ctx = self.make_context()
|
||||
def _alter_column_type(column_name, column_def):
|
||||
node_list = field.ddl(ctx)
|
||||
sql, _ = ctx.sql(Entity(column)).sql(node_list).query()
|
||||
return sql
|
||||
return self._update_column(table, column, _alter_column_type)
|
||||
|
||||
@operation
|
||||
def add_constraint(self, table, name, constraint):
|
||||
raise NotImplementedError
|
||||
|
||||
@operation
|
||||
def drop_constraint(self, table, name):
|
||||
raise NotImplementedError
|
||||
|
||||
@operation
|
||||
def add_foreign_key_constraint(self, table, column_name, field,
|
||||
on_delete=None, on_update=None):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def migrate(*operations, **kwargs):
|
||||
for operation in operations:
|
||||
operation.run()
|
@ -0,0 +1,49 @@
|
||||
import json
|
||||
|
||||
try:
|
||||
import mysql.connector as mysql_connector
|
||||
except ImportError:
|
||||
mysql_connector = None
|
||||
|
||||
from peewee import ImproperlyConfigured
|
||||
from peewee import MySQLDatabase
|
||||
from peewee import NodeList
|
||||
from peewee import SQL
|
||||
from peewee import TextField
|
||||
from peewee import fn
|
||||
|
||||
|
||||
class MySQLConnectorDatabase(MySQLDatabase):
|
||||
def _connect(self):
|
||||
if mysql_connector is None:
|
||||
raise ImproperlyConfigured('MySQL connector not installed!')
|
||||
return mysql_connector.connect(db=self.database, **self.connect_params)
|
||||
|
||||
def cursor(self, commit=None):
|
||||
if self.is_closed():
|
||||
if self.autoconnect:
|
||||
self.connect()
|
||||
else:
|
||||
raise InterfaceError('Error, database connection not opened.')
|
||||
return self._state.conn.cursor(buffered=True)
|
||||
|
||||
|
||||
class JSONField(TextField):
|
||||
field_type = 'JSON'
|
||||
|
||||
def db_value(self, value):
|
||||
if value is not None:
|
||||
return json.dumps(value)
|
||||
|
||||
def python_value(self, value):
|
||||
if value is not None:
|
||||
return json.loads(value)
|
||||
|
||||
|
||||
def Match(columns, expr, modifier=None):
|
||||
if isinstance(columns, (list, tuple)):
|
||||
match = fn.MATCH(*columns) # Tuple of one or more columns / fields.
|
||||
else:
|
||||
match = fn.MATCH(columns) # Single column / field.
|
||||
args = expr if modifier is None else NodeList((expr, SQL(modifier)))
|
||||
return NodeList((match, fn.AGAINST(args)))
|
@ -0,0 +1,318 @@
|
||||
"""
|
||||
Lightweight connection pooling for peewee.
|
||||
|
||||
In a multi-threaded application, up to `max_connections` will be opened. Each
|
||||
thread (or, if using gevent, greenlet) will have it's own connection.
|
||||
|
||||
In a single-threaded application, only one connection will be created. It will
|
||||
be continually recycled until either it exceeds the stale timeout or is closed
|
||||
explicitly (using `.manual_close()`).
|
||||
|
||||
By default, all your application needs to do is ensure that connections are
|
||||
closed when you are finished with them, and they will be returned to the pool.
|
||||
For web applications, this typically means that at the beginning of a request,
|
||||
you will open a connection, and when you return a response, you will close the
|
||||
connection.
|
||||
|
||||
Simple Postgres pool example code:
|
||||
|
||||
# Use the special postgresql extensions.
|
||||
from playhouse.pool import PooledPostgresqlExtDatabase
|
||||
|
||||
db = PooledPostgresqlExtDatabase(
|
||||
'my_app',
|
||||
max_connections=32,
|
||||
stale_timeout=300, # 5 minutes.
|
||||
user='postgres')
|
||||
|
||||
class BaseModel(Model):
|
||||
class Meta:
|
||||
database = db
|
||||
|
||||
That's it!
|
||||
"""
|
||||
import heapq
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
from collections import namedtuple
|
||||
from itertools import chain
|
||||
|
||||
try:
|
||||
from psycopg2.extensions import TRANSACTION_STATUS_IDLE
|
||||
from psycopg2.extensions import TRANSACTION_STATUS_INERROR
|
||||
from psycopg2.extensions import TRANSACTION_STATUS_UNKNOWN
|
||||
except ImportError:
|
||||
TRANSACTION_STATUS_IDLE = \
|
||||
TRANSACTION_STATUS_INERROR = \
|
||||
TRANSACTION_STATUS_UNKNOWN = None
|
||||
|
||||
from peewee import MySQLDatabase
|
||||
from peewee import PostgresqlDatabase
|
||||
from peewee import SqliteDatabase
|
||||
|
||||
logger = logging.getLogger('peewee.pool')
|
||||
|
||||
|
||||
def make_int(val):
|
||||
if val is not None and not isinstance(val, (int, float)):
|
||||
return int(val)
|
||||
return val
|
||||
|
||||
|
||||
class MaxConnectionsExceeded(ValueError): pass
|
||||
|
||||
|
||||
PoolConnection = namedtuple('PoolConnection', ('timestamp', 'connection',
|
||||
'checked_out'))
|
||||
|
||||
|
||||
class PooledDatabase(object):
|
||||
def __init__(self, database, max_connections=20, stale_timeout=None,
|
||||
timeout=None, **kwargs):
|
||||
self._max_connections = make_int(max_connections)
|
||||
self._stale_timeout = make_int(stale_timeout)
|
||||
self._wait_timeout = make_int(timeout)
|
||||
if self._wait_timeout == 0:
|
||||
self._wait_timeout = float('inf')
|
||||
|
||||
# Available / idle connections stored in a heap, sorted oldest first.
|
||||
self._connections = []
|
||||
|
||||
# Mapping of connection id to PoolConnection. Ordinarily we would want
|
||||
# to use something like a WeakKeyDictionary, but Python typically won't
|
||||
# allow us to create weak references to connection objects.
|
||||
self._in_use = {}
|
||||
|
||||
# Use the memory address of the connection as the key in the event the
|
||||
# connection object is not hashable. Connections will not get
|
||||
# garbage-collected, however, because a reference to them will persist
|
||||
# in "_in_use" as long as the conn has not been closed.
|
||||
self.conn_key = id
|
||||
|
||||
super(PooledDatabase, self).__init__(database, **kwargs)
|
||||
|
||||
def init(self, database, max_connections=None, stale_timeout=None,
|
||||
timeout=None, **connect_kwargs):
|
||||
super(PooledDatabase, self).init(database, **connect_kwargs)
|
||||
if max_connections is not None:
|
||||
self._max_connections = make_int(max_connections)
|
||||
if stale_timeout is not None:
|
||||
self._stale_timeout = make_int(stale_timeout)
|
||||
if timeout is not None:
|
||||
self._wait_timeout = make_int(timeout)
|
||||
if self._wait_timeout == 0:
|
||||
self._wait_timeout = float('inf')
|
||||
|
||||
def connect(self, reuse_if_open=False):
|
||||
if not self._wait_timeout:
|
||||
return super(PooledDatabase, self).connect(reuse_if_open)
|
||||
|
||||
expires = time.time() + self._wait_timeout
|
||||
while expires > time.time():
|
||||
try:
|
||||
ret = super(PooledDatabase, self).connect(reuse_if_open)
|
||||
except MaxConnectionsExceeded:
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
return ret
|
||||
raise MaxConnectionsExceeded('Max connections exceeded, timed out '
|
||||
'attempting to connect.')
|
||||
|
||||
def _connect(self):
|
||||
while True:
|
||||
try:
|
||||
# Remove the oldest connection from the heap.
|
||||
ts, conn = heapq.heappop(self._connections)
|
||||
key = self.conn_key(conn)
|
||||
except IndexError:
|
||||
ts = conn = None
|
||||
logger.debug('No connection available in pool.')
|
||||
break
|
||||
else:
|
||||
if self._is_closed(conn):
|
||||
# This connecton was closed, but since it was not stale
|
||||
# it got added back to the queue of available conns. We
|
||||
# then closed it and marked it as explicitly closed, so
|
||||
# it's safe to throw it away now.
|
||||
# (Because Database.close() calls Database._close()).
|
||||
logger.debug('Connection %s was closed.', key)
|
||||
ts = conn = None
|
||||
elif self._stale_timeout and self._is_stale(ts):
|
||||
# If we are attempting to check out a stale connection,
|
||||
# then close it. We don't need to mark it in the "closed"
|
||||
# set, because it is not in the list of available conns
|
||||
# anymore.
|
||||
logger.debug('Connection %s was stale, closing.', key)
|
||||
self._close(conn, True)
|
||||
ts = conn = None
|
||||
else:
|
||||
break
|
||||
|
||||
if conn is None:
|
||||
if self._max_connections and (
|
||||
len(self._in_use) >= self._max_connections):
|
||||
raise MaxConnectionsExceeded('Exceeded maximum connections.')
|
||||
conn = super(PooledDatabase, self)._connect()
|
||||
ts = time.time() - random.random() / 1000
|
||||
key = self.conn_key(conn)
|
||||
logger.debug('Created new connection %s.', key)
|
||||
|
||||
self._in_use[key] = PoolConnection(ts, conn, time.time())
|
||||
return conn
|
||||
|
||||
def _is_stale(self, timestamp):
|
||||
# Called on check-out and check-in to ensure the connection has
|
||||
# not outlived the stale timeout.
|
||||
return (time.time() - timestamp) > self._stale_timeout
|
||||
|
||||
def _is_closed(self, conn):
|
||||
return False
|
||||
|
||||
def _can_reuse(self, conn):
|
||||
# Called on check-in to make sure the connection can be re-used.
|
||||
return True
|
||||
|
||||
def _close(self, conn, close_conn=False):
|
||||
key = self.conn_key(conn)
|
||||
if close_conn:
|
||||
super(PooledDatabase, self)._close(conn)
|
||||
elif key in self._in_use:
|
||||
pool_conn = self._in_use.pop(key)
|
||||
if self._stale_timeout and self._is_stale(pool_conn.timestamp):
|
||||
logger.debug('Closing stale connection %s.', key)
|
||||
super(PooledDatabase, self)._close(conn)
|
||||
elif self._can_reuse(conn):
|
||||
logger.debug('Returning %s to pool.', key)
|
||||
heapq.heappush(self._connections, (pool_conn.timestamp, conn))
|
||||
else:
|
||||
logger.debug('Closed %s.', key)
|
||||
|
||||
def manual_close(self):
|
||||
"""
|
||||
Close the underlying connection without returning it to the pool.
|
||||
"""
|
||||
if self.is_closed():
|
||||
return False
|
||||
|
||||
# Obtain reference to the connection in-use by the calling thread.
|
||||
conn = self.connection()
|
||||
|
||||
# A connection will only be re-added to the available list if it is
|
||||
# marked as "in use" at the time it is closed. We will explicitly
|
||||
# remove it from the "in use" list, call "close()" for the
|
||||
# side-effects, and then explicitly close the connection.
|
||||
self._in_use.pop(self.conn_key(conn), None)
|
||||
self.close()
|
||||
self._close(conn, close_conn=True)
|
||||
|
||||
def close_idle(self):
|
||||
# Close any open connections that are not currently in-use.
|
||||
with self._lock:
|
||||
for _, conn in self._connections:
|
||||
self._close(conn, close_conn=True)
|
||||
self._connections = []
|
||||
|
||||
def close_stale(self, age=600):
|
||||
# Close any connections that are in-use but were checked out quite some
|
||||
# time ago and can be considered stale.
|
||||
with self._lock:
|
||||
in_use = {}
|
||||
cutoff = time.time() - age
|
||||
n = 0
|
||||
for key, pool_conn in self._in_use.items():
|
||||
if pool_conn.checked_out < cutoff:
|
||||
self._close(pool_conn.connection, close_conn=True)
|
||||
n += 1
|
||||
else:
|
||||
in_use[key] = pool_conn
|
||||
self._in_use = in_use
|
||||
return n
|
||||
|
||||
def close_all(self):
|
||||
# Close all connections -- available and in-use. Warning: may break any
|
||||
# active connections used by other threads.
|
||||
self.close()
|
||||
with self._lock:
|
||||
for _, conn in self._connections:
|
||||
self._close(conn, close_conn=True)
|
||||
for pool_conn in self._in_use.values():
|
||||
self._close(pool_conn.connection, close_conn=True)
|
||||
self._connections = []
|
||||
self._in_use = {}
|
||||
|
||||
|
||||
class PooledMySQLDatabase(PooledDatabase, MySQLDatabase):
|
||||
def _is_closed(self, conn):
|
||||
try:
|
||||
conn.ping(False)
|
||||
except:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class _PooledPostgresqlDatabase(PooledDatabase):
|
||||
def _is_closed(self, conn):
|
||||
if conn.closed:
|
||||
return True
|
||||
|
||||
txn_status = conn.get_transaction_status()
|
||||
if txn_status == TRANSACTION_STATUS_UNKNOWN:
|
||||
return True
|
||||
elif txn_status != TRANSACTION_STATUS_IDLE:
|
||||
conn.rollback()
|
||||
return False
|
||||
|
||||
def _can_reuse(self, conn):
|
||||
txn_status = conn.get_transaction_status()
|
||||
# Do not return connection in an error state, as subsequent queries
|
||||
# will all fail. If the status is unknown then we lost the connection
|
||||
# to the server and the connection should not be re-used.
|
||||
if txn_status == TRANSACTION_STATUS_UNKNOWN:
|
||||
return False
|
||||
elif txn_status == TRANSACTION_STATUS_INERROR:
|
||||
conn.reset()
|
||||
elif txn_status != TRANSACTION_STATUS_IDLE:
|
||||
conn.rollback()
|
||||
return True
|
||||
|
||||
class PooledPostgresqlDatabase(_PooledPostgresqlDatabase, PostgresqlDatabase):
|
||||
pass
|
||||
|
||||
try:
|
||||
from playhouse.postgres_ext import PostgresqlExtDatabase
|
||||
|
||||
class PooledPostgresqlExtDatabase(_PooledPostgresqlDatabase, PostgresqlExtDatabase):
|
||||
pass
|
||||
except ImportError:
|
||||
PooledPostgresqlExtDatabase = None
|
||||
|
||||
|
||||
class _PooledSqliteDatabase(PooledDatabase):
|
||||
def _is_closed(self, conn):
|
||||
try:
|
||||
conn.total_changes
|
||||
except:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
class PooledSqliteDatabase(_PooledSqliteDatabase, SqliteDatabase):
|
||||
pass
|
||||
|
||||
try:
|
||||
from playhouse.sqlite_ext import SqliteExtDatabase
|
||||
|
||||
class PooledSqliteExtDatabase(_PooledSqliteDatabase, SqliteExtDatabase):
|
||||
pass
|
||||
except ImportError:
|
||||
PooledSqliteExtDatabase = None
|
||||
|
||||
try:
|
||||
from playhouse.sqlite_ext import CSqliteExtDatabase
|
||||
|
||||
class PooledCSqliteExtDatabase(_PooledSqliteDatabase, CSqliteExtDatabase):
|
||||
pass
|
||||
except ImportError:
|
||||
PooledCSqliteExtDatabase = None
|
@ -0,0 +1,493 @@
|
||||
"""
|
||||
Collection of postgres-specific extensions, currently including:
|
||||
|
||||
* Support for hstore, a key/value type storage
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from peewee import *
|
||||
from peewee import ColumnBase
|
||||
from peewee import Expression
|
||||
from peewee import Node
|
||||
from peewee import NodeList
|
||||
from peewee import SENTINEL
|
||||
from peewee import __exception_wrapper__
|
||||
|
||||
try:
|
||||
from psycopg2cffi import compat
|
||||
compat.register()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from psycopg2.extras import register_hstore
|
||||
except ImportError:
|
||||
def register_hstore(c, globally):
|
||||
pass
|
||||
try:
|
||||
from psycopg2.extras import Json
|
||||
except:
|
||||
Json = None
|
||||
|
||||
|
||||
logger = logging.getLogger('peewee')
|
||||
|
||||
|
||||
HCONTAINS_DICT = '@>'
|
||||
HCONTAINS_KEYS = '?&'
|
||||
HCONTAINS_KEY = '?'
|
||||
HCONTAINS_ANY_KEY = '?|'
|
||||
HKEY = '->'
|
||||
HUPDATE = '||'
|
||||
ACONTAINS = '@>'
|
||||
ACONTAINS_ANY = '&&'
|
||||
TS_MATCH = '@@'
|
||||
JSONB_CONTAINS = '@>'
|
||||
JSONB_CONTAINED_BY = '<@'
|
||||
JSONB_CONTAINS_KEY = '?'
|
||||
JSONB_CONTAINS_ANY_KEY = '?|'
|
||||
JSONB_CONTAINS_ALL_KEYS = '?&'
|
||||
JSONB_EXISTS = '?'
|
||||
JSONB_REMOVE = '-'
|
||||
|
||||
|
||||
class _LookupNode(ColumnBase):
|
||||
def __init__(self, node, parts):
|
||||
self.node = node
|
||||
self.parts = parts
|
||||
super(_LookupNode, self).__init__()
|
||||
|
||||
def clone(self):
|
||||
return type(self)(self.node, list(self.parts))
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.__class__.__name__, id(self)))
|
||||
|
||||
|
||||
class _JsonLookupBase(_LookupNode):
|
||||
def __init__(self, node, parts, as_json=False):
|
||||
super(_JsonLookupBase, self).__init__(node, parts)
|
||||
self._as_json = as_json
|
||||
|
||||
def clone(self):
|
||||
return type(self)(self.node, list(self.parts), self._as_json)
|
||||
|
||||
@Node.copy
|
||||
def as_json(self, as_json=True):
|
||||
self._as_json = as_json
|
||||
|
||||
def concat(self, rhs):
|
||||
if not isinstance(rhs, Node):
|
||||
rhs = Json(rhs)
|
||||
return Expression(self.as_json(True), OP.CONCAT, rhs)
|
||||
|
||||
def contains(self, other):
|
||||
clone = self.as_json(True)
|
||||
if isinstance(other, (list, dict)):
|
||||
return Expression(clone, JSONB_CONTAINS, Json(other))
|
||||
return Expression(clone, JSONB_EXISTS, other)
|
||||
|
||||
def contains_any(self, *keys):
|
||||
return Expression(
|
||||
self.as_json(True),
|
||||
JSONB_CONTAINS_ANY_KEY,
|
||||
Value(list(keys), unpack=False))
|
||||
|
||||
def contains_all(self, *keys):
|
||||
return Expression(
|
||||
self.as_json(True),
|
||||
JSONB_CONTAINS_ALL_KEYS,
|
||||
Value(list(keys), unpack=False))
|
||||
|
||||
def has_key(self, key):
|
||||
return Expression(self.as_json(True), JSONB_CONTAINS_KEY, key)
|
||||
|
||||
|
||||
class JsonLookup(_JsonLookupBase):
|
||||
def __getitem__(self, value):
|
||||
return JsonLookup(self.node, self.parts + [value], self._as_json)
|
||||
|
||||
def __sql__(self, ctx):
|
||||
ctx.sql(self.node)
|
||||
for part in self.parts[:-1]:
|
||||
ctx.literal('->').sql(part)
|
||||
if self.parts:
|
||||
(ctx
|
||||
.literal('->' if self._as_json else '->>')
|
||||
.sql(self.parts[-1]))
|
||||
|
||||
return ctx
|
||||
|
||||
|
||||
class JsonPath(_JsonLookupBase):
|
||||
def __sql__(self, ctx):
|
||||
return (ctx
|
||||
.sql(self.node)
|
||||
.literal('#>' if self._as_json else '#>>')
|
||||
.sql(Value('{%s}' % ','.join(map(str, self.parts)))))
|
||||
|
||||
|
||||
class ObjectSlice(_LookupNode):
|
||||
@classmethod
|
||||
def create(cls, node, value):
|
||||
if isinstance(value, slice):
|
||||
parts = [value.start or 0, value.stop or 0]
|
||||
elif isinstance(value, int):
|
||||
parts = [value]
|
||||
elif isinstance(value, Node):
|
||||
parts = value
|
||||
else:
|
||||
# Assumes colon-separated integer indexes.
|
||||
parts = [int(i) for i in value.split(':')]
|
||||
return cls(node, parts)
|
||||
|
||||
def __sql__(self, ctx):
|
||||
ctx.sql(self.node)
|
||||
if isinstance(self.parts, Node):
|
||||
ctx.literal('[').sql(self.parts).literal(']')
|
||||
else:
|
||||
ctx.literal('[%s]' % ':'.join(str(p + 1) for p in self.parts))
|
||||
return ctx
|
||||
|
||||
def __getitem__(self, value):
|
||||
return ObjectSlice.create(self, value)
|
||||
|
||||
|
||||
class IndexedFieldMixin(object):
|
||||
default_index_type = 'GIN'
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs.setdefault('index', True) # By default, use an index.
|
||||
super(IndexedFieldMixin, self).__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class ArrayField(IndexedFieldMixin, Field):
|
||||
passthrough = True
|
||||
|
||||
def __init__(self, field_class=IntegerField, field_kwargs=None,
|
||||
dimensions=1, convert_values=False, *args, **kwargs):
|
||||
self.__field = field_class(**(field_kwargs or {}))
|
||||
self.dimensions = dimensions
|
||||
self.convert_values = convert_values
|
||||
self.field_type = self.__field.field_type
|
||||
super(ArrayField, self).__init__(*args, **kwargs)
|
||||
|
||||
def bind(self, model, name, set_attribute=True):
|
||||
ret = super(ArrayField, self).bind(model, name, set_attribute)
|
||||
self.__field.bind(model, '__array_%s' % name, False)
|
||||
return ret
|
||||
|
||||
def ddl_datatype(self, ctx):
|
||||
data_type = self.__field.ddl_datatype(ctx)
|
||||
return NodeList((data_type, SQL('[]' * self.dimensions)), glue='')
|
||||
|
||||
def db_value(self, value):
|
||||
if value is None or isinstance(value, Node):
|
||||
return value
|
||||
elif self.convert_values:
|
||||
return self._process(self.__field.db_value, value, self.dimensions)
|
||||
else:
|
||||
return value if isinstance(value, list) else list(value)
|
||||
|
||||
def python_value(self, value):
|
||||
if self.convert_values and value is not None:
|
||||
conv = self.__field.python_value
|
||||
if isinstance(value, list):
|
||||
return self._process(conv, value, self.dimensions)
|
||||
else:
|
||||
return conv(value)
|
||||
else:
|
||||
return value
|
||||
|
||||
def _process(self, conv, value, dimensions):
|
||||
dimensions -= 1
|
||||
if dimensions == 0:
|
||||
return [conv(v) for v in value]
|
||||
else:
|
||||
return [self._process(conv, v, dimensions) for v in value]
|
||||
|
||||
def __getitem__(self, value):
|
||||
return ObjectSlice.create(self, value)
|
||||
|
||||
def _e(op):
|
||||
def inner(self, rhs):
|
||||
return Expression(self, op, ArrayValue(self, rhs))
|
||||
return inner
|
||||
__eq__ = _e(OP.EQ)
|
||||
__ne__ = _e(OP.NE)
|
||||
__gt__ = _e(OP.GT)
|
||||
__ge__ = _e(OP.GTE)
|
||||
__lt__ = _e(OP.LT)
|
||||
__le__ = _e(OP.LTE)
|
||||
__hash__ = Field.__hash__
|
||||
|
||||
def contains(self, *items):
|
||||
return Expression(self, ACONTAINS, ArrayValue(self, items))
|
||||
|
||||
def contains_any(self, *items):
|
||||
return Expression(self, ACONTAINS_ANY, ArrayValue(self, items))
|
||||
|
||||
|
||||
class ArrayValue(Node):
|
||||
def __init__(self, field, value):
|
||||
self.field = field
|
||||
self.value = value
|
||||
|
||||
def __sql__(self, ctx):
|
||||
return (ctx
|
||||
.sql(Value(self.value, unpack=False))
|
||||
.literal('::')
|
||||
.sql(self.field.ddl_datatype(ctx)))
|
||||
|
||||
|
||||
class DateTimeTZField(DateTimeField):
|
||||
field_type = 'TIMESTAMPTZ'
|
||||
|
||||
|
||||
class HStoreField(IndexedFieldMixin, Field):
|
||||
field_type = 'HSTORE'
|
||||
__hash__ = Field.__hash__
|
||||
|
||||
def __getitem__(self, key):
|
||||
return Expression(self, HKEY, Value(key))
|
||||
|
||||
def keys(self):
|
||||
return fn.akeys(self)
|
||||
|
||||
def values(self):
|
||||
return fn.avals(self)
|
||||
|
||||
def items(self):
|
||||
return fn.hstore_to_matrix(self)
|
||||
|
||||
def slice(self, *args):
|
||||
return fn.slice(self, Value(list(args), unpack=False))
|
||||
|
||||
def exists(self, key):
|
||||
return fn.exist(self, key)
|
||||
|
||||
def defined(self, key):
|
||||
return fn.defined(self, key)
|
||||
|
||||
def update(self, **data):
|
||||
return Expression(self, HUPDATE, data)
|
||||
|
||||
def delete(self, *keys):
|
||||
return fn.delete(self, Value(list(keys), unpack=False))
|
||||
|
||||
def contains(self, value):
|
||||
if isinstance(value, dict):
|
||||
rhs = Value(value, unpack=False)
|
||||
return Expression(self, HCONTAINS_DICT, rhs)
|
||||
elif isinstance(value, (list, tuple)):
|
||||
rhs = Value(value, unpack=False)
|
||||
return Expression(self, HCONTAINS_KEYS, rhs)
|
||||
return Expression(self, HCONTAINS_KEY, value)
|
||||
|
||||
def contains_any(self, *keys):
|
||||
return Expression(self, HCONTAINS_ANY_KEY, Value(list(keys),
|
||||
unpack=False))
|
||||
|
||||
|
||||
class JSONField(Field):
|
||||
field_type = 'JSON'
|
||||
_json_datatype = 'json'
|
||||
|
||||
def __init__(self, dumps=None, *args, **kwargs):
|
||||
if Json is None:
|
||||
raise Exception('Your version of psycopg2 does not support JSON.')
|
||||
self.dumps = dumps or json.dumps
|
||||
super(JSONField, self).__init__(*args, **kwargs)
|
||||
|
||||
def db_value(self, value):
|
||||
if value is None:
|
||||
return value
|
||||
if not isinstance(value, Json):
|
||||
return Cast(self.dumps(value), self._json_datatype)
|
||||
return value
|
||||
|
||||
def __getitem__(self, value):
|
||||
return JsonLookup(self, [value])
|
||||
|
||||
def path(self, *keys):
|
||||
return JsonPath(self, keys)
|
||||
|
||||
def concat(self, value):
|
||||
if not isinstance(value, Node):
|
||||
value = Json(value)
|
||||
return super(JSONField, self).concat(value)
|
||||
|
||||
|
||||
def cast_jsonb(node):
|
||||
return NodeList((node, SQL('::jsonb')), glue='')
|
||||
|
||||
|
||||
class BinaryJSONField(IndexedFieldMixin, JSONField):
|
||||
field_type = 'JSONB'
|
||||
_json_datatype = 'jsonb'
|
||||
__hash__ = Field.__hash__
|
||||
|
||||
def contains(self, other):
|
||||
if isinstance(other, (list, dict)):
|
||||
return Expression(self, JSONB_CONTAINS, Json(other))
|
||||
elif isinstance(other, JSONField):
|
||||
return Expression(self, JSONB_CONTAINS, other)
|
||||
return Expression(cast_jsonb(self), JSONB_EXISTS, other)
|
||||
|
||||
def contained_by(self, other):
|
||||
return Expression(cast_jsonb(self), JSONB_CONTAINED_BY, Json(other))
|
||||
|
||||
def contains_any(self, *items):
|
||||
return Expression(
|
||||
cast_jsonb(self),
|
||||
JSONB_CONTAINS_ANY_KEY,
|
||||
Value(list(items), unpack=False))
|
||||
|
||||
def contains_all(self, *items):
|
||||
return Expression(
|
||||
cast_jsonb(self),
|
||||
JSONB_CONTAINS_ALL_KEYS,
|
||||
Value(list(items), unpack=False))
|
||||
|
||||
def has_key(self, key):
|
||||
return Expression(cast_jsonb(self), JSONB_CONTAINS_KEY, key)
|
||||
|
||||
def remove(self, *items):
|
||||
return Expression(
|
||||
cast_jsonb(self),
|
||||
JSONB_REMOVE,
|
||||
Value(list(items), unpack=False))
|
||||
|
||||
|
||||
class TSVectorField(IndexedFieldMixin, TextField):
|
||||
field_type = 'TSVECTOR'
|
||||
__hash__ = Field.__hash__
|
||||
|
||||
def match(self, query, language=None, plain=False):
|
||||
params = (language, query) if language is not None else (query,)
|
||||
func = fn.plainto_tsquery if plain else fn.to_tsquery
|
||||
return Expression(self, TS_MATCH, func(*params))
|
||||
|
||||
|
||||
def Match(field, query, language=None):
|
||||
params = (language, query) if language is not None else (query,)
|
||||
field_params = (language, field) if language is not None else (field,)
|
||||
return Expression(
|
||||
fn.to_tsvector(*field_params),
|
||||
TS_MATCH,
|
||||
fn.to_tsquery(*params))
|
||||
|
||||
|
||||
class IntervalField(Field):
|
||||
field_type = 'INTERVAL'
|
||||
|
||||
|
||||
class FetchManyCursor(object):
|
||||
__slots__ = ('cursor', 'array_size', 'exhausted', 'iterable')
|
||||
|
||||
def __init__(self, cursor, array_size=None):
|
||||
self.cursor = cursor
|
||||
self.array_size = array_size or cursor.itersize
|
||||
self.exhausted = False
|
||||
self.iterable = self.row_gen()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return self.cursor.description
|
||||
|
||||
def close(self):
|
||||
self.cursor.close()
|
||||
|
||||
def row_gen(self):
|
||||
while True:
|
||||
rows = self.cursor.fetchmany(self.array_size)
|
||||
if not rows:
|
||||
return
|
||||
for row in rows:
|
||||
yield row
|
||||
|
||||
def fetchone(self):
|
||||
if self.exhausted:
|
||||
return
|
||||
try:
|
||||
return next(self.iterable)
|
||||
except StopIteration:
|
||||
self.exhausted = True
|
||||
|
||||
|
||||
class ServerSideQuery(Node):
|
||||
def __init__(self, query, array_size=None):
|
||||
self.query = query
|
||||
self.array_size = array_size
|
||||
self._cursor_wrapper = None
|
||||
|
||||
def __sql__(self, ctx):
|
||||
return self.query.__sql__(ctx)
|
||||
|
||||
def __iter__(self):
|
||||
if self._cursor_wrapper is None:
|
||||
self._execute(self.query._database)
|
||||
return iter(self._cursor_wrapper.iterator())
|
||||
|
||||
def _execute(self, database):
|
||||
if self._cursor_wrapper is None:
|
||||
cursor = database.execute(self.query, named_cursor=True,
|
||||
array_size=self.array_size)
|
||||
self._cursor_wrapper = self.query._get_cursor_wrapper(cursor)
|
||||
return self._cursor_wrapper
|
||||
|
||||
|
||||
def ServerSide(query, database=None, array_size=None):
|
||||
if database is None:
|
||||
database = query._database
|
||||
with database.transaction():
|
||||
server_side_query = ServerSideQuery(query, array_size=array_size)
|
||||
for row in server_side_query:
|
||||
yield row
|
||||
|
||||
|
||||
class _empty_object(object):
|
||||
__slots__ = ()
|
||||
def __nonzero__(self):
|
||||
return False
|
||||
__bool__ = __nonzero__
|
||||
|
||||
__named_cursor__ = _empty_object()
|
||||
|
||||
|
||||
class PostgresqlExtDatabase(PostgresqlDatabase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._register_hstore = kwargs.pop('register_hstore', False)
|
||||
self._server_side_cursors = kwargs.pop('server_side_cursors', False)
|
||||
super(PostgresqlExtDatabase, self).__init__(*args, **kwargs)
|
||||
|
||||
def _connect(self):
|
||||
conn = super(PostgresqlExtDatabase, self)._connect()
|
||||
if self._register_hstore:
|
||||
register_hstore(conn, globally=True)
|
||||
return conn
|
||||
|
||||
def cursor(self, commit=None):
|
||||
if self.is_closed():
|
||||
if self.autoconnect:
|
||||
self.connect()
|
||||
else:
|
||||
raise InterfaceError('Error, database connection not opened.')
|
||||
if commit is __named_cursor__:
|
||||
return self._state.conn.cursor(name=str(uuid.uuid1()))
|
||||
return self._state.conn.cursor()
|
||||
|
||||
def execute(self, query, commit=SENTINEL, named_cursor=False,
|
||||
array_size=None, **context_options):
|
||||
ctx = self.get_sql_context(**context_options)
|
||||
sql, params = ctx.sql(query).query()
|
||||
named_cursor = named_cursor or (self._server_side_cursors and
|
||||
sql[:6].lower() == 'select')
|
||||
if named_cursor:
|
||||
commit = __named_cursor__
|
||||
cursor = self.execute_sql(sql, params, commit=commit)
|
||||
if named_cursor:
|
||||
cursor = FetchManyCursor(cursor, array_size)
|
||||
return cursor
|
@ -0,0 +1,833 @@
|
||||
try:
|
||||
from collections import OrderedDict
|
||||
except ImportError:
|
||||
OrderedDict = dict
|
||||
from collections import namedtuple
|
||||
from inspect import isclass
|
||||
import re
|
||||
|
||||
from peewee import *
|
||||
from peewee import _StringField
|
||||
from peewee import _query_val_transform
|
||||
from peewee import CommaNodeList
|
||||
from peewee import SCOPE_VALUES
|
||||
from peewee import make_snake_case
|
||||
from peewee import text_type
|
||||
try:
|
||||
from pymysql.constants import FIELD_TYPE
|
||||
except ImportError:
|
||||
try:
|
||||
from MySQLdb.constants import FIELD_TYPE
|
||||
except ImportError:
|
||||
FIELD_TYPE = None
|
||||
try:
|
||||
from playhouse import postgres_ext
|
||||
except ImportError:
|
||||
postgres_ext = None
|
||||
try:
|
||||
from playhouse.cockroachdb import CockroachDatabase
|
||||
except ImportError:
|
||||
CockroachDatabase = None
|
||||
|
||||
RESERVED_WORDS = set([
|
||||
'and', 'as', 'assert', 'break', 'class', 'continue', 'def', 'del', 'elif',
|
||||
'else', 'except', 'exec', 'finally', 'for', 'from', 'global', 'if',
|
||||
'import', 'in', 'is', 'lambda', 'not', 'or', 'pass', 'print', 'raise',
|
||||
'return', 'try', 'while', 'with', 'yield',
|
||||
])
|
||||
|
||||
|
||||
class UnknownField(object):
|
||||
pass
|
||||
|
||||
|
||||
class Column(object):
|
||||
"""
|
||||
Store metadata about a database column.
|
||||
"""
|
||||
primary_key_types = (IntegerField, AutoField)
|
||||
|
||||
def __init__(self, name, field_class, raw_column_type, nullable,
|
||||
primary_key=False, column_name=None, index=False,
|
||||
unique=False, default=None, extra_parameters=None):
|
||||
self.name = name
|
||||
self.field_class = field_class
|
||||
self.raw_column_type = raw_column_type
|
||||
self.nullable = nullable
|
||||
self.primary_key = primary_key
|
||||
self.column_name = column_name
|
||||
self.index = index
|
||||
self.unique = unique
|
||||
self.default = default
|
||||
self.extra_parameters = extra_parameters
|
||||
|
||||
# Foreign key metadata.
|
||||
self.rel_model = None
|
||||
self.related_name = None
|
||||
self.to_field = None
|
||||
|
||||
def __repr__(self):
|
||||
attrs = [
|
||||
'field_class',
|
||||
'raw_column_type',
|
||||
'nullable',
|
||||
'primary_key',
|
||||
'column_name']
|
||||
keyword_args = ', '.join(
|
||||
'%s=%s' % (attr, getattr(self, attr))
|
||||
for attr in attrs)
|
||||
return 'Column(%s, %s)' % (self.name, keyword_args)
|
||||
|
||||
def get_field_parameters(self):
|
||||
params = {}
|
||||
if self.extra_parameters is not None:
|
||||
params.update(self.extra_parameters)
|
||||
|
||||
# Set up default attributes.
|
||||
if self.nullable:
|
||||
params['null'] = True
|
||||
if self.field_class is ForeignKeyField or self.name != self.column_name:
|
||||
params['column_name'] = "'%s'" % self.column_name
|
||||
if self.primary_key and not issubclass(self.field_class, AutoField):
|
||||
params['primary_key'] = True
|
||||
if self.default is not None:
|
||||
params['constraints'] = '[SQL("DEFAULT %s")]' % self.default
|
||||
|
||||
# Handle ForeignKeyField-specific attributes.
|
||||
if self.is_foreign_key():
|
||||
params['model'] = self.rel_model
|
||||
if self.to_field:
|
||||
params['field'] = "'%s'" % self.to_field
|
||||
if self.related_name:
|
||||
params['backref'] = "'%s'" % self.related_name
|
||||
|
||||
# Handle indexes on column.
|
||||
if not self.is_primary_key():
|
||||
if self.unique:
|
||||
params['unique'] = 'True'
|
||||
elif self.index and not self.is_foreign_key():
|
||||
params['index'] = 'True'
|
||||
|
||||
return params
|
||||
|
||||
def is_primary_key(self):
|
||||
return self.field_class is AutoField or self.primary_key
|
||||
|
||||
def is_foreign_key(self):
|
||||
return self.field_class is ForeignKeyField
|
||||
|
||||
def is_self_referential_fk(self):
|
||||
return (self.field_class is ForeignKeyField and
|
||||
self.rel_model == "'self'")
|
||||
|
||||
def set_foreign_key(self, foreign_key, model_names, dest=None,
|
||||
related_name=None):
|
||||
self.foreign_key = foreign_key
|
||||
self.field_class = ForeignKeyField
|
||||
if foreign_key.dest_table == foreign_key.table:
|
||||
self.rel_model = "'self'"
|
||||
else:
|
||||
self.rel_model = model_names[foreign_key.dest_table]
|
||||
self.to_field = dest and dest.name or None
|
||||
self.related_name = related_name or None
|
||||
|
||||
def get_field(self):
|
||||
# Generate the field definition for this column.
|
||||
field_params = {}
|
||||
for key, value in self.get_field_parameters().items():
|
||||
if isclass(value) and issubclass(value, Field):
|
||||
value = value.__name__
|
||||
field_params[key] = value
|
||||
|
||||
param_str = ', '.join('%s=%s' % (k, v)
|
||||
for k, v in sorted(field_params.items()))
|
||||
field = '%s = %s(%s)' % (
|
||||
self.name,
|
||||
self.field_class.__name__,
|
||||
param_str)
|
||||
|
||||
if self.field_class is UnknownField:
|
||||
field = '%s # %s' % (field, self.raw_column_type)
|
||||
|
||||
return field
|
||||
|
||||
|
||||
class Metadata(object):
|
||||
column_map = {}
|
||||
extension_import = ''
|
||||
|
||||
def __init__(self, database):
|
||||
self.database = database
|
||||
self.requires_extension = False
|
||||
|
||||
def execute(self, sql, *params):
|
||||
return self.database.execute_sql(sql, params)
|
||||
|
||||
def get_columns(self, table, schema=None):
|
||||
metadata = OrderedDict(
|
||||
(metadata.name, metadata)
|
||||
for metadata in self.database.get_columns(table, schema))
|
||||
|
||||
# Look up the actual column type for each column.
|
||||
column_types, extra_params = self.get_column_types(table, schema)
|
||||
|
||||
# Look up the primary keys.
|
||||
pk_names = self.get_primary_keys(table, schema)
|
||||
if len(pk_names) == 1:
|
||||
pk = pk_names[0]
|
||||
if column_types[pk] is IntegerField:
|
||||
column_types[pk] = AutoField
|
||||
elif column_types[pk] is BigIntegerField:
|
||||
column_types[pk] = BigAutoField
|
||||
|
||||
columns = OrderedDict()
|
||||
for name, column_data in metadata.items():
|
||||
field_class = column_types[name]
|
||||
default = self._clean_default(field_class, column_data.default)
|
||||
|
||||
columns[name] = Column(
|
||||
name,
|
||||
field_class=field_class,
|
||||
raw_column_type=column_data.data_type,
|
||||
nullable=column_data.null,
|
||||
primary_key=column_data.primary_key,
|
||||
column_name=name,
|
||||
default=default,
|
||||
extra_parameters=extra_params.get(name))
|
||||
|
||||
return columns
|
||||
|
||||
def get_column_types(self, table, schema=None):
|
||||
raise NotImplementedError
|
||||
|
||||
def _clean_default(self, field_class, default):
|
||||
if default is None or field_class in (AutoField, BigAutoField) or \
|
||||
default.lower() == 'null':
|
||||
return
|
||||
if issubclass(field_class, _StringField) and \
|
||||
isinstance(default, text_type) and not default.startswith("'"):
|
||||
default = "'%s'" % default
|
||||
return default or "''"
|
||||
|
||||
def get_foreign_keys(self, table, schema=None):
|
||||
return self.database.get_foreign_keys(table, schema)
|
||||
|
||||
def get_primary_keys(self, table, schema=None):
|
||||
return self.database.get_primary_keys(table, schema)
|
||||
|
||||
def get_indexes(self, table, schema=None):
|
||||
return self.database.get_indexes(table, schema)
|
||||
|
||||
|
||||
class PostgresqlMetadata(Metadata):
|
||||
column_map = {
|
||||
16: BooleanField,
|
||||
17: BlobField,
|
||||
20: BigIntegerField,
|
||||
21: SmallIntegerField,
|
||||
23: IntegerField,
|
||||
25: TextField,
|
||||
700: FloatField,
|
||||
701: DoubleField,
|
||||
1042: CharField, # blank-padded CHAR
|
||||
1043: CharField,
|
||||
1082: DateField,
|
||||
1114: DateTimeField,
|
||||
1184: DateTimeField,
|
||||
1083: TimeField,
|
||||
1266: TimeField,
|
||||
1700: DecimalField,
|
||||
2950: UUIDField, # UUID
|
||||
}
|
||||
array_types = {
|
||||
1000: BooleanField,
|
||||
1001: BlobField,
|
||||
1005: SmallIntegerField,
|
||||
1007: IntegerField,
|
||||
1009: TextField,
|
||||
1014: CharField,
|
||||
1015: CharField,
|
||||
1016: BigIntegerField,
|
||||
1115: DateTimeField,
|
||||
1182: DateField,
|
||||
1183: TimeField,
|
||||
2951: UUIDField,
|
||||
}
|
||||
extension_import = 'from playhouse.postgres_ext import *'
|
||||
|
||||
def __init__(self, database):
|
||||
super(PostgresqlMetadata, self).__init__(database)
|
||||
|
||||
if postgres_ext is not None:
|
||||
# Attempt to add types like HStore and JSON.
|
||||
cursor = self.execute('select oid, typname, format_type(oid, NULL)'
|
||||
' from pg_type;')
|
||||
results = cursor.fetchall()
|
||||
|
||||
for oid, typname, formatted_type in results:
|
||||
if typname == 'json':
|
||||
self.column_map[oid] = postgres_ext.JSONField
|
||||
elif typname == 'jsonb':
|
||||
self.column_map[oid] = postgres_ext.BinaryJSONField
|
||||
elif typname == 'hstore':
|
||||
self.column_map[oid] = postgres_ext.HStoreField
|
||||
elif typname == 'tsvector':
|
||||
self.column_map[oid] = postgres_ext.TSVectorField
|
||||
|
||||
for oid in self.array_types:
|
||||
self.column_map[oid] = postgres_ext.ArrayField
|
||||
|
||||
def get_column_types(self, table, schema):
|
||||
column_types = {}
|
||||
extra_params = {}
|
||||
extension_types = set((
|
||||
postgres_ext.ArrayField,
|
||||
postgres_ext.BinaryJSONField,
|
||||
postgres_ext.JSONField,
|
||||
postgres_ext.TSVectorField,
|
||||
postgres_ext.HStoreField)) if postgres_ext is not None else set()
|
||||
|
||||
# Look up the actual column type for each column.
|
||||
identifier = '%s."%s"' % (schema, table)
|
||||
cursor = self.execute(
|
||||
'SELECT attname, atttypid FROM pg_catalog.pg_attribute '
|
||||
'WHERE attrelid = %s::regclass AND attnum > %s', identifier, 0)
|
||||
|
||||
# Store column metadata in dictionary keyed by column name.
|
||||
for name, oid in cursor.fetchall():
|
||||
column_types[name] = self.column_map.get(oid, UnknownField)
|
||||
if column_types[name] in extension_types:
|
||||
self.requires_extension = True
|
||||
if oid in self.array_types:
|
||||
extra_params[name] = {'field_class': self.array_types[oid]}
|
||||
|
||||
return column_types, extra_params
|
||||
|
||||
def get_columns(self, table, schema=None):
|
||||
schema = schema or 'public'
|
||||
return super(PostgresqlMetadata, self).get_columns(table, schema)
|
||||
|
||||
def get_foreign_keys(self, table, schema=None):
|
||||
schema = schema or 'public'
|
||||
return super(PostgresqlMetadata, self).get_foreign_keys(table, schema)
|
||||
|
||||
def get_primary_keys(self, table, schema=None):
|
||||
schema = schema or 'public'
|
||||
return super(PostgresqlMetadata, self).get_primary_keys(table, schema)
|
||||
|
||||
def get_indexes(self, table, schema=None):
|
||||
schema = schema or 'public'
|
||||
return super(PostgresqlMetadata, self).get_indexes(table, schema)
|
||||
|
||||
|
||||
class CockroachDBMetadata(PostgresqlMetadata):
|
||||
# CRDB treats INT the same as BIGINT, so we just map bigint type OIDs to
|
||||
# regular IntegerField.
|
||||
column_map = PostgresqlMetadata.column_map.copy()
|
||||
column_map[20] = IntegerField
|
||||
array_types = PostgresqlMetadata.array_types.copy()
|
||||
array_types[1016] = IntegerField
|
||||
extension_import = 'from playhouse.cockroachdb import *'
|
||||
|
||||
def __init__(self, database):
|
||||
Metadata.__init__(self, database)
|
||||
self.requires_extension = True
|
||||
|
||||
if postgres_ext is not None:
|
||||
# Attempt to add JSON types.
|
||||
cursor = self.execute('select oid, typname, format_type(oid, NULL)'
|
||||
' from pg_type;')
|
||||
results = cursor.fetchall()
|
||||
|
||||
for oid, typname, formatted_type in results:
|
||||
if typname == 'jsonb':
|
||||
self.column_map[oid] = postgres_ext.BinaryJSONField
|
||||
|
||||
for oid in self.array_types:
|
||||
self.column_map[oid] = postgres_ext.ArrayField
|
||||
|
||||
|
||||
class MySQLMetadata(Metadata):
|
||||
if FIELD_TYPE is None:
|
||||
column_map = {}
|
||||
else:
|
||||
column_map = {
|
||||
FIELD_TYPE.BLOB: TextField,
|
||||
FIELD_TYPE.CHAR: CharField,
|
||||
FIELD_TYPE.DATE: DateField,
|
||||
FIELD_TYPE.DATETIME: DateTimeField,
|
||||
FIELD_TYPE.DECIMAL: DecimalField,
|
||||
FIELD_TYPE.DOUBLE: FloatField,
|
||||
FIELD_TYPE.FLOAT: FloatField,
|
||||
FIELD_TYPE.INT24: IntegerField,
|
||||
FIELD_TYPE.LONG_BLOB: TextField,
|
||||
FIELD_TYPE.LONG: IntegerField,
|
||||
FIELD_TYPE.LONGLONG: BigIntegerField,
|
||||
FIELD_TYPE.MEDIUM_BLOB: TextField,
|
||||
FIELD_TYPE.NEWDECIMAL: DecimalField,
|
||||
FIELD_TYPE.SHORT: IntegerField,
|
||||
FIELD_TYPE.STRING: CharField,
|
||||
FIELD_TYPE.TIMESTAMP: DateTimeField,
|
||||
FIELD_TYPE.TIME: TimeField,
|
||||
FIELD_TYPE.TINY_BLOB: TextField,
|
||||
FIELD_TYPE.TINY: IntegerField,
|
||||
FIELD_TYPE.VAR_STRING: CharField,
|
||||
}
|
||||
|
||||
def __init__(self, database, **kwargs):
|
||||
if 'password' in kwargs:
|
||||
kwargs['passwd'] = kwargs.pop('password')
|
||||
super(MySQLMetadata, self).__init__(database, **kwargs)
|
||||
|
||||
def get_column_types(self, table, schema=None):
|
||||
column_types = {}
|
||||
|
||||
# Look up the actual column type for each column.
|
||||
cursor = self.execute('SELECT * FROM `%s` LIMIT 1' % table)
|
||||
|
||||
# Store column metadata in dictionary keyed by column name.
|
||||
for column_description in cursor.description:
|
||||
name, type_code = column_description[:2]
|
||||
column_types[name] = self.column_map.get(type_code, UnknownField)
|
||||
|
||||
return column_types, {}
|
||||
|
||||
|
||||
class SqliteMetadata(Metadata):
|
||||
column_map = {
|
||||
'bigint': BigIntegerField,
|
||||
'blob': BlobField,
|
||||
'bool': BooleanField,
|
||||
'boolean': BooleanField,
|
||||
'char': CharField,
|
||||
'date': DateField,
|
||||
'datetime': DateTimeField,
|
||||
'decimal': DecimalField,
|
||||
'float': FloatField,
|
||||
'integer': IntegerField,
|
||||
'integer unsigned': IntegerField,
|
||||
'int': IntegerField,
|
||||
'long': BigIntegerField,
|
||||
'numeric': DecimalField,
|
||||
'real': FloatField,
|
||||
'smallinteger': IntegerField,
|
||||
'smallint': IntegerField,
|
||||
'smallint unsigned': IntegerField,
|
||||
'text': TextField,
|
||||
'time': TimeField,
|
||||
'varchar': CharField,
|
||||
}
|
||||
|
||||
begin = '(?:["\[\(]+)?'
|
||||
end = '(?:["\]\)]+)?'
|
||||
re_foreign_key = (
|
||||
'(?:FOREIGN KEY\s*)?'
|
||||
'{begin}(.+?){end}\s+(?:.+\s+)?'
|
||||
'references\s+{begin}(.+?){end}'
|
||||
'\s*\(["|\[]?(.+?)["|\]]?\)').format(begin=begin, end=end)
|
||||
re_varchar = r'^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$'
|
||||
|
||||
def _map_col(self, column_type):
|
||||
raw_column_type = column_type.lower()
|
||||
if raw_column_type in self.column_map:
|
||||
field_class = self.column_map[raw_column_type]
|
||||
elif re.search(self.re_varchar, raw_column_type):
|
||||
field_class = CharField
|
||||
else:
|
||||
column_type = re.sub('\(.+\)', '', raw_column_type)
|
||||
if column_type == '':
|
||||
field_class = BareField
|
||||
else:
|
||||
field_class = self.column_map.get(column_type, UnknownField)
|
||||
return field_class
|
||||
|
||||
def get_column_types(self, table, schema=None):
|
||||
column_types = {}
|
||||
columns = self.database.get_columns(table)
|
||||
|
||||
for column in columns:
|
||||
column_types[column.name] = self._map_col(column.data_type)
|
||||
|
||||
return column_types, {}
|
||||
|
||||
|
||||
_DatabaseMetadata = namedtuple('_DatabaseMetadata', (
|
||||
'columns',
|
||||
'primary_keys',
|
||||
'foreign_keys',
|
||||
'model_names',
|
||||
'indexes'))
|
||||
|
||||
|
||||
class DatabaseMetadata(_DatabaseMetadata):
|
||||
def multi_column_indexes(self, table):
|
||||
accum = []
|
||||
for index in self.indexes[table]:
|
||||
if len(index.columns) > 1:
|
||||
field_names = [self.columns[table][column].name
|
||||
for column in index.columns
|
||||
if column in self.columns[table]]
|
||||
accum.append((field_names, index.unique))
|
||||
return accum
|
||||
|
||||
def column_indexes(self, table):
|
||||
accum = {}
|
||||
for index in self.indexes[table]:
|
||||
if len(index.columns) == 1:
|
||||
accum[index.columns[0]] = index.unique
|
||||
return accum
|
||||
|
||||
|
||||
class Introspector(object):
|
||||
pk_classes = [AutoField, IntegerField]
|
||||
|
||||
def __init__(self, metadata, schema=None):
|
||||
self.metadata = metadata
|
||||
self.schema = schema
|
||||
|
||||
def __repr__(self):
|
||||
return '<Introspector: %s>' % self.metadata.database
|
||||
|
||||
@classmethod
|
||||
def from_database(cls, database, schema=None):
|
||||
if CockroachDatabase and isinstance(database, CockroachDatabase):
|
||||
metadata = CockroachDBMetadata(database)
|
||||
elif isinstance(database, PostgresqlDatabase):
|
||||
metadata = PostgresqlMetadata(database)
|
||||
elif isinstance(database, MySQLDatabase):
|
||||
metadata = MySQLMetadata(database)
|
||||
elif isinstance(database, SqliteDatabase):
|
||||
metadata = SqliteMetadata(database)
|
||||
else:
|
||||
raise ValueError('Introspection not supported for %r' % database)
|
||||
return cls(metadata, schema=schema)
|
||||
|
||||
def get_database_class(self):
|
||||
return type(self.metadata.database)
|
||||
|
||||
def get_database_name(self):
|
||||
return self.metadata.database.database
|
||||
|
||||
def get_database_kwargs(self):
|
||||
return self.metadata.database.connect_params
|
||||
|
||||
def get_additional_imports(self):
|
||||
if self.metadata.requires_extension:
|
||||
return '\n' + self.metadata.extension_import
|
||||
return ''
|
||||
|
||||
def make_model_name(self, table, snake_case=True):
|
||||
if snake_case:
|
||||
table = make_snake_case(table)
|
||||
model = re.sub(r'[^\w]+', '', table)
|
||||
model_name = ''.join(sub.title() for sub in model.split('_'))
|
||||
if not model_name[0].isalpha():
|
||||
model_name = 'T' + model_name
|
||||
return model_name
|
||||
|
||||
def make_column_name(self, column, is_foreign_key=False, snake_case=True):
|
||||
column = column.strip()
|
||||
if snake_case:
|
||||
column = make_snake_case(column)
|
||||
column = column.lower()
|
||||
if is_foreign_key:
|
||||
# Strip "_id" from foreign keys, unless the foreign-key happens to
|
||||
# be named "_id", in which case the name is retained.
|
||||
column = re.sub('_id$', '', column) or column
|
||||
|
||||
# Remove characters that are invalid for Python identifiers.
|
||||
column = re.sub(r'[^\w]+', '_', column)
|
||||
if column in RESERVED_WORDS:
|
||||
column += '_'
|
||||
if len(column) and column[0].isdigit():
|
||||
column = '_' + column
|
||||
return column
|
||||
|
||||
def introspect(self, table_names=None, literal_column_names=False,
|
||||
include_views=False, snake_case=True):
|
||||
# Retrieve all the tables in the database.
|
||||
tables = self.metadata.database.get_tables(schema=self.schema)
|
||||
if include_views:
|
||||
views = self.metadata.database.get_views(schema=self.schema)
|
||||
tables.extend([view.name for view in views])
|
||||
|
||||
if table_names is not None:
|
||||
tables = [table for table in tables if table in table_names]
|
||||
table_set = set(tables)
|
||||
|
||||
# Store a mapping of table name -> dictionary of columns.
|
||||
columns = {}
|
||||
|
||||
# Store a mapping of table name -> set of primary key columns.
|
||||
primary_keys = {}
|
||||
|
||||
# Store a mapping of table -> foreign keys.
|
||||
foreign_keys = {}
|
||||
|
||||
# Store a mapping of table name -> model name.
|
||||
model_names = {}
|
||||
|
||||
# Store a mapping of table name -> indexes.
|
||||
indexes = {}
|
||||
|
||||
# Gather the columns for each table.
|
||||
for table in tables:
|
||||
table_indexes = self.metadata.get_indexes(table, self.schema)
|
||||
table_columns = self.metadata.get_columns(table, self.schema)
|
||||
try:
|
||||
foreign_keys[table] = self.metadata.get_foreign_keys(
|
||||
table, self.schema)
|
||||
except ValueError as exc:
|
||||
err(*exc.args)
|
||||
foreign_keys[table] = []
|
||||
else:
|
||||
# If there is a possibility we could exclude a dependent table,
|
||||
# ensure that we introspect it so FKs will work.
|
||||
if table_names is not None:
|
||||
for foreign_key in foreign_keys[table]:
|
||||
if foreign_key.dest_table not in table_set:
|
||||
tables.append(foreign_key.dest_table)
|
||||
table_set.add(foreign_key.dest_table)
|
||||
|
||||
model_names[table] = self.make_model_name(table, snake_case)
|
||||
|
||||
# Collect sets of all the column names as well as all the
|
||||
# foreign-key column names.
|
||||
lower_col_names = set(column_name.lower()
|
||||
for column_name in table_columns)
|
||||
fks = set(fk_col.column for fk_col in foreign_keys[table])
|
||||
|
||||
for col_name, column in table_columns.items():
|
||||
if literal_column_names:
|
||||
new_name = re.sub(r'[^\w]+', '_', col_name)
|
||||
else:
|
||||
new_name = self.make_column_name(col_name, col_name in fks,
|
||||
snake_case)
|
||||
|
||||
# If we have two columns, "parent" and "parent_id", ensure
|
||||
# that when we don't introduce naming conflicts.
|
||||
lower_name = col_name.lower()
|
||||
if lower_name.endswith('_id') and new_name in lower_col_names:
|
||||
new_name = col_name.lower()
|
||||
|
||||
column.name = new_name
|
||||
|
||||
for index in table_indexes:
|
||||
if len(index.columns) == 1:
|
||||
column = index.columns[0]
|
||||
if column in table_columns:
|
||||
table_columns[column].unique = index.unique
|
||||
table_columns[column].index = True
|
||||
|
||||
primary_keys[table] = self.metadata.get_primary_keys(
|
||||
table, self.schema)
|
||||
columns[table] = table_columns
|
||||
indexes[table] = table_indexes
|
||||
|
||||
# Gather all instances where we might have a `related_name` conflict,
|
||||
# either due to multiple FKs on a table pointing to the same table,
|
||||
# or a related_name that would conflict with an existing field.
|
||||
related_names = {}
|
||||
sort_fn = lambda foreign_key: foreign_key.column
|
||||
for table in tables:
|
||||
models_referenced = set()
|
||||
for foreign_key in sorted(foreign_keys[table], key=sort_fn):
|
||||
try:
|
||||
column = columns[table][foreign_key.column]
|
||||
except KeyError:
|
||||
continue
|
||||
|
||||
dest_table = foreign_key.dest_table
|
||||
if dest_table in models_referenced:
|
||||
related_names[column] = '%s_%s_set' % (
|
||||
dest_table,
|
||||
column.name)
|
||||
else:
|
||||
models_referenced.add(dest_table)
|
||||
|
||||
# On the second pass convert all foreign keys.
|
||||
for table in tables:
|
||||
for foreign_key in foreign_keys[table]:
|
||||
src = columns[foreign_key.table][foreign_key.column]
|
||||
try:
|
||||
dest = columns[foreign_key.dest_table][
|
||||
foreign_key.dest_column]
|
||||
except KeyError:
|
||||
dest = None
|
||||
|
||||
src.set_foreign_key(
|
||||
foreign_key=foreign_key,
|
||||
model_names=model_names,
|
||||
dest=dest,
|
||||
related_name=related_names.get(src))
|
||||
|
||||
return DatabaseMetadata(
|
||||
columns,
|
||||
primary_keys,
|
||||
foreign_keys,
|
||||
model_names,
|
||||
indexes)
|
||||
|
||||
def generate_models(self, skip_invalid=False, table_names=None,
|
||||
literal_column_names=False, bare_fields=False,
|
||||
include_views=False):
|
||||
database = self.introspect(table_names, literal_column_names,
|
||||
include_views)
|
||||
models = {}
|
||||
|
||||
class BaseModel(Model):
|
||||
class Meta:
|
||||
database = self.metadata.database
|
||||
schema = self.schema
|
||||
|
||||
def _create_model(table, models):
|
||||
for foreign_key in database.foreign_keys[table]:
|
||||
dest = foreign_key.dest_table
|
||||
|
||||
if dest not in models and dest != table:
|
||||
_create_model(dest, models)
|
||||
|
||||
primary_keys = []
|
||||
columns = database.columns[table]
|
||||
for column_name, column in columns.items():
|
||||
if column.primary_key:
|
||||
primary_keys.append(column.name)
|
||||
|
||||
multi_column_indexes = database.multi_column_indexes(table)
|
||||
column_indexes = database.column_indexes(table)
|
||||
|
||||
class Meta:
|
||||
indexes = multi_column_indexes
|
||||
table_name = table
|
||||
|
||||
# Fix models with multi-column primary keys.
|
||||
composite_key = False
|
||||
if len(primary_keys) == 0:
|
||||
primary_keys = columns.keys()
|
||||
if len(primary_keys) > 1:
|
||||
Meta.primary_key = CompositeKey(*[
|
||||
field.name for col, field in columns.items()
|
||||
if col in primary_keys])
|
||||
composite_key = True
|
||||
|
||||
attrs = {'Meta': Meta}
|
||||
for column_name, column in columns.items():
|
||||
FieldClass = column.field_class
|
||||
if FieldClass is not ForeignKeyField and bare_fields:
|
||||
FieldClass = BareField
|
||||
elif FieldClass is UnknownField:
|
||||
FieldClass = BareField
|
||||
|
||||
params = {
|
||||
'column_name': column_name,
|
||||
'null': column.nullable}
|
||||
if column.primary_key and composite_key:
|
||||
if FieldClass is AutoField:
|
||||
FieldClass = IntegerField
|
||||
params['primary_key'] = False
|
||||
elif column.primary_key and FieldClass is not AutoField:
|
||||
params['primary_key'] = True
|
||||
if column.is_foreign_key():
|
||||
if column.is_self_referential_fk():
|
||||
params['model'] = 'self'
|
||||
else:
|
||||
dest_table = column.foreign_key.dest_table
|
||||
params['model'] = models[dest_table]
|
||||
if column.to_field:
|
||||
params['field'] = column.to_field
|
||||
|
||||
# Generate a unique related name.
|
||||
params['backref'] = '%s_%s_rel' % (table, column_name)
|
||||
|
||||
if column.default is not None:
|
||||
constraint = SQL('DEFAULT %s' % column.default)
|
||||
params['constraints'] = [constraint]
|
||||
|
||||
if column_name in column_indexes and not \
|
||||
column.is_primary_key():
|
||||
if column_indexes[column_name]:
|
||||
params['unique'] = True
|
||||
elif not column.is_foreign_key():
|
||||
params['index'] = True
|
||||
|
||||
attrs[column.name] = FieldClass(**params)
|
||||
|
||||
try:
|
||||
models[table] = type(str(table), (BaseModel,), attrs)
|
||||
except ValueError:
|
||||
if not skip_invalid:
|
||||
raise
|
||||
|
||||
# Actually generate Model classes.
|
||||
for table, model in sorted(database.model_names.items()):
|
||||
if table not in models:
|
||||
_create_model(table, models)
|
||||
|
||||
return models
|
||||
|
||||
|
||||
def introspect(database, schema=None):
|
||||
introspector = Introspector.from_database(database, schema=schema)
|
||||
return introspector.introspect()
|
||||
|
||||
|
||||
def generate_models(database, schema=None, **options):
|
||||
introspector = Introspector.from_database(database, schema=schema)
|
||||
return introspector.generate_models(**options)
|
||||
|
||||
|
||||
def print_model(model, indexes=True, inline_indexes=False):
|
||||
print(model._meta.name)
|
||||
for field in model._meta.sorted_fields:
|
||||
parts = [' %s %s' % (field.name, field.field_type)]
|
||||
if field.primary_key:
|
||||
parts.append(' PK')
|
||||
elif inline_indexes:
|
||||
if field.unique:
|
||||
parts.append(' UNIQUE')
|
||||
elif field.index:
|
||||
parts.append(' INDEX')
|
||||
if isinstance(field, ForeignKeyField):
|
||||
parts.append(' FK: %s.%s' % (field.rel_model.__name__,
|
||||
field.rel_field.name))
|
||||
print(''.join(parts))
|
||||
|
||||
if indexes:
|
||||
index_list = model._meta.fields_to_index()
|
||||
if not index_list:
|
||||
return
|
||||
|
||||
print('\nindex(es)')
|
||||
for index in index_list:
|
||||
parts = [' ']
|
||||
ctx = model._meta.database.get_sql_context()
|
||||
with ctx.scope_values(param='%s', quote='""'):
|
||||
ctx.sql(CommaNodeList(index._expressions))
|
||||
if index._where:
|
||||
ctx.literal(' WHERE ')
|
||||
ctx.sql(index._where)
|
||||
sql, params = ctx.query()
|
||||
|
||||
clean = sql % tuple(map(_query_val_transform, params))
|
||||
parts.append(clean.replace('"', ''))
|
||||
|
||||
if index._unique:
|
||||
parts.append(' UNIQUE')
|
||||
print(''.join(parts))
|
||||
|
||||
|
||||
def get_table_sql(model):
|
||||
sql, params = model._schema._create_table().query()
|
||||
if model._meta.database.param != '%s':
|
||||
sql = sql.replace(model._meta.database.param, '%s')
|
||||
|
||||
# Format and indent the table declaration, simplest possible approach.
|
||||
match_obj = re.match('^(.+?\()(.+)(\).*)', sql)
|
||||
create, columns, extra = match_obj.groups()
|
||||
indented = ',\n'.join(' %s' % column for column in columns.split(', '))
|
||||
|
||||
clean = '\n'.join((create, indented, extra)).strip()
|
||||
return clean % tuple(map(_query_val_transform, params))
|
||||
|
||||
def print_table_sql(model):
|
||||
print(get_table_sql(model))
|
@ -0,0 +1,252 @@
|
||||
from peewee import *
|
||||
from peewee import Alias
|
||||
from peewee import CompoundSelectQuery
|
||||
from peewee import SENTINEL
|
||||
from peewee import callable_
|
||||
|
||||
|
||||
_clone_set = lambda s: set(s) if s else set()
|
||||
|
||||
|
||||
def model_to_dict(model, recurse=True, backrefs=False, only=None,
|
||||
exclude=None, seen=None, extra_attrs=None,
|
||||
fields_from_query=None, max_depth=None, manytomany=False):
|
||||
"""
|
||||
Convert a model instance (and any related objects) to a dictionary.
|
||||
|
||||
:param bool recurse: Whether foreign-keys should be recursed.
|
||||
:param bool backrefs: Whether lists of related objects should be recursed.
|
||||
:param only: A list (or set) of field instances indicating which fields
|
||||
should be included.
|
||||
:param exclude: A list (or set) of field instances that should be
|
||||
excluded from the dictionary.
|
||||
:param list extra_attrs: Names of model instance attributes or methods
|
||||
that should be included.
|
||||
:param SelectQuery fields_from_query: Query that was source of model. Take
|
||||
fields explicitly selected by the query and serialize them.
|
||||
:param int max_depth: Maximum depth to recurse, value <= 0 means no max.
|
||||
:param bool manytomany: Process many-to-many fields.
|
||||
"""
|
||||
max_depth = -1 if max_depth is None else max_depth
|
||||
if max_depth == 0:
|
||||
recurse = False
|
||||
|
||||
only = _clone_set(only)
|
||||
extra_attrs = _clone_set(extra_attrs)
|
||||
should_skip = lambda n: (n in exclude) or (only and (n not in only))
|
||||
|
||||
if fields_from_query is not None:
|
||||
for item in fields_from_query._returning:
|
||||
if isinstance(item, Field):
|
||||
only.add(item)
|
||||
elif isinstance(item, Alias):
|
||||
extra_attrs.add(item._alias)
|
||||
|
||||
data = {}
|
||||
exclude = _clone_set(exclude)
|
||||
seen = _clone_set(seen)
|
||||
exclude |= seen
|
||||
model_class = type(model)
|
||||
|
||||
if manytomany:
|
||||
for name, m2m in model._meta.manytomany.items():
|
||||
if should_skip(name):
|
||||
continue
|
||||
|
||||
exclude.update((m2m, m2m.rel_model._meta.manytomany[m2m.backref]))
|
||||
for fkf in m2m.through_model._meta.refs:
|
||||
exclude.add(fkf)
|
||||
|
||||
accum = []
|
||||
for rel_obj in getattr(model, name):
|
||||
accum.append(model_to_dict(
|
||||
rel_obj,
|
||||
recurse=recurse,
|
||||
backrefs=backrefs,
|
||||
only=only,
|
||||
exclude=exclude,
|
||||
max_depth=max_depth - 1))
|
||||
data[name] = accum
|
||||
|
||||
for field in model._meta.sorted_fields:
|
||||
if should_skip(field):
|
||||
continue
|
||||
|
||||
field_data = model.__data__.get(field.name)
|
||||
if isinstance(field, ForeignKeyField) and recurse:
|
||||
if field_data is not None:
|
||||
seen.add(field)
|
||||
rel_obj = getattr(model, field.name)
|
||||
field_data = model_to_dict(
|
||||
rel_obj,
|
||||
recurse=recurse,
|
||||
backrefs=backrefs,
|
||||
only=only,
|
||||
exclude=exclude,
|
||||
seen=seen,
|
||||
max_depth=max_depth - 1)
|
||||
else:
|
||||
field_data = None
|
||||
|
||||
data[field.name] = field_data
|
||||
|
||||
if extra_attrs:
|
||||
for attr_name in extra_attrs:
|
||||
attr = getattr(model, attr_name)
|
||||
if callable_(attr):
|
||||
data[attr_name] = attr()
|
||||
else:
|
||||
data[attr_name] = attr
|
||||
|
||||
if backrefs and recurse:
|
||||
for foreign_key, rel_model in model._meta.backrefs.items():
|
||||
if foreign_key.backref == '+': continue
|
||||
descriptor = getattr(model_class, foreign_key.backref)
|
||||
if descriptor in exclude or foreign_key in exclude:
|
||||
continue
|
||||
if only and (descriptor not in only) and (foreign_key not in only):
|
||||
continue
|
||||
|
||||
accum = []
|
||||
exclude.add(foreign_key)
|
||||
related_query = getattr(model, foreign_key.backref)
|
||||
|
||||
for rel_obj in related_query:
|
||||
accum.append(model_to_dict(
|
||||
rel_obj,
|
||||
recurse=recurse,
|
||||
backrefs=backrefs,
|
||||
only=only,
|
||||
exclude=exclude,
|
||||
max_depth=max_depth - 1))
|
||||
|
||||
data[foreign_key.backref] = accum
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def update_model_from_dict(instance, data, ignore_unknown=False):
|
||||
meta = instance._meta
|
||||
backrefs = dict([(fk.backref, fk) for fk in meta.backrefs])
|
||||
|
||||
for key, value in data.items():
|
||||
if key in meta.combined:
|
||||
field = meta.combined[key]
|
||||
is_backref = False
|
||||
elif key in backrefs:
|
||||
field = backrefs[key]
|
||||
is_backref = True
|
||||
elif ignore_unknown:
|
||||
setattr(instance, key, value)
|
||||
continue
|
||||
else:
|
||||
raise AttributeError('Unrecognized attribute "%s" for model '
|
||||
'class %s.' % (key, type(instance)))
|
||||
|
||||
is_foreign_key = isinstance(field, ForeignKeyField)
|
||||
|
||||
if not is_backref and is_foreign_key and isinstance(value, dict):
|
||||
try:
|
||||
rel_instance = instance.__rel__[field.name]
|
||||
except KeyError:
|
||||
rel_instance = field.rel_model()
|
||||
setattr(
|
||||
instance,
|
||||
field.name,
|
||||
update_model_from_dict(rel_instance, value, ignore_unknown))
|
||||
elif is_backref and isinstance(value, (list, tuple)):
|
||||
instances = [
|
||||
dict_to_model(field.model, row_data, ignore_unknown)
|
||||
for row_data in value]
|
||||
for rel_instance in instances:
|
||||
setattr(rel_instance, field.name, instance)
|
||||
setattr(instance, field.backref, instances)
|
||||
else:
|
||||
setattr(instance, field.name, value)
|
||||
|
||||
return instance
|
||||
|
||||
|
||||
def dict_to_model(model_class, data, ignore_unknown=False):
|
||||
return update_model_from_dict(model_class(), data, ignore_unknown)
|
||||
|
||||
|
||||
class ReconnectMixin(object):
|
||||
"""
|
||||
Mixin class that attempts to automatically reconnect to the database under
|
||||
certain error conditions.
|
||||
|
||||
For example, MySQL servers will typically close connections that are idle
|
||||
for 28800 seconds ("wait_timeout" setting). If your application makes use
|
||||
of long-lived connections, you may find your connections are closed after
|
||||
a period of no activity. This mixin will attempt to reconnect automatically
|
||||
when these errors occur.
|
||||
|
||||
This mixin class probably should not be used with Postgres (unless you
|
||||
REALLY know what you are doing) and definitely has no business being used
|
||||
with Sqlite. If you wish to use with Postgres, you will need to adapt the
|
||||
`reconnect_errors` attribute to something appropriate for Postgres.
|
||||
"""
|
||||
reconnect_errors = (
|
||||
# Error class, error message fragment (or empty string for all).
|
||||
(OperationalError, '2006'), # MySQL server has gone away.
|
||||
(OperationalError, '2013'), # Lost connection to MySQL server.
|
||||
(OperationalError, '2014'), # Commands out of sync.
|
||||
|
||||
# mysql-connector raises a slightly different error when an idle
|
||||
# connection is terminated by the server. This is equivalent to 2013.
|
||||
(OperationalError, 'MySQL Connection not available.'),
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ReconnectMixin, self).__init__(*args, **kwargs)
|
||||
|
||||
# Normalize the reconnect errors to a more efficient data-structure.
|
||||
self._reconnect_errors = {}
|
||||
for exc_class, err_fragment in self.reconnect_errors:
|
||||
self._reconnect_errors.setdefault(exc_class, [])
|
||||
self._reconnect_errors[exc_class].append(err_fragment.lower())
|
||||
|
||||
def execute_sql(self, sql, params=None, commit=SENTINEL):
|
||||
try:
|
||||
return super(ReconnectMixin, self).execute_sql(sql, params, commit)
|
||||
except Exception as exc:
|
||||
exc_class = type(exc)
|
||||
if exc_class not in self._reconnect_errors:
|
||||
raise exc
|
||||
|
||||
exc_repr = str(exc).lower()
|
||||
for err_fragment in self._reconnect_errors[exc_class]:
|
||||
if err_fragment in exc_repr:
|
||||
break
|
||||
else:
|
||||
raise exc
|
||||
|
||||
if not self.is_closed():
|
||||
self.close()
|
||||
self.connect()
|
||||
|
||||
return super(ReconnectMixin, self).execute_sql(sql, params, commit)
|
||||
|
||||
|
||||
def resolve_multimodel_query(query, key='_model_identifier'):
|
||||
mapping = {}
|
||||
accum = [query]
|
||||
while accum:
|
||||
curr = accum.pop()
|
||||
if isinstance(curr, CompoundSelectQuery):
|
||||
accum.extend((curr.lhs, curr.rhs))
|
||||
continue
|
||||
|
||||
model_class = curr.model
|
||||
name = model_class._meta.table_name
|
||||
mapping[name] = model_class
|
||||
curr._returning.append(Value(name).alias(key))
|
||||
|
||||
def wrapped_iterator():
|
||||
for row in query.dicts().iterator():
|
||||
identifier = row.pop(key)
|
||||
model = mapping[identifier]
|
||||
yield model(**row)
|
||||
|
||||
return wrapped_iterator()
|
@ -0,0 +1,79 @@
|
||||
"""
|
||||
Provide django-style hooks for model events.
|
||||
"""
|
||||
from peewee import Model as _Model
|
||||
|
||||
|
||||
class Signal(object):
|
||||
def __init__(self):
|
||||
self._flush()
|
||||
|
||||
def _flush(self):
|
||||
self._receivers = set()
|
||||
self._receiver_list = []
|
||||
|
||||
def connect(self, receiver, name=None, sender=None):
|
||||
name = name or receiver.__name__
|
||||
key = (name, sender)
|
||||
if key not in self._receivers:
|
||||
self._receivers.add(key)
|
||||
self._receiver_list.append((name, receiver, sender))
|
||||
else:
|
||||
raise ValueError('receiver named %s (for sender=%s) already '
|
||||
'connected' % (name, sender or 'any'))
|
||||
|
||||
def disconnect(self, receiver=None, name=None, sender=None):
|
||||
if receiver:
|
||||
name = name or receiver.__name__
|
||||
if not name:
|
||||
raise ValueError('a receiver or a name must be provided')
|
||||
|
||||
key = (name, sender)
|
||||
if key not in self._receivers:
|
||||
raise ValueError('receiver named %s for sender=%s not found.' %
|
||||
(name, sender or 'any'))
|
||||
|
||||
self._receivers.remove(key)
|
||||
self._receiver_list = [(n, r, s) for n, r, s in self._receiver_list
|
||||
if n != name and s != sender]
|
||||
|
||||
def __call__(self, name=None, sender=None):
|
||||
def decorator(fn):
|
||||
self.connect(fn, name, sender)
|
||||
return fn
|
||||
return decorator
|
||||
|
||||
def send(self, instance, *args, **kwargs):
|
||||
sender = type(instance)
|
||||
responses = []
|
||||
for n, r, s in self._receiver_list:
|
||||
if s is None or isinstance(instance, s):
|
||||
responses.append((r, r(sender, instance, *args, **kwargs)))
|
||||
return responses
|
||||
|
||||
|
||||
pre_save = Signal()
|
||||
post_save = Signal()
|
||||
pre_delete = Signal()
|
||||
post_delete = Signal()
|
||||
pre_init = Signal()
|
||||
|
||||
|
||||
class Model(_Model):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Model, self).__init__(*args, **kwargs)
|
||||
pre_init.send(self)
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
pk_value = self._pk if self._meta.primary_key else True
|
||||
created = kwargs.get('force_insert', False) or not bool(pk_value)
|
||||
pre_save.send(self, created=created)
|
||||
ret = super(Model, self).save(*args, **kwargs)
|
||||
post_save.send(self, created=created)
|
||||
return ret
|
||||
|
||||
def delete_instance(self, *args, **kwargs):
|
||||
pre_delete.send(self)
|
||||
ret = super(Model, self).delete_instance(*args, **kwargs)
|
||||
post_delete.send(self)
|
||||
return ret
|
@ -0,0 +1,103 @@
|
||||
"""
|
||||
Peewee integration with pysqlcipher.
|
||||
|
||||
Project page: https://github.com/leapcode/pysqlcipher/
|
||||
|
||||
**WARNING!!! EXPERIMENTAL!!!**
|
||||
|
||||
* Although this extention's code is short, it has not been properly
|
||||
peer-reviewed yet and may have introduced vulnerabilities.
|
||||
|
||||
Also note that this code relies on pysqlcipher and sqlcipher, and
|
||||
the code there might have vulnerabilities as well, but since these
|
||||
are widely used crypto modules, we can expect "short zero days" there.
|
||||
|
||||
Example usage:
|
||||
|
||||
from peewee.playground.ciphersql_ext import SqlCipherDatabase
|
||||
db = SqlCipherDatabase('/path/to/my.db', passphrase="don'tuseme4real")
|
||||
|
||||
* `passphrase`: should be "long enough".
|
||||
Note that *length beats vocabulary* (much exponential), and even
|
||||
a lowercase-only passphrase like easytorememberyethardforotherstoguess
|
||||
packs more noise than 8 random printable characters and *can* be memorized.
|
||||
|
||||
When opening an existing database, passphrase should be the one used when the
|
||||
database was created. If the passphrase is incorrect, an exception will only be
|
||||
raised **when you access the database**.
|
||||
|
||||
If you need to ask for an interactive passphrase, here's example code you can
|
||||
put after the `db = ...` line:
|
||||
|
||||
try: # Just access the database so that it checks the encryption.
|
||||
db.get_tables()
|
||||
# We're looking for a DatabaseError with a specific error message.
|
||||
except peewee.DatabaseError as e:
|
||||
# Check whether the message *means* "passphrase is wrong"
|
||||
if e.args[0] == 'file is encrypted or is not a database':
|
||||
raise Exception('Developer should Prompt user for passphrase '
|
||||
'again.')
|
||||
else:
|
||||
# A different DatabaseError. Raise it.
|
||||
raise e
|
||||
|
||||
See a more elaborate example with this code at
|
||||
https://gist.github.com/thedod/11048875
|
||||
"""
|
||||
import datetime
|
||||
import decimal
|
||||
import sys
|
||||
|
||||
from peewee import *
|
||||
from playhouse.sqlite_ext import SqliteExtDatabase
|
||||
if sys.version_info[0] != 3:
|
||||
from pysqlcipher import dbapi2 as sqlcipher
|
||||
else:
|
||||
try:
|
||||
from sqlcipher3 import dbapi2 as sqlcipher
|
||||
except ImportError:
|
||||
from pysqlcipher3 import dbapi2 as sqlcipher
|
||||
|
||||
sqlcipher.register_adapter(decimal.Decimal, str)
|
||||
sqlcipher.register_adapter(datetime.date, str)
|
||||
sqlcipher.register_adapter(datetime.time, str)
|
||||
|
||||
|
||||
class _SqlCipherDatabase(object):
|
||||
def _connect(self):
|
||||
params = dict(self.connect_params)
|
||||
passphrase = params.pop('passphrase', '').replace("'", "''")
|
||||
|
||||
conn = sqlcipher.connect(self.database, isolation_level=None, **params)
|
||||
try:
|
||||
if passphrase:
|
||||
conn.execute("PRAGMA key='%s'" % passphrase)
|
||||
self._add_conn_hooks(conn)
|
||||
except:
|
||||
conn.close()
|
||||
raise
|
||||
return conn
|
||||
|
||||
def set_passphrase(self, passphrase):
|
||||
if not self.is_closed():
|
||||
raise ImproperlyConfigured('Cannot set passphrase when database '
|
||||
'is open. To change passphrase of an '
|
||||
'open database use the rekey() method.')
|
||||
|
||||
self.connect_params['passphrase'] = passphrase
|
||||
|
||||
def rekey(self, passphrase):
|
||||
if self.is_closed():
|
||||
self.connect()
|
||||
|
||||
self.execute_sql("PRAGMA rekey='%s'" % passphrase.replace("'", "''"))
|
||||
self.connect_params['passphrase'] = passphrase
|
||||
return True
|
||||
|
||||
|
||||
class SqlCipherDatabase(_SqlCipherDatabase, SqliteDatabase):
|
||||
pass
|
||||
|
||||
|
||||
class SqlCipherExtDatabase(_SqlCipherDatabase, SqliteExtDatabase):
|
||||
pass
|
@ -0,0 +1,123 @@
|
||||
from peewee import *
|
||||
from playhouse.sqlite_ext import JSONField
|
||||
|
||||
|
||||
class BaseChangeLog(Model):
|
||||
timestamp = DateTimeField(constraints=[SQL('DEFAULT CURRENT_TIMESTAMP')])
|
||||
action = TextField()
|
||||
table = TextField()
|
||||
primary_key = IntegerField()
|
||||
changes = JSONField()
|
||||
|
||||
|
||||
class ChangeLog(object):
|
||||
# Model class that will serve as the base for the changelog. This model
|
||||
# will be subclassed and mapped to your application database.
|
||||
base_model = BaseChangeLog
|
||||
|
||||
# Template for the triggers that handle updating the changelog table.
|
||||
# table: table name
|
||||
# action: insert / update / delete
|
||||
# new_old: NEW or OLD (OLD is for DELETE)
|
||||
# primary_key: table primary key column name
|
||||
# column_array: output of build_column_array()
|
||||
# change_table: changelog table name
|
||||
template = """CREATE TRIGGER IF NOT EXISTS %(table)s_changes_%(action)s
|
||||
AFTER %(action)s ON %(table)s
|
||||
BEGIN
|
||||
INSERT INTO %(change_table)s
|
||||
("action", "table", "primary_key", "changes")
|
||||
SELECT
|
||||
'%(action)s', '%(table)s', %(new_old)s."%(primary_key)s", "changes"
|
||||
FROM (
|
||||
SELECT json_group_object(
|
||||
col,
|
||||
json_array("oldval", "newval")) AS "changes"
|
||||
FROM (
|
||||
SELECT json_extract(value, '$[0]') as "col",
|
||||
json_extract(value, '$[1]') as "oldval",
|
||||
json_extract(value, '$[2]') as "newval"
|
||||
FROM json_each(json_array(%(column_array)s))
|
||||
WHERE "oldval" IS NOT "newval"
|
||||
)
|
||||
);
|
||||
END;"""
|
||||
|
||||
drop_template = 'DROP TRIGGER IF EXISTS %(table)s_changes_%(action)s'
|
||||
|
||||
_actions = ('INSERT', 'UPDATE', 'DELETE')
|
||||
|
||||
def __init__(self, db, table_name='changelog'):
|
||||
self.db = db
|
||||
self.table_name = table_name
|
||||
|
||||
def _build_column_array(self, model, use_old, use_new, skip_fields=None):
|
||||
# Builds a list of SQL expressions for each field we are tracking. This
|
||||
# is used as the data source for change tracking in our trigger.
|
||||
col_array = []
|
||||
for field in model._meta.sorted_fields:
|
||||
if field.primary_key:
|
||||
continue
|
||||
|
||||
if skip_fields is not None and field.name in skip_fields:
|
||||
continue
|
||||
|
||||
column = field.column_name
|
||||
new = 'NULL' if not use_new else 'NEW."%s"' % column
|
||||
old = 'NULL' if not use_old else 'OLD."%s"' % column
|
||||
|
||||
if isinstance(field, JSONField):
|
||||
# Ensure that values are cast to JSON so that the serialization
|
||||
# is preserved when calculating the old / new.
|
||||
if use_old: old = 'json(%s)' % old
|
||||
if use_new: new = 'json(%s)' % new
|
||||
|
||||
col_array.append("json_array('%s', %s, %s)" % (column, old, new))
|
||||
|
||||
return ', '.join(col_array)
|
||||
|
||||
def trigger_sql(self, model, action, skip_fields=None):
|
||||
assert action in self._actions
|
||||
use_old = action != 'INSERT'
|
||||
use_new = action != 'DELETE'
|
||||
cols = self._build_column_array(model, use_old, use_new, skip_fields)
|
||||
return self.template % {
|
||||
'table': model._meta.table_name,
|
||||
'action': action,
|
||||
'new_old': 'NEW' if action != 'DELETE' else 'OLD',
|
||||
'primary_key': model._meta.primary_key.column_name,
|
||||
'column_array': cols,
|
||||
'change_table': self.table_name}
|
||||
|
||||
def drop_trigger_sql(self, model, action):
|
||||
assert action in self._actions
|
||||
return self.drop_template % {
|
||||
'table': model._meta.table_name,
|
||||
'action': action}
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
if not hasattr(self, '_changelog_model'):
|
||||
class ChangeLog(self.base_model):
|
||||
class Meta:
|
||||
database = self.db
|
||||
table_name = self.table_name
|
||||
self._changelog_model = ChangeLog
|
||||
|
||||
return self._changelog_model
|
||||
|
||||
def install(self, model, skip_fields=None, drop=True, insert=True,
|
||||
update=True, delete=True, create_table=True):
|
||||
ChangeLog = self.model
|
||||
if create_table:
|
||||
ChangeLog.create_table()
|
||||
|
||||
actions = list(zip((insert, update, delete), self._actions))
|
||||
if drop:
|
||||
for _, action in actions:
|
||||
self.db.execute_sql(self.drop_trigger_sql(model, action))
|
||||
|
||||
for enabled, action in actions:
|
||||
if enabled:
|
||||
sql = self.trigger_sql(model, action, skip_fields)
|
||||
self.db.execute_sql(sql)
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,536 @@
|
||||
import datetime
|
||||
import hashlib
|
||||
import heapq
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
import zlib
|
||||
try:
|
||||
from collections import Counter
|
||||
except ImportError:
|
||||
Counter = None
|
||||
try:
|
||||
from urlparse import urlparse
|
||||
except ImportError:
|
||||
from urllib.parse import urlparse
|
||||
|
||||
try:
|
||||
from playhouse._sqlite_ext import TableFunction
|
||||
except ImportError:
|
||||
TableFunction = None
|
||||
|
||||
|
||||
SQLITE_DATETIME_FORMATS = (
|
||||
'%Y-%m-%d %H:%M:%S',
|
||||
'%Y-%m-%d %H:%M:%S.%f',
|
||||
'%Y-%m-%d',
|
||||
'%H:%M:%S',
|
||||
'%H:%M:%S.%f',
|
||||
'%H:%M')
|
||||
|
||||
from peewee import format_date_time
|
||||
|
||||
def format_date_time_sqlite(date_value):
|
||||
return format_date_time(date_value, SQLITE_DATETIME_FORMATS)
|
||||
|
||||
try:
|
||||
from playhouse import _sqlite_udf as cython_udf
|
||||
except ImportError:
|
||||
cython_udf = None
|
||||
|
||||
|
||||
# Group udf by function.
|
||||
CONTROL_FLOW = 'control_flow'
|
||||
DATE = 'date'
|
||||
FILE = 'file'
|
||||
HELPER = 'helpers'
|
||||
MATH = 'math'
|
||||
STRING = 'string'
|
||||
|
||||
AGGREGATE_COLLECTION = {}
|
||||
TABLE_FUNCTION_COLLECTION = {}
|
||||
UDF_COLLECTION = {}
|
||||
|
||||
|
||||
class synchronized_dict(dict):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(synchronized_dict, self).__init__(*args, **kwargs)
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def __getitem__(self, key):
|
||||
with self._lock:
|
||||
return super(synchronized_dict, self).__getitem__(key)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
with self._lock:
|
||||
return super(synchronized_dict, self).__setitem__(key, value)
|
||||
|
||||
def __delitem__(self, key):
|
||||
with self._lock:
|
||||
return super(synchronized_dict, self).__delitem__(key)
|
||||
|
||||
|
||||
STATE = synchronized_dict()
|
||||
SETTINGS = synchronized_dict()
|
||||
|
||||
# Class and function decorators.
|
||||
def aggregate(*groups):
|
||||
def decorator(klass):
|
||||
for group in groups:
|
||||
AGGREGATE_COLLECTION.setdefault(group, [])
|
||||
AGGREGATE_COLLECTION[group].append(klass)
|
||||
return klass
|
||||
return decorator
|
||||
|
||||
def table_function(*groups):
|
||||
def decorator(klass):
|
||||
for group in groups:
|
||||
TABLE_FUNCTION_COLLECTION.setdefault(group, [])
|
||||
TABLE_FUNCTION_COLLECTION[group].append(klass)
|
||||
return klass
|
||||
return decorator
|
||||
|
||||
def udf(*groups):
|
||||
def decorator(fn):
|
||||
for group in groups:
|
||||
UDF_COLLECTION.setdefault(group, [])
|
||||
UDF_COLLECTION[group].append(fn)
|
||||
return fn
|
||||
return decorator
|
||||
|
||||
# Register aggregates / functions with connection.
|
||||
def register_aggregate_groups(db, *groups):
|
||||
seen = set()
|
||||
for group in groups:
|
||||
klasses = AGGREGATE_COLLECTION.get(group, ())
|
||||
for klass in klasses:
|
||||
name = getattr(klass, 'name', klass.__name__)
|
||||
if name not in seen:
|
||||
seen.add(name)
|
||||
db.register_aggregate(klass, name)
|
||||
|
||||
def register_table_function_groups(db, *groups):
|
||||
seen = set()
|
||||
for group in groups:
|
||||
klasses = TABLE_FUNCTION_COLLECTION.get(group, ())
|
||||
for klass in klasses:
|
||||
if klass.name not in seen:
|
||||
seen.add(klass.name)
|
||||
db.register_table_function(klass)
|
||||
|
||||
def register_udf_groups(db, *groups):
|
||||
seen = set()
|
||||
for group in groups:
|
||||
functions = UDF_COLLECTION.get(group, ())
|
||||
for function in functions:
|
||||
name = function.__name__
|
||||
if name not in seen:
|
||||
seen.add(name)
|
||||
db.register_function(function, name)
|
||||
|
||||
def register_groups(db, *groups):
|
||||
register_aggregate_groups(db, *groups)
|
||||
register_table_function_groups(db, *groups)
|
||||
register_udf_groups(db, *groups)
|
||||
|
||||
def register_all(db):
|
||||
register_aggregate_groups(db, *AGGREGATE_COLLECTION)
|
||||
register_table_function_groups(db, *TABLE_FUNCTION_COLLECTION)
|
||||
register_udf_groups(db, *UDF_COLLECTION)
|
||||
|
||||
|
||||
# Begin actual user-defined functions and aggregates.
|
||||
|
||||
# Scalar functions.
|
||||
@udf(CONTROL_FLOW)
|
||||
def if_then_else(cond, truthy, falsey=None):
|
||||
if cond:
|
||||
return truthy
|
||||
return falsey
|
||||
|
||||
@udf(DATE)
|
||||
def strip_tz(date_str):
|
||||
date_str = date_str.replace('T', ' ')
|
||||
tz_idx1 = date_str.find('+')
|
||||
if tz_idx1 != -1:
|
||||
return date_str[:tz_idx1]
|
||||
tz_idx2 = date_str.find('-')
|
||||
if tz_idx2 > 13:
|
||||
return date_str[:tz_idx2]
|
||||
return date_str
|
||||
|
||||
@udf(DATE)
|
||||
def human_delta(nseconds, glue=', '):
|
||||
parts = (
|
||||
(86400 * 365, 'year'),
|
||||
(86400 * 30, 'month'),
|
||||
(86400 * 7, 'week'),
|
||||
(86400, 'day'),
|
||||
(3600, 'hour'),
|
||||
(60, 'minute'),
|
||||
(1, 'second'),
|
||||
)
|
||||
accum = []
|
||||
for offset, name in parts:
|
||||
val, nseconds = divmod(nseconds, offset)
|
||||
if val:
|
||||
suffix = val != 1 and 's' or ''
|
||||
accum.append('%s %s%s' % (val, name, suffix))
|
||||
if not accum:
|
||||
return '0 seconds'
|
||||
return glue.join(accum)
|
||||
|
||||
@udf(FILE)
|
||||
def file_ext(filename):
|
||||
try:
|
||||
res = os.path.splitext(filename)
|
||||
except ValueError:
|
||||
return None
|
||||
return res[1]
|
||||
|
||||
@udf(FILE)
|
||||
def file_read(filename):
|
||||
try:
|
||||
with open(filename) as fh:
|
||||
return fh.read()
|
||||
except:
|
||||
pass
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
@udf(HELPER)
|
||||
def gzip(data, compression=9):
|
||||
return buffer(zlib.compress(data, compression))
|
||||
|
||||
@udf(HELPER)
|
||||
def gunzip(data):
|
||||
return zlib.decompress(data)
|
||||
else:
|
||||
@udf(HELPER)
|
||||
def gzip(data, compression=9):
|
||||
if isinstance(data, str):
|
||||
data = bytes(data.encode('raw_unicode_escape'))
|
||||
return zlib.compress(data, compression)
|
||||
|
||||
@udf(HELPER)
|
||||
def gunzip(data):
|
||||
return zlib.decompress(data)
|
||||
|
||||
@udf(HELPER)
|
||||
def hostname(url):
|
||||
parse_result = urlparse(url)
|
||||
if parse_result:
|
||||
return parse_result.netloc
|
||||
|
||||
@udf(HELPER)
|
||||
def toggle(key):
|
||||
key = key.lower()
|
||||
STATE[key] = ret = not STATE.get(key)
|
||||
return ret
|
||||
|
||||
@udf(HELPER)
|
||||
def setting(key, value=None):
|
||||
if value is None:
|
||||
return SETTINGS.get(key)
|
||||
else:
|
||||
SETTINGS[key] = value
|
||||
return value
|
||||
|
||||
@udf(HELPER)
|
||||
def clear_settings():
|
||||
SETTINGS.clear()
|
||||
|
||||
@udf(HELPER)
|
||||
def clear_toggles():
|
||||
STATE.clear()
|
||||
|
||||
@udf(MATH)
|
||||
def randomrange(start, end=None, step=None):
|
||||
if end is None:
|
||||
start, end = 0, start
|
||||
elif step is None:
|
||||
step = 1
|
||||
return random.randrange(start, end, step)
|
||||
|
||||
@udf(MATH)
|
||||
def gauss_distribution(mean, sigma):
|
||||
try:
|
||||
return random.gauss(mean, sigma)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
@udf(MATH)
|
||||
def sqrt(n):
|
||||
try:
|
||||
return math.sqrt(n)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
@udf(MATH)
|
||||
def tonumber(s):
|
||||
try:
|
||||
return int(s)
|
||||
except ValueError:
|
||||
try:
|
||||
return float(s)
|
||||
except:
|
||||
return None
|
||||
|
||||
@udf(STRING)
|
||||
def substr_count(haystack, needle):
|
||||
if not haystack or not needle:
|
||||
return 0
|
||||
return haystack.count(needle)
|
||||
|
||||
@udf(STRING)
|
||||
def strip_chars(haystack, chars):
|
||||
return haystack.strip(chars)
|
||||
|
||||
def _hash(constructor, *args):
|
||||
hash_obj = constructor()
|
||||
for arg in args:
|
||||
hash_obj.update(arg)
|
||||
return hash_obj.hexdigest()
|
||||
|
||||
# Aggregates.
|
||||
class _heap_agg(object):
|
||||
def __init__(self):
|
||||
self.heap = []
|
||||
self.ct = 0
|
||||
|
||||
def process(self, value):
|
||||
return value
|
||||
|
||||
def step(self, value):
|
||||
self.ct += 1
|
||||
heapq.heappush(self.heap, self.process(value))
|
||||
|
||||
class _datetime_heap_agg(_heap_agg):
|
||||
def process(self, value):
|
||||
return format_date_time_sqlite(value)
|
||||
|
||||
if sys.version_info[:2] == (2, 6):
|
||||
def total_seconds(td):
|
||||
return (td.seconds +
|
||||
(td.days * 86400) +
|
||||
(td.microseconds / (10.**6)))
|
||||
else:
|
||||
total_seconds = lambda td: td.total_seconds()
|
||||
|
||||
@aggregate(DATE)
|
||||
class mintdiff(_datetime_heap_agg):
|
||||
def finalize(self):
|
||||
dtp = min_diff = None
|
||||
while self.heap:
|
||||
if min_diff is None:
|
||||
if dtp is None:
|
||||
dtp = heapq.heappop(self.heap)
|
||||
continue
|
||||
dt = heapq.heappop(self.heap)
|
||||
diff = dt - dtp
|
||||
if min_diff is None or min_diff > diff:
|
||||
min_diff = diff
|
||||
dtp = dt
|
||||
if min_diff is not None:
|
||||
return total_seconds(min_diff)
|
||||
|
||||
@aggregate(DATE)
|
||||
class avgtdiff(_datetime_heap_agg):
|
||||
def finalize(self):
|
||||
if self.ct < 1:
|
||||
return
|
||||
elif self.ct == 1:
|
||||
return 0
|
||||
|
||||
total = ct = 0
|
||||
dtp = None
|
||||
while self.heap:
|
||||
if total == 0:
|
||||
if dtp is None:
|
||||
dtp = heapq.heappop(self.heap)
|
||||
continue
|
||||
|
||||
dt = heapq.heappop(self.heap)
|
||||
diff = dt - dtp
|
||||
ct += 1
|
||||
total += total_seconds(diff)
|
||||
dtp = dt
|
||||
|
||||
return float(total) / ct
|
||||
|
||||
@aggregate(DATE)
|
||||
class duration(object):
|
||||
def __init__(self):
|
||||
self._min = self._max = None
|
||||
|
||||
def step(self, value):
|
||||
dt = format_date_time_sqlite(value)
|
||||
if self._min is None or dt < self._min:
|
||||
self._min = dt
|
||||
if self._max is None or dt > self._max:
|
||||
self._max = dt
|
||||
|
||||
def finalize(self):
|
||||
if self._min and self._max:
|
||||
td = (self._max - self._min)
|
||||
return total_seconds(td)
|
||||
return None
|
||||
|
||||
@aggregate(MATH)
|
||||
class mode(object):
|
||||
if Counter:
|
||||
def __init__(self):
|
||||
self.items = Counter()
|
||||
|
||||
def step(self, *args):
|
||||
self.items.update(args)
|
||||
|
||||
def finalize(self):
|
||||
if self.items:
|
||||
return self.items.most_common(1)[0][0]
|
||||
else:
|
||||
def __init__(self):
|
||||
self.items = []
|
||||
|
||||
def step(self, item):
|
||||
self.items.append(item)
|
||||
|
||||
def finalize(self):
|
||||
if self.items:
|
||||
return max(set(self.items), key=self.items.count)
|
||||
|
||||
@aggregate(MATH)
|
||||
class minrange(_heap_agg):
|
||||
def finalize(self):
|
||||
if self.ct == 0:
|
||||
return
|
||||
elif self.ct == 1:
|
||||
return 0
|
||||
|
||||
prev = min_diff = None
|
||||
|
||||
while self.heap:
|
||||
if min_diff is None:
|
||||
if prev is None:
|
||||
prev = heapq.heappop(self.heap)
|
||||
continue
|
||||
curr = heapq.heappop(self.heap)
|
||||
diff = curr - prev
|
||||
if min_diff is None or min_diff > diff:
|
||||
min_diff = diff
|
||||
prev = curr
|
||||
return min_diff
|
||||
|
||||
@aggregate(MATH)
|
||||
class avgrange(_heap_agg):
|
||||
def finalize(self):
|
||||
if self.ct == 0:
|
||||
return
|
||||
elif self.ct == 1:
|
||||
return 0
|
||||
|
||||
total = ct = 0
|
||||
prev = None
|
||||
while self.heap:
|
||||
if total == 0:
|
||||
if prev is None:
|
||||
prev = heapq.heappop(self.heap)
|
||||
continue
|
||||
|
||||
curr = heapq.heappop(self.heap)
|
||||
diff = curr - prev
|
||||
ct += 1
|
||||
total += diff
|
||||
prev = curr
|
||||
|
||||
return float(total) / ct
|
||||
|
||||
@aggregate(MATH)
|
||||
class _range(object):
|
||||
name = 'range'
|
||||
|
||||
def __init__(self):
|
||||
self._min = self._max = None
|
||||
|
||||
def step(self, value):
|
||||
if self._min is None or value < self._min:
|
||||
self._min = value
|
||||
if self._max is None or value > self._max:
|
||||
self._max = value
|
||||
|
||||
def finalize(self):
|
||||
if self._min is not None and self._max is not None:
|
||||
return self._max - self._min
|
||||
return None
|
||||
|
||||
@aggregate(MATH)
|
||||
class stddev(object):
|
||||
def __init__(self):
|
||||
self.n = 0
|
||||
self.values = []
|
||||
def step(self, v):
|
||||
self.n += 1
|
||||
self.values.append(v)
|
||||
def finalize(self):
|
||||
if self.n <= 1:
|
||||
return 0
|
||||
mean = sum(self.values) / self.n
|
||||
return math.sqrt(sum((i - mean) ** 2 for i in self.values) / (self.n - 1))
|
||||
|
||||
|
||||
if cython_udf is not None:
|
||||
damerau_levenshtein_dist = udf(STRING)(cython_udf.damerau_levenshtein_dist)
|
||||
levenshtein_dist = udf(STRING)(cython_udf.levenshtein_dist)
|
||||
str_dist = udf(STRING)(cython_udf.str_dist)
|
||||
median = aggregate(MATH)(cython_udf.median)
|
||||
|
||||
|
||||
if TableFunction is not None:
|
||||
@table_function(STRING)
|
||||
class RegexSearch(TableFunction):
|
||||
params = ['regex', 'search_string']
|
||||
columns = ['match']
|
||||
name = 'regex_search'
|
||||
|
||||
def initialize(self, regex=None, search_string=None):
|
||||
self._iter = re.finditer(regex, search_string)
|
||||
|
||||
def iterate(self, idx):
|
||||
return (next(self._iter).group(0),)
|
||||
|
||||
@table_function(DATE)
|
||||
class DateSeries(TableFunction):
|
||||
params = ['start', 'stop', 'step_seconds']
|
||||
columns = ['date']
|
||||
name = 'date_series'
|
||||
|
||||
def initialize(self, start, stop, step_seconds=86400):
|
||||
self.start = format_date_time_sqlite(start)
|
||||
self.stop = format_date_time_sqlite(stop)
|
||||
step_seconds = int(step_seconds)
|
||||
self.step_seconds = datetime.timedelta(seconds=step_seconds)
|
||||
|
||||
if (self.start.hour == 0 and
|
||||
self.start.minute == 0 and
|
||||
self.start.second == 0 and
|
||||
step_seconds >= 86400):
|
||||
self.format = '%Y-%m-%d'
|
||||
elif (self.start.year == 1900 and
|
||||
self.start.month == 1 and
|
||||
self.start.day == 1 and
|
||||
self.stop.year == 1900 and
|
||||
self.stop.month == 1 and
|
||||
self.stop.day == 1 and
|
||||
step_seconds < 86400):
|
||||
self.format = '%H:%M:%S'
|
||||
else:
|
||||
self.format = '%Y-%m-%d %H:%M:%S'
|
||||
|
||||
def iterate(self, idx):
|
||||
if self.start > self.stop:
|
||||
raise StopIteration
|
||||
current = self.start
|
||||
self.start += self.step_seconds
|
||||
return (current.strftime(self.format),)
|
@ -0,0 +1,331 @@
|
||||
import logging
|
||||
import weakref
|
||||
from threading import local as thread_local
|
||||
from threading import Event
|
||||
from threading import Thread
|
||||
try:
|
||||
from Queue import Queue
|
||||
except ImportError:
|
||||
from queue import Queue
|
||||
|
||||
try:
|
||||
import gevent
|
||||
from gevent import Greenlet as GThread
|
||||
from gevent.event import Event as GEvent
|
||||
from gevent.local import local as greenlet_local
|
||||
from gevent.queue import Queue as GQueue
|
||||
except ImportError:
|
||||
GThread = GQueue = GEvent = None
|
||||
|
||||
from peewee import SENTINEL
|
||||
from playhouse.sqlite_ext import SqliteExtDatabase
|
||||
|
||||
|
||||
logger = logging.getLogger('peewee.sqliteq')
|
||||
|
||||
|
||||
class ResultTimeout(Exception):
|
||||
pass
|
||||
|
||||
class WriterPaused(Exception):
|
||||
pass
|
||||
|
||||
class ShutdownException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class AsyncCursor(object):
|
||||
__slots__ = ('sql', 'params', 'commit', 'timeout',
|
||||
'_event', '_cursor', '_exc', '_idx', '_rows', '_ready')
|
||||
|
||||
def __init__(self, event, sql, params, commit, timeout):
|
||||
self._event = event
|
||||
self.sql = sql
|
||||
self.params = params
|
||||
self.commit = commit
|
||||
self.timeout = timeout
|
||||
self._cursor = self._exc = self._idx = self._rows = None
|
||||
self._ready = False
|
||||
|
||||
def set_result(self, cursor, exc=None):
|
||||
self._cursor = cursor
|
||||
self._exc = exc
|
||||
self._idx = 0
|
||||
self._rows = cursor.fetchall() if exc is None else []
|
||||
self._event.set()
|
||||
return self
|
||||
|
||||
def _wait(self, timeout=None):
|
||||
timeout = timeout if timeout is not None else self.timeout
|
||||
if not self._event.wait(timeout=timeout) and timeout:
|
||||
raise ResultTimeout('results not ready, timed out.')
|
||||
if self._exc is not None:
|
||||
raise self._exc
|
||||
self._ready = True
|
||||
|
||||
def __iter__(self):
|
||||
if not self._ready:
|
||||
self._wait()
|
||||
if self._exc is not None:
|
||||
raise self._exc
|
||||
return self
|
||||
|
||||
def next(self):
|
||||
if not self._ready:
|
||||
self._wait()
|
||||
try:
|
||||
obj = self._rows[self._idx]
|
||||
except IndexError:
|
||||
raise StopIteration
|
||||
else:
|
||||
self._idx += 1
|
||||
return obj
|
||||
__next__ = next
|
||||
|
||||
@property
|
||||
def lastrowid(self):
|
||||
if not self._ready:
|
||||
self._wait()
|
||||
return self._cursor.lastrowid
|
||||
|
||||
@property
|
||||
def rowcount(self):
|
||||
if not self._ready:
|
||||
self._wait()
|
||||
return self._cursor.rowcount
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return self._cursor.description
|
||||
|
||||
def close(self):
|
||||
self._cursor.close()
|
||||
|
||||
def fetchall(self):
|
||||
return list(self) # Iterating implies waiting until populated.
|
||||
|
||||
def fetchone(self):
|
||||
if not self._ready:
|
||||
self._wait()
|
||||
try:
|
||||
return next(self)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
SHUTDOWN = StopIteration
|
||||
PAUSE = object()
|
||||
UNPAUSE = object()
|
||||
|
||||
|
||||
class Writer(object):
|
||||
__slots__ = ('database', 'queue')
|
||||
|
||||
def __init__(self, database, queue):
|
||||
self.database = database
|
||||
self.queue = queue
|
||||
|
||||
def run(self):
|
||||
conn = self.database.connection()
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
if conn is None: # Paused.
|
||||
if self.wait_unpause():
|
||||
conn = self.database.connection()
|
||||
else:
|
||||
conn = self.loop(conn)
|
||||
except ShutdownException:
|
||||
logger.info('writer received shutdown request, exiting.')
|
||||
return
|
||||
finally:
|
||||
if conn is not None:
|
||||
self.database._close(conn)
|
||||
self.database._state.reset()
|
||||
|
||||
def wait_unpause(self):
|
||||
obj = self.queue.get()
|
||||
if obj is UNPAUSE:
|
||||
logger.info('writer unpaused - reconnecting to database.')
|
||||
return True
|
||||
elif obj is SHUTDOWN:
|
||||
raise ShutdownException()
|
||||
elif obj is PAUSE:
|
||||
logger.error('writer received pause, but is already paused.')
|
||||
else:
|
||||
obj.set_result(None, WriterPaused())
|
||||
logger.warning('writer paused, not handling %s', obj)
|
||||
|
||||
def loop(self, conn):
|
||||
obj = self.queue.get()
|
||||
if isinstance(obj, AsyncCursor):
|
||||
self.execute(obj)
|
||||
elif obj is PAUSE:
|
||||
logger.info('writer paused - closing database connection.')
|
||||
self.database._close(conn)
|
||||
self.database._state.reset()
|
||||
return
|
||||
elif obj is UNPAUSE:
|
||||
logger.error('writer received unpause, but is already running.')
|
||||
elif obj is SHUTDOWN:
|
||||
raise ShutdownException()
|
||||
else:
|
||||
logger.error('writer received unsupported object: %s', obj)
|
||||
return conn
|
||||
|
||||
def execute(self, obj):
|
||||
logger.debug('received query %s', obj.sql)
|
||||
try:
|
||||
cursor = self.database._execute(obj.sql, obj.params, obj.commit)
|
||||
except Exception as execute_err:
|
||||
cursor = None
|
||||
exc = execute_err # python3 is so fucking lame.
|
||||
else:
|
||||
exc = None
|
||||
return obj.set_result(cursor, exc)
|
||||
|
||||
|
||||
class SqliteQueueDatabase(SqliteExtDatabase):
|
||||
WAL_MODE_ERROR_MESSAGE = ('SQLite must be configured to use the WAL '
|
||||
'journal mode when using this feature. WAL mode '
|
||||
'allows one or more readers to continue reading '
|
||||
'while another connection writes to the '
|
||||
'database.')
|
||||
|
||||
def __init__(self, database, use_gevent=False, autostart=True,
|
||||
queue_max_size=None, results_timeout=None, *args, **kwargs):
|
||||
kwargs['check_same_thread'] = False
|
||||
|
||||
# Ensure that journal_mode is WAL. This value is passed to the parent
|
||||
# class constructor below.
|
||||
pragmas = self._validate_journal_mode(kwargs.pop('pragmas', None))
|
||||
|
||||
# Reference to execute_sql on the parent class. Since we've overridden
|
||||
# execute_sql(), this is just a handy way to reference the real
|
||||
# implementation.
|
||||
Parent = super(SqliteQueueDatabase, self)
|
||||
self._execute = Parent.execute_sql
|
||||
|
||||
# Call the parent class constructor with our modified pragmas.
|
||||
Parent.__init__(database, pragmas=pragmas, *args, **kwargs)
|
||||
|
||||
self._autostart = autostart
|
||||
self._results_timeout = results_timeout
|
||||
self._is_stopped = True
|
||||
|
||||
# Get different objects depending on the threading implementation.
|
||||
self._thread_helper = self.get_thread_impl(use_gevent)(queue_max_size)
|
||||
|
||||
# Create the writer thread, optionally starting it.
|
||||
self._create_write_queue()
|
||||
if self._autostart:
|
||||
self.start()
|
||||
|
||||
def get_thread_impl(self, use_gevent):
|
||||
return GreenletHelper if use_gevent else ThreadHelper
|
||||
|
||||
def _validate_journal_mode(self, pragmas=None):
|
||||
if not pragmas:
|
||||
return {'journal_mode': 'wal'}
|
||||
|
||||
if not isinstance(pragmas, dict):
|
||||
pragmas = dict((k.lower(), v) for (k, v) in pragmas)
|
||||
if pragmas.get('journal_mode', 'wal').lower() != 'wal':
|
||||
raise ValueError(self.WAL_MODE_ERROR_MESSAGE)
|
||||
|
||||
pragmas['journal_mode'] = 'wal'
|
||||
return pragmas
|
||||
|
||||
def _create_write_queue(self):
|
||||
self._write_queue = self._thread_helper.queue()
|
||||
|
||||
def queue_size(self):
|
||||
return self._write_queue.qsize()
|
||||
|
||||
def execute_sql(self, sql, params=None, commit=SENTINEL, timeout=None):
|
||||
if commit is SENTINEL:
|
||||
commit = not sql.lower().startswith('select')
|
||||
|
||||
if not commit:
|
||||
return self._execute(sql, params, commit=commit)
|
||||
|
||||
cursor = AsyncCursor(
|
||||
event=self._thread_helper.event(),
|
||||
sql=sql,
|
||||
params=params,
|
||||
commit=commit,
|
||||
timeout=self._results_timeout if timeout is None else timeout)
|
||||
self._write_queue.put(cursor)
|
||||
return cursor
|
||||
|
||||
def start(self):
|
||||
with self._lock:
|
||||
if not self._is_stopped:
|
||||
return False
|
||||
def run():
|
||||
writer = Writer(self, self._write_queue)
|
||||
writer.run()
|
||||
|
||||
self._writer = self._thread_helper.thread(run)
|
||||
self._writer.start()
|
||||
self._is_stopped = False
|
||||
return True
|
||||
|
||||
def stop(self):
|
||||
logger.debug('environment stop requested.')
|
||||
with self._lock:
|
||||
if self._is_stopped:
|
||||
return False
|
||||
self._write_queue.put(SHUTDOWN)
|
||||
self._writer.join()
|
||||
self._is_stopped = True
|
||||
return True
|
||||
|
||||
def is_stopped(self):
|
||||
with self._lock:
|
||||
return self._is_stopped
|
||||
|
||||
def pause(self):
|
||||
with self._lock:
|
||||
self._write_queue.put(PAUSE)
|
||||
|
||||
def unpause(self):
|
||||
with self._lock:
|
||||
self._write_queue.put(UNPAUSE)
|
||||
|
||||
def __unsupported__(self, *args, **kwargs):
|
||||
raise ValueError('This method is not supported by %r.' % type(self))
|
||||
atomic = transaction = savepoint = __unsupported__
|
||||
|
||||
|
||||
class ThreadHelper(object):
|
||||
__slots__ = ('queue_max_size',)
|
||||
|
||||
def __init__(self, queue_max_size=None):
|
||||
self.queue_max_size = queue_max_size
|
||||
|
||||
def event(self): return Event()
|
||||
|
||||
def queue(self, max_size=None):
|
||||
max_size = max_size if max_size is not None else self.queue_max_size
|
||||
return Queue(maxsize=max_size or 0)
|
||||
|
||||
def thread(self, fn, *args, **kwargs):
|
||||
thread = Thread(target=fn, args=args, kwargs=kwargs)
|
||||
thread.daemon = True
|
||||
return thread
|
||||
|
||||
|
||||
class GreenletHelper(ThreadHelper):
|
||||
__slots__ = ()
|
||||
|
||||
def event(self): return GEvent()
|
||||
|
||||
def queue(self, max_size=None):
|
||||
max_size = max_size if max_size is not None else self.queue_max_size
|
||||
return GQueue(maxsize=max_size or 0)
|
||||
|
||||
def thread(self, fn, *args, **kwargs):
|
||||
def wrap(*a, **k):
|
||||
gevent.sleep()
|
||||
return fn(*a, **k)
|
||||
return GThread(wrap, *args, **kwargs)
|
@ -0,0 +1,62 @@
|
||||
from functools import wraps
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger('peewee')
|
||||
|
||||
|
||||
class _QueryLogHandler(logging.Handler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.queries = []
|
||||
logging.Handler.__init__(self, *args, **kwargs)
|
||||
|
||||
def emit(self, record):
|
||||
self.queries.append(record)
|
||||
|
||||
|
||||
class count_queries(object):
|
||||
def __init__(self, only_select=False):
|
||||
self.only_select = only_select
|
||||
self.count = 0
|
||||
|
||||
def get_queries(self):
|
||||
return self._handler.queries
|
||||
|
||||
def __enter__(self):
|
||||
self._handler = _QueryLogHandler()
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.addHandler(self._handler)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
logger.removeHandler(self._handler)
|
||||
if self.only_select:
|
||||
self.count = len([q for q in self._handler.queries
|
||||
if q.msg[0].startswith('SELECT ')])
|
||||
else:
|
||||
self.count = len(self._handler.queries)
|
||||
|
||||
|
||||
class assert_query_count(count_queries):
|
||||
def __init__(self, expected, only_select=False):
|
||||
super(assert_query_count, self).__init__(only_select=only_select)
|
||||
self.expected = expected
|
||||
|
||||
def __call__(self, f):
|
||||
@wraps(f)
|
||||
def decorated(*args, **kwds):
|
||||
with self:
|
||||
ret = f(*args, **kwds)
|
||||
|
||||
self._assert_count()
|
||||
return ret
|
||||
|
||||
return decorated
|
||||
|
||||
def _assert_count(self):
|
||||
error_msg = '%s != %s' % (self.count, self.expected)
|
||||
assert self.count == self.expected, error_msg
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
super(assert_query_count, self).__exit__(exc_type, exc_val, exc_tb)
|
||||
self._assert_count()
|
@ -1,219 +0,0 @@
|
||||
# Copyright (c) 2014 Palantir Technologies
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in
|
||||
# all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
# THE SOFTWARE.
|
||||
|
||||
"""Thread safe sqlite3 interface."""
|
||||
|
||||
__author__ = "Shawn Lee"
|
||||
__email__ = "shawnl@palantir.com"
|
||||
__license__ = "MIT"
|
||||
|
||||
import logging
|
||||
try:
|
||||
import queue as Queue # module re-named in Python 3
|
||||
except ImportError:
|
||||
import Queue
|
||||
import sqlite3
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
|
||||
LOGGER = logging.getLogger('sqlite3worker')
|
||||
|
||||
|
||||
class Sqlite3Worker(threading.Thread):
|
||||
"""Sqlite thread safe object.
|
||||
|
||||
Example:
|
||||
from sqlite3worker import Sqlite3Worker
|
||||
sql_worker = Sqlite3Worker("/tmp/test.sqlite")
|
||||
sql_worker.execute(
|
||||
"CREATE TABLE tester (timestamp DATETIME, uuid TEXT)")
|
||||
sql_worker.execute(
|
||||
"INSERT into tester values (?, ?)", ("2010-01-01 13:00:00", "bow"))
|
||||
sql_worker.execute(
|
||||
"INSERT into tester values (?, ?)", ("2011-02-02 14:14:14", "dog"))
|
||||
sql_worker.execute("SELECT * from tester")
|
||||
sql_worker.close()
|
||||
"""
|
||||
def __init__(self, file_name, max_queue_size=100, as_dict=False):
|
||||
"""Automatically starts the thread.
|
||||
|
||||
Args:
|
||||
file_name: The name of the file.
|
||||
max_queue_size: The max queries that will be queued.
|
||||
as_dict: Return result as a dictionary.
|
||||
"""
|
||||
threading.Thread.__init__(self)
|
||||
self.daemon = True
|
||||
self.sqlite3_conn = sqlite3.connect(
|
||||
file_name, check_same_thread=False,
|
||||
detect_types=sqlite3.PARSE_DECLTYPES)
|
||||
if as_dict:
|
||||
self.sqlite3_conn.row_factory = dict_factory
|
||||
self.sqlite3_cursor = self.sqlite3_conn.cursor()
|
||||
self.sql_queue = Queue.Queue(maxsize=max_queue_size)
|
||||
self.results = {}
|
||||
self.max_queue_size = max_queue_size
|
||||
self.exit_set = False
|
||||
# Token that is put into queue when close() is called.
|
||||
self.exit_token = str(uuid.uuid4())
|
||||
self.start()
|
||||
self.thread_running = True
|
||||
|
||||
def run(self):
|
||||
"""Thread loop.
|
||||
|
||||
This is an infinite loop. The iter method calls self.sql_queue.get()
|
||||
which blocks if there are not values in the queue. As soon as values
|
||||
are placed into the queue the process will continue.
|
||||
|
||||
If many executes happen at once it will churn through them all before
|
||||
calling commit() to speed things up by reducing the number of times
|
||||
commit is called.
|
||||
"""
|
||||
LOGGER.debug("run: Thread started")
|
||||
execute_count = 0
|
||||
for token, query, values, only_one, execute_many in iter(self.sql_queue.get, None):
|
||||
LOGGER.debug("sql_queue: %s", self.sql_queue.qsize())
|
||||
if token != self.exit_token:
|
||||
LOGGER.debug("run: %s, %s", query, values)
|
||||
self.run_query(token, query, values, only_one, execute_many)
|
||||
execute_count += 1
|
||||
# Let the executes build up a little before committing to disk
|
||||
# to speed things up.
|
||||
if (
|
||||
self.sql_queue.empty() or
|
||||
execute_count == self.max_queue_size):
|
||||
LOGGER.debug("run: commit")
|
||||
self.sqlite3_conn.commit()
|
||||
execute_count = 0
|
||||
# Only exit if the queue is empty. Otherwise keep getting
|
||||
# through the queue until it's empty.
|
||||
if self.exit_set and self.sql_queue.empty():
|
||||
self.sqlite3_conn.commit()
|
||||
self.sqlite3_conn.close()
|
||||
self.thread_running = False
|
||||
return
|
||||
|
||||
def run_query(self, token, query, values, only_one=False, execute_many=False):
|
||||
"""Run a query.
|
||||
|
||||
Args:
|
||||
token: A uuid object of the query you want returned.
|
||||
query: A sql query with ? placeholders for values.
|
||||
values: A tuple of values to replace "?" in query.
|
||||
"""
|
||||
if query.lower().strip().startswith(("select", "pragma")):
|
||||
try:
|
||||
self.sqlite3_cursor.execute(query, values)
|
||||
if only_one:
|
||||
self.results[token] = self.sqlite3_cursor.fetchone()
|
||||
else:
|
||||
self.results[token] = self.sqlite3_cursor.fetchall()
|
||||
except sqlite3.Error as err:
|
||||
# Put the error into the output queue since a response
|
||||
# is required.
|
||||
self.results[token] = (
|
||||
"Query returned error: %s: %s: %s" % (query, values, err))
|
||||
LOGGER.error(
|
||||
"Query returned error: %s: %s: %s", query, values, err)
|
||||
else:
|
||||
try:
|
||||
if execute_many:
|
||||
self.sqlite3_cursor.executemany(query, values)
|
||||
if query.lower().strip().startswith(("insert", "update", "delete")):
|
||||
self.results[token] = self.sqlite3_cursor.rowcount
|
||||
else:
|
||||
self.sqlite3_cursor.execute(query, values)
|
||||
if query.lower().strip().startswith(("insert", "update", "delete")):
|
||||
self.results[token] = self.sqlite3_cursor.rowcount
|
||||
except sqlite3.Error as err:
|
||||
self.results[token] = (
|
||||
"Query returned error: %s: %s: %s" % (query, values, err))
|
||||
LOGGER.error(
|
||||
"Query returned error: %s: %s: %s", query, values, err)
|
||||
|
||||
def close(self):
|
||||
"""Close down the thread and close the sqlite3 database file."""
|
||||
self.exit_set = True
|
||||
self.sql_queue.put((self.exit_token, "", "", "", ""), timeout=5)
|
||||
# Sleep and check that the thread is done before returning.
|
||||
while self.thread_running:
|
||||
time.sleep(.01) # Don't kill the CPU waiting.
|
||||
|
||||
@property
|
||||
def queue_size(self):
|
||||
"""Return the queue size."""
|
||||
return self.sql_queue.qsize()
|
||||
|
||||
def query_results(self, token):
|
||||
"""Get the query results for a specific token.
|
||||
|
||||
Args:
|
||||
token: A uuid object of the query you want returned.
|
||||
|
||||
Returns:
|
||||
Return the results of the query when it's executed by the thread.
|
||||
"""
|
||||
delay = .001
|
||||
while True:
|
||||
if token in self.results:
|
||||
return_val = self.results[token]
|
||||
del self.results[token]
|
||||
return return_val
|
||||
# Double back on the delay to a max of 8 seconds. This prevents
|
||||
# a long lived select statement from trashing the CPU with this
|
||||
# infinite loop as it's waiting for the query results.
|
||||
LOGGER.debug("Sleeping: %s %s", delay, token)
|
||||
time.sleep(delay)
|
||||
if delay < 8:
|
||||
delay += delay
|
||||
|
||||
def execute(self, query, values=None, only_one=False, execute_many=False):
|
||||
"""Execute a query.
|
||||
|
||||
Args:
|
||||
query: The sql string using ? for placeholders of dynamic values.
|
||||
values: A tuple of values to be replaced into the ? of the query.
|
||||
|
||||
Returns:
|
||||
If it's a select query it will return the results of the query.
|
||||
"""
|
||||
if self.exit_set:
|
||||
LOGGER.debug("Exit set, not running: %s, %s", query, values)
|
||||
return "Exit Called"
|
||||
LOGGER.debug("execute: %s, %s", query, values)
|
||||
values = values or []
|
||||
# A token to track this query with.
|
||||
token = str(uuid.uuid4())
|
||||
# If it's a select we queue it up with a token to mark the results
|
||||
# into the output queue so we know what results are ours.
|
||||
if query.lower().strip().startswith(("select", "insert", "update", "delete", "pragma")):
|
||||
self.sql_queue.put((token, query, values, only_one, execute_many), timeout=5)
|
||||
return self.query_results(token)
|
||||
else:
|
||||
self.sql_queue.put((token, query, values, only_one, execute_many), timeout=5)
|
||||
|
||||
|
||||
def dict_factory(cursor, row):
|
||||
d = {}
|
||||
for idx, col in enumerate(cursor.description):
|
||||
d[col[0]] = row[idx]
|
||||
return d
|
Loading…
Reference in new issue