From e049cc7d44674c840d866c0e22d1a74dc885c115 Mon Sep 17 00:00:00 2001
From: Dmitriy <dimastbk@proton.me>
Date: Mon, 19 Aug 2024 11:29:36 +0500
Subject: [PATCH] fix: check writer is closing in AIOKafkaConnection.send

---
 aiokafka/conn.py    |  6 ++++
 requirements-ci.txt |  1 +
 tests/test_conn.py  | 82 +++++++++++++++++++++++++++++++++++++++++----
 3 files changed, 83 insertions(+), 6 deletions(-)

diff --git a/aiokafka/conn.py b/aiokafka/conn.py
index 333f9fad..a6132c63 100644
--- a/aiokafka/conn.py
+++ b/aiokafka/conn.py
@@ -457,6 +457,12 @@ def send(self, request, expect_response=True):
                 f"No connection to broker at {self._host}:{self._port}"
             )
 
+        if self._writer.is_closing():
+            self.close(reason=CloseReason.CONNECTION_BROKEN)
+            raise Errors.KafkaConnectionError(
+                f"Connection at {self._host}:{self._port} is closing"
+            )
+
         correlation_id = self._next_correlation_id()
         header = request.build_request_header(
             correlation_id=correlation_id, client_id=self._client_id
diff --git a/requirements-ci.txt b/requirements-ci.txt
index 2294b51d..4502fab0 100644
--- a/requirements-ci.txt
+++ b/requirements-ci.txt
@@ -14,3 +14,4 @@ Pygments==2.15.0
 gssapi==1.8.3
 async-timeout==4.0.1
 cramjam==2.8.0
+uvloop=0.19.0
diff --git a/tests/test_conn.py b/tests/test_conn.py
index f0f4a075..17837c2e 100644
--- a/tests/test_conn.py
+++ b/tests/test_conn.py
@@ -1,10 +1,13 @@
 import asyncio
 import gc
+import socket
 import struct
-from typing import Any
+from typing import Any, AsyncIterable, Iterable, Tuple
 from unittest import mock
 
 import pytest
+import pytest_asyncio
+import uvloop
 
 from aiokafka.conn import AIOKafkaConnection, VersionInfo, create_conn
 from aiokafka.errors import (
@@ -144,7 +147,7 @@ async def test_send_to_closed(self):
         with self.assertRaises(KafkaConnectionError):
             await conn.send(request)
 
-        conn._writer = mock.MagicMock()
+        conn._writer = mock.MagicMock(is_closing=mock.Mock(return_value=False))
         conn._writer.write.side_effect = OSError("mocked writer is closed")
 
         with self.assertRaises(KafkaConnectionError):
@@ -173,7 +176,7 @@ async def second_resp(*args: Any, **kw: Any):
             return resp
 
         reader.readexactly.side_effect = [first_resp(), second_resp()]
-        writer = mock.MagicMock()
+        writer = mock.MagicMock(is_closing=mock.Mock(return_value=False))
 
         conn._reader = reader
         conn._writer = writer
@@ -208,7 +211,7 @@ async def second_resp(*args: Any, **kw: Any):
             return resp
 
         reader.readexactly.side_effect = [first_resp(), second_resp()]
-        writer = mock.MagicMock()
+        writer = mock.MagicMock(is_closing=mock.Mock(return_value=False))
 
         conn._reader = reader
         conn._writer = writer
@@ -237,7 +240,7 @@ async def invoke_osserror(*a, **kw):
         # setup reader
         reader = mock.MagicMock()
         reader.readexactly.return_value = invoke_osserror()
-        writer = mock.MagicMock()
+        writer = mock.MagicMock(is_closing=mock.Mock(return_value=False))
 
         conn._reader = reader
         conn._writer = writer
@@ -394,7 +397,7 @@ async def test__send_sasl_token(self):
         # setup connection with mocked transport and protocol
         conn = AIOKafkaConnection(host="", port=9999)
         conn.close = mock.MagicMock()
-        conn._writer = mock.MagicMock()
+        conn._writer = mock.MagicMock(is_closing=mock.Mock(return_value=False))
         out_buffer = []
         conn._writer.write = mock.Mock(side_effect=out_buffer.append)
         conn._reader = mock.MagicMock()
@@ -424,3 +427,70 @@ async def test__send_sasl_token(self):
             conn._send_sasl_token(b"Super data")
         # We don't need to close 2ce
         self.assertEqual(conn.close.call_count, 1)
+
+
+class TestClosedSocket:
+    @pytest.fixture(
+        params=(
+            asyncio.DefaultEventLoopPolicy(),
+            uvloop.EventLoopPolicy(),
+        ),
+    )
+    def event_loop(
+        self, request: pytest.FixtureRequest
+    ) -> Iterable[asyncio.AbstractEventLoop]:
+        loop: asyncio.AbstractEventLoop = request.param.new_event_loop()
+        yield loop
+        loop.close()
+
+    @pytest.fixture()
+    def server(self, unused_tcp_port: int) -> Iterable[Tuple[str, int, socket.socket]]:
+        host = "localhost"
+        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        sock.bind((host, unused_tcp_port))
+        sock.listen(8)
+        sock.setblocking(False)
+
+        yield host, unused_tcp_port, sock
+
+        sock.close()
+
+    @pytest_asyncio.fixture()
+    async def conn(
+        self, server: Tuple[str, int, socket.socket]
+    ) -> AsyncIterable[AIOKafkaConnection]:
+        host, port, _ = server
+
+        conn = AIOKafkaConnection(host=host, port=port, request_timeout_ms=1000)
+        conn._create_reader_task = mock.Mock()
+
+        yield conn
+
+        fut = conn.close()
+        if fut:
+            await fut
+
+    @pytest.mark.asyncio
+    async def test_send_to_closed_socket(
+        self, server: Tuple[str, int, socket.socket], conn: AIOKafkaConnection
+    ) -> None:
+        host, port, sock = server
+
+        request = MetadataRequest([])
+
+        with pytest.raises(
+            KafkaConnectionError,
+            match=f"KafkaConnectionError: No connection to broker at {host}:{port}",
+        ):
+            await conn.send(request)
+
+        await conn.connect()
+
+        sock.close()
+        await asyncio.sleep(0.1)
+
+        with pytest.raises(
+            KafkaConnectionError,
+            match=f"KafkaConnectionError: Connection at {host}:{port} is closing",
+        ):
+            await conn.send(request)