diff --git a/spec/integ/crypto/crypto.spec.ts b/spec/integ/crypto/crypto.spec.ts index d7cc0896956..2c5be4657db 100644 --- a/spec/integ/crypto/crypto.spec.ts +++ b/spec/integ/crypto/crypto.spec.ts @@ -23,7 +23,14 @@ import { MockResponse, MockResponseFunction } from "fetch-mock"; import Olm from "@matrix-org/olm"; import * as testUtils from "../../test-utils/test-utils"; -import { CRYPTO_BACKENDS, getSyncResponse, InitCrypto, mkEventCustom, syncPromise } from "../../test-utils/test-utils"; +import { + advanceTimersUntil, + CRYPTO_BACKENDS, + getSyncResponse, + InitCrypto, + mkEventCustom, + syncPromise, +} from "../../test-utils/test-utils"; import * as testData from "../../test-utils/test-data"; import { BOB_SIGNED_CROSS_SIGNING_KEYS_DATA, @@ -2767,7 +2774,9 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("crypto (%s)", (backend: string, fetchMock.get("express:/_matrix/client/v3/room_keys/keys", keyBackupData); // should be able to restore from 4S - const importResult = await aliceClient.restoreKeyBackupWithSecretStorage(check!.backupInfo!); + const importResult = await advanceTimersUntil( + aliceClient.restoreKeyBackupWithSecretStorage(check!.backupInfo!), + ); expect(importResult.imported).toStrictEqual(1); }); diff --git a/spec/integ/crypto/megolm-backup.spec.ts b/spec/integ/crypto/megolm-backup.spec.ts index cff127a2742..d12b7e5486e 100644 --- a/spec/integ/crypto/megolm-backup.spec.ts +++ b/spec/integ/crypto/megolm-backup.spec.ts @@ -23,10 +23,17 @@ import { SyncResponder } from "../../test-utils/SyncResponder"; import { E2EKeyReceiver } from "../../test-utils/E2EKeyReceiver"; import { E2EKeyResponder } from "../../test-utils/E2EKeyResponder"; import { mockInitialApiRequests } from "../../test-utils/mockEndpoints"; -import { awaitDecryption, CRYPTO_BACKENDS, InitCrypto, syncPromise } from "../../test-utils/test-utils"; +import { + advanceTimersUntil, + awaitDecryption, + CRYPTO_BACKENDS, + InitCrypto, + syncPromise, +} from "../../test-utils/test-utils"; import * as testData from "../../test-utils/test-data"; import { KeyBackupInfo } from "../../../src/crypto-api/keybackup"; import { IKeyBackup } from "../../../src/crypto/backup"; +import { flushPromises } from "../../test-utils/flushPromises"; const ROOM_ID = testData.TEST_ROOM_ID; @@ -110,9 +117,9 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("megolm-keys backup (%s)", (backe /** an object which intercepts `/keys/query` requests on the test homeserver */ let e2eKeyResponder: E2EKeyResponder; - jest.useFakeTimers(); - beforeEach(async () => { + jest.useFakeTimers(); + // anything that we don't have a specific matcher for silently returns a 404 fetchMock.catch(404); fetchMock.config.warnOnFallback = false; @@ -134,6 +141,7 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("megolm-keys backup (%s)", (backe await jest.runAllTimersAsync(); fetchMock.mockReset(); + jest.restoreAllMocks(); }); async function initTestClient(opts: Partial = {}): Promise { @@ -149,64 +157,131 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("megolm-keys backup (%s)", (backe return client; } - it("Alice checks key backups when receiving a message she can't decrypt", async function () { - const syncResponse = { + describe("Key backup check on UTD message", () => { + // sync response which contains an encrypted event + const SYNC_RESPONSE = { next_batch: 1, - rooms: { - join: { - [ROOM_ID]: { - timeline: { - events: [testData.ENCRYPTED_EVENT], - }, - }, - }, - }, + rooms: { join: { [ROOM_ID]: { timeline: { events: [testData.ENCRYPTED_EVENT] } } } }, }; - fetchMock.get("express:/_matrix/client/v3/room_keys/keys/:room_id/:session_id", (url, request) => { - // check that the version is correct - const version = new URLSearchParams(new URL(url).search).get("version"); - if (version == "1") { - return testData.CURVE25519_KEY_BACKUP_DATA; - } else { - return { - status: 403, - body: { - current_version: "1", - errcode: "M_WRONG_ROOM_KEYS_VERSION", - error: "Wrong backup version.", - }, - }; + const EXPECTED_URL = + [ + "https://alice-server.com/_matrix/client/v3/room_keys/keys", + encodeURIComponent(testData.TEST_ROOM_ID), + encodeURIComponent(testData.MEGOLM_SESSION_DATA.session_id), + ].join("/") + "?version=1"; + + /** Flush promises enough times to get the crypto stacks to make the backup request */ + async function flushBackupRequest() { + // we have to run flushPromises lots of times. It seems like each time the rust code touches indexeddb, + // it needs another round of flushPromises to progress, or something. + for (let i = 0; i < 10; i++) { + await flushPromises(); } + } + + beforeEach(async () => { + fetchMock.get("path:/_matrix/client/v3/room_keys/version", testData.SIGNED_BACKUP_DATA); + + // ignore requests to send room key requests + fetchMock.put("express:/_matrix/client/v3/sendToDevice/m.room_key_request/:request_id", {}); + + aliceClient = await initTestClient(); + const aliceCrypto = aliceClient.getCrypto()!; + await aliceCrypto.storeSessionBackupPrivateKey( + Buffer.from(testData.BACKUP_DECRYPTION_KEY_BASE64, "base64"), + testData.SIGNED_BACKUP_DATA.version!, + ); + + // start after saving the private key + await aliceClient.startClient(); + + // tell Alice to trust the dummy device that signed the backup, and re-check the backup. + // XXX: should we automatically re-check after a device becomes verified? + await waitForDeviceList(); + await aliceClient.getCrypto()!.setDeviceVerified(testData.TEST_USER_ID, testData.TEST_DEVICE_ID); + await aliceClient.getCrypto()!.checkKeyBackupAndEnable(); }); - fetchMock.get("path:/_matrix/client/v3/room_keys/version", testData.SIGNED_BACKUP_DATA); + it("Alice checks key backups when receiving a message she can't decrypt", async () => { + fetchMock.get("express:/_matrix/client/v3/room_keys/keys/:room_id/:session_id", (url, request) => { + // check that the version is correct + const version = new URLSearchParams(new URL(url).search).get("version"); + if (version == "1") { + return testData.CURVE25519_KEY_BACKUP_DATA; + } else { + return { + status: 403, + body: { + current_version: "1", + errcode: "M_WRONG_ROOM_KEYS_VERSION", + error: "Wrong backup version.", + }, + }; + } + }); - aliceClient = await initTestClient(); - const aliceCrypto = aliceClient.getCrypto()!; - await aliceCrypto.storeSessionBackupPrivateKey( - Buffer.from(testData.BACKUP_DECRYPTION_KEY_BASE64, "base64"), - testData.SIGNED_BACKUP_DATA.version!, - ); + // Send Alice a message that she won't be able to decrypt, and check that she fetches the key from the backup. + syncResponder.sendOrQueueSyncResponse(SYNC_RESPONSE); + await syncPromise(aliceClient); - // start after saving the private key - await aliceClient.startClient(); + const room = aliceClient.getRoom(ROOM_ID)!; + const event = room.getLiveTimeline().getEvents()[0]; + await advanceTimersUntil(awaitDecryption(event, { waitOnDecryptionFailure: true })); - // tell Alice to trust the dummy device that signed the backup, and re-check the backup. - // XXX: should we automatically re-check after a device becomes verified? - await waitForDeviceList(); - await aliceCrypto.setDeviceVerified(testData.TEST_USER_ID, testData.TEST_DEVICE_ID); - await aliceClient.getCrypto()!.checkKeyBackupAndEnable(); + expect(event.getContent()).toEqual(testData.CLEAR_EVENT.content); + }); + + it("handles error on backup query gracefully", async () => { + jest.spyOn(console, "error").mockImplementation(() => {}); + + fetchMock.get( + "express:/_matrix/client/v3/room_keys/keys/:room_id/:session_id", + { status: 404, body: { errcode: "M_NOT_FOUND" } }, + { name: "getKey" }, + ); - // Now, send Alice a message that she won't be able to decrypt, and check that she fetches the key from the backup. - syncResponder.sendOrQueueSyncResponse(syncResponse); - await syncPromise(aliceClient); + // Send Alice a message that she won't be able to decrypt + syncResponder.sendOrQueueSyncResponse(SYNC_RESPONSE); + await flushBackupRequest(); + + const calls = fetchMock.calls("getKey"); + expect(calls.length).toEqual(1); + expect(calls[0][0]).toEqual(EXPECTED_URL); + + await flushBackupRequest(); + + // we should not have logged an error. + // eslint-disable-next-line no-console + expect(console.error).not.toHaveBeenCalled(); + }); + + it("Only queries once", async () => { + fetchMock.get( + "express:/_matrix/client/v3/room_keys/keys/:room_id/:session_id", + { status: 404, body: { errcode: "M_NOT_FOUND" } }, + { name: "getKey" }, + ); - const room = aliceClient.getRoom(ROOM_ID)!; - const event = room.getLiveTimeline().getEvents()[0]; - await awaitDecryption(event, { waitOnDecryptionFailure: true }); + // Send Alice a message that she won't be able to decrypt + syncResponder.sendOrQueueSyncResponse(SYNC_RESPONSE); + await flushBackupRequest(); + const calls = fetchMock.calls("getKey"); + expect(calls.length).toEqual(1); + expect(calls[0][0]).toEqual(EXPECTED_URL); - expect(event.getContent()).toEqual(testData.CLEAR_EVENT.content); + fetchMock.resetHistory(); + + // another message + const event2 = { ...testData.ENCRYPTED_EVENT, event_id: "$event2" }; + const syncResponse2 = { + next_batch: 1, + rooms: { join: { [ROOM_ID]: { timeline: { events: [event2] } } } }, + }; + syncResponder.sendOrQueueSyncResponse(syncResponse2); + await flushBackupRequest(); + expect(fetchMock.calls("getKey").length).toEqual(0); + }); }); describe("recover from backup", () => { @@ -240,14 +315,16 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("megolm-keys backup (%s)", (backe onKeyCached = resolve; }); - const result = await aliceClient.restoreKeyBackupWithRecoveryKey( - testData.BACKUP_DECRYPTION_KEY_BASE58, - undefined, - undefined, - check!.backupInfo!, - { - cacheCompleteCallback: () => onKeyCached(), - }, + const result = await advanceTimersUntil( + aliceClient.restoreKeyBackupWithRecoveryKey( + testData.BACKUP_DECRYPTION_KEY_BASE58, + undefined, + undefined, + check!.backupInfo!, + { + cacheCompleteCallback: () => onKeyCached(), + }, + ), ); expect(result.imported).toStrictEqual(1); @@ -255,7 +332,9 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("megolm-keys backup (%s)", (backe await awaitKeyCached; // The key should be now cached - const afterCache = await aliceClient.restoreKeyBackupWithCache(undefined, undefined, check!.backupInfo!); + const afterCache = await advanceTimersUntil( + aliceClient.restoreKeyBackupWithCache(undefined, undefined, check!.backupInfo!), + ); expect(afterCache.imported).toStrictEqual(1); }); @@ -278,11 +357,13 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("megolm-keys backup (%s)", (backe const check = await aliceCrypto.checkKeyBackupAndEnable(); - const result = await aliceClient.restoreKeyBackupWithRecoveryKey( - testData.BACKUP_DECRYPTION_KEY_BASE58, - ROOM_ID, - testData.MEGOLM_SESSION_DATA.session_id, - check!.backupInfo!, + const result = await advanceTimersUntil( + aliceClient.restoreKeyBackupWithRecoveryKey( + testData.BACKUP_DECRYPTION_KEY_BASE58, + ROOM_ID, + testData.MEGOLM_SESSION_DATA.session_id, + check!.backupInfo!, + ), ); expect(result.imported).toStrictEqual(1); diff --git a/spec/test-utils/test-utils.ts b/spec/test-utils/test-utils.ts index efe0b41037c..d7592803050 100644 --- a/spec/test-utils/test-utils.ts +++ b/spec/test-utils/test-utils.ts @@ -537,8 +537,6 @@ export async function awaitDecryption( }); } -export const emitPromise = (e: EventEmitter, k: string): Promise => new Promise((r) => e.once(k, r)); - export const mkPusher = (extra: Partial = {}): IPusher => ({ app_display_name: "app", app_id: "123", @@ -561,3 +559,25 @@ CRYPTO_BACKENDS["rust-sdk"] = (client: MatrixClient) => client.initRustCrypto(); if (global.Olm) { CRYPTO_BACKENDS["libolm"] = (client: MatrixClient) => client.initCrypto(); } + +export const emitPromise = (e: EventEmitter, k: string): Promise => new Promise((r) => e.once(k, r)); + +/** + * Advance the fake timers in a loop until the given promise resolves or rejects. + * + * Returns the result of the promise. + * + * This can be useful when there are multiple steps in the code which require an iteration of the event loop. + */ +export async function advanceTimersUntil(promise: Promise): Promise { + let resolved = false; + promise.finally(() => { + resolved = true; + }); + + while (!resolved) { + await jest.advanceTimersByTimeAsync(1); + } + + return await promise; +} diff --git a/src/rust-crypto/backup.ts b/src/rust-crypto/backup.ts index 784e5b59995..a040b37ed55 100644 --- a/src/rust-crypto/backup.ts +++ b/src/rust-crypto/backup.ts @@ -29,7 +29,7 @@ import { logger } from "../logger"; import { ClientPrefix, IHttpOpts, MatrixError, MatrixHttpApi, Method } from "../http-api"; import { CryptoEvent, IMegolmSessionData } from "../crypto"; import { TypedEventEmitter } from "../models/typed-event-emitter"; -import { encodeUri } from "../utils"; +import { encodeUri, immediate } from "../utils"; import { OutgoingRequestProcessor } from "./OutgoingRequestProcessor"; import { sleep } from "../utils"; import { BackupDecryptor } from "../common-crypto/CryptoBackend"; @@ -460,7 +460,7 @@ export class RustBackupDecryptor implements BackupDecryptor { for (const [sessionId, sessionData] of Object.entries(ciphertexts)) { try { const decrypted = JSON.parse( - await this.decryptionKey.decryptV1( + this.decryptionKey.decryptV1( sessionData.session_data.ephemeral, sessionData.session_data.mac, sessionData.session_data.ciphertext, @@ -468,6 +468,9 @@ export class RustBackupDecryptor implements BackupDecryptor { ); decrypted.session_id = sessionId; keys.push(decrypted); + + // there might be lots of sessions, so don't hog the event loop + await immediate(); } catch (e) { logger.log("Failed to decrypt megolm session from backup", e, sessionData); } diff --git a/src/rust-crypto/rust-crypto.ts b/src/rust-crypto/rust-crypto.ts index fc39fcf47c9..d02bd4004aa 100644 --- a/src/rust-crypto/rust-crypto.ts +++ b/src/rust-crypto/rust-crypto.ts @@ -160,43 +160,72 @@ export class RustCrypto extends TypedEventEmitter { - const backupKeys: RustSdkCryptoJs.BackupKeys = await this.olmMachine.getBackupKeys(); - if (!backupKeys.decryptionKey) return; - const version = backupKeys.backupVersion; - + public startQueryKeyBackupRateLimited(targetRoomId: string, targetSessionId: string): void { const now = new Date().getTime(); - if ( - !this.sessionLastCheckAttemptedTime[targetSessionId!] || - now - this.sessionLastCheckAttemptedTime[targetSessionId!] > KEY_BACKUP_CHECK_RATE_LIMIT - ) { + const lastCheck = this.sessionLastCheckAttemptedTime[targetSessionId]; + if (!lastCheck || now - lastCheck > KEY_BACKUP_CHECK_RATE_LIMIT) { this.sessionLastCheckAttemptedTime[targetSessionId!] = now; - - const path = encodeUri("/room_keys/keys/$roomId/$sessionId", { - $roomId: targetRoomId, - $sessionId: targetSessionId, + this.queryKeyBackup(targetRoomId, targetSessionId).catch((e) => { + this.logger.error(`Unhandled error while checking key backup for session ${targetSessionId}`, e); }); + } else { + const lastCheckStr = new Date(lastCheck).toISOString(); + this.logger.debug( + `Not checking key backup for session ${targetSessionId} (last checked at ${lastCheckStr})`, + ); + } + } + + /** + * Helper for {@link RustCrypto#startQueryKeyBackupRateLimited}. + * + * Requests the backup and imports it. Doesn't do any rate-limiting. + * + * @param targetRoomId - ID of the room that the session is used in. + * @param targetSessionId - ID of the session for which to check backup. + */ + private async queryKeyBackup(targetRoomId: string, targetSessionId: string): Promise { + const backupKeys: RustSdkCryptoJs.BackupKeys = await this.olmMachine.getBackupKeys(); + if (!backupKeys.decryptionKey) { + this.logger.debug(`Not checking key backup for session ${targetSessionId} (no decryption key)`); + return; + } + + this.logger.debug(`Checking key backup for session ${targetSessionId}`); + + const version = backupKeys.backupVersion; + const path = encodeUri("/room_keys/keys/$roomId/$sessionId", { + $roomId: targetRoomId, + $sessionId: targetSessionId, + }); - const res = await this.http.authedRequest(Method.Get, path, { version }, undefined, { + let res: KeyBackupSession; + try { + res = await this.http.authedRequest(Method.Get, path, { version }, undefined, { prefix: ClientPrefix.V3, }); + } catch (e) { + this.logger.info(`No luck requesting key backup for session ${targetSessionId}: ${e}`); + return; + } - if (this.stopped) return; + if (this.stopped) return; - const backupDecryptor = new RustBackupDecryptor(backupKeys.decryptionKey); - if (res) { - const sessionsToImport: Record = {}; - sessionsToImport[targetSessionId] = res; - const keys = await backupDecryptor.decryptSessions(sessionsToImport); - for (const k of keys) { - k.room_id = targetRoomId!; - } - await this.importRoomKeys(keys); - } + const backupDecryptor = new RustBackupDecryptor(backupKeys.decryptionKey); + const sessionsToImport: Record = { [targetSessionId]: res }; + const keys = await backupDecryptor.decryptSessions(sessionsToImport); + for (const k of keys) { + k.room_id = targetRoomId; } + await this.importRoomKeys(keys); } /** @@ -1633,7 +1662,10 @@ class EventDecryptor { session: content.sender_key + "|" + content.session_id, }, ); - this.crypto.queryKeyBackupRateLimited(event.getRoomId()!, event.getWireContent().session_id!); + this.crypto.startQueryKeyBackupRateLimited( + event.getRoomId()!, + event.getWireContent().session_id!, + ); break; } case RustSdkCryptoJs.DecryptionErrorCode.UnknownMessageIndex: { @@ -1644,7 +1676,10 @@ class EventDecryptor { session: content.sender_key + "|" + content.session_id, }, ); - this.crypto.queryKeyBackupRateLimited(event.getRoomId()!, event.getWireContent().session_id!); + this.crypto.startQueryKeyBackupRateLimited( + event.getRoomId()!, + event.getWireContent().session_id!, + ); break; } // We don't map MismatchedIdentityKeys for now, as there is no equivalent in legacy.