-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathcrosstab.py
68 lines (60 loc) · 2.39 KB
/
crosstab.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from sqlalchemy.sql import FromClause, column, ColumnElement
from sqlalchemy.orm import Query
from sqlalchemy.ext.compiler import compiles
class crosstab(FromClause):
def __init__(self, stmt, return_def, categories=None, auto_order=True):
if not (isinstance(return_def, (list, tuple))
or return_def.is_selectable):
raise TypeError('return_def must be a selectable or tuple/list')
self.stmt = stmt
self.columns = return_def if isinstance(return_def, (list, tuple)) \
else return_def.columns
self.categories = categories
if hasattr(return_def, 'name'):
self.name = return_def.name
else:
self.name = None
if isinstance(self.stmt, Query):
self.stmt = self.stmt.selectable
if isinstance(self.categories, Query):
self.categories = self.categories.selectable
#Don't rely on the user to order their stuff
if auto_order:
self.stmt = self.stmt.order_by('1,2')
if self.categories is not None:
self.categories = self.categories.order_by('1')
def _populate_column_collection(self):
self._columns.update(
column(name, type=type_)
for name, type_ in self.names
)
@compiles(crosstab, 'postgresql')
def visit_element(element, compiler, **kw):
if element.categories is not None:
return """crosstab($$%s$$, $$%s$$) AS (%s)""" % (
compiler.visit_select(element.stmt),
compiler.visit_select(element.categories),
", ".join(
"\"%s\" %s" % (c.name, compiler.visit_typeclause(c))
for c in element.c
)
)
else:
return """crosstab($$%s$$) AS (%s)""" % (
compiler.visit_select(element.stmt),
", ".join(
"%s %s" % (c.name, compiler.visit_typeclause(c))
for c in element.c
)
)
from operator import add
from sqlalchemy import func, INTEGER
class row_total(ColumnElement):
type = INTEGER()
def __init__(self, cols):
self.cols = cols
@compiles(row_total)
def compile_row_total(element, compiler, **kw):
#coalesce_columns = [func.coalesce("'%s'" % x.name, 0) for x in element.cols]
coalesce_columns = ['coalesce("%s", 0)' % x.name for x in element.cols]
return "+".join(coalesce_columns)