From 84eff2df9941af39d09e9d5014731305aeb7b480 Mon Sep 17 00:00:00 2001 From: twidi Date: Wed, 9 Jan 2013 09:07:09 +0100 Subject: [PATCH 1/6] Add middlewares hooks in database --- limpyd/database.py | 73 +++++++++++++++++++++++++++++++++++++++++++++- limpyd/fields.py | 11 ++++--- 2 files changed, 79 insertions(+), 5 deletions(-) diff --git a/limpyd/database.py b/limpyd/database.py index 9ea8f1d..f336096 100644 --- a/limpyd/database.py +++ b/limpyd/database.py @@ -1,6 +1,7 @@ # -*- coding:utf-8 -*- import redis +from collections import namedtuple from limpyd.exceptions import * @@ -14,6 +15,9 @@ db=0 ) +Command = namedtuple('Command', ['name', 'args', 'kwargs']) +Result = namedtuple('Result', ['value', ]) + class RedisDatabase(object): """ @@ -27,12 +31,18 @@ class RedisDatabase(object): """ _connections = {} # class level cache discard_cache = False + middlewares = [] - def __init__(self, **connection_settings): + def __init__(self, middlewares=None, **connection_settings): self._connection = None # Instance level cache self.reset(**(connection_settings or DEFAULT_CONNECTION_SETTINGS)) + # _models keep an entry for each defined model on this database self._models = dict() + + if middlewares is not None: + self.middlewares = middlewares + super(RedisDatabase, self).__init__() def connect(self, **settings): @@ -126,3 +136,64 @@ def has_scripting(self): except: self._has_scripting = False return self._has_scripting + + @property + def prepared_middlewares(self): + """ + Load, cache and return the list of usable middlewares, as a dict with + an entry for each usable method. + { + 'pre_command': [list, of, middlewares], + 'post_command': [list, of, middlewares], + } + Middlewares must be defined while declaring the database: + database = RedisDatabase(middlewares=[ + AMiddleware(), + AnoterMiddleware(some, parameter) + ], **connection_settings) + """ + + if not hasattr(self, '_prepared_middlewares'): + + self._prepared_middlewares = { + 'pre_command': [], + 'post_command': [], + } + + for middleware in self.middlewares: + middleware.database = self + + for middleware_type in self._prepared_middlewares: + if hasattr(middleware, middleware_type): + self._prepared_middlewares[middleware_type].append(middleware) + + self._prepared_middlewares['post_command'] = self._prepared_middlewares['post_command'][::-1] + + return self._prepared_middlewares + + def run_command(self, command, context=None): + """ + Run a redis command, passing it through all defined middlewares. + The command must be a Command namedtuple + """ + if context is None: + context = {} + + result = None + + for middleware in self.prepared_middlewares['pre_command']: + result = middleware.pre_command(command, context) + if result: + break + + if result is None: + method = getattr(self.connection, "%s" % command.name) + result = method(*command.args, **command.kwargs) + + if not isinstance(result, Result): + result = Result(result) + + for middleware in self.prepared_middlewares['post_command']: + result = middleware.post_command(command, result, context) + + return result.value diff --git a/limpyd/fields.py b/limpyd/fields.py index cfa040d..5fa61b3 100644 --- a/limpyd/fields.py +++ b/limpyd/fields.py @@ -7,6 +7,7 @@ from redis.client import Lock from limpyd.utils import make_key, memoize_command +from limpyd.database import Command from limpyd.exceptions import * log = getLogger(__name__) @@ -127,10 +128,12 @@ def _traverse_command(self, name, *args, **kwargs): # TODO: implement instance level cache if not name in self.available_commands: raise AttributeError("%s is not an available command for %s" % (name, self.__class__.__name__)) - attr = getattr(self.connection, "%s" % name) - key = self.key - log.debug(u"Requesting %s with key %s and args %s" % (name, key, args)) - result = attr(key, *args, **kwargs) + + log.debug(u"Requesting %s with key %s and args %s" % (name, self.key, args)) + command = Command(name, [self.key, ] + list(args), kwargs) + context = {'sender': self, } + result = self.database.run_command(command, context) + result = self.post_command( sender=self, name=name, From c591647cae01674c75d4742039a818d77d2eb91d Mon Sep 17 00:00:00 2001 From: twidi Date: Wed, 9 Jan 2013 09:23:11 +0100 Subject: [PATCH 2/6] Add a BaseMiddleware, and a LoggingMiddlware to log all commands --- limpyd/fields.py | 1 - limpyd/middlewares.py | 61 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 1 deletion(-) create mode 100644 limpyd/middlewares.py diff --git a/limpyd/fields.py b/limpyd/fields.py index 5fa61b3..66441de 100644 --- a/limpyd/fields.py +++ b/limpyd/fields.py @@ -129,7 +129,6 @@ def _traverse_command(self, name, *args, **kwargs): if not name in self.available_commands: raise AttributeError("%s is not an available command for %s" % (name, self.__class__.__name__)) - log.debug(u"Requesting %s with key %s and args %s" % (name, self.key, args)) command = Command(name, [self.key, ] + list(args), kwargs) context = {'sender': self, } result = self.database.run_command(command, context) diff --git a/limpyd/middlewares.py b/limpyd/middlewares.py new file mode 100644 index 0000000..99fce08 --- /dev/null +++ b/limpyd/middlewares.py @@ -0,0 +1,61 @@ +# -*- coding:utf-8 -*- + +from time import time + +from limpyd.exceptions import ImplementationError + + +class BaseMiddleware(object): + @property + def database(self): + return self._database + + @database.setter + def database(self, value): + if hasattr(self, '_database'): + raise ImplementationError("Cannot change the database of a middleware") + self._database = value + + # minimal pre_command method: do nothing and return None + # def pre_command(self, command, context): + # pass + + # minimal post_command method: return the given result + # def post_command(self, command, result, context): + # return result + + +class LoggingMiddleware(BaseMiddleware): + """ + Middleware that takes a logger, and log commands and their result (and time + to run the command). + """ + def __init__(self, logger, log_results=True): + """ + The logger must be a defined and correctly initialized one (via logging) + The log_results flag indicates if only the commands or also their result + (with duration) are logged. + """ + self.logger = logger + self.log_results = log_results + super(LoggingMiddleware, self).__init__() + + @BaseMiddleware.database.setter + def database(self, value): + BaseMiddleware.database.fset(self, value) # super + self.database._command_logger_counter = 0 + + def pre_command(self, command, context): + self.database._command_logger_counter += 1 + context['_command_number'] = self.database._command_logger_counter + context['_start_time'] = time() + self.logger.info('[#%s] %s' % (context['_command_number'], str(command))) + + def post_command(self, command, result, context): + if self.log_results: + self.logger.info('[#%s, in %0.0fµs] %s' % ( + context['_command_number'], + (time() - context['_start_time']) * 1000000, + str(result)) + ) + return result From 2f27413acc9da90277195160612e2f790997b93d Mon Sep 17 00:00:00 2001 From: twidi Date: Wed, 9 Jan 2013 09:29:10 +0100 Subject: [PATCH 3/6] Commands used to update indexes do not pass through middleware --- limpyd/fields.py | 51 ++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 45 insertions(+), 6 deletions(-) diff --git a/limpyd/fields.py b/limpyd/fields.py index 66441de..0b53a9d 100644 --- a/limpyd/fields.py +++ b/limpyd/fields.py @@ -230,12 +230,18 @@ def __init__(self, *args, **kwargs): self._creation_order = RedisField._creation_order RedisField._creation_order += 1 - def proxy_get(self): + def proxy_get(self, _direct=False): """ - A helper to easily call the proxy_getter of the field + A helper to easily call the proxy_getter of the field. + If _direct is True, don't use the _traverse_command method but directly + use the connection to redis """ - getter = getattr(self, self.proxy_getter) - return getter() + if _direct: + getter = getattr(self.connection, self.proxy_getter) + return getter(self.key) + else: + getter = getattr(self, self.proxy_getter) + return getter() def proxy_set(self, value): """ @@ -496,7 +502,7 @@ def values_for_indexing(self): """ Values for indexing must be a list, so return the simple value as a list """ - return [self.proxy_get()] + return [self.proxy_get(_direct=True)] def index(self, values=None): """ @@ -596,7 +602,7 @@ def values_for_indexing(self): """ Return all values in the field for (de)indexing """ - return self.proxy_get() + return self.proxy_get(_direct=True) def _add(self, command, *args, **kwargs): """ @@ -662,6 +668,17 @@ def zmembers(self): """ return self.zrange(0, -1) + def proxy_get(self, _direct=False): + """ + A helper to easily call the proxy_getter of the field. + If _direct is True, don't use the _traverse_command method but directly + use the connection to redis + """ + if _direct: + return self.connection.zrange(self.key, 0, -1) + else: + return super(SortedSetField, self).proxy_get() + def zadd(self, *args, **kwargs): """ We do the same computation of the zadd method of StrictRedis to keep keys @@ -780,6 +797,17 @@ def lmembers(self): """ return self.lrange(0, -1) + def proxy_get(self, _direct=False): + """ + A helper to easily call the proxy_getter of the field. + If _direct is True, don't use the _traverse_command method but directly + use the connection to redis + """ + if _direct: + return self.connection.lrange(self.key, 0, -1) + else: + return super(ListField, self).proxy_get() + def linsert(self, where, refvalue, value): return self._call_command('linsert', where, refvalue, value, _to_index=[value], _to_deindex=[]) @@ -844,6 +872,17 @@ class HashableField(SingleValueField): 'hset': '_set', } + def proxy_get(self, _direct=False): + """ + A helper to easily call the proxy_getter of the field. + If _direct is True, don't use the _traverse_command method but directly + use the connection to redis + """ + if _direct: + return self.connection.hget(self.key, self.name) + else: + return super(HashableField, self).proxy_get() + @property def key(self): return self._instance.key From cd4bf02a24ca1f02a17e38ea638c750bf4381c64 Mon Sep 17 00:00:00 2001 From: twidi Date: Wed, 9 Jan 2013 22:41:29 +0100 Subject: [PATCH 4/6] Command and Result are now objects with slots, not namedtuples namedtuples are immutables, but we must allow edition of commands and results inside the middleware, so we use real objects, but wuth __slots__ instead of __dict__, which is faster and much memory efficient --- limpyd/database.py | 42 +++++++++++++++++++++++++++++++++++++++--- limpyd/fields.py | 2 +- 2 files changed, 40 insertions(+), 4 deletions(-) diff --git a/limpyd/database.py b/limpyd/database.py index f336096..c341925 100644 --- a/limpyd/database.py +++ b/limpyd/database.py @@ -1,7 +1,6 @@ # -*- coding:utf-8 -*- import redis -from collections import namedtuple from limpyd.exceptions import * @@ -15,8 +14,45 @@ db=0 ) -Command = namedtuple('Command', ['name', 'args', 'kwargs']) -Result = namedtuple('Result', ['value', ]) + +class Command(object): + """ + Object to pass the command through middlewares + """ + __slots__ = ('name', 'args', 'kwargs',) + + def __init__(self, name, *args, **kwargs): + self.name = name + self.args = args + self.kwargs = kwargs + + def __unicode__(self): + return u"Command(name='%s', args=%s, kwargs=%s)" % (self.name, self.args, self.kwargs) + + def __str__(self): + return unicode(self).encode('utf-8') + + def __repr__(self): + return str(self) + + +class Result(object): + """ + Object to pass the command's result through middlewares + """ + __slots__ = ('value',) + + def __init__(self, value): + self.value = value + + def __unicode__(self): + return u"Result(value=%s)" % self.value + + def __str__(self): + return unicode(self).encode('utf-8') + + def __repr__(self): + return str(self) class RedisDatabase(object): diff --git a/limpyd/fields.py b/limpyd/fields.py index 0b53a9d..c7e7ae2 100644 --- a/limpyd/fields.py +++ b/limpyd/fields.py @@ -129,7 +129,7 @@ def _traverse_command(self, name, *args, **kwargs): if not name in self.available_commands: raise AttributeError("%s is not an available command for %s" % (name, self.__class__.__name__)) - command = Command(name, [self.key, ] + list(args), kwargs) + command = Command(name, self.key, *args, **kwargs) context = {'sender': self, } result = self.database.run_command(command, context) From d04feae147d4a35c24844af4ee37ca51d6cdd76c Mon Sep 17 00:00:00 2001 From: twidi Date: Tue, 15 Jan 2013 22:50:25 +0100 Subject: [PATCH 5/6] Add option do LoggingMiddleware to hide/show duration --- limpyd/middlewares.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/limpyd/middlewares.py b/limpyd/middlewares.py index 99fce08..9acf1e5 100644 --- a/limpyd/middlewares.py +++ b/limpyd/middlewares.py @@ -30,14 +30,15 @@ class LoggingMiddleware(BaseMiddleware): Middleware that takes a logger, and log commands and their result (and time to run the command). """ - def __init__(self, logger, log_results=True): + def __init__(self, logger, log_results=True, log_time=True): """ The logger must be a defined and correctly initialized one (via logging) The log_results flag indicates if only the commands or also their result - (with duration) are logged. + (with duration, if log_time is True) are logged. """ self.logger = logger self.log_results = log_results + self.log_time = log_time super(LoggingMiddleware, self).__init__() @BaseMiddleware.database.setter @@ -48,14 +49,18 @@ def database(self, value): def pre_command(self, command, context): self.database._command_logger_counter += 1 context['_command_number'] = self.database._command_logger_counter - context['_start_time'] = time() - self.logger.info('[#%s] %s' % (context['_command_number'], str(command))) + if self.log_time: + context['_start_time'] = time() + self.logger.info(u'[#%s] %s' % (context['_command_number'], str(command))) def post_command(self, command, result, context): if self.log_results: - self.logger.info('[#%s, in %0.0fµs] %s' % ( - context['_command_number'], - (time() - context['_start_time']) * 1000000, - str(result)) - ) + log_str = u'[#%s] %s' + log_params = [context['_command_number'], str(result)] + if self.log_time: + log_str = u'[#%s, in %0.0fµs] %s' + duration = (time() - context['_start_time']) * 1000000 + log_params.insert(1, duration) + + self.logger.info(log_str % tuple(log_params)) return result From 2fece18c8c95b7e2ab9beb657029acf6d1812111 Mon Sep 17 00:00:00 2001 From: twidi Date: Tue, 15 Jan 2013 22:50:43 +0100 Subject: [PATCH 6/6] Add tests for middlewares --- run_tests.py | 4 +- tests/middlewares.py | 149 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 151 insertions(+), 2 deletions(-) create mode 100644 tests/middlewares.py diff --git a/run_tests.py b/run_tests.py index d250c54..fe8f7ac 100644 --- a/run_tests.py +++ b/run_tests.py @@ -4,7 +4,7 @@ import argparse # FIXME: move tests in limpyd module, to prevent a relative import? -from tests import base, model, utils, collection, lock +from tests import base, model, utils, collection, lock, middlewares from tests.contrib import database, related, collection as contrib_collection @@ -37,7 +37,7 @@ else: # Run all the tests suites = [] - default_mods = [base, model, utils, collection, lock, ] + default_mods = [base, model, utils, collection, lock, middlewares, ] contrib_mods = [database, related, contrib_collection] for mod in default_mods + contrib_mods: suite = unittest.TestLoader().loadTestsFromModule(mod) diff --git a/tests/middlewares.py b/tests/middlewares.py new file mode 100644 index 0000000..7343977 --- /dev/null +++ b/tests/middlewares.py @@ -0,0 +1,149 @@ +# -*- coding:utf-8 -*- + +import unittest +import logging +from StringIO import StringIO + +from limpyd.middlewares import BaseMiddleware, LoggingMiddleware +from limpyd.database import RedisDatabase +from limpyd import model +from limpyd import fields + +from base import LimpydBaseTest, TEST_CONNECTION_SETTINGS + + +class ForceSetterMiddleware(BaseMiddleware): + """ + A test middleware that always save the same given value for all "set" calls + """ + def __init__(self, value): + super(ForceSetterMiddleware, self).__init__() + self.value = value + + def pre_command(self, command, context): + if command.name == 'hset': + command.kwargs = {} + command.args = (command.args[0], command.args[1], self.value) + + +class ForceGetterMiddleware(BaseMiddleware): + """ + A test middleware that always returns the same given value for all "get" calls + """ + def __init__(self, value): + super(ForceGetterMiddleware, self).__init__() + self.value = value + + def post_command(self, command, result, context): + if command.name == 'hget': + result.value = self.value + return result + + +class BaseTestModel(model.RedisModel): + abstract = True + cacheable = False + foo = fields.HashableField() + + +class MiddlewareTest(LimpydBaseTest): + def test_middleware_pre_command_method_should_be_called(self): + test_database = RedisDatabase(middlewares=[ + ForceSetterMiddleware(value='BAZ'), + ], **TEST_CONNECTION_SETTINGS) + + class TestModel(BaseTestModel): + database = test_database + namespace = 'test_middleware_pre_command_method_should_be_called' + + instance = TestModel(foo='bar') + self.assertEqual(instance.foo.hget(), 'BAZ') + + def test_middleware_post_command_method_should_be_called(self): + test_database = RedisDatabase(middlewares=[ + ForceGetterMiddleware(value='QUX'), + ], **TEST_CONNECTION_SETTINGS) + + class TestModel(BaseTestModel): + database = test_database + namespace = 'test_middleware_post_command_method_should_be_called' + + instance = TestModel(foo='bar') + + # the middleware will send "QUX" + self.assertEqual(instance.foo.hget(), 'QUX') + + # but for untouched command, we got the real values + self.assertEqual(instance.hmget('foo'), ['bar']) + + def test_database_can_accept_many_middlewares(self): + test_database = RedisDatabase(middlewares=[ + ForceSetterMiddleware(value='BAZ'), + ForceGetterMiddleware(value='QUX'), + ], **TEST_CONNECTION_SETTINGS) + + class TestModel(BaseTestModel): + database = test_database + namespace = 'test_database_can_accept_many_middlewares' + + instance = TestModel(foo='bar') + + # the getter middleware will send "QUX" + self.assertEqual(instance.foo.hget(), 'QUX') + + # but for untouched command, we got the value set by the setter middleware + self.assertEqual(instance.hmget('foo'), ['BAZ']) + + def test_logging_middleware(self): + + logger = logging.getLogger('limpyd.tests.middlewares.test_logging_middleware') + stream = StringIO() + logger.setLevel(logging.INFO) + logger.addHandler(logging.StreamHandler(stream)) + + test_database = RedisDatabase(middlewares=[ + LoggingMiddleware(logger, log_time=False) + ], **TEST_CONNECTION_SETTINGS) + + class TestModel(BaseTestModel): + database = test_database + namespace = 'test_logging_middleware' + + instance = TestModel(foo='bar') + self.assertEqual(instance.foo.hget(), 'bar') + + log_lines = [line for line in stream.getvalue().split('\n') if line] + self.assertEqual(len(log_lines), 4) + self.assertEqual(log_lines[0], u"[#1] Command(name='hset', args=(u'test_logging_middleware:testmodel:1:hash', 'foo', 'bar'), kwargs={})") + self.assertEqual(log_lines[1], u"[#1] Result(value=1)") + self.assertEqual(log_lines[2], u"[#2] Command(name='hget', args=(u'test_logging_middleware:testmodel:1:hash', 'foo'), kwargs={})") + self.assertEqual(log_lines[3], u"[#2] Result(value=bar)") + + def test_logging_middleware_with_another(self): + + logger = logging.getLogger('limpyd.tests.middlewares.test_logging_middleware_with_another') + stream = StringIO() + logger.setLevel(logging.INFO) + logger.addHandler(logging.StreamHandler(stream)) + + test_database = RedisDatabase(middlewares=[ + ForceSetterMiddleware(value='BAZ'), + LoggingMiddleware(logger, log_time=False) + ], **TEST_CONNECTION_SETTINGS) + + class TestModel(BaseTestModel): + database = test_database + namespace = 'test_logging_middleware_with_another' + + instance = TestModel(foo='bar') + self.assertEqual(instance.foo.hget(), 'BAZ') + + log_lines = [line for line in stream.getvalue().split('\n') if line] + self.assertEqual(len(log_lines), 4) + self.assertEqual(log_lines[0], u"[#1] Command(name='hset', args=(u'test_logging_middleware_with_another:testmodel:1:hash', 'foo', 'BAZ'), kwargs={})") + self.assertEqual(log_lines[1], u"[#1] Result(value=1)") + self.assertEqual(log_lines[2], u"[#2] Command(name='hget', args=(u'test_logging_middleware_with_another:testmodel:1:hash', 'foo'), kwargs={})") + self.assertEqual(log_lines[3], u"[#2] Result(value=BAZ)") + +if __name__ == '__main__': + unittest.main()