1 """
2 Database API
3 (part of web.py)
4 """
5
6 __all__ = [
7 "UnknownParamstyle", "UnknownDB", "TransactionError",
8 "sqllist", "sqlors", "reparam", "sqlquote",
9 "SQLQuery", "SQLParam", "sqlparam",
10 "SQLLiteral", "sqlliteral",
11 "database", 'DB',
12 ]
13
14 import time
15 try:
16 import datetime
17 except ImportError:
18 datetime = None
19
20 try: set
21 except NameError:
22 from sets import Set as set
23
24 from utils import threadeddict, storage, iters, iterbetter, safestr, safeunicode
25
26 try:
27
28 from webapi import debug, config
29 except:
30 import sys
31 debug = sys.stderr
32 config = storage()
33
35 """raised for unsupported dbms"""
36 pass
37
40 ValueError.__init__(self)
41 self.text = text
42 self.pos = pos
44 return "unfinished expression in %s at char %d" % (
45 repr(self.text), self.pos)
46
48
50 """
51 raised for unsupported db paramstyles
52
53 (currently supported: qmark, numeric, format, pyformat)
54 """
55 pass
56
58 """
59 Parameter in SQLQuery.
60
61 >>> q = SQLQuery(["SELECT * FROM test WHERE name=", SQLParam("joe")])
62 >>> q
63 <sql: "SELECT * FROM test WHERE name='joe'">
64 >>> q.query()
65 'SELECT * FROM test WHERE name=%s'
66 >>> q.values()
67 ['joe']
68 """
71
73 if paramstyle == 'qmark':
74 return '?'
75 elif paramstyle == 'numeric':
76 return ':1'
77 elif paramstyle is None or paramstyle in ['format', 'pyformat']:
78 return '%s'
79 raise UnknownParamstyle, paramstyle
80
82 return SQLQuery([self])
83
86
89
91 return str(self.value)
92
94 return '<param: %s>' % repr(self.value)
95
96 sqlparam = SQLParam
97
99 """
100 You can pass this sort of thing as a clause in any db function.
101 Otherwise, you can pass a dictionary to the keyword argument `vars`
102 and the function will call reparam for you.
103
104 Internally, consists of `items`, which is a list of strings and
105 SQLParams, which get concatenated to produce the actual query.
106 """
107
109 r"""Creates a new SQLQuery.
110
111 >>> SQLQuery("x")
112 <sql: 'x'>
113 >>> q = SQLQuery(['SELECT * FROM ', 'test', ' WHERE x=', SQLParam(1)])
114 >>> q
115 <sql: 'SELECT * FROM test WHERE x=1'>
116 >>> q.query(), q.values()
117 ('SELECT * FROM test WHERE x=%s', [1])
118 >>> SQLQuery(SQLParam(1))
119 <sql: '1'>
120 """
121 if isinstance(items, list):
122 self.items = items
123 elif isinstance(items, SQLParam):
124 self.items = [items]
125 elif isinstance(items, SQLQuery):
126 self.items = list(items.items)
127 else:
128 self.items = [items]
129
130
131 for i, item in enumerate(self.items):
132 if isinstance(item, SQLParam) and isinstance(item.value, SQLLiteral):
133 self.items[i] = item.value.v
134
136 if isinstance(other, basestring):
137 items = [other]
138 elif isinstance(other, SQLQuery):
139 items = other.items
140 else:
141 return NotImplemented
142 return SQLQuery(self.items + items)
143
145 if isinstance(other, basestring):
146 items = [other]
147 else:
148 return NotImplemented
149
150 return SQLQuery(items + self.items)
151
153 if isinstance(other, basestring):
154 items = [other]
155 elif isinstance(other, SQLQuery):
156 items = other.items
157 else:
158 return NotImplemented
159 self.items.extend(items)
160 return self
161
163 return len(self.query())
164
165 - def query(self, paramstyle=None):
166 """
167 Returns the query part of the sql query.
168 >>> q = SQLQuery(["SELECT * FROM test WHERE name=", SQLParam('joe')])
169 >>> q.query()
170 'SELECT * FROM test WHERE name=%s'
171 >>> q.query(paramstyle='qmark')
172 'SELECT * FROM test WHERE name=?'
173 """
174 s = []
175 for x in self.items:
176 if isinstance(x, SQLParam):
177 x = x.get_marker(paramstyle)
178 s.append(safestr(x))
179 else:
180 x = safestr(x)
181
182
183 if paramstyle in ['format', 'pyformat']:
184 if '%' in x and '%%' not in x:
185 x = x.replace('%', '%%')
186 s.append(x)
187 return "".join(s)
188
190 """
191 Returns the values of the parameters used in the sql query.
192 >>> q = SQLQuery(["SELECT * FROM test WHERE name=", SQLParam('joe')])
193 >>> q.values()
194 ['joe']
195 """
196 return [i.value for i in self.items if isinstance(i, SQLParam)]
197
198 - def join(items, sep=' '):
199 """
200 Joins multiple queries.
201
202 >>> SQLQuery.join(['a', 'b'], ', ')
203 <sql: 'a, b'>
204 """
205 if len(items) == 0:
206 return SQLQuery("")
207
208 q = SQLQuery(items[0])
209 for item in items[1:]:
210 q += sep
211 q += item
212 return q
213
214 join = staticmethod(join)
215
217 try:
218 return self.query() % tuple([sqlify(x) for x in self.values()])
219 except (ValueError, TypeError):
220 return self.query()
221
224
227
229 return '<sql: %s>' % repr(str(self))
230
232 """
233 Protects a string from `sqlquote`.
234
235 >>> sqlquote('NOW()')
236 <sql: "'NOW()'">
237 >>> sqlquote(SQLLiteral('NOW()'))
238 <sql: 'NOW()'>
239 """
241 self.v = v
242
244 return self.v
245
246 sqlliteral = SQLLiteral
247
249 """
250 >>> _sqllist([1, 2, 3])
251 <sql: '(1, 2, 3)'>
252 """
253 items = []
254 items.append('(')
255 for i, v in enumerate(values):
256 if i != 0:
257 items.append(', ')
258 items.append(sqlparam(v))
259 items.append(')')
260 return SQLQuery(items)
261
263 """
264 Takes a string and a dictionary and interpolates the string
265 using values from the dictionary. Returns an `SQLQuery` for the result.
266
267 >>> reparam("s = $s", dict(s=True))
268 <sql: "s = 't'">
269 >>> reparam("s IN $s", dict(s=[1, 2]))
270 <sql: 's IN (1, 2)'>
271 """
272 dictionary = dictionary.copy()
273 vals = []
274 result = []
275 for live, chunk in _interpolate(string_):
276 if live:
277 v = eval(chunk, dictionary)
278 result.append(sqlquote(v))
279 else:
280 result.append(chunk)
281 return SQLQuery.join(result, '')
282
284 """
285 converts `obj` to its proper SQL version
286
287 >>> sqlify(None)
288 'NULL'
289 >>> sqlify(True)
290 "'t'"
291 >>> sqlify(3)
292 '3'
293 """
294
295
296
297 if obj is None:
298 return 'NULL'
299 elif obj is True:
300 return "'t'"
301 elif obj is False:
302 return "'f'"
303 elif datetime and isinstance(obj, datetime.datetime):
304 return repr(obj.isoformat())
305 else:
306 if isinstance(obj, unicode): obj = obj.encode('utf8')
307 return repr(obj)
308
310 """
311 Converts the arguments for use in something like a WHERE clause.
312
313 >>> sqllist(['a', 'b'])
314 'a, b'
315 >>> sqllist('a')
316 'a'
317 >>> sqllist(u'abc')
318 u'abc'
319 """
320 if isinstance(lst, basestring):
321 return lst
322 else:
323 return ', '.join(lst)
324
326 """
327 `left is a SQL clause like `tablename.arg = `
328 and `lst` is a list of values. Returns a reparam-style
329 pair featuring the SQL that ORs together the clause
330 for each item in the lst.
331
332 >>> sqlors('foo = ', [])
333 <sql: '1=2'>
334 >>> sqlors('foo = ', [1])
335 <sql: 'foo = 1'>
336 >>> sqlors('foo = ', 1)
337 <sql: 'foo = 1'>
338 >>> sqlors('foo = ', [1,2,3])
339 <sql: '(foo = 1 OR foo = 2 OR foo = 3 OR 1=2)'>
340 """
341 if isinstance(lst, iters):
342 lst = list(lst)
343 ln = len(lst)
344 if ln == 0:
345 return SQLQuery("1=2")
346 if ln == 1:
347 lst = lst[0]
348
349 if isinstance(lst, iters):
350 return SQLQuery(['('] +
351 sum([[left, sqlparam(x), ' OR '] for x in lst], []) +
352 ['1=2)']
353 )
354 else:
355 return left + sqlparam(lst)
356
357 -def sqlwhere(dictionary, grouping=' AND '):
358 """
359 Converts a `dictionary` to an SQL WHERE clause `SQLQuery`.
360
361 >>> sqlwhere({'cust_id': 2, 'order_id':3})
362 <sql: 'order_id = 3 AND cust_id = 2'>
363 >>> sqlwhere({'cust_id': 2, 'order_id':3}, grouping=', ')
364 <sql: 'order_id = 3, cust_id = 2'>
365 >>> sqlwhere({'a': 'a', 'b': 'b'}).query()
366 'a = %s AND b = %s'
367 """
368 return SQLQuery.join([k + ' = ' + sqlparam(v) for k, v in dictionary.items()], grouping)
369
371 """
372 Ensures `a` is quoted properly for use in a SQL query.
373
374 >>> 'WHERE x = ' + sqlquote(True) + ' AND y = ' + sqlquote(3)
375 <sql: "WHERE x = 't' AND y = 3">
376 >>> 'WHERE x = ' + sqlquote(True) + ' AND y IN ' + sqlquote([2, 3])
377 <sql: "WHERE x = 't' AND y IN (2, 3)">
378 """
379 if isinstance(a, list):
380 return _sqllist(a)
381 else:
382 return sqlparam(a).sqlquery()
383
385 """Database transaction."""
387 self.ctx = ctx
388 self.transaction_count = transaction_count = len(ctx.transactions)
389
390 class transaction_engine:
391 """Transaction Engine used in top level transactions."""
392 def do_transact(self):
393 ctx.commit(unload=False)
394
395 def do_commit(self):
396 ctx.commit()
397
398 def do_rollback(self):
399 ctx.rollback()
400
401 class subtransaction_engine:
402 """Transaction Engine used in sub transactions."""
403 def query(self, q):
404 db_cursor = ctx.db.cursor()
405 ctx.db_execute(db_cursor, SQLQuery(q % transaction_count))
406
407 def do_transact(self):
408 self.query('SAVEPOINT webpy_sp_%s')
409
410 def do_commit(self):
411 self.query('RELEASE SAVEPOINT webpy_sp_%s')
412
413 def do_rollback(self):
414 self.query('ROLLBACK TO SAVEPOINT webpy_sp_%s')
415
416 class dummy_engine:
417 """Transaction Engine used instead of subtransaction_engine
418 when sub transactions are not supported."""
419 do_transact = do_commit = do_rollback = lambda self: None
420
421 if self.transaction_count:
422
423 if self.ctx.get('ignore_nested_transactions'):
424 self.engine = dummy_engine()
425 else:
426 self.engine = subtransaction_engine()
427 else:
428 self.engine = transaction_engine()
429
430 self.engine.do_transact()
431 self.ctx.transactions.append(self)
432
435
436 - def __exit__(self, exctype, excvalue, traceback):
437 if exctype is not None:
438 self.rollback()
439 else:
440 self.commit()
441
443 if len(self.ctx.transactions) > self.transaction_count:
444 self.engine.do_commit()
445 self.ctx.transactions = self.ctx.transactions[:self.transaction_count]
446
448 if len(self.ctx.transactions) > self.transaction_count:
449 self.engine.do_rollback()
450 self.ctx.transactions = self.ctx.transactions[:self.transaction_count]
451
453 """Database"""
454 - def __init__(self, db_module, keywords):
455 """Creates a database.
456 """
457
458
459 keywords.pop('driver', None)
460
461 self.db_module = db_module
462 self.keywords = keywords
463
464
465 self._ctx = threadeddict()
466
467 self.printing = config.get('debug', False)
468 self.supports_multiple_insert = False
469
470 try:
471 import DBUtils
472
473 self.has_pooling = True
474 except ImportError:
475 self.has_pooling = False
476
477
478 self.has_pooling = self.keywords.pop('pooling', True) and self.has_pooling
479
481 if not self._ctx.get('db'):
482 self._load_context(self._ctx)
483 return self._ctx
484 ctx = property(_getctx)
485
486 - def _load_context(self, ctx):
487 ctx.dbq_count = 0
488 ctx.transactions = []
489
490 if self.has_pooling:
491 ctx.db = self._connect_with_pooling(self.keywords)
492 else:
493 ctx.db = self._connect(self.keywords)
494 ctx.db_execute = self._db_execute
495
496 if not hasattr(ctx.db, 'commit'):
497 ctx.db.commit = lambda: None
498
499 if not hasattr(ctx.db, 'rollback'):
500 ctx.db.rollback = lambda: None
501
502 def commit(unload=True):
503
504 ctx.db.commit()
505 if unload and self.has_pooling:
506 self._unload_context(self._ctx)
507
508 def rollback():
509
510 ctx.db.rollback()
511 if self.has_pooling:
512 self._unload_context(self._ctx)
513
514 ctx.commit = commit
515 ctx.rollback = rollback
516
517 - def _unload_context(self, ctx):
519
521 return self.db_module.connect(**keywords)
522
524 def get_pooled_db():
525 from DBUtils import PooledDB
526
527
528
529
530 if PooledDB.__version__.split('.') < '0.9.3'.split('.'):
531 return PooledDB.PooledDB(dbapi=self.db_module, **keywords)
532 else:
533 return PooledDB.PooledDB(creator=self.db_module, **keywords)
534
535 if getattr(self, '_pooleddb', None) is None:
536 self._pooleddb = get_pooled_db()
537
538 return self._pooleddb.connection()
539
541 return self.ctx.db.cursor()
542
544 """Returns parameter marker based on paramstyle attribute if this database."""
545 style = getattr(self, 'paramstyle', 'pyformat')
546
547 if style == 'qmark':
548 return '?'
549 elif style == 'numeric':
550 return ':1'
551 elif style in ['format', 'pyformat']:
552 return '%s'
553 raise UnknownParamstyle, style
554
556 """executes an sql query"""
557 self.ctx.dbq_count += 1
558
559 try:
560 a = time.time()
561 paramstyle = getattr(self, 'paramstyle', 'pyformat')
562 out = cur.execute(sql_query.query(paramstyle), sql_query.values())
563 b = time.time()
564 except:
565 if self.printing:
566 print >> debug, 'ERR:', str(sql_query)
567 if self.ctx.transactions:
568 self.ctx.transactions[-1].rollback()
569 else:
570 self.ctx.rollback()
571 raise
572
573 if self.printing:
574 print >> debug, '%s (%s): %s' % (round(b-a, 2), self.ctx.dbq_count, str(sql_query))
575 return out
576
577 - def _where(self, where, vars):
578 if isinstance(where, (int, long)):
579 where = "id = " + sqlparam(where)
580
581 elif isinstance(where, (list, tuple)) and len(where) == 2:
582 where = SQLQuery(where[0], where[1])
583 elif isinstance(where, SQLQuery):
584 pass
585 else:
586 where = reparam(where, vars)
587 return where
588
589 - def query(self, sql_query, vars=None, processed=False, _test=False):
590 """
591 Execute SQL query `sql_query` using dictionary `vars` to interpolate it.
592 If `processed=True`, `vars` is a `reparam`-style list to use
593 instead of interpolating.
594
595 >>> db = DB(None, {})
596 >>> db.query("SELECT * FROM foo", _test=True)
597 <sql: 'SELECT * FROM foo'>
598 >>> db.query("SELECT * FROM foo WHERE x = $x", vars=dict(x='f'), _test=True)
599 <sql: "SELECT * FROM foo WHERE x = 'f'">
600 >>> db.query("SELECT * FROM foo WHERE x = " + sqlquote('f'), _test=True)
601 <sql: "SELECT * FROM foo WHERE x = 'f'">
602 """
603 if vars is None: vars = {}
604
605 if not processed and not isinstance(sql_query, SQLQuery):
606 sql_query = reparam(sql_query, vars)
607
608 if _test: return sql_query
609
610 db_cursor = self._db_cursor()
611 self._db_execute(db_cursor, sql_query)
612
613 if db_cursor.description:
614 names = [x[0] for x in db_cursor.description]
615 def iterwrapper():
616 row = db_cursor.fetchone()
617 while row:
618 yield storage(dict(zip(names, row)))
619 row = db_cursor.fetchone()
620 out = iterbetter(iterwrapper())
621 out.__len__ = lambda: int(db_cursor.rowcount)
622 out.list = lambda: [storage(dict(zip(names, x))) \
623 for x in db_cursor.fetchall()]
624 else:
625 out = db_cursor.rowcount
626
627 if not self.ctx.transactions:
628 self.ctx.commit()
629 return out
630
631 - def select(self, tables, vars=None, what='*', where=None, order=None, group=None,
632 limit=None, offset=None, _test=False):
633 """
634 Selects `what` from `tables` with clauses `where`, `order`,
635 `group`, `limit`, and `offset`. Uses vars to interpolate.
636 Otherwise, each clause can be a SQLQuery.
637
638 >>> db = DB(None, {})
639 >>> db.select('foo', _test=True)
640 <sql: 'SELECT * FROM foo'>
641 >>> db.select(['foo', 'bar'], where="foo.bar_id = bar.id", limit=5, _test=True)
642 <sql: 'SELECT * FROM foo, bar WHERE foo.bar_id = bar.id LIMIT 5'>
643 """
644 if vars is None: vars = {}
645 sql_clauses = self.sql_clauses(what, tables, where, group, order, limit, offset)
646 clauses = [self.gen_clause(sql, val, vars) for sql, val in sql_clauses if val is not None]
647 qout = SQLQuery.join(clauses)
648 if _test: return qout
649 return self.query(qout, processed=True)
650
651 - def where(self, table, what='*', order=None, group=None, limit=None,
652 offset=None, _test=False, **kwargs):
653 """
654 Selects from `table` where keys are equal to values in `kwargs`.
655
656 >>> db = DB(None, {})
657 >>> db.where('foo', bar_id=3, _test=True)
658 <sql: 'SELECT * FROM foo WHERE bar_id = 3'>
659 >>> db.where('foo', source=2, crust='dewey', _test=True)
660 <sql: "SELECT * FROM foo WHERE source = 2 AND crust = 'dewey'">
661 """
662 where = []
663 for k, v in kwargs.iteritems():
664 where.append(k + ' = ' + sqlquote(v))
665 return self.select(table, what=what, order=order,
666 group=group, limit=limit, offset=offset, _test=_test,
667 where=SQLQuery.join(where, ' AND '))
668
669 - def sql_clauses(self, what, tables, where, group, order, limit, offset):
670 return (
671 ('SELECT', what),
672 ('FROM', sqllist(tables)),
673 ('WHERE', where),
674 ('GROUP BY', group),
675 ('ORDER BY', order),
676 ('LIMIT', limit),
677 ('OFFSET', offset))
678
680 if isinstance(val, (int, long)):
681 if sql == 'WHERE':
682 nout = 'id = ' + sqlquote(val)
683 else:
684 nout = SQLQuery(val)
685
686 elif isinstance(val, (list, tuple)) and len(val) == 2:
687 nout = SQLQuery(val[0], val[1])
688 elif isinstance(val, SQLQuery):
689 nout = val
690 else:
691 nout = reparam(val, vars)
692
693 def xjoin(a, b):
694 if a and b: return a + ' ' + b
695 else: return a or b
696
697 return xjoin(sql, nout)
698
699 - def insert(self, tablename, seqname=None, _test=False, **values):
700 """
701 Inserts `values` into `tablename`. Returns current sequence ID.
702 Set `seqname` to the ID if it's not the default, or to `False`
703 if there isn't one.
704
705 >>> db = DB(None, {})
706 >>> q = db.insert('foo', name='bob', age=2, created=SQLLiteral('NOW()'), _test=True)
707 >>> q
708 <sql: "INSERT INTO foo (age, name, created) VALUES (2, 'bob', NOW())">
709 >>> q.query()
710 'INSERT INTO foo (age, name, created) VALUES (%s, %s, NOW())'
711 >>> q.values()
712 [2, 'bob']
713 """
714 def q(x): return "(" + x + ")"
715
716 if values:
717 _keys = SQLQuery.join(values.keys(), ', ')
718 _values = SQLQuery.join([sqlparam(v) for v in values.values()], ', ')
719 sql_query = "INSERT INTO %s " % tablename + q(_keys) + ' VALUES ' + q(_values)
720 else:
721 sql_query = SQLQuery("INSERT INTO %s DEFAULT VALUES" % tablename)
722
723 if _test: return sql_query
724
725 db_cursor = self._db_cursor()
726 if seqname is not False:
727 sql_query = self._process_insert_query(sql_query, tablename, seqname)
728
729 if isinstance(sql_query, tuple):
730
731
732 q1, q2 = sql_query
733 self._db_execute(db_cursor, q1)
734 self._db_execute(db_cursor, q2)
735 else:
736 self._db_execute(db_cursor, sql_query)
737
738 try:
739 out = db_cursor.fetchone()[0]
740 except Exception:
741 out = None
742
743 if not self.ctx.transactions:
744 self.ctx.commit()
745 return out
746
748 """
749 Inserts multiple rows into `tablename`. The `values` must be a list of dictioanries,
750 one for each row to be inserted, each with the same set of keys.
751 Returns the list of ids of the inserted rows.
752 Set `seqname` to the ID if it's not the default, or to `False`
753 if there isn't one.
754
755 >>> db = DB(None, {})
756 >>> db.supports_multiple_insert = True
757 >>> values = [{"name": "foo", "email": "foo@example.com"}, {"name": "bar", "email": "bar@example.com"}]
758 >>> db.multiple_insert('person', values=values, _test=True)
759 <sql: "INSERT INTO person (name, email) VALUES ('foo', 'foo@example.com'), ('bar', 'bar@example.com')">
760 """
761 if not values:
762 return []
763
764 if not self.supports_multiple_insert:
765 out = [self.insert(tablename, seqname=seqname, _test=_test, **v) for v in values]
766 if seqname is False:
767 return None
768 else:
769 return out
770
771 keys = values[0].keys()
772
773
774
775 for v in values:
776 if v.keys() != keys:
777 raise ValueError, 'Bad data'
778
779 sql_query = SQLQuery('INSERT INTO %s (%s) VALUES ' % (tablename, ', '.join(keys)))
780
781 data = []
782 for row in values:
783 d = SQLQuery.join([SQLParam(row[k]) for k in keys], ', ')
784 data.append('(' + d + ')')
785 sql_query += SQLQuery.join(data, ', ')
786
787 if _test: return sql_query
788
789 db_cursor = self._db_cursor()
790 if seqname is not False:
791 sql_query = self._process_insert_query(sql_query, tablename, seqname)
792
793 if isinstance(sql_query, tuple):
794
795
796 q1, q2 = sql_query
797 self._db_execute(db_cursor, q1)
798 self._db_execute(db_cursor, q2)
799 else:
800 self._db_execute(db_cursor, sql_query)
801
802 try:
803 out = db_cursor.fetchone()[0]
804 out = range(out-len(values)+1, out+1)
805 except Exception:
806 out = None
807
808 if not self.ctx.transactions:
809 self.ctx.commit()
810 return out
811
812
813 - def update(self, tables, where, vars=None, _test=False, **values):
814 """
815 Update `tables` with clause `where` (interpolated using `vars`)
816 and setting `values`.
817
818 >>> db = DB(None, {})
819 >>> name = 'Joseph'
820 >>> q = db.update('foo', where='name = $name', name='bob', age=2,
821 ... created=SQLLiteral('NOW()'), vars=locals(), _test=True)
822 >>> q
823 <sql: "UPDATE foo SET age = 2, name = 'bob', created = NOW() WHERE name = 'Joseph'">
824 >>> q.query()
825 'UPDATE foo SET age = %s, name = %s, created = NOW() WHERE name = %s'
826 >>> q.values()
827 [2, 'bob', 'Joseph']
828 """
829 if vars is None: vars = {}
830 where = self._where(where, vars)
831
832 query = (
833 "UPDATE " + sqllist(tables) +
834 " SET " + sqlwhere(values, ', ') +
835 " WHERE " + where)
836
837 if _test: return query
838
839 db_cursor = self._db_cursor()
840 self._db_execute(db_cursor, query)
841 if not self.ctx.transactions:
842 self.ctx.commit()
843 return db_cursor.rowcount
844
845 - def delete(self, table, where, using=None, vars=None, _test=False):
846 """
847 Deletes from `table` with clauses `where` and `using`.
848
849 >>> db = DB(None, {})
850 >>> name = 'Joe'
851 >>> db.delete('foo', where='name = $name', vars=locals(), _test=True)
852 <sql: "DELETE FROM foo WHERE name = 'Joe'">
853 """
854 if vars is None: vars = {}
855 where = self._where(where, vars)
856
857 q = 'DELETE FROM ' + table
858 if where: q += ' WHERE ' + where
859 if using: q += ' USING ' + sqllist(using)
860
861 if _test: return q
862
863 db_cursor = self._db_cursor()
864 self._db_execute(db_cursor, q)
865 if not self.ctx.transactions:
866 self.ctx.commit()
867 return db_cursor.rowcount
868
871
873 """Start a transaction."""
874 return Transaction(self.ctx)
875
876 -class PostgresDB(DB):
877 """Postgres driver."""
878 - def __init__(self, **keywords):
879 if 'pw' in keywords:
880 keywords['password'] = keywords.pop('pw')
881
882 db_module = import_driver(["psycopg2", "psycopg", "pgdb"], preferred=keywords.pop('driver', None))
883 if db_module.__name__ == "psycopg2":
884 import psycopg2.extensions
885 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
886
887
888 if 'db' in keywords:
889 keywords['database'] = keywords.pop('db')
890
891 self.dbname = "postgres"
892 self.paramstyle = db_module.paramstyle
893 DB.__init__(self, db_module, keywords)
894 self.supports_multiple_insert = True
895 self._sequences = None
896
897 - def _process_insert_query(self, query, tablename, seqname):
898 if seqname is None:
899
900 seqname = tablename + "_id_seq"
901 if seqname not in self._get_all_sequences():
902 seqname = None
903
904 if seqname:
905 query += "; SELECT currval('%s')" % seqname
906
907 return query
908
910 """Query postgres to find names of all sequences used in this database."""
911 if self._sequences is None:
912 q = "SELECT c.relname FROM pg_class c WHERE c.relkind = 'S'"
913 self._sequences = set([c.relname for c in self.query(q)])
914 return self._sequences
915
916 - def _connect(self, keywords):
917 conn = DB._connect(self, keywords)
918 conn.set_client_encoding('UTF8')
919 return conn
920
921 - def _connect_with_pooling(self, keywords):
922 conn = DB._connect_with_pooling(self, keywords)
923 conn._con._con.set_client_encoding('UTF8')
924 return conn
925
928 import MySQLdb as db
929 if 'pw' in keywords:
930 keywords['passwd'] = keywords['pw']
931 del keywords['pw']
932
933 if 'charset' not in keywords:
934 keywords['charset'] = 'utf8'
935 elif keywords['charset'] is None:
936 del keywords['charset']
937
938 self.paramstyle = db.paramstyle = 'pyformat'
939 self.dbname = "mysql"
940 DB.__init__(self, db, keywords)
941 self.supports_multiple_insert = True
942
945
947 """Import the first available driver or preferred driver.
948 """
949 if preferred:
950 drivers = [preferred]
951
952 for d in drivers:
953 try:
954 return __import__(d, None, None, ['x'])
955 except ImportError:
956 pass
957 raise ImportError("Unable to import " + " or ".join(drivers))
958
961 db = import_driver(["sqlite3", "pysqlite2.dbapi2", "sqlite"], preferred=keywords.pop('driver', None))
962
963 if db.__name__ in ["sqlite3", "pysqlite2.dbapi2"]:
964 db.paramstyle = 'qmark'
965
966 self.paramstyle = db.paramstyle
967 keywords['database'] = keywords.pop('db')
968 self.dbname = "sqlite"
969 DB.__init__(self, db, keywords)
970
973
974 - def query(self, *a, **kw):
975 out = DB.query(self, *a, **kw)
976 if isinstance(out, iterbetter):
977 del out.__len__
978 return out
979
981 """Firebird Database.
982 """
984 try:
985 import kinterbasdb as db
986 except Exception:
987 db = None
988 pass
989 if 'pw' in keywords:
990 keywords['passwd'] = keywords['pw']
991 del keywords['pw']
992 keywords['database'] = keywords['db']
993 del keywords['db']
994 DB.__init__(self, db, keywords)
995
996 - def delete(self, table, where=None, using=None, vars=None, _test=False):
1000
1001 - def sql_clauses(self, what, tables, where, group, order, limit, offset):
1002 return (
1003 ('SELECT', ''),
1004 ('FIRST', limit),
1005 ('SKIP', offset),
1006 ('', what),
1007 ('FROM', sqllist(tables)),
1008 ('WHERE', where),
1009 ('GROUP BY', group),
1010 ('ORDER BY', order)
1011 )
1012
1015 import pymssql as db
1016 if 'pw' in keywords:
1017 keywords['password'] = keywords.pop('pw')
1018 keywords['database'] = keywords.pop('db')
1019 self.dbname = "mssql"
1020 DB.__init__(self, db, keywords)
1021
1022 - def sql_clauses(self, what, tables, where, group, order, limit, offset):
1023 return (
1024 ('SELECT', what),
1025 ('TOP', limit),
1026 ('FROM', sqllist(tables)),
1027 ('WHERE', where),
1028 ('GROUP BY', group),
1029 ('ORDER BY', order),
1030 ('OFFSET', offset))
1031
1033 """Test LIMIT.
1034
1035 Fake presence of pymssql module for running tests.
1036 >>> import sys
1037 >>> sys.modules['pymssql'] = sys.modules['sys']
1038
1039 MSSQL has TOP clause instead of LIMIT clause.
1040 >>> db = MSSQLDB(db='test', user='joe', pw='secret')
1041 >>> db.select('foo', limit=4, _test=True)
1042 <sql: 'SELECT * TOP 4 FROM foo'>
1043 """
1044 pass
1045
1048 import cx_Oracle as db
1049 if 'pw' in keywords:
1050 keywords['password'] = keywords.pop('pw')
1051
1052
1053 keywords['dsn'] = keywords.pop('db')
1054 self.dbname = 'oracle'
1055 db.paramstyle = 'numeric'
1056 self.paramstyle = db.paramstyle
1057
1058
1059 keywords.pop('pooling', None)
1060 DB.__init__(self, db, keywords)
1061
1063 if seqname is None:
1064
1065 return query
1066 else:
1067 return query + "; SELECT %s.currval FROM dual" % seqname
1068
1069 _databases = {}
1071 """Creates appropriate database using params.
1072
1073 Pooling will be enabled if DBUtils module is available.
1074 Pooling can be disabled by passing pooling=False in params.
1075 """
1076 dbn = params.pop('dbn')
1077 if dbn in _databases:
1078 return _databases[dbn](**params)
1079 else:
1080 raise UnknownDB, dbn
1081
1083 """
1084 Register a database.
1085
1086 >>> class LegacyDB(DB):
1087 ... def __init__(self, **params):
1088 ... pass
1089 ...
1090 >>> register_database('legacy', LegacyDB)
1091 >>> db = database(dbn='legacy', db='test', user='joe', passwd='secret')
1092 """
1093 _databases[name] = clazz
1094
1095 register_database('mysql', MySQLDB)
1096 register_database('postgres', PostgresDB)
1097 register_database('sqlite', SqliteDB)
1098 register_database('firebird', FirebirdDB)
1099 register_database('mssql', MSSQLDB)
1100 register_database('oracle', OracleDB)
1101
1103 """
1104 Takes a format string and returns a list of 2-tuples of the form
1105 (boolean, string) where boolean says whether string should be evaled
1106 or not.
1107
1108 from <http://lfw.org/python/Itpl.py> (public domain, Ka-Ping Yee)
1109 """
1110 from tokenize import tokenprog
1111
1112 def matchorfail(text, pos):
1113 match = tokenprog.match(text, pos)
1114 if match is None:
1115 raise _ItplError(text, pos)
1116 return match, match.end()
1117
1118 namechars = "abcdefghijklmnopqrstuvwxyz" \
1119 "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_";
1120 chunks = []
1121 pos = 0
1122
1123 while 1:
1124 dollar = format.find("$", pos)
1125 if dollar < 0:
1126 break
1127 nextchar = format[dollar + 1]
1128
1129 if nextchar == "{":
1130 chunks.append((0, format[pos:dollar]))
1131 pos, level = dollar + 2, 1
1132 while level:
1133 match, pos = matchorfail(format, pos)
1134 tstart, tend = match.regs[3]
1135 token = format[tstart:tend]
1136 if token == "{":
1137 level = level + 1
1138 elif token == "}":
1139 level = level - 1
1140 chunks.append((1, format[dollar + 2:pos - 1]))
1141
1142 elif nextchar in namechars:
1143 chunks.append((0, format[pos:dollar]))
1144 match, pos = matchorfail(format, dollar + 1)
1145 while pos < len(format):
1146 if format[pos] == "." and \
1147 pos + 1 < len(format) and format[pos + 1] in namechars:
1148 match, pos = matchorfail(format, pos + 1)
1149 elif format[pos] in "([":
1150 pos, level = pos + 1, 1
1151 while level:
1152 match, pos = matchorfail(format, pos)
1153 tstart, tend = match.regs[3]
1154 token = format[tstart:tend]
1155 if token[0] in "([":
1156 level = level + 1
1157 elif token[0] in ")]":
1158 level = level - 1
1159 else:
1160 break
1161 chunks.append((1, format[dollar + 1:pos]))
1162 else:
1163 chunks.append((0, format[pos:dollar + 1]))
1164 pos = dollar + 1 + (nextchar == "$")
1165
1166 if pos < len(format):
1167 chunks.append((0, format[pos:]))
1168 return chunks
1169
1170 if __name__ == "__main__":
1171 import doctest
1172 doctest.testmod()
1173