summaryrefslogtreecommitdiff
path: root/test/lib/ansible_test/_util/target/pytest/plugins/ansible_pytest_collections.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/lib/ansible_test/_util/target/pytest/plugins/ansible_pytest_collections.py')
-rw-r--r--test/lib/ansible_test/_util/target/pytest/plugins/ansible_pytest_collections.py46
1 files changed, 46 insertions, 0 deletions
diff --git a/test/lib/ansible_test/_util/target/pytest/plugins/ansible_pytest_collections.py b/test/lib/ansible_test/_util/target/pytest/plugins/ansible_pytest_collections.py
index fefd6b0f..2f77c03b 100644
--- a/test/lib/ansible_test/_util/target/pytest/plugins/ansible_pytest_collections.py
+++ b/test/lib/ansible_test/_util/target/pytest/plugins/ansible_pytest_collections.py
@@ -32,6 +32,50 @@ def collection_pypkgpath(self):
raise Exception('File "%s" not found in collection path "%s".' % (self.strpath, ANSIBLE_COLLECTIONS_PATH))
+def enable_assertion_rewriting_hook(): # type: () -> None
+ """
+ Enable pytest's AssertionRewritingHook on Python 3.x.
+ This is necessary because the Ansible collection loader intercepts imports before the pytest provided loader ever sees them.
+ """
+ import sys
+
+ if sys.version_info[0] == 2:
+ return # Python 2.x is not supported
+
+ hook_name = '_pytest.assertion.rewrite.AssertionRewritingHook'
+ hooks = [hook for hook in sys.meta_path if hook.__class__.__module__ + '.' + hook.__class__.__qualname__ == hook_name]
+
+ if len(hooks) != 1:
+ raise Exception('Found {} instance(s) of "{}" in sys.meta_path.'.format(len(hooks), hook_name))
+
+ assertion_rewriting_hook = hooks[0]
+
+ # This is based on `_AnsibleCollectionPkgLoaderBase.exec_module` from `ansible/utils/collection_loader/_collection_finder.py`.
+ def exec_module(self, module):
+ # short-circuit redirect; avoid reinitializing existing modules
+ if self._redirect_module: # pylint: disable=protected-access
+ return
+
+ # execute the module's code in its namespace
+ code_obj = self.get_code(self._fullname) # pylint: disable=protected-access
+
+ if code_obj is not None: # things like NS packages that can't have code on disk will return None
+ # This logic is loosely based on `AssertionRewritingHook._should_rewrite` from pytest.
+ # See: https://github.com/pytest-dev/pytest/blob/779a87aada33af444f14841a04344016a087669e/src/_pytest/assertion/rewrite.py#L209
+ should_rewrite = self._package_to_load == 'conftest' or self._package_to_load.startswith('test_') # pylint: disable=protected-access
+
+ if should_rewrite:
+ # noinspection PyUnresolvedReferences
+ assertion_rewriting_hook.exec_module(module)
+ else:
+ exec(code_obj, module.__dict__) # pylint: disable=exec-used
+
+ # noinspection PyProtectedMember
+ from ansible.utils.collection_loader._collection_finder import _AnsibleCollectionPkgLoaderBase
+
+ _AnsibleCollectionPkgLoaderBase.exec_module = exec_module
+
+
def pytest_configure():
"""Configure this pytest plugin."""
try:
@@ -40,6 +84,8 @@ def pytest_configure():
except AttributeError:
pytest_configure.executed = True
+ enable_assertion_rewriting_hook()
+
# noinspection PyProtectedMember
from ansible.utils.collection_loader._collection_finder import _AnsibleCollectionFinder