diff --git a/changelogs/fragments/2037-add-from-csv-filter.yml b/changelogs/fragments/2037-add-from-csv-filter.yml new file mode 100644 index 0000000000..d99c4cd0a8 --- /dev/null +++ b/changelogs/fragments/2037-add-from-csv-filter.yml @@ -0,0 +1,7 @@ +--- +add plugin.filter: + - name: from_csv + description: Converts CSV text input into list of dicts +minor_changes: + - csv module utils - new module_utils for shared functions between ``from_csv`` filter and ``read_csv`` module (https://github.com/ansible-collections/community.general/pull/2037). + - read_csv - refactored read_csv module to use shared csv functions from csv module_utils (https://github.com/ansible-collections/community.general/pull/2037). diff --git a/plugins/filter/from_csv.py b/plugins/filter/from_csv.py new file mode 100644 index 0000000000..13a18aa88a --- /dev/null +++ b/plugins/filter/from_csv.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- + +# Copyright: (c) 2021, Andrew Pantuso (@ajpantuso) +# Copyright: (c) 2018, Dag Wieers (@dagwieers) +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +from ansible.errors import AnsibleFilterError +from ansible.module_utils._text import to_native + +from ansible_collections.community.general.plugins.module_utils.csv import (initialize_dialect, read_csv, CSVError, + DialectNotAvailableError, + CustomDialectFailureError) + + +def from_csv(data, dialect='excel', fieldnames=None, delimiter=None, skipinitialspace=None, strict=None): + + dialect_params = { + "delimiter": delimiter, + "skipinitialspace": skipinitialspace, + "strict": strict, + } + + try: + dialect = initialize_dialect(dialect, **dialect_params) + except (CustomDialectFailureError, DialectNotAvailableError) as e: + raise AnsibleFilterError(to_native(e)) + + reader = read_csv(data, dialect, fieldnames) + + data_list = [] + + try: + for row in reader: + data_list.append(row) + except CSVError as e: + raise AnsibleFilterError("Unable to process file: %s" % to_native(e)) + + return data_list + + +class FilterModule(object): + + def filters(self): + return { + 'from_csv': from_csv + } diff --git a/plugins/module_utils/csv.py b/plugins/module_utils/csv.py new file mode 100644 index 0000000000..426e2eb279 --- /dev/null +++ b/plugins/module_utils/csv.py @@ -0,0 +1,67 @@ +# -*- coding: utf-8 -*- + +# Copyright: (c) 2021, Andrew Pantuso (@ajpantuso) +# Copyright: (c) 2018, Dag Wieers (@dagwieers) +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import csv +from io import BytesIO, StringIO + +from ansible.module_utils._text import to_native +from ansible.module_utils.six import PY3 + + +class CustomDialectFailureError(Exception): + pass + + +class DialectNotAvailableError(Exception): + pass + + +CSVError = csv.Error + + +def initialize_dialect(dialect, **kwargs): + # Add Unix dialect from Python 3 + class unix_dialect(csv.Dialect): + """Describe the usual properties of Unix-generated CSV files.""" + delimiter = ',' + quotechar = '"' + doublequote = True + skipinitialspace = False + lineterminator = '\n' + quoting = csv.QUOTE_ALL + + csv.register_dialect("unix", unix_dialect) + + if dialect not in csv.list_dialects(): + raise DialectNotAvailableError("Dialect '%s' is not supported by your version of python." % dialect) + + # Create a dictionary from only set options + dialect_params = dict((k, v) for k, v in kwargs.items() if v is not None) + if dialect_params: + try: + csv.register_dialect('custom', dialect, **dialect_params) + except TypeError as e: + raise CustomDialectFailureError("Unable to create custom dialect: %s" % to_native(e)) + dialect = 'custom' + + return dialect + + +def read_csv(data, dialect, fieldnames=None): + + data = to_native(data, errors='surrogate_or_strict') + + if PY3: + fake_fh = StringIO(data) + else: + fake_fh = BytesIO(data) + + reader = csv.DictReader(fake_fh, fieldnames=fieldnames, dialect=dialect) + + return reader diff --git a/plugins/modules/files/read_csv.py b/plugins/modules/files/read_csv.py index 24a77c0e28..c48efc7440 100644 --- a/plugins/modules/files/read_csv.py +++ b/plugins/modules/files/read_csv.py @@ -137,26 +137,12 @@ list: gid: 500 ''' -import csv -from io import BytesIO, StringIO - from ansible.module_utils.basic import AnsibleModule -from ansible.module_utils._text import to_text -from ansible.module_utils.six import PY3 +from ansible.module_utils._text import to_native - -# Add Unix dialect from Python 3 -class unix_dialect(csv.Dialect): - """Describe the usual properties of Unix-generated CSV files.""" - delimiter = ',' - quotechar = '"' - doublequote = True - skipinitialspace = False - lineterminator = '\n' - quoting = csv.QUOTE_ALL - - -csv.register_dialect("unix", unix_dialect) +from ansible_collections.community.general.plugins.module_utils.csv import (initialize_dialect, read_csv, CSVError, + DialectNotAvailableError, + CustomDialectFailureError) def main(): @@ -180,38 +166,24 @@ def main(): fieldnames = module.params['fieldnames'] unique = module.params['unique'] - if dialect not in csv.list_dialects(): - module.fail_json(msg="Dialect '%s' is not supported by your version of python." % dialect) + dialect_params = { + "delimiter": module.params['delimiter'], + "skipinitialspace": module.params['skipinitialspace'], + "strict": module.params['strict'], + } - dialect_options = dict( - delimiter=module.params['delimiter'], - skipinitialspace=module.params['skipinitialspace'], - strict=module.params['strict'], - ) - - # Create a dictionary from only set options - dialect_params = dict((k, v) for k, v in dialect_options.items() if v is not None) - if dialect_params: - try: - csv.register_dialect('custom', dialect, **dialect_params) - except TypeError as e: - module.fail_json(msg="Unable to create custom dialect: %s" % to_text(e)) - dialect = 'custom' + try: + dialect = initialize_dialect(dialect, **dialect_params) + except (CustomDialectFailureError, DialectNotAvailableError) as e: + module.fail_json(msg=to_native(e)) try: with open(path, 'rb') as f: data = f.read() except (IOError, OSError) as e: - module.fail_json(msg="Unable to open file: %s" % to_text(e)) + module.fail_json(msg="Unable to open file: %s" % to_native(e)) - if PY3: - # Manually decode on Python3 so that we can use the surrogateescape error handler - data = to_text(data, errors='surrogate_or_strict') - fake_fh = StringIO(data) - else: - fake_fh = BytesIO(data) - - reader = csv.DictReader(fake_fh, fieldnames=fieldnames, dialect=dialect) + reader = read_csv(data, dialect, fieldnames) if key and key not in reader.fieldnames: module.fail_json(msg="Key '%s' was not found in the CSV header fields: %s" % (key, ', '.join(reader.fieldnames))) @@ -223,16 +195,16 @@ def main(): try: for row in reader: data_list.append(row) - except csv.Error as e: - module.fail_json(msg="Unable to process file: %s" % to_text(e)) + except CSVError as e: + module.fail_json(msg="Unable to process file: %s" % to_native(e)) else: try: for row in reader: if unique and row[key] in data_dict: module.fail_json(msg="Key '%s' is not unique for value '%s'" % (key, row[key])) data_dict[row[key]] = row - except csv.Error as e: - module.fail_json(msg="Unable to process file: %s" % to_text(e)) + except CSVError as e: + module.fail_json(msg="Unable to process file: %s" % to_native(e)) module.exit_json(dict=data_dict, list=data_list) diff --git a/tests/integration/targets/filter_from_csv/aliases b/tests/integration/targets/filter_from_csv/aliases new file mode 100644 index 0000000000..f04737b845 --- /dev/null +++ b/tests/integration/targets/filter_from_csv/aliases @@ -0,0 +1,2 @@ +shippable/posix/group2 +skip/python2.6 # filters are controller only, and we no longer support Python 2.6 on the controller diff --git a/tests/integration/targets/filter_from_csv/tasks/main.yml b/tests/integration/targets/filter_from_csv/tasks/main.yml new file mode 100644 index 0000000000..aafb28fbb0 --- /dev/null +++ b/tests/integration/targets/filter_from_csv/tasks/main.yml @@ -0,0 +1,49 @@ +#################################################################### +# WARNING: These are designed specifically for Ansible tests # +# and should not be used as examples of how to write Ansible roles # +#################################################################### + +- name: Parse valid csv input + assert: + that: + - "valid_comma_separated | community.general.from_csv == expected_result" + +- name: Parse valid csv input containing spaces with/without skipinitialspace=True + assert: + that: + - "valid_comma_separated_spaces | community.general.from_csv(skipinitialspace=True) == expected_result" + - "valid_comma_separated_spaces | community.general.from_csv != expected_result" + +- name: Parse valid csv input with no headers with/without specifiying fieldnames + assert: + that: + - "valid_comma_separated_no_headers | community.general.from_csv(fieldnames=['id','name','role']) == expected_result" + - "valid_comma_separated_no_headers | community.general.from_csv != expected_result" + +- name: Parse valid pipe-delimited csv input with/without delimiter=| + assert: + that: + - "valid_pipe_separated | community.general.from_csv(delimiter='|') == expected_result" + - "valid_pipe_separated | community.general.from_csv != expected_result" + +- name: Register result of invalid csv input when strict=False + debug: + var: "invalid_comma_separated | community.general.from_csv" + register: _invalid_csv_strict_false + +- name: Test invalid csv input when strict=False is successful + assert: + that: + - _invalid_csv_strict_false is success + +- name: Register result of invalid csv input when strict=True + debug: + var: "invalid_comma_separated | community.general.from_csv(strict=True)" + register: _invalid_csv_strict_true + ignore_errors: True + +- name: Test invalid csv input when strict=True is failed + assert: + that: + - _invalid_csv_strict_true is failed + - _invalid_csv_strict_true.msg is match('Unable to process file:.*') diff --git a/tests/integration/targets/filter_from_csv/vars/main.yml b/tests/integration/targets/filter_from_csv/vars/main.yml new file mode 100644 index 0000000000..5801bc20dc --- /dev/null +++ b/tests/integration/targets/filter_from_csv/vars/main.yml @@ -0,0 +1,26 @@ +valid_comma_separated: | + id,name,role + 1,foo,bar + 2,bar,baz +valid_comma_separated_spaces: | + id,name,role + 1, foo, bar + 2, bar, baz +valid_comma_separated_no_headers: | + 1,foo,bar + 2,bar,baz +valid_pipe_separated: | + id|name|role + 1|foo|bar + 2|bar|baz +invalid_comma_separated: | + id,name,role + 1,foo,bar + 2,"b"ar",baz +expected_result: + - id: '1' + name: foo + role: bar + - id: '2' + name: bar + role: baz diff --git a/tests/unit/plugins/module_utils/test_csv.py b/tests/unit/plugins/module_utils/test_csv.py new file mode 100644 index 0000000000..b31915d66d --- /dev/null +++ b/tests/unit/plugins/module_utils/test_csv.py @@ -0,0 +1,164 @@ +# -*- coding: utf-8 -*- + +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import pytest + +from ansible_collections.community.general.plugins.module_utils import csv + + +VALID_CSV = [ + ( + 'excel', + {}, + None, + "id,name,role\n1,foo,bar\n2,bar,baz", + [ + { + "id": "1", + "name": "foo", + "role": "bar", + }, + { + "id": "2", + "name": "bar", + "role": "baz", + }, + ] + ), + ( + 'excel', + {"skipinitialspace": True}, + None, + "id,name,role\n1, foo, bar\n2, bar, baz", + [ + { + "id": "1", + "name": "foo", + "role": "bar", + }, + { + "id": "2", + "name": "bar", + "role": "baz", + }, + ] + ), + ( + 'excel', + {"delimiter": '|'}, + None, + "id|name|role\n1|foo|bar\n2|bar|baz", + [ + { + "id": "1", + "name": "foo", + "role": "bar", + }, + { + "id": "2", + "name": "bar", + "role": "baz", + }, + ] + ), + ( + 'unix', + {}, + None, + "id,name,role\n1,foo,bar\n2,bar,baz", + [ + { + "id": "1", + "name": "foo", + "role": "bar", + }, + { + "id": "2", + "name": "bar", + "role": "baz", + }, + ] + ), + ( + 'excel', + {}, + ['id', 'name', 'role'], + "1,foo,bar\n2,bar,baz", + [ + { + "id": "1", + "name": "foo", + "role": "bar", + }, + { + "id": "2", + "name": "bar", + "role": "baz", + }, + ] + ), +] + +INVALID_CSV = [ + ( + 'excel', + {'strict': True}, + None, + 'id,name,role\n1,"f"oo",bar\n2,bar,baz', + ), +] + +INVALID_DIALECT = [ + ( + 'invalid', + {}, + None, + "id,name,role\n1,foo,bar\n2,bar,baz", + ), +] + + +@pytest.mark.parametrize("dialect,dialect_params,fieldnames,data,expected", VALID_CSV) +def test_valid_csv(data, dialect, dialect_params, fieldnames, expected): + dialect = csv.initialize_dialect(dialect, **dialect_params) + reader = csv.read_csv(data, dialect, fieldnames) + result = True + + for idx, row in enumerate(reader): + for k, v in row.items(): + if expected[idx][k] != v: + result = False + break + + assert result + + +@pytest.mark.parametrize("dialect,dialect_params,fieldnames,data", INVALID_CSV) +def test_invalid_csv(data, dialect, dialect_params, fieldnames): + dialect = csv.initialize_dialect(dialect, **dialect_params) + reader = csv.read_csv(data, dialect, fieldnames) + result = False + + try: + for row in reader: + continue + except csv.CSVError: + result = True + + assert result + + +@pytest.mark.parametrize("dialect,dialect_params,fieldnames,data", INVALID_DIALECT) +def test_invalid_dialect(data, dialect, dialect_params, fieldnames): + result = False + + try: + dialect = csv.initialize_dialect(dialect, **dialect_params) + except csv.DialectNotAvailableError: + result = True + + assert result