From 0fdbac9b1b511309e4625a9a6b32797879d299a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jeremy=20Lain=C3=A9?= Date: Mon, 31 Jan 2022 21:08:09 +0100 Subject: [PATCH] [mdns] unreference mDNS protocol when connection is closed This allows cleanly shutting down the mDNS once all connections are closed. Otherwise we hit a ResourceWarning when the event loop exits. --- src/aioice/ice.py | 18 ++++++++++++++++-- tests/test_ice.py | 5 +++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/src/aioice/ice.py b/src/aioice/ice.py index ed5a168..2b5a5be 100644 --- a/src/aioice/ice.py +++ b/src/aioice/ice.py @@ -30,16 +30,27 @@ _mdns = threading.local() -async def get_or_create_mdns_protocol() -> mdns.MDnsProtocol: +async def get_or_create_mdns_protocol(subscriber: object) -> mdns.MDnsProtocol: if not hasattr(_mdns, "lock"): _mdns.lock = asyncio.Lock() _mdns.protocol = None + _mdns.subscribers = set() async with _mdns.lock: if _mdns.protocol is None: _mdns.protocol = await mdns.create_mdns_protocol() + _mdns.subscribers.add(subscriber) return _mdns.protocol +async def unref_mdns_protocol(subscriber: object) -> None: + if hasattr(_mdns, "lock"): + async with _mdns.lock: + _mdns.subscribers.discard(subscriber) + if _mdns.protocol and not _mdns.subscribers: + await _mdns.protocol.close() + _mdns.protocol = None + + def candidate_pair_priority( local: Candidate, remote: Candidate, ice_controlling: bool ) -> int: @@ -364,7 +375,7 @@ async def add_remote_candidate(self, remote_candidate: Candidate) -> None: # resolve mDNS candidate if mdns.is_mdns_hostname(remote_candidate.host): - mdns_protocol = await get_or_create_mdns_protocol() + mdns_protocol = await get_or_create_mdns_protocol(self) remote_addr = await mdns_protocol.resolve(remote_candidate.host) if remote_addr is None: self.__log_info( @@ -498,6 +509,9 @@ async def close(self) -> None: if self._check_list and not self._check_list_done: await self._check_list_state.put(ICE_FAILED) + # unreference mDNS + await unref_mdns_protocol(self) + self._nominated.clear() for protocol in self._protocols: await protocol.close() diff --git a/tests/test_ice.py b/tests/test_ice.py index f9ea86d..fb5e308 100644 --- a/tests/test_ice.py +++ b/tests/test_ice.py @@ -1092,6 +1092,9 @@ async def test_add_remote_candidate_mdns_bad(self): self.assertEqual(len(conn_a.remote_candidates), 0) self.assertEqual(conn_a._remote_candidates_end, False) + # close + await conn_a.close() + @asynctest async def test_add_remote_candidate_mdns_good(self): """ @@ -1118,6 +1121,8 @@ async def test_add_remote_candidate_mdns_good(self): self.assertEqual(conn_a.remote_candidates[0].host, "1.2.3.4") self.assertEqual(conn_a._remote_candidates_end, False) + # close + await conn_a.close() await publisher.close() @asynctest