ejabberd: ajout d'un cache pour réduire _largement_ les requêtes SQL.
authorProgfou <jean-christophe.andre@auf.org>
Sun, 7 Jun 2009 19:16:45 +0000 (02:16 +0700)
committerProgfou <jean-christophe.andre@auf.org>
Sun, 7 Jun 2009 19:16:45 +0000 (02:16 +0700)
ejabberd/auth-mysql.py

index 35376d2..74fc1d2 100755 (executable)
@@ -21,17 +21,56 @@ _passwd = 'password'
 _db = 'mail'
 _timeout = 2
 _query = "SELECT * FROM auforg_virtual WHERE source=%s AND LENGTH(password)>1"
+_log_filename = '/var/log/ejabberd/auth-mysql.log'
+_cache_positive_ttl = (15*60)
+_cache_negative_ttl = (1*60)
+
+# ce qui suit est à usage interne, ne pas toucher...
+_find_user_cache = {}
+_authenticate_user_cache = {}
 
 def find_user(user, host):
+    """
+    Returns (found, in_cache)
+    where 'found' means the (user,host) was found in the database
+    and 'in_cache' means it was found from the local cache
+    """
+    now = time.time()
+
+    global _find_user_cache
+    c = _find_user_cache.get((user,host))
+    if c != None:
+        if c['found'] and now < (c['time'] + _cache_positive_ttl):
+            return (True, True)
+        if not c['found'] and now < (c['time'] + _cache_negative_ttl):
+            return (False, True)
+
     global _host, _user, _passwd, _db
     db = MySQLdb.connect(host=_host, user=_user, passwd=_passwd,
                          db=_db, connect_timeout=_timeout)
     cur = db.cursor(MySQLdb.cursors.DictCursor)
     nrows = cur.execute(_query, ('%s@%s' % (user,host), ))
     del cur, db
-    return (nrows > 0)
+    found = (nrows > 0)
+    _find_user_cache[(user,host)] = {'found': found, 'time': now}
+    return (found, False)
 
 def authenticate_user(user, host, password):
+    """
+    Returns (valid, in_cache)
+    where 'valid' means the (user,host,password) was valid in the database
+    and 'in_cache' means it was validated against the local cache
+    """
+    now = time.time()
+
+    global _authenticate_user_cache
+    c = _authenticate_user_cache.get((user,host,password))
+    if c != None:
+        if c['valid'] and now < (c['time'] + _cache_positive_ttl):
+            return (True, True)
+        if not c['valid'] and now < (c['time'] + _cache_negative_ttl):
+            return (False, True)
+
     global _host, _user, _passwd, _db
     db = MySQLdb.connect(host=_host, user=_user, passwd=_passwd,
                          db=_db, connect_timeout=_timeout)
@@ -39,15 +78,17 @@ def authenticate_user(user, host, password):
     nrows = cur.execute(_query, ('%s@%s' % (user,host), ))
     users = cur.fetchall()
     del cur, db
-    if nrows < 1:
-        return False
-    for user in users:
-        if crypt.crypt(password, user['password']) == user['password']:
-            return True
-    return False
+    valid = False
+    if nrows > 0:
+        for u in users:
+            if crypt.crypt(password, u['password']) == u['password']:
+                valid = True
+                break
+    _authenticate_user_cache[(user,host,password)] = {'valid': valid, 'time': now}
+    return (valid, False)
 
 def main():
-    log_file = open('/var/log/ejabberd/auth-mysql.log', 'a')
+    log_file = open(_log_filename, 'a')
     while True:
         try:
             nread = sys.stdin.read(2)
@@ -67,13 +108,13 @@ def main():
                 log_file.write('%s operation=%s user=%s host=%s\n'
                                % (now, operation, user, host))
                 log_file.flush()
-                result = authenticate_user(user, host, password)
+                (result, in_cache) = authenticate_user(user, host, password)
             elif operation == 'isuser':
                 (user, host) = data.split(':', 1)
                 log_file.write('%s operation=%s user=%s host=%s\n'
                                % (now, operation, user, host))
                 log_file.flush()
-                result = find_user(user, host)
+                (result, in_cache) = find_user(user, host)
             elif operation == 'setpass':
                 (user, host, password) = data.split(':', 2)
                 log_file.write('%s operation=%s user=%s host=%s\n'
@@ -81,9 +122,15 @@ def main():
                 log_file.flush()
                 #result = set_user_password(user, host, password)
                 result = False
+                in_cache = False
             else:
                 result = False
-            log_file.write('%s => result=%s\n' % (now, result))
+                in_cache = False
+            if in_cache:
+                in_cache = ' (from cache)'
+            else:
+                in_cache = ''
+            log_file.write('%s => result=%s%s\n' % (now, result, in_cache))
             log_file.flush()
             sys.stdout.write(struct.pack('>hh', 2, result and 1 or 0))
             sys.stdout.flush()