summaryrefslogtreecommitdiff
path: root/python/tests/protocol.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/tests/protocol.py')
-rw-r--r--python/tests/protocol.py583
1 files changed, 583 insertions, 0 deletions
diff --git a/python/tests/protocol.py b/python/tests/protocol.py
new file mode 100644
index 0000000000..5cd7938be3
--- /dev/null
+++ b/python/tests/protocol.py
@@ -0,0 +1,583 @@
+import asyncio
+from contextlib import contextmanager
+import os
+import socket
+from tempfile import TemporaryDirectory
+
+import avocado
+
+from qemu.aqmp import ConnectError, Runstate
+from qemu.aqmp.protocol import AsyncProtocol, StateError
+from qemu.aqmp.util import asyncio_run, create_task
+
+
+class NullProtocol(AsyncProtocol[None]):
+ """
+ NullProtocol is a test mockup of an AsyncProtocol implementation.
+
+ It adds a fake_session instance variable that enables a code path
+ that bypasses the actual connection logic, but still allows the
+ reader/writers to start.
+
+ Because the message type is defined as None, an asyncio.Event named
+ 'trigger_input' is created that prohibits the reader from
+ incessantly being able to yield None; this event can be poked to
+ simulate an incoming message.
+
+ For testing symmetry with do_recv, an interface is added to "send" a
+ Null message.
+
+ For testing purposes, a "simulate_disconnection" method is also
+ added which allows us to trigger a bottom half disconnect without
+ injecting any real errors into the reader/writer loops; in essence
+ it performs exactly half of what disconnect() normally does.
+ """
+ def __init__(self, name=None):
+ self.fake_session = False
+ self.trigger_input: asyncio.Event
+ super().__init__(name)
+
+ async def _establish_session(self):
+ self.trigger_input = asyncio.Event()
+ await super()._establish_session()
+
+ async def _do_accept(self, address, ssl=None):
+ if not self.fake_session:
+ await super()._do_accept(address, ssl)
+
+ async def _do_connect(self, address, ssl=None):
+ if not self.fake_session:
+ await super()._do_connect(address, ssl)
+
+ async def _do_recv(self) -> None:
+ await self.trigger_input.wait()
+ self.trigger_input.clear()
+
+ def _do_send(self, msg: None) -> None:
+ pass
+
+ async def send_msg(self) -> None:
+ await self._outgoing.put(None)
+
+ async def simulate_disconnect(self) -> None:
+ """
+ Simulates a bottom-half disconnect.
+
+ This method schedules a disconnection but does not wait for it
+ to complete. This is used to put the loop into the DISCONNECTING
+ state without fully quiescing it back to IDLE. This is normally
+ something you cannot coax AsyncProtocol to do on purpose, but it
+ will be similar to what happens with an unhandled Exception in
+ the reader/writer.
+
+ Under normal circumstances, the library design requires you to
+ await on disconnect(), which awaits the disconnect task and
+ returns bottom half errors as a pre-condition to allowing the
+ loop to return back to IDLE.
+ """
+ self._schedule_disconnect()
+
+
+class LineProtocol(AsyncProtocol[str]):
+ def __init__(self, name=None):
+ super().__init__(name)
+ self.rx_history = []
+
+ async def _do_recv(self) -> str:
+ raw = await self._readline()
+ msg = raw.decode()
+ self.rx_history.append(msg)
+ return msg
+
+ def _do_send(self, msg: str) -> None:
+ assert self._writer is not None
+ self._writer.write(msg.encode() + b'\n')
+
+ async def send_msg(self, msg: str) -> None:
+ await self._outgoing.put(msg)
+
+
+def run_as_task(coro, allow_cancellation=False):
+ """
+ Run a given coroutine as a task.
+
+ Optionally, wrap it in a try..except block that allows this
+ coroutine to be canceled gracefully.
+ """
+ async def _runner():
+ try:
+ await coro
+ except asyncio.CancelledError:
+ if allow_cancellation:
+ return
+ raise
+ return create_task(_runner())
+
+
+@contextmanager
+def jammed_socket():
+ """
+ Opens up a random unused TCP port on localhost, then jams it.
+ """
+ socks = []
+
+ try:
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ sock.bind(('127.0.0.1', 0))
+ sock.listen(1)
+ address = sock.getsockname()
+
+ socks.append(sock)
+
+ # I don't *fully* understand why, but it takes *two* un-accepted
+ # connections to start jamming the socket.
+ for _ in range(2):
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.connect(address)
+ socks.append(sock)
+
+ yield address
+
+ finally:
+ for sock in socks:
+ sock.close()
+
+
+class Smoke(avocado.Test):
+
+ def setUp(self):
+ self.proto = NullProtocol()
+
+ def test__repr__(self):
+ self.assertEqual(
+ repr(self.proto),
+ "<NullProtocol runstate=IDLE>"
+ )
+
+ def testRunstate(self):
+ self.assertEqual(
+ self.proto.runstate,
+ Runstate.IDLE
+ )
+
+ def testDefaultName(self):
+ self.assertEqual(
+ self.proto.name,
+ None
+ )
+
+ def testLogger(self):
+ self.assertEqual(
+ self.proto.logger.name,
+ 'qemu.aqmp.protocol'
+ )
+
+ def testName(self):
+ self.proto = NullProtocol('Steve')
+
+ self.assertEqual(
+ self.proto.name,
+ 'Steve'
+ )
+
+ self.assertEqual(
+ self.proto.logger.name,
+ 'qemu.aqmp.protocol.Steve'
+ )
+
+ self.assertEqual(
+ repr(self.proto),
+ "<NullProtocol name='Steve' runstate=IDLE>"
+ )
+
+
+class TestBase(avocado.Test):
+
+ def setUp(self):
+ self.proto = NullProtocol(type(self).__name__)
+ self.assertEqual(self.proto.runstate, Runstate.IDLE)
+ self.runstate_watcher = None
+
+ def tearDown(self):
+ self.assertEqual(self.proto.runstate, Runstate.IDLE)
+
+ async def _asyncSetUp(self):
+ pass
+
+ async def _asyncTearDown(self):
+ if self.runstate_watcher:
+ await self.runstate_watcher
+
+ @staticmethod
+ def async_test(async_test_method):
+ """
+ Decorator; adds SetUp and TearDown to async tests.
+ """
+ async def _wrapper(self, *args, **kwargs):
+ loop = asyncio.get_event_loop()
+ loop.set_debug(True)
+
+ await self._asyncSetUp()
+ await async_test_method(self, *args, **kwargs)
+ await self._asyncTearDown()
+
+ return _wrapper
+
+ # Definitions
+
+ # The states we expect a "bad" connect/accept attempt to transition through
+ BAD_CONNECTION_STATES = (
+ Runstate.CONNECTING,
+ Runstate.DISCONNECTING,
+ Runstate.IDLE,
+ )
+
+ # The states we expect a "good" session to transition through
+ GOOD_CONNECTION_STATES = (
+ Runstate.CONNECTING,
+ Runstate.RUNNING,
+ Runstate.DISCONNECTING,
+ Runstate.IDLE,
+ )
+
+ # Helpers
+
+ async def _watch_runstates(self, *states):
+ """
+ This launches a task alongside (most) tests below to confirm that
+ the sequence of runstate changes that occur is exactly as
+ anticipated.
+ """
+ async def _watcher():
+ for state in states:
+ new_state = await self.proto.runstate_changed()
+ self.assertEqual(
+ new_state,
+ state,
+ msg=f"Expected state '{state.name}'",
+ )
+
+ self.runstate_watcher = create_task(_watcher())
+ # Kick the loop and force the task to block on the event.
+ await asyncio.sleep(0)
+
+
+class State(TestBase):
+
+ @TestBase.async_test
+ async def testSuperfluousDisconnect(self):
+ """
+ Test calling disconnect() while already disconnected.
+ """
+ await self._watch_runstates(
+ Runstate.DISCONNECTING,
+ Runstate.IDLE,
+ )
+ await self.proto.disconnect()
+
+
+class Connect(TestBase):
+ """
+ Tests primarily related to calling Connect().
+ """
+ async def _bad_connection(self, family: str):
+ assert family in ('INET', 'UNIX')
+
+ if family == 'INET':
+ await self.proto.connect(('127.0.0.1', 0))
+ elif family == 'UNIX':
+ await self.proto.connect('/dev/null')
+
+ async def _hanging_connection(self):
+ with jammed_socket() as addr:
+ await self.proto.connect(addr)
+
+ async def _bad_connection_test(self, family: str):
+ await self._watch_runstates(*self.BAD_CONNECTION_STATES)
+
+ with self.assertRaises(ConnectError) as context:
+ await self._bad_connection(family)
+
+ self.assertIsInstance(context.exception.exc, OSError)
+ self.assertEqual(
+ context.exception.error_message,
+ "Failed to establish connection"
+ )
+
+ @TestBase.async_test
+ async def testBadINET(self):
+ """
+ Test an immediately rejected call to an IP target.
+ """
+ await self._bad_connection_test('INET')
+
+ @TestBase.async_test
+ async def testBadUNIX(self):
+ """
+ Test an immediately rejected call to a UNIX socket target.
+ """
+ await self._bad_connection_test('UNIX')
+
+ @TestBase.async_test
+ async def testCancellation(self):
+ """
+ Test what happens when a connection attempt is aborted.
+ """
+ # Note that accept() cannot be cancelled outright, as it isn't a task.
+ # However, we can wrap it in a task and cancel *that*.
+ await self._watch_runstates(*self.BAD_CONNECTION_STATES)
+ task = run_as_task(self._hanging_connection(), allow_cancellation=True)
+
+ state = await self.proto.runstate_changed()
+ self.assertEqual(state, Runstate.CONNECTING)
+
+ # This is insider baseball, but the connection attempt has
+ # yielded *just* before the actual connection attempt, so kick
+ # the loop to make sure it's truly wedged.
+ await asyncio.sleep(0)
+
+ task.cancel()
+ await task
+
+ @TestBase.async_test
+ async def testTimeout(self):
+ """
+ Test what happens when a connection attempt times out.
+ """
+ await self._watch_runstates(*self.BAD_CONNECTION_STATES)
+ task = run_as_task(self._hanging_connection())
+
+ # More insider baseball: to improve the speed of this test while
+ # guaranteeing that the connection even gets a chance to start,
+ # verify that the connection hangs *first*, then await the
+ # result of the task with a nearly-zero timeout.
+
+ state = await self.proto.runstate_changed()
+ self.assertEqual(state, Runstate.CONNECTING)
+ await asyncio.sleep(0)
+
+ with self.assertRaises(asyncio.TimeoutError):
+ await asyncio.wait_for(task, timeout=0)
+
+ @TestBase.async_test
+ async def testRequire(self):
+ """
+ Test what happens when a connection attempt is made while CONNECTING.
+ """
+ await self._watch_runstates(*self.BAD_CONNECTION_STATES)
+ task = run_as_task(self._hanging_connection(), allow_cancellation=True)
+
+ state = await self.proto.runstate_changed()
+ self.assertEqual(state, Runstate.CONNECTING)
+
+ with self.assertRaises(StateError) as context:
+ await self._bad_connection('UNIX')
+
+ self.assertEqual(
+ context.exception.error_message,
+ "NullProtocol is currently connecting."
+ )
+ self.assertEqual(context.exception.state, Runstate.CONNECTING)
+ self.assertEqual(context.exception.required, Runstate.IDLE)
+
+ task.cancel()
+ await task
+
+ @TestBase.async_test
+ async def testImplicitRunstateInit(self):
+ """
+ Test what happens if we do not wait on the runstate event until
+ AFTER a connection is made, i.e., connect()/accept() themselves
+ initialize the runstate event. All of the above tests force the
+ initialization by waiting on the runstate *first*.
+ """
+ task = run_as_task(self._hanging_connection(), allow_cancellation=True)
+
+ # Kick the loop to coerce the state change
+ await asyncio.sleep(0)
+ assert self.proto.runstate == Runstate.CONNECTING
+
+ # We already missed the transition to CONNECTING
+ await self._watch_runstates(Runstate.DISCONNECTING, Runstate.IDLE)
+
+ task.cancel()
+ await task
+
+
+class Accept(Connect):
+ """
+ All of the same tests as Connect, but using the accept() interface.
+ """
+ async def _bad_connection(self, family: str):
+ assert family in ('INET', 'UNIX')
+
+ if family == 'INET':
+ await self.proto.accept(('example.com', 1))
+ elif family == 'UNIX':
+ await self.proto.accept('/dev/null')
+
+ async def _hanging_connection(self):
+ with TemporaryDirectory(suffix='.aqmp') as tmpdir:
+ sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock")
+ await self.proto.accept(sock)
+
+
+class FakeSession(TestBase):
+
+ def setUp(self):
+ super().setUp()
+ self.proto.fake_session = True
+
+ async def _asyncSetUp(self):
+ await super()._asyncSetUp()
+ await self._watch_runstates(*self.GOOD_CONNECTION_STATES)
+
+ async def _asyncTearDown(self):
+ await self.proto.disconnect()
+ await super()._asyncTearDown()
+
+ ####
+
+ @TestBase.async_test
+ async def testFakeConnect(self):
+
+ """Test the full state lifecycle (via connect) with a no-op session."""
+ await self.proto.connect('/not/a/real/path')
+ self.assertEqual(self.proto.runstate, Runstate.RUNNING)
+
+ @TestBase.async_test
+ async def testFakeAccept(self):
+ """Test the full state lifecycle (via accept) with a no-op session."""
+ await self.proto.accept('/not/a/real/path')
+ self.assertEqual(self.proto.runstate, Runstate.RUNNING)
+
+ @TestBase.async_test
+ async def testFakeRecv(self):
+ """Test receiving a fake/null message."""
+ await self.proto.accept('/not/a/real/path')
+
+ logname = self.proto.logger.name
+ with self.assertLogs(logname, level='DEBUG') as context:
+ self.proto.trigger_input.set()
+ self.proto.trigger_input.clear()
+ await asyncio.sleep(0) # Kick reader.
+
+ self.assertEqual(
+ context.output,
+ [f"DEBUG:{logname}:<-- None"],
+ )
+
+ @TestBase.async_test
+ async def testFakeSend(self):
+ """Test sending a fake/null message."""
+ await self.proto.accept('/not/a/real/path')
+
+ logname = self.proto.logger.name
+ with self.assertLogs(logname, level='DEBUG') as context:
+ # Cheat: Send a Null message to nobody.
+ await self.proto.send_msg()
+ # Kick writer; awaiting on a queue.put isn't sufficient to yield.
+ await asyncio.sleep(0)
+
+ self.assertEqual(
+ context.output,
+ [f"DEBUG:{logname}:--> None"],
+ )
+
+ async def _prod_session_api(
+ self,
+ current_state: Runstate,
+ error_message: str,
+ accept: bool = True
+ ):
+ with self.assertRaises(StateError) as context:
+ if accept:
+ await self.proto.accept('/not/a/real/path')
+ else:
+ await self.proto.connect('/not/a/real/path')
+
+ self.assertEqual(context.exception.error_message, error_message)
+ self.assertEqual(context.exception.state, current_state)
+ self.assertEqual(context.exception.required, Runstate.IDLE)
+
+ @TestBase.async_test
+ async def testAcceptRequireRunning(self):
+ """Test that accept() cannot be called when Runstate=RUNNING"""
+ await self.proto.accept('/not/a/real/path')
+
+ await self._prod_session_api(
+ Runstate.RUNNING,
+ "NullProtocol is already connected and running.",
+ accept=True,
+ )
+
+ @TestBase.async_test
+ async def testConnectRequireRunning(self):
+ """Test that connect() cannot be called when Runstate=RUNNING"""
+ await self.proto.accept('/not/a/real/path')
+
+ await self._prod_session_api(
+ Runstate.RUNNING,
+ "NullProtocol is already connected and running.",
+ accept=False,
+ )
+
+ @TestBase.async_test
+ async def testAcceptRequireDisconnecting(self):
+ """Test that accept() cannot be called when Runstate=DISCONNECTING"""
+ await self.proto.accept('/not/a/real/path')
+
+ # Cheat: force a disconnect.
+ await self.proto.simulate_disconnect()
+
+ await self._prod_session_api(
+ Runstate.DISCONNECTING,
+ ("NullProtocol is disconnecting."
+ " Call disconnect() to return to IDLE state."),
+ accept=True,
+ )
+
+ @TestBase.async_test
+ async def testConnectRequireDisconnecting(self):
+ """Test that connect() cannot be called when Runstate=DISCONNECTING"""
+ await self.proto.accept('/not/a/real/path')
+
+ # Cheat: force a disconnect.
+ await self.proto.simulate_disconnect()
+
+ await self._prod_session_api(
+ Runstate.DISCONNECTING,
+ ("NullProtocol is disconnecting."
+ " Call disconnect() to return to IDLE state."),
+ accept=False,
+ )
+
+
+class SimpleSession(TestBase):
+
+ def setUp(self):
+ super().setUp()
+ self.server = LineProtocol(type(self).__name__ + '-server')
+
+ async def _asyncSetUp(self):
+ await super()._asyncSetUp()
+ await self._watch_runstates(*self.GOOD_CONNECTION_STATES)
+
+ async def _asyncTearDown(self):
+ await self.proto.disconnect()
+ try:
+ await self.server.disconnect()
+ except EOFError:
+ pass
+ await super()._asyncTearDown()
+
+ @TestBase.async_test
+ async def testSmoke(self):
+ with TemporaryDirectory(suffix='.aqmp') as tmpdir:
+ sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock")
+ server_task = create_task(self.server.accept(sock))
+
+ # give the server a chance to start listening [...]
+ await asyncio.sleep(0)
+ await self.proto.connect(sock)