diff --git a/pillar/api/nodes/__init__.py b/pillar/api/nodes/__init__.py index d34ce484..dde813b7 100644 --- a/pillar/api/nodes/__init__.py +++ b/pillar/api/nodes/__init__.py @@ -23,7 +23,7 @@ blueprint = Blueprint('nodes_api', __name__) ROLES_FOR_SHARING = {u'subscriber', u'demo'} -def only_for_node_type_decorator(required_node_type_name): +def only_for_node_type_decorator(*required_node_type_names): """Returns a decorator that checks its first argument's node type. If the node type is not of the required node type, returns None, @@ -33,12 +33,19 @@ def only_for_node_type_decorator(required_node_type_name): >>> @deco ... def handle_comment(node): pass + >>> deco = only_for_node_type_decorator('comment', 'post') + >>> @deco + ... def handle_comment_or_post(node): pass + """ + # Convert to a set for efficient 'x in required_node_type_names' queries. + required_node_type_names = set(required_node_type_names) + def only_for_node_type(wrapped): @functools.wraps(wrapped) def wrapper(node, *args, **kwargs): - if node.get('node_type') != required_node_type_name: + if node.get('node_type') not in required_node_type_names: return return wrapped(node, *args, **kwargs) @@ -46,7 +53,7 @@ def only_for_node_type_decorator(required_node_type_name): return wrapper only_for_node_type.__doc__ = "Decorator, immediately returns when " \ - "the first argument is not of type %s." % required_node_type_name + "the first argument is not of type %s." % required_node_type_names return only_for_node_type