diff options
Diffstat (limited to 'test/lib/ansible_test/_util/target/pytest/plugins/ansible_pytest_collections.py')
-rw-r--r-- | test/lib/ansible_test/_util/target/pytest/plugins/ansible_pytest_collections.py | 46 |
1 files changed, 46 insertions, 0 deletions
diff --git a/test/lib/ansible_test/_util/target/pytest/plugins/ansible_pytest_collections.py b/test/lib/ansible_test/_util/target/pytest/plugins/ansible_pytest_collections.py index fefd6b0f..2f77c03b 100644 --- a/test/lib/ansible_test/_util/target/pytest/plugins/ansible_pytest_collections.py +++ b/test/lib/ansible_test/_util/target/pytest/plugins/ansible_pytest_collections.py @@ -32,6 +32,50 @@ def collection_pypkgpath(self): raise Exception('File "%s" not found in collection path "%s".' % (self.strpath, ANSIBLE_COLLECTIONS_PATH)) +def enable_assertion_rewriting_hook(): # type: () -> None + """ + Enable pytest's AssertionRewritingHook on Python 3.x. + This is necessary because the Ansible collection loader intercepts imports before the pytest provided loader ever sees them. + """ + import sys + + if sys.version_info[0] == 2: + return # Python 2.x is not supported + + hook_name = '_pytest.assertion.rewrite.AssertionRewritingHook' + hooks = [hook for hook in sys.meta_path if hook.__class__.__module__ + '.' + hook.__class__.__qualname__ == hook_name] + + if len(hooks) != 1: + raise Exception('Found {} instance(s) of "{}" in sys.meta_path.'.format(len(hooks), hook_name)) + + assertion_rewriting_hook = hooks[0] + + # This is based on `_AnsibleCollectionPkgLoaderBase.exec_module` from `ansible/utils/collection_loader/_collection_finder.py`. + def exec_module(self, module): + # short-circuit redirect; avoid reinitializing existing modules + if self._redirect_module: # pylint: disable=protected-access + return + + # execute the module's code in its namespace + code_obj = self.get_code(self._fullname) # pylint: disable=protected-access + + if code_obj is not None: # things like NS packages that can't have code on disk will return None + # This logic is loosely based on `AssertionRewritingHook._should_rewrite` from pytest. + # See: https://github.com/pytest-dev/pytest/blob/779a87aada33af444f14841a04344016a087669e/src/_pytest/assertion/rewrite.py#L209 + should_rewrite = self._package_to_load == 'conftest' or self._package_to_load.startswith('test_') # pylint: disable=protected-access + + if should_rewrite: + # noinspection PyUnresolvedReferences + assertion_rewriting_hook.exec_module(module) + else: + exec(code_obj, module.__dict__) # pylint: disable=exec-used + + # noinspection PyProtectedMember + from ansible.utils.collection_loader._collection_finder import _AnsibleCollectionPkgLoaderBase + + _AnsibleCollectionPkgLoaderBase.exec_module = exec_module + + def pytest_configure(): """Configure this pytest plugin.""" try: @@ -40,6 +84,8 @@ def pytest_configure(): except AttributeError: pytest_configure.executed = True + enable_assertion_rewriting_hook() + # noinspection PyProtectedMember from ansible.utils.collection_loader._collection_finder import _AnsibleCollectionFinder |