Skip to content

Commit

Permalink
use a provider in an arc
Browse files Browse the repository at this point in the history
  • Loading branch information
codabrink committed Jan 10, 2025
1 parent 556bc2a commit 9a436a2
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 18 deletions.
12 changes: 6 additions & 6 deletions xmtp_mls/src/groups/device_sync/backup.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::storage::DbConnection;
use crate::XmtpOpenMlsProvider;
use backup_stream::{BackupElement, BackupRecordStreamer, BackupStream};
use futures::{Stream, StreamExt};
use futures::Stream;
use serde::{Deserialize, Serialize};
use std::{pin::Pin, sync::Arc};
use xmtp_proto::xmtp::device_sync::consent_backup::ConsentRecordSave;
Expand Down Expand Up @@ -31,24 +31,24 @@ pub enum BackupOptionsElementSelection {
impl BackupOptionsElementSelection {
fn to_streamers<'a>(
&self,
conn: &'a DbConnection,
provider: &Arc<XmtpOpenMlsProvider>,
opts: &BackupOptions,
) -> Vec<Pin<Box<dyn Stream<Item = Vec<BackupElement>> + 'a>>> {
match self {
Self::Consent => vec![Box::pin(BackupRecordStreamer::<ConsentRecordSave>::new(
conn, opts,
provider, opts,
))],
Self::Messages => vec![],
}
}
}

impl BackupOptions {
pub fn write(self, conn: &'static DbConnection) -> BackupStream {
pub fn write(self, provider: &Arc<XmtpOpenMlsProvider>) -> BackupStream {
let input_streams = self
.elements
.iter()
.map(|e| e.to_streamers(conn, &self))
.map(|e| e.to_streamers(provider, &self))
.collect::<Vec<_>>();

BackupStream {
Expand Down
14 changes: 7 additions & 7 deletions xmtp_mls/src/groups/device_sync/backup/backup_stream.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::BackupOptions;
use crate::storage::DbConnection;
use crate::{storage::DbConnection, XmtpOpenMlsProvider};
use futures::Stream;
use serde::{Deserialize, Serialize};
use std::{marker::PhantomData, pin::Pin, sync::Arc};
Expand Down Expand Up @@ -70,27 +70,27 @@ trait BackupRecordProvider {
}

/// A generic struct to make it easier to stream backup records from the database
pub(super) struct BackupRecordStreamer<'a, R> {
pub(super) struct BackupRecordStreamer<R> {
offset: i64,
conn: &'a DbConnection,
provider: Arc<XmtpOpenMlsProvider>,
start_ns: Option<u64>,
end_ns: Option<u64>,
_phantom: PhantomData<R>,
}

impl<'a, R> BackupRecordStreamer<'a, R> {
pub(super) fn new(conn: &'a DbConnection, opts: &BackupOptions) -> Self {
impl<R> BackupRecordStreamer<R> {
pub(super) fn new(provider: &Arc<XmtpOpenMlsProvider>, opts: &BackupOptions) -> Self {
Self {
offset: 0,
conn: conn,
provider: provider.clone(),
start_ns: opts.start_ns,
end_ns: opts.end_ns,
_phantom: PhantomData,
}
}
}

impl<'a, R> Stream for BackupRecordStreamer<'a, R>
impl<R> Stream for BackupRecordStreamer<R>
where
R: BackupRecordProvider + Unpin,
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ impl BackupRecordProvider for ConsentRecordSave {
.offset(streamer.offset);

let batch = streamer
.conn
.provider
.conn_ref()
.raw_query(|conn| query.load::<StoredConsentRecord>(conn))
.expect("Failed to load consent records");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ impl BackupRecordProvider for GroupSave {
query = query.limit(Self::BATCH_SIZE).offset(streamer.offset);

let batch = streamer
.conn
.provider
.conn_ref()
.raw_query(|conn| query.load::<StoredGroup>(conn))
.expect("Failed to load group records");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ impl BackupRecordProvider for GroupMessageSave {
query = query.limit(Self::BATCH_SIZE).offset(streamer.offset);

let batch = streamer
.conn
.provider
.conn_ref()
.raw_query(|conn| query.load::<StoredGroupMessage>(conn))
.expect("Failed to load group records");

Expand Down
3 changes: 1 addition & 2 deletions xmtp_mls/src/xmtp_openmls_provider.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use crate::storage::{db_connection::DbConnectionPrivate, sql_key_store::SqlKeyStore};
use openmls_rust_crypto::RustCrypto;
use openmls_traits::OpenMlsProvider;

use crate::storage::{db_connection::DbConnectionPrivate, sql_key_store::SqlKeyStore};

pub type XmtpOpenMlsProvider = XmtpOpenMlsProviderPrivate<crate::storage::RawDbConnection>;

#[derive(Debug)]
Expand Down

0 comments on commit 9a436a2

Please sign in to comment.