Skip to content

Commit

Permalink
Fix encryption tests (#2018)
Browse files Browse the repository at this point in the history
  • Loading branch information
NoahStapp authored Nov 26, 2024
1 parent 36480f9 commit d8274b7
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 31 deletions.
2 changes: 1 addition & 1 deletion .evergreen/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ functions:
"run tests":
- command: subprocess.exec
params:
include_expansions_in_env: ["TEST_DATA_LAKE", "AUTH", "SSL", "TEST_INDEX_MANAGEMENT", "CRYPT_SHARED_LIB_PATH", "test_encryption", "test_encryption_pyopenssl", "test_crypt_shared", "test_pyopenssl", "test_loadbalancer", "test_serverless", "ORCHESTRATION_FILE"]
include_expansions_in_env: ["TEST_DATA_LAKE", "PYTHON_BINARY", "AUTH", "SSL", "TEST_INDEX_MANAGEMENT", "CRYPT_SHARED_LIB_PATH", "test_encryption", "test_encryption_pyopenssl", "test_crypt_shared", "test_pyopenssl", "test_loadbalancer", "test_serverless", "ORCHESTRATION_FILE"]
binary: bash
working_dir: "src"
args:
Expand Down
21 changes: 12 additions & 9 deletions test/asynchronous/test_encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -1234,7 +1234,9 @@ async def test_03_bulk_batch_split(self):
doc2 = {"_id": "over_2mib_2", "unencrypted": "a" * _2_MiB}
self.listener.reset()
await self.coll_encrypted.bulk_write([InsertOne(doc1), InsertOne(doc2)])
self.assertEqual(self.listener.started_command_names(), ["insert", "insert"])
self.assertEqual(
len([c for c in self.listener.started_command_names() if c == "insert"]), 2
)

async def test_04_bulk_batch_split(self):
limits_doc = json_data("limits", "limits-doc.json")
Expand All @@ -1244,7 +1246,9 @@ async def test_04_bulk_batch_split(self):
doc2.update(limits_doc)
self.listener.reset()
await self.coll_encrypted.bulk_write([InsertOne(doc1), InsertOne(doc2)])
self.assertEqual(self.listener.started_command_names(), ["insert", "insert"])
self.assertEqual(
len([c for c in self.listener.started_command_names() if c == "insert"]), 2
)

async def test_05_insert_succeeds_just_under_16MiB(self):
doc = {"_id": "under_16mib", "unencrypted": "a" * (_16_MiB - 2000)}
Expand Down Expand Up @@ -1482,19 +1486,18 @@ class AzureGCPEncryptionTestMixin(AsyncEncryptionIntegrationTest):
KEYVAULT_COLL = "datakeys"
client: AsyncMongoClient

async def asyncSetUp(self):
self.client = self.simple_client()
async def _setup(self):
keyvault = self.client.get_database(self.KEYVAULT_DB).get_collection(self.KEYVAULT_COLL)
await create_key_vault(keyvault, self.DEK)

async def _test_explicit(self, expectation):
await self._setup()
client_encryption = self.create_client_encryption(
self.KMS_PROVIDER_MAP, # type: ignore[arg-type]
".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL]),
async_client_context.client,
OPTS,
)
self.addAsyncCleanup(client_encryption.close)

ciphertext = await client_encryption.encrypt(
"string0",
Expand All @@ -1506,6 +1509,7 @@ async def _test_explicit(self, expectation):
self.assertEqual(await client_encryption.decrypt(ciphertext), "string0")

async def _test_automatic(self, expectation_extjson, payload):
await self._setup()
encrypted_db = "db"
encrypted_coll = "coll"
keyvault_namespace = ".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL])
Expand All @@ -1520,7 +1524,6 @@ async def _test_automatic(self, expectation_extjson, payload):
client = await self.async_rs_or_single_client(
auto_encryption_opts=encryption_opts, event_listeners=[insert_listener]
)
self.addAsyncCleanup(client.aclose)

coll = client.get_database(encrypted_db).get_collection(
encrypted_coll, codec_options=OPTS, write_concern=WriteConcern("majority")
Expand Down Expand Up @@ -1594,6 +1597,7 @@ async def test_automatic(self):
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#deadlock-tests
class TestDeadlockProse(AsyncEncryptionIntegrationTest):
async def asyncSetUp(self):
await super().asyncSetUp()
self.client_test = await self.async_rs_or_single_client(
maxPoolSize=1, readConcernLevel="majority", w="majority", uuidRepresentation="standard"
)
Expand Down Expand Up @@ -1626,7 +1630,6 @@ async def asyncSetUp(self):
self.ciphertext = await client_encryption.encrypt(
"string0", Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_alt_name="local"
)
await client_encryption.close()

self.client_listener = OvertCommandListener()
self.topology_listener = TopologyEventListener()
Expand Down Expand Up @@ -1821,6 +1824,7 @@ async def test_case_8(self):
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#14-decryption-events
class TestDecryptProse(AsyncEncryptionIntegrationTest):
async def asyncSetUp(self):
await super().asyncSetUp()
self.client = async_client_context.client
await self.client.db.drop_collection("decryption_events")
await create_key_vault(self.client.keyvault.datakeys)
Expand Down Expand Up @@ -2256,6 +2260,7 @@ async def test_06_named_kms_providers_apply_tls_options_kmip(self):
# https://github.com/mongodb/specifications/blob/50e26fe/source/client-side-encryption/tests/README.md#unique-index-on-keyaltnames
class TestUniqueIndexOnKeyAltNamesProse(AsyncEncryptionIntegrationTest):
async def asyncSetUp(self):
await super().asyncSetUp()
self.client = async_client_context.client
await create_key_vault(self.client.keyvault.datakeys)
kms_providers_map = {"local": {"key": LOCAL_MASTER_KEY}}
Expand Down Expand Up @@ -2605,8 +2610,6 @@ async def AsyncMongoClient(**kwargs):
assert isinstance(res["encrypted_indexed"], Binary)
assert isinstance(res["encrypted_unindexed"], Binary)

await client_encryption.close()


# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#22-range-explicit-encryption
class TestRangeQueryProse(AsyncEncryptionIntegrationTest):
Expand Down
4 changes: 0 additions & 4 deletions test/asynchronous/unified_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,10 +499,6 @@ async def asyncSetUp(self):
# process file-level runOnRequirements
run_on_spec = self.TEST_SPEC.get("runOnRequirements", [])
if not await self.should_run_on(run_on_spec):
# Explicitly close async clients here
# to prevent leaky monitor tasks
if not _IS_SYNC:
await async_client_context.client.close()
raise unittest.SkipTest(f"{self.__class__.__name__} runOnRequirements not satisfied")

# add any special-casing for skipping tests here
Expand Down
2 changes: 0 additions & 2 deletions test/asynchronous/utils_spec_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,8 +692,6 @@ async def run_scenario(self, scenario_def, test):
self.listener = listener
self.pool_listener = pool_listener
self.server_listener = server_listener
# Close the client explicitly to avoid having too many threads open.
self.addAsyncCleanup(client.close)

# Create session0 and session1.
sessions = {}
Expand Down
21 changes: 12 additions & 9 deletions test/test_encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -1230,7 +1230,9 @@ def test_03_bulk_batch_split(self):
doc2 = {"_id": "over_2mib_2", "unencrypted": "a" * _2_MiB}
self.listener.reset()
self.coll_encrypted.bulk_write([InsertOne(doc1), InsertOne(doc2)])
self.assertEqual(self.listener.started_command_names(), ["insert", "insert"])
self.assertEqual(
len([c for c in self.listener.started_command_names() if c == "insert"]), 2
)

def test_04_bulk_batch_split(self):
limits_doc = json_data("limits", "limits-doc.json")
Expand All @@ -1240,7 +1242,9 @@ def test_04_bulk_batch_split(self):
doc2.update(limits_doc)
self.listener.reset()
self.coll_encrypted.bulk_write([InsertOne(doc1), InsertOne(doc2)])
self.assertEqual(self.listener.started_command_names(), ["insert", "insert"])
self.assertEqual(
len([c for c in self.listener.started_command_names() if c == "insert"]), 2
)

def test_05_insert_succeeds_just_under_16MiB(self):
doc = {"_id": "under_16mib", "unencrypted": "a" * (_16_MiB - 2000)}
Expand Down Expand Up @@ -1476,19 +1480,18 @@ class AzureGCPEncryptionTestMixin(EncryptionIntegrationTest):
KEYVAULT_COLL = "datakeys"
client: MongoClient

def setUp(self):
self.client = self.simple_client()
def _setup(self):
keyvault = self.client.get_database(self.KEYVAULT_DB).get_collection(self.KEYVAULT_COLL)
create_key_vault(keyvault, self.DEK)

def _test_explicit(self, expectation):
self._setup()
client_encryption = self.create_client_encryption(
self.KMS_PROVIDER_MAP, # type: ignore[arg-type]
".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL]),
client_context.client,
OPTS,
)
self.addCleanup(client_encryption.close)

ciphertext = client_encryption.encrypt(
"string0",
Expand All @@ -1500,6 +1503,7 @@ def _test_explicit(self, expectation):
self.assertEqual(client_encryption.decrypt(ciphertext), "string0")

def _test_automatic(self, expectation_extjson, payload):
self._setup()
encrypted_db = "db"
encrypted_coll = "coll"
keyvault_namespace = ".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL])
Expand All @@ -1514,7 +1518,6 @@ def _test_automatic(self, expectation_extjson, payload):
client = self.rs_or_single_client(
auto_encryption_opts=encryption_opts, event_listeners=[insert_listener]
)
self.addCleanup(client.close)

coll = client.get_database(encrypted_db).get_collection(
encrypted_coll, codec_options=OPTS, write_concern=WriteConcern("majority")
Expand Down Expand Up @@ -1588,6 +1591,7 @@ def test_automatic(self):
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#deadlock-tests
class TestDeadlockProse(EncryptionIntegrationTest):
def setUp(self):
super().setUp()
self.client_test = self.rs_or_single_client(
maxPoolSize=1, readConcernLevel="majority", w="majority", uuidRepresentation="standard"
)
Expand Down Expand Up @@ -1618,7 +1622,6 @@ def setUp(self):
self.ciphertext = client_encryption.encrypt(
"string0", Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_alt_name="local"
)
client_encryption.close()

self.client_listener = OvertCommandListener()
self.topology_listener = TopologyEventListener()
Expand Down Expand Up @@ -1813,6 +1816,7 @@ def test_case_8(self):
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#14-decryption-events
class TestDecryptProse(EncryptionIntegrationTest):
def setUp(self):
super().setUp()
self.client = client_context.client
self.client.db.drop_collection("decryption_events")
create_key_vault(self.client.keyvault.datakeys)
Expand Down Expand Up @@ -2248,6 +2252,7 @@ def test_06_named_kms_providers_apply_tls_options_kmip(self):
# https://github.com/mongodb/specifications/blob/50e26fe/source/client-side-encryption/tests/README.md#unique-index-on-keyaltnames
class TestUniqueIndexOnKeyAltNamesProse(EncryptionIntegrationTest):
def setUp(self):
super().setUp()
self.client = client_context.client
create_key_vault(self.client.keyvault.datakeys)
kms_providers_map = {"local": {"key": LOCAL_MASTER_KEY}}
Expand Down Expand Up @@ -2589,8 +2594,6 @@ def MongoClient(**kwargs):
assert isinstance(res["encrypted_indexed"], Binary)
assert isinstance(res["encrypted_unindexed"], Binary)

client_encryption.close()


# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#22-range-explicit-encryption
class TestRangeQueryProse(EncryptionIntegrationTest):
Expand Down
4 changes: 0 additions & 4 deletions test/unified_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,10 +498,6 @@ def setUp(self):
# process file-level runOnRequirements
run_on_spec = self.TEST_SPEC.get("runOnRequirements", [])
if not self.should_run_on(run_on_spec):
# Explicitly close async clients here
# to prevent leaky monitor tasks
if not _IS_SYNC:
client_context.client.close()
raise unittest.SkipTest(f"{self.__class__.__name__} runOnRequirements not satisfied")

# add any special-casing for skipping tests here
Expand Down
2 changes: 0 additions & 2 deletions test/utils_spec_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,8 +689,6 @@ def run_scenario(self, scenario_def, test):
self.listener = listener
self.pool_listener = pool_listener
self.server_listener = server_listener
# Close the client explicitly to avoid having too many threads open.
self.addCleanup(client.close)

# Create session0 and session1.
sessions = {}
Expand Down

0 comments on commit d8274b7

Please sign in to comment.