[#5200] Gestion des one-to-many et many-to-many fields
authorEric Mc Sween <eric.mcsween@auf.org>
Thu, 17 Jan 2013 19:41:43 +0000 (14:41 -0500)
committerEric Mc Sween <eric.mcsween@auf.org>
Thu, 17 Jan 2013 19:41:43 +0000 (14:41 -0500)
auf/django/permissions/__init__.py
tests/simpletests/__init__.py
tests/simpletests/models.py
tests/simpletests/tests.py

index 0d3f02b..a948a3f 100644 (file)
@@ -7,7 +7,7 @@ import urlparse
 from django.conf import settings
 from django.contrib.auth.views import redirect_to_login
 from django.core.exceptions import ImproperlyConfigured, PermissionDenied
-from django.db.models import Q
+from django.db.models import Q, Manager
 from django.http import HttpResponseForbidden
 from django.template.loader import render_to_string
 from django.utils.importlib import import_module
@@ -130,6 +130,16 @@ def qeval(obj, q):
     """
     Evaluates a Q object on an instance of a model.
     """
+    def q_getattr(x, attr):
+        y = getattr(x, attr)
+        if y is None:
+            return
+        elif isinstance(y, Manager):
+            for z in y.all():
+                yield z
+        else:
+            yield y
+
     # Evaluate all children
     for child in q.children:
         if isinstance(child, Q):
@@ -139,62 +149,75 @@ def qeval(obj, q):
             bits = filter.split('__')
             path = bits[:-1]
             lookup = bits[-1]
-            obj_value = obj
+
+            # Traverse the attribute path and find all candidate values
+            candidates = [obj]
             for attr in path:
-                if obj_value is None:
-                    break
-                obj_value = getattr(obj_value, attr)
-            if obj_value is None:
-                result = value is None or (lookup == 'isnull' and value)
-            elif lookup == 'exact':
-                result = obj_value == value
+                candidates = list(itertools.chain.from_iterable(
+                    q_getattr(x, attr) for x in candidates
+                ))
+
+            if lookup == 'exact':
+                result = any(x == value for x in candidates)
             elif lookup == 'iexact':
-                result = obj_value.lower() == value.lower()
+                result = any(x.lower() == value.lower() for x in candidates)
             elif lookup == 'contains':
-                result = value in obj_value
+                result = any(value in x for x in candidates)
             elif lookup == 'icontains':
-                result = value.lower() in obj_value.lower()
+                result = any(value.lower() in x.lower() for x in candidates)
             elif lookup == 'in':
-                result = obj_value in value
+                result = any(x in value for x in candidates)
             elif lookup == 'gt':
-                result = obj_value > value
+                result = any(x > value for x in candidates)
             elif lookup == 'gte':
-                result = obj_value >= value
+                result = any(x >= value for x in candidates)
             elif lookup == 'lt':
-                result = obj_value < value
+                result = any(x < value for x in candidates)
             elif lookup == 'lte':
-                result = obj_value <= value
+                result = any(x <= value for x in candidates)
             elif lookup == 'startswith':
-                result = obj_value.startswith(value)
+                result = any(x.startswith(value) for x in candidates)
             elif lookup == 'istartswith':
-                result = obj_value.lower().istartswith(value.lower())
+                result = any(
+                    x.lower().istartswith(value.lower())
+                    for x in candidates
+                )
             elif lookup == 'endswith':
-                result = obj_value.lower().iendswith(value.lower())
+                result = any(
+                    x.lower().iendswith(value.lower()) for x in candidates
+                )
             elif lookup == 'range':
-                result = value[0] <= obj_value <= value[1]
+                result = any(value[0] <= x <= value[1] for x in candidates)
             elif lookup == 'year':
-                result = obj_value.year == value
+                result = any(x.year == value for x in candidates)
             elif lookup == 'month':
-                result = obj_value.month == value
+                result = any(x.month == value for x in candidates)
             elif lookup == 'day':
-                result = obj_value.day == value
+                result = any(x.day == value for x in candidates)
             elif lookup == 'week_day':
-                result = (obj_value.weekday() + 1) % 7 + 1 == value
+                result = any(
+                    (x.weekday() + 1) % 7 + 1 == value for x in candidates
+                )
             elif lookup == 'isnull':
-                # We took care of the case where obj_value is None earlier,
-                # so at this point, obj_value is not None
-                result = not value
+                if value:
+                    return not candidates
+                else:
+                    return bool(candidates)
             elif lookup == 'search':
                 raise NotImplementedError(
                     'qeval does not implement "__search"'
                 )
             elif lookup == 'regex':
-                result = bool(re.search(value, obj_value))
+                result = any(bool(re.search(value, x)) for x in candidates)
             elif lookup == 'iregex':
-                result = bool(re.search(value, obj_value, re.I))
+                result = any(
+                    bool(re.search(value, x, re.I)) for x in candidates
+                )
             else:
-                obj_value = getattr(obj_value, lookup)
-                result = obj_value == value
+                candidates = list(itertools.chain.from_iterable(
+                    q_getattr(x, lookup) for x in candidates
+                ))
+                result = any(x == value for x in candidates)
 
         # See if we can shortcut
         if (result and q.connector == Q.OR) \
index b843d5b..5d0fbf3 100644 (file)
@@ -3,7 +3,7 @@ from __future__ import absolute_import
 from auf.django.permissions import Role
 from django.db.models import Q
 
-from tests.simpletests.models import Food
+from tests.simpletests.models import Food, Recipe
 
 
 def role_provider(user):
@@ -32,6 +32,9 @@ class VegetarianRole(Role):
                 return Q(is_meat=True) | Q(name__contains='canned')
             elif perm == 'paint':
                 return Q(owner__username__startswith='a')
+        elif model is Recipe:
+            if perm == 'eat':
+                return ~Q(ingredients__is_meat=True)
         return False
 
 
index b3f54b9..4bbb7b7 100644 (file)
@@ -9,3 +9,11 @@ class Food(models.Model):
 
     def __unicode__(self):
         return self.name
+
+
+class Recipe(models.Model):
+    name = models.CharField(max_length=255)
+    ingredients = models.ManyToManyField(Food)
+
+    def __unicode__(self):
+        return self.name
index 6c67080..dbfc170 100644 (file)
@@ -3,7 +3,7 @@ from __future__ import absolute_import
 from django.contrib.auth.models import User
 from django.test import TransactionTestCase
 
-from tests.simpletests.models import Food
+from tests.simpletests.models import Food, Recipe
 
 
 class HasPermTestCase(TransactionTestCase):
@@ -20,6 +20,14 @@ class HasPermTestCase(TransactionTestCase):
         self.celery = Food.objects.create(name=u'celery', is_meat=False)
         self.steak = Food.objects.create(name=u'steak', is_meat=True)
         self.soup = Food.objects.create(name=u'canned soup', is_meat=False)
+        self.vegetable_soup = Recipe.objects.create(
+            name=u'vegetable soup'
+        )
+        self.vegetable_soup.ingredients = [self.carrot, self.celery]
+        self.beef_soup = Recipe.objects.create(
+            name=u'beef soup'
+        )
+        self.beef_soup.ingredients = [self.carrot, self.celery, self.steak]
 
     def test_global_permissions(self):
         self.assertTrue(self.bob.has_perm('eat'))
@@ -34,6 +42,8 @@ class HasPermTestCase(TransactionTestCase):
         self.assertFalse(self.alice.has_perm('give', self.carrot))
         self.assertTrue(self.alice.has_perm('paint', self.carrot))
         self.assertFalse(self.alice.has_perm('paint', self.steak))
+        self.assertTrue(self.alice.has_perm('eat', self.vegetable_soup))
+        self.assertFalse(self.alice.has_perm('eat', self.beef_soup))
 
     def test_queryset_filtering(self):
         self.assertEqual(