Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(crypto): Make memory store behave more like other stores #4558

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
32 changes: 27 additions & 5 deletions crates/matrix-sdk-crypto/src/gossiping/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1115,6 +1115,7 @@ mod tests {
use crate::{
gossiping::KeyForwardDecision,
olm::OutboundGroupSession,
store::{CryptoStore, DeviceChanges},
types::requests::AnyOutgoingRequest,
types::{
events::{
Expand Down Expand Up @@ -1177,20 +1178,41 @@ mod tests {
let user_id = user_id.to_owned();
let device_id = DeviceId::new();

let account = Account::with_device_id(&user_id, &device_id);
let store = Arc::new(CryptoStoreWrapper::new(&user_id, &device_id, MemoryStore::new()));
let store = Arc::new(store_with_account_helper(&user_id, &device_id).await);
let static_data = store.load_account().await.unwrap().unwrap().static_data;
let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(alice_id())));
let verification =
VerificationMachine::new(account.static_data.clone(), identity.clone(), store.clone());
let store = Store::new(account.static_data().clone(), identity, store, verification);
store.save_pending_changes(PendingChanges { account: Some(account) }).await.unwrap();
VerificationMachine::new(static_data.clone(), identity.clone(), store.clone());
let store = Store::new(static_data, identity, store, verification);

let session_cache = GroupSessionCache::new(store.clone());
let identity_manager = IdentityManager::new(store.clone());

GossipMachine::new(store, identity_manager, session_cache, Default::default())
}

#[cfg(feature = "automatic-room-key-forwarding")]
async fn store_with_account_helper(
user_id: &UserId,
device_id: &DeviceId,
) -> CryptoStoreWrapper {
// Properly create the store by first saving the own device and then the account
// data.
let account = Account::with_device_id(user_id, device_id);
let device = DeviceData::from_account(&account);
device.set_trust_state(LocalTrust::Verified);

let changes = Changes {
devices: DeviceChanges { new: vec![device], ..Default::default() },
..Default::default()
};
let mem_store = MemoryStore::new();
mem_store.save_changes(changes).await.unwrap();
mem_store.save_pending_changes(PendingChanges { account: Some(account) }).await.unwrap();

CryptoStoreWrapper::new(user_id, device_id, mem_store)
}

async fn get_machine_test_helper() -> GossipMachine {
let user_id = alice_id().to_owned();
let account = Account::with_device_id(&user_id, alice_device_id());
Expand Down
91 changes: 4 additions & 87 deletions crates/matrix-sdk-crypto/src/store/caches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,12 @@ use std::{
};

use matrix_sdk_common::locks::RwLock as StdRwLock;
use ruma::{DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId};
use ruma::{DeviceId, OwnedDeviceId, OwnedUserId, UserId};
use serde::{Deserialize, Serialize};
use tokio::sync::{Mutex, RwLock};
use tracing::{field::display, instrument, trace, Span};

use crate::{
identities::DeviceData,
olm::{InboundGroupSession, Session},
};
use crate::{identities::DeviceData, olm::Session};

/// In-memory store for Olm Sessions.
#[derive(Debug, Default, Clone)]
Expand Down Expand Up @@ -86,52 +83,6 @@ impl SessionStore {
}
}

#[derive(Debug, Default)]
/// In-memory store that holds inbound group sessions.
pub struct GroupSessionStore {
entries: StdRwLock<BTreeMap<OwnedRoomId, HashMap<String, InboundGroupSession>>>,
}

impl GroupSessionStore {
/// Create a new empty store.
pub fn new() -> Self {
Self::default()
}

/// Add an inbound group session to the store.
///
/// Returns true if the session was added, false if the session was
/// already in the store.
pub fn add(&self, session: InboundGroupSession) -> bool {
self.entries
.write()
.entry(session.room_id().to_owned())
.or_default()
.insert(session.session_id().to_owned(), session)
.is_none()
}

/// Get all the group sessions the store knows about.
pub fn get_all(&self) -> Vec<InboundGroupSession> {
self.entries.read().values().flat_map(HashMap::values).cloned().collect()
}

/// Get the number of `InboundGroupSession`s we have.
pub fn count(&self) -> usize {
self.entries.read().values().map(HashMap::len).sum()
}

/// Get a inbound group session from our store.
///
/// # Arguments
/// * `room_id` - The room id of the room that the session belongs to.
///
/// * `session_id` - The unique id of the session.
pub fn get(&self, room_id: &RoomId, session_id: &str) -> Option<InboundGroupSession> {
self.entries.read().get(room_id)?.get(session_id).cloned()
}
}

/// In-memory store holding the devices of users.
#[derive(Debug, Default)]
pub struct DeviceStore {
Expand Down Expand Up @@ -381,13 +332,10 @@ impl UsersForKeyQuery {
mod tests {
use matrix_sdk_test::async_test;
use proptest::prelude::*;
use ruma::room_id;
use vodozemac::{Curve25519PublicKey, Ed25519PublicKey};

use super::{DeviceStore, GroupSessionStore, SequenceNumber, SessionStore};
use super::{DeviceStore, SequenceNumber, SessionStore};
use crate::{
identities::device::testing::get_device,
olm::{tests::get_account_and_session_test_helper, InboundGroupSession, SenderData},
identities::device::testing::get_device, olm::tests::get_account_and_session_test_helper,
};

#[async_test]
Expand Down Expand Up @@ -422,37 +370,6 @@ mod tests {
assert_eq!(&session, loaded_session);
}

#[async_test]
async fn test_group_session_store() {
let (account, _) = get_account_and_session_test_helper();
let room_id = room_id!("!test:localhost");
let curve_key = "Nn0L2hkcCMFKqynTjyGsJbth7QrVmX3lbrksMkrGOAw";

let (outbound, _) = account.create_group_session_pair_with_defaults(room_id).await;

assert_eq!(0, outbound.message_index().await);
assert!(!outbound.shared());
outbound.mark_as_shared();
assert!(outbound.shared());

let inbound = InboundGroupSession::new(
Curve25519PublicKey::from_base64(curve_key).unwrap(),
Ed25519PublicKey::from_base64("ee3Ek+J2LkkPmjGPGLhMxiKnhiX//xcqaVL4RP6EypE").unwrap(),
room_id,
&outbound.session_key().await,
SenderData::unknown(),
outbound.settings().algorithm.to_owned(),
None,
)
.unwrap();

let store = GroupSessionStore::new();
store.add(inbound.clone());

let loaded_session = store.get(room_id, outbound.session_id()).unwrap();
assert_eq!(inbound, loaded_session);
}

#[async_test]
async fn test_device_store() {
let device = get_device();
Expand Down
Loading
Loading