mirror of
				https://github.com/ansible-collections/community.general.git
				synced 2024-09-14 20:13:21 +02:00 
			
		
		
		
	* sanity: Add future boilerplate Signed-off-by: Abhijeet Kasurde <akasurde@redhat.com> * Module Utils Signed-off-by: Abhijeet Kasurde <akasurde@redhat.com> * Scripts Signed-off-by: Abhijeet Kasurde <akasurde@redhat.com> * sanity: Add future boilerplate * Tests Signed-off-by: Abhijeet Kasurde <akasurde@redhat.com> * CI failure Signed-off-by: Abhijeet Kasurde <akasurde@redhat.com>
		
			
				
	
	
		
			189 lines
		
	
	
	
		
			6.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			189 lines
		
	
	
	
		
			6.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # This code is part of Ansible, but is an independent component.
 | |
| # This particular file snippet, and this file snippet only, is BSD licensed.
 | |
| # Modules you write using this snippet, which is embedded dynamically by Ansible
 | |
| # still belong to the author of the module, and may assign their own license
 | |
| # to the complete work.
 | |
| #
 | |
| # Copyright (c) 2014, Toshio Kuratomi <tkuratomi@ansible.com>
 | |
| #
 | |
| # Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause)
 | |
| 
 | |
| from __future__ import (absolute_import, division, print_function)
 | |
| __metaclass__ = type
 | |
| 
 | |
| import re
 | |
| 
 | |
| 
 | |
| # Input patterns for is_input_dangerous function:
 | |
| #
 | |
| # 1. '"' in string and '--' in string or
 | |
| # "'" in string and '--' in string
 | |
| PATTERN_1 = re.compile(r'(\'|\").*--')
 | |
| 
 | |
| # 2. union \ intersect \ except + select
 | |
| PATTERN_2 = re.compile(r'(UNION|INTERSECT|EXCEPT).*SELECT', re.IGNORECASE)
 | |
| 
 | |
| # 3. ';' and any KEY_WORDS
 | |
| PATTERN_3 = re.compile(r';.*(SELECT|UPDATE|INSERT|DELETE|DROP|TRUNCATE|ALTER)', re.IGNORECASE)
 | |
| 
 | |
| 
 | |
| class SQLParseError(Exception):
 | |
|     pass
 | |
| 
 | |
| 
 | |
| class UnclosedQuoteError(SQLParseError):
 | |
|     pass
 | |
| 
 | |
| 
 | |
| # maps a type of identifier to the maximum number of dot levels that are
 | |
| # allowed to specify that identifier.  For example, a database column can be
 | |
| # specified by up to 4 levels: database.schema.table.column
 | |
| _PG_IDENTIFIER_TO_DOT_LEVEL = dict(
 | |
|     database=1,
 | |
|     schema=2,
 | |
|     table=3,
 | |
|     column=4,
 | |
|     role=1,
 | |
|     tablespace=1,
 | |
|     sequence=3,
 | |
|     publication=1,
 | |
| )
 | |
| _MYSQL_IDENTIFIER_TO_DOT_LEVEL = dict(database=1, table=2, column=3, role=1, vars=1)
 | |
| 
 | |
| 
 | |
| def _find_end_quote(identifier, quote_char):
 | |
|     accumulate = 0
 | |
|     while True:
 | |
|         try:
 | |
|             quote = identifier.index(quote_char)
 | |
|         except ValueError:
 | |
|             raise UnclosedQuoteError
 | |
|         accumulate = accumulate + quote
 | |
|         try:
 | |
|             next_char = identifier[quote + 1]
 | |
|         except IndexError:
 | |
|             return accumulate
 | |
|         if next_char == quote_char:
 | |
|             try:
 | |
|                 identifier = identifier[quote + 2:]
 | |
|                 accumulate = accumulate + 2
 | |
|             except IndexError:
 | |
|                 raise UnclosedQuoteError
 | |
|         else:
 | |
|             return accumulate
 | |
| 
 | |
| 
 | |
| def _identifier_parse(identifier, quote_char):
 | |
|     if not identifier:
 | |
|         raise SQLParseError('Identifier name unspecified or unquoted trailing dot')
 | |
| 
 | |
|     already_quoted = False
 | |
|     if identifier.startswith(quote_char):
 | |
|         already_quoted = True
 | |
|         try:
 | |
|             end_quote = _find_end_quote(identifier[1:], quote_char=quote_char) + 1
 | |
|         except UnclosedQuoteError:
 | |
|             already_quoted = False
 | |
|         else:
 | |
|             if end_quote < len(identifier) - 1:
 | |
|                 if identifier[end_quote + 1] == '.':
 | |
|                     dot = end_quote + 1
 | |
|                     first_identifier = identifier[:dot]
 | |
