diff --git a/Signal/Provisioning/ProvisioningCoordinatorImpl.swift b/Signal/Provisioning/ProvisioningCoordinatorImpl.swift index 869bf81ef63..8c2a2ffe986 100644 --- a/Signal/Provisioning/ProvisioningCoordinatorImpl.swift +++ b/Signal/Provisioning/ProvisioningCoordinatorImpl.swift @@ -4,6 +4,7 @@ // import Foundation +import LibSignalClient public import SignalServiceKit public class ProvisioningCoordinatorImpl: ProvisioningCoordinator { @@ -246,7 +247,7 @@ public class ProvisioningCoordinatorImpl: ProvisioningCoordinator { if FeatureFlags.linkAndSyncSecondary, - let ephemeralBackupKey = EphemeralBackupKey(provisioningMessage: provisionMessage) + let ephemeralBackupKey = BackupKey(provisioningMessage: provisionMessage) { do { try await self.linkAndSyncManager.waitForBackupAndRestore( diff --git a/Signal/Registration/RegistrationCoordinatorDependencies.swift b/Signal/Registration/RegistrationCoordinatorDependencies.swift index 565bc2a16fe..6756f02da73 100644 --- a/Signal/Registration/RegistrationCoordinatorDependencies.swift +++ b/Signal/Registration/RegistrationCoordinatorDependencies.swift @@ -17,6 +17,7 @@ public struct RegistrationCoordinatorDependencies { public let featureFlags: RegistrationCoordinatorImpl.Shims.FeatureFlags public let keyValueStoreFactory: KeyValueStoreFactory public let localUsernameManager: LocalUsernameManager + public let messageBackupKeyMaterial: MessageBackupKeyMaterial public let messageBackupErrorPresenter: MessageBackupErrorPresenter public let messageBackupManager: MessageBackupManager public let messagePipelineSupervisor: RegistrationCoordinatorImpl.Shims.MessagePipelineSupervisor @@ -51,6 +52,7 @@ public struct RegistrationCoordinatorDependencies { featureFlags: RegistrationCoordinatorImpl.Wrappers.FeatureFlags(), keyValueStoreFactory: DependenciesBridge.shared.keyValueStoreFactory, localUsernameManager: DependenciesBridge.shared.localUsernameManager, + messageBackupKeyMaterial: DependenciesBridge.shared.messageBackupKeyMaterial, messageBackupErrorPresenter: DependenciesBridge.shared.messageBackupErrorPresenter, messageBackupManager: DependenciesBridge.shared.messageBackupManager, messagePipelineSupervisor: RegistrationCoordinatorImpl.Wrappers.MessagePipelineSupervisor(SSKEnvironment.shared.messagePipelineSupervisorRef), diff --git a/Signal/Registration/RegistrationCoordinatorImpl.swift b/Signal/Registration/RegistrationCoordinatorImpl.swift index 5f3a95ae573..b9f36ff425b 100644 --- a/Signal/Registration/RegistrationCoordinatorImpl.swift +++ b/Signal/Registration/RegistrationCoordinatorImpl.swift @@ -464,10 +464,14 @@ public class RegistrationCoordinatorImpl: RegistrationCoordinator { auth: identity.chatServiceAuth ) } + // Get Backup Key + let backupKey = try self.deps.db.read { tx in + return try self.deps.messageBackupKeyMaterial.backupKey(type: .messages, tx: tx) + } try await self.deps.messageBackupManager.importEncryptedBackup( fileUrl: fileUrl, localIdentifiers: identity.localIdentifiers, - mode: .remote + backupKey: backupKey ) self.inMemoryState.hasRestoredFromLocalMessageBackup = true Logger.info("Finished restore") diff --git a/Signal/src/ViewControllers/AppSettings/Internal/InternalSettingsViewController.swift b/Signal/src/ViewControllers/AppSettings/Internal/InternalSettingsViewController.swift index 4ea34919bb0..dbd30fe3e26 100644 --- a/Signal/src/ViewControllers/AppSettings/Internal/InternalSettingsViewController.swift +++ b/Signal/src/ViewControllers/AppSettings/Internal/InternalSettingsViewController.swift @@ -263,6 +263,7 @@ private extension InternalSettingsViewController { } func exportMessageBackupProto() { + let messageBackupKeyMaterial = DependenciesBridge.shared.messageBackupKeyMaterial let messageBackupManager = DependenciesBridge.shared.messageBackupManager let tsAccountManager = DependenciesBridge.shared.tsAccountManager @@ -286,9 +287,12 @@ private extension InternalSettingsViewController { Task { do { + let backupKey = try SSKEnvironment.shared.databaseStorageRef.read { tx in + try messageBackupKeyMaterial.backupKey(type: .messages, tx: tx.asV2Read) + } let metadata = try await messageBackupManager.exportEncryptedBackup( localIdentifiers: localIdentifiers, - mode: .remote + backupKey: backupKey ) await MainActor.run { let actionSheet = ActionSheetController(title: "Choose backup destination:") diff --git a/Signal/src/ViewControllers/AppSettings/Linked Devices/LinkDeviceViewController.swift b/Signal/src/ViewControllers/AppSettings/Linked Devices/LinkDeviceViewController.swift index bd0ce198f19..ee42bd858ac 100644 --- a/Signal/src/ViewControllers/AppSettings/Linked Devices/LinkDeviceViewController.swift +++ b/Signal/src/ViewControllers/AppSettings/Linked Devices/LinkDeviceViewController.swift @@ -192,7 +192,7 @@ class LinkDeviceViewController: OWSViewController { return mrbk } - let ephemeralBackupKey: EphemeralBackupKey? + let ephemeralBackupKey: BackupKey? if isLinkAndSyncEnabled, deviceProvisioningUrl.capabilities.contains(where: { $0 == .linknsync }) diff --git a/Signal/test/Provisioning/ProvisioningCoordinatorTest.swift b/Signal/test/Provisioning/ProvisioningCoordinatorTest.swift index a6b3820c339..9fe42262148 100644 --- a/Signal/test/Provisioning/ProvisioningCoordinatorTest.swift +++ b/Signal/test/Provisioning/ProvisioningCoordinatorTest.swift @@ -191,12 +191,12 @@ private class MockLinkAndSyncManager: LinkAndSyncManager { func setIsLinkAndSyncEnabledOnPrimary(_ isEnabled: Bool, tx: DBWriteTransaction) {} - func generateEphemeralBackupKey() -> EphemeralBackupKey { + func generateEphemeralBackupKey() -> BackupKey { return .forTesting() } func waitForLinkingAndUploadBackup( - ephemeralBackupKey: EphemeralBackupKey, + ephemeralBackupKey: BackupKey, tokenId: DeviceProvisioningTokenId ) async throws(PrimaryLinkNSyncError) { return @@ -205,7 +205,7 @@ private class MockLinkAndSyncManager: LinkAndSyncManager { func waitForBackupAndRestore( localIdentifiers: LocalIdentifiers, auth: ChatServiceAuth, - ephemeralBackupKey: EphemeralBackupKey + ephemeralBackupKey: BackupKey ) async throws(SecondaryLinkNSyncError) { return } diff --git a/Signal/test/Registration/RegistrationCoordinatorTest.swift b/Signal/test/Registration/RegistrationCoordinatorTest.swift index 53764395b52..b41c035aafe 100644 --- a/Signal/test/Registration/RegistrationCoordinatorTest.swift +++ b/Signal/test/Registration/RegistrationCoordinatorTest.swift @@ -114,6 +114,7 @@ public class RegistrationCoordinatorTest: XCTestCase { featureFlags: featureFlags, keyValueStoreFactory: InMemoryKeyValueStoreFactory(), localUsernameManager: localUsernameManagerMock, + messageBackupKeyMaterial: MessageBackupKeyMaterialMock(), messageBackupErrorPresenter: NoOpMessageBackupErrorPresenter(), messageBackupManager: MessageBackupManagerMock(), messagePipelineSupervisor: mockMessagePipelineSupervisor, diff --git a/SignalServiceKit/Cryptography/Cryptography.swift b/SignalServiceKit/Cryptography/Cryptography.swift index 6cdf87882f5..08c11f8060b 100644 --- a/SignalServiceKit/Cryptography/Cryptography.swift +++ b/SignalServiceKit/Cryptography/Cryptography.swift @@ -14,7 +14,7 @@ public enum Cryptography { public static func computeSHA256DigestOfFile(at url: URL) throws -> Data { let file = try LocalFileHandle(url: url) var sha256 = SHA256() - var buffer = Data(count: diskPageSize) + var buffer = Data(count: Constants.diskPageSize) var bytesRead: Int repeat { bytesRead = try file.read(into: &buffer) @@ -63,15 +63,16 @@ public protocol EncryptedFileHandle { } public extension Cryptography { - - fileprivate static let hmac256KeyLength = 32 - fileprivate static let hmac256OutputLength = 32 - fileprivate static let aescbcIVLength = 16 - fileprivate static let aesKeySize = 32 - fileprivate static let aescbcBlockLength = 16 - fileprivate static var concatenatedEncryptionKeyLength: Int { aesKeySize + hmac256KeyLength } - /// Optimize reads/writes by reading this many bytes at once; best balance of performance/memory use from testing in practice. - fileprivate static let diskPageSize = 8192 + enum Constants { + static let hmac256KeyLength = 32 + static let hmac256OutputLength = 32 + static let aescbcIVLength = 16 + static let aesKeySize = 32 + static let aescbcBlockLength = 16 + static var concatenatedEncryptionKeyLength: Int { aesKeySize + hmac256KeyLength } + /// Optimize reads/writes by reading this many bytes at once; best balance of performance/memory use from testing in practice. + static let diskPageSize = 8192 + } static func paddedSize(unpaddedSize: UInt) -> UInt { // In order to obsfucate attachment size on the wire, we round up @@ -84,7 +85,7 @@ public extension Cryptography { static func randomAttachmentEncryptionKey() -> Data { // The metadata "key" is actually a concatentation of the // encryption key and the hmac key. - return Randomness.generateRandomBytes(UInt(concatenatedEncryptionKeyLength)) + return Randomness.generateRandomBytes(UInt(Constants.concatenatedEncryptionKeyLength)) } /// Encrypt an input file to a provided output file location. @@ -135,7 +136,7 @@ public extension Cryptography { encryptionKey inputKey: Data?, applyExtraPadding: Bool ) throws -> EncryptionMetadata { - if let inputKey, inputKey.count != concatenatedEncryptionKeyLength { + if let inputKey, inputKey.count != Constants.concatenatedEncryptionKeyLength { throw OWSAssertionError("Invalid encryption key length") } @@ -155,12 +156,12 @@ public extension Cryptography { let outputFile = try FileHandle(forWritingTo: encryptedUrl) let inputKey = inputKey ?? randomAttachmentEncryptionKey() - let encryptionKey = inputKey.prefix(aesKeySize) - let hmacKey = inputKey.suffix(hmac256KeyLength) + let encryptionKey = inputKey.prefix(Constants.aesKeySize) + let hmacKey = inputKey.suffix(Constants.hmac256KeyLength) return try _encryptAttachment( enumerateInputInBlocks: { closure in - var buffer = Data(count: diskPageSize) + var buffer = Data(count: Constants.diskPageSize) var totalBytesRead: UInt = 0 var bytesRead: Int repeat { @@ -208,15 +209,15 @@ public extension Cryptography { let outputFileHandle = try FileHandle(forWritingTo: outputFileURL) let inputKey = inputKey ?? randomAttachmentEncryptionKey() - let encryptionKey = inputKey.prefix(aesKeySize) - let hmacKey = inputKey.suffix(hmac256KeyLength) + let encryptionKey = inputKey.prefix(Constants.aesKeySize) + let hmacKey = inputKey.suffix(Constants.hmac256KeyLength) return try _encryptAttachment( enumerateInputInBlocks: { closure in var totalBytesRead: UInt = 0 var bytesRead: Int repeat { - let data = try encryptedFileHandle.read(upToCount: UInt32(diskPageSize)) + let data = try encryptedFileHandle.read(upToCount: UInt32(Constants.diskPageSize)) bytesRead = data.count if bytesRead > 0 { totalBytesRead += UInt(bytesRead) @@ -250,13 +251,13 @@ public extension Cryptography { iv: Data? = nil, applyExtraPadding: Bool = false ) throws -> (Data, EncryptionMetadata) { - if let inputKey, inputKey.count != concatenatedEncryptionKeyLength { + if let inputKey, inputKey.count != Constants.concatenatedEncryptionKeyLength { throw OWSAssertionError("Invalid encryption key length") } let inputKey = inputKey ?? randomAttachmentEncryptionKey() - let encryptionKey = inputKey.prefix(aesKeySize) - let hmacKey = inputKey.suffix(hmac256KeyLength) + let encryptionKey = inputKey.prefix(Constants.aesKeySize) + let hmacKey = inputKey.suffix(Constants.hmac256KeyLength) var outputData = Data() let encryptionMetadata = try _encryptAttachment( @@ -305,12 +306,12 @@ public extension Cryptography { let iv: Data if let inputIV { - if inputIV.count != aescbcIVLength { + if inputIV.count != Constants.aescbcIVLength { throw OWSAssertionError("Invalid IV length") } iv = inputIV } else { - iv = Randomness.generateRandomBytes(UInt(aescbcIVLength)) + iv = Randomness.generateRandomBytes(UInt(Constants.aescbcIVLength)) } var hmac = HMAC(key: .init(data: hmacKey)) @@ -466,7 +467,7 @@ public extension Cryptography { at: encryptedUrl, metadata: metadata, // Most efficient to write one page size at a time. - outputBlockSize: UInt32(diskPageSize) + outputBlockSize: UInt32(Constants.diskPageSize) ) { plaintextDataBlock in outputFile.write(plaintextDataBlock) } @@ -504,7 +505,7 @@ public extension Cryptography { metadata: metadata, validateHmacAndDigest: false, // Most efficient to write one page size at a time. - outputBlockSize: UInt32(diskPageSize) + outputBlockSize: UInt32(Constants.diskPageSize) ) { plaintextDataBlock in outputFile.write(plaintextDataBlock) } @@ -594,7 +595,7 @@ public extension Cryptography { if validateHmacAndDigest { // The metadata "key" is actually a concatentation of the // encryption key and the hmac key. - let hmacKey = metadata.key.suffix(hmac256KeyLength) + let hmacKey = metadata.key.suffix(Constants.hmac256KeyLength) hmac = HMAC(key: .init(data: hmacKey)) if metadata.digest != nil { @@ -641,7 +642,7 @@ public extension Cryptography { if validateHmacAndDigest, var hmac { // Add the last padding bytes to the hmac/digest. - var remainingPaddingLength = aescbcIVLength + inputFile.ciphertextLength - inputFile.file.offsetInFile + var remainingPaddingLength = Constants.aescbcIVLength + inputFile.ciphertextLength - inputFile.file.offsetInFile while remainingPaddingLength > 0 { let lengthToRead = min(remainingPaddingLength, 1024 * 16) let paddingCiphertext = try inputFile.file.readData(ofLength: Int(lengthToRead)) @@ -657,7 +658,7 @@ public extension Cryptography { // At this point we are done with the EncryptedFileHandle, so grab its internal // FileHandle for reading directly. // (This breaks EncryptedFileHandle's invariants and renders it unuseable). - let inputFileHmac = try inputFile.file.readData(ofLength: hmac256OutputLength) + let inputFileHmac = try inputFile.file.readData(ofLength: Constants.hmac256OutputLength) guard hmacResult.ows_constantTimeIsEqual(to: inputFileHmac) else { Logger.debug("Bad hmac. Their hmac: \(inputFileHmac.hexadecimalString), our hmac: \(hmacResult.hexadecimalString)") throw OWSAssertionError("Bad hmac") @@ -722,7 +723,7 @@ public extension Cryptography { /// CCCryptor documentation says: "the output length is never larger than the input length plus the block size." /// To ensure we always have enough room in the buffer, we allocate two block lengths. /// `numBytesInPlaintextBuffer` indicates how many bytes (starting from 0) contain non-stale content. - private var plaintextBuffer = Data(repeating: 0, count: aescbcBlockLength * 2) + private var plaintextBuffer = Data(repeating: 0, count: Constants.aescbcBlockLength * 2) private var numBytesInPlaintextBuffer = 0 init( @@ -734,22 +735,22 @@ public extension Cryptography { throw OWSAssertionError("Missing attachment file.") } - guard encryptionKey.count == (aesKeySize + hmac256KeyLength) else { + guard encryptionKey.count == (Constants.aesKeySize + Constants.hmac256KeyLength) else { throw OWSAssertionError("Encryption key shorter than combined key length") } self.file = try LocalFileHandle(url: encryptedUrl) - let cryptoOverheadLength = aescbcIVLength + hmac256OutputLength + let cryptoOverheadLength = Constants.aescbcIVLength + Constants.hmac256OutputLength self.ciphertextLength = file.fileLength - cryptoOverheadLength // The metadata "key" is actually a concatentation of the // encryption key and the hmac key. - self.encryptionKey = encryptionKey.prefix(aesKeySize) + self.encryptionKey = encryptionKey.prefix(Constants.aesKeySize) // This first N bytes of the encrypted file are the IV - self.iv = try file.readData(ofLength: aescbcIVLength) - guard iv.count == aescbcIVLength else { + self.iv = try file.readData(ofLength: Constants.aescbcIVLength) + guard iv.count == Constants.aescbcIVLength else { throw OWSAssertionError("Failed to read IV") } @@ -765,17 +766,17 @@ public extension Cryptography { // determine the pkcs7 padding length. let prePaddingBlockOffset = file.fileLength // Not the hmac - - hmac256OutputLength + - Constants.hmac256OutputLength // Start of the previous block which has the pkcs7 padding - - aescbcBlockLength + - Constants.aescbcBlockLength // Start of the block before that which has its iv - - aescbcBlockLength + - Constants.aescbcBlockLength try file.seek(toFileOffset: prePaddingBlockOffset) // Read the preceding block, use it as the IV. - let paddingBlockIV = try file.readData(ofLength: aescbcBlockLength) + let paddingBlockIV = try file.readData(ofLength: Constants.aescbcBlockLength) // Read the block itself - let paddingBlockCiphertext = try file.readData(ofLength: aescbcBlockLength) + let paddingBlockCiphertext = try file.readData(ofLength: Constants.aescbcBlockLength) // Decrypt, but use ecb instead of cbc mode; we _want_ the plaintext // of the pkcs7 padding bytes; doing the block cipher XOR'ing ourselves @@ -786,7 +787,7 @@ public extension Cryptography { options: .ecbMode, key: self.encryptionKey, // Irrelevant in ecb mode. - iv: Data(repeating: 0, count: aescbcBlockLength) + iv: Data(repeating: 0, count: Constants.aescbcBlockLength) ) var paddingBlockPlaintext = try paddingCipherContext.update(paddingBlockCiphertext) @@ -803,11 +804,11 @@ public extension Cryptography { self._plaintextLength = ciphertextLength - Int(paddingLength) // Move the file handle to the start of the encrypted data (after IV) - try file.seek(toFileOffset: aescbcIVLength) + try file.seek(toFileOffset: Constants.aescbcIVLength) } // We should be just after the iv at this point. - owsAssertDebug(file.offsetInFile == aescbcIVLength) + owsAssertDebug(file.offsetInFile == Constants.aescbcIVLength) self.cipherContext = try CipherContext( operation: .decrypt, @@ -835,14 +836,14 @@ public extension Cryptography { // The offset in the encrypted file rounds down to the start of the block. // Add 1 because the first block in the encrypted file is the iv which isn't // represented in the virtual plaintext's address space. - var (desiredBlock, desiredOffsetInBlock) = toOffset.quotientAndRemainder(dividingBy: aescbcBlockLength) + var (desiredBlock, desiredOffsetInBlock) = toOffset.quotientAndRemainder(dividingBy: Constants.aescbcBlockLength) desiredBlock += 1 // The preceding block serves as the iv for decryption. let ivBlock = desiredBlock - 1 - let ivOffset = ivBlock * aescbcBlockLength + let ivOffset = ivBlock * Constants.aescbcBlockLength try file.seek(toFileOffset: ivOffset) - let iv = try file.readData(ofLength: aescbcBlockLength) + let iv = try file.readData(ofLength: Constants.aescbcBlockLength) // Initialize a new context with the preceding block as the iv. self.cipherContext = try CipherContext( @@ -925,9 +926,9 @@ public extension Cryptography { result += 16 - remainder } // Read at most the page size; no point in reading more. - result = min(result, diskPageSize) + result = min(result, Constants.diskPageSize) // But never read past the end of the file. - result = min(result, aescbcIVLength + ciphertextLength - file.offsetInFile) + result = min(result, Constants.aescbcIVLength + ciphertextLength - file.offsetInFile) return result } var numCiphertextBytesToRead = computeNumCiphertextBytesToRead() diff --git a/SignalServiceKit/Dependencies/DependenciesBridge.swift b/SignalServiceKit/Dependencies/DependenciesBridge.swift index 8f32708759a..f5c15918e42 100644 --- a/SignalServiceKit/Dependencies/DependenciesBridge.swift +++ b/SignalServiceKit/Dependencies/DependenciesBridge.swift @@ -106,6 +106,7 @@ public class DependenciesBridge { public let mediaBandwidthPreferenceStore: MediaBandwidthPreferenceStore public let mediaGalleryResourceManager: MediaGalleryResourceManager public let messageBackupErrorPresenter: MessageBackupErrorPresenter + public let messageBackupKeyMaterial: MessageBackupKeyMaterial public let messageBackupManager: MessageBackupManager public let messageStickerManager: MessageStickerManager public let mrbkStore: MediaRootBackupKeyStore @@ -228,6 +229,7 @@ public class DependenciesBridge { mediaBandwidthPreferenceStore: MediaBandwidthPreferenceStore, mediaGalleryResourceManager: MediaGalleryResourceManager, messageBackupErrorPresenter: MessageBackupErrorPresenter, + messageBackupKeyMaterial: MessageBackupKeyMaterial, messageBackupManager: MessageBackupManager, messageStickerManager: MessageStickerManager, mrbkStore: MediaRootBackupKeyStore, @@ -347,6 +349,7 @@ public class DependenciesBridge { self.mediaBandwidthPreferenceStore = mediaBandwidthPreferenceStore self.mediaGalleryResourceManager = mediaGalleryResourceManager self.messageBackupErrorPresenter = messageBackupErrorPresenter + self.messageBackupKeyMaterial = messageBackupKeyMaterial self.messageBackupManager = messageBackupManager self.messageStickerManager = messageStickerManager self.mrbkStore = mrbkStore diff --git a/SignalServiceKit/Devices/LinkAndSyncManager.swift b/SignalServiceKit/Devices/LinkAndSyncManager.swift index 77a31efcd63..9d5a0a8ff91 100644 --- a/SignalServiceKit/Devices/LinkAndSyncManager.swift +++ b/SignalServiceKit/Devices/LinkAndSyncManager.swift @@ -3,23 +3,19 @@ // SPDX-License-Identifier: AGPL-3.0-only // -public struct EphemeralBackupKey { - public let data: Data +public import LibSignalClient - fileprivate init(_ data: Data) { - self.data = data - } - - public init?(provisioningMessage: ProvisionMessage) { +extension BackupKey { + public convenience init?(provisioningMessage: ProvisionMessage) { guard let data = provisioningMessage.ephemeralBackupKey else { return nil } - self.init(data) + try? self.init(contents: Array(data)) } #if TESTABLE_BUILD - public static func forTesting() -> EphemeralBackupKey { - return EphemeralBackupKey(Randomness.generateRandomBytes(UInt(SVR.DerivedKey.backupKeyLength))) + public static func forTesting() -> BackupKey { + return try! BackupKey(contents: Array(Randomness.generateRandomBytes(UInt(SVR.DerivedKey.backupKeyLength)))) } #endif } @@ -53,14 +49,14 @@ public protocol LinkAndSyncManager { /// This key should be included in the provisioning message and then used to encrypt the backup proto we send. /// /// - returns The ephemeral key to use, or nil if link'n'sync should not be used. - func generateEphemeralBackupKey() -> EphemeralBackupKey + func generateEphemeralBackupKey() -> BackupKey /// **Call this on the primary device!** /// Once the primary sends the provisioning message to the linked device, call this method /// to wait on the linked device to link, generate a backup, and upload it. Once this method returns, /// the primary's role is complete and the user can exit. func waitForLinkingAndUploadBackup( - ephemeralBackupKey: EphemeralBackupKey, + ephemeralBackupKey: BackupKey, tokenId: DeviceProvisioningTokenId ) async throws(PrimaryLinkNSyncError) @@ -71,7 +67,7 @@ public protocol LinkAndSyncManager { func waitForBackupAndRestore( localIdentifiers: LocalIdentifiers, auth: ChatServiceAuth, - ephemeralBackupKey: EphemeralBackupKey + ephemeralBackupKey: BackupKey ) async throws(SecondaryLinkNSyncError) } @@ -118,14 +114,14 @@ public class LinkAndSyncManagerImpl: LinkAndSyncManager { kvStore.setBool(isEnabled, key: Constants.enabledOnPrimaryKey, transaction: tx) } - public func generateEphemeralBackupKey() -> EphemeralBackupKey { + public func generateEphemeralBackupKey() -> BackupKey { owsAssertDebug(FeatureFlags.linkAndSyncTogglePrimary || FeatureFlags.linkAndSyncOverridePrimary) owsAssertDebug(tsAccountManager.registrationStateWithMaybeSneakyTransaction.isPrimaryDevice == true) - return EphemeralBackupKey(Randomness.generateRandomBytes(UInt(SVR.DerivedKey.backupKeyLength))) + return try! BackupKey(contents: Array(Randomness.generateRandomBytes(UInt(SVR.DerivedKey.backupKeyLength)))) } public func waitForLinkingAndUploadBackup( - ephemeralBackupKey: EphemeralBackupKey, + ephemeralBackupKey: BackupKey, tokenId: DeviceProvisioningTokenId ) async throws(PrimaryLinkNSyncError) { guard FeatureFlags.linkAndSyncTogglePrimary || FeatureFlags.linkAndSyncOverridePrimary else { @@ -161,7 +157,7 @@ public class LinkAndSyncManagerImpl: LinkAndSyncManager { public func waitForBackupAndRestore( localIdentifiers: LocalIdentifiers, auth: ChatServiceAuth, - ephemeralBackupKey: EphemeralBackupKey + ephemeralBackupKey: BackupKey ) async throws(SecondaryLinkNSyncError) { guard FeatureFlags.linkAndSyncSecondary else { owsFailDebug("link'n'sync not available") @@ -222,13 +218,13 @@ public class LinkAndSyncManagerImpl: LinkAndSyncManager { } private func generateBackup( - ephemeralBackupKey: EphemeralBackupKey, + ephemeralBackupKey: BackupKey, localIdentifiers: LocalIdentifiers ) async throws(PrimaryLinkNSyncError) -> Upload.EncryptedBackupUploadMetadata { do { return try await messageBackupManager.exportEncryptedBackup( localIdentifiers: localIdentifiers, - mode: .linknsync(ephemeralBackupKey) + backupKey: ephemeralBackupKey ) } catch let error { owsFailDebug("Unable to generate link'n'sync backup: \(error)") @@ -317,14 +313,14 @@ public class LinkAndSyncManagerImpl: LinkAndSyncManager { private func downloadEphemeralBackup( waitForBackupResponse: Requests.WaitForLinkNSyncBackupUploadResponse, - ephemeralBackupKey: EphemeralBackupKey + ephemeralBackupKey: BackupKey ) async throws(SecondaryLinkNSyncError) -> URL { do { return try await attachmentDownloadManager.downloadTransientAttachment( metadata: AttachmentDownloads.DownloadMetadata( mimeType: MimeType.applicationOctetStream.rawValue, cdnNumber: waitForBackupResponse.cdn, - encryptionKey: ephemeralBackupKey.data, + encryptionKey: ephemeralBackupKey.serialize().asData, source: .linkNSyncBackup(cdnKey: waitForBackupResponse.key) ) ).awaitable() @@ -340,13 +336,13 @@ public class LinkAndSyncManagerImpl: LinkAndSyncManager { private func restoreEphemeralBackup( fileUrl: URL, localIdentifiers: LocalIdentifiers, - ephemeralBackupKey: EphemeralBackupKey + ephemeralBackupKey: BackupKey ) async throws(SecondaryLinkNSyncError) { do { try await messageBackupManager.importEncryptedBackup( fileUrl: fileUrl, localIdentifiers: localIdentifiers, - mode: .linknsync(ephemeralBackupKey) + backupKey: ephemeralBackupKey ) } catch { owsFailDebug("Unable to restore link'n'sync backup: \(error)") diff --git a/SignalServiceKit/Devices/OWSDeviceProvisioner.swift b/SignalServiceKit/Devices/OWSDeviceProvisioner.swift index eb5f8ece9c1..4ed2ed7e0d3 100644 --- a/SignalServiceKit/Devices/OWSDeviceProvisioner.swift +++ b/SignalServiceKit/Devices/OWSDeviceProvisioner.swift @@ -24,7 +24,7 @@ public final class OWSDeviceProvisioner { private let profileKey: Data private let masterKey: Data private let mrbk: Data - private let ephemeralBackupKey: EphemeralBackupKey? + private let ephemeralBackupKey: BackupKey? private let readReceiptsEnabled: Bool private let provisioningService: DeviceProvisioningService @@ -41,7 +41,7 @@ public final class OWSDeviceProvisioner { profileKey: Data, masterKey: Data, mrbk: Data, - ephemeralBackupKey: EphemeralBackupKey?, + ephemeralBackupKey: BackupKey?, readReceiptsEnabled: Bool, provisioningService: DeviceProvisioningService, schedulers: Schedulers @@ -99,7 +99,7 @@ public final class OWSDeviceProvisioner { messageBuilder.setMasterKey(masterKey) messageBuilder.setMediaRootBackupKey(mrbk) if let ephemeralBackupKey { - messageBuilder.setEphemeralBackupKey(ephemeralBackupKey.data) + messageBuilder.setEphemeralBackupKey(ephemeralBackupKey.serialize().asData) } let plainTextProvisionMessage = try messageBuilder.buildSerializedData() diff --git a/SignalServiceKit/Environment/AppSetup.swift b/SignalServiceKit/Environment/AppSetup.swift index 24cc01be9b6..500d4b5c2fe 100644 --- a/SignalServiceKit/Environment/AppSetup.swift +++ b/SignalServiceKit/Environment/AppSetup.swift @@ -337,7 +337,11 @@ public class AppSetup { twoFAManager: SVR2.Wrappers.OWS2FAManager(ows2FAManager) ) - let messageBackupKeyMaterial = MessageBackupKeyMaterialImpl(svr: svr) + let mrbkStore = MediaRootBackupKeyStore(keyValueStoreFactory: keyValueStoreFactory) + let messageBackupKeyMaterial = MessageBackupKeyMaterialImpl( + mrbkStore: mrbkStore, + svr: svr + ) let messageBackupRequestManager = MessageBackupRequestManagerImpl( dateProvider: dateProvider, db: db, @@ -1027,7 +1031,6 @@ public class AppSetup { let backupThreadStore = MessageBackupThreadStore(threadStore: threadStore) let backupInteractionStore = MessageBackupInteractionStore(interactionStore: interactionStore) let backupStoryStore = MessageBackupStoryStore(storyStore: storyStore) - let mrbkStore = MediaRootBackupKeyStore(keyValueStoreFactory: keyValueStoreFactory) let messageBackupErrorPresenter = messageBackupErrorPresenterFactory.build( db: db, @@ -1287,6 +1290,7 @@ public class AppSetup { mediaBandwidthPreferenceStore: mediaBandwidthPreferenceStore, mediaGalleryResourceManager: mediaGalleryResourceManager, messageBackupErrorPresenter: messageBackupErrorPresenter, + messageBackupKeyMaterial: messageBackupKeyMaterial, messageBackupManager: messageBackupManager, messageStickerManager: messageStickerManager, mrbkStore: mrbkStore, diff --git a/SignalServiceKit/MessageBackup/Archivers/ChatItem/MessageBackupMessageAttachmentArchiver.swift b/SignalServiceKit/MessageBackup/Archivers/ChatItem/MessageBackupMessageAttachmentArchiver.swift index 75634fb88cb..c422b813030 100644 --- a/SignalServiceKit/MessageBackup/Archivers/ChatItem/MessageBackupMessageAttachmentArchiver.swift +++ b/SignalServiceKit/MessageBackup/Archivers/ChatItem/MessageBackupMessageAttachmentArchiver.swift @@ -423,7 +423,7 @@ internal class MessageBackupMessageAttachmentArchiver: MessageBackupProtoArchive internal static func currentUploadEra() throws -> String { // TODO: [Backups] use actual subscription id. For now use a fixed, // arbitrary id, so that it never changes. - let backupSubscriptionId = Data(repeating: 5, count: 32) + let backupSubscriptionId = Data(repeating: 4, count: 32) return try Attachment.uploadEra(backupSubscriptionId: backupSubscriptionId) } diff --git a/SignalServiceKit/MessageBackup/Attachments/BackupAttachmentDownloadManager.swift b/SignalServiceKit/MessageBackup/Attachments/BackupAttachmentDownloadManager.swift index 941eefd17cd..ec06073e606 100644 --- a/SignalServiceKit/MessageBackup/Attachments/BackupAttachmentDownloadManager.swift +++ b/SignalServiceKit/MessageBackup/Attachments/BackupAttachmentDownloadManager.swift @@ -285,7 +285,7 @@ public class BackupAttachmentDownloadManagerImpl: BackupAttachmentDownloadManage self.tsAccountManager.localIdentifiers(tx: tx)?.aci, currentUploadEra, try self.needsToQueryListMedia(currentUploadEra: currentUploadEra, tx: tx), - try messageBackupKeyMaterial.backupKey(mode: .remote, tx: tx) + try messageBackupKeyMaterial.backupKey(type: .media, tx: tx) ) } guard needsToQuery else { @@ -297,7 +297,7 @@ public class BackupAttachmentDownloadManagerImpl: BackupAttachmentDownloadManage } let messageBackupAuth = try await messageBackupRequestManager.fetchBackupServiceAuth( - for: .download(.media), + for: .media, localAci: localAci, auth: .implicit() ) diff --git a/SignalServiceKit/MessageBackup/Attachments/BackupAttachmentUploadManager.swift b/SignalServiceKit/MessageBackup/Attachments/BackupAttachmentUploadManager.swift index 70e10ed7f27..f37d4b844a4 100644 --- a/SignalServiceKit/MessageBackup/Attachments/BackupAttachmentUploadManager.swift +++ b/SignalServiceKit/MessageBackup/Attachments/BackupAttachmentUploadManager.swift @@ -203,7 +203,7 @@ public class BackupAttachmentUploadManagerImpl: BackupAttachmentUploadManager { let messageBackupAuth: MessageBackupServiceAuth do { messageBackupAuth = try await messageBackupRequestManager.fetchBackupServiceAuth( - for: .upload(.media), + for: .media, localAci: localAci, auth: .implicit() ) diff --git a/SignalServiceKit/MessageBackup/Attachments/OrphanedBackupAttachmentManager.swift b/SignalServiceKit/MessageBackup/Attachments/OrphanedBackupAttachmentManager.swift index e7e028baaa0..bb8a3dae262 100644 --- a/SignalServiceKit/MessageBackup/Attachments/OrphanedBackupAttachmentManager.swift +++ b/SignalServiceKit/MessageBackup/Attachments/OrphanedBackupAttachmentManager.swift @@ -345,7 +345,7 @@ public class OrphanedBackupAttachmentManagerImpl: OrphanedBackupAttachmentManage let messageBackupAuth: MessageBackupServiceAuth do { messageBackupAuth = try await messageBackupRequestManager.fetchBackupServiceAuth( - for: .delete(.media), + for: .media, localAci: localAci, auth: .implicit() ) diff --git a/SignalServiceKit/MessageBackup/FileStreams/MessageBackupProtoStreamProvider.swift b/SignalServiceKit/MessageBackup/FileStreams/MessageBackupProtoStreamProvider.swift index b6a4acc1f45..14203b1b89e 100644 --- a/SignalServiceKit/MessageBackup/FileStreams/MessageBackupProtoStreamProvider.swift +++ b/SignalServiceKit/MessageBackup/FileStreams/MessageBackupProtoStreamProvider.swift @@ -72,7 +72,7 @@ public protocol MessageBackupEncryptedProtoStreamProvider { /// once finished. func openEncryptedOutputFileStream( localAci: Aci, - mode: MessageBackup.EncryptionMode, + backupKey: BackupKey, tx: DBReadTransaction ) -> ProtoStream.OpenOutputStreamResult @@ -82,7 +82,7 @@ public protocol MessageBackupEncryptedProtoStreamProvider { func openEncryptedInputFileStream( fileUrl: URL, localAci: Aci, - mode: MessageBackup.EncryptionMode, + backupKey: BackupKey, tx: DBReadTransaction ) -> ProtoStream.OpenInputStreamResult } @@ -100,10 +100,11 @@ public class MessageBackupEncryptedProtoStreamProviderImpl: MessageBackupEncrypt public func openEncryptedOutputFileStream( localAci: Aci, - mode: MessageBackup.EncryptionMode, + backupKey: BackupKey, tx: any DBReadTransaction ) -> ProtoStream.OpenOutputStreamResult { do { + let messageBackupKey = try backupKey.asMessageBackupKey(for: localAci) let inputTrackingTransform = MetadataStreamTransform(calculateDigest: false) let outputTrackingTransform = MetadataStreamTransform(calculateDigest: true) @@ -111,8 +112,11 @@ public class MessageBackupEncryptedProtoStreamProviderImpl: MessageBackupEncrypt inputTrackingTransform, ChunkedOutputStreamTransform(), try GzipStreamTransform(.compress), - try backupKeyMaterial.createEncryptingStreamTransform(localAci: localAci, mode: mode, tx: tx), - try backupKeyMaterial.createHmacGeneratingStreamTransform(localAci: localAci, mode: mode, tx: tx), + try EncryptingStreamTransform( + iv: Randomness.generateRandomBytes(UInt(Cryptography.Constants.aescbcIVLength)), + encryptionKey: Data(messageBackupKey.aesKey) + ), + try HmacStreamTransform(hmacKey: Data(messageBackupKey.hmacKey), operation: .generate), outputTrackingTransform ] @@ -147,17 +151,18 @@ public class MessageBackupEncryptedProtoStreamProviderImpl: MessageBackupEncrypt public func openEncryptedInputFileStream( fileUrl: URL, localAci: Aci, - mode: MessageBackup.EncryptionMode, + backupKey: BackupKey, tx: any DBReadTransaction ) -> ProtoStream.OpenInputStreamResult { - guard validateBackupHMAC(localAci: localAci, mode: mode, fileUrl: fileUrl, tx: tx) else { + guard validateBackupHMAC(localAci: localAci, backupKey: backupKey, fileUrl: fileUrl, tx: tx) else { return .hmacValidationFailedOnEncryptedFile } do { + let messageBackupKey = try backupKey.asMessageBackupKey(for: localAci) let transforms: [any StreamTransform] = [ - try backupKeyMaterial.createHmacValidatingStreamTransform(localAci: localAci, mode: mode, tx: tx), - try backupKeyMaterial.createDecryptingStreamTransform(localAci: localAci, mode: mode, tx: tx), + try HmacStreamTransform(hmacKey: Data(messageBackupKey.hmacKey), operation: .validate), + try DecryptingStreamTransform(encryptionKey: Data(messageBackupKey.aesKey)), try GzipStreamTransform(.decompress), ChunkedInputStreamTransform(), ] @@ -173,19 +178,16 @@ public class MessageBackupEncryptedProtoStreamProviderImpl: MessageBackupEncrypt private func validateBackupHMAC( localAci: Aci, - mode: MessageBackup.EncryptionMode, + backupKey: BackupKey, fileUrl: URL, tx: DBReadTransaction ) -> Bool { do { + let messageBackupKey = try backupKey.asMessageBackupKey(for: localAci) let inputStreamResult = genericStreamProvider.openInputFileStream( fileUrl: fileUrl, transforms: [ - try backupKeyMaterial.createHmacGeneratingStreamTransform( - localAci: localAci, - mode: mode, - tx: tx - ) + try HmacStreamTransform(hmacKey: Data(messageBackupKey.hmacKey), operation: .validate) ] ) diff --git a/SignalServiceKit/MessageBackup/MessageBackupKeyMaterial.swift b/SignalServiceKit/MessageBackup/MessageBackupKeyMaterial.swift index 3eaaa44c7cd..2f7b7122137 100644 --- a/SignalServiceKit/MessageBackup/MessageBackupKeyMaterial.swift +++ b/SignalServiceKit/MessageBackup/MessageBackupKeyMaterial.swift @@ -7,10 +7,7 @@ import Foundation public import LibSignalClient enum MessageBackupKeyMaterialError: Error { - case invalidKeyInfo case missingMasterKey - case notRegistered - case invalidEncryptionKey } public enum MediaTierEncryptionType: CaseIterable { @@ -29,79 +26,21 @@ public struct MediaTierEncryptionMetadata: Equatable { } } -public protocol MessageBackupKeyMaterial { +extension BackupKey { + public func asMessageBackupKey(for aci: Aci) throws -> MessageBackupKey { + try MessageBackupKey(backupKey: self, backupId: self.deriveBackupId(aci: aci)) + } +} +public protocol MessageBackupKeyMaterial { func backupKey( - mode: MessageBackup.EncryptionMode, - tx: DBReadTransaction - ) throws -> BackupKey - - /// Backup ID material derived from a combination of the backup key and the - /// local ACI. This ID is used both as the salt for the backup encryption and - /// to create the anonymous credentials for interacting with server stored backups - func backupID( - localAci: Aci, - mode: MessageBackup.EncryptionMode, - tx: DBReadTransaction - ) throws -> Data - - /// Private key derived from the BackupKey + ACI that is used for signing backup auth presentations. - func backupPrivateKey( - localAci: Aci, - mode: MessageBackup.EncryptionMode, - tx: DBReadTransaction - ) throws -> PrivateKey - - /// LibSignal.BackupAuthCredentialRequestContext derived from the ACI and BackupKey and used primarily - /// for building backup credentials. - /// Always implicitly uses ``MessageBackup/EncryptionMode/remote``. - func backupAuthRequestContext( - localAci: Aci, type: MessageBackupAuthCredentialType, tx: DBReadTransaction - ) throws -> BackupAuthCredentialRequestContext - - func messageBackupKey( - localAci: Aci, - mode: MessageBackup.EncryptionMode, - tx: DBReadTransaction - ) throws -> MessageBackupKey + ) throws -> BackupKey - /// Always implicitly uses ``MessageBackup/EncryptionMode/remote``. func mediaEncryptionMetadata( mediaName: String, type: MediaTierEncryptionType, tx: any DBReadTransaction ) throws -> MediaTierEncryptionMetadata - - /// Builds an encrypting StreamTransform object derived from the backup master key and the backupID - func createEncryptingStreamTransform( - localAci: Aci, - mode: MessageBackup.EncryptionMode, - tx: DBReadTransaction - ) throws -> EncryptingStreamTransform - - func createDecryptingStreamTransform( - localAci: Aci, - mode: MessageBackup.EncryptionMode, - tx: DBReadTransaction - ) throws -> DecryptingStreamTransform - - func createHmacGeneratingStreamTransform( - localAci: Aci, - mode: MessageBackup.EncryptionMode, - tx: DBReadTransaction - ) throws -> HmacStreamTransform - - func createHmacValidatingStreamTransform( - localAci: Aci, - mode: MessageBackup.EncryptionMode, - tx: DBReadTransaction - ) throws -> HmacStreamTransform - - func mediaId( - mediaName: String, - type: MediaTierEncryptionType, - backupKey: BackupKey - ) throws -> Data } diff --git a/SignalServiceKit/MessageBackup/MessageBackupKeyMaterialImpl.swift b/SignalServiceKit/MessageBackup/MessageBackupKeyMaterialImpl.swift index de28694cdb1..c37b5b4e2e4 100644 --- a/SignalServiceKit/MessageBackup/MessageBackupKeyMaterialImpl.swift +++ b/SignalServiceKit/MessageBackup/MessageBackupKeyMaterialImpl.swift @@ -7,83 +7,37 @@ import Foundation public import LibSignalClient public struct MessageBackupKeyMaterialImpl: MessageBackupKeyMaterial { - - private enum Constants { - static let MessageBackupThumbnailEncryptionInfoString = "20240513_Signal_Backups_EncryptThumbnail" - static let MessageBackupThumbnailEncryptionDataLength = 64 - } - private let svr: SecureValueRecovery + private let mrbkStore: MediaRootBackupKeyStore - public init(svr: SecureValueRecovery) { + public init( + mrbkStore: MediaRootBackupKeyStore, + svr: SecureValueRecovery + ) { + self.mrbkStore = mrbkStore self.svr = svr } - public func backupAuthRequestContext( - localAci: Aci, - type: MessageBackupAuthCredentialType, - tx: DBReadTransaction - ) throws -> BackupAuthCredentialRequestContext { - return BackupAuthCredentialRequestContext.create( - backupKey: try backupKey(mode: .remote, tx: tx).serialize(), - aci: localAci.rawUUID - ) - } - - public func backupID(localAci: Aci, mode: MessageBackup.EncryptionMode, tx: DBReadTransaction) throws -> Data { - Data(try backupKey(mode: mode, tx: tx).deriveBackupId(aci: localAci)) - } - - public func backupPrivateKey(localAci: Aci, mode: MessageBackup.EncryptionMode, tx: DBReadTransaction) throws -> PrivateKey { - try backupKey(mode: mode, tx: tx).deriveEcKey(aci: localAci) - } - - public func messageBackupKey( - localAci: Aci, - mode: MessageBackup.EncryptionMode, - tx: DBReadTransaction - ) throws -> MessageBackupKey { - let backupKey = try backupKey(mode: mode, tx: tx) - return try MessageBackupKey( - backupKey: backupKey, - backupId: backupKey.deriveBackupId(aci: localAci) - ) - } - - public func createEncryptingStreamTransform( - localAci: Aci, - mode: MessageBackup.EncryptionMode, - tx: DBReadTransaction - ) throws -> EncryptingStreamTransform { - let encryptionKey = try messageBackupKey(localAci: localAci, mode: mode, tx: tx).aesKey - return try EncryptingStreamTransform(iv: Randomness.generateRandomBytes(16), encryptionKey: Data(encryptionKey)) - } - - public func createDecryptingStreamTransform( - localAci: Aci, - mode: MessageBackup.EncryptionMode, - tx: DBReadTransaction - ) throws -> DecryptingStreamTransform { - let encryptionKey = try messageBackupKey(localAci: localAci, mode: mode, tx: tx).aesKey - return try DecryptingStreamTransform(encryptionKey: Data(encryptionKey)) - } - - public func createHmacGeneratingStreamTransform( - localAci: Aci, - mode: MessageBackup.EncryptionMode, - tx: DBReadTransaction - ) throws -> HmacStreamTransform { - let hmacKey = try messageBackupKey(localAci: localAci, mode: mode, tx: tx).hmacKey - return try HmacStreamTransform(hmacKey: Data(hmacKey), operation: .generate) - } - - public func createHmacValidatingStreamTransform( - localAci: Aci, - mode: MessageBackup.EncryptionMode, - tx: DBReadTransaction - ) throws -> HmacStreamTransform { - let hmacKey = try messageBackupKey(localAci: localAci, mode: mode, tx: tx).hmacKey - return try HmacStreamTransform(hmacKey: Data(hmacKey), operation: .validate) + /// Get the root backup key used by the encryption mode. The key may be derived + /// differently depending on the mode, but derivations downstream of it work the same. + public func backupKey(type: MessageBackupAuthCredentialType, tx: DBReadTransaction) throws -> BackupKey { + let resultData: Data + switch type { + case .media: + guard let backupKey = mrbkStore.getMediaRootBackupKey(tx: tx) else { + throw MessageBackupKeyMaterialError.missingMasterKey + } + resultData = backupKey + case .messages: + guard let backupKey = svr.data(for: .backupKey, transaction: tx) else { + throw MessageBackupKeyMaterialError.missingMasterKey + } + guard backupKey.type == .backupKey else { + throw OWSAssertionError("Wrong key provided") + } + resultData = backupKey.rawData + } + return try resultData.withUnsafeBytes { try BackupKey(contents: Array($0)) } } public func mediaEncryptionMetadata( @@ -91,26 +45,15 @@ public struct MessageBackupKeyMaterialImpl: MessageBackupKeyMaterial { type: MediaTierEncryptionType, tx: any DBReadTransaction ) throws -> MediaTierEncryptionMetadata { - let backupKey = try backupKey(mode: .remote, tx: tx) + let backupKey = try backupKey(type: .media, tx: tx) let mediaId = try backupKey.deriveMediaId(mediaName) - let keyBytes: [UInt8] switch type { case .attachment: keyBytes = try backupKey.deriveMediaEncryptionKey(mediaId) case .thumbnail: - // TODO: Remove this when libSignal supports this. - guard let infoData = Constants.MessageBackupThumbnailEncryptionInfoString.data(using: .utf8) else { - throw MessageBackupKeyMaterialError.invalidKeyInfo - } - keyBytes = try hkdf( - outputLength: Constants.MessageBackupThumbnailEncryptionDataLength, - inputKeyMaterial: backupKey.serialize(), - salt: mediaId, - info: infoData - ) + keyBytes = try backupKey.deriveThumbnailTransitEncryptionKey(mediaId) } - return MediaTierEncryptionMetadata( type: type, mediaId: Data(mediaId), @@ -118,44 +61,25 @@ public struct MessageBackupKeyMaterialImpl: MessageBackupKeyMaterial { aesKey: Data(Array(keyBytes[32..<64])) ) } +} - public func mediaId( - mediaName: String, - type: MediaTierEncryptionType, - backupKey: BackupKey - ) throws -> Data { - let mediaName = { - switch type { - case .attachment: - return mediaName - case .thumbnail: - return mediaName - } - }() +#if TESTABLE_BUILD - return Data(try backupKey.deriveMediaId(mediaName)) +open class MessageBackupKeyMaterialMock: MessageBackupKeyMaterial { + public func backupKey( + type: MessageBackupAuthCredentialType, + tx: any DBReadTransaction + ) throws -> BackupKey { + throw OWSAssertionError("Unimplemented") } - /// Get the root backup key used by the encryption mode. The key may be derived - /// differently depending on the mode, but derivations downstream of it work the same. - public func backupKey(mode: MessageBackup.EncryptionMode, tx: DBReadTransaction) throws -> BackupKey { - let resultData: Data - switch mode { - case .remote: - guard let backupKey = svr.data(for: .backupKey, transaction: tx) else { - throw MessageBackupKeyMaterialError.missingMasterKey - } - guard backupKey.type == .backupKey else { - throw OWSAssertionError("Wrong key provided") - } - resultData = backupKey.rawData - case .linknsync(let ephemeralBackupKey): - let rawKey = ephemeralBackupKey.data - guard rawKey.byteLength == SVR.DerivedKey.backupKeyLength else { - throw MessageBackupKeyMaterialError.invalidEncryptionKey - } - resultData = rawKey - } - return try resultData.withUnsafeBytes { try BackupKey(contents: Array($0)) } + public func mediaEncryptionMetadata( + mediaName: String, + type: MediaTierEncryptionType, + tx: any DBReadTransaction + ) throws -> MediaTierEncryptionMetadata { + throw OWSAssertionError("Unimplemented") } } + +#endif diff --git a/SignalServiceKit/MessageBackup/MessageBackupManager.swift b/SignalServiceKit/MessageBackup/MessageBackupManager.swift index e8a7de92b9b..aa70094c4cf 100644 --- a/SignalServiceKit/MessageBackup/MessageBackupManager.swift +++ b/SignalServiceKit/MessageBackup/MessageBackupManager.swift @@ -4,19 +4,7 @@ // import Foundation - -public extension MessageBackup { - enum EncryptionMode { - /// Export/Import an encrypted backup to/from the remote server. - /// The encryption key used derives entirely from the local account keychain. - case remote - /// Export/Import an encrypted backup used to link a new device. - /// The encryption key used is derived from the aci and a 32-byte "ephemeral" backup key. - case linknsync(EphemeralBackupKey) - - // TODO: [LocalBackups] introduce local mode with its own key scheme. - } -} +public import LibSignalClient public protocol MessageBackupManager { @@ -39,7 +27,7 @@ public protocol MessageBackupManager { /// - SeeAlso ``uploadEncryptedBackup(metadata:localIdentifiers:auth:)`` func exportEncryptedBackup( localIdentifiers: LocalIdentifiers, - mode: MessageBackup.EncryptionMode + backupKey: BackupKey ) async throws -> Upload.EncryptedBackupUploadMetadata /// Export a plaintext backup binary at the returned file URL. @@ -52,7 +40,7 @@ public protocol MessageBackupManager { func importEncryptedBackup( fileUrl: URL, localIdentifiers: LocalIdentifiers, - mode: MessageBackup.EncryptionMode + backupKey: BackupKey ) async throws /// Import a backup from the plaintext binary file at the given local URL. @@ -64,6 +52,6 @@ public protocol MessageBackupManager { func validateEncryptedBackup( fileUrl: URL, localIdentifiers: LocalIdentifiers, - mode: MessageBackup.EncryptionMode + backupKey: BackupKey ) async throws } diff --git a/SignalServiceKit/MessageBackup/MessageBackupManagerImpl.swift b/SignalServiceKit/MessageBackup/MessageBackupManagerImpl.swift index 35907defe1f..feae30684dd 100644 --- a/SignalServiceKit/MessageBackup/MessageBackupManagerImpl.swift +++ b/SignalServiceKit/MessageBackup/MessageBackupManagerImpl.swift @@ -3,7 +3,7 @@ // SPDX-License-Identifier: AGPL-3.0-only // -import LibSignalClient +public import LibSignalClient public enum BackupValidationError: Error { case unknownFields([String]) @@ -16,6 +16,7 @@ public class MessageBackupManagerImpl: MessageBackupManager { private enum Constants { static let keyValueStoreCollectionName = "MessageBackupManager" static let keyValueStoreHasReservedBackupKey = "HasReservedBackupKey" + static let keyValueStoreHasReservedMediaBackupKey = "HasReservedMediaBackupKey" static let supportedBackupVersion: UInt64 = 1 } @@ -122,34 +123,33 @@ public class MessageBackupManagerImpl: MessageBackupManager { /// These registration calls are safe to call multiple times, but to avoid unecessary network calls, the app will remember if /// backups have been successfully registered on this device and will no-op in this case. private func reserveAndRegister(localIdentifiers: LocalIdentifiers, auth: ChatServiceAuth) async throws { - guard db.read(block: { tx in - return kvStore.getBool(Constants.keyValueStoreHasReservedBackupKey, transaction: tx) ?? false - }).negated else { + let (hasReservedBackupKey, hasReservedMediaBackupKey) = db.read { tx in + return ( + kvStore.getBool(Constants.keyValueStoreHasReservedBackupKey, transaction: tx) ?? false, + kvStore.getBool(Constants.keyValueStoreHasReservedMediaBackupKey, transaction: tx) ?? false + ) + } + + if hasReservedBackupKey && hasReservedMediaBackupKey { return } // Both reserveBackupId and registerBackupKeys can be called multiple times, so if // we think the backupId needs to be registered, register the public key at the same time. - - try await backupRequestManager.reserveBackupId(localAci: localIdentifiers.aci, auth: auth) - - let backupAuth = try await backupRequestManager.fetchBackupServiceAuth( - for: .oneTimeKeySetup(.messages), - localAci: localIdentifiers.aci, - auth: auth - ) - - try await backupRequestManager.registerBackupKeys(auth: backupAuth) + let localAci = localIdentifiers.aci + try await backupRequestManager.reserveBackupId(localAci: localAci, auth: auth) + try await backupRequestManager.registerBackupKeys(localAci: localAci, auth: auth) // Remember this device has registered for backups await db.awaitableWrite { [weak self] tx in self?.kvStore.setBool(true, key: Constants.keyValueStoreHasReservedBackupKey, transaction: tx) + self?.kvStore.setBool(true, key: Constants.keyValueStoreHasReservedMediaBackupKey, transaction: tx) } } public func downloadEncryptedBackup(localIdentifiers: LocalIdentifiers, auth: ChatServiceAuth) async throws -> URL { let backupAuth = try await backupRequestManager.fetchBackupServiceAuth( - for: .download(.messages), + for: .messages, localAci: localIdentifiers.aci, auth: auth ) @@ -170,7 +170,7 @@ public class MessageBackupManagerImpl: MessageBackupManager { // This will return early if this device has already registered the backup ID. try await reserveAndRegister(localIdentifiers: localIdentifiers, auth: auth) let backupAuth = try await backupRequestManager.fetchBackupServiceAuth( - for: .upload(.messages), + for: .messages, localAci: localIdentifiers.aci, auth: auth ) @@ -182,7 +182,7 @@ public class MessageBackupManagerImpl: MessageBackupManager { public func exportEncryptedBackup( localIdentifiers: LocalIdentifiers, - mode: MessageBackup.EncryptionMode + backupKey: BackupKey ) async throws -> Upload.EncryptedBackupUploadMetadata { guard FeatureFlags.messageBackupFileAlpha @@ -202,7 +202,7 @@ public class MessageBackupManagerImpl: MessageBackupManager { let metadataProvider: MessageBackup.ProtoStream.EncryptionMetadataProvider switch self.encryptedStreamProvider.openEncryptedOutputFileStream( localAci: localIdentifiers.aci, - mode: mode, + backupKey: backupKey, tx: tx ) { case let .success(_outputStream, _metadataProvider): @@ -475,7 +475,7 @@ public class MessageBackupManagerImpl: MessageBackupManager { public func importEncryptedBackup( fileUrl: URL, localIdentifiers: LocalIdentifiers, - mode: MessageBackup.EncryptionMode + backupKey: BackupKey ) async throws { guard FeatureFlags.messageBackupFileAlpha || FeatureFlags.linkAndSyncSecondary else { owsFailDebug("Should not be able to use backups!") @@ -492,7 +492,7 @@ public class MessageBackupManagerImpl: MessageBackupManager { switch self.encryptedStreamProvider.openEncryptedInputFileStream( fileUrl: fileUrl, localAci: localIdentifiers.aci, - mode: mode, + backupKey: backupKey, tx: tx ) { case .success(let protoStream, _): @@ -846,11 +846,9 @@ public class MessageBackupManagerImpl: MessageBackupManager { public func validateEncryptedBackup( fileUrl: URL, localIdentifiers: LocalIdentifiers, - mode: MessageBackup.EncryptionMode + backupKey: BackupKey ) async throws { - let key = try db.read { tx in - return try messageBackupKeyMaterial.messageBackupKey(localAci: localIdentifiers.aci, mode: mode, tx: tx) - } + let key = try backupKey.asMessageBackupKey(for: localIdentifiers.aci) let fileSize = OWSFileSystem.fileSize(ofPath: fileUrl.path)?.uint64Value ?? 0 do { diff --git a/SignalServiceKit/MessageBackup/MessageBackupManagerMock.swift b/SignalServiceKit/MessageBackup/MessageBackupManagerMock.swift index 9abb74092b6..18b9c2f683e 100644 --- a/SignalServiceKit/MessageBackup/MessageBackupManagerMock.swift +++ b/SignalServiceKit/MessageBackup/MessageBackupManagerMock.swift @@ -4,6 +4,7 @@ // import Foundation +public import LibSignalClient #if TESTABLE_BUILD @@ -35,7 +36,7 @@ open class MessageBackupManagerMock: MessageBackupManager { public func exportEncryptedBackup( localIdentifiers: LocalIdentifiers, - mode: MessageBackup.EncryptionMode + backupKey: BackupKey ) async throws -> Upload.EncryptedBackupUploadMetadata { return Upload.EncryptedBackupUploadMetadata( fileUrl: URL(string: "file://")!, @@ -52,13 +53,13 @@ open class MessageBackupManagerMock: MessageBackupManager { public func importEncryptedBackup( fileUrl: URL, localIdentifiers: LocalIdentifiers, - mode: MessageBackup.EncryptionMode + backupKey: BackupKey ) async throws {} public func importPlaintextBackup(fileUrl: URL, localIdentifiers: LocalIdentifiers) async throws {} public func validateEncryptedBackup( fileUrl: URL, localIdentifiers: LocalIdentifiers, - mode: MessageBackup.EncryptionMode + backupKey: BackupKey ) async throws {} } diff --git a/SignalServiceKit/MessageBackup/MessageBackupRequestManager.swift b/SignalServiceKit/MessageBackup/MessageBackupRequestManager.swift index b51ee16767c..74e18c46740 100644 --- a/SignalServiceKit/MessageBackup/MessageBackupRequestManager.swift +++ b/SignalServiceKit/MessageBackup/MessageBackupRequestManager.swift @@ -89,14 +89,14 @@ public protocol MessageBackupRequestManager { /// The intended purpose for this service auth. This influences the steps /// taken in fetching the credentials underlying the returned service auth. func fetchBackupServiceAuth( - for purpose: MessageBackupAuthCredentialManager.Purpose, + for credentialType: MessageBackupAuthCredentialType, localAci: Aci, auth: ChatServiceAuth ) async throws -> MessageBackupServiceAuth func reserveBackupId(localAci: Aci, auth: ChatServiceAuth) async throws - func registerBackupKeys(auth: MessageBackupServiceAuth) async throws + func registerBackupKeys(localAci: Aci, auth: ChatServiceAuth) async throws func fetchBackupUploadForm(auth: MessageBackupServiceAuth) async throws -> Upload.Form @@ -142,12 +142,31 @@ public struct MessageBackupRequestManagerImpl: MessageBackupRequestManager { static let keyValueStoreCollectionName = "MessageBackupRequestManager" static let cdnNumberOfDaysFetchIntervalInSeconds: TimeInterval = kDayInterval - static let keyValueStoreCdn2CredentialKey = "Cdn2Credential" - static let keyValueStoreCdn3CredentialKey = "Cdn3Credential" + private static let keyValueStoreCdn2CredentialKey = "Cdn2Credential:" + private static let keyValueStoreCdn3CredentialKey = "Cdn3Credential:" + + static func cdnCredentialCacheKey(for cdn: Int32, auth: MessageBackupServiceAuth) -> String { + switch cdn { + case 2: + return Constants.keyValueStoreCdn2CredentialKey + auth.type.rawValue + case 3: + return Constants.keyValueStoreCdn3CredentialKey + auth.type.rawValue + default: + owsFailDebug("Invalid CDN version requested") + return Constants.keyValueStoreCdn3CredentialKey + auth.type.rawValue + } + } - static let keyValueStoreBackupInfoKey = "BackupInfo" static let backupInfoNumberOfDaysFetchIntervalInSeconds: TimeInterval = kDayInterval - static let keyValueStoreLastBackupInfoFetchTimeKey = "LastBackupInfoFetchTime" + private static let keyValueStoreBackupInfoKeyPrefix = "BackupInfo:" + private static let keyValueStoreLastBackupInfoFetchTimeKeyPrefix = "LastBackupInfoFetchTime:" + + static func backupInfoCacheInfo(for auth: MessageBackupServiceAuth) -> (infoKey: String, lastfetchTimeKey: String) { + ( + keyValueStoreBackupInfoKeyPrefix + auth.type.rawValue, + keyValueStoreLastBackupInfoFetchTimeKeyPrefix + auth.type.rawValue + ) + } } private let dateProvider: DateProvider @@ -177,11 +196,25 @@ public struct MessageBackupRequestManagerImpl: MessageBackupRequestManager { /// Onetime request to reserve this backup ID. public func reserveBackupId(localAci: Aci, auth: ChatServiceAuth) async throws { - let backupRequestContext = try db.read { tx in - return try messageBackupKeyMaterial.backupAuthRequestContext(localAci: localAci, type: .messages, tx: tx) + let messageBackupRequestContext = try db.read { tx in + BackupAuthCredentialRequestContext.create( + backupKey: try messageBackupKeyMaterial.backupKey(type: .messages, tx: tx).serialize(), + aci: localAci.rawUUID + ) } - let base64RequestContext = Data(backupRequestContext.getRequest().serialize()).base64EncodedString() - let request = try OWSRequestFactory.reserveBackupId(backupId: base64RequestContext, auth: auth) + let mediaBackupRequestContext = try db.read { tx in + return BackupAuthCredentialRequestContext.create( + backupKey: try messageBackupKeyMaterial.backupKey(type: .media, tx: tx).serialize(), + aci: localAci.rawUUID + ) + } + let base64MessageRequestContext = messageBackupRequestContext.getRequest().serialize().asData.base64EncodedString() + let base64MediaRequestContext = mediaBackupRequestContext.getRequest().serialize().asData.base64EncodedString() + let request = try OWSRequestFactory.reserveBackupId( + backupId: base64MessageRequestContext, + mediaBackupId: base64MediaRequestContext, + auth: auth + ) // TODO: Switch this back to true when reg supports websockets _ = try await networkManager.asyncRequest(request, canUseWebSocket: false) } @@ -189,31 +222,52 @@ public struct MessageBackupRequestManagerImpl: MessageBackupRequestManager { // MARK: - Backup Auth public func fetchBackupServiceAuth( - for purpose: MessageBackupAuthCredentialManager.Purpose, + for credentialType: MessageBackupAuthCredentialType, localAci: Aci, auth: ChatServiceAuth ) async throws -> MessageBackupServiceAuth { let (backupKey, privateKey) = try db.read { tx in - let backupKey = try messageBackupKeyMaterial.backupID(localAci: localAci, mode: .remote, tx: tx) - let privateKey = try messageBackupKeyMaterial.backupPrivateKey(localAci: localAci, mode: .remote, tx: tx) + let key = try messageBackupKeyMaterial.backupKey(type: credentialType, tx: tx) + let backupKey = key.deriveBackupId(aci: localAci) + let privateKey = key.deriveEcKey(aci: localAci) return (backupKey, privateKey) } let authCredential = try await messageBackupAuthCredentialManager.fetchBackupCredential( - for: purpose, + for: credentialType, localAci: localAci, chatServiceAuth: auth ) - return try MessageBackupServiceAuth(backupKey: backupKey, privateKey: privateKey, authCredential: authCredential) + return try MessageBackupServiceAuth( + backupKey: backupKey.asData, + privateKey: privateKey, + authCredential: authCredential, + type: credentialType + ) } // MARK: - Register Backup /// Onetime request to register the backup public key. - public func registerBackupKeys(auth: MessageBackupServiceAuth) async throws { + public func registerBackupKeys(localAci: Aci, auth: ChatServiceAuth) async throws { + let backupAuth = try await fetchBackupServiceAuth( + for: .messages, + localAci: localAci, + auth: auth + ) _ = try await executeBackupServiceRequest( - auth: auth, + auth: backupAuth, + requestFactory: OWSRequestFactory.backupSetPublicKeyRequest(auth:) + ) + + let mediaBackupAuth = try await fetchBackupServiceAuth( + for: .media, + localAci: localAci, + auth: auth + ) + _ = try await executeBackupServiceRequest( + auth: mediaBackupAuth, requestFactory: OWSRequestFactory.backupSetPublicKeyRequest(auth:) ) } @@ -222,6 +276,7 @@ public struct MessageBackupRequestManagerImpl: MessageBackupRequestManager { /// CDN upload form for uploading a backup public func fetchBackupUploadForm(auth: MessageBackupServiceAuth) async throws -> Upload.Form { + owsAssertDebug(auth.type == .messages) return try await executeBackupService( auth: auth, requestFactory: OWSRequestFactory.backupUploadFormRequest(auth:) @@ -230,6 +285,7 @@ public struct MessageBackupRequestManagerImpl: MessageBackupRequestManager { /// CDN upload form for uploading backup media public func fetchBackupMediaAttachmentUploadForm(auth: MessageBackupServiceAuth) async throws -> Upload.Form { + owsAssertDebug(auth.type == .media) return try await executeBackupService( auth: auth, requestFactory: OWSRequestFactory.backupMediaUploadFormRequest(auth:) @@ -240,9 +296,10 @@ public struct MessageBackupRequestManagerImpl: MessageBackupRequestManager { /// Fetch details about the current backup public func fetchBackupInfo(auth: MessageBackupServiceAuth) async throws -> MessageBackupRemoteInfo { + let cacheInfo = Constants.backupInfoCacheInfo(for: auth) let cachedBackupInfo = db.read { tx -> MessageBackupRemoteInfo? in let lastInfoFetchTime = kvStore.getDate( - Constants.keyValueStoreLastBackupInfoFetchTimeKey, + cacheInfo.lastfetchTimeKey, transaction: tx ) ?? .distantPast @@ -250,7 +307,7 @@ public struct MessageBackupRequestManagerImpl: MessageBackupRequestManager { if abs(lastInfoFetchTime.timeIntervalSinceNow) < Constants.backupInfoNumberOfDaysFetchIntervalInSeconds { do { if let backupInfo: MessageBackupRemoteInfo = try kvStore.getCodableValue( - forKey: Constants.keyValueStoreBackupInfoKey, + forKey: cacheInfo.infoKey, transaction: tx ) { return backupInfo @@ -275,8 +332,17 @@ public struct MessageBackupRequestManagerImpl: MessageBackupRequestManager { ) try await db.awaitableWrite { tx in - try kvStore.setCodable(backupInfo, key: Constants.keyValueStoreBackupInfoKey, transaction: tx) - kvStore.setDate(dateProvider(), key: Constants.keyValueStoreLastBackupInfoFetchTimeKey, transaction: tx) + try kvStore.setCodable( + backupInfo, + key: cacheInfo.infoKey, + transaction: tx + ) + + kvStore.setDate( + dateProvider(), + key: cacheInfo.lastfetchTimeKey, + transaction: tx + ) } return backupInfo @@ -292,6 +358,7 @@ public struct MessageBackupRequestManagerImpl: MessageBackupRequestManager { /// Delete the current backup public func deleteBackup(auth: MessageBackupServiceAuth) async throws { + owsAssertDebug(auth.type == .messages) _ = try await executeBackupServiceRequest( auth: auth, requestFactory: OWSRequestFactory.deleteBackupRequest(auth:) @@ -305,19 +372,7 @@ public struct MessageBackupRequestManagerImpl: MessageBackupRequestManager { cdn: Int32, auth: MessageBackupServiceAuth ) async throws -> CDNReadCredential { - - let cacheKey = { - switch cdn { - case 2: - return Constants.keyValueStoreCdn2CredentialKey - case 3: - return Constants.keyValueStoreCdn3CredentialKey - default: - owsFailDebug("Invalid CDN version requested") - return Constants.keyValueStoreCdn3CredentialKey - } - }() - + let cacheKey = Constants.cdnCredentialCacheKey(for: cdn, auth: auth) let result = db.read { tx -> CDNReadCredential? in do { if @@ -361,6 +416,7 @@ public struct MessageBackupRequestManagerImpl: MessageBackupRequestManager { cdn: Int32, auth: MessageBackupServiceAuth ) async throws -> MediaTierReadCredential { + owsAssertDebug(auth.type == .media) let info = try await fetchBackupInfo(auth: auth) let authCredential = try await fetchCDNReadCredentials(cdn: cdn, auth: auth) return MediaTierReadCredential(cdn: cdn, credential: authCredential, info: info) @@ -370,6 +426,7 @@ public struct MessageBackupRequestManagerImpl: MessageBackupRequestManager { item: MessageBackup.Request.MediaItem, auth: MessageBackupServiceAuth ) async throws -> UInt32 { + owsAssertDebug(auth.type == .media) do { let response = try await executeBackupServiceRequest( auth: auth, @@ -407,6 +464,7 @@ public struct MessageBackupRequestManagerImpl: MessageBackupRequestManager { items: [MessageBackup.Request.MediaItem], auth: MessageBackupServiceAuth ) async throws -> [MessageBackup.Response.BatchedBackupMediaResult] { + owsAssertDebug(auth.type == .media) return try await executeBackupService( auth: auth, requestFactory: { @@ -423,6 +481,7 @@ public struct MessageBackupRequestManagerImpl: MessageBackupRequestManager { limit: UInt32?, auth: MessageBackupServiceAuth ) async throws -> MessageBackup.Response.ListMediaResult { + owsAssertDebug(auth.type == .media) return try await executeBackupService( auth: auth, requestFactory: { @@ -436,6 +495,7 @@ public struct MessageBackupRequestManagerImpl: MessageBackupRequestManager { } public func deleteMediaObjects(objects: [MessageBackup.Request.DeleteMediaTarget], auth: MessageBackupServiceAuth) async throws { + owsAssertDebug(auth.type == .media) _ = try await executeBackupServiceRequest( auth: auth, requestFactory: { diff --git a/SignalServiceKit/MessageBackup/MessageBackupServiceAuth.swift b/SignalServiceKit/MessageBackup/MessageBackupServiceAuth.swift index c7ba4a2b61a..2d8f4cf9701 100644 --- a/SignalServiceKit/MessageBackup/MessageBackupServiceAuth.swift +++ b/SignalServiceKit/MessageBackup/MessageBackupServiceAuth.swift @@ -10,10 +10,15 @@ public struct MessageBackupServiceAuth { private let authHeaders: [String: String] public let publicKey: PublicKey - public init(backupKey: Data, privateKey: PrivateKey, authCredential: BackupAuthCredential) throws { + // Remember the type of auth this credential represents (message vs media). + // This makes it easier to cache requested information correctly based on the type + public let type: MessageBackupAuthCredentialType + + public init(backupKey: Data, privateKey: PrivateKey, authCredential: BackupAuthCredential, type: MessageBackupAuthCredentialType) throws { let backupServerPublicParams = try GenericServerPublicParams(contents: [UInt8](TSConstants.backupServerPublicParams)) let presentation = authCredential.present(serverParams: backupServerPublicParams).serialize() let signedPresentation = privateKey.generateSignature(message: presentation) + self.type = type self.publicKey = privateKey.publicKey self.authHeaders = [ diff --git a/SignalServiceKit/Messages/Attachments/V2/Downloads/AttachmentDownloadManagerImpl.swift b/SignalServiceKit/Messages/Attachments/V2/Downloads/AttachmentDownloadManagerImpl.swift index 91ca3faa29c..318b9b4d6a6 100644 --- a/SignalServiceKit/Messages/Attachments/V2/Downloads/AttachmentDownloadManagerImpl.swift +++ b/SignalServiceKit/Messages/Attachments/V2/Downloads/AttachmentDownloadManagerImpl.swift @@ -812,7 +812,7 @@ public class AttachmentDownloadManagerImpl: AttachmentDownloadManager { } guard let auth = try? await messageBackupRequestManager.fetchBackupServiceAuth( - for: .download(.media), + for: .media, localAci: localAci, auth: .implicit() ) else { diff --git a/SignalServiceKit/Network/API/Requests/OWSRequestFactory+MessageBackup.swift b/SignalServiceKit/Network/API/Requests/OWSRequestFactory+MessageBackup.swift index 7b86cb05f4c..ad28ad93c85 100644 --- a/SignalServiceKit/Network/API/Requests/OWSRequestFactory+MessageBackup.swift +++ b/SignalServiceKit/Network/API/Requests/OWSRequestFactory+MessageBackup.swift @@ -10,12 +10,16 @@ import Foundation extension OWSRequestFactory { public static func reserveBackupId( backupId: String, + mediaBackupId: String, auth: ChatServiceAuth = .implicit() ) throws -> TSRequest { let request = TSRequest( url: URL(string: "v1/archives/backupid")!, method: "PUT", - parameters: ["backupAuthCredentialRequest": backupId] + parameters: [ + "messagesBackupAuthCredentialRequest": backupId, + "mediaBackupAuthCredentialRequest": mediaBackupId + ] ) request.setAuth(auth) return request diff --git a/SignalServiceKit/Upload/AttachmentUploadManager.swift b/SignalServiceKit/Upload/AttachmentUploadManager.swift index b460e94c5c7..64a99e95011 100644 --- a/SignalServiceKit/Upload/AttachmentUploadManager.swift +++ b/SignalServiceKit/Upload/AttachmentUploadManager.swift @@ -888,7 +888,7 @@ public actor AttachmentUploadManagerImpl: AttachmentUploadManager { logger: PrefixedLogger ) async throws -> UInt32 { let auth = try await messageBackupRequestManager.fetchBackupServiceAuth( - for: .upload(.media), + for: .media, localAci: localAci, auth: .implicit() ) diff --git a/SignalServiceKit/ZkParams/AuthCredentialStore.swift b/SignalServiceKit/ZkParams/AuthCredentialStore.swift index ca4bc80d8a7..6bf6364259f 100644 --- a/SignalServiceKit/ZkParams/AuthCredentialStore.swift +++ b/SignalServiceKit/ZkParams/AuthCredentialStore.swift @@ -10,11 +10,13 @@ class AuthCredentialStore { private let callLinkAuthCredentialStore: any KeyValueStore private let groupAuthCredentialStore: any KeyValueStore private let backupAuthCredentialStore: any KeyValueStore + private let mediaAuthCredentialStore: any KeyValueStore init(keyValueStoreFactory: any KeyValueStoreFactory) { self.callLinkAuthCredentialStore = keyValueStoreFactory.keyValueStore(collection: "CallLinkAuthCredential") self.groupAuthCredentialStore = keyValueStoreFactory.keyValueStore(collection: "GroupsV2Impl.authCredentialStoreStore") self.backupAuthCredentialStore = keyValueStoreFactory.keyValueStore(collection: "BackupAuthCredential") + self.mediaAuthCredentialStore = keyValueStoreFactory.keyValueStore(collection: "MediaAuthCredential") } private static func callLinkAuthCredentialKey(for redemptionTime: UInt64) -> String { @@ -86,11 +88,13 @@ class AuthCredentialStore { } func backupAuthCredential( - for redemptionTime: UInt64, + for credentialType: MessageBackupAuthCredentialType, + redemptionTime: UInt64, tx: DBReadTransaction ) -> BackupAuthCredential? { + let store = credentialType == .messages ? backupAuthCredentialStore : mediaAuthCredentialStore do { - return try backupAuthCredentialStore.getData( + return try store.getData( Self.backupAuthCredentialKey(for: redemptionTime), transaction: tx ).map { @@ -104,17 +108,20 @@ class AuthCredentialStore { func setBackupAuthCredential( _ credential: BackupAuthCredential, - for redemptionTime: UInt64, + for credentialType: MessageBackupAuthCredentialType, + redemptionTime: UInt64, tx: DBWriteTransaction ) { - backupAuthCredentialStore.setData( + let store = credentialType == .messages ? backupAuthCredentialStore : mediaAuthCredentialStore + store.setData( credential.serialize().asData, key: Self.backupAuthCredentialKey(for: redemptionTime), transaction: tx ) } - func removeAllBackupAuthCredentials(tx: DBWriteTransaction) { - backupAuthCredentialStore.removeAll(transaction: tx) + func removeAllBackupAuthCredentials(for credentialType: MessageBackupAuthCredentialType, tx: DBWriteTransaction) { + let store = credentialType == .messages ? backupAuthCredentialStore : mediaAuthCredentialStore + store.removeAll(transaction: tx) } } diff --git a/SignalServiceKit/ZkParams/MessageBackupAuthCredentialManager.swift b/SignalServiceKit/ZkParams/MessageBackupAuthCredentialManager.swift index d16e78dd046..a748ca22064 100644 --- a/SignalServiceKit/ZkParams/MessageBackupAuthCredentialManager.swift +++ b/SignalServiceKit/ZkParams/MessageBackupAuthCredentialManager.swift @@ -6,40 +6,14 @@ import Foundation public import LibSignalClient -/// The purposes for which one could need a Backup auth credential. The intended -/// purpose influences what steps are taken when retrieving a credential. -public enum MessageBackupAuthCredentialPurpose { - /// The credential will be used for one-time set-up of keys related to a - /// user's Backup. - case oneTimeKeySetup(MessageBackupAuthCredentialType) - /// The credential will be used for download, for example of a Backup file - /// or attachments. - case download(MessageBackupAuthCredentialType) - /// The credential will be used for upload, for example of a Backup file or - /// attachments. - case upload(MessageBackupAuthCredentialType) - /// The credential will be used for delete, for example of a Backup file or - /// attachments. - case delete(MessageBackupAuthCredentialType) - - var credentialType: MessageBackupAuthCredentialType { - switch self { - case .delete(let credentialType), .download(let credentialType), .oneTimeKeySetup(let credentialType), .upload(let credentialType): - return credentialType - } - } -} - -public enum MessageBackupAuthCredentialType: String, Codable, CaseIterable { +public enum MessageBackupAuthCredentialType: String, Codable, CaseIterable, CodingKeyRepresentable { case media case messages } public protocol MessageBackupAuthCredentialManager { - typealias Purpose = MessageBackupAuthCredentialPurpose - func fetchBackupCredential( - for purpose: Purpose, + for credentialType: MessageBackupAuthCredentialType, localAci: Aci, chatServiceAuth auth: ChatServiceAuth ) async throws -> BackupAuthCredential @@ -52,7 +26,11 @@ public struct MessageBackupAuthCredentialManagerImpl: MessageBackupAuthCredentia static let numberOfDaysFetchIntervalInSeconds: TimeInterval = 3 * kDayInterval static let keyValueStoreCollectionName = "MessageBackupAuthCredentialManager" - static let keyValueStoreLastCredentialFetchTimeKey = "LastCredentialFetchTime" + + private static let keyValueStoreLastBackupCredentialFetchTimeKeyPrefix = "LastCredentialFetchTime:" + static func cacheKey(for credentialType: MessageBackupAuthCredentialType) -> String { + keyValueStoreLastBackupCredentialFetchTimeKeyPrefix + credentialType.rawValue + } } private let authCredentialStore: AuthCredentialStore @@ -79,20 +57,22 @@ public struct MessageBackupAuthCredentialManagerImpl: MessageBackupAuthCredentia } public func fetchBackupCredential( - for purpose: Purpose, + for credentialType: MessageBackupAuthCredentialType, localAci: Aci, chatServiceAuth auth: ChatServiceAuth ) async throws -> BackupAuthCredential { + let fetchTimeKey = Constants.cacheKey(for: credentialType) let authCredential = db.read { tx -> BackupAuthCredential? in - let lastCredentialFetchTime = kvStore.getDate( - Constants.keyValueStoreLastCredentialFetchTimeKey, - transaction: tx - ) ?? .distantPast + let lastCredentialFetchTime = kvStore.getDate(fetchTimeKey, transaction: tx) ?? .distantPast // every 3 days fetch 7 days worth if abs(lastCredentialFetchTime.timeIntervalSinceNow) < Constants.numberOfDaysFetchIntervalInSeconds { let redemptionTime = self.dateProvider().startOfTodayUTCTimestamp() - if let backupAuthCredential = self.authCredentialStore.backupAuthCredential(for: redemptionTime, tx: tx) { + if let backupAuthCredential = self.authCredentialStore.backupAuthCredential( + for: credentialType, + redemptionTime: redemptionTime, + tx: tx + ) { return backupAuthCredential } else { owsFailDebug("Error retrieving cached auth credential") @@ -106,21 +86,32 @@ public struct MessageBackupAuthCredentialManagerImpl: MessageBackupAuthCredentia return authCredential } - let authCredentials = try await fetchNewAuthCredentials(localAci: localAci, for: purpose, auth: auth) + let authCredentials = try await fetchNewAuthCredentials(localAci: localAci, for: credentialType, auth: auth) await db.awaitableWrite { tx in - self.authCredentialStore.removeAllBackupAuthCredentials(tx: tx) - for receivedCredential in authCredentials { - self.authCredentialStore.setBackupAuthCredential( - receivedCredential.credential, - for: receivedCredential.redemptionTime, - tx: tx - ) + // Fetch both credential types if either is needed. + MessageBackupAuthCredentialType.allCases.forEach { credentialType in + guard let receivedCredentials = authCredentials[credentialType] else { + if credentialType == credentialType { + // If the requested media type fails, make some noise about it. + owsFailDebug("Failed to retrieve credentials for \(credentialType.rawValue)") + } + return + } + self.authCredentialStore.removeAllBackupAuthCredentials(for: credentialType, tx: tx) + for receivedCredential in receivedCredentials { + self.authCredentialStore.setBackupAuthCredential( + receivedCredential.credential, + for: credentialType, + redemptionTime: receivedCredential.redemptionTime, + tx: tx + ) + } + kvStore.setDate(dateProvider(), key: fetchTimeKey, transaction: tx) } - kvStore.setDate(dateProvider(), key: Constants.keyValueStoreLastCredentialFetchTimeKey, transaction: tx) } - guard let authCredential = authCredentials.first?.credential else { + guard let authCredential = authCredentials[credentialType]?.first?.credential else { throw OWSAssertionError("The server didn't give us any auth credentials.") } @@ -129,9 +120,9 @@ public struct MessageBackupAuthCredentialManagerImpl: MessageBackupAuthCredentia private func fetchNewAuthCredentials( localAci: Aci, - for purpose: Purpose, + for credentialType: MessageBackupAuthCredentialType, auth: ChatServiceAuth - ) async throws -> [ReceivedBackupAuthCredentials] { + ) async throws -> [MessageBackupAuthCredentialType: [ReceivedBackupAuthCredentials]] { let startTimestamp = self.dateProvider().startOfTodayUTCTimestamp() let endTimestamp = startTimestamp + UInt64(Constants.numberOfDaysToFetchInSeconds) @@ -151,32 +142,34 @@ public struct MessageBackupAuthCredentialManagerImpl: MessageBackupAuthCredentia let authCredentialRepsonse = try JSONDecoder().decode(BackupCredentialResponse.self, from: data) let backupServerPublicParams = try GenericServerPublicParams(contents: [UInt8](TSConstants.backupServerPublicParams)) - return try authCredentialRepsonse.credentials.map { - do { - let redemptionDate = Date(timeIntervalSince1970: TimeInterval($0.redemptionTime)) - let backupRequestContext = try db.read { tx in - return try messageBackupKeyMaterial.backupAuthRequestContext(localAci: localAci, type: purpose.credentialType, tx: tx) + return try authCredentialRepsonse.credentials.reduce(into: [MessageBackupAuthCredentialType: [ReceivedBackupAuthCredentials]]()) { result, element in + let type = element.key + result[type] = try element.value.compactMap { + do { + let redemptionDate = Date(timeIntervalSince1970: TimeInterval($0.redemptionTime)) + let backupRequestContext = try db.read { tx in + let backupKey = try messageBackupKeyMaterial.backupKey(type: type, tx: tx) + return BackupAuthCredentialRequestContext.create(backupKey: backupKey.serialize(), aci: localAci.rawUUID) + } + let backupAuthResponse = try BackupAuthCredentialResponse(contents: [UInt8]($0.credential)) + let credential = try backupRequestContext.receive( + backupAuthResponse, + timestamp: redemptionDate, + params: backupServerPublicParams + ) + return ReceivedBackupAuthCredentials(redemptionTime: $0.redemptionTime, credential: credential) + } catch MessageBackupKeyMaterialError.missingMasterKey where type != credentialType { + return nil + } catch { + owsFailDebug("Error creating credential") + throw error } - let backupAuthResponse = try BackupAuthCredentialResponse(contents: [UInt8]($0.credential)) - let credential = try backupRequestContext.receive( - backupAuthResponse, - timestamp: redemptionDate, - params: backupServerPublicParams - ) - return ReceivedBackupAuthCredentials(redemptionTime: $0.redemptionTime, credential: credential) - } catch { - owsFailDebug("Error creating credential") - throw error } } } private struct BackupCredentialResponse: Decodable { - enum CodingKeys: String, CodingKey { - case credentials - } - - var credentials: [AuthCredential] + var credentials: [MessageBackupAuthCredentialType: [AuthCredential]] struct AuthCredential: Decodable { var redemptionTime: UInt64 diff --git a/SignalServiceKit/tests/Network/Upload/AttachmentUploadManagerTestMocks.swift b/SignalServiceKit/tests/Network/Upload/AttachmentUploadManagerTestMocks.swift index c2b84bd8948..6822ba6cfc7 100644 --- a/SignalServiceKit/tests/Network/Upload/AttachmentUploadManagerTestMocks.swift +++ b/SignalServiceKit/tests/Network/Upload/AttachmentUploadManagerTestMocks.swift @@ -92,54 +92,25 @@ class _AttachmentUploadManager_ChatConnectionManagerMock: ChatConnectionManager } class _AttachmentUploadManager_MessageBackupKeyMaterialMock: MessageBackupKeyMaterial { - func backupKey(mode: MessageBackup.EncryptionMode, tx: any DBReadTransaction) throws -> BackupKey { - fatalError("Unimplemented for tests") - } - - func backupID(localAci: Aci, mode: MessageBackup.EncryptionMode, tx: DBReadTransaction) throws -> Data { - fatalError("Unimplemented for tests") - } - - func backupPrivateKey(localAci: Aci, mode: MessageBackup.EncryptionMode, tx: DBReadTransaction) throws -> PrivateKey { - fatalError("Unimplemented for tests") - } - - func backupAuthRequestContext(localAci: Aci, type: MessageBackupAuthCredentialType, tx: DBReadTransaction) throws -> BackupAuthCredentialRequestContext { - fatalError("Unimplemented for tests") - } - - func messageBackupKey(localAci: Aci, mode: MessageBackup.EncryptionMode, tx: DBReadTransaction) throws -> MessageBackupKey { + func backupKey( + type: MessageBackupAuthCredentialType, + tx: DBReadTransaction + ) throws -> BackupKey { fatalError("Unimplemented for tests") } - func mediaEncryptionMetadata(mediaName: String, type: MediaTierEncryptionType, tx: DBReadTransaction) throws -> MediaTierEncryptionMetadata { + func mediaEncryptionMetadata( + mediaName: String, + type: MediaTierEncryptionType, + tx: any DBReadTransaction + ) throws -> MediaTierEncryptionMetadata { return .init(type: type, mediaId: Data(), hmacKey: Data(), aesKey: Data()) } - - func mediaId(mediaName: String, type: MediaTierEncryptionType, backupKey: BackupKey) throws -> Data { - return Data() - } - - func createEncryptingStreamTransform(localAci: Aci, mode: MessageBackup.EncryptionMode, tx: DBReadTransaction) throws -> EncryptingStreamTransform { - fatalError("Unimplemented for tests") - } - - func createDecryptingStreamTransform(localAci: Aci, mode: MessageBackup.EncryptionMode, tx: DBReadTransaction) throws -> DecryptingStreamTransform { - fatalError("Unimplemented for tests") - } - - func createHmacGeneratingStreamTransform(localAci: Aci, mode: MessageBackup.EncryptionMode, tx: DBReadTransaction) throws -> HmacStreamTransform { - fatalError("Unimplemented for tests") - } - - func createHmacValidatingStreamTransform(localAci: Aci, mode: MessageBackup.EncryptionMode, tx: DBReadTransaction) throws -> HmacStreamTransform { - fatalError("Unimplemented for tests") - } } class _AttachmentUploadManager_MessageBackupRequestManagerMock: MessageBackupRequestManager { func fetchBackupServiceAuth( - for purpose: MessageBackupAuthCredentialManager.Purpose, + for type: MessageBackupAuthCredentialType, localAci: Aci, auth: ChatServiceAuth ) async throws -> MessageBackupServiceAuth { @@ -148,7 +119,7 @@ class _AttachmentUploadManager_MessageBackupRequestManagerMock: MessageBackupReq func reserveBackupId(localAci: Aci, auth: ChatServiceAuth) async throws { } - func registerBackupKeys(auth: MessageBackupServiceAuth) async throws { } + func registerBackupKeys(localAci: Aci, auth: ChatServiceAuth) async throws {} func fetchBackupUploadForm(auth: MessageBackupServiceAuth) async throws -> Upload.Form { fatalError("Unimplemented for tests")