diff options
Diffstat (limited to 'lib/ansible/plugins/connection/__init__.py')
-rw-r--r-- | lib/ansible/plugins/connection/__init__.py | 116 |
1 files changed, 76 insertions, 40 deletions
diff --git a/lib/ansible/plugins/connection/__init__.py b/lib/ansible/plugins/connection/__init__.py index daa683ce..5f7e282f 100644 --- a/lib/ansible/plugins/connection/__init__.py +++ b/lib/ansible/plugins/connection/__init__.py @@ -2,10 +2,12 @@ # (c) 2015 Toshio Kuratomi <tkuratomi@ansible.com> # (c) 2017, Peter Sprygada <psprygad@redhat.com> # (c) 2017 Ansible Project -from __future__ import (absolute_import, division, print_function) +from __future__ import (annotations, absolute_import, division, print_function) __metaclass__ = type +import collections.abc as c import fcntl +import io import os import shlex import typing as t @@ -14,8 +16,11 @@ from abc import abstractmethod from functools import wraps from ansible import constants as C -from ansible.module_utils._text import to_bytes, to_text +from ansible.module_utils.common.text.converters import to_bytes, to_text +from ansible.playbook.play_context import PlayContext from ansible.plugins import AnsiblePlugin +from ansible.plugins.become import BecomeBase +from ansible.plugins.shell import ShellBase from ansible.utils.display import Display from ansible.plugins.loader import connection_loader, get_shell_plugin from ansible.utils.path import unfrackpath @@ -27,10 +32,15 @@ __all__ = ['ConnectionBase', 'ensure_connect'] BUFSIZE = 65536 +P = t.ParamSpec('P') +T = t.TypeVar('T') -def ensure_connect(func): + +def ensure_connect( + func: c.Callable[t.Concatenate[ConnectionBase, P], T], +) -> c.Callable[t.Concatenate[ConnectionBase, P], T]: @wraps(func) - def wrapped(self, *args, **kwargs): + def wrapped(self: ConnectionBase, *args: P.args, **kwargs: P.kwargs) -> T: if not self._connected: self._connect() return func(self, *args, **kwargs) @@ -57,9 +67,16 @@ class ConnectionBase(AnsiblePlugin): supports_persistence = False force_persistence = False - default_user = None + default_user: str | None = None - def __init__(self, play_context, new_stdin, shell=None, *args, **kwargs): + def __init__( + self, + play_context: PlayContext, + new_stdin: io.TextIOWrapper | None = None, + shell: ShellBase | None = None, + *args: t.Any, + **kwargs: t.Any, + ) -> None: super(ConnectionBase, self).__init__() @@ -67,18 +84,17 @@ class ConnectionBase(AnsiblePlugin): if not hasattr(self, '_play_context'): # Backwards compat: self._play_context isn't really needed, using set_options/get_option self._play_context = play_context - if not hasattr(self, '_new_stdin'): - self._new_stdin = new_stdin + # Delete once the deprecation period is over for WorkerProcess._new_stdin + if not hasattr(self, '__new_stdin'): + self.__new_stdin = new_stdin if not hasattr(self, '_display'): # Backwards compat: self._display isn't really needed, just import the global display and use that. self._display = display - if not hasattr(self, '_connected'): - self._connected = False self.success_key = None self.prompt = None self._connected = False - self._socket_path = None + self._socket_path: str | None = None # helper plugins self._shell = shell @@ -88,23 +104,32 @@ class ConnectionBase(AnsiblePlugin): shell_type = play_context.shell if play_context.shell else getattr(self, '_shell_type', None) self._shell = get_shell_plugin(shell_type=shell_type, executable=self._play_context.executable) - self.become = None + self.become: BecomeBase | None = None + + @property + def _new_stdin(self) -> io.TextIOWrapper | None: + display.deprecated( + "The connection's stdin object is deprecated. " + "Call display.prompt_until(msg) instead.", + version='2.19', + ) + return self.__new_stdin - def set_become_plugin(self, plugin): + def set_become_plugin(self, plugin: BecomeBase) -> None: self.become = plugin @property - def connected(self): + def connected(self) -> bool: '''Read-only property holding whether the connection to the remote host is active or closed.''' return self._connected @property - def socket_path(self): + def socket_path(self) -> str | None: '''Read-only property holding the connection socket path for this remote host''' return self._socket_path @staticmethod - def _split_ssh_args(argstring): + def _split_ssh_args(argstring: str) -> list[str]: """ Takes a string like '-o Foo=1 -o Bar="foo bar"' and returns a list ['-o', 'Foo=1', '-o', 'Bar=foo bar'] that can be added to @@ -115,17 +140,17 @@ class ConnectionBase(AnsiblePlugin): @property @abstractmethod - def transport(self): + def transport(self) -> str: """String used to identify this Connection class from other classes""" pass @abstractmethod - def _connect(self): + def _connect(self: T) -> T: """Connect to the host we've been initialized with""" @ensure_connect @abstractmethod - def exec_command(self, cmd, in_data=None, sudoable=True): + def exec_command(self, cmd: str, in_data: bytes | None = None, sudoable: bool = True) -> tuple[int, bytes, bytes]: """Run a command on the remote host. :arg cmd: byte string containing the command @@ -193,36 +218,36 @@ class ConnectionBase(AnsiblePlugin): @ensure_connect @abstractmethod - def put_file(self, in_path, out_path): + def put_file(self, in_path: str, out_path: str) -> None: """Transfer a file from local to remote""" pass @ensure_connect @abstractmethod - def fetch_file(self, in_path, out_path): + def fetch_file(self, in_path: str, out_path: str) -> None: """Fetch a file from remote to local; callers are expected to have pre-created the directory chain for out_path""" pass @abstractmethod - def close(self): + def close(self) -> None: """Terminate the connection""" pass - def connection_lock(self): + def connection_lock(self) -> None: f = self._play_context.connection_lockfd display.vvvv('CONNECTION: pid %d waiting for lock on %d' % (os.getpid(), f), host=self._play_context.remote_addr) fcntl.lockf(f, fcntl.LOCK_EX) display.vvvv('CONNECTION: pid %d acquired lock on %d' % (os.getpid(), f), host=self._play_context.remote_addr) - def connection_unlock(self): + def connection_unlock(self) -> None: f = self._play_context.connection_lockfd fcntl.lockf(f, fcntl.LOCK_UN) display.vvvv('CONNECTION: pid %d released lock on %d' % (os.getpid(), f), host=self._play_context.remote_addr) - def reset(self): + def reset(self) -> None: display.warning("Reset is not implemented for this connection") - def update_vars(self, variables): + def update_vars(self, variables: dict[str, t.Any]) -> None: ''' Adds 'magic' variables relating to connections to the variable dictionary provided. In case users need to access from the play, this is a legacy from runner. @@ -238,7 +263,7 @@ class ConnectionBase(AnsiblePlugin): elif varname == 'ansible_connection': # its me mom! value = self._load_name - elif varname == 'ansible_shell_type': + elif varname == 'ansible_shell_type' and self._shell: # its my cousin ... value = self._shell._load_name else: @@ -271,9 +296,15 @@ class NetworkConnectionBase(ConnectionBase): # Do not use _remote_is_local in other connections _remote_is_local = True - def __init__(self, play_context, new_stdin, *args, **kwargs): + def __init__( + self, + play_context: PlayContext, + new_stdin: io.TextIOWrapper | None = None, + *args: t.Any, + **kwargs: t.Any, + ) -> None: super(NetworkConnectionBase, self).__init__(play_context, new_stdin, *args, **kwargs) - self._messages = [] + self._messages: list[tuple[str, str]] = [] self._conn_closed = False self._network_os = self._play_context.network_os @@ -281,7 +312,7 @@ class NetworkConnectionBase(ConnectionBase): self._local = connection_loader.get('local', play_context, '/dev/null') self._local.set_options() - self._sub_plugin = {} + self._sub_plugin: dict[str, t.Any] = {} self._cached_variables = (None, None, None) # reconstruct the socket_path and set instance values accordingly @@ -300,10 +331,10 @@ class NetworkConnectionBase(ConnectionBase): return method raise AttributeError("'%s' object has no attribute '%s'" % (self.__class__.__name__, name)) - def exec_command(self, cmd, in_data=None, sudoable=True): + def exec_command(self, cmd: str, in_data: bytes | None = None, sudoable: bool = True) -> tuple[int, bytes, bytes]: return self._local.exec_command(cmd, in_data, sudoable) - def queue_message(self, level, message): + def queue_message(self, level: str, message: str) -> None: """ Adds a message to the queue of messages waiting to be pushed back to the controller process. @@ -313,19 +344,19 @@ class NetworkConnectionBase(ConnectionBase): """ self._messages.append((level, message)) - def pop_messages(self): + def pop_messages(self) -> list[tuple[str, str]]: messages, self._messages = self._messages, [] return messages - def put_file(self, in_path, out_path): + def put_file(self, in_path: str, out_path: str) -> None: """Transfer a file from local to remote""" return self._local.put_file(in_path, out_path) - def fetch_file(self, in_path, out_path): + def fetch_file(self, in_path: str, out_path: str) -> None: """Fetch a file from remote to local""" return self._local.fetch_file(in_path, out_path) - def reset(self): + def reset(self) -> None: ''' Reset the connection ''' @@ -334,12 +365,17 @@ class NetworkConnectionBase(ConnectionBase): self.close() self.queue_message('vvvv', 'reset call on connection instance') - def close(self): + def close(self) -> None: self._conn_closed = True if self._connected: self._connected = False - def set_options(self, task_keys=None, var_options=None, direct=None): + def set_options( + self, + task_keys: dict[str, t.Any] | None = None, + var_options: dict[str, t.Any] | None = None, + direct: dict[str, t.Any] | None = None, + ) -> None: super(NetworkConnectionBase, self).set_options(task_keys=task_keys, var_options=var_options, direct=direct) if self.get_option('persistent_log_messages'): warning = "Persistent connection logging is enabled for %s. This will log ALL interactions" % self._play_context.remote_addr @@ -354,7 +390,7 @@ class NetworkConnectionBase(ConnectionBase): except AttributeError: pass - def _update_connection_state(self): + def _update_connection_state(self) -> None: ''' Reconstruct the connection socket_path and check if it exists @@ -377,6 +413,6 @@ class NetworkConnectionBase(ConnectionBase): self._connected = True self._socket_path = socket_path - def _log_messages(self, message): + def _log_messages(self, message: str) -> None: if self.get_option('persistent_log_messages'): self.queue_message('log', message) |