You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1295 lines
43 KiB
1295 lines
43 KiB
4 years ago
|
import json
|
||
|
import math
|
||
|
import re
|
||
|
import struct
|
||
|
import sys
|
||
|
|
||
|
from peewee import *
|
||
|
from peewee import ColumnBase
|
||
|
from peewee import EnclosedNodeList
|
||
|
from peewee import Entity
|
||
|
from peewee import Expression
|
||
|
from peewee import Node
|
||
|
from peewee import NodeList
|
||
|
from peewee import OP
|
||
|
from peewee import VirtualField
|
||
|
from peewee import merge_dict
|
||
|
from peewee import sqlite3
|
||
|
try:
|
||
|
from playhouse._sqlite_ext import (
|
||
|
backup,
|
||
|
backup_to_file,
|
||
|
Blob,
|
||
|
ConnectionHelper,
|
||
|
register_bloomfilter,
|
||
|
register_hash_functions,
|
||
|
register_rank_functions,
|
||
|
sqlite_get_db_status,
|
||
|
sqlite_get_status,
|
||
|
TableFunction,
|
||
|
ZeroBlob,
|
||
|
)
|
||
|
CYTHON_SQLITE_EXTENSIONS = True
|
||
|
except ImportError:
|
||
|
CYTHON_SQLITE_EXTENSIONS = False
|
||
|
|
||
|
|
||
|
if sys.version_info[0] == 3:
|
||
|
basestring = str
|
||
|
|
||
|
|
||
|
FTS3_MATCHINFO = 'pcx'
|
||
|
FTS4_MATCHINFO = 'pcnalx'
|
||
|
if sqlite3 is not None:
|
||
|
FTS_VERSION = 4 if sqlite3.sqlite_version_info[:3] >= (3, 7, 4) else 3
|
||
|
else:
|
||
|
FTS_VERSION = 3
|
||
|
|
||
|
FTS5_MIN_SQLITE_VERSION = (3, 9, 0)
|
||
|
|
||
|
|
||
|
class RowIDField(AutoField):
|
||
|
auto_increment = True
|
||
|
column_name = name = required_name = 'rowid'
|
||
|
|
||
|
def bind(self, model, name, *args):
|
||
|
if name != self.required_name:
|
||
|
raise ValueError('%s must be named "%s".' %
|
||
|
(type(self), self.required_name))
|
||
|
super(RowIDField, self).bind(model, name, *args)
|
||
|
|
||
|
|
||
|
class DocIDField(RowIDField):
|
||
|
column_name = name = required_name = 'docid'
|
||
|
|
||
|
|
||
|
class AutoIncrementField(AutoField):
|
||
|
def ddl(self, ctx):
|
||
|
node_list = super(AutoIncrementField, self).ddl(ctx)
|
||
|
return NodeList((node_list, SQL('AUTOINCREMENT')))
|
||
|
|
||
|
|
||
|
class TDecimalField(DecimalField):
|
||
|
field_type = 'TEXT'
|
||
|
def get_modifiers(self): pass
|
||
|
|
||
|
|
||
|
class JSONPath(ColumnBase):
|
||
|
def __init__(self, field, path=None):
|
||
|
super(JSONPath, self).__init__()
|
||
|
self._field = field
|
||
|
self._path = path or ()
|
||
|
|
||
|
@property
|
||
|
def path(self):
|
||
|
return Value('$%s' % ''.join(self._path))
|
||
|
|
||
|
def __getitem__(self, idx):
|
||
|
if isinstance(idx, int):
|
||
|
item = '[%s]' % idx
|
||
|
else:
|
||
|
item = '.%s' % idx
|
||
|
return JSONPath(self._field, self._path + (item,))
|
||
|
|
||
|
def set(self, value, as_json=None):
|
||
|
if as_json or isinstance(value, (list, dict)):
|
||
|
value = fn.json(self._field._json_dumps(value))
|
||
|
return fn.json_set(self._field, self.path, value)
|
||
|
|
||
|
def update(self, value):
|
||
|
return self.set(fn.json_patch(self, self._field._json_dumps(value)))
|
||
|
|
||
|
def remove(self):
|
||
|
return fn.json_remove(self._field, self.path)
|
||
|
|
||
|
def json_type(self):
|
||
|
return fn.json_type(self._field, self.path)
|
||
|
|
||
|
def length(self):
|
||
|
return fn.json_array_length(self._field, self.path)
|
||
|
|
||
|
def children(self):
|
||
|
return fn.json_each(self._field, self.path)
|
||
|
|
||
|
def tree(self):
|
||
|
return fn.json_tree(self._field, self.path)
|
||
|
|
||
|
def __sql__(self, ctx):
|
||
|
return ctx.sql(fn.json_extract(self._field, self.path)
|
||
|
if self._path else self._field)
|
||
|
|
||
|
|
||
|
class JSONField(TextField):
|
||
|
field_type = 'JSON'
|
||
|
unpack = False
|
||
|
|
||
|
def __init__(self, json_dumps=None, json_loads=None, **kwargs):
|
||
|
self._json_dumps = json_dumps or json.dumps
|
||
|
self._json_loads = json_loads or json.loads
|
||
|
super(JSONField, self).__init__(**kwargs)
|
||
|
|
||
|
def python_value(self, value):
|
||
|
if value is not None:
|
||
|
try:
|
||
|
return self._json_loads(value)
|
||
|
except (TypeError, ValueError):
|
||
|
return value
|
||
|
|
||
|
def db_value(self, value):
|
||
|
if value is not None:
|
||
|
if not isinstance(value, Node):
|
||
|
value = fn.json(self._json_dumps(value))
|
||
|
return value
|
||
|
|
||
|
def _e(op):
|
||
|
def inner(self, rhs):
|
||
|
if isinstance(rhs, (list, dict)):
|
||
|
rhs = Value(rhs, converter=self.db_value, unpack=False)
|
||
|
return Expression(self, op, rhs)
|
||
|
return inner
|
||
|
__eq__ = _e(OP.EQ)
|
||
|
__ne__ = _e(OP.NE)
|
||
|
__gt__ = _e(OP.GT)
|
||
|
__ge__ = _e(OP.GTE)
|
||
|
__lt__ = _e(OP.LT)
|
||
|
__le__ = _e(OP.LTE)
|
||
|
__hash__ = Field.__hash__
|
||
|
|
||
|
def __getitem__(self, item):
|
||
|
return JSONPath(self)[item]
|
||
|
|
||
|
def set(self, value, as_json=None):
|
||
|
return JSONPath(self).set(value, as_json)
|
||
|
|
||
|
def update(self, data):
|
||
|
return JSONPath(self).update(data)
|
||
|
|
||
|
def remove(self):
|
||
|
return JSONPath(self).remove()
|
||
|
|
||
|
def json_type(self):
|
||
|
return fn.json_type(self)
|
||
|
|
||
|
def length(self):
|
||
|
return fn.json_array_length(self)
|
||
|
|
||
|
def children(self):
|
||
|
"""
|
||
|
Schema of `json_each` and `json_tree`:
|
||
|
|
||
|
key,
|
||
|
value,
|
||
|
type TEXT (object, array, string, etc),
|
||
|
atom (value for primitive/scalar types, NULL for array and object)
|
||
|
id INTEGER (unique identifier for element)
|
||
|
parent INTEGER (unique identifier of parent element or NULL)
|
||
|
fullkey TEXT (full path describing element)
|
||
|
path TEXT (path to the container of the current element)
|
||
|
json JSON hidden (1st input parameter to function)
|
||
|
root TEXT hidden (2nd input parameter, path at which to start)
|
||
|
"""
|
||
|
return fn.json_each(self)
|
||
|
|
||
|
def tree(self):
|
||
|
return fn.json_tree(self)
|
||
|
|
||
|
|
||
|
class SearchField(Field):
|
||
|
def __init__(self, unindexed=False, column_name=None, **k):
|
||
|
if k:
|
||
|
raise ValueError('SearchField does not accept these keyword '
|
||
|
'arguments: %s.' % sorted(k))
|
||
|
super(SearchField, self).__init__(unindexed=unindexed,
|
||
|
column_name=column_name, null=True)
|
||
|
|
||
|
def match(self, term):
|
||
|
return match(self, term)
|
||
|
|
||
|
|
||
|
class VirtualTableSchemaManager(SchemaManager):
|
||
|
def _create_virtual_table(self, safe=True, **options):
|
||
|
options = self.model.clean_options(
|
||
|
merge_dict(self.model._meta.options, options))
|
||
|
|
||
|
# Structure:
|
||
|
# CREATE VIRTUAL TABLE <model>
|
||
|
# USING <extension_module>
|
||
|
# ([prefix_arguments, ...] fields, ... [arguments, ...], [options...])
|
||
|
ctx = self._create_context()
|
||
|
ctx.literal('CREATE VIRTUAL TABLE ')
|
||
|
if safe:
|
||
|
ctx.literal('IF NOT EXISTS ')
|
||
|
(ctx
|
||
|
.sql(self.model)
|
||
|
.literal(' USING '))
|
||
|
|
||
|
ext_module = self.model._meta.extension_module
|
||
|
if isinstance(ext_module, Node):
|
||
|
return ctx.sql(ext_module)
|
||
|
|
||
|
ctx.sql(SQL(ext_module)).literal(' ')
|
||
|
arguments = []
|
||
|
meta = self.model._meta
|
||
|
|
||
|
if meta.prefix_arguments:
|
||
|
arguments.extend([SQL(a) for a in meta.prefix_arguments])
|
||
|
|
||
|
# Constraints, data-types, foreign and primary keys are all omitted.
|
||
|
for field in meta.sorted_fields:
|
||
|
if isinstance(field, (RowIDField)) or field._hidden:
|
||
|
continue
|
||
|
field_def = [Entity(field.column_name)]
|
||
|
if field.unindexed:
|
||
|
field_def.append(SQL('UNINDEXED'))
|
||
|
arguments.append(NodeList(field_def))
|
||
|
|
||
|
if meta.arguments:
|
||
|
arguments.extend([SQL(a) for a in meta.arguments])
|
||
|
|
||
|
if options:
|
||
|
arguments.extend(self._create_table_option_sql(options))
|
||
|
return ctx.sql(EnclosedNodeList(arguments))
|
||
|
|
||
|
def _create_table(self, safe=True, **options):
|
||
|
if issubclass(self.model, VirtualModel):
|
||
|
return self._create_virtual_table(safe, **options)
|
||
|
|
||
|
return super(VirtualTableSchemaManager, self)._create_table(
|
||
|
safe, **options)
|
||
|
|
||
|
|
||
|
class VirtualModel(Model):
|
||
|
class Meta:
|
||
|
arguments = None
|
||
|
extension_module = None
|
||
|
prefix_arguments = None
|
||
|
primary_key = False
|
||
|
schema_manager_class = VirtualTableSchemaManager
|
||
|
|
||
|
@classmethod
|
||
|
def clean_options(cls, options):
|
||
|
return options
|
||
|
|
||
|
|
||
|
class BaseFTSModel(VirtualModel):
|
||
|
@classmethod
|
||
|
def clean_options(cls, options):
|
||
|
content = options.get('content')
|
||
|
prefix = options.get('prefix')
|
||
|
tokenize = options.get('tokenize')
|
||
|
|
||
|
if isinstance(content, basestring) and content == '':
|
||
|
# Special-case content-less full-text search tables.
|
||
|
options['content'] = "''"
|
||
|
elif isinstance(content, Field):
|
||
|
# Special-case to ensure fields are fully-qualified.
|
||
|
options['content'] = Entity(content.model._meta.table_name,
|
||
|
content.column_name)
|
||
|
|
||
|
if prefix:
|
||
|
if isinstance(prefix, (list, tuple)):
|
||
|
prefix = ','.join([str(i) for i in prefix])
|
||
|
options['prefix'] = "'%s'" % prefix.strip("' ")
|
||
|
|
||
|
if tokenize and cls._meta.extension_module.lower() == 'fts5':
|
||
|
# Tokenizers need to be in quoted string for FTS5, but not for FTS3
|
||
|
# or FTS4.
|
||
|
options['tokenize'] = '"%s"' % tokenize
|
||
|
|
||
|
return options
|
||
|
|
||
|
|
||
|
class FTSModel(BaseFTSModel):
|
||
|
"""
|
||
|
VirtualModel class for creating tables that use either the FTS3 or FTS4
|
||
|
search extensions. Peewee automatically determines which version of the
|
||
|
FTS extension is supported and will use FTS4 if possible.
|
||
|
"""
|
||
|
# FTS3/4 uses "docid" in the same way a normal table uses "rowid".
|
||
|
docid = DocIDField()
|
||
|
|
||
|
class Meta:
|
||
|
extension_module = 'FTS%s' % FTS_VERSION
|
||
|
|
||
|
@classmethod
|
||
|
def _fts_cmd(cls, cmd):
|
||
|
tbl = cls._meta.table_name
|
||
|
res = cls._meta.database.execute_sql(
|
||
|
"INSERT INTO %s(%s) VALUES('%s');" % (tbl, tbl, cmd))
|
||
|
return res.fetchone()
|
||
|
|
||
|
@classmethod
|
||
|
def optimize(cls):
|
||
|
return cls._fts_cmd('optimize')
|
||
|
|
||
|
@classmethod
|
||
|
def rebuild(cls):
|
||
|
return cls._fts_cmd('rebuild')
|
||
|
|
||
|
@classmethod
|
||
|
def integrity_check(cls):
|
||
|
return cls._fts_cmd('integrity-check')
|
||
|
|
||
|
@classmethod
|
||
|
def merge(cls, blocks=200, segments=8):
|
||
|
return cls._fts_cmd('merge=%s,%s' % (blocks, segments))
|
||
|
|
||
|
@classmethod
|
||
|
def automerge(cls, state=True):
|
||
|
return cls._fts_cmd('automerge=%s' % (state and '1' or '0'))
|
||
|
|
||
|
@classmethod
|
||
|
def match(cls, term):
|
||
|
"""
|
||
|
Generate a `MATCH` expression appropriate for searching this table.
|
||
|
"""
|
||
|
return match(cls._meta.entity, term)
|
||
|
|
||
|
@classmethod
|
||
|
def rank(cls, *weights):
|
||
|
matchinfo = fn.matchinfo(cls._meta.entity, FTS3_MATCHINFO)
|
||
|
return fn.fts_rank(matchinfo, *weights)
|
||
|
|
||
|
@classmethod
|
||
|
def bm25(cls, *weights):
|
||
|
match_info = fn.matchinfo(cls._meta.entity, FTS4_MATCHINFO)
|
||
|
return fn.fts_bm25(match_info, *weights)
|
||
|
|
||
|
@classmethod
|
||
|
def bm25f(cls, *weights):
|
||
|
match_info = fn.matchinfo(cls._meta.entity, FTS4_MATCHINFO)
|
||
|
return fn.fts_bm25f(match_info, *weights)
|
||
|
|
||
|
@classmethod
|
||
|
def lucene(cls, *weights):
|
||
|
match_info = fn.matchinfo(cls._meta.entity, FTS4_MATCHINFO)
|
||
|
return fn.fts_lucene(match_info, *weights)
|
||
|
|
||
|
@classmethod
|
||
|
def _search(cls, term, weights, with_score, score_alias, score_fn,
|
||
|
explicit_ordering):
|
||
|
if not weights:
|
||
|
rank = score_fn()
|
||
|
elif isinstance(weights, dict):
|
||
|
weight_args = []
|
||
|
for field in cls._meta.sorted_fields:
|
||
|
# Attempt to get the specified weight of the field by looking
|
||
|
# it up using it's field instance followed by name.
|
||
|
field_weight = weights.get(field, weights.get(field.name, 1.0))
|
||
|
weight_args.append(field_weight)
|
||
|
rank = score_fn(*weight_args)
|
||
|
else:
|
||
|
rank = score_fn(*weights)
|
||
|
|
||
|
selection = ()
|
||
|
order_by = rank
|
||
|
if with_score:
|
||
|
selection = (cls, rank.alias(score_alias))
|
||
|
if with_score and not explicit_ordering:
|
||
|
order_by = SQL(score_alias)
|
||
|
|
||
|
return (cls
|
||
|
.select(*selection)
|
||
|
.where(cls.match(term))
|
||
|
.order_by(order_by))
|
||
|
|
||
|
@classmethod
|
||
|
def search(cls, term, weights=None, with_score=False, score_alias='score',
|
||
|
explicit_ordering=False):
|
||
|
"""Full-text search using selected `term`."""
|
||
|
return cls._search(
|
||
|
term,
|
||
|
weights,
|
||
|
with_score,
|
||
|
score_alias,
|
||
|
cls.rank,
|
||
|
explicit_ordering)
|
||
|
|
||
|
@classmethod
|
||
|
def search_bm25(cls, term, weights=None, with_score=False,
|
||
|
score_alias='score', explicit_ordering=False):
|
||
|
"""Full-text search for selected `term` using BM25 algorithm."""
|
||
|
return cls._search(
|
||
|
term,
|
||
|
weights,
|
||
|
with_score,
|
||
|
score_alias,
|
||
|
cls.bm25,
|
||
|
explicit_ordering)
|
||
|
|
||
|
@classmethod
|
||
|
def search_bm25f(cls, term, weights=None, with_score=False,
|
||
|
score_alias='score', explicit_ordering=False):
|
||
|
"""Full-text search for selected `term` using BM25 algorithm."""
|
||
|
return cls._search(
|
||
|
term,
|
||
|
weights,
|
||
|
with_score,
|
||
|
score_alias,
|
||
|
cls.bm25f,
|
||
|
explicit_ordering)
|
||
|
|
||
|
@classmethod
|
||
|
def search_lucene(cls, term, weights=None, with_score=False,
|
||
|
score_alias='score', explicit_ordering=False):
|
||
|
"""Full-text search for selected `term` using BM25 algorithm."""
|
||
|
return cls._search(
|
||
|
term,
|
||
|
weights,
|
||
|
with_score,
|
||
|
score_alias,
|
||
|
cls.lucene,
|
||
|
explicit_ordering)
|
||
|
|
||
|
|
||
|
_alphabet = 'abcdefghijklmnopqrstuvwxyz'
|
||
|
_alphanum = (set('\t ,"(){}*:_+0123456789') |
|
||
|
set(_alphabet) |
|
||
|
set(_alphabet.upper()) |
|
||
|
set((chr(26),)))
|
||
|
_invalid_ascii = set(chr(p) for p in range(128) if chr(p) not in _alphanum)
|
||
|
_quote_re = re.compile(r'(?:[^\s"]|"(?:\\.|[^"])*")+')
|
||
|
|
||
|
|
||
|
class FTS5Model(BaseFTSModel):
|
||
|
"""
|
||
|
Requires SQLite >= 3.9.0.
|
||
|
|
||
|
Table options:
|
||
|
|
||
|
content: table name of external content, or empty string for "contentless"
|
||
|
content_rowid: column name of external content primary key
|
||
|
prefix: integer(s). Ex: '2' or '2 3 4'
|
||
|
tokenize: porter, unicode61, ascii. Ex: 'porter unicode61'
|
||
|
|
||
|
The unicode tokenizer supports the following parameters:
|
||
|
|
||
|
* remove_diacritics (1 or 0, default is 1)
|
||
|
* tokenchars (string of characters, e.g. '-_'
|
||
|
* separators (string of characters)
|
||
|
|
||
|
Parameters are passed as alternating parameter name and value, so:
|
||
|
|
||
|
{'tokenize': "unicode61 remove_diacritics 0 tokenchars '-_'"}
|
||
|
|
||
|
Content-less tables:
|
||
|
|
||
|
If you don't need the full-text content in it's original form, you can
|
||
|
specify a content-less table. Searches and auxiliary functions will work
|
||
|
as usual, but the only values returned when SELECT-ing can be rowid. Also
|
||
|
content-less tables do not support UPDATE or DELETE.
|
||
|
|
||
|
External content tables:
|
||
|
|
||
|
You can set up triggers to sync these, e.g.
|
||
|
|
||
|
-- Create a table. And an external content fts5 table to index it.
|
||
|
CREATE TABLE tbl(a INTEGER PRIMARY KEY, b);
|
||
|
CREATE VIRTUAL TABLE ft USING fts5(b, content='tbl', content_rowid='a');
|
||
|
|
||
|
-- Triggers to keep the FTS index up to date.
|
||
|
CREATE TRIGGER tbl_ai AFTER INSERT ON tbl BEGIN
|
||
|
INSERT INTO ft(rowid, b) VALUES (new.a, new.b);
|
||
|
END;
|
||
|
CREATE TRIGGER tbl_ad AFTER DELETE ON tbl BEGIN
|
||
|
INSERT INTO ft(fts_idx, rowid, b) VALUES('delete', old.a, old.b);
|
||
|
END;
|
||
|
CREATE TRIGGER tbl_au AFTER UPDATE ON tbl BEGIN
|
||
|
INSERT INTO ft(fts_idx, rowid, b) VALUES('delete', old.a, old.b);
|
||
|
INSERT INTO ft(rowid, b) VALUES (new.a, new.b);
|
||
|
END;
|
||
|
|
||
|
Built-in auxiliary functions:
|
||
|
|
||
|
* bm25(tbl[, weight_0, ... weight_n])
|
||
|
* highlight(tbl, col_idx, prefix, suffix)
|
||
|
* snippet(tbl, col_idx, prefix, suffix, ?, max_tokens)
|
||
|
"""
|
||
|
# FTS5 does not support declared primary keys, but we can use the
|
||
|
# implicit rowid.
|
||
|
rowid = RowIDField()
|
||
|
|
||
|
class Meta:
|
||
|
extension_module = 'fts5'
|
||
|
|
||
|
_error_messages = {
|
||
|
'field_type': ('Besides the implicit `rowid` column, all columns must '
|
||
|
'be instances of SearchField'),
|
||
|
'index': 'Secondary indexes are not supported for FTS5 models',
|
||
|
'pk': 'FTS5 models must use the default `rowid` primary key',
|
||
|
}
|
||
|
|
||
|
@classmethod
|
||
|
def validate_model(cls):
|
||
|
# Perform FTS5-specific validation and options post-processing.
|
||
|
if cls._meta.primary_key.name != 'rowid':
|
||
|
raise ImproperlyConfigured(cls._error_messages['pk'])
|
||
|
for field in cls._meta.fields.values():
|
||
|
if not isinstance(field, (SearchField, RowIDField)):
|
||
|
raise ImproperlyConfigured(cls._error_messages['field_type'])
|
||
|
if cls._meta.indexes:
|
||
|
raise ImproperlyConfigured(cls._error_messages['index'])
|
||
|
|
||
|
@classmethod
|
||
|
def fts5_installed(cls):
|
||
|
if sqlite3.sqlite_version_info[:3] < FTS5_MIN_SQLITE_VERSION:
|
||
|
return False
|
||
|
|
||
|
# Test in-memory DB to determine if the FTS5 extension is installed.
|
||
|
tmp_db = sqlite3.connect(':memory:')
|
||
|
try:
|
||
|
tmp_db.execute('CREATE VIRTUAL TABLE fts5test USING fts5 (data);')
|
||
|
except:
|
||
|
try:
|
||
|
tmp_db.enable_load_extension(True)
|
||
|
tmp_db.load_extension('fts5')
|
||
|
except:
|
||
|
return False
|
||
|
else:
|
||
|
cls._meta.database.load_extension('fts5')
|
||
|
finally:
|
||
|
tmp_db.close()
|
||
|
|
||
|
return True
|
||
|
|
||
|
@staticmethod
|
||
|
def validate_query(query):
|
||
|
"""
|
||
|
Simple helper function to indicate whether a search query is a
|
||
|
valid FTS5 query. Note: this simply looks at the characters being
|
||
|
used, and is not guaranteed to catch all problematic queries.
|
||
|
"""
|
||
|
tokens = _quote_re.findall(query)
|
||
|
for token in tokens:
|
||
|
if token.startswith('"') and token.endswith('"'):
|
||
|
continue
|
||
|
if set(token) & _invalid_ascii:
|
||
|
return False
|
||
|
return True
|
||
|
|
||
|
@staticmethod
|
||
|
def clean_query(query, replace=chr(26)):
|
||
|
"""
|
||
|
Clean a query of invalid tokens.
|
||
|
"""
|
||
|
accum = []
|
||
|
any_invalid = False
|
||
|
tokens = _quote_re.findall(query)
|
||
|
for token in tokens:
|
||
|
if token.startswith('"') and token.endswith('"'):
|
||
|
accum.append(token)
|
||
|
continue
|
||
|
token_set = set(token)
|
||
|
invalid_for_token = token_set & _invalid_ascii
|
||
|
if invalid_for_token:
|
||
|
any_invalid = True
|
||
|
for c in invalid_for_token:
|
||
|
token = token.replace(c, replace)
|
||
|
accum.append(token)
|
||
|
|
||
|
if any_invalid:
|
||
|
return ' '.join(accum)
|
||
|
return query
|
||
|
|
||
|
@classmethod
|
||
|
def match(cls, term):
|
||
|
"""
|
||
|
Generate a `MATCH` expression appropriate for searching this table.
|
||
|
"""
|
||
|
return match(cls._meta.entity, term)
|
||
|
|
||
|
@classmethod
|
||
|
def rank(cls, *args):
|
||
|
return cls.bm25(*args) if args else SQL('rank')
|
||
|
|
||
|
@classmethod
|
||
|
def bm25(cls, *weights):
|
||
|
return fn.bm25(cls._meta.entity, *weights)
|
||
|
|
||
|
@classmethod
|
||
|
def search(cls, term, weights=None, with_score=False, score_alias='score',
|
||
|
explicit_ordering=False):
|
||
|
"""Full-text search using selected `term`."""
|
||
|
return cls.search_bm25(
|
||
|
FTS5Model.clean_query(term),
|
||
|
weights,
|
||
|
with_score,
|
||
|
score_alias,
|
||
|
explicit_ordering)
|
||
|
|
||
|
@classmethod
|
||
|
def search_bm25(cls, term, weights=None, with_score=False,
|
||
|
score_alias='score', explicit_ordering=False):
|
||
|
"""Full-text search using selected `term`."""
|
||
|
if not weights:
|
||
|
rank = SQL('rank')
|
||
|
elif isinstance(weights, dict):
|
||
|
weight_args = []
|
||
|
for field in cls._meta.sorted_fields:
|
||
|
if isinstance(field, SearchField) and not field.unindexed:
|
||
|
weight_args.append(
|
||
|
weights.get(field, weights.get(field.name, 1.0)))
|
||
|
rank = fn.bm25(cls._meta.entity, *weight_args)
|
||
|
else:
|
||
|
rank = fn.bm25(cls._meta.entity, *weights)
|
||
|
|
||
|
selection = ()
|
||
|
order_by = rank
|
||
|
if with_score:
|
||
|
selection = (cls, rank.alias(score_alias))
|
||
|
if with_score and not explicit_ordering:
|
||
|
order_by = SQL(score_alias)
|
||
|
|
||
|
return (cls
|
||
|
.select(*selection)
|
||
|
.where(cls.match(FTS5Model.clean_query(term)))
|
||
|
.order_by(order_by))
|
||
|
|
||
|
@classmethod
|
||
|
def _fts_cmd_sql(cls, cmd, **extra_params):
|
||
|
tbl = cls._meta.entity
|
||
|
columns = [tbl]
|
||
|
values = [cmd]
|
||
|
for key, value in extra_params.items():
|
||
|
columns.append(Entity(key))
|
||
|
values.append(value)
|
||
|
|
||
|
return NodeList((
|
||
|
SQL('INSERT INTO'),
|
||
|
cls._meta.entity,
|
||
|
EnclosedNodeList(columns),
|
||
|
SQL('VALUES'),
|
||
|
EnclosedNodeList(values)))
|
||
|
|
||
|
@classmethod
|
||
|
def _fts_cmd(cls, cmd, **extra_params):
|
||
|
query = cls._fts_cmd_sql(cmd, **extra_params)
|
||
|
return cls._meta.database.execute(query)
|
||
|
|
||
|
@classmethod
|
||
|
def automerge(cls, level):
|
||
|
if not (0 <= level <= 16):
|
||
|
raise ValueError('level must be between 0 and 16')
|
||
|
return cls._fts_cmd('automerge', rank=level)
|
||
|
|
||
|
@classmethod
|
||
|
def merge(cls, npages):
|
||
|
return cls._fts_cmd('merge', rank=npages)
|
||
|
|
||
|
@classmethod
|
||
|
def set_pgsz(cls, pgsz):
|
||
|
return cls._fts_cmd('pgsz', rank=pgsz)
|
||
|
|
||
|
@classmethod
|
||
|
def set_rank(cls, rank_expression):
|
||
|
return cls._fts_cmd('rank', rank=rank_expression)
|
||
|
|
||
|
@classmethod
|
||
|
def delete_all(cls):
|
||
|
return cls._fts_cmd('delete-all')
|
||
|
|
||
|
@classmethod
|
||
|
def VocabModel(cls, table_type='row', table=None):
|
||
|
if table_type not in ('row', 'col', 'instance'):
|
||
|
raise ValueError('table_type must be either "row", "col" or '
|
||
|
'"instance".')
|
||
|
|
||
|
attr = '_vocab_model_%s' % table_type
|
||
|
|
||
|
if not hasattr(cls, attr):
|
||
|
class Meta:
|
||
|
database = cls._meta.database
|
||
|
table_name = table or cls._meta.table_name + '_v'
|
||
|
extension_module = fn.fts5vocab(
|
||
|
cls._meta.entity,
|
||
|
SQL(table_type))
|
||
|
|
||
|
attrs = {
|
||
|
'term': VirtualField(TextField),
|
||
|
'doc': IntegerField(),
|
||
|
'cnt': IntegerField(),
|
||
|
'rowid': RowIDField(),
|
||
|
'Meta': Meta,
|
||
|
}
|
||
|
if table_type == 'col':
|
||
|
attrs['col'] = VirtualField(TextField)
|
||
|
elif table_type == 'instance':
|
||
|
attrs['offset'] = VirtualField(IntegerField)
|
||
|
|
||
|
class_name = '%sVocab' % cls.__name__
|
||
|
setattr(cls, attr, type(class_name, (VirtualModel,), attrs))
|
||
|
|
||
|
return getattr(cls, attr)
|
||
|
|
||
|
|
||
|
def ClosureTable(model_class, foreign_key=None, referencing_class=None,
|
||
|
referencing_key=None):
|
||
|
"""Model factory for the transitive closure extension."""
|
||
|
if referencing_class is None:
|
||
|
referencing_class = model_class
|
||
|
|
||
|
if foreign_key is None:
|
||
|
for field_obj in model_class._meta.refs:
|
||
|
if field_obj.rel_model is model_class:
|
||
|
foreign_key = field_obj
|
||
|
break
|
||
|
else:
|
||
|
raise ValueError('Unable to find self-referential foreign key.')
|
||
|
|
||
|
source_key = model_class._meta.primary_key
|
||
|
if referencing_key is None:
|
||
|
referencing_key = source_key
|
||
|
|
||
|
class BaseClosureTable(VirtualModel):
|
||
|
depth = VirtualField(IntegerField)
|
||
|
id = VirtualField(IntegerField)
|
||
|
idcolumn = VirtualField(TextField)
|
||
|
parentcolumn = VirtualField(TextField)
|
||
|
root = VirtualField(IntegerField)
|
||
|
tablename = VirtualField(TextField)
|
||
|
|
||
|
class Meta:
|
||
|
extension_module = 'transitive_closure'
|
||
|
|
||
|
@classmethod
|
||
|
def descendants(cls, node, depth=None, include_node=False):
|
||
|
query = (model_class
|
||
|
.select(model_class, cls.depth.alias('depth'))
|
||
|
.join(cls, on=(source_key == cls.id))
|
||
|
.where(cls.root == node)
|
||
|
.objects())
|
||
|
if depth is not None:
|
||
|
query = query.where(cls.depth == depth)
|
||
|
elif not include_node:
|
||
|
query = query.where(cls.depth > 0)
|
||
|
return query
|
||
|
|
||
|
@classmethod
|
||
|
def ancestors(cls, node, depth=None, include_node=False):
|
||
|
query = (model_class
|
||
|
.select(model_class, cls.depth.alias('depth'))
|
||
|
.join(cls, on=(source_key == cls.root))
|
||
|
.where(cls.id == node)
|
||
|
.objects())
|
||
|
if depth:
|
||
|
query = query.where(cls.depth == depth)
|
||
|
elif not include_node:
|
||
|
query = query.where(cls.depth > 0)
|
||
|
return query
|
||
|
|
||
|
@classmethod
|
||
|
def siblings(cls, node, include_node=False):
|
||
|
if referencing_class is model_class:
|
||
|
# self-join
|
||
|
fk_value = node.__data__.get(foreign_key.name)
|
||
|
query = model_class.select().where(foreign_key == fk_value)
|
||
|
else:
|
||
|
# siblings as given in reference_class
|
||
|
siblings = (referencing_class
|
||
|
.select(referencing_key)
|
||
|
.join(cls, on=(foreign_key == cls.root))
|
||
|
.where((cls.id == node) & (cls.depth == 1)))
|
||
|
|
||
|
# the according models
|
||
|
query = (model_class
|
||
|
.select()
|
||
|
.where(source_key << siblings)
|
||
|
.objects())
|
||
|
|
||
|
if not include_node:
|
||
|
query = query.where(source_key != node)
|
||
|
|
||
|
return query
|
||
|
|
||
|
class Meta:
|
||
|
database = referencing_class._meta.database
|
||
|
options = {
|
||
|
'tablename': referencing_class._meta.table_name,
|
||
|
'idcolumn': referencing_key.column_name,
|
||
|
'parentcolumn': foreign_key.column_name}
|
||
|
primary_key = False
|
||
|
|
||
|
name = '%sClosure' % model_class.__name__
|
||
|
return type(name, (BaseClosureTable,), {'Meta': Meta})
|
||
|
|
||
|
|
||
|
class LSMTable(VirtualModel):
|
||
|
class Meta:
|
||
|
extension_module = 'lsm1'
|
||
|
filename = None
|
||
|
|
||
|
@classmethod
|
||
|
def clean_options(cls, options):
|
||
|
filename = cls._meta.filename
|
||
|
if not filename:
|
||
|
raise ValueError('LSM1 extension requires that you specify a '
|
||
|
'filename for the LSM database.')
|
||
|
else:
|
||
|
if len(filename) >= 2 and filename[0] != '"':
|
||
|
filename = '"%s"' % filename
|
||
|
if not cls._meta.primary_key:
|
||
|
raise ValueError('LSM1 models must specify a primary-key field.')
|
||
|
|
||
|
key = cls._meta.primary_key
|
||
|
if isinstance(key, AutoField):
|
||
|
raise ValueError('LSM1 models must explicitly declare a primary '
|
||
|
'key field.')
|
||
|
if not isinstance(key, (TextField, BlobField, IntegerField)):
|
||
|
raise ValueError('LSM1 key must be a TextField, BlobField, or '
|
||
|
'IntegerField.')
|
||
|
key._hidden = True
|
||
|
if isinstance(key, IntegerField):
|
||
|
data_type = 'UINT'
|
||
|
elif isinstance(key, BlobField):
|
||
|
data_type = 'BLOB'
|
||
|
else:
|
||
|
data_type = 'TEXT'
|
||
|
cls._meta.prefix_arguments = [filename, '"%s"' % key.name, data_type]
|
||
|
|
||
|
# Does the key map to a scalar value, or a tuple of values?
|
||
|
if len(cls._meta.sorted_fields) == 2:
|
||
|
cls._meta._value_field = cls._meta.sorted_fields[1]
|
||
|
else:
|
||
|
cls._meta._value_field = None
|
||
|
|
||
|
return options
|
||
|
|
||
|
@classmethod
|
||
|
def load_extension(cls, path='lsm.so'):
|
||
|
cls._meta.database.load_extension(path)
|
||
|
|
||
|
@staticmethod
|
||
|
def slice_to_expr(key, idx):
|
||
|
if idx.start is not None and idx.stop is not None:
|
||
|
return key.between(idx.start, idx.stop)
|
||
|
elif idx.start is not None:
|
||
|
return key >= idx.start
|
||
|
elif idx.stop is not None:
|
||
|
return key <= idx.stop
|
||
|
|
||
|
@staticmethod
|
||
|
def _apply_lookup_to_query(query, key, lookup):
|
||
|
if isinstance(lookup, slice):
|
||
|
expr = LSMTable.slice_to_expr(key, lookup)
|
||
|
if expr is not None:
|
||
|
query = query.where(expr)
|
||
|
return query, False
|
||
|
elif isinstance(lookup, Expression):
|
||
|
return query.where(lookup), False
|
||
|
else:
|
||
|
return query.where(key == lookup), True
|
||
|
|
||
|
@classmethod
|
||
|
def get_by_id(cls, pk):
|
||
|
query, is_single = cls._apply_lookup_to_query(
|
||
|
cls.select().namedtuples(),
|
||
|
cls._meta.primary_key,
|
||
|
pk)
|
||
|
|
||
|
if is_single:
|
||
|
try:
|
||
|
row = query.get()
|
||
|
except cls.DoesNotExist:
|
||
|
raise KeyError(pk)
|
||
|
return row[1] if cls._meta._value_field is not None else row
|
||
|
else:
|
||
|
return query
|
||
|
|
||
|
@classmethod
|
||
|
def set_by_id(cls, key, value):
|
||
|
if cls._meta._value_field is not None:
|
||
|
data = {cls._meta._value_field: value}
|
||
|
elif isinstance(value, tuple):
|
||
|
data = {}
|
||
|
for field, fval in zip(cls._meta.sorted_fields[1:], value):
|
||
|
data[field] = fval
|
||
|
elif isinstance(value, dict):
|
||
|
data = value
|
||
|
elif isinstance(value, cls):
|
||
|
data = value.__dict__
|
||
|
data[cls._meta.primary_key] = key
|
||
|
cls.replace(data).execute()
|
||
|
|
||
|
@classmethod
|
||
|
def delete_by_id(cls, pk):
|
||
|
query, is_single = cls._apply_lookup_to_query(
|
||
|
cls.delete(),
|
||
|
cls._meta.primary_key,
|
||
|
pk)
|
||
|
return query.execute()
|
||
|
|
||
|
|
||
|
OP.MATCH = 'MATCH'
|
||
|
|
||
|
def _sqlite_regexp(regex, value):
|
||
|
return re.search(regex, value) is not None
|
||
|
|
||
|
|
||
|
class SqliteExtDatabase(SqliteDatabase):
|
||
|
def __init__(self, database, c_extensions=None, rank_functions=True,
|
||
|
hash_functions=False, regexp_function=False,
|
||
|
bloomfilter=False, json_contains=False, *args, **kwargs):
|
||
|
super(SqliteExtDatabase, self).__init__(database, *args, **kwargs)
|
||
|
self._row_factory = None
|
||
|
|
||
|
if c_extensions and not CYTHON_SQLITE_EXTENSIONS:
|
||
|
raise ImproperlyConfigured('SqliteExtDatabase initialized with '
|
||
|
'C extensions, but shared library was '
|
||
|
'not found!')
|
||
|
prefer_c = CYTHON_SQLITE_EXTENSIONS and (c_extensions is not False)
|
||
|
if rank_functions:
|
||
|
if prefer_c:
|
||
|
register_rank_functions(self)
|
||
|
else:
|
||
|
self.register_function(bm25, 'fts_bm25')
|
||
|
self.register_function(rank, 'fts_rank')
|
||
|
self.register_function(bm25, 'fts_bm25f') # Fall back to bm25.
|
||
|
self.register_function(bm25, 'fts_lucene')
|
||
|
if hash_functions:
|
||
|
if not prefer_c:
|
||
|
raise ValueError('C extension required to register hash '
|
||
|
'functions.')
|
||
|
register_hash_functions(self)
|
||
|
if regexp_function:
|
||
|
self.register_function(_sqlite_regexp, 'regexp', 2)
|
||
|
if bloomfilter:
|
||
|
if not prefer_c:
|
||
|
raise ValueError('C extension required to use bloomfilter.')
|
||
|
register_bloomfilter(self)
|
||
|
if json_contains:
|
||
|
self.register_function(_json_contains, 'json_contains')
|
||
|
|
||
|
self._c_extensions = prefer_c
|
||
|
|
||
|
def _add_conn_hooks(self, conn):
|
||
|
super(SqliteExtDatabase, self)._add_conn_hooks(conn)
|
||
|
if self._row_factory:
|
||
|
conn.row_factory = self._row_factory
|
||
|
|
||
|
def row_factory(self, fn):
|
||
|
self._row_factory = fn
|
||
|
|
||
|
|
||
|
if CYTHON_SQLITE_EXTENSIONS:
|
||
|
SQLITE_STATUS_MEMORY_USED = 0
|
||
|
SQLITE_STATUS_PAGECACHE_USED = 1
|
||
|
SQLITE_STATUS_PAGECACHE_OVERFLOW = 2
|
||
|
SQLITE_STATUS_SCRATCH_USED = 3
|
||
|
SQLITE_STATUS_SCRATCH_OVERFLOW = 4
|
||
|
SQLITE_STATUS_MALLOC_SIZE = 5
|
||
|
SQLITE_STATUS_PARSER_STACK = 6
|
||
|
SQLITE_STATUS_PAGECACHE_SIZE = 7
|
||
|
SQLITE_STATUS_SCRATCH_SIZE = 8
|
||
|
SQLITE_STATUS_MALLOC_COUNT = 9
|
||
|
SQLITE_DBSTATUS_LOOKASIDE_USED = 0
|
||
|
SQLITE_DBSTATUS_CACHE_USED = 1
|
||
|
SQLITE_DBSTATUS_SCHEMA_USED = 2
|
||
|
SQLITE_DBSTATUS_STMT_USED = 3
|
||
|
SQLITE_DBSTATUS_LOOKASIDE_HIT = 4
|
||
|
SQLITE_DBSTATUS_LOOKASIDE_MISS_SIZE = 5
|
||
|
SQLITE_DBSTATUS_LOOKASIDE_MISS_FULL = 6
|
||
|
SQLITE_DBSTATUS_CACHE_HIT = 7
|
||
|
SQLITE_DBSTATUS_CACHE_MISS = 8
|
||
|
SQLITE_DBSTATUS_CACHE_WRITE = 9
|
||
|
SQLITE_DBSTATUS_DEFERRED_FKS = 10
|
||
|
#SQLITE_DBSTATUS_CACHE_USED_SHARED = 11
|
||
|
|
||
|
def __status__(flag, return_highwater=False):
|
||
|
"""
|
||
|
Expose a sqlite3_status() call for a particular flag as a property of
|
||
|
the Database object.
|
||
|
"""
|
||
|
def getter(self):
|
||
|
result = sqlite_get_status(flag)
|
||
|
return result[1] if return_highwater else result
|
||
|
return property(getter)
|
||
|
|
||
|
def __dbstatus__(flag, return_highwater=False, return_current=False):
|
||
|
"""
|
||
|
Expose a sqlite3_dbstatus() call for a particular flag as a property of
|
||
|
the Database instance. Unlike sqlite3_status(), the dbstatus properties
|
||
|
pertain to the current connection.
|
||
|
"""
|
||
|
def getter(self):
|
||
|
if self._state.conn is None:
|
||
|
raise ImproperlyConfigured('database connection not opened.')
|
||
|
result = sqlite_get_db_status(self._state.conn, flag)
|
||
|
if return_current:
|
||
|
return result[0]
|
||
|
return result[1] if return_highwater else result
|
||
|
return property(getter)
|
||
|
|
||
|
class CSqliteExtDatabase(SqliteExtDatabase):
|
||
|
def __init__(self, *args, **kwargs):
|
||
|
self._conn_helper = None
|
||
|
self._commit_hook = self._rollback_hook = self._update_hook = None
|
||
|
self._replace_busy_handler = False
|
||
|
super(CSqliteExtDatabase, self).__init__(*args, **kwargs)
|
||
|
|
||
|
def init(self, database, replace_busy_handler=False, **kwargs):
|
||
|
super(CSqliteExtDatabase, self).init(database, **kwargs)
|
||
|
self._replace_busy_handler = replace_busy_handler
|
||
|
|
||
|
def _close(self, conn):
|
||
|
if self._commit_hook:
|
||
|
self._conn_helper.set_commit_hook(None)
|
||
|
if self._rollback_hook:
|
||
|
self._conn_helper.set_rollback_hook(None)
|
||
|
if self._update_hook:
|
||
|
self._conn_helper.set_update_hook(None)
|
||
|
return super(CSqliteExtDatabase, self)._close(conn)
|
||
|
|
||
|
def _add_conn_hooks(self, conn):
|
||
|
super(CSqliteExtDatabase, self)._add_conn_hooks(conn)
|
||
|
self._conn_helper = ConnectionHelper(conn)
|
||
|
if self._commit_hook is not None:
|
||
|
self._conn_helper.set_commit_hook(self._commit_hook)
|
||
|
if self._rollback_hook is not None:
|
||
|
self._conn_helper.set_rollback_hook(self._rollback_hook)
|
||
|
if self._update_hook is not None:
|
||
|
self._conn_helper.set_update_hook(self._update_hook)
|
||
|
if self._replace_busy_handler:
|
||
|
timeout = self._timeout or 5
|
||
|
self._conn_helper.set_busy_handler(timeout * 1000)
|
||
|
|
||
|
def on_commit(self, fn):
|
||
|
self._commit_hook = fn
|
||
|
if not self.is_closed():
|
||
|
self._conn_helper.set_commit_hook(fn)
|
||
|
return fn
|
||
|
|
||
|
def on_rollback(self, fn):
|
||
|
self._rollback_hook = fn
|
||
|
if not self.is_closed():
|
||
|
self._conn_helper.set_rollback_hook(fn)
|
||
|
return fn
|
||
|
|
||
|
def on_update(self, fn):
|
||
|
self._update_hook = fn
|
||
|
if not self.is_closed():
|
||
|
self._conn_helper.set_update_hook(fn)
|
||
|
return fn
|
||
|
|
||
|
def changes(self):
|
||
|
return self._conn_helper.changes()
|
||
|
|
||
|
@property
|
||
|
def last_insert_rowid(self):
|
||
|
return self._conn_helper.last_insert_rowid()
|
||
|
|
||
|
@property
|
||
|
def autocommit(self):
|
||
|
return self._conn_helper.autocommit()
|
||
|
|
||
|
def backup(self, destination, pages=None, name=None, progress=None):
|
||
|
return backup(self.connection(), destination.connection(),
|
||
|
pages=pages, name=name, progress=progress)
|
||
|
|
||
|
def backup_to_file(self, filename, pages=None, name=None,
|
||
|
progress=None):
|
||
|
return backup_to_file(self.connection(), filename, pages=pages,
|
||
|
name=name, progress=progress)
|
||
|
|
||
|
def blob_open(self, table, column, rowid, read_only=False):
|
||
|
return Blob(self, table, column, rowid, read_only)
|
||
|
|
||
|
# Status properties.
|
||
|
memory_used = __status__(SQLITE_STATUS_MEMORY_USED)
|
||
|
malloc_size = __status__(SQLITE_STATUS_MALLOC_SIZE, True)
|
||
|
malloc_count = __status__(SQLITE_STATUS_MALLOC_COUNT)
|
||
|
pagecache_used = __status__(SQLITE_STATUS_PAGECACHE_USED)
|
||
|
pagecache_overflow = __status__(SQLITE_STATUS_PAGECACHE_OVERFLOW)
|
||
|
pagecache_size = __status__(SQLITE_STATUS_PAGECACHE_SIZE, True)
|
||
|
scratch_used = __status__(SQLITE_STATUS_SCRATCH_USED)
|
||
|
scratch_overflow = __status__(SQLITE_STATUS_SCRATCH_OVERFLOW)
|
||
|
scratch_size = __status__(SQLITE_STATUS_SCRATCH_SIZE, True)
|
||
|
|
||
|
# Connection status properties.
|
||
|
lookaside_used = __dbstatus__(SQLITE_DBSTATUS_LOOKASIDE_USED)
|
||
|
lookaside_hit = __dbstatus__(SQLITE_DBSTATUS_LOOKASIDE_HIT, True)
|
||
|
lookaside_miss = __dbstatus__(SQLITE_DBSTATUS_LOOKASIDE_MISS_SIZE,
|
||
|
True)
|
||
|
lookaside_miss_full = __dbstatus__(SQLITE_DBSTATUS_LOOKASIDE_MISS_FULL,
|
||
|
True)
|
||
|
cache_used = __dbstatus__(SQLITE_DBSTATUS_CACHE_USED, False, True)
|
||
|
#cache_used_shared = __dbstatus__(SQLITE_DBSTATUS_CACHE_USED_SHARED,
|
||
|
# False, True)
|
||
|
schema_used = __dbstatus__(SQLITE_DBSTATUS_SCHEMA_USED, False, True)
|
||
|
statement_used = __dbstatus__(SQLITE_DBSTATUS_STMT_USED, False, True)
|
||
|
cache_hit = __dbstatus__(SQLITE_DBSTATUS_CACHE_HIT, False, True)
|
||
|
cache_miss = __dbstatus__(SQLITE_DBSTATUS_CACHE_MISS, False, True)
|
||
|
cache_write = __dbstatus__(SQLITE_DBSTATUS_CACHE_WRITE, False, True)
|
||
|
|
||
|
|
||
|
def match(lhs, rhs):
|
||
|
return Expression(lhs, OP.MATCH, rhs)
|
||
|
|
||
|
def _parse_match_info(buf):
|
||
|
# See http://sqlite.org/fts3.html#matchinfo
|
||
|
bufsize = len(buf) # Length in bytes.
|
||
|
return [struct.unpack('@I', buf[i:i+4])[0] for i in range(0, bufsize, 4)]
|
||
|
|
||
|
def get_weights(ncol, raw_weights):
|
||
|
if not raw_weights:
|
||
|
return [1] * ncol
|
||
|
else:
|
||
|
weights = [0] * ncol
|
||
|
for i, weight in enumerate(raw_weights):
|
||
|
weights[i] = weight
|
||
|
return weights
|
||
|
|
||
|
# Ranking implementation, which parse matchinfo.
|
||
|
def rank(raw_match_info, *raw_weights):
|
||
|
# Handle match_info called w/default args 'pcx' - based on the example rank
|
||
|
# function http://sqlite.org/fts3.html#appendix_a
|
||
|
match_info = _parse_match_info(raw_match_info)
|
||
|
score = 0.0
|
||
|
|
||
|
p, c = match_info[:2]
|
||
|
weights = get_weights(c, raw_weights)
|
||
|
|
||
|
# matchinfo X value corresponds to, for each phrase in the search query, a
|
||
|
# list of 3 values for each column in the search table.
|
||
|
# So if we have a two-phrase search query and three columns of data, the
|
||
|
# following would be the layout:
|
||
|
# p0 : c0=[0, 1, 2], c1=[3, 4, 5], c2=[6, 7, 8]
|
||
|
# p1 : c0=[9, 10, 11], c1=[12, 13, 14], c2=[15, 16, 17]
|
||
|
for phrase_num in range(p):
|
||
|
phrase_info_idx = 2 + (phrase_num * c * 3)
|
||
|
for col_num in range(c):
|
||
|
weight = weights[col_num]
|
||
|
if not weight:
|
||
|
continue
|
||
|
|
||
|
col_idx = phrase_info_idx + (col_num * 3)
|
||
|
|
||
|
# The idea is that we count the number of times the phrase appears
|
||
|
# in this column of the current row, compared to how many times it
|
||
|
# appears in this column across all rows. The ratio of these values
|
||
|
# provides a rough way to score based on "high value" terms.
|
||
|
row_hits = match_info[col_idx]
|
||
|
all_rows_hits = match_info[col_idx + 1]
|
||
|
if row_hits > 0:
|
||
|
score += weight * (float(row_hits) / all_rows_hits)
|
||
|
|
||
|
return -score
|
||
|
|
||
|
# Okapi BM25 ranking implementation (FTS4 only).
|
||
|
def bm25(raw_match_info, *args):
|
||
|
"""
|
||
|
Usage:
|
||
|
|
||
|
# Format string *must* be pcnalx
|
||
|
# Second parameter to bm25 specifies the index of the column, on
|
||
|
# the table being queries.
|
||
|
bm25(matchinfo(document_tbl, 'pcnalx'), 1) AS rank
|
||
|
"""
|
||
|
match_info = _parse_match_info(raw_match_info)
|
||
|
K = 1.2
|
||
|
B = 0.75
|
||
|
score = 0.0
|
||
|
|
||
|
P_O, C_O, N_O, A_O = range(4) # Offsets into the matchinfo buffer.
|
||
|
term_count = match_info[P_O] # n
|
||
|
col_count = match_info[C_O]
|
||
|
total_docs = match_info[N_O] # N
|
||
|
L_O = A_O + col_count
|
||
|
X_O = L_O + col_count
|
||
|
|
||
|
# Worked example of pcnalx for two columns and two phrases, 100 docs total.
|
||
|
# {
|
||
|
# p = 2
|
||
|
# c = 2
|
||
|
# n = 100
|
||
|
# a0 = 4 -- avg number of tokens for col0, e.g. title
|
||
|
# a1 = 40 -- avg number of tokens for col1, e.g. body
|
||
|
# l0 = 5 -- curr doc has 5 tokens in col0
|
||
|
# l1 = 30 -- curr doc has 30 tokens in col1
|
||
|
#
|
||
|
# x000 -- hits this row for phrase0, col0
|
||
|
# x001 -- hits all rows for phrase0, col0
|
||
|
# x002 -- rows with phrase0 in col0 at least once
|
||
|
#
|
||
|
# x010 -- hits this row for phrase0, col1
|
||
|
# x011 -- hits all rows for phrase0, col1
|
||
|
# x012 -- rows with phrase0 in col1 at least once
|
||
|
#
|
||
|
# x100 -- hits this row for phrase1, col0
|
||
|
# x101 -- hits all rows for phrase1, col0
|
||
|
# x102 -- rows with phrase1 in col0 at least once
|
||
|
#
|
||
|
# x110 -- hits this row for phrase1, col1
|
||
|
# x111 -- hits all rows for phrase1, col1
|
||
|
# x112 -- rows with phrase1 in col1 at least once
|
||
|
# }
|
||
|
|
||
|
weights = get_weights(col_count, args)
|
||
|
|
||
|
for i in range(term_count):
|
||
|
for j in range(col_count):
|
||
|
weight = weights[j]
|
||
|
if weight == 0:
|
||
|
continue
|
||
|
|
||
|
x = X_O + (3 * (j + i * col_count))
|
||
|
term_frequency = float(match_info[x]) # f(qi, D)
|
||
|
docs_with_term = float(match_info[x + 2]) # n(qi)
|
||
|
|
||
|
# log( (N - n(qi) + 0.5) / (n(qi) + 0.5) )
|
||
|
idf = math.log(
|
||
|
(total_docs - docs_with_term + 0.5) /
|
||
|
(docs_with_term + 0.5))
|
||
|
if idf <= 0.0:
|
||
|
idf = 1e-6
|
||
|
|
||
|
doc_length = float(match_info[L_O + j]) # |D|
|
||
|
avg_length = float(match_info[A_O + j]) or 1. # avgdl
|
||
|
ratio = doc_length / avg_length
|
||
|
|
||
|
num = term_frequency * (K + 1.0)
|
||
|
b_part = 1.0 - B + (B * ratio)
|
||
|
denom = term_frequency + (K * b_part)
|
||
|
|
||
|
pc_score = idf * (num / denom)
|
||
|
score += (pc_score * weight)
|
||
|
|
||
|
return -score
|
||
|
|
||
|
|
||
|
def _json_contains(src_json, obj_json):
|
||
|
stack = []
|
||
|
try:
|
||
|
stack.append((json.loads(obj_json), json.loads(src_json)))
|
||
|
except:
|
||
|
# Invalid JSON!
|
||
|
return False
|
||
|
|
||
|
while stack:
|
||
|
obj, src = stack.pop()
|
||
|
if isinstance(src, dict):
|
||
|
if isinstance(obj, dict):
|
||
|
for key in obj:
|
||
|
if key not in src:
|
||
|
return False
|
||
|
stack.append((obj[key], src[key]))
|
||
|
elif isinstance(obj, list):
|
||
|
for item in obj:
|
||
|
if item not in src:
|
||
|
return False
|
||
|
elif obj not in src:
|
||
|
return False
|
||
|
elif isinstance(src, list):
|
||
|
if isinstance(obj, dict):
|
||
|
return False
|
||
|
elif isinstance(obj, list):
|
||
|
try:
|
||
|
for i in range(len(obj)):
|
||
|
stack.append((obj[i], src[i]))
|
||
|
except IndexError:
|
||
|
return False
|
||
|
elif obj not in src:
|
||
|
return False
|
||
|
elif obj != src:
|
||
|
return False
|
||
|
return True
|