Source code for pydal.adapters.postgres

# -*- coding: utf-8 -*-
import re

from .._globals import IDENTITY
from ..drivers import psycopg2_adapt
from .._compat import PY2
from ..helpers.methods import varquote_aux
from .base import BaseAdapter
from ..objects import Expression


[docs]class PostgreSQLAdapter(BaseAdapter): drivers = ('psycopg2','pg8000') QUOTE_TEMPLATE = '"%s"' support_distributed_transaction = True types = { 'boolean': 'CHAR(1)', 'string': 'VARCHAR(%(length)s)', 'text': 'TEXT', 'json': 'TEXT', 'password': 'VARCHAR(%(length)s)', 'blob': 'BYTEA', 'upload': 'VARCHAR(%(length)s)', 'integer': 'INTEGER', 'bigint': 'BIGINT', 'float': 'FLOAT', 'double': 'FLOAT8', 'decimal': 'NUMERIC(%(precision)s,%(scale)s)', 'date': 'DATE', 'time': 'TIME', 'datetime': 'TIMESTAMP', 'id': 'SERIAL PRIMARY KEY', 'reference': 'INTEGER REFERENCES %(foreign_key)s ON DELETE %(on_delete_action)s %(null)s %(unique)s', 'list:integer': 'TEXT', 'list:string': 'TEXT', 'list:reference': 'TEXT', 'geometry': 'GEOMETRY', 'geography': 'GEOGRAPHY', 'big-id': 'BIGSERIAL PRIMARY KEY', 'big-reference': 'BIGINT REFERENCES %(foreign_key)s ON DELETE %(on_delete_action)s %(null)s %(unique)s', 'reference FK': ', CONSTRAINT "FK_%(constraint_name)s" FOREIGN KEY (%(field_name)s) REFERENCES %(foreign_key)s ON DELETE %(on_delete_action)s', 'reference TFK': ' CONSTRAINT "FK_%(foreign_table)s_PK" FOREIGN KEY (%(field_name)s) REFERENCES %(foreign_table)s (%(foreign_key)s) ON DELETE %(on_delete_action)s', }
[docs] def varquote(self, name): return varquote_aux(name, '"%s"')
[docs] def adapt(self, obj): if self.driver_name == 'psycopg2': adapted = psycopg2_adapt(obj) # deal with new relic Connection Wrapper (newrelic>=2.10.0.8) cxn = getattr(self.connection,'__wrapped__',self.connection) adapted.prepare(cxn) rv = adapted.getquoted() if not PY2: if isinstance(rv, bytes): return rv.decode('utf-8') return rv elif self.driver_name == 'pg8000': return "'%s'" % obj.replace("%", "%%").replace("'", "''") else: return "'%s'" % obj.replace("'", "''")
[docs] def sequence_name(self, table): return self.QUOTE_TEMPLATE % (table + '_id_seq')
[docs] def RANDOM(self): return 'RANDOM()'
[docs] def ADD(self, first, second): t = first.type if t in ('text','string','password', 'json', 'upload','blob'): return '(%s || %s)' % (self.expand(first), self.expand(second, t)) else: return '(%s + %s)' % (self.expand(first), self.expand(second, t))
[docs] def distributed_transaction_begin(self, key): return
[docs] def prepare(self,key): self.execute("PREPARE TRANSACTION '%s';" % key)
[docs] def commit_prepared(self,key): self.execute("COMMIT PREPARED '%s';" % key)
[docs] def rollback_prepared(self,key): self.execute("ROLLBACK PREPARED '%s';" % key)
[docs] def create_sequence_and_triggers(self, query, table, **args): # following lines should only be executed if table._sequence_name does not exist # self.execute('CREATE SEQUENCE %s;' % table._sequence_name) # self.execute("ALTER TABLE %s ALTER COLUMN %s SET DEFAULT NEXTVAL('%s');" \ # % (table._tablename, table._fieldname, table._sequence_name)) self.execute(query)
REGEX_URI = re.compile('^(?P<user>[^:@]+)(\:(?P<password>[^@]*))?@(?P<host>\[[^/]+\]|[^\:@]+)(\:(?P<port>[0-9]+))?/(?P<db>[^\?]+)(\?sslmode=(?P<sslmode>.+))?$') def __init__(self, db,uri, pool_size=0, folder=None, db_codec ='UTF-8', credential_decoder=IDENTITY, driver_args={}, adapter_args={}, do_connect=True, srid=4326, after_connection=None): self.db = db self.dbengine="postgres" self.uri = uri if do_connect: self.find_driver(adapter_args, uri) self.pool_size = pool_size self.folder = folder self.db_codec = db_codec self._after_connection = after_connection self.srid = srid self.find_or_make_work_folder() self._last_insert = None # for INSERT ... RETURNING ID self.TRUE_exp = 'TRUE' self.FALSE_exp = 'FALSE' ruri = uri.split('://',1)[1] m = self.REGEX_URI.match(ruri) if not m: raise SyntaxError("Invalid URI string in DAL") user = credential_decoder(m.group('user')) if not user: raise SyntaxError('User required') password = credential_decoder(m.group('password')) if not password: password = '' host = m.group('host') if not host: raise SyntaxError('Host name required') db = m.group('db') if not db: raise SyntaxError('Database name required') port = m.group('port') or '5432' sslmode = m.group('sslmode') driver_args['database'] = db driver_args['user'] = user driver_args['host'] = host driver_args['port'] = int(port) driver_args['password'] = password if sslmode: driver_args['sslmode'] = sslmode # choose diver according uri if self.driver: self.__version__ = "%s %s" % (self.driver.__name__, self.driver.__version__) else: self.__version__ = None def connector(driver_args=driver_args): return self.driver.connect(**driver_args) self.connector=connector if do_connect: self.reconnect()
[docs] def after_connection(self): #self.connection.set_client_encoding('UTF8') #pg8000 doesn't have a native set_client_encoding self.execute("SET CLIENT_ENCODING TO 'UTF8'") self.execute("SET standard_conforming_strings=on;") self.try_json()
def _insert(self, table, fields): table_rname = table.sqlsafe if fields: keys = ','.join(f.sqlsafe_name for f, v in fields) values = ','.join(self.expand(v, f.type) for f, v in fields) if hasattr(table, '_id'): self._last_insert = (table._id, 1) return 'INSERT INTO %s(%s) VALUES (%s) RETURNING %s;' % ( table_rname, keys, values, self.QUOTE_TEMPLATE % table._id.name) else: self._last_insert = None return 'INSERT INTO %s(%s) VALUES (%s);' % (table_rname, keys, values) else: self._last_insert return self._insert_empty(table)
[docs] def lastrowid(self, table=None): if self._last_insert: return int(self.cursor.fetchone()[0]) else: self.execute("select lastval()") return int(self.cursor.fetchone()[0])
[docs] def try_json(self): if self.driver_name == "pg8000": supports_json = self.connection._server_version >= "9.2.0" elif (self.driver_name == "psycopg2" and self.driver.__version__ >= "2.0.12"): supports_json = self.connection.server_version >= 90200 elif self.driver_name == "zxJDBC": supports_json = self.connection.dbversion >= "9.2.0" else: supports_json = None if supports_json: self.types["json"] = "JSON" if ((self.driver_name == "psycopg2" and self.driver.__version__ >= '2.5.0') or (self.driver_name == "pg8000" and self.driver.__version__ >= '1.10.2')): self.driver_auto_json = ['loads'] else: self.db.logger.debug("Your database version does not support the JSON" " data type (using TEXT instead)")
[docs] def LIKE(self, first, second, escape=None): """Case sensitive like operator""" if isinstance(second, Expression): second = self.expand(second, 'string') else: second = self.expand(second, 'string') if escape is None: escape = '\\' second = second.replace(escape, escape * 2) if first.type not in ('string', 'text', 'json'): return "(%s LIKE %s ESCAPE '%s')" % ( self.CAST(self.expand(first), 'CHAR(%s)' % first.length), second, escape ) else: return "(%s LIKE %s ESCAPE '%s')" % (self.expand(first), second, escape)
[docs] def ILIKE(self, first, second, escape=None): """Case sensitive like operator""" if isinstance(second, Expression): second = self.expand(second, 'string') else: second = self.expand(second, 'string') if escape is None: escape = '\\' second = second.replace(escape, escape * 2) if first.type not in ('string', 'text', 'json', 'list:string'): return "(%s ILIKE %s ESCAPE '%s')" % ( self.CAST(self.expand(first), 'CHAR(%s)' % first.length), second, escape ) else: return "(%s ILIKE %s ESCAPE '%s')" % (self.expand(first), second, escape)
[docs] def REGEXP(self,first,second): return '(%s ~ %s)' % (self.expand(first), self.expand(second,'string'))
# GIS functions
[docs] def ST_ASGEOJSON(self, first, second): """ http://postgis.org/docs/ST_AsGeoJSON.html """ return 'ST_AsGeoJSON(%s,%s,%s,%s)' %(second['version'], self.expand(first), second['precision'], second['options'])
[docs] def ST_ASTEXT(self, first): """ http://postgis.org/docs/ST_AsText.html """ return 'ST_AsText(%s)' %(self.expand(first))
[docs] def ST_X(self, first): """ http://postgis.org/docs/ST_X.html """ return 'ST_X(%s)' %(self.expand(first))
[docs] def ST_Y(self, first): """ http://postgis.org/docs/ST_Y.html """ return 'ST_Y(%s)' %(self.expand(first))
[docs] def ST_CONTAINS(self, first, second): """ http://postgis.org/docs/ST_Contains.html """ return 'ST_Contains(%s,%s)' %(self.expand(first), self.expand(second, first.type))
[docs] def ST_DISTANCE(self, first, second): """ http://postgis.org/docs/ST_Distance.html """ return 'ST_Distance(%s,%s)' %(self.expand(first), self.expand(second, first.type))
[docs] def ST_EQUALS(self, first, second): """ http://postgis.org/docs/ST_Equals.html """ return 'ST_Equals(%s,%s)' %(self.expand(first), self.expand(second, first.type))
[docs] def ST_INTERSECTS(self, first, second): """ http://postgis.org/docs/ST_Intersects.html """ return 'ST_Intersects(%s,%s)' %(self.expand(first), self.expand(second, first.type))
[docs] def ST_OVERLAPS(self, first, second): """ http://postgis.org/docs/ST_Overlaps.html """ return 'ST_Overlaps(%s,%s)' %(self.expand(first), self.expand(second, first.type))
[docs] def ST_SIMPLIFY(self, first, second): """ http://postgis.org/docs/ST_Simplify.html """ return 'ST_Simplify(%s,%s)' %(self.expand(first), self.expand(second, 'double'))
[docs] def ST_SIMPLIFYPRESERVETOPOLOGY(self, first, second): """ http://postgis.org/docs/ST_SimplifyPreserveTopology.html """ return 'ST_SimplifyPreserveTopology(%s,%s)' %(self.expand(first), self.expand(second, 'double'))
[docs] def ST_TOUCHES(self, first, second): """ http://postgis.org/docs/ST_Touches.html """ return 'ST_Touches(%s,%s)' %(self.expand(first), self.expand(second, first.type))
[docs] def ST_WITHIN(self, first, second): """ http://postgis.org/docs/ST_Within.html """ return 'ST_Within(%s,%s)' %(self.expand(first), self.expand(second, first.type))
[docs] def ST_DWITHIN(self, first, tup): """ http://postgis.org/docs/ST_DWithin.html """ second, third = tup return 'ST_DWithin(%s,%s,%s)' %(self.expand(first), self.expand(second, first.type), self.expand(third, 'double'))
[docs] def represent(self, obj, fieldtype): field_is_type = fieldtype.startswith if field_is_type('geo'): srid = 4326 # postGIS default srid for geometry geotype, parms = fieldtype[:-1].split('(') parms = parms.split(',') if len(parms) >= 2: schema, srid = parms[:2] if field_is_type('geometry'): value = "ST_GeomFromText('%s',%s)" %(obj, srid) elif field_is_type('geography'): value = "ST_GeogFromText('SRID=%s;%s')" %(srid, obj) # else: # raise SyntaxError('Invalid field type %s' %fieldtype) return value return BaseAdapter.represent(self, obj, fieldtype)
def _drop(self, table, mode='restrict'): if mode not in ['restrict', 'cascade', '']: raise ValueError('Invalid mode: %s' % mode) return ['DROP TABLE ' + table.sqlsafe + ' ' + str(mode) + ';']
[docs] def execute(self, *a, **b): if PY2 and self.driver_name == "pg8000": a = list(a) a[0] = a[0].decode('utf8') return BaseAdapter.execute(self, *a, **b)
[docs]class NewPostgreSQLAdapter(PostgreSQLAdapter): drivers = ('psycopg2','pg8000') types = { 'boolean': 'CHAR(1)', 'string': 'VARCHAR(%(length)s)', 'text': 'TEXT', 'json': 'TEXT', 'password': 'VARCHAR(%(length)s)', 'blob': 'BYTEA', 'upload': 'VARCHAR(%(length)s)', 'integer': 'INTEGER', 'bigint': 'BIGINT', 'float': 'FLOAT', 'double': 'FLOAT8', 'decimal': 'NUMERIC(%(precision)s,%(scale)s)', 'date': 'DATE', 'time': 'TIME', 'datetime': 'TIMESTAMP', 'id': 'SERIAL PRIMARY KEY', 'reference': 'INTEGER REFERENCES %(foreign_key)s ON DELETE %(on_delete_action)s %(null)s %(unique)s', 'list:integer': 'BIGINT[]', 'list:string': 'TEXT[]', 'list:reference': 'BIGINT[]', 'geometry': 'GEOMETRY', 'geography': 'GEOGRAPHY', 'big-id': 'BIGSERIAL PRIMARY KEY', 'big-reference': 'BIGINT REFERENCES %(foreign_key)s ON DELETE %(on_delete_action)s %(null)s %(unique)s', 'reference FK': ', CONSTRAINT "FK_%(constraint_name)s" FOREIGN KEY (%(field_name)s) REFERENCES %(foreign_key)s ON DELETE %(on_delete_action)s', 'reference TFK': ' CONSTRAINT "FK_%(foreign_table)s_PK" FOREIGN KEY (%(field_name)s) REFERENCES %(foreign_table)s (%(foreign_key)s) ON DELETE %(on_delete_action)s', }
[docs] def parse_list_integers(self, value, field_type): return value
[docs] def parse_list_references(self, value, field_type): return [self.parse_reference(r, field_type[5:]) for r in value]
[docs] def parse_list_strings(self, value, field_type): return value
[docs] def represent(self, obj, fieldtype): field_is_type = fieldtype.startswith if field_is_type('list:'): if not obj: obj = [] elif not isinstance(obj, (list, tuple)): obj = [obj] if field_is_type('list:string'): obj = map(str,obj) else: obj = map(int,obj) return 'ARRAY[%s]' % ','.join(repr(item) for item in obj) return PostgreSQLAdapter.represent(self, obj, fieldtype)
[docs] def CONTAINS(self, first, second, case_sensitive=True): if first.type.startswith('list'): f = self.expand(second, 'string') s = self.ANY(first) if case_sensitive is True: return self.EQ(f, s) else: return self.ILIKE(f, s, escape='\\') else: return PostgreSQLAdapter.CONTAINS(self, first, second, case_sensitive=case_sensitive)
[docs] def ANY(self, first): return "ANY(%s)" % self.expand(first)
[docs] def ILIKE(self, first, second, escape=None): if first and 'type' not in first: args = (first, self.expand(second)) ilike = '(%s ILIKE %s)' % args else: ilike = PostgreSQLAdapter.ILIKE(self, first, second, escape=escape) return ilike
[docs] def EQ(self, first, second=None): if first and 'type' not in first: eq = '(%s = %s)' % (first, self.expand(second)) else: eq = PostgreSQLAdapter.EQ(self, first, second) return eq
[docs]class JDBCPostgreSQLAdapter(PostgreSQLAdapter): drivers = ('zxJDBC',) REGEX_URI = re.compile('^(?P<user>[^:@]+)(\:(?P<password>[^@]*))?@(?P<host>\[[^/]+\]|[^\:/]+)(\:(?P<port>[0-9]+))?/(?P<db>.+)$') def __init__(self,db,uri,pool_size=0,folder=None,db_codec ='UTF-8', credential_decoder=IDENTITY, driver_args={}, adapter_args={}, do_connect=True, after_connection=None ): self.db = db self.dbengine = "postgres" self.uri = uri if do_connect: self.find_driver(adapter_args,uri) self.pool_size = pool_size self.folder = folder self.db_codec = db_codec self._after_connection = after_connection self.find_or_make_work_folder() ruri = uri.split('://',1)[1] m = self.REGEX_URI.match(ruri) if not m: raise SyntaxError("Invalid URI string in DAL") user = credential_decoder(m.group('user')) if not user: raise SyntaxError('User required') password = credential_decoder(m.group('password')) if not password: password = '' host = m.group('host') if not host: raise SyntaxError('Host name required') db = m.group('db') if not db: raise SyntaxError('Database name required') port = m.group('port') or '5432' msg = ('jdbc:postgresql://%s:%s/%s' % (host, port, db), user, password) def connector(msg=msg,driver_args=driver_args): return self.driver.connect(*msg,**driver_args) self.connector = connector if do_connect: self.reconnect()
[docs] def after_connection(self): self.connection.set_client_encoding('UTF8') self.execute('BEGIN;') self.execute("SET CLIENT_ENCODING TO 'UNICODE';") self.try_json()