1
0
Fork 0
mirror of https://github.com/ansible-collections/community.general.git synced 2024-09-14 20:13:21 +02:00

Adding VariableManager class for v2

This commit is contained in:
James Cammarata 2014-11-01 14:34:14 -05:00
parent e1662422bf
commit 11822f0d57
8 changed files with 771 additions and 0 deletions

59
v2/ansible/plugins/cache/__init__.py vendored Normal file
View file

@ -0,0 +1,59 @@
# (c) 2014, Michael DeHaan <michael.dehaan@gmail.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
from collections import MutableMapping
from ansible import constants as C
from ansible.plugins import cache_loader
class FactCache(MutableMapping):
def __init__(self, *args, **kwargs):
self._plugin = cache_loader.get(C.CACHE_PLUGIN)
if self._plugin is None:
return
def __getitem__(self, key):
if key not in self:
raise KeyError
return self._plugin.get(key)
def __setitem__(self, key, value):
self._plugin.set(key, value)
def __delitem__(self, key):
self._plugin.delete(key)
def __contains__(self, key):
return self._plugin.contains(key)
def __iter__(self):
return iter(self._plugin.keys())
def __len__(self):
return len(self._plugin.keys())
def copy(self):
""" Return a primitive copy of the keys and values from the cache. """
return dict([(k, v) for (k, v) in self.iteritems()])
def keys(self):
return self._plugin.keys()
def flush(self):
""" Flush the fact cache of all keys. """
self._plugin.flush()

41
v2/ansible/plugins/cache/base.py vendored Normal file
View file

@ -0,0 +1,41 @@
# (c) 2014, Brian Coca, Josh Drake, et al
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
import exceptions
class BaseCacheModule(object):
def get(self, key):
raise exceptions.NotImplementedError
def set(self, key, value):
raise exceptions.NotImplementedError
def keys(self):
raise exceptions.NotImplementedError
def contains(self, key):
raise exceptions.NotImplementedError
def delete(self, key):
raise exceptions.NotImplementedError
def flush(self):
raise exceptions.NotImplementedError
def copy(self):
raise exceptions.NotImplementedError

191
v2/ansible/plugins/cache/memcached.py vendored Normal file
View file

@ -0,0 +1,191 @@
# (c) 2014, Brian Coca, Josh Drake, et al
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
import collections
import os
import sys
import time
import threading
from itertools import chain
from ansible import constants as C
from ansible.plugins.cache.base import BaseCacheModule
try:
import memcache
except ImportError:
print 'python-memcached is required for the memcached fact cache'
sys.exit(1)
class ProxyClientPool(object):
"""
Memcached connection pooling for thread/fork safety. Inspired by py-redis
connection pool.
Available connections are maintained in a deque and released in a FIFO manner.
"""
def __init__(self, *args, **kwargs):
self.max_connections = kwargs.pop('max_connections', 1024)
self.connection_args = args
self.connection_kwargs = kwargs
self.reset()
def reset(self):
self.pid = os.getpid()
self._num_connections = 0
self._available_connections = collections.deque(maxlen=self.max_connections)
self._locked_connections = set()
self._lock = threading.Lock()
def _check_safe(self):
if self.pid != os.getpid():
with self._lock:
if self.pid == os.getpid():
# bail out - another thread already acquired the lock
return
self.disconnect_all()
self.reset()
def get_connection(self):
self._check_safe()
try:
connection = self._available_connections.popleft()
except IndexError:
connection = self.create_connection()
self._locked_connections.add(connection)
return connection
def create_connection(self):
if self._num_connections >= self.max_connections:
raise RuntimeError("Too many memcached connections")
self._num_connections += 1
return memcache.Client(*self.connection_args, **self.connection_kwargs)
def release_connection(self, connection):
self._check_safe()
self._locked_connections.remove(connection)
self._available_connections.append(connection)
def disconnect_all(self):
for conn in chain(self._available_connections, self._locked_connections):
conn.disconnect_all()
def __getattr__(self, name):
def wrapped(*args, **kwargs):
return self._proxy_client(name, *args, **kwargs)
return wrapped
def _proxy_client(self, name, *args, **kwargs):
conn = self.get_connection()
try:
return getattr(conn, name)(*args, **kwargs)
finally:
self.release_connection(conn)
class CacheModuleKeys(collections.MutableSet):
"""
A set subclass that keeps track of insertion time and persists
the set in memcached.
"""
PREFIX = 'ansible_cache_keys'
def __init__(self, cache, *args, **kwargs):
self._cache = cache
self._keyset = dict(*args, **kwargs)
def __contains__(self, key):
return key in self._keyset
def __iter__(self):
return iter(self._keyset)
def __len__(self):
return len(self._keyset)
def add(self, key):
self._keyset[key] = time.time()
self._cache.set(self.PREFIX, self._keyset)
def discard(self, key):
del self._keyset[key]
self._cache.set(self.PREFIX, self._keyset)
def remove_by_timerange(self, s_min, s_max):
for k in self._keyset.keys():
t = self._keyset[k]
if s_min < t < s_max:
del self._keyset[k]
self._cache.set(self.PREFIX, self._keyset)
class CacheModule(BaseCacheModule):
def __init__(self, *args, **kwargs):
if C.CACHE_PLUGIN_CONNECTION:
connection = C.CACHE_PLUGIN_CONNECTION.split(',')
else:
connection = ['127.0.0.1:11211']
self._timeout = C.CACHE_PLUGIN_TIMEOUT
self._prefix = C.CACHE_PLUGIN_PREFIX
self._cache = ProxyClientPool(connection, debug=0)
self._keys = CacheModuleKeys(self._cache, self._cache.get(CacheModuleKeys.PREFIX) or [])
def _make_key(self, key):
return "{0}{1}".format(self._prefix, key)
def _expire_keys(self):
if self._timeout > 0:
expiry_age = time.time() - self._timeout
self._keys.remove_by_timerange(0, expiry_age)
def get(self, key):
value = self._cache.get(self._make_key(key))
# guard against the key not being removed from the keyset;
# this could happen in cases where the timeout value is changed
# between invocations
if value is None:
self.delete(key)
raise KeyError
return value
def set(self, key, value):
self._cache.set(self._make_key(key), value, time=self._timeout, min_compress_len=1)
self._keys.add(key)
def keys(self):
self._expire_keys()
return list(iter(self._keys))
def contains(self, key):
self._expire_keys()
return key in self._keys
def delete(self, key):
self._cache.delete(self._make_key(key))
self._keys.discard(key)
def flush(self):
for key in self.keys():
self.delete(key)
def copy(self):
return self._keys.copy()

