Source code for pydal.adapters.base

# -*- coding: utf-8 -*-
import re
import os
import sys
import locale
import datetime
import decimal
import copy
import time
import base64
import types
import json

from .._compat import PY2, pjoin, exists, pickle, hashlib_md5, iterkeys, \
    iteritems, with_metaclass, to_unicode, integer_types, basestring, \
    string_types, to_bytes
from .._globals import IDENTITY
from .._load import portalocker
from ..connection import ConnectionPool
from ..objects import Expression, Field, Query, Table, Row, FieldVirtual, \
    FieldMethod, LazyReferenceGetter, LazySet, VirtualCommand, Rows, IterRows
from ..helpers.regex import REGEX_NO_GREEDY_ENTITY_NAME, REGEX_TYPE, \
    REGEX_SELECT_AS_PARSER
from ..helpers.methods import xorify, use_common_filters, bar_encode, \
    bar_decode_integer, bar_decode_string
from ..helpers.classes import SQLCustomType, SQLALL, Reference, \
    RecordUpdater, RecordDeleter, NullDriver, FakeCursor
from ..helpers.serializers import serializers

if PY2:
    from itertools import izip as zip

long = integer_types[-1]

TIMINGSSIZE = 100
CALLABLETYPES = (types.LambdaType, types.FunctionType,
                 types.BuiltinFunctionType,
                 types.MethodType, types.BuiltinMethodType)
SELECT_ARGS = set(
    ('orderby', 'groupby', 'limitby', 'required', 'cache', 'left', 'distinct',
     'having', 'join', 'for_update', 'processor', 'cacheable',
     'orderby_on_limitby','outer_scoped'))


