parent
72419538a3
commit
7c987ac8f7
@ -1,3 +0,0 @@
|
||||
__author__ = 'Audrey Roy'
|
||||
__email__ = 'audreyr@gmail.com'
|
||||
__version__ = '0.4.4'
|
@ -1,52 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
binaryornot.check
|
||||
-----------------
|
||||
|
||||
Main code for checking if a file is binary or text.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import argparse
|
||||
|
||||
from binaryornot.helpers import get_starting_chunk, is_binary_string
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_binary(filename):
|
||||
"""
|
||||
:param filename: File to check.
|
||||
:returns: True if it's a binary file, otherwise False.
|
||||
"""
|
||||
logger.debug('is_binary: %(filename)r', locals())
|
||||
|
||||
# Check if the file extension is in a list of known binary types
|
||||
# binary_extensions = ['.pyc', ]
|
||||
# for ext in binary_extensions:
|
||||
# if filename.endswith(ext):
|
||||
# return True
|
||||
|
||||
# Check if the starting chunk is a binary string
|
||||
chunk = get_starting_chunk(filename)
|
||||
return is_binary_string(chunk)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Check if a "
|
||||
"file passed as argument is "
|
||||
"binary or not")
|
||||
|
||||
parser.add_argument("filename", help="File name to check for. If "
|
||||
"the file is not in the same folder, "
|
||||
"include full path")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print(is_binary(**vars(args)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,135 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
|
||||
"""
|
||||
binaryornot.helpers
|
||||
-------------------
|
||||
|
||||
Helper utilities used by BinaryOrNot.
|
||||
"""
|
||||
|
||||
import chardet
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def print_as_hex(s):
|
||||
"""
|
||||
Print a string as hex bytes.
|
||||
"""
|
||||
print(":".join("{0:x}".format(ord(c)) for c in s))
|
||||
|
||||
|
||||
def get_starting_chunk(filename, length=1024):
|
||||
"""
|
||||
:param filename: File to open and get the first little chunk of.
|
||||
:param length: Number of bytes to read, default 1024.
|
||||
:returns: Starting chunk of bytes.
|
||||
"""
|
||||
# Ensure we open the file in binary mode
|
||||
try:
|
||||
with open(filename, 'rb') as f:
|
||||
chunk = f.read(length)
|
||||
return chunk
|
||||
except IOError as e:
|
||||
print(e)
|
||||
|
||||
|
||||
_control_chars = b'\n\r\t\f\b'
|
||||
if bytes is str:
|
||||
# Python 2 means we need to invoke chr() explicitly
|
||||
_printable_ascii = _control_chars + b''.join(map(chr, range(32, 127)))
|
||||
_printable_high_ascii = b''.join(map(chr, range(127, 256)))
|
||||
else:
|
||||
# Python 3 means bytes accepts integer input directly
|
||||
_printable_ascii = _control_chars + bytes(range(32, 127))
|
||||
_printable_high_ascii = bytes(range(127, 256))
|
||||
|
||||
|
||||
def is_binary_string(bytes_to_check):
|
||||
"""
|
||||
Uses a simplified version of the Perl detection algorithm,
|
||||
based roughly on Eli Bendersky's translation to Python:
|
||||
http://eli.thegreenplace.net/2011/10/19/perls-guess-if-file-is-text-or-binary-implemented-in-python/
|
||||
|
||||
This is biased slightly more in favour of deeming files as text
|
||||
files than the Perl algorithm, since all ASCII compatible character
|
||||
sets are accepted as text, not just utf-8.
|
||||
|
||||
:param bytes: A chunk of bytes to check.
|
||||
:returns: True if appears to be a binary, otherwise False.
|
||||
"""
|
||||
|
||||
# Empty files are considered text files
|
||||
if not bytes_to_check:
|
||||
return False
|
||||
|
||||
# Now check for a high percentage of ASCII control characters
|
||||
# Binary if control chars are > 30% of the string
|
||||
low_chars = bytes_to_check.translate(None, _printable_ascii)
|
||||
nontext_ratio1 = float(len(low_chars)) / float(len(bytes_to_check))
|
||||
logger.debug('nontext_ratio1: %(nontext_ratio1)r', locals())
|
||||
|
||||
# and check for a low percentage of high ASCII characters:
|
||||
# Binary if high ASCII chars are < 5% of the string
|
||||
# From: https://en.wikipedia.org/wiki/UTF-8
|
||||
# If the bytes are random, the chances of a byte with the high bit set
|
||||
# starting a valid UTF-8 character is only 6.64%. The chances of finding 7
|
||||
# of these without finding an invalid sequence is actually lower than the
|
||||
# chance of the first three bytes randomly being the UTF-8 BOM.
|
||||
|
||||
high_chars = bytes_to_check.translate(None, _printable_high_ascii)
|
||||
nontext_ratio2 = float(len(high_chars)) / float(len(bytes_to_check))
|
||||
logger.debug('nontext_ratio2: %(nontext_ratio2)r', locals())
|
||||
|
||||
if nontext_ratio1 > 0.90 and nontext_ratio2 > 0.90:
|
||||
return True
|
||||
|
||||
is_likely_binary = (
|
||||
(nontext_ratio1 > 0.3 and nontext_ratio2 < 0.05) or
|
||||
(nontext_ratio1 > 0.8 and nontext_ratio2 > 0.8)
|
||||
)
|
||||
logger.debug('is_likely_binary: %(is_likely_binary)r', locals())
|
||||
|
||||
# then check for binary for possible encoding detection with chardet
|
||||
detected_encoding = chardet.detect(bytes_to_check)
|
||||
logger.debug('detected_encoding: %(detected_encoding)r', locals())
|
||||
|
||||
# finally use all the check to decide binary or text
|
||||
decodable_as_unicode = False
|
||||
if (detected_encoding['confidence'] > 0.9 and
|
||||
detected_encoding['encoding'] != 'ascii'):
|
||||
try:
|
||||
try:
|
||||
bytes_to_check.decode(encoding=detected_encoding['encoding'])
|
||||
except TypeError:
|
||||
# happens only on Python 2.6
|
||||
unicode(bytes_to_check, encoding=detected_encoding['encoding']) # noqa
|
||||
decodable_as_unicode = True
|
||||
logger.debug('success: decodable_as_unicode: '
|
||||
'%(decodable_as_unicode)r', locals())
|
||||
except LookupError:
|
||||
logger.debug('failure: could not look up encoding %(encoding)s',
|
||||
detected_encoding)
|
||||
except UnicodeDecodeError:
|
||||
logger.debug('failure: decodable_as_unicode: '
|
||||
'%(decodable_as_unicode)r', locals())
|
||||
|
||||
logger.debug('failure: decodable_as_unicode: '
|
||||
'%(decodable_as_unicode)r', locals())
|
||||
if is_likely_binary:
|
||||
if decodable_as_unicode:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
else:
|
||||
if decodable_as_unicode:
|
||||
return False
|
||||
else:
|
||||
if b'\x00' in bytes_to_check or b'\xff' in bytes_to_check:
|
||||
# Check for NULL bytes last
|
||||
logger.debug('has nulls:' + repr(b'\x00' in bytes_to_check))
|
||||
return True
|
||||
return False
|
File diff suppressed because it is too large
Load Diff
@ -1,40 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
import bottle
|
||||
import inspect
|
||||
import beaker
|
||||
from beaker import middleware
|
||||
|
||||
|
||||
class BeakerPlugin(object):
|
||||
name = 'beaker'
|
||||
|
||||
def __init__(self, keyword='beaker'):
|
||||
"""
|
||||
:param keyword: Keyword used to inject beaker in a route
|
||||
"""
|
||||
self.keyword = keyword
|
||||
|
||||
def setup(self, app):
|
||||
""" Make sure that other installed plugins don't affect the same
|
||||
keyword argument and check if metadata is available."""
|
||||
for other in app.plugins:
|
||||
if not isinstance(other, BeakerPlugin):
|
||||
continue
|
||||
if other.keyword == self.keyword:
|
||||
raise bottle.PluginError("Found another beaker plugin "
|
||||
"with conflicting settings ("
|
||||
"non-unique keyword).")
|
||||
|
||||
def apply(self, callback, context):
|
||||
args = inspect.getargspec(context['callback'])[0]
|
||||
|
||||
if self.keyword not in args:
|
||||
return callback
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
kwargs[self.keyword] = beaker
|
||||
kwargs["{0}_middleware".format(self.keyword)] = middleware
|
||||
return callback(*args, **kwargs)
|
||||
|
||||
return wrapper
|
@ -1,7 +0,0 @@
|
||||
# Cork - Authentication module for the Bottle web framework
|
||||
# Copyright (C) 2013 Federico Ceratto and others, see AUTHORS file.
|
||||
# Released under LGPLv3+ license, see LICENSE.txt
|
||||
#
|
||||
# Backends API - used to make backends available for importing
|
||||
#
|
||||
from .cork import Cork, JsonBackend, AAAException, AuthException, Mailer, FlaskCork, Redirect
|
@ -1,13 +0,0 @@
|
||||
# Cork - Authentication module for the Bottle web framework
|
||||
# Copyright (C) 2013 Federico Ceratto and others, see AUTHORS file.
|
||||
# Released under LGPLv3+ license, see LICENSE.txt
|
||||
|
||||
"""
|
||||
.. module:: backends
|
||||
:synopsis: Backends API - used to make backends available for importing
|
||||
"""
|
||||
|
||||
from .json_backend import JsonBackend
|
||||
from .mongodb_backend import MongoDBBackend
|
||||
from .sqlalchemy_backend import SqlAlchemyBackend
|
||||
from .sqlite_backend import SQLiteBackend
|
@ -1,31 +0,0 @@
|
||||
# Cork - Authentication module for the Bottle web framework
|
||||
# Copyright (C) 2013 Federico Ceratto and others, see AUTHORS file.
|
||||
# Released under LGPLv3+ license, see LICENSE.txt
|
||||
|
||||
"""
|
||||
.. module:: backend.py
|
||||
:synopsis: Base Backend.
|
||||
"""
|
||||
|
||||
class BackendIOException(Exception):
|
||||
"""Generic Backend I/O Exception"""
|
||||
pass
|
||||
|
||||
def ni(*args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
class Backend(object):
|
||||
"""Base Backend class - to be subclassed by real backends."""
|
||||
save_users = ni
|
||||
save_roles = ni
|
||||
save_pending_registrations = ni
|
||||
|
||||
class Table(object):
|
||||
"""Base Table class - to be subclassed by real backends."""
|
||||
__len__ = ni
|
||||
__contains__ = ni
|
||||
__setitem__ = ni
|
||||
__getitem__ = ni
|
||||
__iter__ = ni
|
||||
iteritems = ni
|
||||
|
@ -1,975 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Cork - Authentication module for the Bottle web framework
|
||||
# Copyright (C) 2013 Federico Ceratto and others, see AUTHORS file.
|
||||
#
|
||||
# This package is free software; you can redistribute it and/or
|
||||
# modify it under the terms of the GNU Lesser General Public
|
||||
# License as published by the Free Software Foundation; either
|
||||
# version 3 of the License, or (at your option) any later version.
|
||||
#
|
||||
# This package is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
||||
# Lesser General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
#
|
||||
|
||||
from base64 import b64encode, b64decode
|
||||
from datetime import datetime, timedelta
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from logging import getLogger
|
||||
from smtplib import SMTP, SMTP_SSL
|
||||
from threading import Thread
|
||||
from time import time
|
||||
import bottle
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import uuid
|
||||
|
||||
try:
|
||||
import scrypt
|
||||
scrypt_available = True
|
||||
except ImportError: # pragma: no cover
|
||||
scrypt_available = False
|
||||
|
||||
try:
|
||||
basestring
|
||||
except NameError:
|
||||
basestring = str
|
||||
|
||||
from .backends import JsonBackend
|
||||
|
||||
is_py3 = (sys.version_info.major == 3)
|
||||
|
||||
log = getLogger(__name__)
|
||||
|
||||
|
||||
class AAAException(Exception):
|
||||
"""Generic Authentication/Authorization Exception"""
|
||||
pass
|
||||
|
||||
|
||||
class AuthException(AAAException):
|
||||
"""Authentication Exception: incorrect username/password pair"""
|
||||
pass
|
||||
|
||||
|
||||
class BaseCork(object):
|
||||
"""Abstract class"""
|
||||
|
||||
def __init__(self, directory=None, backend=None, email_sender=None,
|
||||
initialize=False, session_domain=None, smtp_server=None,
|
||||
smtp_url='localhost', session_key_name=None):
|
||||
"""Auth/Authorization/Accounting class
|
||||
|
||||
:param directory: configuration directory
|
||||
:type directory: str.
|
||||
:param users_fname: users filename (without .json), defaults to 'users'
|
||||
:type users_fname: str.
|
||||
:param roles_fname: roles filename (without .json), defaults to 'roles'
|
||||
:type roles_fname: str.
|
||||
"""
|
||||
if smtp_server:
|
||||
smtp_url = smtp_server
|
||||
self.mailer = Mailer(email_sender, smtp_url)
|
||||
self.password_reset_timeout = 3600 * 24
|
||||
self.session_domain = session_domain
|
||||
self.session_key_name = session_key_name or 'beaker.session'
|
||||
self.preferred_hashing_algorithm = 'PBKDF2'
|
||||
|
||||
# Setup JsonBackend by default for backward compatibility.
|
||||
if backend is None:
|
||||
self._store = JsonBackend(directory, users_fname='users',
|
||||
roles_fname='roles', pending_reg_fname='register',
|
||||
initialize=initialize)
|
||||
|
||||
else:
|
||||
self._store = backend
|
||||
|
||||
def login(self, username, password, success_redirect=None,
|
||||
fail_redirect=None):
|
||||
"""Check login credentials for an existing user.
|
||||
Optionally redirect the user to another page (typically /login)
|
||||
|
||||
:param username: username
|
||||
:type username: str or unicode.
|
||||
:param password: cleartext password
|
||||
:type password: str.or unicode
|
||||
:param success_redirect: redirect authorized users (optional)
|
||||
:type success_redirect: str.
|
||||
:param fail_redirect: redirect unauthorized users (optional)
|
||||
:type fail_redirect: str.
|
||||
:returns: True for successful logins, else False
|
||||
"""
|
||||
#assert isinstance(username, type(u'')), "the username must be a string"
|
||||
#assert isinstance(password, type(u'')), "the password must be a string"
|
||||
|
||||
if username in self._store.users:
|
||||
salted_hash = self._store.users[username]['hash']
|
||||
if hasattr(salted_hash, 'encode'):
|
||||
salted_hash = salted_hash.encode('ascii')
|
||||
authenticated = self._verify_password(
|
||||
username,
|
||||
password,
|
||||
salted_hash,
|
||||
)
|
||||
if authenticated:
|
||||
# Setup session data
|
||||
self._setup_cookie(username)
|
||||
self._store.users[username]['last_login'] = str(datetime.utcnow())
|
||||
self._store.save_users()
|
||||
if success_redirect:
|
||||
self._redirect(success_redirect)
|
||||
return True
|
||||
|
||||
if fail_redirect:
|
||||
self._redirect(fail_redirect)
|
||||
|
||||
return False
|
||||
|
||||
def logout(self, success_redirect='/login', fail_redirect='/login'):
|
||||
"""Log the user out, remove cookie
|
||||
|
||||
:param success_redirect: redirect the user after logging out
|
||||
:type success_redirect: str.
|
||||
:param fail_redirect: redirect the user if it is not logged in
|
||||
:type fail_redirect: str.
|
||||
"""
|
||||
try:
|
||||
session = self._beaker_session
|
||||
session.delete()
|
||||
except Exception as e:
|
||||
log.debug("Exception %s while logging out." % repr(e))
|
||||
self._redirect(fail_redirect)
|
||||
|
||||
self._redirect(success_redirect)
|
||||
|
||||
def require(self, username=None, role=None, fixed_role=False,
|
||||
fail_redirect=None):
|
||||
"""Ensure the user is logged in has the required role (or higher).
|
||||
Optionally redirect the user to another page (typically /login)
|
||||
If both `username` and `role` are specified, both conditions need to be
|
||||
satisfied.
|
||||
If none is specified, any authenticated user will be authorized.
|
||||
By default, any role with higher level than `role` will be authorized;
|
||||
set fixed_role=True to prevent this.
|
||||
|
||||
:param username: username (optional)
|
||||
:type username: str.
|
||||
:param role: role
|
||||
:type role: str.
|
||||
:param fixed_role: require user role to match `role` strictly
|
||||
:type fixed_role: bool.
|
||||
:param redirect: redirect unauthorized users (optional)
|
||||
:type redirect: str.
|
||||
"""
|
||||
# Parameter validation
|
||||
if username is not None:
|
||||
if username not in self._store.users:
|
||||
raise AAAException("Nonexistent user")
|
||||
|
||||
if fixed_role and role is None:
|
||||
raise AAAException(
|
||||
"""A role must be specified if fixed_role has been set""")
|
||||
|
||||
if role is not None and role not in self._store.roles:
|
||||
raise AAAException("Role not found")
|
||||
|
||||
# Authentication
|
||||
try:
|
||||
cu = self.current_user
|
||||
except AAAException:
|
||||
if fail_redirect is None:
|
||||
raise AuthException("Unauthenticated user")
|
||||
else:
|
||||
self._redirect(fail_redirect)
|
||||
|
||||
# Authorization
|
||||
if cu.role not in self._store.roles:
|
||||
raise AAAException("Role not found for the current user")
|
||||
|
||||
if username is not None:
|
||||
# A specific user is required
|
||||
if username == self.current_user.username:
|
||||
return
|
||||
|
||||
if fail_redirect is None:
|
||||
raise AuthException("Unauthorized access: incorrect"
|
||||
" username")
|
||||
|
||||
self._redirect(fail_redirect)
|
||||
|
||||
if fixed_role:
|
||||
# A specific role is required
|
||||
if role == self.current_user.role:
|
||||
return
|
||||
|
||||
if fail_redirect is None:
|
||||
raise AuthException("Unauthorized access: incorrect role")
|
||||
|
||||
self._redirect(fail_redirect)
|
||||
|
||||
if role is not None:
|
||||
# Any role with higher level is allowed
|
||||
current_lvl = self._store.roles[self.current_user.role]
|
||||
threshold_lvl = self._store.roles[role]
|
||||
if current_lvl >= threshold_lvl:
|
||||
return
|
||||
|
||||
if fail_redirect is None:
|
||||
raise AuthException("Unauthorized access: ")
|
||||
|
||||
self._redirect(fail_redirect)
|
||||
|
||||
return # success
|
||||
|
||||
def create_role(self, role, level):
|
||||
"""Create a new role.
|
||||
|
||||
:param role: role name
|
||||
:type role: str.
|
||||
:param level: role level (0=lowest, 100=admin)
|
||||
:type level: int.
|
||||
:raises: AuthException on errors
|
||||
"""
|
||||
if self.current_user.level < 100:
|
||||
raise AuthException("The current user is not authorized to ")
|
||||
if role in self._store.roles:
|
||||
raise AAAException("The role is already existing")
|
||||
try:
|
||||
int(level)
|
||||
except ValueError:
|
||||
raise AAAException("The level must be numeric.")
|
||||
self._store.roles[role] = level
|
||||
self._store.save_roles()
|
||||
|
||||
def delete_role(self, role):
|
||||
"""Deleta a role.
|
||||
|
||||
:param role: role name
|
||||
:type role: str.
|
||||
:raises: AuthException on errors
|
||||
"""
|
||||
if self.current_user.level < 100:
|
||||
raise AuthException("The current user is not authorized to ")
|
||||
if role not in self._store.roles:
|
||||
raise AAAException("Nonexistent role.")
|
||||
self._store.roles.pop(role)
|
||||
self._store.save_roles()
|
||||
|
||||
def list_roles(self):
|
||||
"""List roles.
|
||||
|
||||
:returns: (role, role_level) generator (sorted by role)
|
||||
"""
|
||||
for role in sorted(self._store.roles):
|
||||
yield (role, self._store.roles[role])
|
||||
|
||||
def create_user(self, username, role, password, email_addr=None,
|
||||
description=None):
|
||||
"""Create a new user account.
|
||||
This method is available to users with level>=100
|
||||
|
||||
:param username: username
|
||||
:type username: str.
|
||||
:param role: role
|
||||
:type role: str.
|
||||
:param password: cleartext password
|
||||
:type password: str.
|
||||
:param email_addr: email address (optional)
|
||||
:type email_addr: str.
|
||||
:param description: description (free form)
|
||||
:type description: str.
|
||||
:raises: AuthException on errors
|
||||
"""
|
||||
assert username, "Username must be provided."
|
||||
if self.current_user.level < 100:
|
||||
raise AuthException("The current user is not authorized"
|
||||
" to create users.")
|
||||
|
||||
if username in self._store.users:
|
||||
raise AAAException("User is already existing.")
|
||||
if role not in self._store.roles:
|
||||
raise AAAException("Nonexistent user role.")
|
||||
tstamp = str(datetime.utcnow())
|
||||
h = self._hash(username, password)
|
||||
h = h.decode('ascii')
|
||||
self._store.users[username] = {
|
||||
'role': role,
|
||||
'hash': h,
|
||||
'email_addr': email_addr,
|
||||
'desc': description,
|
||||
'creation_date': tstamp,
|
||||
'last_login': tstamp
|
||||
}
|
||||
self._store.save_users()
|
||||
|
||||
def delete_user(self, username):
|
||||
"""Delete a user account.
|
||||
This method is available to users with level>=100
|
||||
|
||||
:param username: username
|
||||
:type username: str.
|
||||
:raises: Exceptions on errors
|
||||
"""
|
||||
if self.current_user.level < 100:
|
||||
raise AuthException("The current user is not authorized to ")
|
||||
if username not in self._store.users:
|
||||
raise AAAException("Nonexistent user.")
|
||||
self.user(username).delete()
|
||||
|
||||
def list_users(self):
|
||||
"""List users.
|
||||
|
||||
:return: (username, role, email_addr, description) generator (sorted by
|
||||
username)
|
||||
"""
|
||||
for un in sorted(self._store.users):
|
||||
d = self._store.users[un]
|
||||
yield (un, d['role'], d['email_addr'], d['desc'])
|
||||
|
||||
@property
|
||||
def current_user(self):
|
||||
"""Current autenticated user
|
||||
|
||||
:returns: User() instance, if authenticated
|
||||
:raises: AuthException otherwise
|
||||
"""
|
||||
session = self._beaker_session
|
||||
username = session.get('username', None)
|
||||
if username is None:
|
||||
raise AuthException("Unauthenticated user")
|
||||
if username is not None and username in self._store.users:
|
||||
return User(username, self, session=session)
|
||||
raise AuthException("Unknown user: %s" % username)
|
||||
|
||||
@property
|
||||
def user_is_anonymous(self):
|
||||
"""Check if the current user is anonymous.
|
||||
|
||||
:returns: True if the user is anonymous, False otherwise
|
||||
:raises: AuthException if the session username is unknown
|
||||
"""
|
||||
try:
|
||||
username = self._beaker_session['username']
|
||||
except KeyError:
|
||||
return True
|
||||
|
||||
if username not in self._store.users:
|
||||
raise AuthException("Unknown user: %s" % username)
|
||||
|
||||
return False
|
||||
|
||||
def user(self, username):
|
||||
"""Existing user
|
||||
|
||||
:returns: User() instance if the user exist, None otherwise
|
||||
"""
|
||||
if username is not None and username in self._store.users:
|
||||
return User(username, self)
|
||||
return None
|
||||
|
||||
def register(self, username, password, email_addr, role='user',
|
||||
max_level=50, subject="Signup confirmation",
|
||||
email_template='views/registration_email.tpl',
|
||||
description=None, **kwargs):
|
||||
"""Register a new user account. An email with a registration validation
|
||||
is sent to the user.
|
||||
WARNING: this method is available to unauthenticated users
|
||||
|
||||
:param username: username
|
||||
:type username: str.
|
||||
:param password: cleartext password
|
||||
:type password: str.
|
||||
:param role: role (optional), defaults to 'user'
|
||||
:type role: str.
|
||||
:param max_level: maximum role level (optional), defaults to 50
|
||||
:type max_level: int.
|
||||
:param email_addr: email address
|
||||
:type email_addr: str.
|
||||
:param subject: email subject
|
||||
:type subject: str.
|
||||
:param email_template: email template filename
|
||||
:type email_template: str.
|
||||
:param description: description (free form)
|
||||
:type description: str.
|
||||
:raises: AssertError or AAAException on errors
|
||||
"""
|
||||
assert username, "Username must be provided."
|
||||
assert password, "A password must be provided."
|
||||
assert email_addr, "An email address must be provided."
|
||||
if username in self._store.users:
|
||||
raise AAAException("User is already existing.")
|
||||
if role not in self._store.roles:
|
||||
raise AAAException("Nonexistent role")
|
||||
if self._store.roles[role] > max_level:
|
||||
raise AAAException("Unauthorized role")
|
||||
|
||||
registration_code = uuid.uuid4().hex
|
||||
creation_date = str(datetime.utcnow())
|
||||
|
||||
# send registration email
|
||||
email_text = bottle.template(
|
||||
email_template,
|
||||
username=username,
|
||||
email_addr=email_addr,
|
||||
role=role,
|
||||
creation_date=creation_date,
|
||||
registration_code=registration_code,
|
||||
**kwargs
|
||||
)
|
||||
self.mailer.send_email(email_addr, subject, email_text)
|
||||
|
||||
# store pending registration
|
||||
h = self._hash(username, password)
|
||||
h = h.decode('ascii')
|
||||
self._store.pending_registrations[registration_code] = {
|
||||
'username': username,
|
||||
'role': role,
|
||||
'hash': h,
|
||||
'email_addr': email_addr,
|
||||
'desc': description,
|
||||
'creation_date': creation_date,
|
||||
}
|
||||
self._store.save_pending_registrations()
|
||||
|
||||
def validate_registration(self, registration_code):
|
||||
"""Validate pending account registration, create a new account if
|
||||
successful.
|
||||
|
||||
:param registration_code: registration code
|
||||
:type registration_code: str.
|
||||
"""
|
||||
try:
|
||||
data = self._store.pending_registrations.pop(registration_code)
|
||||
except KeyError:
|
||||
raise AuthException("Invalid registration code.")
|
||||
|
||||
username = data['username']
|
||||
if username in self._store.users:
|
||||
raise AAAException("User is already existing.")
|
||||
|
||||
# the user data is moved from pending_registrations to _users
|
||||
self._store.users[username] = {
|
||||
'role': data['role'],
|
||||
'hash': data['hash'],
|
||||
'email_addr': data['email_addr'],
|
||||
'desc': data['desc'],
|
||||
'creation_date': data['creation_date'],
|
||||
'last_login': str(datetime.utcnow())
|
||||
}
|
||||
self._store.save_users()
|
||||
|
||||
def send_password_reset_email(self, username=None, email_addr=None,
|
||||
subject="Password reset confirmation",
|
||||
email_template='views/password_reset_email',
|
||||
**kwargs):
|
||||
"""Email the user with a link to reset his/her password
|
||||
If only one parameter is passed, fetch the other from the users
|
||||
database. If both are passed they will be matched against the users
|
||||
database as a security check.
|
||||
|
||||
:param username: username
|
||||
:type username: str.
|
||||
:param email_addr: email address
|
||||
:type email_addr: str.
|
||||
:param subject: email subject
|
||||
:type subject: str.
|
||||
:param email_template: email template filename
|
||||
:type email_template: str.
|
||||
:raises: AAAException on missing username or email_addr,
|
||||
AuthException on incorrect username/email_addr pair
|
||||
"""
|
||||
if username is None:
|
||||
if email_addr is None:
|
||||
raise AAAException("At least `username` or `email_addr` must"
|
||||
" be specified.")
|
||||
|
||||
# only email_addr is specified: fetch the username
|
||||
for k, v in self._store.users.iteritems():
|
||||
if v['email_addr'] == email_addr:
|
||||
username = k
|
||||
break
|
||||
else:
|
||||
raise AAAException("Email address not found.")
|
||||
|
||||
else: # username is provided
|
||||
if username not in self._store.users:
|
||||
raise AAAException("Nonexistent user.")
|
||||
if email_addr is None:
|
||||
email_addr = self._store.users[username].get('email_addr', None)
|
||||
if not email_addr:
|
||||
raise AAAException("Email address not available.")
|
||||
else:
|
||||
# both username and email_addr are provided: check them
|
||||
stored_email_addr = self._store.users[username]['email_addr']
|
||||
if email_addr != stored_email_addr:
|
||||
raise AuthException("Username/email address pair not found.")
|
||||
|
||||
# generate a reset_code token
|
||||
reset_code = self._reset_code(username, email_addr)
|
||||
|
||||
# send reset email
|
||||
email_text = bottle.template(
|
||||
email_template,
|
||||
username=username,
|
||||
email_addr=email_addr,
|
||||
reset_code=reset_code,
|
||||
**kwargs
|
||||
)
|
||||
self.mailer.send_email(email_addr, subject, email_text)
|
||||
|
||||
def reset_password(self, reset_code, password):
|
||||
"""Validate reset_code and update the account password
|
||||
The username is extracted from the reset_code token
|
||||
|
||||
:param reset_code: reset token
|
||||
:type reset_code: str.
|
||||
:param password: new password
|
||||
:type password: str.
|
||||
:raises: AuthException for invalid reset tokens, AAAException
|
||||
"""
|
||||
try:
|
||||
reset_code = b64decode(reset_code).decode()
|
||||
username, email_addr, tstamp, h = reset_code.split(':', 3)
|
||||
tstamp = int(tstamp)
|
||||
assert isinstance(username, type(u''))
|
||||
assert isinstance(email_addr, type(u''))
|
||||
if not isinstance(h, type(b'')):
|
||||
h = h.encode('utf-8')
|
||||
except (TypeError, ValueError):
|
||||
raise AuthException("Invalid reset code.")
|
||||
|
||||
if time() - tstamp > self.password_reset_timeout:
|
||||
raise AuthException("Expired reset code.")
|
||||
|
||||
assert isinstance(h, type(b''))
|
||||
if not self._verify_password(username, email_addr, h):
|
||||
raise AuthException("Invalid reset code.")
|
||||
user = self.user(username)
|
||||
if user is None:
|
||||
raise AAAException("Nonexistent user.")
|
||||
user.update(pwd=password)
|
||||
|
||||
def make_auth_decorator(self, username=None, role=None, fixed_role=False, fail_redirect='/login'):
|
||||
'''
|
||||
Create a decorator to be used for authentication and authorization
|
||||
|
||||
:param username: A resource can be protected for a specific user
|
||||
:param role: Minimum role level required for authorization
|
||||
:param fixed_role: Only this role gets authorized
|
||||
:param fail_redirect: The URL to redirect to if a login is required.
|
||||
'''
|
||||
session_manager = self
|
||||
def auth_require(username=username, role=role, fixed_role=fixed_role,
|
||||
fail_redirect=fail_redirect):
|
||||
def decorator(func):
|
||||
import functools
|
||||
@functools.wraps(func)
|
||||
def wrapper(*a, **ka):
|
||||
session_manager.require(username=username, role=role, fixed_role=fixed_role,
|
||||
fail_redirect=fail_redirect)
|
||||
return func(*a, **ka)
|
||||
return wrapper
|
||||
return decorator
|
||||
return(auth_require)
|
||||
|
||||
|
||||
## Private methods
|
||||
|
||||
def _setup_cookie(self, username):
|
||||
"""Setup cookie for a user that just logged in"""
|
||||
session = self._beaker_session
|
||||
session['username'] = username
|
||||
if self.session_domain is not None:
|
||||
session.domain = self.session_domain
|
||||
|
||||
self._save_session()
|
||||
|
||||
def _hash(self, username, pwd, salt=None, algo=None):
|
||||
"""Hash username and password, generating salt value if required
|
||||
"""
|
||||
if algo is None:
|
||||
algo = self.preferred_hashing_algorithm
|
||||
|
||||
if algo == 'PBKDF2':
|
||||
return self._hash_pbkdf2(username, pwd, salt=salt)
|
||||
|
||||
if algo == 'scrypt':
|
||||
return self._hash_scrypt(username, pwd, salt=salt)
|
||||
|
||||
raise RuntimeError("Unknown hashing algorithm requested: %s" % algo)
|
||||
|
||||
@staticmethod
|
||||
def _hash_scrypt(username, pwd, salt=None):
|
||||
"""Hash username and password, generating salt value if required
|
||||
Use scrypt.
|
||||
|
||||
:returns: base-64 encoded str.
|
||||
"""
|
||||
if not scrypt_available:
|
||||
raise Exception("scrypt.hash required."
|
||||
" Please install the scrypt library.")
|
||||
|
||||
if salt is None:
|
||||
salt = os.urandom(32)
|
||||
|
||||
assert len(salt) == 32, "Incorrect salt length"
|
||||
|
||||
cleartext = "%s\0%s" % (username, pwd)
|
||||
h = scrypt.hash(cleartext, salt)
|
||||
|
||||
# 's' for scrypt
|
||||
hashed = b's' + salt + h
|
||||
return b64encode(hashed)
|
||||
|
||||
@staticmethod
|
||||
def _hash_pbkdf2(username, pwd, salt=None):
|
||||
"""Hash username and password, generating salt value if required
|
||||
Use PBKDF2 from Beaker
|
||||
|
||||
:returns: base-64 encoded str.
|
||||
"""
|
||||
if salt is None:
|
||||
salt = os.urandom(32)
|
||||
|
||||
assert isinstance(salt, bytes)
|
||||
assert len(salt) == 32, "Incorrect salt length"
|
||||
|
||||
username = username.encode('utf-8')
|
||||
assert isinstance(username, bytes)
|
||||
|
||||
pwd = pwd.encode('utf-8')
|
||||
assert isinstance(pwd, bytes)
|
||||
|
||||
cleartext = username + b'\0' + pwd
|
||||
h = hashlib.pbkdf2_hmac('sha1', cleartext, salt, 10, dklen=32)
|
||||
|
||||
# 'p' for PBKDF2
|
||||
hashed = b'p' + salt + h
|
||||
return b64encode(hashed)
|
||||
|
||||
def _verify_password(self, username, pwd, salted_hash):
|
||||
"""Verity username/password pair against a salted hash
|
||||
|
||||
:returns: bool
|
||||
"""
|
||||
assert isinstance(salted_hash, type(b''))
|
||||
decoded = b64decode(salted_hash)
|
||||
hash_type = decoded[0]
|
||||
if isinstance(hash_type, int):
|
||||
hash_type = chr(hash_type)
|
||||
|
||||
salt = decoded[1:33]
|
||||
|
||||
if hash_type == 'p': # PBKDF2
|
||||
h = self._hash_pbkdf2(username, pwd, salt)
|
||||
return salted_hash == h
|
||||
|
||||
if hash_type == 's': # scrypt
|
||||
h = self._hash_scrypt(username, pwd, salt)
|
||||
return salted_hash == h
|
||||
|
||||
raise RuntimeError("Unknown hashing algorithm in hash: %r" % decoded)
|
||||
|
||||
def _purge_expired_registrations(self, exp_time=96):
|
||||
"""Purge expired registration requests.
|
||||
|
||||
:param exp_time: expiration time (hours)
|
||||
:type exp_time: float.
|
||||
"""
|
||||
pending = self._store.pending_registrations.items()
|
||||
if is_py3:
|
||||
pending = list(pending)
|
||||
|
||||
for uuid_code, data in pending:
|
||||
creation = datetime.strptime(data['creation_date'],
|
||||
"%Y-%m-%d %H:%M:%S.%f")
|
||||
now = datetime.utcnow()
|
||||
maxdelta = timedelta(hours=exp_time)
|
||||
if now - creation > maxdelta:
|
||||
self._store.pending_registrations.pop(uuid_code)
|
||||
|
||||
def _reset_code(self, username, email_addr):
|
||||
"""generate a reset_code token
|
||||
|
||||
:param username: username
|
||||
:type username: str.
|
||||
:param email_addr: email address
|
||||
:type email_addr: str.
|
||||
:returns: Base-64 encoded token
|
||||
"""
|
||||
h = self._hash(username, email_addr)
|
||||
t = "%d" % time()
|
||||
t = t.encode('utf-8')
|
||||
reset_code = b':'.join((username.encode('utf-8'), email_addr.encode('utf-8'), t, h))
|
||||
return b64encode(reset_code)
|
||||
|
||||
|
||||
class User(object):
|
||||
|
||||
def __init__(self, username, cork_obj, session=None):
|
||||
"""Represent an authenticated user, exposing useful attributes:
|
||||
username, role, level, description, email_addr, session_creation_time,
|
||||
session_accessed_time, session_id. The session-related attributes are
|
||||
available for the current user only.
|
||||
|
||||
:param username: username
|
||||
:type username: str.
|
||||
:param cork_obj: instance of :class:`Cork`
|
||||
"""
|
||||
self._cork = cork_obj
|
||||
assert username in self._cork._store.users, "Unknown user"
|
||||
self.username = username
|
||||
user_data = self._cork._store.users[username]
|
||||
self.role = user_data['role']
|
||||
self.description = user_data['desc']
|
||||
self.email_addr = user_data['email_addr']
|
||||
self.level = self._cork._store.roles[self.role]
|
||||
|
||||
if session is not None:
|
||||
try:
|
||||
self.session_creation_time = session['_creation_time']
|
||||
self.session_accessed_time = session['_accessed_time']
|
||||
self.session_id = session['_id']
|
||||
except:
|
||||
pass
|
||||
|
||||
def update(self, role=None, pwd=None, email_addr=None):
|
||||
"""Update an user account data
|
||||
|
||||
:param role: change user role, if specified
|
||||
:type role: str.
|
||||
:param pwd: change user password, if specified
|
||||
:type pwd: str.
|
||||
:param email_addr: change user email address, if specified
|
||||
:type email_addr: str.
|
||||
:raises: AAAException on nonexistent user or role.
|
||||
"""
|
||||
username = self.username
|
||||
if username not in self._cork._store.users:
|
||||
raise AAAException("User does not exist.")
|
||||
|
||||
if role is not None:
|
||||
if role not in self._cork._store.roles:
|
||||
raise AAAException("Nonexistent role.")
|
||||
|
||||
self._cork._store.users[username]['role'] = role
|
||||
|
||||
if pwd is not None:
|
||||
self._cork._store.users[username]['hash'] = self._cork._hash(
|
||||
username, pwd)
|
||||
|
||||
if email_addr is not None:
|
||||
self._cork._store.users[username]['email_addr'] = email_addr
|
||||
|
||||
self._cork._store.save_users()
|
||||
|
||||
def delete(self):
|
||||
"""Delete user account
|
||||
|
||||
:raises: AAAException on nonexistent user.
|
||||
"""
|
||||
try:
|
||||
self._cork._store.users.pop(self.username)
|
||||
except KeyError:
|
||||
raise AAAException("Nonexistent user.")
|
||||
self._cork._store.save_users()
|
||||
|
||||
|
||||
class Redirect(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def raise_redirect(path):
|
||||
raise Redirect(path)
|
||||
|
||||
|
||||
class Cork(BaseCork):
|
||||
@staticmethod
|
||||
def _redirect(location):
|
||||
bottle.redirect(location)
|
||||
|
||||
@property
|
||||
def _beaker_session(self):
|
||||
"""Get session"""
|
||||
return bottle.request.environ.get(self.session_key_name)
|
||||
|
||||
def _save_session(self):
|
||||
self._beaker_session.save()
|
||||
|
||||
|
||||
class FlaskCork(BaseCork):
|
||||
@staticmethod
|
||||
def _redirect(location):
|
||||
raise_redirect(location)
|
||||
|
||||
@property
|
||||
def _beaker_session(self):
|
||||
"""Get session"""
|
||||
import flask
|
||||
return flask.session
|
||||
|
||||
def _save_session(self):
|
||||
pass
|
||||
|
||||
|
||||
class Mailer(object):
|
||||
|
||||
def __init__(self, sender, smtp_url, join_timeout=5, use_threads=True):
|
||||
"""Send emails asyncronously
|
||||
|
||||
:param sender: Sender email address
|
||||
:type sender: str.
|
||||
:param smtp_server: SMTP server
|
||||
:type smtp_server: str.
|
||||
"""
|
||||
self.sender = sender
|
||||
self.join_timeout = join_timeout
|
||||
self.use_threads = use_threads
|
||||
self._threads = []
|
||||
self._conf = self._parse_smtp_url(smtp_url)
|
||||
|
||||
def _parse_smtp_url(self, url):
|
||||
"""Parse SMTP URL"""
|
||||
match = re.match(r"""
|
||||
( # Optional protocol
|
||||
(?P<proto>smtp|starttls|ssl) # Protocol name
|
||||
://
|
||||
)?
|
||||
( # Optional user:pass@
|
||||
(?P<user>[^:]*) # Match every char except ':'
|
||||
(: (?P<pass>.*) )? @ # Optional :pass
|
||||
)?
|
||||
(?P<fqdn> # Required FQDN on IP address
|
||||
()| # Empty string
|
||||
( # FQDN
|
||||
[a-zA-Z_\-] # First character cannot be a number
|
||||
[a-zA-Z0-9_\-\.]{,254}
|
||||
)
|
||||
|( # IPv4
|
||||
([0-9]{1,3}\.){3}
|
||||
[0-9]{1,3}
|
||||
)
|
||||
|( # IPv6
|
||||
\[ # Square brackets
|
||||
([0-9a-f]{,4}:){1,8}
|
||||
[0-9a-f]{,4}
|
||||
\]
|
||||
)
|
||||
)
|
||||
( # Optional :port
|
||||
:
|
||||
(?P<port>[0-9]{,5}) # Up to 5-digits port
|
||||
)?
|
||||
[/]?
|
||||
$
|
||||
""", url, re.VERBOSE)
|
||||
|
||||
if not match:
|
||||
raise RuntimeError("SMTP URL seems incorrect")
|
||||
|
||||
d = match.groupdict()
|
||||
if d['proto'] is None:
|
||||
d['proto'] = 'smtp'
|
||||
|
||||
if d['port'] is None:
|
||||
d['port'] = 25
|
||||
else:
|
||||
d['port'] = int(d['port'])
|
||||
|
||||
if not 0 < d['port'] < 65536:
|
||||
raise RuntimeError("Incorrect SMTP port")
|
||||
|
||||
return d
|
||||
|
||||
def send_email(self, email_addr, subject, email_text):
|
||||
"""Send an email
|
||||
|
||||
:param email_addr: email address
|
||||
:type email_addr: str.
|
||||
:param subject: subject
|
||||
:type subject: str.
|
||||
:param email_text: email text
|
||||
:type email_text: str.
|
||||
:raises: AAAException if smtp_server and/or sender are not set
|
||||
"""
|
||||
if not (self._conf['fqdn'] and self.sender):
|
||||
raise AAAException("SMTP server or sender not set")
|
||||
msg = MIMEMultipart('alternative')
|
||||
msg['Subject'] = subject
|
||||
msg['From'] = self.sender
|
||||
msg['To'] = email_addr
|
||||
if isinstance(email_text, bytes):
|
||||
email_text = email_text.encode('utf-8')
|
||||
|
||||
part = MIMEText(email_text, 'html')
|
||||
msg.attach(part)
|
||||
msg = msg.as_string()
|
||||
|
||||
log.debug("Sending email using %s" % self._conf['fqdn'])
|
||||
|
||||
if self.use_threads:
|
||||
thread = Thread(target=self._send, args=(email_addr, msg))
|
||||
thread.start()
|
||||
self._threads.append(thread)
|
||||
|
||||
else:
|
||||
self._send(email_addr, msg)
|
||||
|
||||
def _send(self, email_addr, msg):
|
||||
"""Deliver an email using SMTP
|
||||
|
||||
:param email_addr: recipient
|
||||
:type email_addr: str.
|
||||
:param msg: email text
|
||||
:type msg: str.
|
||||
"""
|
||||
proto = self._conf['proto']
|
||||
assert proto in ('smtp', 'starttls', 'ssl'), \
|
||||
"Incorrect protocol: %s" % proto
|
||||
|
||||
try:
|
||||
if proto == 'ssl':
|
||||
log.debug("Setting up SSL")
|
||||
session = SMTP_SSL(self._conf['fqdn'], self._conf['port'])
|
||||
else:
|
||||
session = SMTP(self._conf['fqdn'], self._conf['port'])
|
||||
|
||||
if proto == 'starttls':
|
||||
log.debug('Sending EHLO and STARTTLS')
|
||||
session.ehlo()
|
||||
session.starttls()
|
||||
session.ehlo()
|
||||
|
||||
if self._conf['user'] is not None:
|
||||
log.debug('Performing login')
|
||||
session.login(self._conf['user'], self._conf['pass'])
|
||||
|
||||
log.debug('Sending')
|
||||
session.sendmail(self.sender, email_addr, msg)
|
||||
session.quit()
|
||||
log.info('Email sent')
|
||||
|
||||
except Exception as e: # pragma: no cover
|
||||
log.error("Error sending email: %s" % e, exc_info=True)
|
||||
|
||||
def join(self):
|
||||
"""Flush email queue by waiting the completion of the existing threads
|
||||
|
||||
:returns: None
|
||||
"""
|
||||
return [t.join(self.join_timeout) for t in self._threads]
|
||||
|
||||
def __del__(self):
|
||||
"""Class destructor: wait for threads to terminate within a timeout"""
|
||||
try:
|
||||
self.join()
|
||||
except TypeError:
|
||||
pass
|
@ -1,134 +0,0 @@
|
||||
# Cork - Authentication module for the Bottle web framework
|
||||
# Copyright (C) 2013 Federico Ceratto and others, see AUTHORS file.
|
||||
# Released under LGPLv3+ license, see LICENSE.txt
|
||||
|
||||
"""
|
||||
.. module:: json_backend
|
||||
:synopsis: JSON file-based storage backend.
|
||||
"""
|
||||
|
||||
from logging import getLogger
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
try:
|
||||
import json
|
||||
except ImportError: # pragma: no cover
|
||||
import simplejson as json
|
||||
|
||||
from .base_backend import BackendIOException
|
||||
|
||||
is_py3 = (sys.version_info.major == 3)
|
||||
|
||||
log = getLogger(__name__)
|
||||
|
||||
try:
|
||||
dict.iteritems
|
||||
py23dict = dict
|
||||
except AttributeError:
|
||||
class py23dict(dict):
|
||||
iteritems = dict.items
|
||||
|
||||
class BytesEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if is_py3 and isinstance(obj, bytes):
|
||||
return obj.decode()
|
||||
|
||||
return json.JSONEncoder.default(self, obj)
|
||||
|
||||
|
||||
class JsonBackend(object):
|
||||
"""JSON file-based storage backend."""
|
||||
|
||||
def __init__(self, directory, users_fname='users',
|
||||
roles_fname='roles', pending_reg_fname='register', initialize=False):
|
||||
"""Data storage class. Handles JSON files
|
||||
|
||||
:param users_fname: users file name (without .json)
|
||||
:type users_fname: str.
|
||||
:param roles_fname: roles file name (without .json)
|
||||
:type roles_fname: str.
|
||||
:param pending_reg_fname: pending registrations file name (without .json)
|
||||
:type pending_reg_fname: str.
|
||||
:param initialize: create empty JSON files (defaults to False)
|
||||
:type initialize: bool.
|
||||
"""
|
||||
assert directory, "Directory name must be valid"
|
||||
self._directory = directory
|
||||
self.users = py23dict()
|
||||
self._users_fname = users_fname
|
||||
self.roles = py23dict()
|
||||
self._roles_fname = roles_fname
|
||||
self._mtimes = py23dict()
|
||||
self._pending_reg_fname = pending_reg_fname
|
||||
self.pending_registrations = py23dict()
|
||||
if initialize:
|
||||
self._initialize_storage()
|
||||
self._refresh() # load users and roles
|
||||
|
||||
def _initialize_storage(self):
|
||||
"""Create empty JSON files"""
|
||||
self._savejson(self._users_fname, {})
|
||||
self._savejson(self._roles_fname, {})
|
||||
self._savejson(self._pending_reg_fname, {})
|
||||
|
||||
def _refresh(self):
|
||||
"""Load users and roles from JSON files, if needed"""
|
||||
self._loadjson(self._users_fname, self.users)
|
||||
self._loadjson(self._roles_fname, self.roles)
|
||||
self._loadjson(self._pending_reg_fname, self.pending_registrations)
|
||||
|
||||
def _loadjson(self, fname, dest):
|
||||
"""Load JSON file located under self._directory, if needed
|
||||
|
||||
:param fname: short file name (without path and .json)
|
||||
:type fname: str.
|
||||
:param dest: destination
|
||||
:type dest: dict
|
||||
"""
|
||||
try:
|
||||
fname = "%s/%s.json" % (self._directory, fname)
|
||||
mtime = os.stat(fname).st_mtime
|
||||
|
||||
if self._mtimes.get(fname, 0) == mtime:
|
||||
# no need to reload the file: the mtime has not been changed
|
||||
return
|
||||
|
||||
with open(fname) as f:
|
||||
json_data = f.read()
|
||||
except Exception as e:
|
||||
raise BackendIOException("Unable to read json file %s: %s" % (fname, e))
|
||||
|
||||
try:
|
||||
json_obj = json.loads(json_data)
|
||||
dest.clear()
|
||||
dest.update(json_obj)
|
||||
self._mtimes[fname] = os.stat(fname).st_mtime
|
||||
except Exception as e:
|
||||
raise BackendIOException("Unable to parse JSON data from %s: %s" \
|
||||
% (fname, e))
|
||||
|
||||
def _savejson(self, fname, obj):
|
||||
"""Save obj in JSON format in a file in self._directory"""
|
||||
fname = "%s/%s.json" % (self._directory, fname)
|
||||
try:
|
||||
with open("%s.tmp" % fname, 'w') as f:
|
||||
json.dump(obj, f, cls=BytesEncoder)
|
||||
f.flush()
|
||||
shutil.move("%s.tmp" % fname, fname)
|
||||
except Exception as e:
|
||||
raise BackendIOException("Unable to save JSON file %s: %s" \
|
||||
% (fname, e))
|
||||
|
||||
def save_users(self):
|
||||
"""Save users in a JSON file"""
|
||||
self._savejson(self._users_fname, self.users)
|
||||
|
||||
def save_roles(self):
|
||||
"""Save roles in a JSON file"""
|
||||
self._savejson(self._roles_fname, self.roles)
|
||||
|
||||
def save_pending_registrations(self):
|
||||
"""Save pending registrations in a JSON file"""
|
||||
self._savejson(self._pending_reg_fname, self.pending_registrations)
|
@ -1,180 +0,0 @@
|
||||
# Cork - Authentication module for the Bottle web framework
|
||||
# Copyright (C) 2013 Federico Ceratto and others, see AUTHORS file.
|
||||
# Released under LGPLv3+ license, see LICENSE.txt
|
||||
|
||||
"""
|
||||
.. module:: mongodb_backend
|
||||
:synopsis: MongoDB storage backend.
|
||||
"""
|
||||
from logging import getLogger
|
||||
log = getLogger(__name__)
|
||||
|
||||
from .base_backend import Backend, Table
|
||||
|
||||
try:
|
||||
import pymongo
|
||||
is_pymongo_2 = (pymongo.version_tuple[0] == 2)
|
||||
except ImportError: # pragma: no cover
|
||||
pass
|
||||
|
||||
|
||||
class MongoTable(Table):
|
||||
"""Abstract MongoDB Table.
|
||||
Allow dictionary-like access.
|
||||
"""
|
||||
def __init__(self, name, key_name, collection):
|
||||
self._name = name
|
||||
self._key_name = key_name
|
||||
self._coll = collection
|
||||
|
||||
def create_index(self):
|
||||
"""Create collection index."""
|
||||
self._coll.create_index(
|
||||
self._key_name,
|
||||
drop_dups=True,
|
||||
unique=True,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return self._coll.count()
|
||||
|
||||
def __contains__(self, value):
|
||||
r = self._coll.find_one({self._key_name: value})
|
||||
return r is not None
|
||||
|
||||
def __iter__(self):
|
||||
"""Iter on dictionary keys"""
|
||||
if is_pymongo_2:
|
||||
r = self._coll.find(fields=[self._key_name,])
|
||||
else:
|
||||
r = self._coll.find(projection=[self._key_name,])
|
||||
|
||||
return (i[self._key_name] for i in r)
|
||||
|
||||
def iteritems(self):
|
||||
"""Iter on dictionary items.
|
||||
|
||||
:returns: generator of (key, value) tuples
|
||||
"""
|
||||
r = self._coll.find()
|
||||
for i in r:
|
||||
d = i.copy()
|
||||
d.pop(self._key_name)
|
||||
d.pop('_id')
|
||||
yield (i[self._key_name], d)
|
||||
|
||||
def pop(self, key_val):
|
||||
"""Remove a dictionary item"""
|
||||
r = self[key_val]
|
||||
self._coll.remove({self._key_name: key_val}, w=1)
|
||||
return r
|
||||
|
||||
|
||||
class MongoSingleValueTable(MongoTable):
|
||||
"""MongoDB table accessible as a simple key -> value dictionary.
|
||||
Used to store roles.
|
||||
"""
|
||||
# Values are stored in a MongoDB "column" named "val"
|
||||
def __init__(self, *args, **kw):
|
||||
super(MongoSingleValueTable, self).__init__(*args, **kw)
|
||||
|
||||
def __setitem__(self, key_val, data):
|
||||
assert not isinstance(data, dict)
|
||||
spec = {self._key_name: key_val}
|
||||
data = {self._key_name: key_val, 'val': data}
|
||||
if is_pymongo_2:
|
||||
self._coll.update(spec, {'$set': data}, upsert=True, w=1)
|
||||
else:
|
||||
self._coll.update_one(spec, {'$set': data}, upsert=True)
|
||||
|
||||
def __getitem__(self, key_val):
|
||||
r = self._coll.find_one({self._key_name: key_val})
|
||||
if r is None:
|
||||
raise KeyError(key_val)
|
||||
|
||||
return r['val']
|
||||
|
||||
class MongoMutableDict(dict):
|
||||
"""Represent an item from a Table. Acts as a dictionary.
|
||||
"""
|
||||
def __init__(self, parent, root_key, d):
|
||||
"""Create a MongoMutableDict instance.
|
||||
:param parent: Table instance
|
||||
:type parent: :class:`MongoTable`
|
||||
"""
|
||||
super(MongoMutableDict, self).__init__(d)
|
||||
self._parent = parent
|
||||
self._root_key = root_key
|
||||
|
||||
def __setitem__(self, k, v):
|
||||
super(MongoMutableDict, self).__setitem__(k, v)
|
||||
spec = {self._parent._key_name: self._root_key}
|
||||
if is_pymongo_2:
|
||||
r = self._parent._coll.update(spec, {'$set': {k: v}}, upsert=True)
|
||||
else:
|
||||
r = self._parent._coll.update_one(spec, {'$set': {k: v}}, upsert=True)
|
||||
|
||||
|
||||
|
||||
class MongoMultiValueTable(MongoTable):
|
||||
"""MongoDB table accessible as a dictionary.
|
||||
"""
|
||||
def __init__(self, *args, **kw):
|
||||
super(MongoMultiValueTable, self).__init__(*args, **kw)
|
||||
|
||||
def __setitem__(self, key_val, data):
|
||||
assert isinstance(data, dict)
|
||||
key_name = self._key_name
|
||||
if key_name in data:
|
||||
assert data[key_name] == key_val
|
||||
else:
|
||||
data[key_name] = key_val
|
||||
|
||||
spec = {key_name: key_val}
|
||||
if u'_id' in data:
|
||||
del(data[u'_id'])
|
||||
|
||||
if is_pymongo_2:
|
||||
self._coll.update(spec, {'$set': data}, upsert=True, w=1)
|
||||
else:
|
||||
self._coll.update_one(spec, {'$set': data}, upsert=True)
|
||||
|
||||
def __getitem__(self, key_val):
|
||||
r = self._coll.find_one({self._key_name: key_val})
|
||||
if r is None:
|
||||
raise KeyError(key_val)
|
||||
|
||||
return MongoMutableDict(self, key_val, r)
|
||||
|
||||
|
||||
class MongoDBBackend(Backend):
|
||||
def __init__(self, db_name='cork', hostname='localhost', port=27017, initialize=False, username=None, password=None):
|
||||
"""Initialize MongoDB Backend"""
|
||||
connection = pymongo.MongoClient(host=hostname, port=port)
|
||||
db = connection[db_name]
|
||||
if username and password:
|
||||
db.authenticate(username, password)
|
||||
self.users = MongoMultiValueTable('users', 'login', db.users)
|
||||
self.pending_registrations = MongoMultiValueTable(
|
||||
'pending_registrations',
|
||||
'pending_registration',
|
||||
db.pending_registrations
|
||||
)
|
||||
self.roles = MongoSingleValueTable('roles', 'role', db.roles)
|
||||
|
||||
if initialize:
|
||||
self._initialize_storage()
|
||||
|
||||
def _initialize_storage(self):
|
||||
"""Create MongoDB indexes."""
|
||||
for c in (self.users, self.roles, self.pending_registrations):
|
||||
c.create_index()
|
||||
|
||||
def save_users(self):
|
||||
pass
|
||||
|
||||
def save_roles(self):
|
||||
pass
|
||||
|
||||
def save_pending_registrations(self):
|
||||
pass
|
@ -1,204 +0,0 @@
|
||||
# Cork - Authentication module for the Bottle web framework
|
||||
# Copyright (C) 2013 Federico Ceratto and others, see AUTHORS file.
|
||||
# Released under LGPLv3+ license, see LICENSE.txt
|
||||
|
||||
"""
|
||||
.. module:: sqlalchemy_backend
|
||||
:synopsis: SQLAlchemy storage backend.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from logging import getLogger
|
||||
|
||||
from . import base_backend
|
||||
|
||||
log = getLogger(__name__)
|
||||
is_py3 = (sys.version_info.major == 3)
|
||||
|
||||
try:
|
||||
from sqlalchemy import create_engine, delete, select, \
|
||||
Column, ForeignKey, Integer, MetaData, String, Table, Unicode
|
||||
sqlalchemy_available = True
|
||||
except ImportError: # pragma: no cover
|
||||
sqlalchemy_available = False
|
||||
|
||||
|
||||
class SqlRowProxy(dict):
|
||||
def __init__(self, sql_dict, key, *args, **kwargs):
|
||||
dict.__init__(self, *args, **kwargs)
|
||||
self.sql_dict = sql_dict
|
||||
self.key = key
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
dict.__setitem__(self, key, value)
|
||||
if self.sql_dict is not None:
|
||||
self.sql_dict[self.key] = {key: value}
|
||||
|
||||
|
||||
class SqlTable(base_backend.Table):
|
||||
"""Provides dictionary-like access to an SQL table."""
|
||||
|
||||
def __init__(self, engine, table, key_col_name):
|
||||
self._engine = engine
|
||||
self._table = table
|
||||
self._key_col = table.c[key_col_name]
|
||||
|
||||
def _row_to_value(self, row):
|
||||
row_key = row[self._key_col]
|
||||
row_value = SqlRowProxy(self, row_key,
|
||||
((k, row[k]) for k in row.keys() if k != self._key_col.name))
|
||||
return row_key, row_value
|
||||
|
||||
def __len__(self):
|
||||
query = self._table.count()
|
||||
c = self._engine.execute(query).scalar()
|
||||
return int(c)
|
||||
|
||||
def __contains__(self, key):
|
||||
query = select([self._key_col], self._key_col == key)
|
||||
row = self._engine.execute(query).fetchone()
|
||||
return row is not None
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if key in self:
|
||||
values = value
|
||||
query = self._table.update().where(self._key_col == key)
|
||||
|
||||
else:
|
||||
values = {self._key_col.name: key}
|
||||
values.update(value)
|
||||
query = self._table.insert()
|
||||
|
||||
self._engine.execute(query.values(**values))
|
||||
|
||||
def __getitem__(self, key):
|
||||
query = select([self._table], self._key_col == key)
|
||||
row = self._engine.execute(query).fetchone()
|
||||
if row is None:
|
||||
raise KeyError(key)
|
||||
return self._row_to_value(row)[1]
|
||||
|
||||
def __iter__(self):
|
||||
"""Iterate over table index key values"""
|
||||
query = select([self._key_col])
|
||||
result = self._engine.execute(query)
|
||||
for row in result:
|
||||
key = row[0]
|
||||
yield key
|
||||
|
||||
def iteritems(self):
|
||||
"""Iterate over table rows"""
|
||||
query = select([self._table])
|
||||
result = self._engine.execute(query)
|
||||
for row in result:
|
||||
key = row[0]
|
||||
d = self._row_to_value(row)[1]
|
||||
yield (key, d)
|
||||
|
||||
def pop(self, key):
|
||||
query = select([self._table], self._key_col == key)
|
||||
row = self._engine.execute(query).fetchone()
|
||||
if row is None:
|
||||
raise KeyError
|
||||
|
||||
query = delete(self._table, self._key_col == key)
|
||||
self._engine.execute(query)
|
||||
return row
|
||||
|
||||
def insert(self, d):
|
||||
query = self._table.insert(d)
|
||||
self._engine.execute(query)
|
||||
log.debug("%s inserted" % repr(d))
|
||||
|
||||
def empty_table(self):
|
||||
query = self._table.delete()
|
||||
self._engine.execute(query)
|
||||
log.info("Table purged")
|
||||
|
||||
|
||||
class SqlSingleValueTable(SqlTable):
|
||||
def __init__(self, engine, table, key_col_name, col_name):
|
||||
SqlTable.__init__(self, engine, table, key_col_name)
|
||||
self._col_name = col_name
|
||||
|
||||
def _row_to_value(self, row):
|
||||
return row[self._key_col], row[self._col_name]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
SqlTable.__setitem__(self, key, {self._col_name: value})
|
||||
|
||||
|
||||
|
||||
class SqlAlchemyBackend(base_backend.Backend):
|
||||
|
||||
def __init__(self, db_full_url, users_tname='users', roles_tname='roles',
|
||||
pending_reg_tname='register', initialize=False):
|
||||
|
||||
if not sqlalchemy_available:
|
||||
raise RuntimeError("The SQLAlchemy library is not available.")
|
||||
|
||||
self._metadata = MetaData()
|
||||
if initialize:
|
||||
# Create new database if needed.
|
||||
db_url, db_name = db_full_url.rsplit('/', 1)
|
||||
if is_py3 and db_url.startswith('mysql'):
|
||||
print("WARNING: MySQL is not supported under Python3")
|
||||
|
||||
self._engine = create_engine(db_url, encoding='utf-8')
|
||||
try:
|
||||
self._engine.execute("CREATE DATABASE %s" % db_name)
|
||||
except Exception as e:
|
||||
log.info("Failed DB creation: %s" % e)
|
||||
|
||||
# SQLite in-memory database URL: "sqlite://:memory:"
|
||||
if db_name != ':memory:' and not db_url.startswith('postgresql'):
|
||||
self._engine.execute("USE %s" % db_name)
|
||||
|
||||
else:
|
||||
self._engine = create_engine(db_full_url, encoding='utf-8')
|
||||
|
||||
|
||||
self._users = Table(users_tname, self._metadata,
|
||||
Column('username', Unicode(128), primary_key=True),
|
||||
Column('role', ForeignKey(roles_tname + '.role')),
|
||||
Column('hash', String(256), nullable=False),
|
||||
Column('email_addr', String(128)),
|
||||
Column('desc', String(128)),
|
||||
Column('creation_date', String(128), nullable=False),
|
||||
Column('last_login', String(128), nullable=False)
|
||||
|
||||
)
|
||||
self._roles = Table(roles_tname, self._metadata,
|
||||
Column('role', String(128), primary_key=True),
|
||||
Column('level', Integer, nullable=False)
|
||||
)
|
||||
self._pending_reg = Table(pending_reg_tname, self._metadata,
|
||||
Column('code', String(128), primary_key=True),
|
||||
Column('username', Unicode(128), nullable=False),
|
||||
Column('role', ForeignKey(roles_tname + '.role')),
|
||||
Column('hash', String(256), nullable=False),
|
||||
Column('email_addr', String(128)),
|
||||
Column('desc', String(128)),
|
||||
Column('creation_date', String(128), nullable=False)
|
||||
)
|
||||
|
||||
self.users = SqlTable(self._engine, self._users, 'username')
|
||||
self.roles = SqlSingleValueTable(self._engine, self._roles, 'role', 'level')
|
||||
self.pending_registrations = SqlTable(self._engine, self._pending_reg, 'code')
|
||||
|
||||
if initialize:
|
||||
self._initialize_storage(db_name)
|
||||
log.debug("Tables created")
|
||||
|
||||
|
||||
def _initialize_storage(self, db_name):
|
||||
self._metadata.create_all(self._engine)
|
||||
|
||||
def _drop_all_tables(self):
|
||||
for table in reversed(self._metadata.sorted_tables):
|
||||
log.info("Dropping table %s" % repr(table.name))
|
||||
self._engine.execute(table.delete())
|
||||
|
||||
def save_users(self): pass
|
||||
def save_roles(self): pass
|
||||
def save_pending_registrations(self): pass
|
@ -1,242 +0,0 @@
|
||||
# Cork - Authentication module for the Bottle web framework
|
||||
# Copyright (C) 2013 Federico Ceratto and others, see AUTHORS file.
|
||||
# Released under LGPLv3+ license, see LICENSE.txt
|
||||
|
||||
"""
|
||||
.. module:: sqlite_backend
|
||||
:synopsis: SQLite storage backend.
|
||||
"""
|
||||
|
||||
from . import base_backend
|
||||
from logging import getLogger
|
||||
log = getLogger(__name__)
|
||||
|
||||
|
||||
class SqlRowProxy(dict):
|
||||
def __init__(self, table, key, row):
|
||||
li = ((k, v) for (k, ktype), v in zip(table._columns[1:], row[1:]))
|
||||
dict.__init__(self, li)
|
||||
self._table = table
|
||||
self._key = key
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
dict.__setitem__(self, key, value)
|
||||
self._table[self._key] = self
|
||||
|
||||
|
||||
class Table(base_backend.Table):
|
||||
"""Provides dictionary-like access to an SQL table."""
|
||||
|
||||
def __init__(self, backend, table_name):
|
||||
self._backend = backend
|
||||
self._engine = backend.connection
|
||||
self._table_name = table_name
|
||||
self._column_names = [n for n, t in self._columns]
|
||||
self._key_col_num = 0
|
||||
self._key_col_name = self._column_names[self._key_col_num]
|
||||
self._key_col = self._column_names[self._key_col_num]
|
||||
|
||||
def _row_to_value(self, key, row):
|
||||
assert isinstance(row, tuple)
|
||||
row_key = row[self._key_col_num]
|
||||
row_value = SqlRowProxy(self, key, row)
|
||||
return row_key, row_value
|
||||
|
||||
def __len__(self):
|
||||
query = "SELECT count() FROM %s" % self._table_name
|
||||
ret = self._backend.run_query(query)
|
||||
return ret.fetchone()[0]
|
||||
|
||||
def __contains__(self, key):
|
||||
#FIXME: count()
|
||||
query = "SELECT * FROM %s WHERE %s='%s'" % \
|
||||
(self._table_name, self._key_col, key)
|
||||
row = self._backend.fetch_one(query)
|
||||
return row is not None
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
"""Create or update a row"""
|
||||
assert isinstance(value, dict)
|
||||
v, cn = set(value), set(self._column_names[1:])
|
||||
assert not v - cn, repr(v - cn)
|
||||
assert not cn - v, repr(cn - v)
|
||||
|
||||
assert set(value) == set(self._column_names[1:]), "%s %s" % \
|
||||
(repr(set(value)), repr(set(self._column_names[1:])))
|
||||
|
||||
col_values = [key] + [value[k] for k in self._column_names[1:]]
|
||||
|
||||
col_names = ', '.join(self._column_names)
|
||||
question_marks = ', '.join('?' for x in col_values)
|
||||
query = "INSERT OR REPLACE INTO %s (%s) VALUES (%s)" % \
|
||||
(self._table_name, col_names, question_marks)
|
||||
|
||||
ret = self._backend.run_query_using_conversion(query, col_values)
|
||||
|
||||
|
||||
def __getitem__(self, key):
|
||||
query = "SELECT * FROM %s WHERE %s='%s'" % \
|
||||
(self._table_name, self._key_col, key)
|
||||
row = self._backend.fetch_one(query)
|
||||
if row is None:
|
||||
raise KeyError(key)
|
||||
|
||||
return self._row_to_value(key, row)[1]
|
||||
#return dict(zip(self._column_names, row))
|
||||
|
||||
def __iter__(self):
|
||||
"""Iterate over table index key values"""
|
||||
query = "SELECT %s FROM %s" % (self._key_col, self._table_name)
|
||||
result = self._backend.run_query(query)
|
||||
for row in result:
|
||||
yield row[0]
|
||||
|
||||
def iteritems(self):
|
||||
"""Iterate over table rows"""
|
||||
query = "SELECT * FROM %s" % self._table_name
|
||||
result = self._backend.run_query(query)
|
||||
for row in result:
|
||||
d = dict(zip(self._column_names, row))
|
||||
d.pop(self._key_col)
|
||||
|
||||
yield (self._key_col, d)
|
||||
|
||||
def pop(self, key):
|
||||
d = self.__getitem__(key)
|
||||
query = "DELETE FROM %s WHERE %s='%s'" % \
|
||||
(self._table_name, self._key_col, key)
|
||||
self._backend.fetch_one(query)
|
||||
#FIXME: check deletion
|
||||
return d
|
||||
|
||||
def insert(self, d):
|
||||
raise NotImplementedError
|
||||
|
||||
def empty_table(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def create_table(self):
|
||||
"""Issue table creation"""
|
||||
cc = []
|
||||
for col_name, col_type in self._columns:
|
||||
if col_type == int:
|
||||
col_type = 'INTEGER'
|
||||
elif col_type == str:
|
||||
col_type = 'TEXT'
|
||||
|
||||
if col_name == self._key_col:
|
||||
extras = 'PRIMARY KEY ASC'
|
||||
else:
|
||||
extras = ''
|
||||
|
||||
cc.append("%s %s %s" % (col_name, col_type, extras))
|
||||
|
||||
cc = ','.join(cc)
|
||||
query = "CREATE TABLE %s (%s)" % (self._table_name, cc)
|
||||
self._backend.run_query(query)
|
||||
|
||||
|
||||
class SingleValueTable(Table):
|
||||
def __init__(self, *args):
|
||||
super(SingleValueTable, self).__init__(*args)
|
||||
self._value_col = self._column_names[1]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
"""Create or update a row"""
|
||||
assert not isinstance(value, dict)
|
||||
query = "INSERT OR REPLACE INTO %s (%s, %s) VALUES (?, ?)" % \
|
||||
(self._table_name, self._key_col, self._value_col)
|
||||
|
||||
col_values = (key, value)
|
||||
ret = self._backend.run_query_using_conversion(query, col_values)
|
||||
|
||||
def __getitem__(self, key):
|
||||
query = "SELECT %s FROM %s WHERE %s='%s'" % \
|
||||
(self._value_col, self._table_name, self._key_col, key)
|
||||
row = self._backend.fetch_one(query)
|
||||
if row is None:
|
||||
raise KeyError(key)
|
||||
|
||||
return row[0]
|
||||
|
||||
class UsersTable(Table):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._columns = (
|
||||
('username', str),
|
||||
('role', str),
|
||||
('hash', str),
|
||||
('email_addr', str),
|
||||
('desc', str),
|
||||
('creation_date', str),
|
||||
('last_login', str)
|
||||
)
|
||||
super(UsersTable, self).__init__(*args, **kwargs)
|
||||
|
||||
class RolesTable(SingleValueTable):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._columns = (
|
||||
('role', str),
|
||||
('level', int)
|
||||
)
|
||||
super(RolesTable, self).__init__(*args, **kwargs)
|
||||
|
||||
class PendingRegistrationsTable(Table):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._columns = (
|
||||
('code', str),
|
||||
('username', str),
|
||||
('role', str),
|
||||
('hash', str),
|
||||
('email_addr', str),
|
||||
('desc', str),
|
||||
('creation_date', str)
|
||||
)
|
||||
super(PendingRegistrationsTable, self).__init__(*args, **kwargs)
|
||||
|
||||
|
||||
|
||||
|
||||
class SQLiteBackend(base_backend.Backend):
|
||||
|
||||
def __init__(self, filename, users_tname='users', roles_tname='roles',
|
||||
pending_reg_tname='register', initialize=False):
|
||||
|
||||
self._filename = filename
|
||||
|
||||
self.users = UsersTable(self, users_tname)
|
||||
self.roles = RolesTable(self, roles_tname)
|
||||
self.pending_registrations = PendingRegistrationsTable(self, pending_reg_tname)
|
||||
|
||||
if initialize:
|
||||
self.users.create_table()
|
||||
self.roles.create_table()
|
||||
self.pending_registrations.create_table()
|
||||
log.debug("Tables created")
|
||||
|
||||
@property
|
||||
def connection(self):
|
||||
try:
|
||||
return self._connection
|
||||
except AttributeError:
|
||||
import sqlite3
|
||||
self._connection = sqlite3.connect(self._filename)
|
||||
return self._connection
|
||||
|
||||
def run_query(self, query):
|
||||
return self._connection.execute(query)
|
||||
|
||||
def run_query_using_conversion(self, query, args):
|
||||
return self._connection.execute(query, args)
|
||||
|
||||
def fetch_one(self, query):
|
||||
return self._connection.execute(query).fetchone()
|
||||
|
||||
def _initialize_storage(self, db_name):
|
||||
raise NotImplementedError
|
||||
|
||||
def _drop_all_tables(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def save_users(self): pass
|
||||
def save_roles(self): pass
|
||||
def save_pending_registrations(self): pass
|
@ -1,221 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import datetime
|
||||
import sys
|
||||
from getpass import getpass
|
||||
from optparse import OptionParser
|
||||
|
||||
from peewee import *
|
||||
from peewee import print_
|
||||
from peewee import __version__ as peewee_version
|
||||
from playhouse.reflection import *
|
||||
|
||||
|
||||
HEADER = """from peewee import *%s
|
||||
|
||||
database = %s('%s'%s)
|
||||
"""
|
||||
|
||||
BASE_MODEL = """\
|
||||
class BaseModel(Model):
|
||||
class Meta:
|
||||
database = database
|
||||
"""
|
||||
|
||||
UNKNOWN_FIELD = """\
|
||||
class UnknownField(object):
|
||||
def __init__(self, *_, **__): pass
|
||||
"""
|
||||
|
||||
DATABASE_ALIASES = {
|
||||
MySQLDatabase: ['mysql', 'mysqldb'],
|
||||
PostgresqlDatabase: ['postgres', 'postgresql'],
|
||||
SqliteDatabase: ['sqlite', 'sqlite3'],
|
||||
}
|
||||
|
||||
DATABASE_MAP = dict((value, key)
|
||||
for key in DATABASE_ALIASES
|
||||
for value in DATABASE_ALIASES[key])
|
||||
|
||||
def make_introspector(database_type, database_name, **kwargs):
|
||||
if database_type not in DATABASE_MAP:
|
||||
err('Unrecognized database, must be one of: %s' %
|
||||
', '.join(DATABASE_MAP.keys()))
|
||||
sys.exit(1)
|
||||
|
||||
schema = kwargs.pop('schema', None)
|
||||
DatabaseClass = DATABASE_MAP[database_type]
|
||||
db = DatabaseClass(database_name, **kwargs)
|
||||
return Introspector.from_database(db, schema=schema)
|
||||
|
||||
def print_models(introspector, tables=None, preserve_order=False,
|
||||
include_views=False, ignore_unknown=False, snake_case=True):
|
||||
database = introspector.introspect(table_names=tables,
|
||||
include_views=include_views,
|
||||
snake_case=snake_case)
|
||||
|
||||
db_kwargs = introspector.get_database_kwargs()
|
||||
header = HEADER % (
|
||||
introspector.get_additional_imports(),
|
||||
introspector.get_database_class().__name__,
|
||||
introspector.get_database_name(),
|
||||
', **%s' % repr(db_kwargs) if db_kwargs else '')
|
||||
print_(header)
|
||||
|
||||
if not ignore_unknown:
|
||||
print_(UNKNOWN_FIELD)
|
||||
|
||||
print_(BASE_MODEL)
|
||||
|
||||
def _print_table(table, seen, accum=None):
|
||||
accum = accum or []
|
||||
foreign_keys = database.foreign_keys[table]
|
||||
for foreign_key in foreign_keys:
|
||||
dest = foreign_key.dest_table
|
||||
|
||||
# In the event the destination table has already been pushed
|
||||
# for printing, then we have a reference cycle.
|
||||
if dest in accum and table not in accum:
|
||||
print_('# Possible reference cycle: %s' % dest)
|
||||
|
||||
# If this is not a self-referential foreign key, and we have
|
||||
# not already processed the destination table, do so now.
|
||||
if dest not in seen and dest not in accum:
|
||||
seen.add(dest)
|
||||
if dest != table:
|
||||
_print_table(dest, seen, accum + [table])
|
||||
|
||||
print_('class %s(BaseModel):' % database.model_names[table])
|
||||
columns = database.columns[table].items()
|
||||
if not preserve_order:
|
||||
columns = sorted(columns)
|
||||
primary_keys = database.primary_keys[table]
|
||||
for name, column in columns:
|
||||
skip = all([
|
||||
name in primary_keys,
|
||||
name == 'id',
|
||||
len(primary_keys) == 1,
|
||||
column.field_class in introspector.pk_classes])
|
||||
if skip:
|
||||
continue
|
||||
if column.primary_key and len(primary_keys) > 1:
|
||||
# If we have a CompositeKey, then we do not want to explicitly
|
||||
# mark the columns as being primary keys.
|
||||
column.primary_key = False
|
||||
|
||||
is_unknown = column.field_class is UnknownField
|
||||
if is_unknown and ignore_unknown:
|
||||
disp = '%s - %s' % (column.name, column.raw_column_type or '?')
|
||||
print_(' # %s' % disp)
|
||||
else:
|
||||
print_(' %s' % column.get_field())
|
||||
|
||||
print_('')
|
||||
print_(' class Meta:')
|
||||
print_(' table_name = \'%s\'' % table)
|
||||
multi_column_indexes = database.multi_column_indexes(table)
|
||||
if multi_column_indexes:
|
||||
print_(' indexes = (')
|
||||
for fields, unique in sorted(multi_column_indexes):
|
||||
print_(' ((%s), %s),' % (
|
||||
', '.join("'%s'" % field for field in fields),
|
||||
unique,
|
||||
))
|
||||
print_(' )')
|
||||
|
||||
if introspector.schema:
|
||||
print_(' schema = \'%s\'' % introspector.schema)
|
||||
if len(primary_keys) > 1:
|
||||
pk_field_names = sorted([
|
||||
field.name for col, field in columns
|
||||
if col in primary_keys])
|
||||
pk_list = ', '.join("'%s'" % pk for pk in pk_field_names)
|
||||
print_(' primary_key = CompositeKey(%s)' % pk_list)
|
||||
elif not primary_keys:
|
||||
print_(' primary_key = False')
|
||||
print_('')
|
||||
|
||||
seen.add(table)
|
||||
|
||||
seen = set()
|
||||
for table in sorted(database.model_names.keys()):
|
||||
if table not in seen:
|
||||
if not tables or table in tables:
|
||||
_print_table(table, seen)
|
||||
|
||||
def print_header(cmd_line, introspector):
|
||||
timestamp = datetime.datetime.now()
|
||||
print_('# Code generated by:')
|
||||
print_('# python -m pwiz %s' % cmd_line)
|
||||
print_('# Date: %s' % timestamp.strftime('%B %d, %Y %I:%M%p'))
|
||||
print_('# Database: %s' % introspector.get_database_name())
|
||||
print_('# Peewee version: %s' % peewee_version)
|
||||
print_('')
|
||||
|
||||
|
||||
def err(msg):
|
||||
sys.stderr.write('\033[91m%s\033[0m\n' % msg)
|
||||
sys.stderr.flush()
|
||||
|
||||
def get_option_parser():
|
||||
parser = OptionParser(usage='usage: %prog [options] database_name')
|
||||
ao = parser.add_option
|
||||
ao('-H', '--host', dest='host')
|
||||
ao('-p', '--port', dest='port', type='int')
|
||||
ao('-u', '--user', dest='user')
|
||||
ao('-P', '--password', dest='password', action='store_true')
|
||||
engines = sorted(DATABASE_MAP)
|
||||
ao('-e', '--engine', dest='engine', default='postgresql', choices=engines,
|
||||
help=('Database type, e.g. sqlite, mysql or postgresql. Default '
|
||||
'is "postgresql".'))
|
||||
ao('-s', '--schema', dest='schema')
|
||||
ao('-t', '--tables', dest='tables',
|
||||
help=('Only generate the specified tables. Multiple table names should '
|
||||
'be separated by commas.'))
|
||||
ao('-v', '--views', dest='views', action='store_true',
|
||||
help='Generate model classes for VIEWs in addition to tables.')
|
||||
ao('-i', '--info', dest='info', action='store_true',
|
||||
help=('Add database information and other metadata to top of the '
|
||||
'generated file.'))
|
||||
ao('-o', '--preserve-order', action='store_true', dest='preserve_order',
|
||||
help='Model definition column ordering matches source table.')
|
||||
ao('-I', '--ignore-unknown', action='store_true', dest='ignore_unknown',
|
||||
help='Ignore fields whose type cannot be determined.')
|
||||
ao('-L', '--legacy-naming', action='store_true', dest='legacy_naming',
|
||||
help='Use legacy table- and column-name generation.')
|
||||
return parser
|
||||
|
||||
def get_connect_kwargs(options):
|
||||
ops = ('host', 'port', 'user', 'schema')
|
||||
kwargs = dict((o, getattr(options, o)) for o in ops if getattr(options, o))
|
||||
if options.password:
|
||||
kwargs['password'] = getpass()
|
||||
return kwargs
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
raw_argv = sys.argv
|
||||
|
||||
parser = get_option_parser()
|
||||
options, args = parser.parse_args()
|
||||
|
||||
if len(args) < 1:
|
||||
err('Missing required parameter "database"')
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
connect = get_connect_kwargs(options)
|
||||
database = args[-1]
|
||||
|
||||
tables = None
|
||||
if options.tables:
|
||||
tables = [table.strip() for table in options.tables.split(',')
|
||||
if table.strip()]
|
||||
|
||||
introspector = make_introspector(options.engine, database, **connect)
|
||||
if options.info:
|
||||
cmd_line = ' '.join(raw_argv[1:])
|
||||
print_header(cmd_line, introspector)
|
||||
|
||||
print_models(introspector, tables, options.preserve_order, options.views,
|
||||
options.ignore_unknown, not options.legacy_naming)
|
Loading…
Reference in new issue