44
v2/ansible/plugins/cache/memory.py vendored Normal file
View file

@ -0,0 +1,44 @@
# (c) 2014, Brian Coca, Josh Drake, et al
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
from ansible.plugins.cache.base import BaseCacheModule
class CacheModule(BaseCacheModule):
def __init__(self, *args, **kwargs):
self._cache = {}
def get(self, key):
return self._cache.get(key)
def set(self, key, value):
self._cache[key] = value
def keys(self):
return self._cache.keys()
def contains(self, key):
return key in self._cache
def delete(self, key):
del self._cache[key]
def flush(self):
self._cache = {}
def copy(self):
return self._cache.copy()

102
v2/ansible/plugins/cache/redis.py vendored Normal file
View file

@ -0,0 +1,102 @@
# (c) 2014, Brian Coca, Josh Drake, et al
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
from __future__ import absolute_import
import collections
# FIXME: can we store these as something else before we ship it?
import sys
import time
import json
from ansible import constants as C
from ansible.plugins.cache.base import BaseCacheModule
try:
from redis import StrictRedis
except ImportError:
print "The 'redis' python module is required, 'pip install redis'"
sys.exit(1)
class CacheModule(BaseCacheModule):
"""
A caching module backed by redis.
Keys are maintained in a zset with their score being the timestamp
when they are inserted. This allows for the usage of 'zremrangebyscore'
to expire keys. This mechanism is used or a pattern matched 'scan' for
performance.
"""
def __init__(self, *args, **kwargs):
if C.CACHE_PLUGIN_CONNECTION:
connection = C.CACHE_PLUGIN_CONNECTION.split(':')
else:
connection = []
self._timeout = float(C.CACHE_PLUGIN_TIMEOUT)
self._prefix = C.CACHE_PLUGIN_PREFIX
self._cache = StrictRedis(*connection)
self._keys_set = 'ansible_cache_keys'
def _make_key(self, key):
return self._prefix + key
def get(self, key):
value = self._cache.get(self._make_key(key))
# guard against the key not being removed from the zset;
# this could happen in cases where the timeout value is changed
# between invocations
if value is None:
self.delete(key)
raise KeyError
return json.loads(value)
def set(self, key, value):
value2 = json.dumps(value)
if self._timeout > 0: # a timeout of 0 is handled as meaning 'never expire'
self._cache.setex(self._make_key(key), int(self._timeout), value2)
else:
self._cache.set(self._make_key(key), value2)
self._cache.zadd(self._keys_set, time.time(), key)
def _expire_keys(self):
if self._timeout > 0:
expiry_age = time.time() - self._timeout
self._cache.zremrangebyscore(self._keys_set, 0, expiry_age)
def keys(self):
self._expire_keys()
return self._cache.zrange(self._keys_set, 0, -1)
def contains(self, key):
self._expire_keys()
return (self._cache.zrank(self._keys_set, key) >= 0)
def delete(self, key):
self._cache.delete(self._make_key(key))
self._cache.zrem(self._keys_set, key)
def flush(self):
for key in self.keys():
self.delete(key)
def copy(self):
# FIXME: there is probably a better way to do this in redis
ret = dict()
for key in self.keys():
ret[key] = self.get(key)
return ret

