diff --git a/changelogs/fragments/5367-consul-refactor.yaml b/changelogs/fragments/5367-consul-refactor.yaml new file mode 100644 index 0000000000..2012d69cc5 --- /dev/null +++ b/changelogs/fragments/5367-consul-refactor.yaml @@ -0,0 +1,2 @@ +minor_changes: + - consul - minor refactoring (https://github.com/ansible-collections/community.general/pull/5367). diff --git a/plugins/modules/clustering/consul/consul.py b/plugins/modules/clustering/consul/consul.py index 152d4577a1..0d75bde2eb 100644 --- a/plugins/modules/clustering/consul/consul.py +++ b/plugins/modules/clustering/consul/consul.py @@ -241,7 +241,7 @@ from ansible.module_utils.basic import AnsibleModule def register_with_consul(module): - state = module.params.get('state') + state = module.params['state'] if state == 'present': add(module) @@ -267,10 +267,8 @@ def add(module): def remove(module): ''' removes a service or a check ''' - service_id = module.params.get('service_id') or module.params.get('service_name') - check_id = module.params.get('check_id') or module.params.get('check_name') - if not (service_id or check_id): - module.fail_json(msg='services and checks are removed by id or name. please supply a service id/name or a check id/name') + service_id = module.params['service_id'] or module.params['service_name'] + check_id = module.params['check_id'] or module.params['check_name'] if service_id: remove_service(module, service_id) else: @@ -343,63 +341,61 @@ def remove_service(module, service_id): consul_api = get_consul_api(module) service = get_service_by_id_or_name(consul_api, service_id) if service: - consul_api.agent.service.deregister(service_id, token=module.params.get('token')) + consul_api.agent.service.deregister(service_id, token=module.params['token']) module.exit_json(changed=True, id=service_id) module.exit_json(changed=False, id=service_id) def get_consul_api(module): - consulClient = consul.Consul(host=module.params.get('host'), - port=module.params.get('port'), - scheme=module.params.get('scheme'), - verify=module.params.get('validate_certs'), - token=module.params.get('token')) + consulClient = consul.Consul(host=module.params['host'], + port=module.params['port'], + scheme=module.params['scheme'], + verify=module.params['validate_certs'], + token=module.params['token']) consulClient.agent.service = PatchedConsulAgentService(consulClient) return consulClient def get_service_by_id_or_name(consul_api, service_id_or_name): ''' iterate the registered services and find one with the given id ''' - for name, service in consul_api.agent.services().items(): - if service['ID'] == service_id_or_name or service['Service'] == service_id_or_name: + for dummy, service in consul_api.agent.services().items(): + if service_id_or_name in (service['ID'], service['Service']): return ConsulService(loaded=service) def parse_check(module): - if len([p for p in (module.params.get('script'), module.params.get('ttl'), module.params.get('tcp'), module.params.get('http')) if p]) > 1: + _checks = [module.params[p] for p in ('script', 'ttl', 'tcp', 'http') if module.params[p]] + + if len(_checks) > 1: module.fail_json( msg='checks are either script, tcp, http or ttl driven, supplying more than one does not make sense') - if module.params.get('check_id') or module.params.get('script') or module.params.get('ttl') or module.params.get('tcp') or module.params.get('http'): - + if module.params['check_id'] or _checks: return ConsulCheck( - module.params.get('check_id'), - module.params.get('check_name'), - module.params.get('check_node'), - module.params.get('check_host'), - module.params.get('script'), - module.params.get('interval'), - module.params.get('ttl'), - module.params.get('notes'), - module.params.get('tcp'), - module.params.get('http'), - module.params.get('timeout'), - module.params.get('service_id'), + module.params['check_id'], + module.params['check_name'], + module.params['check_node'], + module.params['check_host'], + module.params['script'], + module.params['interval'], + module.params['ttl'], + module.params['notes'], + module.params['tcp'], + module.params['http'], + module.params['timeout'], + module.params['service_id'], ) def parse_service(module): - if module.params.get('service_name'): - return ConsulService( - module.params.get('service_id'), - module.params.get('service_name'), - module.params.get('service_address'), - module.params.get('service_port'), - module.params.get('tags'), - ) - elif not module.params.get('service_name'): - module.fail_json(msg="service_name is required to configure a service.") + return ConsulService( + module.params['service_id'], + module.params['service_name'], + module.params['service_address'], + module.params['service_port'], + module.params['tags'], + ) class ConsulService(object): @@ -502,10 +498,10 @@ class ConsulCheck(object): if interval is None: raise Exception('tcp check must specify interval') - regex = r"(?P.*)(?::)(?P(?:[0-9]+))$" + regex = r"(?P.*):(?P(?:[0-9]+))$" match = re.match(regex, tcp) - if match is None: + if not match: raise Exception('tcp check must be in host:port format') self.check = consul.Check.tcp(match.group('host').strip('[]'), int(match.group('port')), self.interval) @@ -513,7 +509,7 @@ class ConsulCheck(object): def validate_duration(self, name, duration): if duration: duration_units = ['ns', 'us', 'ms', 's', 'm', 'h'] - if not any((duration.endswith(suffix) for suffix in duration_units)): + if not any(duration.endswith(suffix) for suffix in duration_units): duration = "{0}s".format(duration) return duration @@ -589,6 +585,10 @@ def main(): tags=dict(type='list', elements='str'), token=dict(no_log=True) ), + required_if=[ + ('state', 'present', ['service_name']), + ('state', 'absent', ['service_id', 'service_name', 'check_id', 'check_name'], True), + ], supports_check_mode=False, ) @@ -598,7 +598,7 @@ def main(): register_with_consul(module) except ConnectionError as e: module.fail_json(msg='Could not connect to consul agent at %s:%s, error was %s' % ( - module.params.get('host'), module.params.get('port'), str(e))) + module.params['host'], module.params['port'], str(e))) except Exception as e: module.fail_json(msg=str(e))