master merge
[auf_rh_dae.git] / src / qbe / django_qbe / forms.py
1 # -*- coding: utf-8 -*-
2 from django import forms
3 from django.db import connections
4 from django.db.models import get_model
5 from django.db.models.fields import Field
6 from django.core.urlresolvers import reverse, NoReverseMatch
7 from django.conf import settings
8 from django.forms.formsets import BaseFormSet, formset_factory
9 from django.utils.importlib import import_module
10 from django.utils.translation import ugettext as _
11 from django.contrib.contenttypes.models import ContentType
12
13 from django_qbe.utils import get_models
14 from django_qbe.widgets import CriteriaInput
15
16
17 DATABASES = None
18 try:
19 DATABASES = settings.DATABASES
20 except AttributeError:
21 # Backwards compatibility for Django versions prior to 1.1.
22 DATABASES = {
23 'default': {
24 'ENGINE': "django.db.backends.%s" % settings.DATABASE_ENGINE,
25 'NAME': settings.DATABASE_NAME,
26 }
27 }
28
29 SORT_CHOICES = (
30 ("", ""),
31 ("asc", _("Ascending")),
32 ("des", _("Descending")),
33 )
34
35
36 class QueryByExampleForm(forms.Form):
37 show = forms.BooleanField(label=_("Show"), required=False)
38 model = forms.CharField(label=_("Model"))
39 field = forms.CharField(label=_("Field"))
40 criteria = forms.CharField(label=_("Criteria"), required=False)
41 sort = forms.ChoiceField(label=_("Sort"), choices=SORT_CHOICES,
42 required=False)
43
44 def __init__(self, *args, **kwargs):
45 super(QueryByExampleForm, self).__init__(*args, **kwargs)
46 model_widget = forms.Select(attrs={'class': "qbeFillModels to:field"})
47 self.fields['model'].widget = model_widget
48 sort_widget = forms.Select(attrs={'disabled': "disabled",
49 'class': 'submitIfChecked'},
50 choices=SORT_CHOICES)
51 self.fields['sort'].widget = sort_widget
52 criteria_widget = CriteriaInput(attrs={'disabled': "disabled"})
53 self.fields['criteria'].widget = criteria_widget
54 criteria_widgets = getattr(criteria_widget, "widgets", [])
55 if criteria_widgets:
56 criteria_len = len(criteria_widgets)
57 criteria_names = ",".join([("criteria_%s" % s)
58 for s in range(0, criteria_len)])
59 field_attr_class = "qbeFillFields enable:sort,%s" % criteria_names
60 else:
61 field_attr_class = "qbeFillFields enable:sort,criteria"
62 field_widget = forms.Select(attrs={'class': field_attr_class})
63 self.fields['field'].widget = field_widget
64
65 def clean_model(self):
66 model = self.cleaned_data['model']
67 return model.lower().replace(".", "_")
68
69 def clean_criteria(self):
70 criteria = self.cleaned_data['criteria']
71 try:
72 operator, over = eval(criteria, {}, {})
73 return (operator, over)
74 except:
75 return (None, None)
76
77
78 class BaseQueryByExampleFormSet(BaseFormSet):
79 _selects = []
80 _froms = []
81 _wheres = []
82 _sorts = []
83 _params = []
84 _models = {}
85 _raw_query = None
86 _db_alias = "default"
87 _db_operators = {}
88 _db_table_names = []
89 _db_operations = None
90
91 def __init__(self, *args, **kwargs):
92 self._db_alias = kwargs.pop("using", "default")
93 self._db_connection = connections["default"]
94 database_properties = DATABASES.get(self._db_alias, "default")
95 module = database_properties['ENGINE']
96 try:
97 base_mod = import_module("%s.base" % module)
98 intros_mod = import_module("%s.introspection" % module)
99 except ImportError:
100 pass
101 if base_mod and intros_mod:
102 self._db_operators = base_mod.DatabaseWrapper.operators
103 DatabaseOperations = base_mod.DatabaseOperations
104 try:
105 self._db_operations = DatabaseOperations(self._db_connection)
106 except TypeError:
107 # Some engines have no params to instance DatabaseOperations
108 self._db_operations = DatabaseOperations()
109 intros_db = intros_mod.DatabaseIntrospection(self._db_connection)
110 django_table_names = intros_db.django_table_names()
111 table_names = intros_db.table_names()
112 self._db_table_names = list(django_table_names.union(table_names))
113 super(BaseQueryByExampleFormSet, self).__init__(*args, **kwargs)
114
115 def clean(self):
116 """
117 Checks that there is almost one field to select
118 """
119 if any(self.errors):
120 # Don't bother validating the formset unless each form is valid on
121 # its own
122 return
123 selects, froms, wheres, sorts, params = self.get_query_parts()
124 if not selects:
125 validation_message = _(u"At least you must check a row to get.")
126 raise forms.ValidationError, validation_message
127 self._selects = selects
128 self._froms = froms
129 self._wheres = wheres
130 self._sorts = sorts
131 self._params = params
132
133 def translate_model_to_db_table(self, model_name):
134 """
135 Ensure the full model name match the DB table name (not app_label
136 name).
137 """
138 app_label, name = model_name.split("_")
139 try:
140 ct = ContentType.objects.get(app_label=app_label, model=name)
141 model = ct.model_class()
142 return model._meta.db_table
143 except:
144 return model_name
145
146 def get_query_parts(self):
147 """
148 Return SQL query for cleaned data
149 """
150 selects = []
151 froms = []
152 wheres = []
153 sorts = []
154 params = []
155 app_model_labels = None
156 lookup_cast = self._db_operations.lookup_cast
157 qn = self._db_operations.quote_name
158 uqn = self._unquote_name
159 for data in self.cleaned_data:
160 if not ("model" in data and "field" in data):
161 break
162 model = data["model"]
163 # HACK: Workaround to handle tables created
164 # by django for its own
165 if not app_model_labels:
166 app_models = get_models(include_auto_created=True,
167 include_deferred=True)
168 app_model_labels = [u"%s_%s" % (a._meta.app_label,
169 a._meta.module_name)
170 for a in app_models]
171 if model in app_model_labels:
172 position = app_model_labels.index(model)
173 model = app_models[position]._meta.db_table
174 self._models[model] = app_models[position]
175 field = data["field"]
176 show = data["show"]
177 criteria = data["criteria"]
178 sort = data["sort"]
179 db_field = u"%s.%s" % (qn(model), qn(field))
180 operator, over = criteria
181 is_join = operator.lower() == 'join'
182 if show and not is_join:
183 selects.append(db_field)
184 if sort:
185 sorts.append(db_field)
186 if all(criteria):
187 if is_join:
188 over_split = over.lower().rsplit(".", 1)
189 join_model = qn(self.translate_model_to_db_table(over_split[0].replace(".", "_")))
190 join_field = qn(over_split[1])
191
192 if model in self._models:
193 _field = self._models[model]._meta.get_field(field)
194 join = u"%s.%s = %s.%s" \
195 % (join_model, join_field, qn(model),
196 qn(_field.db_column))
197 else:
198 join = u"%s.%s = %s" \
199 % (join_model, join_field,
200 u"%s_id" % db_field)
201 if (join not in wheres
202 and uqn(join_model) in self._db_table_names):
203 wheres.append(join)
204 if join_model not in froms:
205 froms.append(join_model)
206 # join_select = u"%s.%s" % (join_model, join_field)
207 # if join_select not in selects:
208 # selects.append(join_select)
209 elif operator in self._db_operators:
210 # db_operator = self._db_operators[operator] % over
211 db_operator = self._db_operators[operator]
212 lookup = self._get_lookup(operator, over)
213 params.append(lookup)
214 wheres.append(u"%s %s" \
215 % (lookup_cast(operator) % db_field,
216 db_operator))
217 if qn(model) not in froms and model in self._db_table_names:
218 froms.append(qn(model))
219 return selects, froms, wheres, sorts, params
220
221 def get_raw_query(self, limit=None, offset=None, count=False,
222 add_extra_ids=False, add_params=False):
223 if self._raw_query:
224 return self._raw_query
225 if self._sorts:
226 order_by = u"ORDER BY %s" % (", ".join(self._sorts))
227 else:
228 order_by = u""
229 if self._wheres:
230 wheres = u"WHERE %s" % (" AND ".join(self._wheres))
231 else:
232 wheres = u""
233 if count:
234 selects = (u"COUNT(*) as count", )
235 order_by = u""
236 elif add_extra_ids:
237 selects = self._get_selects_with_extra_ids()
238 else:
239 selects = self._selects
240 limits = u""
241 if limit:
242 try:
243 limits = u"LIMIT %s" % int(limit)
244 except ValueError:
245 pass
246 offsets = u""
247 if offset:
248 try:
249 offsets = u"OFFSET %s" % int(offset)
250 except ValueError:
251 pass
252 sql = u"""SELECT %s FROM %s %s %s %s %s;""" \
253 % (", ".join(selects),
254 ", ".join(self._froms),
255 wheres,
256 order_by,
257 limits,
258 offsets)
259 if add_params:
260 return u"%s /* %s */" % (sql, ", ".join(self._params))
261 else:
262 return sql
263
264 def get_results(self, limit=None, offset=None, query=None, admin_name=None,
265 row_number=False):
266 """
267 Fetch all results after perform SQL query and
268 """
269 add_extra_ids = (admin_name != None)
270 if not query:
271 sql = self.get_raw_query(limit=limit, offset=offset,
272 add_extra_ids=add_extra_ids)
273 else:
274 sql = query
275 if settings.DEBUG:
276 print sql
277 cursor = self._db_connection.cursor()
278 cursor.execute(sql, tuple(self._params))
279 query_results = cursor.fetchall()
280 if admin_name:
281 selects = self._get_selects_with_extra_ids()
282 results = []
283 try:
284 offset = int(offset)
285 except ValueError:
286 offset = 0
287 for r, row in enumerate(query_results):
288 i = 0
289 l = len(row)
290 if row_number:
291 result = [(r + offset + 1, u"#row%s" % (r + offset + 1))]
292 else:
293 result = []
294 while i < l:
295 appmodel, field = selects[i].split(".")
296 appmodel = self._unquote_name(appmodel)
297 field = self._unquote_name(field)
298 try:
299 if appmodel in self._models:
300 _model = self._models[appmodel]
301 _appmodel = u"%s_%s" % (_model._meta.app_label,
302 _model._meta.module_name)
303 else:
304 _appmodel = appmodel
305 admin_url = reverse("%s:%s_change" % (admin_name,
306 _appmodel),
307 args=[row[i + 1]])
308 except NoReverseMatch:
309 admin_url = None
310 result.append((row[i], admin_url))
311 i += 2
312 results.append(result)
313 return results
314 else:
315 if row_number:
316 results = []
317 for r, row in enumerate(query_results):
318 result = [r + 1]
319 for cell in row:
320 result.append(cell)
321 results.append(result)
322 return results
323 else:
324 return query_results
325
326 def get_count(self):
327 query = self.get_raw_query(count=True)
328 results = self.get_results(query=query)
329 if results:
330 return float(results[0][0])
331 else:
332 return len(self.get_results())
333
334
335 def get_model(self, db_prefix, model):
336 klass = get_model(db_prefix, model)
337 if klass is None:
338 db_model = "%s_%s" % (db_prefix, model)
339 for table in self._models.keys():
340 if table == db_model:
341 return self._models[table]
342 return klass
343
344 def get_labels(self, add_extra_ids=False, row_number=False):
345 if row_number:
346 labels = [_(u"#")]
347 else:
348 labels = []
349 if add_extra_ids:
350 selects = self._get_selects_with_extra_ids()
351 else:
352 selects = self._selects
353 if selects and isinstance(selects, (tuple, list)):
354 for select in selects:
355 label_splits = select.replace("`", "").replace("_", ".").split(".")
356 # restore underscore for fields which use it
357 label_field = "_".join(label_splits[2:])
358 model = self.get_model(label_splits[0], label_splits[1])
359 label = model._meta.get_field_by_name(label_field)[0].verbose_name
360 labels.append(label.capitalize())
361 return labels
362
363 def _unquote_name(self, name):
364 quoted_space = self._db_operations.quote_name("")
365 if name.startswith(quoted_space[0]) and name.endswith(quoted_space[1]):
366 return name[1:-1]
367 return name
368
369 def _get_lookup(self, operator, over):
370 lookup = Field().get_db_prep_lookup(operator, over,
371 connection=self._db_connection,
372 prepared=True)
373 if isinstance(lookup, (tuple, list)):
374 return lookup[0]
375 return lookup
376
377 def _get_selects_with_extra_ids(self):
378 qn = self._db_operations.quote_name
379 selects = []
380 for select in self._selects:
381 appmodel, field = select.split(".")
382 appmodel = self._unquote_name(appmodel)
383 field = self._unquote_name(field)
384 selects.append(select)
385 if appmodel in self._models:
386 pk_name = self._models[appmodel]._meta.pk.name
387 else:
388 pk_name = u"id"
389 selects.append("%s.%s" % (qn(appmodel), qn(pk_name)))
390 return selects
391
392 QueryByExampleFormSet = formset_factory(QueryByExampleForm,
393 formset=BaseQueryByExampleFormSet,
394 extra=1,
395 can_delete=True)