182
v2/ansible/vars/__init__.py Normal file
View file

@ -0,0 +1,182 @@
# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
# Make coding more python3-ish
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
import os
from collections import defaultdict
from ansible.parsing.yaml import DataLoader
from ansible.plugins.cache import FactCache
class VariableManager:
def __init__(self, inventory_path=None, loader=None):
self._fact_cache = FactCache()
self._vars_cache = defaultdict(dict)
self._extra_vars = defaultdict(dict)
self._host_vars_files = defaultdict(dict)
self._group_vars_files = defaultdict(dict)
if not loader:
self._loader = DataLoader()
else:
self._loader = loader
@property
def extra_vars(self):
''' ensures a clean copy of the extra_vars are made '''
return self._extra_vars.copy()
def set_extra_vars(self, value):
''' ensures a clean copy of the extra_vars are used to set the value '''
assert isinstance(value, dict)
self._extra_vars = value.copy()
def _merge_dicts(self, a, b):
'''
Recursively merges dict b into a, so that keys
from b take precedence over keys from a.
'''
result = dict()
# FIXME: do we need this from utils, or should it just
# be merged into this definition?
#_validate_both_dicts(a, b)
for dicts in a, b:
# next, iterate over b keys and values
for k, v in dicts.iteritems():
# if there's already such key in a
# and that key contains dict
if k in result and isinstance(result[k], dict):
# merge those dicts recursively
result[k] = self._merge_dicts(a[k], v)
else:
# otherwise, just copy a value from b to a
result[k] = v
return result
def get_vars(self, play=None, host=None, task=None):
'''
Returns the variables, with optional "context" given via the parameters
for the play, host, and task (which could possibly result in different
sets of variables being returned due to the additional context).
The order of precedence is:
- play->roles->get_default_vars (if there is a play context)
- group_vars_files[host] (if there is a host context)
- host_vars_files[host] (if there is a host context)
- host->get_vars (if there is a host context)
- fact_cache[host] (if there is a host context)
- vars_cache[host] (if there is a host context)
- play vars (if there is a play context)
- play vars_files (if there's no host context, ignore
file names that cannot be templated)
- task->get_vars (if there is a task context)
- extra vars
'''
vars = defaultdict(dict)
if play:
# first we compile any vars specified in defaults/main.yml
# for all roles within the specified play
for role in play.get_roles():
vars = self._merge_dicts(vars, role.get_default_vars())
if host:
# next, if a host is specified, we load any vars from group_vars
# files and then any vars from host_vars files which may apply to
# this host or the groups it belongs to
for group in host.get_groups():
if group in self._group_vars_files:
vars = self._merge_dicts(vars, self._group_vars_files[group])
host_name = host.get_name()
if host_name in self._host_vars_files:
vars = self._merge_dicts(vars, self._host_vars_files[host_name])
# then we merge in vars specified for this host
vars = self._merge_dicts(vars, host.get_vars())
# next comes the facts cache and the vars cache, respectively
vars = self._merge_dicts(vars, self._fact_cache.get(host.get_name(), dict()))
vars = self._merge_dicts(vars, self._vars_cache.get(host.get_name(), dict()))
if play:
vars = self._merge_dicts(vars, play.get_vars())
for vars_file in play.get_vars_files():
# Try templating the vars_file. If an unknown var error is raised,
# ignore it - unless a host is specified
# TODO ...
data = self._loader.load_from_file(vars_file)
vars = self._merge_dicts(vars, data)
if task:
vars = self._merge_dicts(vars, task.get_vars())
vars = self._merge_dicts(vars, self._extra_vars)
return vars
def _get_inventory_basename(self, path):
'''
Returns the bsaename minus the extension of the given path, so the
bare filename can be matched against host/group names later
'''
(name, ext) = os.path.splitext(os.path.basename(path))
return name
def _load_inventory_file(self, path):
'''
helper function, which loads the file and gets the
basename of the file without the extension
'''
data = self._loader.load_from_file(path)
name = self._get_inventory_basename(path)
return (name, data)
def add_host_vars_file(self, path):
'''
Loads and caches a host_vars file in the _host_vars_files dict,
where the key to that dictionary is the basename of the file, minus
the extension, for matching against a given inventory host name
'''
(name, data) = self._load_inventory_file(path)
self._host_vars_files[name] = data
def add_group_vars_file(self, path):
'''
Loads and caches a host_vars file in the _host_vars_files dict,
where the key to that dictionary is the basename of the file, minus
the extension, for matching against a given inventory host name
'''
(name, data) = self._load_inventory_file(path)
self._group_vars_files[name] = data

