From d877c146ab750c0f1505b4180e4170163d854fb5 Mon Sep 17 00:00:00 2001 From: Sloane Hertel Date: Wed, 20 Dec 2017 14:57:47 -0500 Subject: [PATCH] [cloud] ec2_group fix CIDR with host bits set - fixes #25403 (#29605) * WIP adds network subnetting functions * adds functions to convert between netmask and masklen * adds functions to verify netmask and masklen * adds function to dtermine network and subnet from address / mask pair * network_common: add a function to get the first 48 bits in a IPv6 address. ec2_group: only use network bits of a CIDR. * Add tests for CIDRs with host bits set. * ec2_group: add warning if CIDR isn't the networking address. * Fix pep8. * Improve wording. * fix import for network utils * Update tests to use pytest instead of unittest * add test for to_ipv6_network() * Fix PEP8 --- .../module_utils/network/common/utils.py | 107 ++++++++ lib/ansible/modules/cloud/amazon/ec2_group.py | 18 +- .../targets/ec2_group/tasks/main.yml | 106 ++++++++ .../module_utils/network/common/test_utils.py | 255 +++++++++++------- 4 files changed, 385 insertions(+), 101 deletions(-) diff --git a/lib/ansible/module_utils/network/common/utils.py b/lib/ansible/module_utils/network/common/utils.py index cf3faaf91b..276d941cde 100644 --- a/lib/ansible/module_utils/network/common/utils.py +++ b/lib/ansible/module_utils/network/common/utils.py @@ -31,8 +31,11 @@ import operator import socket from itertools import chain +from struct import pack +from socket import inet_aton, inet_ntoa from ansible.module_utils.six import iteritems, string_types +from ansible.module_utils.six.moves import zip from ansible.module_utils.basic import AnsibleFallbackNotFound try: @@ -45,6 +48,7 @@ except ImportError: OPERATORS = frozenset(['ge', 'gt', 'eq', 'neq', 'lt', 'le']) ALIASES = frozenset([('min', 'ge'), ('max', 'le'), ('exactly', 'eq'), ('neq', 'ne')]) +VALID_MASKS = [2**8 - 2**i for i in range(0, 9)] def to_list(val): @@ -426,3 +430,106 @@ class Template: if marker in data: return True return False + + +def is_netmask(val): + parts = str(val).split('.') + if not len(parts) == 4: + return False + for part in parts: + try: + if int(part) not in VALID_MASKS: + raise ValueError + except ValueError: + return False + return True + + +def is_masklen(val): + try: + return 0 <= int(val) <= 32 + except ValueError: + return False + + +def to_bits(val): + """ converts a netmask to bits """ + bits = '' + for octet in val.split('.'): + bits += bin(int(octet))[2:].zfill(8) + return str + + +def to_netmask(val): + """ converts a masklen to a netmask """ + if not is_masklen(val): + raise ValueError('invalid value for masklen') + + bits = 0 + for i in range(32 - int(val), 32): + bits |= (1 << i) + + return inet_ntoa(pack('>I', bits)) + + +def to_masklen(val): + """ converts a netmask to a masklen """ + if not is_netmask(val): + raise ValueError('invalid value for netmask: %s' % val) + + bits = list() + for x in val.split('.'): + octet = bin(int(x)).count('1') + bits.append(octet) + + return sum(bits) + + +def to_subnet(addr, mask, dotted_notation=False): + """ coverts an addr / mask pair to a subnet in cidr notation """ + try: + if not is_masklen(mask): + raise ValueError + cidr = int(mask) + mask = to_netmask(mask) + except ValueError: + cidr = to_masklen(mask) + + addr = addr.split('.') + mask = mask.split('.') + + network = list() + for s_addr, s_mask in zip(addr, mask): + network.append(str(int(s_addr) & int(s_mask))) + + if dotted_notation: + return '%s %s' % ('.'.join(network), to_netmask(cidr)) + return '%s/%s' % ('.'.join(network), cidr) + + +def to_ipv6_network(addr): + """ IPv6 addresses are eight groupings. The first three groupings (48 bits) comprise the network address. """ + + # Split by :: to identify omitted zeros + ipv6_prefix = addr.split('::')[0] + + # Get the first three groups, or as many as are found + :: + found_groups = [] + for group in ipv6_prefix.split(':'): + found_groups.append(group) + if len(found_groups) == 3: + break + if len(found_groups) < 3: + found_groups.append('::') + + # Concatenate network address parts + network_addr = '' + for group in found_groups: + if group != '::': + network_addr += str(group) + network_addr += str(':') + + # Ensure network address ends with :: + if not network_addr.endswith('::'): + network_addr += str(':') + return network_addr diff --git a/lib/ansible/modules/cloud/amazon/ec2_group.py b/lib/ansible/modules/cloud/amazon/ec2_group.py index 281db4ea84..12928332fe 100644 --- a/lib/ansible/modules/cloud/amazon/ec2_group.py +++ b/lib/ansible/modules/cloud/amazon/ec2_group.py @@ -289,6 +289,7 @@ from ansible.module_utils.ec2 import camel_dict_to_snake_dict from ansible.module_utils.ec2 import HAS_BOTO3 from ansible.module_utils.ec2 import boto3_tag_list_to_ansible_dict, ansible_dict_to_boto3_tag_list, compare_aws_tags from ansible.module_utils.ec2 import AWSRetry +from ansible.module_utils.network.common.utils import to_ipv6_network, to_subnet import traceback try: @@ -521,7 +522,22 @@ def update_rules_description(module, client, rule_type, group_id, ip_permissions def authorize_ip(type, changed, client, group, groupRules, ip, ip_permission, module, rule, ethertype): # If rule already exists, don't later delete it - for thisip in ip: + for this_ip in ip: + + split_addr = this_ip.split('/') + if len(split_addr) == 2: + # this_ip is a IPv4 or IPv6 CIDR that may or may not have host bits set + # Get the network bits. + try: + thisip = to_subnet(split_addr[0], split_addr[1]) + except ValueError: + thisip = to_ipv6_network(split_addr[0]) + "/" + split_addr[1] + if thisip != this_ip: + module.warn("One of your CIDR addresses ({0}) has host bits set. To get rid of this warning, " + "check the network mask and make sure that only network bits are set: {1}.".format(this_ip, thisip)) + else: + thisip = this_ip + rule_id = make_rule_key(type, rule, group['GroupId'], thisip) if rule_id in groupRules: diff --git a/test/integration/targets/ec2_group/tasks/main.yml b/test/integration/targets/ec2_group/tasks/main.yml index f2b91b1770..b169f7bee2 100644 --- a/test/integration/targets/ec2_group/tasks/main.yml +++ b/test/integration/targets/ec2_group/tasks/main.yml @@ -435,6 +435,112 @@ - 'result.changed' - 'result.group_id.startswith("sg-")' + # ============================================================ + + - name: test adding a rule with a IPv4 CIDR with host bits set (expected changed=true) + ec2_group: + name: '{{ec2_group_name}}' + description: '{{ec2_group_description}}' + ec2_region: '{{ec2_region}}' + ec2_access_key: '{{ec2_access_key}}' + ec2_secret_key: '{{ec2_secret_key}}' + security_token: '{{security_token}}' + state: present + # set purge_rules to false so we don't get a false positive from previously added rules + purge_rules: false + rules: + - proto: "tcp" + ports: + - 8195 + cidr_ip: 10.0.0.1/8 + register: result + + - name: assert state=present (expected changed=true) + assert: + that: + - 'result.changed' + - 'result.group_id.startswith("sg-")' + + # ============================================================ + + - name: test adding the same rule with a IPv4 CIDR with host bits set (expected changed=false and a warning) + ec2_group: + name: '{{ec2_group_name}}' + description: '{{ec2_group_description}}' + ec2_region: '{{ec2_region}}' + ec2_access_key: '{{ec2_access_key}}' + ec2_secret_key: '{{ec2_secret_key}}' + security_token: '{{security_token}}' + state: present + # set purge_rules to false so we don't get a false positive from previously added rules + purge_rules: false + rules: + - proto: "tcp" + ports: + - 8195 + cidr_ip: 10.0.0.1/8 + register: result + + - name: assert state=present (expected changed=false and a warning) + assert: + that: + # No way to assert for warnings? + - 'not result.changed' + - 'result.group_id.startswith("sg-")' + + # ============================================================ + + - name: test adding a rule with a IPv6 CIDR with host bits set (expected changed=true) + ec2_group: + name: '{{ec2_group_name}}' + description: '{{ec2_group_description}}' + ec2_region: '{{ec2_region}}' + ec2_access_key: '{{ec2_access_key}}' + ec2_secret_key: '{{ec2_secret_key}}' + security_token: '{{security_token}}' + state: present + # set purge_rules to false so we don't get a false positive from previously added rules + purge_rules: false + rules: + - proto: "tcp" + ports: + - 8196 + cidr_ipv6: '2001:db00::1/24' + register: result + + - name: assert state=present (expected changed=true) + assert: + that: + - 'result.changed' + - 'result.group_id.startswith("sg-")' + + # ============================================================ + + - name: test adding a rule again with a IPv6 CIDR with host bits set (expected changed=false and a warning) + ec2_group: + name: '{{ec2_group_name}}' + description: '{{ec2_group_description}}' + ec2_region: '{{ec2_region}}' + ec2_access_key: '{{ec2_access_key}}' + ec2_secret_key: '{{ec2_secret_key}}' + security_token: '{{security_token}}' + state: present + # set purge_rules to false so we don't get a false positive from previously added rules + purge_rules: false + rules: + - proto: "tcp" + ports: + - 8196 + cidr_ipv6: '2001:db00::1/24' + register: result + + - name: assert state=present (expected changed=false and a warning) + assert: + that: + # No way to assert for warnings? + - 'not result.changed' + - 'result.group_id.startswith("sg-")' + # ============================================================ - name: test state=absent (expected changed=true) ec2_group: diff --git a/test/units/module_utils/network/common/test_utils.py b/test/units/module_utils/network/common/test_utils.py index 734d81ac79..bedc1edd20 100644 --- a/test/units/module_utils/network/common/test_utils.py +++ b/test/units/module_utils/network/common/test_utils.py @@ -18,134 +18,189 @@ # along with Ansible. If not, see . # Make coding more python3-ish -from __future__ import (absolute_import, division) +from __future__ import absolute_import, division, print_function __metaclass__ = type -from ansible.compat.tests import unittest +import pytest from ansible.module_utils.network.common.utils import to_list, sort_list from ansible.module_utils.network.common.utils import dict_diff, dict_merge from ansible.module_utils.network.common.utils import conditional, Template +from ansible.module_utils.network.common.utils import to_masklen, to_netmask, to_subnet, to_ipv6_network +from ansible.module_utils.network.common.utils import is_masklen, is_netmask -class TestModuleUtilsNetworkCommon(unittest.TestCase): +def test_to_list(): + for scalar in ('string', 1, True, False, None): + assert isinstance(to_list(scalar), list) - def test_to_list(self): - for scalar in ('string', 1, True, False, None): - self.assertTrue(isinstance(to_list(scalar), list)) + for container in ([1, 2, 3], {'one': 1}): + assert isinstance(to_list(container), list) - for container in ([1, 2, 3], {'one': 1}): - self.assertTrue(isinstance(to_list(container), list)) + test_list = [1, 2, 3] + assert id(test_list) != id(to_list(test_list)) - test_list = [1, 2, 3] - self.assertNotEqual(id(test_list), id(to_list(test_list))) - def test_sort(self): - data = [3, 1, 2] - self.assertEqual([1, 2, 3], sort_list(data)) +def test_sort(): + data = [3, 1, 2] + assert [1, 2, 3] == sort_list(data) - string_data = '123' - self.assertEqual(string_data, sort_list(string_data)) + string_data = '123' + assert string_data == sort_list(string_data) - def test_dict_diff(self): - base = dict(obj2=dict(), b1=True, b2=False, b3=False, - one=1, two=2, three=3, obj1=dict(key1=1, key2=2), - l1=[1, 3], l2=[1, 2, 3], l4=[4], - nested=dict(n1=dict(n2=2))) - other = dict(b1=True, b2=False, b3=True, b4=True, - one=1, three=4, four=4, obj1=dict(key1=2), - l1=[2, 1], l2=[3, 2, 1], l3=[1], - nested=dict(n1=dict(n2=2, n3=3))) +def test_dict_diff(): + base = dict(obj2=dict(), b1=True, b2=False, b3=False, + one=1, two=2, three=3, obj1=dict(key1=1, key2=2), + l1=[1, 3], l2=[1, 2, 3], l4=[4], + nested=dict(n1=dict(n2=2))) - result = dict_diff(base, other) + other = dict(b1=True, b2=False, b3=True, b4=True, + one=1, three=4, four=4, obj1=dict(key1=2), + l1=[2, 1], l2=[3, 2, 1], l3=[1], + nested=dict(n1=dict(n2=2, n3=3))) - # string assertions - self.assertNotIn('one', result) - self.assertNotIn('two', result) - self.assertEqual(result['three'], 4) - self.assertEqual(result['four'], 4) + result = dict_diff(base, other) - # dict assertions - self.assertIn('obj1', result) - self.assertIn('key1', result['obj1']) - self.assertNotIn('key2', result['obj1']) + # string assertions + assert 'one' not in result + assert 'two' not in result + assert result['three'] == 4 + assert result['four'] == 4 - # list assertions - self.assertEqual(result['l1'], [2, 1]) - self.assertNotIn('l2', result) - self.assertEqual(result['l3'], [1]) - self.assertNotIn('l4', result) + # dict assertions + assert 'obj1' in result + assert 'key1' in result['obj1'] + assert 'key2' not in result['obj1'] - # nested assertions - self.assertIn('obj1', result) - self.assertEqual(result['obj1']['key1'], 2) - self.assertNotIn('key2', result['obj1']) + # list assertions + assert result['l1'] == [2, 1] + assert 'l2' not in result + assert result['l3'] == [1] + assert 'l4' not in result - # bool assertions - self.assertNotIn('b1', result) - self.assertNotIn('b2', result) - self.assertTrue(result['b3']) - self.assertTrue(result['b4']) + # nested assertions + assert 'obj1' in result + assert result['obj1']['key1'] == 2 + assert 'key2' not in result['obj1'] - def test_dict_merge(self): - base = dict(obj2=dict(), b1=True, b2=False, b3=False, - one=1, two=2, three=3, obj1=dict(key1=1, key2=2), - l1=[1, 3], l2=[1, 2, 3], l4=[4], - nested=dict(n1=dict(n2=2))) + # bool assertions + assert 'b1' not in result + assert 'b2' not in result + assert result['b3'] + assert result['b4'] - other = dict(b1=True, b2=False, b3=True, b4=True, - one=1, three=4, four=4, obj1=dict(key1=2), - l1=[2, 1], l2=[3, 2, 1], l3=[1], - nested=dict(n1=dict(n2=2, n3=3))) - result = dict_merge(base, other) +def test_dict_merge(): + base = dict(obj2=dict(), b1=True, b2=False, b3=False, + one=1, two=2, three=3, obj1=dict(key1=1, key2=2), + l1=[1, 3], l2=[1, 2, 3], l4=[4], + nested=dict(n1=dict(n2=2))) - # string assertions - self.assertIn('one', result) - self.assertIn('two', result) - self.assertEqual(result['three'], 4) - self.assertEqual(result['four'], 4) + other = dict(b1=True, b2=False, b3=True, b4=True, + one=1, three=4, four=4, obj1=dict(key1=2), + l1=[2, 1], l2=[3, 2, 1], l3=[1], + nested=dict(n1=dict(n2=2, n3=3))) - # dict assertions - self.assertIn('obj1', result) - self.assertIn('key1', result['obj1']) - self.assertIn('key2', result['obj1']) + result = dict_merge(base, other) - # list assertions - self.assertEqual(result['l1'], [1, 2, 3]) - self.assertIn('l2', result) - self.assertEqual(result['l3'], [1]) - self.assertIn('l4', result) + # string assertions + assert 'one' in result + assert 'two' in result + assert result['three'] == 4 + assert result['four'] == 4 - # nested assertions - self.assertIn('obj1', result) - self.assertEqual(result['obj1']['key1'], 2) - self.assertIn('key2', result['obj1']) + # dict assertions + assert 'obj1' in result + assert 'key1' in result['obj1'] + assert 'key2' in result['obj1'] - # bool assertions - self.assertIn('b1', result) - self.assertIn('b2', result) - self.assertTrue(result['b3']) - self.assertTrue(result['b4']) + # list assertions + assert result['l1'] == [1, 2, 3] + assert 'l2' in result + assert result['l3'] == [1] + assert 'l4' in result - def test_conditional(self): - self.assertTrue(conditional(10, 10)) - self.assertTrue(conditional('10', '10')) - self.assertTrue(conditional('foo', 'foo')) - self.assertTrue(conditional(True, True)) - self.assertTrue(conditional(False, False)) - self.assertTrue(conditional(None, None)) - self.assertTrue(conditional("ge(1)", 1)) - self.assertTrue(conditional("gt(1)", 2)) - self.assertTrue(conditional("le(2)", 2)) - self.assertTrue(conditional("lt(3)", 2)) - self.assertTrue(conditional("eq(1)", 1)) - self.assertTrue(conditional("neq(0)", 1)) - self.assertTrue(conditional("min(1)", 1)) - self.assertTrue(conditional("max(1)", 1)) - self.assertTrue(conditional("exactly(1)", 1)) + # nested assertions + assert 'obj1' in result + assert result['obj1']['key1'] == 2 + assert 'key2' in result['obj1'] - def test_template(self): - tmpl = Template() - self.assertEqual('foo', tmpl('{{ test }}', {'test': 'foo'})) + # bool assertions + assert 'b1' in result + assert 'b2' in result + assert result['b3'] + assert result['b4'] + + +def test_conditional(): + assert conditional(10, 10) + assert conditional('10', '10') + assert conditional('foo', 'foo') + assert conditional(True, True) + assert conditional(False, False) + assert conditional(None, None) + assert conditional("ge(1)", 1) + assert conditional("gt(1)", 2) + assert conditional("le(2)", 2) + assert conditional("lt(3)", 2) + assert conditional("eq(1)", 1) + assert conditional("neq(0)", 1) + assert conditional("min(1)", 1) + assert conditional("max(1)", 1) + assert conditional("exactly(1)", 1) + + +def test_template(): + tmpl = Template() + assert 'foo' == tmpl('{{ test }}', {'test': 'foo'}) + + +def test_to_masklen(): + assert 24 == to_masklen('255.255.255.0') + + +def test_to_masklen_invalid(): + with pytest.raises(ValueError): + to_masklen('255') + + +def test_to_netmask(): + assert '255.0.0.0' == to_netmask(8) + assert '255.0.0.0' == to_netmask('8') + + +def test_to_netmask_invalid(): + with pytest.raises(ValueError): + to_netmask(128) + + +def test_to_subnet(): + result = to_subnet('192.168.1.1', 24) + assert '192.168.1.0/24' == result + + result = to_subnet('192.168.1.1', 24, dotted_notation=True) + assert '192.168.1.0 255.255.255.0' == result + + +def test_to_subnet_invalid(): + with pytest.raises(ValueError): + to_subnet('foo', 'bar') + + +def test_is_masklen(): + assert is_masklen(32) + assert not is_masklen(33) + assert not is_masklen('foo') + + +def test_is_netmask(): + assert is_netmask('255.255.255.255') + assert not is_netmask(24) + assert not is_netmask('foo') + + +def test_to_ipv6_network(): + assert '2001:db8::' == to_ipv6_network('2001:db8::') + assert '2001:0db8:85a3::' == to_ipv6_network('2001:0db8:85a3:0000:0000:8a2e:0370:7334') + assert '2001:0db8:85a3::' == to_ipv6_network('2001:0db8:85a3:0:0:8a2e:0370:7334')