|                     next_identifier = identifier[dot + 1:]
 | |
|                     further_identifiers = _identifier_parse(next_identifier, quote_char)
 | |
|                     further_identifiers.insert(0, first_identifier)
 | |
|                 else:
 | |
|                     raise SQLParseError('User escaped identifiers must escape extra quotes')
 | |
|             else:
 | |
|                 further_identifiers = [identifier]
 | |
| 
 | |
|     if not already_quoted:
 | |
|         try:
 | |
|             dot = identifier.index('.')
 | |
|         except ValueError:
 | |
|             identifier = identifier.replace(quote_char, quote_char * 2)
 | |
|             identifier = ''.join((quote_char, identifier, quote_char))
 | |
|             further_identifiers = [identifier]
 | |
|         else:
 | |
|             if dot == 0 or dot >= len(identifier) - 1:
 | |
|                 identifier = identifier.replace(quote_char, quote_char * 2)
 | |
|                 identifier = ''.join((quote_char, identifier, quote_char))
 | |
|                 further_identifiers = [identifier]
 | |
|             else:
 | |
|                 first_identifier = identifier[:dot]
 | |
|                 next_identifier = identifier[dot + 1:]
 | |
|                 further_identifiers = _identifier_parse(next_identifier, quote_char)
 | |
|                 first_identifier = first_identifier.replace(quote_char, quote_char * 2)
 | |
|                 first_identifier = ''.join((quote_char, first_identifier, quote_char))
 | |
|                 further_identifiers.insert(0, first_identifier)
 | |
| 
 | |
|     return further_identifiers
 | |
| 
 | |
| 
 | |
| def pg_quote_identifier(identifier, id_type):
 | |
|     identifier_fragments = _identifier_parse(identifier, quote_char='"')
 | |
|     if len(identifier_fragments) > _PG_IDENTIFIER_TO_DOT_LEVEL[id_type]:
 | |
|         raise SQLParseError('PostgreSQL does not support %s with more than %i dots' % (id_type, _PG_IDENTIFIER_TO_DOT_LEVEL[id_type]))
 | |
|     return '.'.join(identifier_fragments)
 | |
| 
 | |
| 
 | |
| def mysql_quote_identifier(identifier, id_type):
 | |
|     identifier_fragments = _identifier_parse(identifier, quote_char='`')
 | |
|     if (len(identifier_fragments) - 1) > _MYSQL_IDENTIFIER_TO_DOT_LEVEL[id_type]:
 | |
|         raise SQLParseError('MySQL does not support %s with more than %i dots' % (id_type, _MYSQL_IDENTIFIER_TO_DOT_LEVEL[id_type]))
 | |
| 
 | |
|     special_cased_fragments = []
 | |
|     for fragment in identifier_fragments:
 | |
|         if fragment == '`*`':
 | |
|             special_cased_fragments.append('*')
 | |
|         else:
 | |
|             special_cased_fragments.append(fragment)
 | |
| 
 | |
|     return '.'.join(special_cased_fragments)
 | |
| 
 | |
| 
 | |
| def is_input_dangerous(string):
 | |
|     """Check if the passed string is potentially dangerous.
 | |
|     Can be used to prevent SQL injections.
 | |
| 
 | |
|     Note: use this function only when you can't use
 | |
|       psycopg2's cursor.execute method parametrized
 | |
|       (typically with DDL queries).
 | |
|     """
 | |
|     if not string:
 | |
|         return False
 | |
| 
 | |
|     for pattern in (PATTERN_1, PATTERN_2, PATTERN_3):
 | |
|         if re.search(pattern, string):
 | |
|             return True
 | |
| 
 | |
|     return False
 | |
| 
 | |
| 
 | |
| def check_input(module, *args):
 | |
|     """Wrapper for is_input_dangerous function."""
 | |
|     needs_to_check = args
 | |
| 
 | |
|     dangerous_elements = []
 | |
| 
 | |
|     for elem in needs_to_check:
 | |
|         if isinstance(elem, str):
 | |
|             if is_input_dangerous(elem):
 | |
|                 dangerous_elements.append(elem)
 | |
| 
 | |
|         elif isinstance(elem, list):
 | |
|             for e in elem:
 | |
|                 if is_input_dangerous(e):
 | |
|                     dangerous_elements.append(e)
 | |
| 
 | |
|         elif elem is None or isinstance(elem, bool):
 | |
|             pass
 | |
| 
 | |
|         else:
 | |
|             elem = str(elem)
 | |
|             if is_input_dangerous(elem):
 | |
|                 dangerous_elements.append(elem)
 | |
| 
 | |
|     if dangerous_elements:
 | |
|         module.fail_json(msg="Passed input '%s' is "
 | |
|                              "potentially dangerous" % ', '.join(dangerous_elements))
 |