21
v2/test/vars/__init__.py Normal file
View file

@ -0,0 +1,21 @@
# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
# Make coding more python3-ish
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type

View file

@ -0,0 +1,131 @@
# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
# Make coding more python3-ish
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
from ansible.compat.tests import unittest
from ansible.compat.tests.mock import patch, MagicMock
from ansible.vars import VariableManager
from test.mock.loader import DictDataLoader
class TestVariableManager(unittest.TestCase):
def setUp(self):
pass
def tearDown(self):
pass
def test_basic_manager(self):
v = VariableManager()
self.assertEqual(v.get_vars(), dict())
self.assertEqual(
v._merge_dicts(
dict(a=1),
dict(b=2)
), dict(a=1, b=2)
)
self.assertEqual(
v._merge_dicts(
dict(a=1, c=dict(foo='bar')),
dict(b=2, c=dict(baz='bam'))
), dict(a=1, b=2, c=dict(foo='bar', baz='bam'))
)
def test_manager_extra_vars(self):
extra_vars = dict(a=1, b=2, c=3)
v = VariableManager()
v.set_extra_vars(extra_vars)
self.assertEqual(v.get_vars(), extra_vars)
self.assertIsNot(v.extra_vars, extra_vars)
def test_manager_host_vars_file(self):
fake_loader = DictDataLoader({
"host_vars/hostname1.yml": """
foo: bar
"""
})
v = VariableManager(loader=fake_loader)
v.add_host_vars_file("host_vars/hostname1.yml")
self.assertIn("hostname1", v._host_vars_files)
self.assertEqual(v._host_vars_files["hostname1"], dict(foo="bar"))
mock_host = MagicMock()
mock_host.get_name.return_value = "hostname1"
mock_host.get_vars.return_value = dict()
mock_host.get_groups.return_value = ()
self.assertEqual(v.get_vars(host=mock_host), dict(foo="bar"))
def test_manager_group_vars_file(self):
fake_loader = DictDataLoader({
"group_vars/somegroup.yml": """
foo: bar
"""
})
v = VariableManager(loader=fake_loader)
v.add_group_vars_file("group_vars/somegroup.yml")
self.assertIn("somegroup", v._group_vars_files)
self.assertEqual(v._group_vars_files["somegroup"], dict(foo="bar"))
mock_host = MagicMock()
mock_host.get_name.return_value = "hostname1"
mock_host.get_vars.return_value = dict()
mock_host.get_groups.return_value = ["somegroup"]
self.assertEqual(v.get_vars(host=mock_host), dict(foo="bar"))
def test_manager_play_vars(self):
mock_play = MagicMock()
mock_play.get_vars.return_value = dict(foo="bar")
mock_play.get_roles.return_value = []
mock_play.get_vars_files.return_value = []
v = VariableManager()
self.assertEqual(v.get_vars(play=mock_play), dict(foo="bar"))
def test_manager_play_vars_files(self):
fake_loader = DictDataLoader({
"/path/to/somefile.yml": """
foo: bar
"""
})
mock_play = MagicMock()
mock_play.get_vars.return_value = dict()
mock_play.get_roles.return_value = []
mock_play.get_vars_files.return_value = ['/path/to/somefile.yml']
v = VariableManager(loader=fake_loader)
self.assertEqual(v.get_vars(play=mock_play), dict(foo="bar"))
def test_manager_task_vars(self):
mock_task = MagicMock()
mock_task.get_vars.return_value = dict(foo="bar")
v = VariableManager()
self.assertEqual(v.get_vars(task=mock_task), dict(foo="bar"))