# -*- coding: utf-8 -*- # This code is part of Ansible, but is an independent component. # This particular file snippet, and this file snippet only, is based on # the config parser from here: https://github.com/emre/storm/blob/master/storm/parsers/ssh_config_parser.py # Copyright (C) <2013> # SPDX-License-Identifier: MIT from __future__ import (absolute_import, division, print_function) import os import re import traceback from operator import itemgetter __metaclass__ = type try: from paramiko.config import SSHConfig except ImportError: SSHConfig = object HAS_PARAMIKO = False PARAMIKO_IMPORT_ERROR = traceback.format_exc() else: HAS_PARAMIKO = True PARAMIKO_IMPORT_ERROR = None class StormConfig(SSHConfig): def parse(self, file_obj): """ Read an OpenSSH config from the given file object. @param file_obj: a file-like object to read the config file from @type file_obj: file """ order = 1 host = {"host": ['*'], "config": {}, } for line in file_obj: line = line.rstrip('\n').lstrip() if line == '': self._config.append({ 'type': 'empty_line', 'value': line, 'host': '', 'order': order, }) order += 1 continue if line.startswith('#'): self._config.append({ 'type': 'comment', 'value': line, 'host': '', 'order': order, }) order += 1 continue if '=' in line: # Ensure ProxyCommand gets properly split if line.lower().strip().startswith('proxycommand'): proxy_re = re.compile(r"^(proxycommand)\s*=*\s*(.*)", re.I) match = proxy_re.match(line) key, value = match.group(1).lower(), match.group(2) else: key, value = line.split('=', 1) key = key.strip().lower() else: # find first whitespace, and split there i = 0 while (i < len(line)) and not line[i].isspace(): i += 1 if i == len(line): raise Exception('Unparsable line: %r' % line) key = line[:i].lower() value = line[i:].lstrip() if key == 'host': self._config.append(host) value = value.split() host = { key: value, 'config': {}, 'type': 'entry', 'order': order } order += 1 elif key in ['identityfile', 'localforward', 'remoteforward']: if key in host['config']: host['config'][key].append(value) else: host['config'][key] = [value] elif key not in host['config']: host['config'].update({key: value}) self._config.append(host) class ConfigParser(object): """ Config parser for ~/.ssh/config files. """ def __init__(self, ssh_config_file=None): if not ssh_config_file: ssh_config_file = self.get_default_ssh_config_file() self.defaults = {} self.ssh_config_file = ssh_config_file if not os.path.exists(self.ssh_config_file): if not os.path.exists(os.path.dirname(self.ssh_config_file)): os.makedirs(os.path.dirname(self.ssh_config_file)) open(self.ssh_config_file, 'w+').close() os.chmod(self.ssh_config_file, 0o600) self.config_data = [] def get_default_ssh_config_file(self): return os.path.expanduser("~/.ssh/config") def load(self): config = StormConfig() with open(self.ssh_config_file) as fd: config.parse(fd) for entry in config.__dict__.get("_config"): if entry.get("host") == ["*"]: self.defaults.update(entry.get("config")) if entry.get("type") in ["comment", "empty_line"]: self.config_data.append(entry) continue host_item = { 'host': entry["host"][0], 'options': entry.get("config"), 'type': 'entry', 'order': entry.get("order", 0), } if len(entry["host"]) > 1: host_item.update({ 'host': " ".join(entry["host"]), }) # minor bug in paramiko.SSHConfig that duplicates # "Host *" entries. if entry.get("config") and len(entry.get("config")) > 0: self.config_data.append(host_item) return self.config_data def add_host(self, host, options): self.config_data.append({ 'host': host, 'options': options, 'order': self.get_last_index(), }) return self def update_host(self, host, options, use_regex=False): for index, host_entry in enumerate(self.config_data): if host_entry.get("host") == host or \ (use_regex and re.match(host, host_entry.get("host"))): if 'deleted_fields' in options: deleted_fields = options.pop("deleted_fields") for deleted_field in deleted_fields: del self.config_data[index]["options"][deleted_field] self.config_data[index]["options"].update(options) return self def search_host(self, search_string): results = [] for host_entry in self.config_data: if host_entry.get("type") != 'entry': continue if host_entry.get("host") == "*": continue searchable_information = host_entry.get("host") for key, value in host_entry.get("options").items(): if isinstance(value, list): value = " ".join(value) if isinstance(value, int): value = str(value) searchable_information += " " + value if search_string in searchable_information: results.append(host_entry) return results def delete_host(self, host): found = 0 for index, host_entry in enumerate(self.config_data): if host_entry.get("host") == host: del self.config_data[index] found += 1 if found == 0: raise ValueError('No host found') return self def delete_all_hosts(self): self.config_data = [] self.write_to_ssh_config() return self def dump(self): if len(self.config_data) < 1: return file_content = "" self.config_data = sorted(self.config_data, key=itemgetter("order")) for host_item in self.config_data: if host_item.get("type") in ['comment', 'empty_line']: file_content += host_item.get("value") + "\n" continue host_item_content = "Host {0}\n".format(host_item.get("host")) for key, value in host_item.get("options").items(): if isinstance(value, list): sub_content = "" for value_ in value: sub_content += " {0} {1}\n".format( key, value_ ) host_item_content += sub_content else: host_item_content += " {0} {1}\n".format( key, value ) file_content += host_item_content return file_content def write_to_ssh_config(self): with open(self.ssh_config_file, 'w+') as f: data = self.dump() if data: f.write(data) return self def get_last_index(self): last_index = 0 indexes = [] for item in self.config_data: if item.get("order"): indexes.append(item.get("order")) if len(indexes) > 0: last_index = max(indexes) return last_index