Source code for pydal.adapters.mongo

# -*- coding: utf-8 -*-
import datetime
import re
import copy

from .._globals import IDENTITY
from .._compat import integer_types, basestring, PY2
from ..objects import Table, Query, Field, Expression, Row
from ..helpers.classes import SQLCustomType, SQLALL, Reference
from ..helpers.methods import use_common_filters, xorify
from .base import NoSQLAdapter, CALLABLETYPES, SELECT_ARGS

try:
    from bson import Binary
    from bson.binary import USER_DEFINED_SUBTYPE
except:
[docs] class Binary(object): pass
USER_DEFINED_SUBTYPE = 0 long = integer_types[-1] SUPPORTED_SELECT_ARGS = set(('limitby', 'orderby', 'orderby_on_limitby', 'groupby', 'distinct', 'having')) SQL_ONLY_SELECT_ARGS = set(('join', 'left')) NON_MONGO_SELECT_ARGS = set(('for_update',)) NOT_IMPLEMENTED_SELECT_ARGS = (SELECT_ARGS - SUPPORTED_SELECT_ARGS - SQL_ONLY_SELECT_ARGS - NON_MONGO_SELECT_ARGS)
[docs]class MongoDBAdapter(NoSQLAdapter): drivers = ('pymongo',) driver_auto_json = ['loads', 'dumps'] uploads_in_blob = False types = { 'boolean': bool, 'string': str, 'text': str, 'json': str, 'password': str, 'blob': str, 'upload': str, 'integer': long, 'bigint': long, 'float': float, 'double': float, 'date': datetime.date, 'time': datetime.time, 'datetime': datetime.datetime, 'id': long, 'reference': long, 'list:string': list, 'list:integer': list, 'list:reference': list, } GROUP_MARK = "__#GROUP#__" AS_MARK = "__#AS#__" REGEXP_MARK1 = "__#REGEXP_1#__" REGEXP_MARK2 = "__#REGEXP_2#__" def __init__(self, db, uri='mongodb://127.0.0.1:5984/db', pool_size=0, folder=None, db_codec='UTF-8', credential_decoder=IDENTITY, driver_args={}, adapter_args={}, do_connect=True, after_connection=None): super(MongoDBAdapter, self).__init__( db=db, uri=uri, pool_size=pool_size, folder=folder, db_codec=db_codec, credential_decoder=credential_decoder, driver_args=driver_args, adapter_args=adapter_args, do_connect=do_connect, after_connection=after_connection) if do_connect: self.find_driver(adapter_args) from pymongo import version if 'fake_version' in driver_args: version = driver_args['fake_version'] if int(version.split('.')[0]) < 3: raise Exception( "pydal requires pymongo version >= 3.0, found '%s'" % version) import random from bson.objectid import ObjectId from bson.son import SON import pymongo.uri_parser from pymongo.write_concern import WriteConcern m = pymongo.uri_parser.parse_uri(uri) self.epoch = datetime.datetime.fromtimestamp(0) self.SON = SON self.ObjectId = ObjectId self.random = random self.WriteConcern = WriteConcern self.dbengine = 'mongodb' db['_lastsql'] = '' self.db_codec = 'UTF-8' self.find_or_make_work_folder() # this is the minimum amount of replicates that it should wait # for on insert/update self.minimumreplication = adapter_args.get('minimumreplication', 0) # by default all inserts and selects are performed asynchronous, # but now the default is # synchronous, except when overruled by either this default or # function parameter self.safe = 1 if adapter_args.get('safe', True) else 0 if isinstance(m, tuple): m = {"database": m[1]} if m.get('database') is None: raise SyntaxError("Database is required!") def connector(uri=self.uri, m=m): driver = self.driver.MongoClient(uri, w=self.safe)[m.get('database')] driver.cursor = lambda: self.fake_cursor driver.close = lambda: None driver.commit = lambda: None return driver self.connector = connector self.reconnect() # _server_version is a string like '3.0.3' or '2.4.12' self._server_version = self.connection.command("serverStatus")['version'] self.server_version = tuple( [int(x) for x in self._server_version.split('.')]) self.server_version_major = ( self.server_version[0] + self.server_version[1] / 10.0)
[docs] def object_id(self, arg=None): """ Convert input to a valid Mongodb ObjectId instance self.object_id("<random>") -> ObjectId (not unique) instance """ if not arg: arg = 0 if isinstance(arg, basestring): # we assume an integer as default input rawhex = len(arg.replace("0x", "").replace("L", "")) == 24 if arg.isdigit() and (not rawhex): arg = int(arg) elif arg == "<random>": arg = int("0x%s" % "".join([self.random.choice("0123456789abcdef") for x in range(24)]), 0) elif arg.isalnum(): if not arg.startswith("0x"): arg = "0x%s" % arg try: arg = int(arg, 0) except ValueError as e: raise ValueError( "invalid objectid argument string: %s" % e) else: raise ValueError("Invalid objectid argument string. " + "Requires an integer or base 16 value") elif isinstance(arg, self.ObjectId): return arg elif isinstance(arg, (Row, Reference)): return self.object_id(long(arg['id'])) elif not isinstance(arg, (int, long)): raise TypeError("object_id argument must be of type " + "ObjectId or an objectid representable integer" + " (type %s)" % type(arg) ) hexvalue = hex(arg)[2:].rstrip('L').zfill(24) return self.ObjectId(hexvalue)
[docs] def parse_reference(self, value, field_type): # here we have to check for ObjectID before base parse if isinstance(value, self.ObjectId): value = long(str(value), 16) return super(MongoDBAdapter, self).parse_reference(value, field_type)
[docs] def parse_id(self, value, field_type): if isinstance(value, self.ObjectId): value = long(str(value), 16) return super(MongoDBAdapter, self).parse_id(value, field_type)
[docs] def represent(self, obj, fieldtype): if isinstance(obj, CALLABLETYPES): obj = obj() if isinstance(fieldtype, SQLCustomType): return fieldtype.encoder(obj) if isinstance(obj, self.ObjectId): value = obj elif fieldtype == 'id': value = self.object_id(obj) elif fieldtype in ['double', 'float']: value = None if obj is None else float(obj) elif fieldtype == 'date': if obj is None: return None # this piece of data can be stripped off based on the fieldtype t = datetime.time(0, 0, 0) # mongodb doesn't have a date object and so it must datetime, # string or integer return datetime.datetime.combine(obj, t) elif fieldtype == 'time': if obj is None: return None # this piece of data can be stripped off based on the fieldtype d = datetime.date(2000, 1, 1) # mongodb doesn't have a time object and so it must datetime, # string or integer return datetime.datetime.combine(d, obj) elif fieldtype == "blob": if isinstance(obj, basestring) and obj == '': obj = None return MongoBlob(obj) # reference types must be converted to ObjectID elif isinstance(fieldtype, basestring): if fieldtype.startswith('list:reference'): value = [self.object_id(v) for v in obj] elif fieldtype.startswith("reference") or fieldtype == "id": value = self.object_id(obj) else: value = NoSQLAdapter.represent(self, obj, fieldtype) elif isinstance(fieldtype, Table): raise NotImplementedError("How did you reach this line of code???") value = self.object_id(obj) else: value = NoSQLAdapter.represent(self, obj, fieldtype) return value
[docs] def parse_blob(self, value, field_type): return MongoBlob.decode(value)
REGEX_SELECT_AS_PARSER = re.compile("\\'" + AS_MARK + "\\': \\'(\\S+)\\'") def _regex_select_as_parser(self, colname): return self.REGEX_SELECT_AS_PARSER.search(colname) def _get_collection(self, tablename, safe=None): ctable = self.connection[tablename] if safe is not None and safe != self.safe: wc = self.WriteConcern(w=self._get_safe(safe)) ctable = ctable.with_options(write_concern=wc) return ctable def _get_safe(self, val=None): if val is None: return self.safe return 1 if val else 0
[docs] def create_table(self, table, migrate=True, fake_migrate=False, polymodel=None): table._dbt = None table._notnulls = [] for field_name in table.fields: if table[field_name].notnull: table._notnulls.append(field_name) table._uniques = [] for field_name in table.fields: if table[field_name].unique: # this is unnecessary if the fields are indexed and unique table._uniques.append(field_name)
[docs] class Expanded (object): """ Class to encapsulate a pydal expression and track the parse expansion and its results. Two different MongoDB mechanisms are targeted here. If the query is sufficiently simple, then simple queries are generated. The bulk of the complexity here is however to support more complex queries that are targeted to the MongoDB Aggregation Pipeline. This class supports four operations: 'count', 'select', 'update' and 'delete'. Behavior varies somewhat for each operation type. However building each pipeline stage is shared where the behavior is the same (or similar) for the different operations. In general an attempt is made to build the query without using the pipeline, and if that fails then the query is rebuilt with the pipeline. QUERY constructed in _build_pipeline_query(): $project : used to calculate expressions if needed $match: filters out records FIELDS constructed in _expand_fields(): FIELDS:COUNT $group : filter for distinct if needed $group: count the records remaining FIELDS:SELECT $group : implement aggregations if needed $project: implement expressions (etc) for select FIELDS:UPDATE $project: implement expressions (etc) for update HAVING constructed in _add_having(): $project : used to calculate expressions $match: filters out records $project : used to filter out previous expression fields """ def __init__ (self, adapter, crud, query, fields=(), tablename=None, groupby=None, distinct=False, having=None): self.adapter = adapter self._parse_data = {'pipeline': False, 'need_group': bool(groupby or distinct or having)} self.crud = crud self.having = having self.distinct = distinct if not groupby and distinct: if distinct is True: # groupby gets all fields self.groupby = fields else: self.groupby = distinct else: self.groupby = groupby if crud == 'update': self.values = [(f[0], self.annotate_expression(f[1])) for f in (fields or [])] self.fields = [f[0] for f in self.values] else: self.fields = [self.annotate_expression(f) for f in (fields or [])] self.tablename = tablename or adapter.get_table(query, *self.fields) if use_common_filters(query): query = adapter.common_filter(query, [self.tablename]) self.query = self.annotate_expression(query) # expand the query self.pipeline = [] self.query_dict = adapter.expand(self.query) self.field_dicts = adapter.SON() self.field_groups = adapter.SON() self.field_groups['_id'] = adapter.SON() if self._parse_data['pipeline']: # if the query needs the aggregation engine, set that up self._build_pipeline_query() # expand the fields for the aggregation engine self._expand_fields(None) else: # expand the fields try: if not self._parse_data['need_group']: self._expand_fields(self._fields_loop_abort) else: self._parse_data['pipeline'] = True raise StopIteration except StopIteration: # if the fields needs the aggregation engine, set that up self.field_dicts = adapter.SON() if self.query_dict: if self.query_dict != MongoDBAdapter.Expanded.NULL_QUERY: self.pipeline = [{'$match': self.query_dict}] self.query_dict = {} # expand the fields for the aggregation engine self._expand_fields(None) if not self._parse_data['pipeline']: if crud == 'update': # do not update id fields for fieldname in ("_id", "id"): if fieldname in self.field_dicts: del self.field_dicts[fieldname] else: if crud == 'update': self._add_all_fields_projection(self.field_dicts) self.field_dicts = adapter.SON() elif crud == 'select': if self._parse_data['need_group']: if not self.groupby: # no groupby, aggregate all records self.field_groups['_id'] = None # id has no value after aggregations self.field_dicts['_id'] = False self.pipeline.append({'$group': self.field_groups}) if self.field_dicts: self.pipeline.append({'$project': self.field_dicts}) self.field_dicts = adapter.SON() self._add_having() elif crud == 'count': if self._parse_data['need_group']: self.pipeline.append({'$group': self.field_groups}) self.pipeline.append( {'$group': {"_id": None, 'count': {"$sum": 1}}}) #elif crud == 'delete': # pass try: from bson.objectid import ObjectId NULL_QUERY = {'_id': {'$gt': ObjectId('000000000000000000000000')}} except: pass def _build_pipeline_query(self): # search for anything needing the $match stage. # currently only '$regex' requires the match stage def parse_need_match_stage(items, parent, parent_key): need_match = False non_matched_indices = [] if isinstance(items, list): indices = range(len(items)) elif isinstance(items, dict): indices = items.keys() else: return for i in indices: if parse_need_match_stage(items[i], items, i): need_match = True elif i not in [MongoDBAdapter.REGEXP_MARK1, MongoDBAdapter.REGEXP_MARK2]: non_matched_indices.append(i) if i == MongoDBAdapter.REGEXP_MARK1: need_match = True self.query_dict['project'].update(items[i]) parent[parent_key] = items[MongoDBAdapter.REGEXP_MARK2] if need_match: for i in non_matched_indices: name = str(items[i]) self.query_dict['project'][name] = items[i] items[i] = {name: True} if parent is None and self.query_dict['project']: self.query_dict['match'] = items return need_match expanded = self.adapter.expand(self.query) if MongoDBAdapter.REGEXP_MARK1 in expanded: # the REGEXP_MARK is at the top of the tree, so can just split # the regex over a '$project' and a '$match' self.query_dict = None match = expanded[MongoDBAdapter.REGEXP_MARK2] project = expanded[MongoDBAdapter.REGEXP_MARK1] else: self.query_dict = {'project': {}, 'match': {}} if parse_need_match_stage(expanded, None, None): project = self.query_dict['project'] match = self.query_dict['match'] else: project = {'__query__': expanded} match = {'__query__': True} if self.crud in ['select', 'update']: self._add_all_fields_projection(project) else: self.pipeline.append({'$project': project}) self.pipeline.append({'$match': match}) self.query_dict = None def _expand_fields(self, mid_loop): if self.crud == 'update': mid_loop = mid_loop or self._fields_loop_update_pipeline for field, value in self.values: self._expand_field(field, value, mid_loop) elif self.crud in ['select', 'count']: mid_loop = mid_loop or self._fields_loop_select_pipeline for field in self.fields: self._expand_field(field, field, mid_loop) elif self.fields: raise RuntimeError(self.crud + " not supported with fields") def _expand_field(self, field, value, mid_loop): expanded = {} if isinstance(field, Field): expanded = self.adapter.expand(value, field.type) elif isinstance(field, (Expression, Query)): expanded = self.adapter.expand(field) field.name = str(expanded) else: raise RuntimeError("%s not supported with fields" % type(field)) if mid_loop: expanded = mid_loop(expanded, field, value) self.field_dicts[field.name] = expanded def _fields_loop_abort(self, expanded, *args): # if we need the aggregation engine, then start over if self._parse_data['pipeline']: raise StopIteration() return expanded def _fields_loop_update_pipeline(self, expanded, field, value): if not isinstance(value, Expression): if self.adapter.server_version_major >= 2.6: expanded = {'$literal': expanded} # '$literal' not present in server versions < 2.6 elif field.type in ['string', 'text', 'password']: expanded = {'$concat': [expanded]} elif field.type in ['integer', 'bigint', 'float', 'double']: expanded = {'$add': [expanded]} elif field.type == 'boolean': expanded = {'$and': [expanded]} elif field.type in ['date', 'time', 'datetime']: expanded = {'$add': [expanded]} else: raise RuntimeError("updating with expressions not " + "supported for field type '" + "%s' in MongoDB version < 2.6" % field.type) return expanded def _fields_loop_select_pipeline(self, expanded, field, value): # search for anything needing $group def parse_groups(items, parent, parent_key): for item in items: if isinstance(items[item], list): for list_item in items[item]: if isinstance(list_item, dict): parse_groups(list_item, items[item], items[item].index(list_item)) elif isinstance(items[item], dict): parse_groups(items[item], items, item) if item == MongoDBAdapter.GROUP_MARK: name = str(items) self.field_groups[name] = items[item] parent[parent_key] = '$' + name return items if MongoDBAdapter.AS_MARK in field.name: # The AS_MARK in the field name is used by base to alias the # result, we don't actually need the AS_MARK in the parse tree # so we remove it here. if isinstance(expanded, list): # AS mark is first element in list, drop it expanded = expanded[1] elif MongoDBAdapter.AS_MARK in expanded: # AS mark is element in dict, drop it del expanded[MongoDBAdapter.AS_MARK] else: # ::TODO:: should be possible to do this... raise SyntaxError("AS() not at top of parse tree") if MongoDBAdapter.GROUP_MARK in expanded: # the GROUP_MARK is at the top of the tree, so can just pass # the group result straight through the '$project' stage self.field_groups[field.name] = expanded[MongoDBAdapter.GROUP_MARK] expanded = 1 elif MongoDBAdapter.GROUP_MARK in field.name: # the GROUP_MARK is not at the top of the tree, so we need to # pass the group results through to a '$project' stage. expanded = parse_groups(expanded, None, None) elif self._parse_data['need_group']: if field in self.groupby: # this is a 'groupby' field self.field_groups['_id'][field.name] = expanded expanded = '$_id.' + field.name else: raise SyntaxError("field '%s' not in groupby" % field) return expanded def _add_all_fields_projection(self, fields): for fieldname in self.adapter.db[self.tablename].fields: # add all fields to projection to pass them through if fieldname not in fields and fieldname not in ("_id", "id"): fields[fieldname] = 1 self.pipeline.append({'$project': fields}) def _add_having(self): if not self.having: return self._expand_field( self.having, None, self._fields_loop_select_pipeline) fields = {'__having__': self.field_dicts[self.having.name]} for fieldname in self.pipeline[-1]['$project']: # add all fields to projection to pass them through if fieldname not in fields and fieldname not in ("_id", "id"): fields[fieldname] = 1 self.pipeline.append({'$project': copy.copy(fields)}) self.pipeline.append({'$match': {'__having__': True}}) del fields['__having__'] self.pipeline.append({'$project': fields})
[docs] def annotate_expression(self, expression): def mark_has_field(expression): if not isinstance(expression, (Expression, Query)): return False first_has_field = mark_has_field(expression.first) second_has_field = mark_has_field(expression.second) expression.has_field = (isinstance(expression, Field) or first_has_field or second_has_field) return expression.has_field def add_parse_data(child, parent): if isinstance(child, (Expression, Query)): child.parse_root = parent.parse_root child.parse_parent = parent child.parse_depth = parent.parse_depth + 1 child._parse_data = parent._parse_data add_parse_data(child.first, child) add_parse_data(child.second, child) elif isinstance(child, (list, tuple)): for c in child: add_parse_data(c, parent) if isinstance(expression, (Expression, Query)): expression.parse_root = expression expression.parse_depth = -1 expression._parse_data = self._parse_data add_parse_data(expression, expression) mark_has_field(expression) return expression
[docs] def get_collection(self, safe=None): return self.adapter._get_collection(self.tablename, safe)
@staticmethod
[docs] def parse_data(expression, attribute, value=None): if isinstance(expression, (list, tuple)): ret = False for e in expression: ret = MongoDBAdapter.parse_data(e, attribute, value) or ret return ret if value is not None: try: expression._parse_data[attribute] = value except AttributeError: return None try: return expression._parse_data[attribute] except (AttributeError, TypeError): return None
@staticmethod
[docs] def has_field(expression): try: return expression.has_field except AttributeError: return False
[docs] def expand(self, expression, field_type=None): if isinstance(expression, Field): if expression.type == 'id': result = "_id" else: result = expression.name if self.parse_data(expression, 'pipeline'): # field names as part of expressions need to start with '$' result = '$' + result elif isinstance(expression, (Expression, Query)): first = expression.first second = expression.second if isinstance(first, Field) and "reference" in first.type: # cast to Mongo ObjectId if isinstance(second, (tuple, list, set)): second = [self.object_id(item) for item in expression.second] else: second = self.object_id(expression.second) op = expression.op optional_args = expression.optional_args or {} if second is not None: result = op(first, second, **optional_args) elif first is not None: result = op(first, **optional_args) elif isinstance(op, str): result = op else: result = op(**optional_args) elif isinstance(expression, MongoDBAdapter.Expanded): expression.query = (self.expand(expression.query, field_type)) result = expression elif isinstance(expression, (list, tuple)): raise NotImplementedError("How did you reach this line of code???") result = [self.represent(item, field_type) for item in expression] elif field_type: result = self.represent(expression, field_type) else: result = expression return result
[docs] def drop(self, table, mode=''): ctable = self.connection[table._tablename] ctable.drop() self._drop_cleanup(table) return
[docs] def truncate(self, table, mode, safe=None): ctable = self.connection[table._tablename] ctable.delete_many({})
[docs] def count(self, query, distinct=None, snapshot=True): if not isinstance(query, Query): raise SyntaxError("Type '%s' not supported in count" % type(query)) distinct_fields = [] if distinct is True: distinct_fields = [x for x in query.first.table if x.name != 'id'] elif distinct: if isinstance(distinct, Field): distinct_fields = [distinct] else: while (isinstance(distinct, Expression) and isinstance(distinct.second, Field)): distinct_fields += [distinct.second] distinct = distinct.first if isinstance(distinct, Field): distinct_fields += [distinct] distinct = True expanded = MongoDBAdapter.Expanded( self, 'count', query, fields=distinct_fields, distinct=distinct) ctable = expanded.get_collection() if not expanded.pipeline: return ctable.count(filter=expanded.query_dict) else: for record in ctable.aggregate(expanded.pipeline): return record['count'] return 0
[docs] def select(self, query, fields, attributes, snapshot=False): new_fields = [] for item in fields: if isinstance(item, SQLALL): new_fields += item._table else: new_fields.append(item) fields = new_fields tablename = self.get_table(query, *fields) orderby = attributes.get('orderby', False) limitby = attributes.get('limitby', False) groupby = attributes.get('groupby', None) orderby_on_limitby = attributes.get('orderby_on_limitby', True) distinct = attributes.get('distinct', False) having = attributes.get('having', None) for key in attributes.keys(): if attributes[key] and key not in SUPPORTED_SELECT_ARGS: if key in NON_MONGO_SELECT_ARGS: self.db.logger.warning( "Attribute '%s' unsuppored by MongoDB" % key) elif key in NOT_IMPLEMENTED_SELECT_ARGS: self.db.logger.warning( "Attribute '%s' is not implemented by MongoDB" % key) elif key in SQL_ONLY_SELECT_ARGS: raise MongoDBAdapter.NotOnNoSqlError( "Attribute '%s' not supported on NoSQL databases" % key) else: raise SyntaxError( "Attribute '%s' is unknown" % key) if limitby and orderby_on_limitby and not orderby: if groupby: orderby = groupby else: table = self.db[tablename] orderby = [table[x] for x in (hasattr(table, '_primarykey') and table._primarykey or ['_id'])] if not orderby: mongosort_list = [] else: if snapshot: raise RuntimeError("snapshot and orderby are mutually exclusive") if isinstance(orderby, (list, tuple)): orderby = xorify(orderby) if str(orderby) == '<random>': # !!!! need to add 'random' mongosort_list = self.RANDOM() else: mongosort_list = [] for f in self.expand(orderby).split(','): include = 1 if f.startswith('-'): include = -1 f = f[1:] if f.startswith('$'): f = f[1:] mongosort_list.append((f, include)) expanded = MongoDBAdapter.Expanded( self, 'select', query, fields or self.db[tablename], groupby=groupby, distinct=distinct, having=having) ctable = self.connection[tablename] modifiers = {'snapshot':snapshot} if not expanded.pipeline: if limitby: limitby_skip, limitby_limit = limitby[0], int(limitby[1]) - 1 else: limitby_skip = limitby_limit = 0 mongo_list_dicts = ctable.find( expanded.query_dict, expanded.field_dicts, skip=limitby_skip, limit=limitby_limit, sort=mongosort_list, modifiers=modifiers) null_rows = [] else: if mongosort_list: sortby_dict = self.SON() for f in mongosort_list: sortby_dict[f[0]] = f[1] expanded.pipeline.append({'$sort': sortby_dict}) if limitby and limitby[1]: expanded.pipeline.append({'$limit': limitby[1]}) if limitby and limitby[0]: expanded.pipeline.append({'$skip': limitby[0]}) mongo_list_dicts = ctable.aggregate(expanded.pipeline) null_rows = [(None,)] rows = [] # populate row in proper order # Here we replace ._id with .id to follow the standard naming colnames = [] newnames = [] for field in expanded.fields: if hasattr(field, "tablename"): if field.name in ('id', '_id'): # Mongodb reserved uuid key colname = (tablename + '.' + 'id', '_id') else: colname = (tablename + '.' + field.name, field.name) elif not isinstance(query, Expression): colname = (field.name, field.name) colnames.append(colname[1]) newnames.append(colname[0]) for record in mongo_list_dicts: row = [] for colname in colnames: try: value = record[colname] except: value = None if self.server_version_major < 2.6: # '$size' not present in server versions < 2.6 if isinstance(value, list) and '$addToSet' in colname: value = len(value) row.append(value) rows.append(row) if not rows: rows = null_rows processor = attributes.get('processor', self.parse) result = processor(rows, fields, newnames, blob_decode=True) return result
[docs] def check_notnull(self, table, values): for fieldname in table._notnulls: if fieldname not in values or values[fieldname] is None: raise Exception("NOT NULL constraint failed: %s" % fieldname)
[docs] def check_unique(self, table, values): if len(table._uniques) > 0: db = table._db unique_queries = [] for fieldname in table._uniques: if fieldname in values: value = values[fieldname] else: value = table[fieldname].default unique_queries.append( Query(db, self.EQ, table[fieldname], value)) if len(unique_queries) > 0: unique_query = unique_queries[0] # if more than one field, build a query of ORs for query in unique_queries[1:]: unique_query = Query(db, self.OR, unique_query, query) if self.count(unique_query, distinct=False) != 0: for query in unique_queries: if self.count(query, distinct=False) != 0: # one of the 'OR' queries failed, see which one raise Exception("NOT UNIQUE constraint failed: %s" % query.first.name)
[docs] def insert(self, table, fields, safe=None): """Safe determines whether a asynchronous request is done or a synchronous action is done For safety, we use by default synchronous requests""" values = {} safe = self._get_safe(safe) ctable = self._get_collection(table._tablename, safe) for k, v in fields: if k.name not in ["id", "safe"]: fieldname = k.name fieldtype = table[k.name].type values[fieldname] = self.represent(v, fieldtype) # validate notnulls try: self.check_notnull(table, values) except Exception as e: if hasattr(table, '_on_insert_error'): return table._on_insert_error(table, fields, e) raise e # validate uniques try: self.check_unique(table, values) except Exception as e: if hasattr(table, '_on_insert_error'): return table._on_insert_error(table, fields, e) raise e # perform the insert result = ctable.insert_one(values) if result.acknowledged: Oid = result.inserted_id rid = Reference(long(str(Oid), 16)) (rid._table, rid._record) = (table, None) return rid else: return None
[docs] def update(self, tablename, query, fields, safe=None): # return amount of adjusted rows or zero, but no exceptions # @ related not finding the result if not isinstance(query, Query): raise RuntimeError("Not implemented") safe = self._get_safe(safe) if safe: amount = 0 else: amount = self.count(query, distinct=False) if amount == 0: return amount expanded = MongoDBAdapter.Expanded(self, 'update', query, fields) ctable = expanded.get_collection(safe) if expanded.pipeline: try: for doc in ctable.aggregate(expanded.pipeline): result = ctable.replace_one({'_id': doc['_id']}, doc) if safe and result.acknowledged: amount += result.matched_count return amount except Exception as e: # TODO Reverse update query to verify that the query succeeded raise RuntimeError("uncaught exception when updating rows: %s" % e) else: try: result = ctable.update_many( filter=expanded.query_dict, update={'$set': expanded.field_dicts}) if safe and result.acknowledged: amount = result.matched_count return amount except Exception as e: # TODO Reverse update query to verify that the query succeeded raise RuntimeError("uncaught exception when updating rows: %s" % e)
[docs] def delete(self, tablename, query, safe=None): if not isinstance(query, Query): raise RuntimeError("query type %s is not supported" % type(query)) safe = self._get_safe(safe) expanded = MongoDBAdapter.Expanded(self, 'delete', query) ctable = expanded.get_collection(safe) if expanded.pipeline: deleted = [x['_id'] for x in ctable.aggregate(expanded.pipeline)] else: deleted = [x['_id'] for x in ctable.find(expanded.query_dict)] # find references to deleted items db = self.db table = db[tablename] cascade = [] set_null = [] for field in table._referenced_by: if field.type == 'reference '+ tablename: if field.ondelete == 'CASCADE': cascade.append(field) if field.ondelete == 'SET NULL': set_null.append(field) cascade_list = [] set_null_list = [] for field in table._referenced_by_list: if field.type == 'list:reference '+ tablename: if field.ondelete == 'CASCADE': cascade_list.append(field) if field.ondelete == 'SET NULL': set_null_list.append(field) # perform delete result = ctable.delete_many({"_id": { "$in": deleted }}) if result.acknowledged: amount = result.deleted_count else: amount = len(deleted) # clean up any references if amount and deleted: # ::TODO:: test if deleted references cascade def remove_from_list(field, deleted, safe): for delete in deleted: modify = {field.name: delete} dtable = self._get_collection(field.tablename, safe) result = dtable.update_many( filter=modify, update={'$pull': modify}) # for cascaded items, if the reference is the only item in the list, # then remove the entire record, else delete reference from the list for field in cascade_list: for delete in deleted: modify = {field.name: [delete]} dtable = self._get_collection(field.tablename, safe) result = dtable.delete_many(filter=modify) remove_from_list(field, deleted, safe) for field in set_null_list: remove_from_list(field, deleted, safe) for field in cascade: db(field.belongs(deleted)).delete() for field in set_null: db(field.belongs(deleted)).update(**{field.name:None}) return amount
[docs] def bulk_insert(self, table, items): return [self.insert(table, item) for item in items]
## OPERATORS
[docs] def needs_mongodb_aggregation_pipeline(f): def mark_pipeline(self, first, *args, **kwargs): self.parse_data(first, 'pipeline', True) if len(args) > 0: self.parse_data(args[0], 'pipeline', True) return f(self, first, *args, **kwargs) return mark_pipeline
[docs] def INVERT(self, first): #print "in invert first=%s" % first return '-%s' % self.expand(first)
[docs] def NOT(self, first): op = self.expand(first) op_k = list(op)[0] op_body = op[op_k] r = None if type(op_body) is list: # apply De Morgan law for and/or # not(A and B) -> not(A) or not(B) # not(A or B) -> not(A) and not(B) not_op = '$and' if op_k == '$or' else '$or' r = {not_op: [self.NOT(first.first), self.NOT(first.second)]} else: try: sub_ops = list(op_body.keys()) if len(sub_ops) == 1 and sub_ops[0] == '$ne': r = {op_k: op_body['$ne']} except AttributeError: r = {op_k: {'$ne': op_body}} if r is None: r = {op_k: {'$not': op_body}} return r
[docs] def AND(self, first, second): # pymongo expects: .find({'$and': [{'x':'1'}, {'y':'2'}]}) if isinstance(second, bool): if second: return self.expand(first) return self.NE(first, first) return {'$and': [self.expand(first), self.expand(second)]}
[docs] def OR(self, first, second): # pymongo expects: .find({'$or': [{'name':'1'}, {'name':'2'}]}) if isinstance(second, bool): if not second: return self.expand(first) return True return {'$or': [self.expand(first), self.expand(second)]}
[docs] def BELONGS(self, first, second): if isinstance(second, str): # this is broken, the only way second is a string is if it has # been converted to SQL. This no worky. This might be made to # work if _select did not return SQL. raise RuntimeError("nested queries not supported") items = [self.expand(item, first.type) for item in second] return {self.expand(first): {"$in": items}}
[docs] def validate_second(f): def check_second(*args, **kwargs): if len(args) < 3 or args[2] is None: raise RuntimeError("Cannot compare %s with None" % args[1]) return f(*args, **kwargs) return check_second
[docs] def check_fields_for_cmp(f): def check_fields(self, first, second=None, *args, **kwargs): if (self.parse_data((first, second), 'pipeline')): pipeline = True elif not isinstance(first, Field) or self.has_field(second): pipeline = True self.parse_data((first, second), 'pipeline', True) else: pipeline = False return f(self, first, second, *args, pipeline=pipeline, **kwargs) return check_fields
[docs] def CMP_OPS_AGGREGATION_PIPELINE(self, op, first, second): try: type = first.type except: type = None return {op: [self.expand(first), self.expand(second, type)]}
@check_fields_for_cmp
[docs] def EQ(self, first, second=None, pipeline=False): if pipeline: return self.CMP_OPS_AGGREGATION_PIPELINE('$eq', first, second) return {self.expand(first): self.expand(second, first.type)}
@check_fields_for_cmp
[docs] def NE(self, first, second=None, pipeline=False): if pipeline: return self.CMP_OPS_AGGREGATION_PIPELINE('$ne', first, second) return {self.expand(first): {'$ne': self.expand(second, first.type)}}
@validate_second @check_fields_for_cmp
[docs] def LT(self, first, second=None, pipeline=False): if pipeline: return self.CMP_OPS_AGGREGATION_PIPELINE('$lt', first, second) return {self.expand(first): {'$lt': self.expand(second, first.type)}}
@validate_second @check_fields_for_cmp
[docs] def LE(self, first, second=None, pipeline=False): if pipeline: return self.CMP_OPS_AGGREGATION_PIPELINE('$lte', first, second) return {self.expand(first): {'$lte': self.expand(second, first.type)}}
@validate_second @check_fields_for_cmp
[docs] def GT(self, first, second=None, pipeline=False): if pipeline: return self.CMP_OPS_AGGREGATION_PIPELINE('$gt', first, second) return {self.expand(first): {'$gt': self.expand(second, first.type)}}
@validate_second @check_fields_for_cmp
[docs] def GE(self, first, second=None, pipeline=False): if pipeline: return self.CMP_OPS_AGGREGATION_PIPELINE('$gte', first, second) return {self.expand(first): {'$gte': self.expand(second, first.type)}}
@needs_mongodb_aggregation_pipeline
[docs] def ADD(self, first, second): op_code = '$add' for field in [first, second]: try: if field.type in ['string', 'text', 'password']: op_code = '$concat' break except: pass return {op_code: [self.expand(first), self.expand(second, first.type)]}
@needs_mongodb_aggregation_pipeline
[docs] def SUB(self, first, second): return {'$subtract': [ self.expand(first), self.expand(second, first.type)]}
@needs_mongodb_aggregation_pipeline
[docs] def MUL(self, first, second): return {'$multiply': [ self.expand(first), self.expand(second, first.type)]}
@needs_mongodb_aggregation_pipeline
[docs] def DIV(self, first, second): return {'$divide': [ self.expand(first), self.expand(second, first.type)]}
@needs_mongodb_aggregation_pipeline
[docs] def MOD(self, first, second): return {'$mod': [ self.expand(first), self.expand(second, first.type)]}
_aggregate_map = { 'SUM': '$sum', 'MAX': '$max', 'MIN': '$min', 'AVG': '$avg', } @needs_mongodb_aggregation_pipeline
[docs] def AGGREGATE(self, first, what): if what == 'ABS': return {"$cond": [ {"$lt": [self.expand(first), 0]}, {"$subtract": [0, self.expand(first)]}, self.expand(first) ]} try: expanded = {self._aggregate_map[what]: self.expand(first)} except KeyError: raise NotImplementedError("'%s' not implemented" % what) self.parse_data(first, 'need_group', True) return {MongoDBAdapter.GROUP_MARK: expanded}
@needs_mongodb_aggregation_pipeline
[docs] def COUNT(self, first, distinct=None): self.parse_data(first, 'need_group', True) if distinct: ret = {MongoDBAdapter.GROUP_MARK: {"$addToSet": self.expand(first)}} if self.server_version_major >= 2.6: # '$size' not present in server versions < 2.6 ret = {'$size': ret} return ret return {MongoDBAdapter.GROUP_MARK: {"$sum": 1}}
_extract_map = { 'dayofyear': '$dayOfYear', 'day': '$dayOfMonth', 'dayofweek': '$dayOfWeek', 'year': '$year', 'month': '$month', 'week': '$week', 'hour': '$hour', 'minute': '$minute', 'second': '$second', 'millisecond': '$millisecond', 'string': '$dateToString', } @needs_mongodb_aggregation_pipeline
[docs] def EXTRACT(self, first, what): try: return {self._extract_map[what]: self.expand(first)} except KeyError: raise NotImplementedError("EXTRACT(%s) not implemented" % what)
@needs_mongodb_aggregation_pipeline
[docs] def EPOCH(self, first): return {"$divide": [{"$subtract": [self.expand(first), self.epoch]}, 1000]}
@needs_mongodb_aggregation_pipeline
[docs] def EXPAND_CASE(self, query, true_false): return {"$cond": [self.expand(query), self.expand(true_false[0]), self.expand(true_false[1])]}
@needs_mongodb_aggregation_pipeline
[docs] def AS(self, first, second): # put the AS_MARK into the structure. The 'AS' name will be parsed # later from the string of the field name. if isinstance(first, Field): return [{MongoDBAdapter.AS_MARK: second}, self.expand(first)] else: result = self.expand(first) result[MongoDBAdapter.AS_MARK] = second return result
# We could implement an option that simulates a full featured SQL # database. But I think the option should be set explicit or # implemented as another library.
[docs] def ON(self, first, second): raise MongoDBAdapter.NotOnNoSqlError()
[docs] def COMMA(self, first, second): # returns field name lists, to be separated via split(',') return '%s,%s' % (self.expand(first), self.expand(second))
#TODO verify full compatibilty with official SQL Like operator def _build_like_regex(self, first, second, case_sensitive=True, escape=None, ends_with=False, starts_with=False, whole_string=True, like_wildcards=False): import re base = self.expand(second, 'string') need_regex = (whole_string or not case_sensitive or starts_with or ends_with or like_wildcards and ('_' in base or '%' in base)) if not need_regex: return base else: expr = re.escape(base) if like_wildcards: if escape: # protect % and _ which are escaped expr = expr.replace(escape+'\\%', '%') if PY2: expr = expr.replace(escape+'\\_', '_') elif escape+'_' in expr: set_aside = str(self.object_id('<random>')) while set_aside in expr: set_aside = str(self.object_id('<random>')) expr = expr.replace(escape+'_', set_aside) else: set_aside = None expr = expr.replace('\\%', '.*') if PY2: expr = expr.replace('\\_', '.') else: expr = expr.replace('_', '.') if escape: # convert to protected % and _ expr = expr.replace('%', '\\%') if PY2: expr = expr.replace('_', '\\_') elif set_aside: expr = expr.replace(set_aside, '_') if starts_with: pattern = '^%s' elif ends_with: pattern = '%s$' elif whole_string: pattern = '^%s$' else: pattern = '%s' return self.REGEXP(first, pattern % expr, case_sensitive)
[docs] def LIKE(self, first, second, case_sensitive=True, escape=None): return self._build_like_regex(first, second, case_sensitive=case_sensitive, escape=escape, like_wildcards=True)
[docs] def ILIKE(self, first, second, escape=None): return self.LIKE(first, second, case_sensitive=False, escape=escape)
[docs] def STARTSWITH(self, first, second): return self._build_like_regex(first, second, starts_with=True)
[docs] def ENDSWITH(self, first, second): return self._build_like_regex(first, second, ends_with=True)
#TODO verify full compatibilty with official oracle contains operator
[docs] def CONTAINS(self, first, second, case_sensitive=True): if isinstance(second, self.ObjectId): ret = {self.expand(first): second} elif isinstance(second, Field): if second.type in ['string', 'text']: if isinstance(first, Field): if first.type in ['list:string', 'string', 'text']: ret = {'$where': "this.%s.indexOf(this.%s) > -1" % (first.name, second.name)} else: raise NotImplementedError("field.CONTAINS() not " + "implemented for field type of '%s'" % first.type) else: raise NotImplementedError( "x.CONTAINS() not implemented for x type of '%s'" % type(first)) elif second.type in ['integer', 'bigint']: ret = {'$where': "this.%s.indexOf(this.%s + '') > -1" % (first.name, second.name)} else: raise NotImplementedError( "CONTAINS(field) not implemented for field type '%s'" % second.type) elif isinstance(second, (basestring, int)): whole_string = (isinstance(first, Field) and first.type == 'list:string') ret = self._build_like_regex(first, second, case_sensitive=case_sensitive, whole_string=whole_string) # first.type in ('string', 'text', 'json', 'upload') # or first.type.startswith('list:'): else: raise NotImplementedError( "CONTAINS() not implemented for type '%s'" % type(second)) return ret
@needs_mongodb_aggregation_pipeline
[docs] def SUBSTRING(self, field, parameters): def parse_parameters(pos0, length): """ The expression object can return these as string based expressions. We can't use that so we have to tease it apart. These are the possibilities: pos0 = '(%s - %d)' % (self.len(), abs(start) - 1) pos0 = start + 1 length = self.len() length = '(%s - %d - %s)' % (self.len(), abs(stop) - 1, pos0) length = '(%s - %s)' % (stop + 1, pos0) Two of these five require the length of the string which is not supported by Mongo, so for now these cause an Exception and won't reach here. If this were to ever be supported it may require a change to Expression.__getitem__ so that it either returned the base expression to be expanded here, or converted length to a string to be parsed back to a call to STRLEN() """ if isinstance(length, basestring): return (pos0 - 1, eval(length)) else: # take the rest of the string return (pos0 - 1, -1) parameters = parse_parameters(*parameters) return {'$substr': [self.expand(field), parameters[0], parameters[1]]}
@needs_mongodb_aggregation_pipeline
[docs] def LOWER(self, first): return {'$toLower': self.expand(first)}
@needs_mongodb_aggregation_pipeline
[docs] def UPPER(self, first): return {'$toUpper': self.expand(first)}
[docs] def REGEXP(self, first, second, case_sensitive=True): """ MongoDB provides regular expression capabilities for pattern matching strings in queries. MongoDB uses Perl compatible regular expressions (i.e. 'PCRE') version 8.36 with UTF-8 support. """ if (isinstance(first, Field) and first.type in ['integer', 'bigint', 'float', 'double']): return {'$where': "RegExp('%s').test(this.%s + '')" % (self.expand(second, 'string'), first.name)} expanded_first = self.expand(first) regex_second = {'$regex': self.expand(second, 'string')} if not case_sensitive: regex_second['$options'] = 'i' if (self.parse_data((first, second), 'pipeline')): name = str(expanded_first) return {MongoDBAdapter.REGEXP_MARK1: {name: expanded_first}, MongoDBAdapter.REGEXP_MARK2: {name: regex_second}} try: return {expanded_first: regex_second} except TypeError: # if first is not hashable, then will need the pipeline self.parse_data((first, second), 'pipeline', True) return {}
[docs] def LENGTH(self, first): """ https://jira.mongodb.org/browse/SERVER-5319 https://github.com/afchin/mongo/commit/f52105977e4d0ccb53bdddfb9c4528a3f3c40bdf """ raise NotImplementedError()
@needs_mongodb_aggregation_pipeline
[docs] def COALESCE(self, first, second): if len(second) > 1: second = [self.COALESCE(second[0], second[1:])] return {"$ifNull": [self.expand(first), self.expand(second[0])]}
[docs] def RANDOM(self): """ ORDER BY RANDOM() https://github.com/mongodb/cookbook/blob/master/content/patterns/random-attribute.txt https://jira.mongodb.org/browse/SERVER-533 http://stackoverflow.com/questions/19412/how-to-request-a-random-row-in-sql """ raise NotImplementedError()
[docs] class NotOnNoSqlError(NotImplementedError): def __init__(self, message=None): if message is None: message = "Not Supported on NoSQL databases" super(MongoDBAdapter.NotOnNoSqlError, self).__init__(message)
[docs]class MongoBlob(Binary): MONGO_BLOB_BYTES = USER_DEFINED_SUBTYPE MONGO_BLOB_NON_UTF8_STR = USER_DEFINED_SUBTYPE + 1 def __new__(cls, value): # return None and Binary() unmolested if value is None or isinstance(value, Binary): return value # bytearray is marked as MONGO_BLOB_BYTES if isinstance(value, bytearray): return Binary.__new__(cls, bytes(value), MongoBlob.MONGO_BLOB_BYTES) # return non-strings as Binary(), eg: PY3 bytes() if not isinstance(value, basestring): return Binary(value) # if string is encodable as UTF-8, then return as string try: value.encode('utf-8') return value except UnicodeDecodeError: # string which can not be UTF-8 encoded, eg: pickle strings return Binary.__new__(cls, value, MongoBlob.MONGO_BLOB_NON_UTF8_STR) def __repr__(self): return repr(MongoBlob.decode(self)) @staticmethod
[docs] def decode(value): if isinstance(value, Binary): if value.subtype == MongoBlob.MONGO_BLOB_BYTES: return bytearray(value) if value.subtype == MongoBlob.MONGO_BLOB_NON_UTF8_STR: return str(value) return value