diff --git a/tests/test_client.py b/tests/test_client.py index cc1691db..91594f36 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -218,3 +218,75 @@ async def test_deliver_timeout(): await client.unsubscribe(["$SYS/broker/uptime"]) await client.disconnect() await broker.shutdown() + +@pytest.mark.asyncio +async def test_cancel_publish_qos1(): + """ + Tests that timeouts on published messages will clean up in flight messages + """ + data = b"data" + broker = Broker(broker_config, plugin_namespace="amqtt.test.plugins") + await broker.start() + client_pub = MQTTClient() + await client_pub.connect("mqtt://127.0.0.1/") + assert client_pub.session.inflight_out_count == 0 + fut = asyncio.create_task(client_pub.publish("test_topic", data, QOS_1)) + assert len(client_pub._handler._puback_waiters) == 0 + while len(client_pub._handler._puback_waiters) == 0 or fut.done(): + await asyncio.sleep(0) + assert len(client_pub._handler._puback_waiters) == 1 + assert client_pub.session.inflight_out_count == 1 + fut.cancel() + await asyncio.wait([fut]) + assert len(client_pub._handler._puback_waiters) == 0 + assert client_pub.session.inflight_out_count == 0 + await client_pub.disconnect() + await broker.shutdown() + +@pytest.mark.asyncio +async def test_cancel_publish_qos2_pubrec(): + """ + Tests that timeouts on published messages will clean up in flight messages + """ + data = b"data" + broker = Broker(broker_config, plugin_namespace="amqtt.test.plugins") + await broker.start() + client_pub = MQTTClient() + await client_pub.connect("mqtt://127.0.0.1/") + assert client_pub.session.inflight_out_count == 0 + fut = asyncio.create_task(client_pub.publish("test_topic", data, QOS_2)) + assert len(client_pub._handler._pubrec_waiters) == 0 + while len(client_pub._handler._pubrec_waiters) == 0 or fut.done() or fut.cancelled(): + await asyncio.sleep(0) + assert len(client_pub._handler._pubrec_waiters) == 1 + assert client_pub.session.inflight_out_count == 1 + fut.cancel() + await asyncio.sleep(1) + await asyncio.wait([fut]) + assert len(client_pub._handler._pubrec_waiters) == 0 + assert client_pub.session.inflight_out_count == 0 + await client_pub.disconnect() + await broker.shutdown() + +@pytest.mark.asyncio +async def test_cancel_publish_qos2_pubcomp(): + """ + Tests that timeouts on published messages will clean up in flight messages + """ + data = b"data" + broker = Broker(broker_config, plugin_namespace="amqtt.test.plugins") + await broker.start() + client_pub = MQTTClient() + await client_pub.connect("mqtt://127.0.0.1/") + assert client_pub.session.inflight_out_count == 0 + fut = asyncio.create_task(client_pub.publish("test_topic", data, QOS_2)) + assert len(client_pub._handler._pubcomp_waiters) == 0 + while len(client_pub._handler._pubcomp_waiters) == 0 or fut.done(): + await asyncio.sleep(0) + assert len(client_pub._handler._pubcomp_waiters) == 1 + fut.cancel() + await asyncio.wait([fut]) + assert len(client_pub._handler._pubcomp_waiters) == 0 + assert client_pub.session.inflight_out_count == 0 + await client_pub.disconnect() + await broker.shutdown()