Package web :: Package web :: Module db
[hide private]
[frames] | no frames]

Source Code for Module web.web.db

   1  """ 
   2  Database API 
   3  (part of web.py) 
   4  """ 
   5   
   6  __all__ = [ 
   7    "UnknownParamstyle", "UnknownDB", "TransactionError",  
   8    "sqllist", "sqlors", "reparam", "sqlquote", 
   9    "SQLQuery", "SQLParam", "sqlparam", 
  10    "SQLLiteral", "sqlliteral", 
  11    "database", 'DB', 
  12  ] 
  13   
  14  import time 
  15  try: 
  16      import datetime 
  17  except ImportError: 
  18      datetime = None 
  19   
  20  try: set 
  21  except NameError: 
  22      from sets import Set as set 
  23       
  24  from utils import threadeddict, storage, iters, iterbetter, safestr, safeunicode 
  25   
  26  try: 
  27      # db module can work independent of web.py 
  28      from webapi import debug, config 
  29  except: 
  30      import sys 
  31      debug = sys.stderr 
  32      config = storage() 
  33   
34 -class UnknownDB(Exception):
35 """raised for unsupported dbms""" 36 pass
37
38 -class _ItplError(ValueError):
39 - def __init__(self, text, pos):
40 ValueError.__init__(self) 41 self.text = text 42 self.pos = pos
43 - def __str__(self):
44 return "unfinished expression in %s at char %d" % ( 45 repr(self.text), self.pos)
46
47 -class TransactionError(Exception): pass
48
49 -class UnknownParamstyle(Exception):
50 """ 51 raised for unsupported db paramstyles 52 53 (currently supported: qmark, numeric, format, pyformat) 54 """ 55 pass 56
57 -class SQLParam:
58 """ 59 Parameter in SQLQuery. 60 61 >>> q = SQLQuery(["SELECT * FROM test WHERE name=", SQLParam("joe")]) 62 >>> q 63 <sql: "SELECT * FROM test WHERE name='joe'"> 64 >>> q.query() 65 'SELECT * FROM test WHERE name=%s' 66 >>> q.values() 67 ['joe'] 68 """
69 - def __init__(self, value):
70 self.value = value
71
72 - def get_marker(self, paramstyle='pyformat'):
73 if paramstyle == 'qmark': 74 return '?' 75 elif paramstyle == 'numeric': 76 return ':1' 77 elif paramstyle is None or paramstyle in ['format', 'pyformat']: 78 return '%s' 79 raise UnknownParamstyle, paramstyle
80
81 - def sqlquery(self):
82 return SQLQuery([self])
83
84 - def __add__(self, other):
85 return self.sqlquery() + other
86
87 - def __radd__(self, other):
88 return other + self.sqlquery()
89
90 - def __str__(self):
91 return str(self.value) 92
93 - def __repr__(self):
94 return '<param: %s>' % repr(self.value)
95 96 sqlparam = SQLParam 97
98 -class SQLQuery:
99 """ 100 You can pass this sort of thing as a clause in any db function. 101 Otherwise, you can pass a dictionary to the keyword argument `vars` 102 and the function will call reparam for you. 103 104 Internally, consists of `items`, which is a list of strings and 105 SQLParams, which get concatenated to produce the actual query. 106 """ 107 # tested in sqlquote's docstring
108 - def __init__(self, items=[]):
109 r"""Creates a new SQLQuery. 110 111 >>> SQLQuery("x") 112 <sql: 'x'> 113 >>> q = SQLQuery(['SELECT * FROM ', 'test', ' WHERE x=', SQLParam(1)]) 114 >>> q 115 <sql: 'SELECT * FROM test WHERE x=1'> 116 >>> q.query(), q.values() 117 ('SELECT * FROM test WHERE x=%s', [1]) 118 >>> SQLQuery(SQLParam(1)) 119 <sql: '1'> 120 """ 121 if isinstance(items, list): 122 self.items = items 123 elif isinstance(items, SQLParam): 124 self.items = [items] 125 elif isinstance(items, SQLQuery): 126 self.items = list(items.items) 127 else: 128 self.items = [items] 129 130 # Take care of SQLLiterals 131 for i, item in enumerate(self.items): 132 if isinstance(item, SQLParam) and isinstance(item.value, SQLLiteral): 133 self.items[i] = item.value.v
134
135 - def __add__(self, other):
136 if isinstance(other, basestring): 137 items = [other] 138 elif isinstance(other, SQLQuery): 139 items = other.items 140 else: 141 return NotImplemented 142 return SQLQuery(self.items + items)
143
144 - def __radd__(self, other):
145 if isinstance(other, basestring): 146 items = [other] 147 else: 148 return NotImplemented 149 150 return SQLQuery(items + self.items)
151
152 - def __iadd__(self, other):
153 if isinstance(other, basestring): 154 items = [other] 155 elif isinstance(other, SQLQuery): 156 items = other.items 157 else: 158 return NotImplemented 159 self.items.extend(items) 160 return self
161
162 - def __len__(self):
163 return len(self.query())
164
165 - def query(self, paramstyle=None):
166 """ 167 Returns the query part of the sql query. 168 >>> q = SQLQuery(["SELECT * FROM test WHERE name=", SQLParam('joe')]) 169 >>> q.query() 170 'SELECT * FROM test WHERE name=%s' 171 >>> q.query(paramstyle='qmark') 172 'SELECT * FROM test WHERE name=?' 173 """ 174 s = [] 175 for x in self.items: 176 if isinstance(x, SQLParam): 177 x = x.get_marker(paramstyle) 178 s.append(safestr(x)) 179 else: 180 x = safestr(x) 181 # automatically escape % characters in the query 182 # For backward compatability, ignore escaping when the query looks already escaped 183 if paramstyle in ['format', 'pyformat']: 184 if '%' in x and '%%' not in x: 185 x = x.replace('%', '%%') 186 s.append(x) 187 return "".join(s)
188
189 - def values(self):
190 """ 191 Returns the values of the parameters used in the sql query. 192 >>> q = SQLQuery(["SELECT * FROM test WHERE name=", SQLParam('joe')]) 193 >>> q.values() 194 ['joe'] 195 """ 196 return [i.value for i in self.items if isinstance(i, SQLParam)]
197
198 - def join(items, sep=' '):
199 """ 200 Joins multiple queries. 201 202 >>> SQLQuery.join(['a', 'b'], ', ') 203 <sql: 'a, b'> 204 """ 205 if len(items) == 0: 206 return SQLQuery("") 207 208 q = SQLQuery(items[0]) 209 for item in items[1:]: 210 q += sep 211 q += item 212 return q
213 214 join = staticmethod(join) 215
216 - def _str(self):
217 try: 218 return self.query() % tuple([sqlify(x) for x in self.values()]) 219 except (ValueError, TypeError): 220 return self.query()
221
222 - def __str__(self):
223 return safestr(self._str())
224
225 - def __unicode__(self):
226 return safeunicode(self._str())
227
228 - def __repr__(self):
229 return '<sql: %s>' % repr(str(self))
230
231 -class SQLLiteral:
232 """ 233 Protects a string from `sqlquote`. 234 235 >>> sqlquote('NOW()') 236 <sql: "'NOW()'"> 237 >>> sqlquote(SQLLiteral('NOW()')) 238 <sql: 'NOW()'> 239 """
240 - def __init__(self, v):
241 self.v = v 242
243 - def __repr__(self):
244 return self.v 245 246 sqlliteral = SQLLiteral 247
248 -def _sqllist(values):
249 """ 250 >>> _sqllist([1, 2, 3]) 251 <sql: '(1, 2, 3)'> 252 """ 253 items = [] 254 items.append('(') 255 for i, v in enumerate(values): 256 if i != 0: 257 items.append(', ') 258 items.append(sqlparam(v)) 259 items.append(')') 260 return SQLQuery(items)
261
262 -def reparam(string_, dictionary):
263 """ 264 Takes a string and a dictionary and interpolates the string 265 using values from the dictionary. Returns an `SQLQuery` for the result. 266 267 >>> reparam("s = $s", dict(s=True)) 268 <sql: "s = 't'"> 269 >>> reparam("s IN $s", dict(s=[1, 2])) 270 <sql: 's IN (1, 2)'> 271 """ 272 dictionary = dictionary.copy() # eval mucks with it 273 vals = [] 274 result = [] 275 for live, chunk in _interpolate(string_): 276 if live: 277 v = eval(chunk, dictionary) 278 result.append(sqlquote(v)) 279 else: 280 result.append(chunk) 281 return SQLQuery.join(result, '') 282
283 -def sqlify(obj):
284 """ 285 converts `obj` to its proper SQL version 286 287 >>> sqlify(None) 288 'NULL' 289 >>> sqlify(True) 290 "'t'" 291 >>> sqlify(3) 292 '3' 293 """ 294 # because `1 == True and hash(1) == hash(True)` 295 # we have to do this the hard way... 296 297 if obj is None: 298 return 'NULL' 299 elif obj is True: 300 return "'t'" 301 elif obj is False: 302 return "'f'" 303 elif datetime and isinstance(obj, datetime.datetime): 304 return repr(obj.isoformat()) 305 else: 306 if isinstance(obj, unicode): obj = obj.encode('utf8') 307 return repr(obj) 308
309 -def sqllist(lst):
310 """ 311 Converts the arguments for use in something like a WHERE clause. 312 313 >>> sqllist(['a', 'b']) 314 'a, b' 315 >>> sqllist('a') 316 'a' 317 >>> sqllist(u'abc') 318 u'abc' 319 """ 320 if isinstance(lst, basestring): 321 return lst 322 else: 323 return ', '.join(lst) 324
325 -def sqlors(left, lst):
326 """ 327 `left is a SQL clause like `tablename.arg = ` 328 and `lst` is a list of values. Returns a reparam-style 329 pair featuring the SQL that ORs together the clause 330 for each item in the lst. 331 332 >>> sqlors('foo = ', []) 333 <sql: '1=2'> 334 >>> sqlors('foo = ', [1]) 335 <sql: 'foo = 1'> 336 >>> sqlors('foo = ', 1) 337 <sql: 'foo = 1'> 338 >>> sqlors('foo = ', [1,2,3]) 339 <sql: '(foo = 1 OR foo = 2 OR foo = 3 OR 1=2)'> 340 """ 341 if isinstance(lst, iters): 342 lst = list(lst) 343 ln = len(lst) 344 if ln == 0: 345 return SQLQuery("1=2") 346 if ln == 1: 347 lst = lst[0] 348 349 if isinstance(lst, iters): 350 return SQLQuery(['('] + 351 sum([[left, sqlparam(x), ' OR '] for x in lst], []) + 352 ['1=2)'] 353 ) 354 else: 355 return left + sqlparam(lst)
356
357 -def sqlwhere(dictionary, grouping=' AND '):
358 """ 359 Converts a `dictionary` to an SQL WHERE clause `SQLQuery`. 360 361 >>> sqlwhere({'cust_id': 2, 'order_id':3}) 362 <sql: 'order_id = 3 AND cust_id = 2'> 363 >>> sqlwhere({'cust_id': 2, 'order_id':3}, grouping=', ') 364 <sql: 'order_id = 3, cust_id = 2'> 365 >>> sqlwhere({'a': 'a', 'b': 'b'}).query() 366 'a = %s AND b = %s' 367 """ 368 return SQLQuery.join([k + ' = ' + sqlparam(v) for k, v in dictionary.items()], grouping) 369
370 -def sqlquote(a):
371 """ 372 Ensures `a` is quoted properly for use in a SQL query. 373 374 >>> 'WHERE x = ' + sqlquote(True) + ' AND y = ' + sqlquote(3) 375 <sql: "WHERE x = 't' AND y = 3"> 376 >>> 'WHERE x = ' + sqlquote(True) + ' AND y IN ' + sqlquote([2, 3]) 377 <sql: "WHERE x = 't' AND y IN (2, 3)"> 378 """ 379 if isinstance(a, list): 380 return _sqllist(a) 381 else: 382 return sqlparam(a).sqlquery() 383
384 -class Transaction:
385 """Database transaction."""
386 - def __init__(self, ctx):
387 self.ctx = ctx 388 self.transaction_count = transaction_count = len(ctx.transactions) 389 390 class transaction_engine: 391 """Transaction Engine used in top level transactions.""" 392 def do_transact(self): 393 ctx.commit(unload=False)
394 395 def do_commit(self): 396 ctx.commit()
397 398 def do_rollback(self): 399 ctx.rollback() 400 401 class subtransaction_engine: 402 """Transaction Engine used in sub transactions.""" 403 def query(self, q): 404 db_cursor = ctx.db.cursor() 405 ctx.db_execute(db_cursor, SQLQuery(q % transaction_count)) 406 407 def do_transact(self): 408 self.query('SAVEPOINT webpy_sp_%s') 409 410 def do_commit(self): 411 self.query('RELEASE SAVEPOINT webpy_sp_%s') 412 413 def do_rollback(self): 414 self.query('ROLLBACK TO SAVEPOINT webpy_sp_%s') 415 416 class dummy_engine: 417 """Transaction Engine used instead of subtransaction_engine 418 when sub transactions are not supported.""" 419 do_transact = do_commit = do_rollback = lambda self: None 420 421 if self.transaction_count: 422 # nested transactions are not supported in some databases 423 if self.ctx.get('ignore_nested_transactions'): 424 self.engine = dummy_engine() 425 else: 426 self.engine = subtransaction_engine() 427 else: 428 self.engine = transaction_engine() 429 430 self.engine.do_transact() 431 self.ctx.transactions.append(self) 432
433 - def __enter__(self):
434 return self
435
436 - def __exit__(self, exctype, excvalue, traceback):
437 if exctype is not None: 438 self.rollback() 439 else: 440 self.commit()
441
442 - def commit(self):
443 if len(self.ctx.transactions) > self.transaction_count: 444 self.engine.do_commit() 445 self.ctx.transactions = self.ctx.transactions[:self.transaction_count]
446
447 - def rollback(self):
448 if len(self.ctx.transactions) > self.transaction_count: 449 self.engine.do_rollback() 450 self.ctx.transactions = self.ctx.transactions[:self.transaction_count]
451
452 -class DB:
453 """Database"""
454 - def __init__(self, db_module, keywords):
455 """Creates a database. 456 """ 457 # some DB implementaions take optional paramater `driver` to use a specific driver modue 458 # but it should not be passed to connect 459 keywords.pop('driver', None) 460 461 self.db_module = db_module 462 self.keywords = keywords 463 464 465 self._ctx = threadeddict() 466 # flag to enable/disable printing queries 467 self.printing = config.get('debug', False) 468 self.supports_multiple_insert = False 469 470 try: 471 import DBUtils 472 # enable pooling if DBUtils module is available. 473 self.has_pooling = True 474 except ImportError: 475 self.has_pooling = False 476 477 # Pooling can be disabled by passing pooling=False in the keywords. 478 self.has_pooling = self.keywords.pop('pooling', True) and self.has_pooling
479
480 - def _getctx(self):
481 if not self._ctx.get('db'): 482 self._load_context(self._ctx) 483 return self._ctx 484 ctx = property(_getctx) 485
486 - def _load_context(self, ctx):
487 ctx.dbq_count = 0 488 ctx.transactions = [] # stack of transactions 489 490 if self.has_pooling: 491 ctx.db = self._connect_with_pooling(self.keywords) 492 else: 493 ctx.db = self._connect(self.keywords) 494 ctx.db_execute = self._db_execute 495 496 if not hasattr(ctx.db, 'commit'): 497 ctx.db.commit = lambda: None 498 499 if not hasattr(ctx.db, 'rollback'): 500 ctx.db.rollback = lambda: None 501 502 def commit(unload=True): 503 # do db commit and release the connection if pooling is enabled. 504 ctx.db.commit() 505 if unload and self.has_pooling: 506 self._unload_context(self._ctx)
507 508 def rollback(): 509 # do db rollback and release the connection if pooling is enabled. 510 ctx.db.rollback() 511 if self.has_pooling: 512 self._unload_context(self._ctx) 513 514 ctx.commit = commit 515 ctx.rollback = rollback 516
517 - def _unload_context(self, ctx):
518 del ctx.db
519
520 - def _connect(self, keywords):
521 return self.db_module.connect(**keywords)
522
523 - def _connect_with_pooling(self, keywords):
524 def get_pooled_db(): 525 from DBUtils import PooledDB 526 527 # In DBUtils 0.9.3, `dbapi` argument is renamed as `creator` 528 # see Bug#122112 529 530 if PooledDB.__version__.split('.') < '0.9.3'.split('.'): 531 return PooledDB.PooledDB(dbapi=self.db_module, **keywords) 532 else: 533 return PooledDB.PooledDB(creator=self.db_module, **keywords)
534 535 if getattr(self, '_pooleddb', None) is None: 536 self._pooleddb = get_pooled_db() 537 538 return self._pooleddb.connection() 539
540 - def _db_cursor(self):
541 return self.ctx.db.cursor()
542
543 - def _param_marker(self):
544 """Returns parameter marker based on paramstyle attribute if this database.""" 545 style = getattr(self, 'paramstyle', 'pyformat') 546 547 if style == 'qmark': 548 return '?' 549 elif style == 'numeric': 550 return ':1' 551 elif style in ['format', 'pyformat']: 552 return '%s' 553 raise UnknownParamstyle, style
554
555 - def _db_execute(self, cur, sql_query):
556 """executes an sql query""" 557 self.ctx.dbq_count += 1 558 559 try: 560 a = time.time() 561 paramstyle = getattr(self, 'paramstyle', 'pyformat') 562 out = cur.execute(sql_query.query(paramstyle), sql_query.values()) 563 b = time.time() 564 except: 565 if self.printing: 566 print >> debug, 'ERR:', str(sql_query) 567 if self.ctx.transactions: 568 self.ctx.transactions[-1].rollback() 569 else: 570 self.ctx.rollback() 571 raise 572 573 if self.printing: 574 print >> debug, '%s (%s): %s' % (round(b-a, 2), self.ctx.dbq_count, str(sql_query)) 575 return out 576
577 - def _where(self, where, vars):
578 if isinstance(where, (int, long)): 579 where = "id = " + sqlparam(where) 580 #@@@ for backward-compatibility 581 elif isinstance(where, (list, tuple)) and len(where) == 2: 582 where = SQLQuery(where[0], where[1]) 583 elif isinstance(where, SQLQuery): 584 pass 585 else: 586 where = reparam(where, vars) 587 return where 588
589 - def query(self, sql_query, vars=None, processed=False, _test=False):
590 """ 591 Execute SQL query `sql_query` using dictionary `vars` to interpolate it. 592 If `processed=True`, `vars` is a `reparam`-style list to use 593 instead of interpolating. 594 595 >>> db = DB(None, {}) 596 >>> db.query("SELECT * FROM foo", _test=True) 597 <sql: 'SELECT * FROM foo'> 598 >>> db.query("SELECT * FROM foo WHERE x = $x", vars=dict(x='f'), _test=True) 599 <sql: "SELECT * FROM foo WHERE x = 'f'"> 600 >>> db.query("SELECT * FROM foo WHERE x = " + sqlquote('f'), _test=True) 601 <sql: "SELECT * FROM foo WHERE x = 'f'"> 602 """ 603 if vars is None: vars = {} 604 605 if not processed and not isinstance(sql_query, SQLQuery): 606 sql_query = reparam(sql_query, vars) 607 608 if _test: return sql_query 609 610 db_cursor = self._db_cursor() 611 self._db_execute(db_cursor, sql_query) 612 613 if db_cursor.description: 614 names = [x[0] for x in db_cursor.description] 615 def iterwrapper(): 616 row = db_cursor.fetchone() 617 while row: 618 yield storage(dict(zip(names, row))) 619 row = db_cursor.fetchone() 620 out = iterbetter(iterwrapper()) 621 out.__len__ = lambda: int(db_cursor.rowcount) 622 out.list = lambda: [storage(dict(zip(names, x))) \ 623 for x in db_cursor.fetchall()] 624 else: 625 out = db_cursor.rowcount 626 627 if not self.ctx.transactions: 628 self.ctx.commit() 629 return out 630
631 - def select(self, tables, vars=None, what='*', where=None, order=None, group=None, 632 limit=None, offset=None, _test=False):
633 """ 634 Selects `what` from `tables` with clauses `where`, `order`, 635 `group`, `limit`, and `offset`. Uses vars to interpolate. 636 Otherwise, each clause can be a SQLQuery. 637 638 >>> db = DB(None, {}) 639 >>> db.select('foo', _test=True) 640 <sql: 'SELECT * FROM foo'> 641 >>> db.select(['foo', 'bar'], where="foo.bar_id = bar.id", limit=5, _test=True) 642 <sql: 'SELECT * FROM foo, bar WHERE foo.bar_id = bar.id LIMIT 5'> 643 """ 644 if vars is None: vars = {} 645 sql_clauses = self.sql_clauses(what, tables, where, group, order, limit, offset) 646 clauses = [self.gen_clause(sql, val, vars) for sql, val in sql_clauses if val is not None] 647 qout = SQLQuery.join(clauses) 648 if _test: return qout 649 return self.query(qout, processed=True) 650
651 - def where(self, table, what='*', order=None, group=None, limit=None, 652 offset=None, _test=False, **kwargs):
653 """ 654 Selects from `table` where keys are equal to values in `kwargs`. 655 656 >>> db = DB(None, {}) 657 >>> db.where('foo', bar_id=3, _test=True) 658 <sql: 'SELECT * FROM foo WHERE bar_id = 3'> 659 >>> db.where('foo', source=2, crust='dewey', _test=True) 660 <sql: "SELECT * FROM foo WHERE source = 2 AND crust = 'dewey'"> 661 """ 662 where = [] 663 for k, v in kwargs.iteritems(): 664 where.append(k + ' = ' + sqlquote(v)) 665 return self.select(table, what=what, order=order, 666 group=group, limit=limit, offset=offset, _test=_test, 667 where=SQLQuery.join(where, ' AND '))
668
669 - def sql_clauses(self, what, tables, where, group, order, limit, offset):
670 return ( 671 ('SELECT', what), 672 ('FROM', sqllist(tables)), 673 ('WHERE', where), 674 ('GROUP BY', group), 675 ('ORDER BY', order), 676 ('LIMIT', limit), 677 ('OFFSET', offset)) 678
679 - def gen_clause(self, sql, val, vars):
680 if isinstance(val, (int, long)): 681 if sql == 'WHERE': 682 nout = 'id = ' + sqlquote(val) 683 else: 684 nout = SQLQuery(val) 685 #@@@ 686 elif isinstance(val, (list, tuple)) and len(val) == 2: 687 nout = SQLQuery(val[0], val[1]) # backwards-compatibility 688 elif isinstance(val, SQLQuery): 689 nout = val 690 else: 691 nout = reparam(val, vars) 692 693 def xjoin(a, b): 694 if a and b: return a + ' ' + b 695 else: return a or b 696 697 return xjoin(sql, nout) 698
699 - def insert(self, tablename, seqname=None, _test=False, **values):
700 """ 701 Inserts `values` into `tablename`. Returns current sequence ID. 702 Set `seqname` to the ID if it's not the default, or to `False` 703 if there isn't one. 704 705 >>> db = DB(None, {}) 706 >>> q = db.insert('foo', name='bob', age=2, created=SQLLiteral('NOW()'), _test=True) 707 >>> q 708 <sql: "INSERT INTO foo (age, name, created) VALUES (2, 'bob', NOW())"> 709 >>> q.query() 710 'INSERT INTO foo (age, name, created) VALUES (%s, %s, NOW())' 711 >>> q.values() 712 [2, 'bob'] 713 """ 714 def q(x): return "(" + x + ")" 715 716 if values: 717 _keys = SQLQuery.join(values.keys(), ', ') 718 _values = SQLQuery.join([sqlparam(v) for v in values.values()], ', ') 719 sql_query = "INSERT INTO %s " % tablename + q(_keys) + ' VALUES ' + q(_values) 720 else: 721 sql_query = SQLQuery("INSERT INTO %s DEFAULT VALUES" % tablename) 722 723 if _test: return sql_query 724 725 db_cursor = self._db_cursor() 726 if seqname is not False: 727 sql_query = self._process_insert_query(sql_query, tablename, seqname) 728 729 if isinstance(sql_query, tuple): 730 # for some databases, a separate query has to be made to find 731 # the id of the inserted row. 732 q1, q2 = sql_query 733 self._db_execute(db_cursor, q1) 734 self._db_execute(db_cursor, q2) 735 else: 736 self._db_execute(db_cursor, sql_query) 737 738 try: 739 out = db_cursor.fetchone()[0] 740 except Exception: 741 out = None 742 743 if not self.ctx.transactions: 744 self.ctx.commit() 745 return out 746
747 - def multiple_insert(self, tablename, values, seqname=None, _test=False):
748 """ 749 Inserts multiple rows into `tablename`. The `values` must be a list of dictioanries, 750 one for each row to be inserted, each with the same set of keys. 751 Returns the list of ids of the inserted rows. 752 Set `seqname` to the ID if it's not the default, or to `False` 753 if there isn't one. 754 755 >>> db = DB(None, {}) 756 >>> db.supports_multiple_insert = True 757 >>> values = [{"name": "foo", "email": "foo@example.com"}, {"name": "bar", "email": "bar@example.com"}] 758 >>> db.multiple_insert('person', values=values, _test=True) 759 <sql: "INSERT INTO person (name, email) VALUES ('foo', 'foo@example.com'), ('bar', 'bar@example.com')"> 760 """ 761 if not values: 762 return [] 763 764 if not self.supports_multiple_insert: 765 out = [self.insert(tablename, seqname=seqname, _test=_test, **v) for v in values] 766 if seqname is False: 767 return None 768 else: 769 return out 770 771 keys = values[0].keys() 772 #@@ make sure all keys are valid 773 774 # make sure all rows have same keys. 775 for v in values: 776 if v.keys() != keys: 777 raise ValueError, 'Bad data' 778 779 sql_query = SQLQuery('INSERT INTO %s (%s) VALUES ' % (tablename, ', '.join(keys))) 780 781 data = [] 782 for row in values: 783 d = SQLQuery.join([SQLParam(row[k]) for k in keys], ', ') 784 data.append('(' + d + ')') 785 sql_query += SQLQuery.join(data, ', ') 786 787 if _test: return sql_query 788 789 db_cursor = self._db_cursor() 790 if seqname is not False: 791 sql_query = self._process_insert_query(sql_query, tablename, seqname) 792 793 if isinstance(sql_query, tuple): 794 # for some databases, a separate query has to be made to find 795 # the id of the inserted row. 796 q1, q2 = sql_query 797 self._db_execute(db_cursor, q1) 798 self._db_execute(db_cursor, q2) 799 else: 800 self._db_execute(db_cursor, sql_query) 801 802 try: 803 out = db_cursor.fetchone()[0] 804 out = range(out-len(values)+1, out+1) 805 except Exception: 806 out = None 807 808 if not self.ctx.transactions: 809 self.ctx.commit() 810 return out
811 812
813 - def update(self, tables, where, vars=None, _test=False, **values):
814 """ 815 Update `tables` with clause `where` (interpolated using `vars`) 816 and setting `values`. 817 818 >>> db = DB(None, {}) 819 >>> name = 'Joseph' 820 >>> q = db.update('foo', where='name = $name', name='bob', age=2, 821 ... created=SQLLiteral('NOW()'), vars=locals(), _test=True) 822 >>> q 823 <sql: "UPDATE foo SET age = 2, name = 'bob', created = NOW() WHERE name = 'Joseph'"> 824 >>> q.query() 825 'UPDATE foo SET age = %s, name = %s, created = NOW() WHERE name = %s' 826 >>> q.values() 827 [2, 'bob', 'Joseph'] 828 """ 829 if vars is None: vars = {} 830 where = self._where(where, vars) 831 832 query = ( 833 "UPDATE " + sqllist(tables) + 834 " SET " + sqlwhere(values, ', ') + 835 " WHERE " + where) 836 837 if _test: return query 838 839 db_cursor = self._db_cursor() 840 self._db_execute(db_cursor, query) 841 if not self.ctx.transactions: 842 self.ctx.commit() 843 return db_cursor.rowcount 844
845 - def delete(self, table, where, using=None, vars=None, _test=False):
846 """ 847 Deletes from `table` with clauses `where` and `using`. 848 849 >>> db = DB(None, {}) 850 >>> name = 'Joe' 851 >>> db.delete('foo', where='name = $name', vars=locals(), _test=True) 852 <sql: "DELETE FROM foo WHERE name = 'Joe'"> 853 """ 854 if vars is None: vars = {} 855 where = self._where(where, vars) 856 857 q = 'DELETE FROM ' + table 858 if where: q += ' WHERE ' + where 859 if using: q += ' USING ' + sqllist(using) 860 861 if _test: return q 862 863 db_cursor = self._db_cursor() 864 self._db_execute(db_cursor, q) 865 if not self.ctx.transactions: 866 self.ctx.commit() 867 return db_cursor.rowcount 868
869 - def _process_insert_query(self, query, tablename, seqname):
870 return query
871
872 - def transaction(self):
873 """Start a transaction.""" 874 return Transaction(self.ctx) 875
876 -class PostgresDB(DB):
877 """Postgres driver."""
878 - def __init__(self, **keywords):
879 if 'pw' in keywords: 880 keywords['password'] = keywords.pop('pw') 881 882 db_module = import_driver(["psycopg2", "psycopg", "pgdb"], preferred=keywords.pop('driver', None)) 883 if db_module.__name__ == "psycopg2": 884 import psycopg2.extensions 885 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE) 886 887 # if db is not provided postgres driver will take it from PGDATABASE environment variable 888 if 'db' in keywords: 889 keywords['database'] = keywords.pop('db') 890 891 self.dbname = "postgres" 892 self.paramstyle = db_module.paramstyle 893 DB.__init__(self, db_module, keywords) 894 self.supports_multiple_insert = True 895 self._sequences = None
896
897 - def _process_insert_query(self, query, tablename, seqname):
898 if seqname is None: 899 # when seqname is not provided guess the seqname and make sure it exists 900 seqname = tablename + "_id_seq" 901 if seqname not in self._get_all_sequences(): 902 seqname = None 903 904 if seqname: 905 query += "; SELECT currval('%s')" % seqname 906 907 return query
908
909 - def _get_all_sequences(self):
910 """Query postgres to find names of all sequences used in this database.""" 911 if self._sequences is None: 912 q = "SELECT c.relname FROM pg_class c WHERE c.relkind = 'S'" 913 self._sequences = set([c.relname for c in self.query(q)]) 914 return self._sequences
915
916 - def _connect(self, keywords):
917 conn = DB._connect(self, keywords) 918 conn.set_client_encoding('UTF8') 919 return conn
920
921 - def _connect_with_pooling(self, keywords):
922 conn = DB._connect_with_pooling(self, keywords) 923 conn._con._con.set_client_encoding('UTF8') 924 return conn
925
926 -class MySQLDB(DB):
927 - def __init__(self, **keywords):
928 import MySQLdb as db 929 if 'pw' in keywords: 930 keywords['passwd'] = keywords['pw'] 931 del keywords['pw'] 932 933 if 'charset' not in keywords: 934 keywords['charset'] = 'utf8' 935 elif keywords['charset'] is None: 936 del keywords['charset'] 937 938 self.paramstyle = db.paramstyle = 'pyformat' # it's both, like psycopg 939 self.dbname = "mysql" 940 DB.__init__(self, db, keywords) 941 self.supports_multiple_insert = True
942
943 - def _process_insert_query(self, query, tablename, seqname):
944 return query, SQLQuery('SELECT last_insert_id();')
945
946 -def import_driver(drivers, preferred=None):
947 """Import the first available driver or preferred driver. 948 """ 949 if preferred: 950 drivers = [preferred] 951 952 for d in drivers: 953 try: 954 return __import__(d, None, None, ['x']) 955 except ImportError: 956 pass 957 raise ImportError("Unable to import " + " or ".join(drivers))
958
959 -class SqliteDB(DB):
960 - def __init__(self, **keywords):
961 db = import_driver(["sqlite3", "pysqlite2.dbapi2", "sqlite"], preferred=keywords.pop('driver', None)) 962 963 if db.__name__ in ["sqlite3", "pysqlite2.dbapi2"]: 964 db.paramstyle = 'qmark' 965 966 self.paramstyle = db.paramstyle 967 keywords['database'] = keywords.pop('db') 968 self.dbname = "sqlite" 969 DB.__init__(self, db, keywords)
970
971 - def _process_insert_query(self, query, tablename, seqname):
972 return query, SQLQuery('SELECT last_insert_rowid();')
973
974 - def query(self, *a, **kw):
975 out = DB.query(self, *a, **kw) 976 if isinstance(out, iterbetter): 977 del out.__len__ 978 return out
979
980 -class FirebirdDB(DB):
981 """Firebird Database. 982 """
983 - def __init__(self, **keywords):
984 try: 985 import kinterbasdb as db 986 except Exception: 987 db = None 988 pass 989 if 'pw' in keywords: 990 keywords['passwd'] = keywords['pw'] 991 del keywords['pw'] 992 keywords['database'] = keywords['db'] 993 del keywords['db'] 994 DB.__init__(self, db, keywords)
995
996 - def delete(self, table, where=None, using=None, vars=None, _test=False):
997 # firebird doesn't support using clause 998 using=None 999 return DB.delete(self, table, where, using, vars, _test)
1000
1001 - def sql_clauses(self, what, tables, where, group, order, limit, offset):
1002 return ( 1003 ('SELECT', ''), 1004 ('FIRST', limit), 1005 ('SKIP', offset), 1006 ('', what), 1007 ('FROM', sqllist(tables)), 1008 ('WHERE', where), 1009 ('GROUP BY', group), 1010 ('ORDER BY', order) 1011 )
1012
1013 -class MSSQLDB(DB):
1014 - def __init__(self, **keywords):
1015 import pymssql as db 1016 if 'pw' in keywords: 1017 keywords['password'] = keywords.pop('pw') 1018 keywords['database'] = keywords.pop('db') 1019 self.dbname = "mssql" 1020 DB.__init__(self, db, keywords)
1021
1022 - def sql_clauses(self, what, tables, where, group, order, limit, offset):
1023 return ( 1024 ('SELECT', what), 1025 ('TOP', limit), 1026 ('FROM', sqllist(tables)), 1027 ('WHERE', where), 1028 ('GROUP BY', group), 1029 ('ORDER BY', order), 1030 ('OFFSET', offset))
1031
1032 - def _test(self):
1033 """Test LIMIT. 1034 1035 Fake presence of pymssql module for running tests. 1036 >>> import sys 1037 >>> sys.modules['pymssql'] = sys.modules['sys'] 1038 1039 MSSQL has TOP clause instead of LIMIT clause. 1040 >>> db = MSSQLDB(db='test', user='joe', pw='secret') 1041 >>> db.select('foo', limit=4, _test=True) 1042 <sql: 'SELECT * TOP 4 FROM foo'> 1043 """ 1044 pass
1045
1046 -class OracleDB(DB):
1047 - def __init__(self, **keywords):
1048 import cx_Oracle as db 1049 if 'pw' in keywords: 1050 keywords['password'] = keywords.pop('pw') 1051 1052 #@@ TODO: use db.makedsn if host, port is specified 1053 keywords['dsn'] = keywords.pop('db') 1054 self.dbname = 'oracle' 1055 db.paramstyle = 'numeric' 1056 self.paramstyle = db.paramstyle 1057 1058 # oracle doesn't support pooling 1059 keywords.pop('pooling', None) 1060 DB.__init__(self, db, keywords) 1061
1062 - def _process_insert_query(self, query, tablename, seqname):
1063 if seqname is None: 1064 # It is not possible to get seq name from table name in Oracle 1065 return query 1066 else: 1067 return query + "; SELECT %s.currval FROM dual" % seqname 1068 1069 _databases = {}
1070 -def database(dburl=None, **params):
1071 """Creates appropriate database using params. 1072 1073 Pooling will be enabled if DBUtils module is available. 1074 Pooling can be disabled by passing pooling=False in params. 1075 """ 1076 dbn = params.pop('dbn') 1077 if dbn in _databases: 1078 return _databases[dbn](**params) 1079 else: 1080 raise UnknownDB, dbn
1081
1082 -def register_database(name, clazz):
1083 """ 1084 Register a database. 1085 1086 >>> class LegacyDB(DB): 1087 ... def __init__(self, **params): 1088 ... pass 1089 ... 1090 >>> register_database('legacy', LegacyDB) 1091 >>> db = database(dbn='legacy', db='test', user='joe', passwd='secret') 1092 """ 1093 _databases[name] = clazz
1094 1095 register_database('mysql', MySQLDB) 1096 register_database('postgres', PostgresDB) 1097 register_database('sqlite', SqliteDB) 1098 register_database('firebird', FirebirdDB) 1099 register_database('mssql', MSSQLDB) 1100 register_database('oracle', OracleDB) 1101
1102 -def _interpolate(format):
1103 """ 1104 Takes a format string and returns a list of 2-tuples of the form 1105 (boolean, string) where boolean says whether string should be evaled 1106 or not. 1107 1108 from <http://lfw.org/python/Itpl.py> (public domain, Ka-Ping Yee) 1109 """ 1110 from tokenize import tokenprog 1111 1112 def matchorfail(text, pos): 1113 match = tokenprog.match(text, pos) 1114 if match is None: 1115 raise _ItplError(text, pos) 1116 return match, match.end() 1117 1118 namechars = "abcdefghijklmnopqrstuvwxyz" \ 1119 "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_"; 1120 chunks = [] 1121 pos = 0 1122 1123 while 1: 1124 dollar = format.find("$", pos) 1125 if dollar < 0: 1126 break 1127 nextchar = format[dollar + 1] 1128 1129 if nextchar == "{": 1130 chunks.append((0, format[pos:dollar])) 1131 pos, level = dollar + 2, 1 1132 while level: 1133 match, pos = matchorfail(format, pos) 1134 tstart, tend = match.regs[3] 1135 token = format[tstart:tend] 1136 if token == "{": 1137 level = level + 1 1138 elif token == "}": 1139 level = level - 1 1140 chunks.append((1, format[dollar + 2:pos - 1])) 1141 1142 elif nextchar in namechars: 1143 chunks.append((0, format[pos:dollar])) 1144 match, pos = matchorfail(format, dollar + 1) 1145 while pos < len(format): 1146 if format[pos] == "." and \ 1147 pos + 1 < len(format) and format[pos + 1] in namechars: 1148 match, pos = matchorfail(format, pos + 1) 1149 elif format[pos] in "([": 1150 pos, level = pos + 1, 1 1151 while level: 1152 match, pos = matchorfail(format, pos) 1153 tstart, tend = match.regs[3] 1154 token = format[tstart:tend] 1155 if token[0] in "([": 1156 level = level + 1 1157 elif token[0] in ")]": 1158 level = level - 1 1159 else: 1160 break 1161 chunks.append((1, format[dollar + 1:pos])) 1162 else: 1163 chunks.append((0, format[pos:dollar + 1])) 1164 pos = dollar + 1 + (nextchar == "$") 1165 1166 if pos < len(format): 1167 chunks.append((0, format[pos:])) 1168 return chunks 1169 1170 if __name__ == "__main__": 1171 import doctest 1172 doctest.testmod() 1173