[docs]class AdapterMeta(type): """Metaclass to support manipulation of adapter classes. At the moment is used to intercept `entity_quoting` argument passed to DAL. """ def __call__(cls, *args, **kwargs): uploads_in_blob = kwargs.get('adapter_args', {}).get( 'uploads_in_blob', cls.uploads_in_blob) cls.uploads_in_blob = uploads_in_blob entity_quoting = kwargs.get('entity_quoting', False) if 'entity_quoting' in kwargs: del kwargs['entity_quoting'] obj = super(AdapterMeta, cls).__call__(*args, **kwargs) if not entity_quoting: quot = obj.QUOTE_TEMPLATE = '%s' regex_ent = r'(\w+)' else: quot = obj.QUOTE_TEMPLATE regex_ent = REGEX_NO_GREEDY_ENTITY_NAME obj.REGEX_TABLE_DOT_FIELD = re.compile( r'^' + quot % regex_ent + r'\.' + quot % regex_ent + r'$') return obj
[docs]class BaseAdapter(with_metaclass(AdapterMeta, ConnectionPool)): driver_auto_json = [] driver = None driver_name = None drivers = () # list of drivers from which to pick connection = None commit_on_alter_table = False support_distributed_transaction = False uploads_in_blob = False can_select_for_update = True dbpath = None folder = None connector = lambda *args, **kwargs: None # __init__ should override this TRUE_exp = '1' FALSE_exp = '0' TRUE = 'T' FALSE = 'F' T_SEP = ' ' QUOTE_TEMPLATE = '"%s"' test_query = 'SELECT 1;' cursors_in_use = [] current_cursor_in_use = False types = { 'boolean': 'CHAR(1)', 'string': 'CHAR(%(length)s)', 'text': 'TEXT', 'json': 'TEXT', 'password': 'CHAR(%(length)s)', 'blob': 'BLOB', 'upload': 'CHAR(%(length)s)', 'integer': 'INTEGER', 'bigint': 'INTEGER', 'float':'DOUBLE', 'double': 'DOUBLE', 'decimal': 'DOUBLE', 'date': 'DATE', 'time': 'TIME', 'datetime': 'TIMESTAMP', 'id': 'INTEGER PRIMARY KEY AUTOINCREMENT', 'reference': 'INTEGER REFERENCES %(foreign_key)s ON DELETE %(on_delete_action)s %(null)s %(unique)s', 'list:integer': 'TEXT', 'list:string': 'TEXT', 'list:reference': 'TEXT', # the two below are only used when DAL(...bigint_id=True) and replace 'id','reference' 'big-id': 'INTEGER PRIMARY KEY AUTOINCREMENT', 'big-reference': 'INTEGER REFERENCES %(foreign_key)s ON DELETE %(on_delete_action)s %(null)s %(unique)s', 'reference FK': ', CONSTRAINT "FK_%(constraint_name)s" FOREIGN KEY (%(field_name)s) REFERENCES %(foreign_key)s ON DELETE %(on_delete_action)s', }
[docs] def isOperationalError(self,exception): if not hasattr(self.driver, "OperationalError"): return None return isinstance(exception, self.driver.OperationalError)
[docs] def isProgrammingError(self,exception): if not hasattr(self.driver, "ProgrammingError"): return None return isinstance(exception, self.driver.ProgrammingError)
[docs] def id_query(self, table): pkeys = getattr(table,'_primarykey',None) if pkeys: return table[pkeys[0]] != None else: return table._id != None
[docs] def adapt(self, obj): return "'%s'" % obj.replace("'", "''")
[docs] def smart_adapt(self, obj): if isinstance(obj,(int,float)): return str(obj) return self.adapt(str(obj))
[docs] def file_exists(self, filename): #to be used ONLY for files that on GAE may not be on filesystem return exists(filename)
[docs] def file_open(self, filename, mode='rb', lock=True): #to be used ONLY for files that on GAE may not be on filesystem if lock: fileobj = portalocker.LockedFile(filename,mode) else: fileobj = open(filename,mode) return fileobj
[docs] def file_close(self, fileobj): #to be used ONLY for files that on GAE may not be on filesystem if fileobj: fileobj.close()
[docs] def file_delete(self, filename): os.unlink(filename)
[docs] def find_driver(self, adapter_args, uri=None): self.adapter_args = adapter_args if getattr(self, 'driver', None) is not None: return drivers_available = [driver for driver in self.drivers if driver in iterkeys(self.db._drivers_available)] if uri: items = uri.split('://', 1)[0].split(':') request_driver = items[1] if len(items) > 1 else None else: request_driver = None request_driver = request_driver or adapter_args.get('driver') if request_driver: if request_driver in drivers_available: self.driver_name = request_driver #self.driver = globals().get(request_driver) self.driver = self.db._drivers_available[request_driver] else: raise RuntimeError("driver %s not available" % request_driver) elif drivers_available: self.driver_name = drivers_available[0] #self.driver = globals().get(self.driver_name) self.driver = self.db._drivers_available[self.driver_name] else: raise RuntimeError("no driver available %s" % str(self.drivers))
[docs] def log(self, message, table=None): """ Logs migrations It will not log changes if logfile is not specified. Defaults to sql.log """ isabs = None logfilename = self.adapter_args.get('logfile','sql.log') writelog = bool(logfilename) if writelog: isabs = os.path.isabs(logfilename) if table and table._dbt and writelog and self.folder: if isabs: table._loggername = logfilename else: table._loggername = pjoin(self.folder, logfilename) logfile = self.file_open(table._loggername, 'ab') logfile.write(to_bytes(message)) self.file_close(logfile)
def __init__(self, db, uri, pool_size=0, folder=None, db_codec='UTF-8', credential_decoder=IDENTITY, driver_args={}, adapter_args={}, do_connect=True, after_connection=None): self.db = db self.dbengine = "None" self.uri = uri self.pool_size = pool_size self.folder = folder self.db_codec = db_codec self._after_connection = after_connection self.connection = None self.cursor = None if uri == "None": self.connector = NullDriver self.reconnect()
[docs] def get_cursor(self): # safe_reuse allows further queries to be executed using the same cursor if self.current_cursor_in_use == True: self.current_cursor_in_use = False # Save locally the current cursor in use self.cursors_in_use.append(self.cursor) self.cursor = self.connection.cursor() return self.cursor
[docs] def sequence_name(self,tablename): return self.QUOTE_TEMPLATE % ('%s_sequence' % tablename)
[docs] def trigger_name(self,tablename): return '%s_sequence' % tablename
[docs] def varquote(self,name): return name
[docs] def create_table(self, table, migrate=True, fake_migrate=False, polymodel=None): db = table._db fields = [] # PostGIS geo fields are added after the table has been created postcreation_fields = [] sql_fields = {} sql_fields_aux = {} TFK = {} tablename = table._tablename types = self.types for sortable, field in enumerate(table, start=1): field_name = field.name field_type = field.type if isinstance(field_type,SQLCustomType): ftype = field_type.native or field_type.type elif field_type.startswith(('reference', 'big-reference')): if field_type.startswith('reference'): referenced = field_type[10:].strip() type_name = 'reference' else: referenced = field_type[14:].strip() type_name = 'big-reference' if referenced == '.': referenced = tablename constraint_name = self.constraint_name(tablename, field_name) # if not '.' in referenced \ # and referenced != tablename \ # and hasattr(table,'_primarykey'): # ftype = types['integer'] #else: try: rtable = db[referenced] rfield = rtable._id rfieldname = rfield.name rtablename = referenced except (KeyError, ValueError, AttributeError) as e: self.db.logger.debug('Error: %s' % e) try: rtablename,rfieldname = referenced.split('.') rtable = db[rtablename] rfield = rtable[rfieldname] except Exception as e: self.db.logger.debug('Error: %s' %e) raise KeyError('Cannot resolve reference %s in %s definition' % (referenced, table._tablename)) # must be PK reference or unique if getattr(rtable, '_primarykey', None) and rfieldname in rtable._primarykey or \ rfield.unique: ftype = types[rfield.type[:9]] % \ dict(length=rfield.length) # multicolumn primary key reference? if not rfield.unique and len(rtable._primarykey)>1: # then it has to be a table level FK if rtablename not in TFK: TFK[rtablename] = {} TFK[rtablename][rfieldname] = field_name else: ftype = ftype + \ types['reference FK'] % dict( constraint_name = constraint_name, # should be quoted foreign_key = rtable.sqlsafe + ' (' + rfield.sqlsafe_name + ')', table_name = table.sqlsafe, field_name = field.sqlsafe_name, on_delete_action=field.ondelete) else: # make a guess here for circular references if referenced in db: id_fieldname = db[referenced]._id.sqlsafe_name elif referenced == tablename: id_fieldname = table._id.sqlsafe_name else: #make a guess id_fieldname = self.QUOTE_TEMPLATE % 'id' #gotcha: the referenced table must be defined before #the referencing one to be able to create the table #Also if it's not recommended, we can still support #references to tablenames without rname to make #migrations and model relationship work also if tables #are not defined in order if referenced == tablename: real_referenced = db[referenced].sqlsafe else: real_referenced = (referenced in db and db[referenced].sqlsafe or referenced) rfield = db[referenced]._id ftype_info = dict( index_name = self.QUOTE_TEMPLATE % (field_name+'__idx'), field_name = field.sqlsafe_name, constraint_name = self.QUOTE_TEMPLATE % constraint_name, foreign_key = '%s (%s)' % (real_referenced, rfield.sqlsafe_name), on_delete_action=field.ondelete, ) ftype_info['null'] = ' NOT NULL' if field.notnull else self.ALLOW_NULL() ftype_info['unique'] = ' UNIQUE' if field.unique else '' ftype = types[type_name] % ftype_info elif field_type.startswith('list:reference'): ftype = types[field_type[:14]] elif field_type.startswith('decimal'): precision, scale = map(int,field_type[8:-1].split(',')) ftype = types[field_type[:7]] % \ dict(precision=precision,scale=scale) elif field_type.startswith('geo'): if not hasattr(self,'srid'): raise RuntimeError('Adapter does not support geometry') srid = self.srid geotype, parms = field_type[:-1].split('(') if not geotype in types: raise SyntaxError( 'Field: unknown field type: %s for %s' \ % (field_type, field_name)) ftype = types[geotype] if self.dbengine == 'postgres' and geotype == 'geometry': if self.ignore_field_case is True: field_name = field_name.lower() # parameters: schema, srid, dimension dimension = 2 # GIS.dimension ??? parms = parms.split(',') if len(parms) == 3: schema, srid, dimension = parms elif len(parms) == 2: schema, srid = parms else: schema = parms[0] ftype = "SELECT AddGeometryColumn ('%%(schema)s', '%%(tablename)s', '%%(fieldname)s', %%(srid)s, '%s', %%(dimension)s);" % types[geotype] ftype = ftype % dict(schema=schema, tablename=tablename, fieldname=field_name, srid=srid, dimension=dimension) postcreation_fields.append(ftype) elif field_type not in types: raise SyntaxError('Field: unknown field type: %s for %s' % \ (field_type, field_name)) else: ftype = types[field_type] % {'length':field.length} if not field_type.startswith(('id','reference', 'big-reference')): if field.notnull: ftype += ' NOT NULL' else: ftype += self.ALLOW_NULL() if field.unique: ftype += ' UNIQUE' if field.custom_qualifier: ftype += ' %s' % field.custom_qualifier # add to list of fields sql_fields[field_name] = dict( length=field.length, unique=field.unique, notnull=field.notnull, sortable=sortable, type=str(field_type), sql=ftype) if field.notnull and not field.default is None: # Caveat: sql_fields and sql_fields_aux # differ for default values. # sql_fields is used to trigger migrations and sql_fields_aux # is used for create tables. # The reason is that we do not want to trigger # a migration simply because a default value changes. not_null = self.NOT_NULL(field.default, field_type) ftype = ftype.replace('NOT NULL', not_null) sql_fields_aux[field_name] = dict(sql=ftype) # Postgres - PostGIS: # geometry fields are added after the table has been created, not now if not (self.dbengine == 'postgres' and \ field_type.startswith('geom')): fields.append('%s %s' % (field.sqlsafe_name, ftype)) other = ';' # backend-specific extensions to fields if self.dbengine == 'mysql': if not hasattr(table, "_primarykey"): fields.append('PRIMARY KEY (%s)' % (self.QUOTE_TEMPLATE % table._id.name)) engine = self.adapter_args.get('engine','InnoDB') other = ' ENGINE=%s CHARACTER SET utf8;' % engine fields = ',\n '.join(fields) for rtablename in TFK: rfields = TFK[rtablename] pkeys = [self.QUOTE_TEMPLATE % pk for pk in db[rtablename]._primarykey] fkeys = [self.QUOTE_TEMPLATE % rfields[k].name for k in pkeys ] fields = fields + ',\n ' + \ types['reference TFK'] % dict( table_name = table.sqlsafe, field_name=', '.join(fkeys), foreign_table = table.sqlsafe, foreign_key = ', '.join(pkeys), on_delete_action = field.ondelete) if getattr(table,'_primarykey',None): query = "CREATE TABLE %s(\n %s,\n %s) %s" % \ (table.sqlsafe, fields, self.PRIMARY_KEY(', '.join([self.QUOTE_TEMPLATE % pk for pk in table._primarykey])),other) else: query = "CREATE TABLE %s(\n %s\n)%s" % \ (table.sqlsafe, fields, other) if self.uri.startswith('sqlite:///') \ or self.uri.startswith('spatialite:///'): if PY2: path_encoding = sys.getfilesystemencoding() \ or locale.getdefaultlocale()[1] or 'utf8' dbpath = self.uri[9:self.uri.rfind('/')].decode( 'utf8').encode(path_encoding) else: dbpath = self.uri[9:self.uri.rfind('/')] else: dbpath = self.folder if not migrate: return query elif self.uri.startswith('sqlite:memory')\ or self.uri.startswith('spatialite:memory'): table._dbt = None elif isinstance(migrate, string_types): table._dbt = pjoin(dbpath, migrate) else: table._dbt = pjoin( dbpath, '%s_%s.table' % (db._uri_hash, tablename)) if not table._dbt or not self.file_exists(table._dbt): if table._dbt: self.log('timestamp: %s\n%s\n' % (datetime.datetime.today().isoformat(), query), table) if not fake_migrate: self.create_sequence_and_triggers(query, table) db.commit() # Postgres geom fields are added now, # after the table has been created for query in postcreation_fields: self.execute(query) db.commit() if table._dbt: tfile = self.file_open(table._dbt, 'wb') pickle.dump(sql_fields, tfile) self.file_close(tfile) if fake_migrate: self.log('faked!\n', table) else: self.log('success!\n', table) else: tfile = self.file_open(table._dbt, 'rb') try: sql_fields_old = pickle.load(tfile) except EOFError: self.file_close(tfile) raise RuntimeError('File %s appears corrupted' % table._dbt) self.file_close(tfile) if sql_fields != sql_fields_old: self.migrate_table( table, sql_fields, sql_fields_old, sql_fields_aux, None, fake_migrate=fake_migrate ) return query
[docs] def migrate_table( self, table, sql_fields, sql_fields_old, sql_fields_aux, logfile, fake_migrate=False, ): # logfile is deprecated (moved to adapter.log method) db = table._db db._migrated.append(table._tablename) tablename = table._tablename def fix(item): k,v=item if not isinstance(v,dict): v=dict(type='unknown',sql=v) if self.ignore_field_case is not True: return k, v return k.lower(),v # make sure all field names are lower case to avoid # migrations because of case cahnge sql_fields = dict(map(fix, iteritems(sql_fields))) sql_fields_old = dict(map(fix, iteritems(sql_fields_old))) sql_fields_aux = dict(map(fix, iteritems(sql_fields_aux))) if db._debug: db.logger.debug('migrating %s to %s' % (sql_fields_old,sql_fields)) keys = list(sql_fields.keys()) for key in sql_fields_old: if not key in keys: keys.append(key) new_add = self.concat_add(tablename) metadata_change = False sql_fields_current = copy.copy(sql_fields_old) for key in keys: query = None if not key in sql_fields_old: sql_fields_current[key] = sql_fields[key] if self.dbengine in ('postgres',) and \ sql_fields[key]['type'].startswith('geometry'): # 'sql' == ftype in sql query = [ sql_fields[key]['sql'] ] else: query = ['ALTER TABLE %s ADD %s %s;' % \ (table.sqlsafe, key, sql_fields_aux[key]['sql'].replace(', ', new_add))] metadata_change = True elif self.dbengine in ('sqlite', 'spatialite'): if key in sql_fields: sql_fields_current[key] = sql_fields[key] metadata_change = True elif not key in sql_fields: del sql_fields_current[key] ftype = sql_fields_old[key]['type'] if (self.dbengine in ('postgres',) and ftype.startswith('geometry')): geotype, parms = ftype[:-1].split('(') schema = parms.split(',')[0] query = [ "SELECT DropGeometryColumn ('%(schema)s', \ '%(table)s', '%(field)s');" % dict(schema=schema, table=tablename, field=key) ] elif self.dbengine in ('firebird',): query = ['ALTER TABLE %s DROP %s;' % (self.QUOTE_TEMPLATE % tablename, self.QUOTE_TEMPLATE % key)] else: query = ['ALTER TABLE %s DROP COLUMN %s;' % (self.QUOTE_TEMPLATE % tablename, self.QUOTE_TEMPLATE % key)] metadata_change = True elif sql_fields[key]['sql'] != sql_fields_old[key]['sql'] \ and not (key in table.fields and isinstance(table[key].type, SQLCustomType)) \ and not sql_fields[key]['type'].startswith('reference')\ and not sql_fields[key]['type'].startswith('double')\ and not sql_fields[key]['type'].startswith('id'): sql_fields_current[key] = sql_fields[key] t = tablename tt = sql_fields_aux[key]['sql'].replace(', ', new_add) if self.dbengine in ('firebird',): drop_expr = 'ALTER TABLE %s DROP %s;' else: drop_expr = 'ALTER TABLE %s DROP COLUMN %s;' key_tmp = key + '__tmp' query = ['ALTER TABLE %s ADD %s %s;' % (self.QUOTE_TEMPLATE % t, self.QUOTE_TEMPLATE % key_tmp, tt), 'UPDATE %s SET %s=%s;' % (self.QUOTE_TEMPLATE % t, self.QUOTE_TEMPLATE % key_tmp, self.QUOTE_TEMPLATE % key), drop_expr % (self.QUOTE_TEMPLATE % t, self.QUOTE_TEMPLATE % key), 'ALTER TABLE %s ADD %s %s;' % (self.QUOTE_TEMPLATE % t, self.QUOTE_TEMPLATE % key, tt), 'UPDATE %s SET %s=%s;' % (self.QUOTE_TEMPLATE % t, self.QUOTE_TEMPLATE % key, self.QUOTE_TEMPLATE % key_tmp), drop_expr % (self.QUOTE_TEMPLATE % t, self.QUOTE_TEMPLATE % key_tmp)] metadata_change = True elif sql_fields[key]['type'] != sql_fields_old[key]['type']: sql_fields_current[key] = sql_fields[key] metadata_change = True if query: self.log('timestamp: %s\n' % datetime.datetime.today().isoformat(), table) db['_lastsql'] = '\n'.join(query) for sub_query in query: self.log(sub_query + '\n', table) if fake_migrate: if db._adapter.commit_on_alter_table: self.save_dbt(table,sql_fields_current) self.log('faked!\n', table) else: self.execute(sub_query) # Caveat: mysql, oracle and firebird # do not allow multiple alter table # in one transaction so we must commit # partial transactions and # update table._dbt after alter table. if db._adapter.commit_on_alter_table: db.commit() self.save_dbt(table,sql_fields_current) self.log('success!\n', table) elif metadata_change: self.save_dbt(table,sql_fields_current) if metadata_change and not (query and db._adapter.commit_on_alter_table): db.commit() self.save_dbt(table,sql_fields_current) self.log('success!\n', table)
[docs] def save_dbt(self,table, sql_fields_current): tfile = self.file_open(table._dbt, 'wb') pickle.dump(sql_fields_current, tfile) self.file_close(tfile)
[docs] def LOWER(self, first): return 'LOWER(%s)' % self.expand(first)
[docs] def UPPER(self, first): return 'UPPER(%s)' % self.expand(first)
[docs] def COUNT(self, first, distinct=None): return ('COUNT(%s)' if not distinct else 'COUNT(DISTINCT %s)') \ % self.expand(first)
[docs] def EXTRACT(self, first, what): return "EXTRACT(%s FROM %s)" % (what, self.expand(first))
[docs] def EPOCH(self, first): return self.EXTRACT(first, 'epoch')
[docs] def LENGTH(self, first): return "LENGTH(%s)" % self.expand(first)
[docs] def AGGREGATE(self, first, what): return "%s(%s)" % (what, self.expand(first))
[docs] def JOIN(self): return 'JOIN'
[docs] def LEFT_JOIN(self): return 'LEFT JOIN'
[docs] def RANDOM(self): return 'Random()'
[docs] def NOT_NULL(self, default, field_type): return 'NOT NULL DEFAULT %s' % self.represent(default,field_type)
[docs] def COALESCE(self, first, second): expressions = [self.expand(first)]+[self.expand(e) for e in second] return 'COALESCE(%s)' % ','.join(expressions)
[docs] def COALESCE_ZERO(self, first): return self.COALESCE(first, [0])
[docs] def RAW(self, first): return first
[docs] def ALLOW_NULL(self): return ''
[docs] def SUBSTRING(self, field, parameters): return 'SUBSTR(%s,%s,%s)' % (self.expand(field), parameters[0], parameters[1])
[docs] def PRIMARY_KEY(self, key): return 'PRIMARY KEY(%s)' % key
# SQL statement for dropping table def _drop(self, table, mode): return ['DROP TABLE %s;' % table.sqlsafe] # PYDAL cleanup def _drop_cleanup(self, table): db = table._db del db[table._tablename] del db.tables[db.tables.index(table._tablename)] db._remove_references_to(table) if table._dbt: self.file_delete(table._dbt) self.log('success!\n', table) return
[docs] def drop(self, table, mode=''): db = table._db queries = self._drop(table, mode) for query in queries: if table._dbt: self.log(query + '\n', table) self.execute(query) db.commit() self._drop_cleanup(table) return
def _insert(self, table, fields): table_rname = table.sqlsafe if fields: keys = ','.join(f.sqlsafe_name for f, v in fields) values = ','.join(self.expand(v, f.type) for f, v in fields) return 'INSERT INTO %s(%s) VALUES (%s);' % (table_rname, keys, values) else: return self._insert_empty(table) def _insert_empty(self, table): return 'INSERT INTO %s DEFAULT VALUES;' % (table.sqlsafe)
[docs] def insert(self, table, fields): query = self._insert(table,fields) try: self.execute(query) except Exception: e = sys.exc_info()[1] if hasattr(table,'_on_insert_error'): return table._on_insert_error(table,fields,e) raise e if hasattr(table, '_primarykey'): mydict = dict([(k[0].name, k[1]) for k in fields if k[0].name in table._primarykey]) if mydict != {}: return mydict id = self.lastrowid(table) if hasattr(table, '_primarykey') and len(table._primarykey) == 1: id = {table._primarykey[0]: id} if not isinstance(id, (int, long)): return id rid = Reference(id) (rid._table, rid._record) = (table, None) return rid
[docs] def bulk_insert(self, table, items): return [self.insert(table,item) for item in items]
[docs] def NOT(self, first): return '(NOT %s)' % self.expand(first)
[docs] def AND(self, first, second): return '(%s AND %s)' % (self.expand(first), self.expand(second))
[docs] def OR(self, first, second): return '(%s OR %s)' % (self.expand(first), self.expand(second))
[docs] def BELONGS(self, first, second): if isinstance(second, str): return '(%s IN (%s))' % (self.expand(first), second[:-1]) if not second: return '(1=0)' items = ','.join(self.expand(item, first.type) for item in second) return '(%s IN (%s))' % (self.expand(first), items)
[docs] def REGEXP(self, first, second): """Regular expression operator""" raise NotImplementedError
[docs] def like_escaper_default(self, term): if isinstance(term, Expression): return term term = term.replace('\\', '\\\\') term = term.replace('%', '\%').replace('_', '\_') return term
[docs] def LIKE(self, first, second, escape=None): """Case sensitive like operator""" if isinstance(second, Expression): second = self.expand(second, 'string') else: second = self.expand(second, 'string') if escape is None: escape = '\\' second = second.replace(escape, escape * 2) return "(%s LIKE %s ESCAPE '%s')" % (self.expand(first), second, escape)
[docs] def ILIKE(self, first, second, escape=None): """Case insensitive like operator""" if isinstance(second, Expression): second = self.expand(second, 'string') else: second = self.expand(second, 'string').lower() if escape is None: escape = '\\' second = second.replace(escape, escape*2) return "(LOWER(%s) LIKE %s ESCAPE '%s')" % (self.expand(first), second, escape)
[docs] def STARTSWITH(self, first, second): return "(%s LIKE %s ESCAPE '\\')" % (self.expand(first), self.expand(self.like_escaper_default(second)+'%', 'string'))
[docs] def ENDSWITH(self, first, second): return "(%s LIKE %s ESCAPE '\\')" % (self.expand(first), self.expand('%'+self.like_escaper_default(second), 'string'))
[docs] def CONTAINS(self, first, second, case_sensitive=True): if first.type in ('string','text', 'json'): if isinstance(second,Expression): second = Expression(second.db, self.CONCAT('%',Expression( second.db, self.REPLACE(second,('%','\%'))),'%')) else: second = '%'+self.like_escaper_default(str(second))+'%' elif first.type.startswith('list:'): if isinstance(second,Expression): second = Expression(second.db, self.CONCAT( '%|',Expression(second.db, self.REPLACE( Expression(second.db, self.REPLACE( second,('%','\%'))),('|','||'))),'|%')) else: second = str(second).replace('|', '||') second = '%|'+self.like_escaper_default(second)+'|%' op = case_sensitive and self.LIKE or self.ILIKE return op(first,second,escape='\\')
[docs] def EQ(self, first, second=None): if second is None: return '(%s IS NULL)' % self.expand(first) return '(%s = %s)' % (self.expand(first), self.expand(second, first.type))
[docs] def NE(self, first, second=None): if second is None: return '(%s IS NOT NULL)' % self.expand(first) return '(%s <> %s)' % (self.expand(first), self.expand(second, first.type))
[docs] def LT(self,first,second=None): if second is None: raise RuntimeError("Cannot compare %s < None" % first) return '(%s < %s)' % (self.expand(first), self.expand(second,first.type))
[docs] def LE(self,first,second=None): if second is None: raise RuntimeError("Cannot compare %s <= None" % first) return '(%s <= %s)' % (self.expand(first), self.expand(second,first.type))
[docs] def GT(self,first,second=None): if second is None: raise RuntimeError("Cannot compare %s > None" % first) return '(%s > %s)' % (self.expand(first), self.expand(second,first.type))
[docs] def GE(self,first,second=None): if second is None: raise RuntimeError("Cannot compare %s >= None" % first) return '(%s >= %s)' % (self.expand(first), self.expand(second,first.type))
[docs] def is_numerical_type(self, ftype): return ftype in ('integer','boolean','double','bigint') or \ ftype.startswith('decimal')
[docs] def REPLACE(self, first, tup): second, third = tup return 'REPLACE(%s,%s,%s)' % (self.expand(first,'string'), self.expand(second,'string'), self.expand(third,'string'))
[docs] def CONCAT(self, *items): return '(%s)' % ' || '.join(self.expand(x,'string') for x in items)
[docs] def ADD(self, first, second): if self.is_numerical_type(first.type) or isinstance(first.type, Field): return '(%s + %s)' % (self.expand(first), self.expand(second, first.type)) else: return self.CONCAT(first, second)
[docs] def SUB(self, first, second): return '(%s - %s)' % (self.expand(first), self.expand(second, first.type))
[docs] def MUL(self, first, second): return '(%s * %s)' % (self.expand(first), self.expand(second, first.type))
[docs] def DIV(self, first, second): return '(%s / %s)' % (self.expand(first), self.expand(second, first.type))
[docs] def MOD(self, first, second): return '(%s %% %s)' % (self.expand(first), self.expand(second, first.type))
[docs] def AS(self, first, second): return '%s AS %s' % (self.expand(first), second)
[docs] def ON(self, first, second): table_rname = self.table_alias(first) if use_common_filters(second): second = self.common_filter(second,[first._tablename]) return ('%s ON %s') % (self.expand(table_rname), self.expand(second))
[docs] def INVERT(self, first): return '%s DESC' % self.expand(first)
[docs] def COMMA(self, first, second): return '%s, %s' % (self.expand(first), self.expand(second))
[docs] def CAST(self, first, second): return 'CAST(%s AS %s)' % (first, second)
[docs] def expand(self, expression, field_type=None, colnames=False): if isinstance(expression, Field): et = expression.table if not colnames: table_rname = et._ot and self.QUOTE_TEMPLATE % et._tablename \ or et._rname or self.QUOTE_TEMPLATE % et._tablename rv = '%s.%s' % (table_rname, expression._rname or (self.QUOTE_TEMPLATE % (expression.name))) else: rv = '%s.%s' % (self.QUOTE_TEMPLATE % et._tablename, self.QUOTE_TEMPLATE % expression.name) if field_type == 'string' and expression.type not in ( 'string', 'text', 'json', 'password'): rv = self.CAST(rv, self.types['text']) elif isinstance(expression, (Expression, Query)): first = expression.first second = expression.second op = expression.op optional_args = expression.optional_args or {} if second is not None: rv = op(first, second, **optional_args) elif first is not None: rv = op(first, **optional_args) elif isinstance(op, str): if op.endswith(';'): op = op[:-1] rv = '(%s)' % op else: rv = op() elif field_type: rv = self.represent(expression, field_type) elif isinstance(expression, (list, tuple)): rv = ','.join(self.represent(item, field_type) for item in expression) elif isinstance(expression, bool): rv = self.db._adapter.TRUE_exp if expression else \ self.db._adapter.FALSE_exp else: rv = expression return str(rv)
[docs] def table_alias(self, tbl): if not isinstance(tbl, Table): tbl = self.db[tbl] return tbl.sqlsafe_alias
[docs] def alias(self, table, alias): """ Given a table object, makes a new table object with alias name. """ other = copy.copy(table) other['_ot'] = other._ot or other.sqlsafe other['ALL'] = SQLALL(other) other['_tablename'] = alias for fieldname in other.fields: other[fieldname] = copy.copy(other[fieldname]) other[fieldname]._tablename = alias other[fieldname].tablename = alias other[fieldname].table = other table._db[alias] = other return other
def _truncate(self, table, mode=''): return ['TRUNCATE TABLE %s %s;' % (table.sqlsafe, mode or '')]
[docs] def truncate(self, table, mode= ' '): # Prepare functions "write_to_logfile" and "close_logfile" try: queries = table._db._adapter._truncate(table, mode) for query in queries: self.log(query + '\n', table) self.execute(query) self.log('success!\n', table) finally: pass
def _update(self, tablename, query, fields): if query: if use_common_filters(query): query = self.common_filter(query, [tablename]) sql_w = ' WHERE ' + self.expand(query) else: sql_w = '' sql_v = ','.join(['%s=%s' % (field.sqlsafe_name, self.expand(value, field.type)) \ for (field, value) in fields]) tablename = self.db[tablename].sqlsafe return 'UPDATE %s SET %s%s;' % (tablename, sql_v, sql_w)
[docs] def update(self, tablename, query, fields): sql = self._update(tablename, query, fields) try: self.execute(sql) except Exception: e = sys.exc_info()[1] table = self.db[tablename] if hasattr(table,'_on_update_error'): return table._on_update_error(table,query,fields,e) raise e try: return self.cursor.rowcount except: return None
def _delete(self, tablename, query): if query: if use_common_filters(query): query = self.common_filter(query, [tablename]) sql_w = ' WHERE ' + self.expand(query) else: sql_w = '' tablename = self.db[tablename].sqlsafe return 'DELETE FROM %s%s;' % (tablename, sql_w)
[docs] def delete(self, tablename, query): sql = self._delete(tablename, query) self.execute(sql) try: counter = self.cursor.rowcount except: counter = None return counter
[docs] def get_table(self, *queries): tablenames = self.tables(*queries) if len(tablenames)==1: return tablenames[0] elif len(tablenames)<1: raise RuntimeError("No table selected") else: raise RuntimeError("Too many tables selected (%s)" % str(tablenames))
[docs] def expand_all(self, fields, tablenames): db = self.db new_fields = [] append = new_fields.append for item in fields: if isinstance(item,SQLALL): new_fields += item._table elif isinstance(item,str): m = self.REGEX_TABLE_DOT_FIELD.match(item) if m: tablename,fieldname = m.groups() append(db[tablename][fieldname]) else: append(Expression(db,lambda item=item:item)) else: append(item) # ## if no fields specified take them all from the requested tables if not new_fields: for table in tablenames: for field in db[table]: append(field) return new_fields
def _select(self, query, fields, attributes): tables = self.tables for key in set(attributes.keys())-SELECT_ARGS: raise SyntaxError('invalid select attribute: %s' % key) args_get = attributes.get tablenames = tables(query) tablenames_for_common_filters = tablenames outer_scoped = [t._tablename for t in args_get('outer_scoped',[])] for field in fields: for tablename in tables(field): if not tablename in tablenames: tablenames.append(tablename) if len(tablenames) < 1: raise SyntaxError('Set: no tables selected') def colexpand(field): return self.expand(field, colnames=True) self._colnames = list(map(colexpand, fields)) def geoexpand(field): if isinstance(field.type,str) and field.type.startswith('geo') and isinstance(field, Field): field = field.st_astext() return self.expand(field) sql_f = ', '.join(map(geoexpand, fields)) sql_o = '' sql_s = '' left = args_get('left', False) inner_join = args_get('join', False) distinct = args_get('distinct', False) groupby = args_get('groupby', False) orderby = args_get('orderby', False) having = args_get('having', False) limitby = args_get('limitby', False) orderby_on_limitby = args_get('orderby_on_limitby', True) for_update = args_get('for_update', False) if self.can_select_for_update is False and for_update is True: raise SyntaxError('invalid select attribute: for_update') if distinct is True: sql_s += 'DISTINCT' elif distinct: sql_s += 'DISTINCT ON (%s)' % distinct if inner_join: icommand = self.JOIN() if not isinstance(inner_join, (tuple, list)): inner_join = [inner_join] ijoint = [t._tablename for t in inner_join if not isinstance(t,Expression)] ijoinon = [t for t in inner_join if isinstance(t, Expression)] itables_to_merge={} #issue 490 [itables_to_merge.update( dict.fromkeys(tables(t))) for t in ijoinon] ijoinont = [t.first._tablename for t in ijoinon] [itables_to_merge.pop(t) for t in ijoinont if t in itables_to_merge] #issue 490 iimportant_tablenames = ijoint + ijoinont + list(itables_to_merge.keys()) iexcluded = [t for t in tablenames if not t in iimportant_tablenames] if left: join = attributes['left'] command = self.LEFT_JOIN() if not isinstance(join, (tuple, list)): join = [join] joint = [t._tablename for t in join if not isinstance(t, Expression)] joinon = [t for t in join if isinstance(t, Expression)] #patch join+left patch (solves problem with ordering in left joins) tables_to_merge={} [tables_to_merge.update( dict.fromkeys(tables(t))) for t in joinon] joinont = [t.first._tablename for t in joinon] [tables_to_merge.pop(t) for t in joinont if t in tables_to_merge] tablenames_for_common_filters = [t for t in tablenames if not t in joinont ] important_tablenames = joint + joinont + list(tables_to_merge.keys()) excluded = [t for t in tablenames if not t in important_tablenames ] else: excluded = tablenames tablenames = [t for t in tablenames if t not in outer_scoped] if use_common_filters(query): query = self.common_filter(query,tablenames_for_common_filters) sql_w = ' WHERE ' + self.expand(query) if query else '' JOIN = ' CROSS JOIN ' if inner_join and not left: # Wrap table references with parenthesis (approach 1) # sql_t = ', '.join([self.table_alias(t) for t in iexcluded + \ # itables_to_merge.keys()]) # sql_t = '(%s)' % sql_t # or approach 2: Use 'JOIN' instead comma: sql_t = JOIN.join([self.table_alias(t) for t in iexcluded + list(itables_to_merge.keys())]) for t in ijoinon: sql_t += ' %s %s' % (icommand, t) elif not inner_join and left: sql_t = JOIN.join([self.table_alias(t) for t in excluded + \ list(tables_to_merge.keys())]) if joint: sql_t += ' %s %s' % (command, ','.join([t for t in joint])) for t in joinon: sql_t += ' %s %s' % (command, t) elif inner_join and left: all_tables_in_query = set(important_tablenames + \ iimportant_tablenames + \ tablenames) tables_in_joinon = set(joinont + ijoinont) tables_not_in_joinon = \ all_tables_in_query.difference(tables_in_joinon) sql_t = JOIN.join([self.table_alias(t) for t in tables_not_in_joinon]) for t in ijoinon: sql_t += ' %s %s' % (icommand, t) if joint: sql_t += ' %s %s' % (command, ','.join([t for t in joint])) for t in joinon: sql_t += ' %s %s' % (command, t) else: sql_t = ', '.join(self.table_alias(t) for t in tablenames) if groupby: if isinstance(groupby, (list, tuple)): groupby = xorify(groupby) sql_o += ' GROUP BY %s' % self.expand(groupby) if having: sql_o += ' HAVING %s' % attributes['having'] if orderby: if isinstance(orderby, (list, tuple)): orderby = xorify(orderby) if str(orderby) == '<random>': sql_o += ' ORDER BY %s' % self.RANDOM() else: sql_o += ' ORDER BY %s' % self.expand(orderby) if (limitby and not groupby and tablenames and orderby_on_limitby and not orderby): sql_o += ' ORDER BY %s' % ', '.join( [self.db[t].sqlsafe + '.' + self.db[t][x].sqlsafe_name for t in tablenames for x in ( hasattr(self.db[t], '_primarykey') and self.db[t]._primarykey or ['_id'] ) ] ) # oracle does not support limitby sql = self.select_limitby(sql_s, sql_f, sql_t, sql_w, sql_o, limitby) if for_update and self.can_select_for_update is True: sql = sql.rstrip(';') + ' FOR UPDATE;' return sql
[docs] def select_limitby(self, sql_s, sql_f, sql_t, sql_w, sql_o, limitby): if limitby: (lmin, lmax) = limitby sql_o += ' LIMIT %i OFFSET %i' % (lmax - lmin, lmin) return 'SELECT %s %s FROM %s%s%s;' % \ (sql_s, sql_f, sql_t, sql_w, sql_o)
def _fetchall(self): return self.cursor.fetchall() def _fetchone(self): return self.cursor.fetchone() def _select_aux(self, sql, fields, attributes): args_get = attributes.get cache = args_get('cache',None) if not cache: self.execute(sql) rows = self._fetchall() else: if isinstance(cache, dict): cache_model = cache['model'] time_expire = cache['expiration'] key = cache.get('key') if not key: key = self.uri + '/' + sql + '/rows' key = hashlib_md5(key).hexdigest() else: (cache_model, time_expire) = cache key = self.uri + '/' + sql + '/rows' key = hashlib_md5(key).hexdigest() def _select_aux2(): self.execute(sql) return self._fetchall() rows = cache_model(key,_select_aux2,time_expire) if isinstance(rows,tuple): rows = list(rows) limitby = args_get('limitby', None) or (0,) rows = self.rowslice(rows,limitby[0],None) processor = args_get('processor', self.parse) cacheable = args_get('cacheable',False) return processor(rows,fields,self._colnames,cacheable=cacheable)
[docs] def select(self, query, fields, attributes): """ Always returns a Rows object, possibly empty. """ sql = self._select(query, fields, attributes) cache = attributes.get('cache', None) if cache and attributes.get('cacheable',False): del attributes['cache'] (cache_model, time_expire) = cache key = self.uri + '/' + sql key = hashlib_md5(key).hexdigest() args = (sql,fields,attributes) return cache_model( key, lambda self=self,args=args:self._select_aux(*args), time_expire) else: return self._select_aux(sql,fields,attributes)
[docs] def iterselect(self, query, fields, attributes): sql = self._select(query, fields, attributes) cacheable = attributes.get('cacheable', False) return self.iterparse(sql, fields, self._colnames, cacheable=cacheable)
def _count(self, query, distinct=None): tablenames = self.tables(query) if query: if use_common_filters(query): query = self.common_filter(query, tablenames) sql_w = ' WHERE ' + self.expand(query) else: sql_w = '' sql_t = ','.join(self.table_alias(t) for t in tablenames) if distinct: if isinstance(distinct,(list, tuple)): distinct = xorify(distinct) sql_d = self.expand(distinct) return 'SELECT count(DISTINCT %s) FROM %s%s;' % \ (sql_d, sql_t, sql_w) return 'SELECT count(*) FROM %s%s;' % (sql_t, sql_w)
[docs] def count(self, query, distinct=None): self.execute(self._count(query, distinct)) return self.cursor.fetchone()[0]
[docs] def tables(self, *queries): tables = set() for query in queries: if isinstance(query, Field): tables.add(query.tablename) elif isinstance(query, (Expression, Query)): if not query.first is None: tables = tables.union(self.tables(query.first)) if not query.second is None: tables = tables.union(self.tables(query.second)) return list(tables)
[docs] def commit(self): if self.connection: return self.connection.commit()
[docs] def rollback(self): if self.connection: return self.connection.rollback()
[docs] def close_connection(self): if self.connection: r = self.connection.close() self.connection = None return r
[docs] def distributed_transaction_begin(self, key): return
[docs] def prepare(self, key): if self.connection: self.connection.prepare()
[docs] def commit_prepared(self, key): if self.connection: self.connection.commit()
[docs] def rollback_prepared(self, key): if self.connection: self.connection.rollback()
[docs] def concat_add(self, tablename): return ', ADD '
[docs] def constraint_name(self, table, fieldname): return '%s_%s__constraint' % (table,fieldname)
[docs] def create_sequence_and_triggers(self, query, table, **args): self.execute(query)
[docs] def log_execute(self, *a, **b): if not self.connection: raise ValueError(a[0]) if not self.connection: return None command = a[0] if hasattr(self,'filter_sql_command'): command = self.filter_sql_command(command) if self.db._debug: self.db.logger.debug('SQL: %s' % command) self.db._lastsql = command t0 = time.time() ret = self.get_cursor().execute(command, *a[1:], **b) self.db._timings.append((command,time.time()-t0)) del self.db._timings[:-TIMINGSSIZE] return ret
[docs] def execute(self, *a, **b): return self.log_execute(*a, **b)
[docs] def execute_test_query(self): return self.execute(self.test_query)
[docs] def represent(self, obj, fieldtype): field_is_type = fieldtype.startswith if isinstance(obj, CALLABLETYPES): obj = obj() if isinstance(fieldtype, SQLCustomType): value = fieldtype.encoder(obj) if value and fieldtype.type in ('string', 'text', 'json'): return self.adapt(value) return value or 'NULL' if isinstance(obj, (Expression, Field)): return str(obj) if field_is_type('list:'): if not obj: obj = [] elif not isinstance(obj, (list, tuple)): obj = [obj] if field_is_type('list:string'): if PY2: try: obj = map(str, obj) except: obj = map(lambda x: unicode(x).encode(self.db_codec), obj) else: obj = list(map(str,obj)) else: obj = list(map(int,[o for o in obj if o != ''])) # we don't want to bar_encode json objects if isinstance(obj, (list, tuple)) and (not fieldtype == "json"): obj = bar_encode(obj) if obj is None: return 'NULL' if obj == '' and not fieldtype[:2] in ['st', 'te', 'js', 'pa', 'up']: return 'NULL' r = self.represent_exceptions(obj, fieldtype) if r is not None: return r if fieldtype == 'boolean': if obj and not str(obj)[:1].upper() in '0F': return self.smart_adapt(self.TRUE) else: return self.smart_adapt(self.FALSE) if fieldtype == 'id' or fieldtype == 'integer': return str(long(obj)) if field_is_type('decimal'): return str(obj) elif field_is_type('reference'): # reference # check for tablename first referenced = fieldtype[9:].strip() if referenced in self.db.tables: return str(long(obj)) p = referenced.partition('.') if p[2] != '': try: ftype = self.db[p[0]][p[2]].type return self.represent(obj, ftype) except (ValueError, KeyError): return repr(obj) elif isinstance(obj, (Row, Reference)): return str(obj['id']) return str(long(obj)) elif fieldtype == 'double': return repr(float(obj)) if PY2 and isinstance(obj, unicode): obj = obj.encode(self.db_codec) if fieldtype == 'blob': if PY2: obj = base64.b64encode(str(obj)) else: obj = base64.b64encode(obj.encode('utf-8')) elif fieldtype == 'date': if isinstance(obj, (datetime.date, datetime.datetime)): obj = obj.isoformat()[:10] else: obj = str(obj) elif fieldtype == 'datetime': if isinstance(obj, datetime.datetime): obj = obj.isoformat(self.T_SEP)[:19] elif isinstance(obj, datetime.date): obj = obj.isoformat()[:10]+self.T_SEP+'00:00:00' else: obj = str(obj) elif fieldtype == 'time': if isinstance(obj, datetime.time): obj = obj.isoformat()[:10] else: obj = str(obj) elif fieldtype == 'json': if not 'dumps' in self.driver_auto_json: # always pass a string JSON string obj = serializers.json(obj) if PY2: if not isinstance(obj, bytes): obj = bytes(obj) try: obj.decode(self.db_codec) except: obj = obj.decode('latin1').encode(self.db_codec) else: obj = to_unicode(obj) return self.adapt(obj)
[docs] def represent_exceptions(self, obj, fieldtype): return None
[docs] def lastrowid(self, table): return self.cursor.lastrowid
[docs] def rowslice(self, rows, minimum=0, maximum=None): """ By default this function does nothing; overload when db does not do slicing. """ return rows
[docs] def parse_value(self, value, field_type, blob_decode=True): if field_type != 'blob' and isinstance(value, str): try: value = value.decode(self.db._db_codec) except Exception: pass if PY2 and isinstance(value, unicode): value = value.encode('utf-8') if isinstance(field_type, SQLCustomType): value = field_type.decoder(value) if not isinstance(field_type, str) or value is None: return value elif field_type in ('string', 'text', 'password', 'upload', 'dict'): return value elif field_type.startswith('geo'): return value elif field_type == 'blob' and not blob_decode: return value else: key = REGEX_TYPE.match(field_type).group(0) return self.parsemap[key](value,field_type)
[docs] def parse_reference(self, value, field_type): referee = field_type[10:].strip() if not '.' in referee: value = Reference(value) value._table, value._record = self.db[referee], None return value
[docs] def parse_boolean(self, value, field_type): return value == self.TRUE or str(value)[:1].lower() == 't'
[docs] def parse_date(self, value, field_type): if isinstance(value, datetime.datetime): # Extract the date portion from the datetime return value.date() if not isinstance(value, (datetime.date,datetime.datetime)): (y, m, d) = map(int, str(value)[:10].strip().split('-')) value = datetime.date(y, m, d) return value
[docs] def parse_time(self, value, field_type): if isinstance(value, datetime.datetime): # Extract the time portion from the datetime return value.time() if not isinstance(value, datetime.time): time_items = list(map(int,str(value)[:8].strip().split(':')[:3])) if len(time_items) == 3: (h, mi, s) = time_items else: (h, mi, s) = time_items + [0] value = datetime.time(h, mi, s) return value
[docs] def parse_datetime(self, value, field_type): if not isinstance(value, datetime.datetime): value = str(value) date_part,time_part,timezone = value[:10],value[11:19],value[19:] if '+' in timezone: ms,tz = timezone.split('+') h,m = tz.split(':') dt = datetime.timedelta(seconds=3600*int(h)+60*int(m)) elif '-' in timezone: ms,tz = timezone.split('-') h,m = tz.split(':') dt = -datetime.timedelta(seconds=3600*int(h)+60*int(m)) else: ms = timezone.upper().split('Z')[0] dt = None (y, m, d) = map(int,date_part.split('-')) time_parts = time_part and time_part.split(':')[:3] or (0,0,0) while len(time_parts)<3: time_parts.append(0) time_items = map(int,time_parts) (h, mi, s) = time_items if ms and ms[0] == '.': ms = int(float('0' + ms) * 1000000) else: ms = 0 value = datetime.datetime(y, m, d, h, mi, s, ms) if dt: value = value + dt return value
[docs] def parse_blob(self, value, field_type): if PY2: return base64.b64decode(str(value)) else: # TODO # better implement the check, this is for py3.3.x and psycopg2 # (why is not bytes/str) ? if not isinstance(value, (bytes, str)): value = bytes(value) return base64.b64decode(value).decode('utf-8')
[docs] def parse_decimal(self, value, field_type): decimals = int(field_type[8:-1].split(',')[-1]) if self.dbengine in ('sqlite', 'spatialite'): value = ('%.' + str(decimals) + 'f') % value if not isinstance(value, decimal.Decimal): value = decimal.Decimal(str(value)) return value
[docs] def parse_list_integers(self, value, field_type): value = bar_decode_integer(value) return value
[docs] def parse_list_references(self, value, field_type): value = bar_decode_integer(value) return [self.parse_reference(r, field_type[5:]) for r in value]
[docs] def parse_list_strings(self, value, field_type): value = bar_decode_string(value) return value
[docs] def parse_id(self, value, field_type): return long(value)
[docs] def parse_integer(self, value, field_type): return long(value)
[docs] def parse_double(self, value, field_type): return float(value)
[docs] def parse_json(self, value, field_type): if 'loads' not in self.driver_auto_json: if not isinstance(value, basestring): raise RuntimeError('json data not a string') if PY2 and isinstance(value, unicode): value = value.encode('utf-8') value = json.loads(value) return value
[docs] def build_parsemap(self): self.parsemap = { 'id': self.parse_id, 'integer': self.parse_integer, 'bigint': self.parse_integer, 'float': self.parse_double, 'double': self.parse_double, 'reference': self.parse_reference, 'boolean': self.parse_boolean, 'date': self.parse_date, 'time': self.parse_time, 'datetime': self.parse_datetime, 'blob': self.parse_blob, 'decimal': self.parse_decimal, 'json': self.parse_json, 'list:integer': self.parse_list_integers, 'list:reference': self.parse_list_references, 'list:string': self.parse_list_strings, }
def _parse(self, row, tmps, fields, colnames, blob_decode, cacheable, fields_virtual, fields_lazy): """ Return a parsed row """ new_row = self.db.Row( dict((tablename, self.db.Row()) for tablename in fields_virtual.keys())) for (j, colname) in enumerate(colnames): value = row[j] tmp = tmps[j] tablename = None if tmp: (tablename, fieldname, table, field, ft) = tmp colset = new_row.get(tablename, None) if colset is None: colset = new_row[tablename] = self.db.Row() value = self.parse_value(value, ft, blob_decode) if field.filter_out: value = field.filter_out(value) colset[fieldname] = value # for backward compatibility if ft == 'id' and fieldname != 'id' and \ 'id' not in table.fields: colset['id'] = value if ft == 'id' and not cacheable: if self.dbengine == 'google:datastore': id = value.key.id() colset[fieldname] = id colset.gae_item = value else: id = value colset.update_record = RecordUpdater(colset, table, id) colset.delete_record = RecordDeleter(table, id) if table._db._lazy_tables: colset['__get_lazy_reference__'] = \ LazyReferenceGetter(table, id) for rfield in table._referenced_by: referee_link = self.db._referee_name and \ self.db._referee_name % dict( table=rfield.tablename, field=rfield.name) if (referee_link and referee_link not in colset and referee_link != tablename): colset[referee_link] = LazySet(rfield, id) else: if '_extra' not in new_row: new_row['_extra'] = self.db.Row() value = self.parse_value(value, fields[j].type, blob_decode) new_row['_extra'][colname] = value new_column_name = self._regex_select_as_parser(colname) if new_column_name is not None: column_name = new_column_name.groups(0) setattr(new_row, column_name[0], value) for tablename in fields_virtual.keys(): for f, v in fields_virtual[tablename]: try: new_row[tablename][f] = v.f(new_row) except (AttributeError, KeyError): pass # not enough fields to define virtual field for f, v in fields_lazy[tablename]: try: new_row[tablename][f] = (v.handler or VirtualCommand)( v.f, new_row ) except (AttributeError, KeyError): pass # not enough fields to define virtual field return new_row def _regex_select_as_parser(self, colname): return REGEX_SELECT_AS_PARSER.search(colname) def _parse_expand_colnames(self, colnames): """ - Expand a list of colnames into a list of (tablename, fieldname, table_obj, field_obj, field_type) - Create a list of table for virtual/lazy fields """ fields_virtual = {} fields_lazy = {} tmps = [] for colname in colnames: col_m = self.REGEX_TABLE_DOT_FIELD.match(colname) if not col_m: tmps.append(None) else: tablename, fieldname = col_m.groups() table = self.db[tablename] field = table[fieldname] ft = field.type tmps.append((tablename, fieldname, table, field, ft)) if tablename not in fields_virtual: fields_virtual[tablename] = [ (f.name, f) for f in table._virtual_fields ] fields_lazy[tablename] = [ (f.name, f) for f in table._virtual_methods ] return (fields_virtual, fields_lazy, tmps)
[docs] def parse(self, rows, fields, colnames, blob_decode=True, cacheable=False): (fields_virtual, fields_lazy, tmps) = self._parse_expand_colnames(colnames) new_rows = [self._parse(row, tmps, fields, colnames, blob_decode, cacheable, fields_virtual, fields_lazy) for row in rows] rowsobj = Rows(self.db, new_rows, colnames, rawrows=rows) # Old stype virtual fields for tablename in fields_virtual.keys(): table = self.db[tablename] ### old style virtual fields for item in table.virtualfields: try: rowsobj = rowsobj.setvirtualfields(**{tablename: item}) except (KeyError, AttributeError): # to avoid breaking virtualfields when partial select pass return rowsobj
[docs] def iterparse(self, sql, fields, colnames, blob_decode=True, cacheable=False): """ Iterator to parse one row at a time. It doen't support the old style virtual fields """ return IterRows(self.db, sql, fields, colnames, blob_decode, cacheable)
[docs] def common_filter(self, query, tablenames): tenant_fieldname = self.db._request_tenant for tablename in tablenames: table = self.db[tablename] # deal with user provided filters if table._common_filter is not None: query = query & table._common_filter(query) # deal with multi_tenant filters if tenant_fieldname in table: default = table[tenant_fieldname].default if default is not None: newquery = table[tenant_fieldname] == default if query is None: query = newquery else: query = query & newquery return query
[docs] def CASE(self,query,t,f): return Expression(self.db, self.EXPAND_CASE, query, (t, f))
[docs] def EXPAND_CASE(self, query, true_false): def represent(x): types = {type(True):'boolean',type(0):'integer',type(1.0):'double'} if x is None: return 'NULL' elif isinstance(x,Expression): return str(x) else: return self.represent(x,types.get(type(x),'string')) return 'CASE WHEN %s THEN %s ELSE %s END' % ( self.expand(query), represent(true_false[0]), represent(true_false[1]))
[docs] def sqlsafe_table(self, tablename, ot=None): if ot is not None: return ('%s AS ' + self.QUOTE_TEMPLATE) % (ot, tablename) return self.QUOTE_TEMPLATE % tablename
[docs] def sqlsafe_field(self, fieldname): return self.QUOTE_TEMPLATE % fieldname
[docs] def can_join(self): return True
[docs]class NoSQLAdapter(BaseAdapter): can_select_for_update = False QUOTE_TEMPLATE = '%s' def __init__(self, db, uri, pool_size=0, folder=None, db_codec='UTF-8', credential_decoder=IDENTITY, driver_args={}, adapter_args={}, do_connect=True, after_connection=None): super(NoSQLAdapter, self).__init__( db=db, uri=uri, pool_size=pool_size, folder=folder, db_codec=db_codec, credential_decoder=credential_decoder, driver_args=driver_args, adapter_args=adapter_args, do_connect=do_connect, after_connection=after_connection) self.fake_cursor = FakeCursor()
[docs] def id_query(self, table): return table._id > 0
[docs] def execute_test_query(self): ''' NoSql DBs don't have a universal query language. Override this specifc driver if need to test connection status. Throw exception on failure. ''' return None
[docs] def represent(self, obj, fieldtype): field_is_type = fieldtype.startswith if isinstance(obj, CALLABLETYPES): obj = obj() if isinstance(fieldtype, SQLCustomType): return fieldtype.encoder(obj) is_string = isinstance(fieldtype,str) is_list = is_string and field_is_type('list:') if is_list: if not obj: obj = [] if not isinstance(obj, (list, tuple)): obj = [obj] obj = [item for item in obj if item is not None] if obj == '' and not \ (is_string and fieldtype[:2] in ['st','te', 'pa','up']): return None if not obj is None: if isinstance(obj, list) and not is_list: obj = [self.represent(o, fieldtype) for o in obj] elif fieldtype in ('integer','bigint','id'): obj = long(obj) elif fieldtype == 'double': obj = float(obj) elif is_string and field_is_type('reference'): if isinstance(obj, (Row, Reference)): obj = obj['id'] obj = long(obj) elif fieldtype == 'boolean': if obj and not str(obj)[0].upper() in '0F': obj = True else: obj = False elif fieldtype == 'date': if not isinstance(obj, datetime.date): (y, m, d) = map(int,str(obj).strip().split('-')) obj = datetime.date(y, m, d) elif isinstance(obj,datetime.datetime): (y, m, d) = (obj.year, obj.month, obj.day) obj = datetime.date(y, m, d) elif fieldtype == 'time': if not isinstance(obj, datetime.time): time_items = list(map(int,str(obj).strip().split(':')[:3])) if len(time_items) == 3: (h, mi, s) = time_items else: (h, mi, s) = time_items + [0] obj = datetime.time(h, mi, s) elif fieldtype == 'datetime': if not isinstance(obj, datetime.datetime): (y, m, d) = map(int,str(obj)[:10].strip().split('-')) time_items = list(map(int,str(obj)[11:].strip().split(':')[:3])) while len(time_items)<3: time_items.append(0) (h, mi, s) = time_items obj = datetime.datetime(y, m, d, h, mi, s) elif fieldtype == 'blob': pass elif fieldtype == 'json': if isinstance(obj, basestring): obj = to_unicode(obj) obj = json.loads(obj) elif is_string and field_is_type('list:string'): return list(map(to_unicode,obj)) elif is_list: return list(map(int,obj)) else: obj = to_unicode(obj) return obj
def _insert(self,table,fields): return 'insert %s in %s' % (fields, table) def _count(self,query,distinct=None): return 'count %s' % repr(query) def _select(self,query,fields,attributes): return 'select %s where %s' % (repr(fields), repr(query)) def _delete(self,tablename, query): return 'delete %s where %s' % (repr(tablename),repr(query)) def _update(self,tablename,query,fields): return 'update %s (%s) where %s' % (repr(tablename), repr(fields),repr(query))
[docs] def commit(self): """ remember: no transactions on many NoSQL """ pass
[docs] def rollback(self): """ remember: no transactions on many NoSQL """ pass
[docs] def close_connection(self): """ remember: no transactions on many NoSQL """ pass
[docs] def parse_list_integers(self, value, field_type): return value
[docs] def parse_list_references(self, value, field_type): return [self.parse_reference(r, field_type[5:]) for r in value]
[docs] def parse_list_strings(self, value, field_type): return value
# these functions should never be called!
[docs] def OR(self,first,second): raise SyntaxError("Not supported")
[docs] def AND(self,first,second): raise SyntaxError("Not supported")
[docs] def AS(self,first,second): raise SyntaxError("Not supported")
[docs] def ON(self,first,second): raise SyntaxError("Not supported")
[docs] def STARTSWITH(self,first,second=None): raise SyntaxError("Not supported")
[docs] def ENDSWITH(self,first,second=None): raise SyntaxError("Not supported")
[docs] def ADD(self,first,second): raise SyntaxError("Not supported")
[docs] def SUB(self,first,second): raise SyntaxError("Not supported")
[docs] def MUL(self,first,second): raise SyntaxError("Not supported")
[docs] def DIV(self,first,second): raise SyntaxError("Not supported")
[docs] def LOWER(self,first): raise SyntaxError("Not supported")
[docs] def UPPER(self,first): raise SyntaxError("Not supported")
[docs] def EXTRACT(self,first,what): raise SyntaxError("Not supported")
[docs] def LENGTH(self, first): raise SyntaxError("Not supported")
[docs] def AGGREGATE(self,first,what): raise SyntaxError("Not supported")
[docs] def LEFT_JOIN(self): raise SyntaxError("Not supported")
[docs] def RANDOM(self): raise SyntaxError("Not supported")
[docs] def SUBSTRING(self,field,parameters): raise SyntaxError("Not supported")
[docs] def PRIMARY_KEY(self,key): raise SyntaxError("Not supported")
[docs] def ILIKE(self,first,second): raise SyntaxError("Not supported")
[docs] def drop(self,table,mode): raise SyntaxError("Not supported")
[docs] def migrate_table(self,*a,**b): raise SyntaxError("Not supported")
[docs] def distributed_transaction_begin(self,key): raise SyntaxError("Not supported")
[docs] def prepare(self,key): raise SyntaxError("Not supported")
[docs] def commit_prepared(self,key): raise SyntaxError("Not supported")
[docs] def rollback_prepared(self,key): raise SyntaxError("Not supported")
[docs] def concat_add(self,table): raise SyntaxError("Not supported")
[docs] def constraint_name(self, table, fieldname): raise SyntaxError("Not supported")
[docs] def create_sequence_and_triggers(self, query, table, **args): pass
[docs] def log_execute(self,*a,**b): raise SyntaxError("Not supported")
[docs] def execute(self,*a,**b): raise SyntaxError("Not supported")
[docs] def represent_exceptions(self, obj, fieldtype): raise SyntaxError("Not supported")
[docs] def lastrowid(self,table): raise SyntaxError("Not supported")
[docs] def rowslice(self,rows,minimum=0,maximum=None): raise SyntaxError("Not supported")
[docs] def can_join(self): return False