summaryrefslogtreecommitdiffstats
path: root/test/lib/ansible_test/_util/target/sanity/import/importer.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/lib/ansible_test/_util/target/sanity/import/importer.py')
-rw-r--r--test/lib/ansible_test/_util/target/sanity/import/importer.py573
1 files changed, 573 insertions, 0 deletions
diff --git a/test/lib/ansible_test/_util/target/sanity/import/importer.py b/test/lib/ansible_test/_util/target/sanity/import/importer.py
new file mode 100644
index 0000000..44a5ddc
--- /dev/null
+++ b/test/lib/ansible_test/_util/target/sanity/import/importer.py
@@ -0,0 +1,573 @@
+"""Import the given python module(s) and report error(s) encountered."""
+from __future__ import (absolute_import, division, print_function)
+__metaclass__ = type
+
+
+def main():
+ """
+ Main program function used to isolate globals from imported code.
+ Changes to globals in imported modules on Python 2.x will overwrite our own globals.
+ """
+ import os
+ import sys
+ import types
+
+ # preload an empty ansible._vendor module to prevent use of any embedded modules during the import test
+ vendor_module_name = 'ansible._vendor'
+
+ vendor_module = types.ModuleType(vendor_module_name)
+ vendor_module.__file__ = os.path.join(os.path.sep.join(os.path.abspath(__file__).split(os.path.sep)[:-8]), 'lib/ansible/_vendor/__init__.py')
+ vendor_module.__path__ = []
+ vendor_module.__package__ = vendor_module_name
+
+ sys.modules[vendor_module_name] = vendor_module
+
+ import ansible
+ import contextlib
+ import datetime
+ import json
+ import re
+ import runpy
+ import subprocess
+ import traceback
+ import warnings
+
+ ansible_path = os.path.dirname(os.path.dirname(ansible.__file__))
+ temp_path = os.environ['SANITY_TEMP_PATH'] + os.path.sep
+ external_python = os.environ.get('SANITY_EXTERNAL_PYTHON')
+ yaml_to_json_path = os.environ.get('SANITY_YAML_TO_JSON')
+ collection_full_name = os.environ.get('SANITY_COLLECTION_FULL_NAME')
+ collection_root = os.environ.get('ANSIBLE_COLLECTIONS_PATH')
+ import_type = os.environ.get('SANITY_IMPORTER_TYPE')
+
+ try:
+ # noinspection PyCompatibility
+ from importlib import import_module
+ except ImportError:
+ def import_module(name, package=None): # type: (str, str | None) -> types.ModuleType
+ assert package is None
+ __import__(name)
+ return sys.modules[name]
+
+ from io import BytesIO, TextIOWrapper
+
+ try:
+ from importlib.util import spec_from_loader, module_from_spec
+ from importlib.machinery import SourceFileLoader, ModuleSpec # pylint: disable=unused-import
+ except ImportError:
+ has_py3_loader = False
+ else:
+ has_py3_loader = True
+
+ if collection_full_name:
+ # allow importing code from collections when testing a collection
+ from ansible.module_utils.common.text.converters import to_bytes, to_text, to_native, text_type
+
+ # noinspection PyProtectedMember
+ from ansible.utils.collection_loader._collection_finder import _AnsibleCollectionFinder
+ from ansible.utils.collection_loader import _collection_finder
+
+ yaml_to_dict_cache = {}
+
+ # unique ISO date marker matching the one present in yaml_to_json.py
+ iso_date_marker = 'isodate:f23983df-f3df-453c-9904-bcd08af468cc:'
+ iso_date_re = re.compile('^%s([0-9]{4})-([0-9]{2})-([0-9]{2})$' % iso_date_marker)
+
+ def parse_value(value):
+ """Custom value parser for JSON deserialization that recognizes our internal ISO date format."""
+ if isinstance(value, text_type):
+ match = iso_date_re.search(value)
+
+ if match:
+ value = datetime.date(int(match.group(1)), int(match.group(2)), int(match.group(3)))
+
+ return value
+
+ def object_hook(data):
+ """Object hook for custom ISO date deserialization from JSON."""
+ return dict((key, parse_value(value)) for key, value in data.items())
+
+ def yaml_to_dict(yaml, content_id):
+ """
+ Return a Python dict version of the provided YAML.
+ Conversion is done in a subprocess since the current Python interpreter does not have access to PyYAML.
+ """
+ if content_id in yaml_to_dict_cache:
+ return yaml_to_dict_cache[content_id]
+
+ try:
+ cmd = [external_python, yaml_to_json_path]
+ proc = subprocess.Popen([to_bytes(c) for c in cmd], # pylint: disable=consider-using-with
+ stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ stdout_bytes, stderr_bytes = proc.communicate(to_bytes(yaml))
+
+ if proc.returncode != 0:
+ raise Exception('command %s failed with return code %d: %s' % ([to_native(c) for c in cmd], proc.returncode, to_native(stderr_bytes)))
+
+ data = yaml_to_dict_cache[content_id] = json.loads(to_text(stdout_bytes), object_hook=object_hook)
+
+ return data
+ except Exception as ex:
+ raise Exception('internal importer error - failed to parse yaml: %s' % to_native(ex))
+
+ _collection_finder._meta_yml_to_dict = yaml_to_dict # pylint: disable=protected-access
+
+ collection_loader = _AnsibleCollectionFinder(paths=[collection_root])
+ # noinspection PyProtectedMember
+ collection_loader._install() # pylint: disable=protected-access
+ else:
+ # do not support collection loading when not testing a collection
+ collection_loader = None
+
+ if collection_loader and import_type == 'plugin':
+ # do not unload ansible code for collection plugin (not module) tests
+ # doing so could result in the collection loader being initialized multiple times
+ pass
+ else:
+ # remove all modules under the ansible package, except the preloaded vendor module
+ list(map(sys.modules.pop, [m for m in sys.modules if m.partition('.')[0] == ansible.__name__ and m != vendor_module_name]))
+
+ if import_type == 'module':
+ # pre-load an empty ansible package to prevent unwanted code in __init__.py from loading
+ # this more accurately reflects the environment that AnsiballZ runs modules under
+ # it also avoids issues with imports in the ansible package that are not allowed
+ ansible_module = types.ModuleType(ansible.__name__)
+ ansible_module.__file__ = ansible.__file__
+ ansible_module.__path__ = ansible.__path__
+ ansible_module.__package__ = ansible.__package__
+
+ sys.modules[ansible.__name__] = ansible_module
+
+ class ImporterAnsibleModuleException(Exception):
+ """Exception thrown during initialization of ImporterAnsibleModule."""
+
+ class ImporterAnsibleModule:
+ """Replacement for AnsibleModule to support import testing."""
+ def __init__(self, *args, **kwargs):
+ raise ImporterAnsibleModuleException()
+
+ class RestrictedModuleLoader:
+ """Python module loader that restricts inappropriate imports."""
+ def __init__(self, path, name, restrict_to_module_paths):
+ self.path = path
+ self.name = name
+ self.loaded_modules = set()
+ self.restrict_to_module_paths = restrict_to_module_paths
+
+ def find_spec(self, fullname, path=None, target=None): # pylint: disable=unused-argument
+ # type: (RestrictedModuleLoader, str, list[str], types.ModuleType | None ) -> ModuleSpec | None | ImportError
+ """Return the spec from the loader or None"""
+ loader = self._get_loader(fullname, path=path)
+ if loader is not None:
+ if has_py3_loader:
+ # loader is expected to be Optional[importlib.abc.Loader], but RestrictedModuleLoader does not inherit from importlib.abc.Loder
+ return spec_from_loader(fullname, loader) # type: ignore[arg-type]
+ raise ImportError("Failed to import '%s' due to a bug in ansible-test. Check importlib imports for typos." % fullname)
+ return None
+
+ def find_module(self, fullname, path=None):
+ # type: (RestrictedModuleLoader, str, list[str]) -> RestrictedModuleLoader | None
+ """Return self if the given fullname is restricted, otherwise return None."""
+ return self._get_loader(fullname, path=path)
+
+ def _get_loader(self, fullname, path=None):
+ # type: (RestrictedModuleLoader, str, list[str]) -> RestrictedModuleLoader | None
+ """Return self if the given fullname is restricted, otherwise return None."""
+ if fullname in self.loaded_modules:
+ return None # ignore modules that are already being loaded
+
+ if is_name_in_namepace(fullname, ['ansible']):
+ if not self.restrict_to_module_paths:
+ return None # for non-modules, everything in the ansible namespace is allowed
+
+ if fullname in ('ansible.module_utils.basic',):
+ return self # intercept loading so we can modify the result
+
+ if is_name_in_namepace(fullname, ['ansible.module_utils', self.name]):
+ return None # module_utils and module under test are always allowed
+
+ if any(os.path.exists(candidate_path) for candidate_path in convert_ansible_name_to_absolute_paths(fullname)):
+ return self # restrict access to ansible files that exist
+
+ return None # ansible file does not exist, do not restrict access
+
+ if is_name_in_namepace(fullname, ['ansible_collections']):
+ if not collection_loader:
+ return self # restrict access to collections when we are not testing a collection
+
+ if not self.restrict_to_module_paths:
+ return None # for non-modules, everything in the ansible namespace is allowed
+
+ if is_name_in_namepace(fullname, ['ansible_collections...plugins.module_utils', self.name]):
+ return None # module_utils and module under test are always allowed
+
+ if collection_loader.find_module(fullname, path):
+ return self # restrict access to collection files that exist
+
+ return None # collection file does not exist, do not restrict access
+
+ # not a namespace we care about
+ return None
+
+ def create_module(self, spec): # pylint: disable=unused-argument
+ # type: (RestrictedModuleLoader, ModuleSpec) -> None
+ """Return None to use default module creation."""
+ return None
+
+ def exec_module(self, module):
+ # type: (RestrictedModuleLoader, types.ModuleType) -> None | ImportError
+ """Execute the module if the name is ansible.module_utils.basic and otherwise raise an ImportError"""
+ fullname = module.__spec__.name
+ if fullname == 'ansible.module_utils.basic':
+ self.loaded_modules.add(fullname)
+ for path in convert_ansible_name_to_absolute_paths(fullname):
+ if not os.path.exists(path):
+ continue
+ loader = SourceFileLoader(fullname, path)
+ spec = spec_from_loader(fullname, loader)
+ real_module = module_from_spec(spec)
+ loader.exec_module(real_module)
+ real_module.AnsibleModule = ImporterAnsibleModule # type: ignore[attr-defined]
+ real_module._load_params = lambda *args, **kwargs: {} # type: ignore[attr-defined] # pylint: disable=protected-access
+ sys.modules[fullname] = real_module
+ return None
+ raise ImportError('could not find "%s"' % fullname)
+ raise ImportError('import of "%s" is not allowed in this context' % fullname)
+
+ def load_module(self, fullname):
+ # type: (RestrictedModuleLoader, str) -> types.ModuleType | ImportError
+ """Return the module if the name is ansible.module_utils.basic and otherwise raise an ImportError."""
+ if fullname == 'ansible.module_utils.basic':
+ module = self.__load_module(fullname)
+
+ # stop Ansible module execution during AnsibleModule instantiation
+ module.AnsibleModule = ImporterAnsibleModule # type: ignore[attr-defined]
+ # no-op for _load_params since it may be called before instantiating AnsibleModule
+ module._load_params = lambda *args, **kwargs: {} # type: ignore[attr-defined] # pylint: disable=protected-access
+
+ return module
+
+ raise ImportError('import of "%s" is not allowed in this context' % fullname)
+
+ def __load_module(self, fullname):
+ # type: (RestrictedModuleLoader, str) -> types.ModuleType
+ """Load the requested module while avoiding infinite recursion."""
+ self.loaded_modules.add(fullname)
+ return import_module(fullname)
+
+ def run(restrict_to_module_paths):
+ """Main program function."""
+ base_dir = os.getcwd()
+ messages = set()
+
+ for path in sys.argv[1:] or sys.stdin.read().splitlines():
+ name = convert_relative_path_to_name(path)
+ test_python_module(path, name, base_dir, messages, restrict_to_module_paths)
+
+ if messages:
+ sys.exit(10)
+
+ def test_python_module(path, name, base_dir, messages, restrict_to_module_paths):
+ """Test the given python module by importing it.
+ :type path: str
+ :type name: str
+ :type base_dir: str
+ :type messages: set[str]
+ :type restrict_to_module_paths: bool
+ """
+ if name in sys.modules:
+ return # cannot be tested because it has already been loaded
+
+ is_ansible_module = (path.startswith('lib/ansible/modules/') or path.startswith('plugins/modules/')) and os.path.basename(path) != '__init__.py'
+ run_main = is_ansible_module
+
+ if path == 'lib/ansible/modules/async_wrapper.py':
+ # async_wrapper is a non-standard Ansible module (does not use AnsibleModule) so we cannot test the main function
+ run_main = False
+
+ capture_normal = Capture()
+ capture_main = Capture()
+
+ run_module_ok = False
+
+ try:
+ with monitor_sys_modules(path, messages):
+ with restrict_imports(path, name, messages, restrict_to_module_paths):
+ with capture_output(capture_normal):
+ import_module(name)
+
+ if run_main:
+ run_module_ok = is_ansible_module
+
+ with monitor_sys_modules(path, messages):
+ with restrict_imports(path, name, messages, restrict_to_module_paths):
+ with capture_output(capture_main):
+ runpy.run_module(name, run_name='__main__', alter_sys=True)
+ except ImporterAnsibleModuleException:
+ # module instantiated AnsibleModule without raising an exception
+ if not run_module_ok:
+ if is_ansible_module:
+ report_message(path, 0, 0, 'module-guard', "AnsibleModule instantiation not guarded by `if __name__ == '__main__'`", messages)
+ else:
+ report_message(path, 0, 0, 'non-module', "AnsibleModule instantiated by import of non-module", messages)
+ except BaseException as ex: # pylint: disable=locally-disabled, broad-except
+ # intentionally catch all exceptions, including calls to sys.exit
+ exc_type, _exc, exc_tb = sys.exc_info()
+ message = str(ex)
+ results = list(reversed(traceback.extract_tb(exc_tb)))
+ line = 0
+ offset = 0
+ full_path = os.path.join(base_dir, path)
+ base_path = base_dir + os.path.sep
+ source = None
+
+ # avoid line wraps in messages
+ message = re.sub(r'\n *', ': ', message)
+
+ for result in results:
+ if result[0] == full_path:
+ # save the line number for the file under test
+ line = result[1] or 0
+
+ if not source and result[0].startswith(base_path) and not result[0].startswith(temp_path):
+ # save the first path and line number in the traceback which is in our source tree
+ source = (os.path.relpath(result[0], base_path), result[1] or 0, 0)
+
+ if isinstance(ex, SyntaxError):
+ # SyntaxError has better information than the traceback
+ if ex.filename == full_path: # pylint: disable=locally-disabled, no-member
+ # syntax error was reported in the file under test
+ line = ex.lineno or 0 # pylint: disable=locally-disabled, no-member
+ offset = ex.offset or 0 # pylint: disable=locally-disabled, no-member
+ elif ex.filename.startswith(base_path) and not ex.filename.startswith(temp_path): # pylint: disable=locally-disabled, no-member
+ # syntax error was reported in our source tree
+ source = (os.path.relpath(ex.filename, base_path), ex.lineno or 0, ex.offset or 0) # pylint: disable=locally-disabled, no-member
+
+ # remove the filename and line number from the message
+ # either it was extracted above, or it's not really useful information
+ message = re.sub(r' \(.*?, line [0-9]+\)$', '', message)
+
+ if source and source[0] != path:
+ message += ' (at %s:%d:%d)' % (source[0], source[1], source[2])
+
+ report_message(path, line, offset, 'traceback', '%s: %s' % (exc_type.__name__, message), messages)
+ finally:
+ capture_report(path, capture_normal, messages)
+ capture_report(path, capture_main, messages)
+
+ def is_name_in_namepace(name, namespaces):
+ """Returns True if the given name is one of the given namespaces, otherwise returns False."""
+ name_parts = name.split('.')
+
+ for namespace in namespaces:
+ namespace_parts = namespace.split('.')
+ length = min(len(name_parts), len(namespace_parts))
+
+ truncated_name = name_parts[0:length]
+ truncated_namespace = namespace_parts[0:length]
+
+ # empty parts in the namespace are treated as wildcards
+ # to simplify the comparison, use those empty parts to indicate the positions in the name to be empty as well
+ for idx, part in enumerate(truncated_namespace):
+ if not part:
+ truncated_name[idx] = part
+
+ # example: name=ansible, allowed_name=ansible.module_utils
+ # example: name=ansible.module_utils.system.ping, allowed_name=ansible.module_utils
+ if truncated_name == truncated_namespace:
+ return True
+
+ return False
+
+ def check_sys_modules(path, before, messages):
+ """Check for unwanted changes to sys.modules.
+ :type path: str
+ :type before: dict[str, module]
+ :type messages: set[str]
+ """
+ after = sys.modules
+ removed = set(before.keys()) - set(after.keys())
+ changed = set(key for key, value in before.items() if key in after and value != after[key])
+
+ # additions are checked by our custom PEP 302 loader, so we don't need to check them again here
+
+ for module in sorted(removed):
+ report_message(path, 0, 0, 'unload', 'unloading of "%s" in sys.modules is not supported' % module, messages)
+
+ for module in sorted(changed):
+ report_message(path, 0, 0, 'reload', 'reloading of "%s" in sys.modules is not supported' % module, messages)
+
+ def convert_ansible_name_to_absolute_paths(name):
+ """Calculate the module path from the given name.
+ :type name: str
+ :rtype: list[str]
+ """
+ return [
+ os.path.join(ansible_path, name.replace('.', os.path.sep)),
+ os.path.join(ansible_path, name.replace('.', os.path.sep)) + '.py',
+ ]
+
+ def convert_relative_path_to_name(path):
+ """Calculate the module name from the given path.
+ :type path: str
+ :rtype: str
+ """
+ if path.endswith('/__init__.py'):
+ clean_path = os.path.dirname(path)
+ else:
+ clean_path = path
+
+ clean_path = os.path.splitext(clean_path)[0]
+
+ name = clean_path.replace(os.path.sep, '.')
+
+ if collection_loader:
+ # when testing collections the relative paths (and names) being tested are within the collection under test
+ name = 'ansible_collections.%s.%s' % (collection_full_name, name)
+ else:
+ # when testing ansible all files being imported reside under the lib directory
+ name = name[len('lib/'):]
+
+ return name
+
+ class Capture:
+ """Captured output and/or exception."""
+ def __init__(self):
+ # use buffered IO to simulate StringIO; allows Ansible's stream patching to behave without warnings
+ self.stdout = TextIOWrapper(BytesIO())
+ self.stderr = TextIOWrapper(BytesIO())
+
+ def capture_report(path, capture, messages):
+ """Report on captured output.
+ :type path: str
+ :type capture: Capture
+ :type messages: set[str]
+ """
+ # since we're using buffered IO, flush before checking for data
+ capture.stdout.flush()
+ capture.stderr.flush()
+ stdout_value = capture.stdout.buffer.getvalue()
+ if stdout_value:
+ first = stdout_value.decode().strip().splitlines()[0].strip()
+ report_message(path, 0, 0, 'stdout', first, messages)
+
+ stderr_value = capture.stderr.buffer.getvalue()
+ if stderr_value:
+ first = stderr_value.decode().strip().splitlines()[0].strip()
+ report_message(path, 0, 0, 'stderr', first, messages)
+
+ def report_message(path, line, column, code, message, messages):
+ """Report message if not already reported.
+ :type path: str
+ :type line: int
+ :type column: int
+ :type code: str
+ :type message: str
+ :type messages: set[str]
+ """
+ message = '%s:%d:%d: %s: %s' % (path, line, column, code, message)
+
+ if message not in messages:
+ messages.add(message)
+ print(message)
+
+ @contextlib.contextmanager
+ def restrict_imports(path, name, messages, restrict_to_module_paths):
+ """Restrict available imports.
+ :type path: str
+ :type name: str
+ :type messages: set[str]
+ :type restrict_to_module_paths: bool
+ """
+ restricted_loader = RestrictedModuleLoader(path, name, restrict_to_module_paths)
+
+ # noinspection PyTypeChecker
+ sys.meta_path.insert(0, restricted_loader)
+ sys.path_importer_cache.clear()
+
+ try:
+ yield
+ finally:
+ if import_type == 'plugin' and not collection_loader:
+ from ansible.utils.collection_loader._collection_finder import _AnsibleCollectionFinder
+ _AnsibleCollectionFinder._remove() # pylint: disable=protected-access
+
+ if sys.meta_path[0] != restricted_loader:
+ report_message(path, 0, 0, 'metapath', 'changes to sys.meta_path[0] are not permitted', messages)
+
+ while restricted_loader in sys.meta_path:
+ # noinspection PyTypeChecker
+ sys.meta_path.remove(restricted_loader)
+
+ sys.path_importer_cache.clear()
+
+ @contextlib.contextmanager
+ def monitor_sys_modules(path, messages):
+ """Monitor sys.modules for unwanted changes, reverting any additions made to our own namespaces."""
+ snapshot = sys.modules.copy()
+
+ try:
+ yield
+ finally:
+ check_sys_modules(path, snapshot, messages)
+
+ for key in set(sys.modules.keys()) - set(snapshot.keys()):
+ if is_name_in_namepace(key, ('ansible', 'ansible_collections')):
+ del sys.modules[key] # only unload our own code since we know it's native Python
+
+ @contextlib.contextmanager
+ def capture_output(capture):
+ """Capture sys.stdout and sys.stderr.
+ :type capture: Capture
+ """
+ old_stdout = sys.stdout
+ old_stderr = sys.stderr
+
+ sys.stdout = capture.stdout
+ sys.stderr = capture.stderr
+
+ # clear all warnings registries to make all warnings available
+ for module in sys.modules.values():
+ try:
+ # noinspection PyUnresolvedReferences
+ module.__warningregistry__.clear()
+ except AttributeError:
+ pass
+
+ with warnings.catch_warnings():
+ warnings.simplefilter('error')
+
+ if collection_loader and import_type == 'plugin':
+ warnings.filterwarnings(
+ "ignore",
+ "AnsibleCollectionFinder has already been configured")
+
+ if sys.version_info[0] == 2:
+ warnings.filterwarnings(
+ "ignore",
+ "Python 2 is no longer supported by the Python core team. Support for it is now deprecated in cryptography,"
+ " and will be removed in a future release.")
+ warnings.filterwarnings(
+ "ignore",
+ "Python 2 is no longer supported by the Python core team. Support for it is now deprecated in cryptography,"
+ " and will be removed in the next release.")
+
+ if sys.version_info[:2] == (3, 5):
+ warnings.filterwarnings(
+ "ignore",
+ "Python 3.5 support will be dropped in the next release ofcryptography. Please upgrade your Python.")
+ warnings.filterwarnings(
+ "ignore",
+ "Python 3.5 support will be dropped in the next release of cryptography. Please upgrade your Python.")
+
+ try:
+ yield
+ finally:
+ sys.stdout = old_stdout
+ sys.stderr = old_stderr
+
+ run(import_type == 'module')
+
+
+if __name__ == '__main__':
+ main()