diff --git a/test/lang/python/requirements.txt b/test/lang/python/requirements.txt index 8b46b1a7d9232..76398f91c6eec 100644 --- a/test/lang/python/requirements.txt +++ b/test/lang/python/requirements.txt @@ -1,4 +1,4 @@ psycopg==3.1.10 psycopg-binary==3.1.10 -psycopg2==2.8.6 +psycopg2==2.9.9 SQLAlchemy==1.3.20 diff --git a/test/lang/python/smoketest.py b/test/lang/python/smoketest.py index f8f7c563c398b..8ecd05fa14833 100644 --- a/test/lang/python/smoketest.py +++ b/test/lang/python/smoketest.py @@ -82,50 +82,48 @@ def test_sqlalchemy(self) -> None: def test_psycopg2_subscribe(self) -> None: """Test SUBSCRIBE with psycopg2 via server cursors.""" - with psycopg2.connect(MATERIALIZED_URL) as conn: - conn.set_session(autocommit=True) - with conn.cursor() as cur: - # Create a table with one row of data. - cur.execute("CREATE TABLE psycopg2_subscribe (a int, b text)") - cur.execute("INSERT INTO psycopg2_subscribe VALUES (1, 'a')") - conn.set_session(autocommit=False) + conn = psycopg2.connect(MATERIALIZED_URL) + conn.set_session(autocommit=True) + with conn.cursor() as cur: + # Create a table with one row of data. + cur.execute("CREATE TABLE psycopg2_subscribe (a int, b text)") + cur.execute("INSERT INTO psycopg2_subscribe VALUES (1, 'a')") + conn.set_session(autocommit=False) + + # Start SUBSCRIBE using the binary copy protocol. + cur.execute("DECLARE cur CURSOR FOR SUBSCRIBE psycopg2_subscribe") + cur.execute("FETCH ALL cur") + + # Validate the first row, but ignore the timestamp column. + row = cur.fetchone() + if row is not None: + (ts, diff, a, b) = row + self.assertEqual(diff, 1) + self.assertEqual(a, 1) + self.assertEqual(b, "a") + else: + self.fail("row is None") - # Start SUBSCRIBE using the binary copy protocol. - cur.execute("DECLARE cur CURSOR FOR SUBSCRIBE psycopg2_subscribe") - cur.execute("FETCH ALL cur") + self.assertEqual(cur.fetchone(), None) - # Validate the first row, but ignore the timestamp column. - row = cur.fetchone() - if row is not None: - (ts, diff, a, b) = row - self.assertEqual(diff, 1) - self.assertEqual(a, 1) - self.assertEqual(b, "a") - else: - self.fail("row is None") + # Insert another row from another connection to simulate an + # update arriving. + with psycopg2.connect(MATERIALIZED_URL) as conn2: + conn2.set_session(autocommit=True) + with conn2.cursor() as cur2: + cur2.execute("INSERT INTO psycopg2_subscribe VALUES (2, 'b')") - self.assertEqual(cur.fetchone(), None) + # Validate the new row, again ignoring the timestamp column. + cur.execute("FETCH ALL cur") + row = cur.fetchone() + assert row is not None - # Insert another row from another connection to simulate an - # update arriving. - with psycopg2.connect(MATERIALIZED_URL) as conn2: - conn2.set_session(autocommit=True) - with conn2.cursor() as cur2: - cur2.execute("INSERT INTO psycopg2_subscribe VALUES (2, 'b')") + (ts, diff, a, b) = row + self.assertEqual(diff, 1) + self.assertEqual(a, 2) + self.assertEqual(b, "b") - # Validate the new row, again ignoring the timestamp column. - cur.execute("FETCH ALL cur") - row = cur.fetchone() - - if row is not None: - (ts, diff, a, b) = row - self.assertEqual(diff, 1) - self.assertEqual(a, 2) - self.assertEqual(b, "b") - else: - self.fail("row None") - - self.assertEqual(cur.fetchone(), None) + self.assertEqual(cur.fetchone(), None) def test_psycopg3_subscribe_copy(self) -> None: """Test SUBSCRIBE with psycopg3 via its new binary COPY decoding support.""" @@ -177,16 +175,9 @@ def test_psycopg3_subscribe_copy(self) -> None: # The subscribe won't end until we send a cancel request. conn.cancel() - with self.assertRaises(Exception) as context: + with self.assertRaises(psycopg.errors.QueryCanceled): copy.read_row() - self.assertTrue( - "canceling statement due to user request" - in str(context.exception) - ) - # There might be problem with stream and the cancellation message. Skip until - # resolved. - @unittest.skip("https://github.com/psycopg/psycopg3/issues/30") def test_psycopg3_subscribe_stream(self) -> None: """Test subscribe with psycopg3 via its new streaming query support.""" with psycopg.connect(MATERIALIZED_URL) as conn: @@ -223,4 +214,70 @@ def test_psycopg3_subscribe_stream(self) -> None: # The subscribe won't end until we send a cancel request. conn.cancel() - self.assertEqual(next(stream, None), None) + with self.assertRaises(psycopg.errors.QueryCanceled): + next(stream) + + def test_psycopg3_subscribe_terminate_connection(self) -> None: + """Test terminating a bare subscribe with psycopg3. + + This test ensures that Materialize notices a TCP connection close when a + bare SUBSCRIBE statement (i.e., one not wrapped in a COPY statement) is + producing no rows. + """ + + # Create two connections: one to create a subscription and one to + # query metadata about the subscription. + with psycopg.connect(MATERIALIZED_URL) as metadata_conn: + with psycopg.connect(MATERIALIZED_URL) as subscribe_conn: + try: + metadata_session_id = metadata_conn.pgconn.backend_pid + subscribe_session_id = subscribe_conn.pgconn.backend_pid + + # Subscribe to the list of active subscriptions in + # Materialize. + metadata = metadata_conn.cursor().stream( + "SUBSCRIBE (SELECT session_id FROM mz_internal.mz_subscriptions)" + ) + + # Ensure we see our own subscription in `mz_subscriptions`. + (_ts, diff, pid) = next(metadata) + self.assertEqual(int(pid), metadata_session_id) + self.assertEqual(diff, 1) + + # Create a dummy subscribe that we know will only ever + # produce a single row, but, as far as Materialize can tell, + # has the potential to produce future updates. This ensures + # the SUBSCRIBE operation will be blocked inside of + # Materialize waiting for more rows. + # + # IMPORTANT: this must use a bare `SUBSCRIBE` statement, + # rather than a `SUBSCRIBE` inside of a `COPY` operation, to + # test the code path that previously had the bug. + stream = subscribe_conn.cursor().stream( + "SUBSCRIBE (SELECT * FROM mz_tables LIMIT 1)" + ) + next(stream) + + # Ensure we see the dummy subscription added to + # `mz_subscriptions`. + (_ts, diff, pid) = next(metadata) + self.assertEqual(int(pid), subscribe_session_id) + self.assertEqual(diff, 1) + + # Kill the dummy subscription by forcibly closing the + # connection. + subscribe_conn.close() + + # Ensure we see the dummy subscription removed from + # `mz_subscriptions`. + (_ts, diff, pid) = next(metadata) + self.assertEqual(int(pid), subscribe_session_id) + self.assertEqual(diff, -1) + + finally: + # Ensure the connections are always closed, even if an + # assertion fails partway through the test, as otherwise the + # `with` context manager will hang forever waiting for the + # subscribes to gracefully terminate, which they never will. + subscribe_conn.close() + metadata_conn.close()