summaryrefslogtreecommitdiff
path: root/lib/ansible/utils/unsafe_proxy.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/ansible/utils/unsafe_proxy.py')
-rw-r--r--lib/ansible/utils/unsafe_proxy.py265
1 files changed, 254 insertions, 11 deletions
diff --git a/lib/ansible/utils/unsafe_proxy.py b/lib/ansible/utils/unsafe_proxy.py
index d78ebf6e..683f6e27 100644
--- a/lib/ansible/utils/unsafe_proxy.py
+++ b/lib/ansible/utils/unsafe_proxy.py
@@ -57,7 +57,6 @@ from collections.abc import Mapping, Set
from ansible.module_utils._text import to_bytes, to_text
from ansible.module_utils.common.collections import is_sequence
-from ansible.module_utils.six import string_types, binary_type, text_type
from ansible.utils.native_jinja import NativeJinjaText
@@ -68,16 +67,256 @@ class AnsibleUnsafe(object):
__UNSAFE__ = True
-class AnsibleUnsafeBytes(binary_type, AnsibleUnsafe):
- def decode(self, *args, **kwargs):
- """Wrapper method to ensure type conversions maintain unsafe context"""
- return AnsibleUnsafeText(super(AnsibleUnsafeBytes, self).decode(*args, **kwargs))
+class AnsibleUnsafeBytes(bytes, AnsibleUnsafe):
+ def _strip_unsafe(self):
+ return super().__bytes__()
+ def __reduce__(self, /):
+ return (self.__class__, (self._strip_unsafe(),))
-class AnsibleUnsafeText(text_type, AnsibleUnsafe):
- def encode(self, *args, **kwargs):
- """Wrapper method to ensure type conversions maintain unsafe context"""
- return AnsibleUnsafeBytes(super(AnsibleUnsafeText, self).encode(*args, **kwargs))
+ def __str__(self, /): # pylint: disable=invalid-str-returned
+ return self.decode()
+
+ def __bytes__(self, /): # pylint: disable=invalid-bytes-returned
+ return self
+
+ def __repr__(self, /): # pylint: disable=invalid-repr-returned
+ return AnsibleUnsafeText(super().__repr__())
+
+ def __format__(self, format_spec, /): # pylint: disable=invalid-format-returned
+ return AnsibleUnsafeText(super().__format__(format_spec))
+
+ def __getitem__(self, key, /):
+ if isinstance(key, int):
+ return super().__getitem__(key)
+ return self.__class__(super().__getitem__(key))
+
+ def __reversed__(self, /):
+ return self[::-1]
+
+ def __add__(self, value, /):
+ return self.__class__(super().__add__(value))
+
+ def __radd__(self, value, /):
+ return self.__class__(value.__add__(self))
+
+ def __mul__(self, value, /):
+ return self.__class__(super().__mul__(value))
+
+ __rmul__ = __mul__
+
+ def __mod__(self, value, /):
+ return self.__class__(super().__mod__(value))
+
+ def __rmod__(self, value, /):
+ return self.__class__(super().__rmod__(value))
+
+ def capitalize(self, /):
+ return self.__class__(super().capitalize())
+
+ def center(self, width, fillchar=b' ', /):
+ return self.__class__(super().center(width, fillchar))
+
+ def decode(self, /, encoding='utf-8', errors='strict'):
+ return AnsibleUnsafeText(super().decode(encoding=encoding, errors=errors))
+
+ def removeprefix(self, prefix, /):
+ return self.__class__(super().removeprefix(prefix))
+
+ def removesuffix(self, suffix, /):
+ return self.__class__(super().removesuffix(suffix))
+
+ def expandtabs(self, /, tabsize=8):
+ return self.__class__(super().expandtabs(tabsize))
+
+ def join(self, iterable_of_bytes, /):
+ return self.__class__(super().join(iterable_of_bytes))
+
+ def ljust(self, width, fillchar=b' ', /):
+ return self.__class__(super().ljust(width, fillchar))
+
+ def lower(self, /):
+ return self.__class__(super().lower())
+
+ def lstrip(self, chars=None, /):
+ return self.__class__(super().lstrip(chars))
+
+ def partition(self, sep, /):
+ cls = self.__class__
+ return tuple(cls(e) for e in super().partition(sep))
+
+ def replace(self, old, new, count=-1, /):
+ return self.__class__(super().replace(old, new, count))
+
+ def rjust(self, width, fillchar=b' ', /):
+ return self.__class__(super().rjust(width, fillchar))
+
+ def rpartition(self, sep, /):
+ cls = self.__class__
+ return tuple(cls(e) for e in super().rpartition(sep))
+
+ def rstrip(self, chars=None, /):
+ return self.__class__(super().rstrip(chars))
+
+ def split(self, /, sep=None, maxsplit=-1):
+ cls = self.__class__
+ return [cls(e) for e in super().split(sep=sep, maxsplit=maxsplit)]
+
+ def rsplit(self, /, sep=None, maxsplit=-1):
+ cls = self.__class__
+ return [cls(e) for e in super().rsplit(sep=sep, maxsplit=maxsplit)]
+
+ def splitlines(self, /, keepends=False):
+ cls = self.__class__
+ return [cls(e) for e in super().splitlines(keepends=keepends)]
+
+ def strip(self, chars=None, /):
+ return self.__class__(super().strip(chars))
+
+ def swapcase(self, /):
+ return self.__class__(super().swapcase())
+
+ def title(self, /):
+ return self.__class__(super().title())
+
+ def translate(self, table, /, delete=b''):
+ return self.__class__(super().translate(table, delete=delete))
+
+ def upper(self, /):
+ return self.__class__(super().upper())
+
+ def zfill(self, width, /):
+ return self.__class__(super().zfill(width))
+
+
+class AnsibleUnsafeText(str, AnsibleUnsafe):
+ def _strip_unsafe(self, /):
+ return super().__str__()
+
+ def __reduce__(self, /):
+ return (self.__class__, (self._strip_unsafe(),))
+
+ def __str__(self, /): # pylint: disable=invalid-str-returned
+ return self
+
+ def __repr__(self, /): # pylint: disable=invalid-repr-returned
+ return self.__class__(super().__repr__())
+
+ def __format__(self, format_spec, /): # pylint: disable=invalid-format-returned
+ return self.__class__(super().__format__(format_spec))
+
+ def __getitem__(self, key, /):
+ return self.__class__(super().__getitem__(key))
+
+ def __iter__(self, /):
+ cls = self.__class__
+ return (cls(c) for c in super().__iter__())
+
+ def __reversed__(self, /):
+ return self[::-1]
+
+ def __add__(self, value, /):
+ return self.__class__(super().__add__(value))
+
+ def __radd__(self, value, /):
+ return self.__class__(value.__add__(self))
+
+ def __mul__(self, value, /):
+ return self.__class__(super().__mul__(value))
+
+ __rmul__ = __mul__
+
+ def __mod__(self, value, /):
+ return self.__class__(super().__mod__(value))
+
+ def __rmod__(self, value, /):
+ return self.__class__(super().__rmod__(value))
+
+ def capitalize(self, /):
+ return self.__class__(super().capitalize())
+
+ def casefold(self, /):
+ return self.__class__(super().casefold())
+
+ def center(self, width, fillchar=' ', /):
+ return self.__class__(super().center(width, fillchar))
+
+ def encode(self, /, encoding='utf-8', errors='strict'):
+ return AnsibleUnsafeBytes(super().encode(encoding=encoding, errors=errors))
+
+ def removeprefix(self, prefix, /):
+ return self.__class__(super().removeprefix(prefix))
+
+ def removesuffix(self, suffix, /):
+ return self.__class__(super().removesuffix(suffix))
+
+ def expandtabs(self, /, tabsize=8):
+ return self.__class__(super().expandtabs(tabsize))
+
+ def format(self, /, *args, **kwargs):
+ return self.__class__(super().format(*args, **kwargs))
+
+ def format_map(self, mapping, /):
+ return self.__class__(super().format_map(mapping))
+
+ def join(self, iterable, /):
+ return self.__class__(super().join(iterable))
+
+ def ljust(self, width, fillchar=' ', /):
+ return self.__class__(super().ljust(width, fillchar))
+
+ def lower(self, /):
+ return self.__class__(super().lower())
+
+ def lstrip(self, chars=None, /):
+ return self.__class__(super().lstrip(chars))
+
+ def partition(self, sep, /):
+ cls = self.__class__
+ return tuple(cls(e) for e in super().partition(sep))
+
+ def replace(self, old, new, count=-1, /):
+ return self.__class__(super().replace(old, new, count))
+
+ def rjust(self, width, fillchar=' ', /):
+ return self.__class__(super().rjust(width, fillchar))
+
+ def rpartition(self, sep, /):
+ cls = self.__class__
+ return tuple(cls(e) for e in super().rpartition(sep))
+
+ def rstrip(self, chars=None, /):
+ return self.__class__(super().rstrip(chars))
+
+ def split(self, /, sep=None, maxsplit=-1):
+ cls = self.__class__
+ return [cls(e) for e in super().split(sep=sep, maxsplit=maxsplit)]
+
+ def rsplit(self, /, sep=None, maxsplit=-1):
+ cls = self.__class__
+ return [cls(e) for e in super().rsplit(sep=sep, maxsplit=maxsplit)]
+
+ def splitlines(self, /, keepends=False):
+ cls = self.__class__
+ return [cls(e) for e in super().splitlines(keepends=keepends)]
+
+ def strip(self, chars=None, /):
+ return self.__class__(super().strip(chars))
+
+ def swapcase(self, /):
+ return self.__class__(super().swapcase())
+
+ def title(self, /):
+ return self.__class__(super().title())
+
+ def translate(self, table, /):
+ return self.__class__(super().translate(table))
+
+ def upper(self, /):
+ return self.__class__(super().upper())
+
+ def zfill(self, width, /):
+ return self.__class__(super().zfill(width))
class NativeJinjaUnsafeText(NativeJinjaText, AnsibleUnsafeText):
@@ -112,9 +351,9 @@ def wrap_var(v):
v = _wrap_sequence(v)
elif isinstance(v, NativeJinjaText):
v = NativeJinjaUnsafeText(v)
- elif isinstance(v, binary_type):
+ elif isinstance(v, bytes):
v = AnsibleUnsafeBytes(v)
- elif isinstance(v, text_type):
+ elif isinstance(v, str):
v = AnsibleUnsafeText(v)
return v
@@ -126,3 +365,7 @@ def to_unsafe_bytes(*args, **kwargs):
def to_unsafe_text(*args, **kwargs):
return wrap_var(to_text(*args, **kwargs))
+
+
+def _is_unsafe(obj):
+ return getattr(obj, '__UNSAFE__', False) is True