Skip to content

Commit

Permalink
save reference_id to DB and add get_group_messages_with_reactions fun…
Browse files Browse the repository at this point in the history
…ction
  • Loading branch information
cameronvoell committed Jan 7, 2025
1 parent 00a4cbb commit e68a294
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 8 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
DROP INDEX idx_group_messages_reference_id;
ALTER TABLE group_messages
DROP COLUMN reference_id;
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
ALTER TABLE group_messages
ADD COLUMN reference_id BINARY;
CREATE INDEX idx_group_messages_reference_id ON group_messages(reference_id);
1 change: 1 addition & 0 deletions xmtp_mls/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,7 @@ where
version_major: conversation_item.version_major?,
version_minor: conversation_item.version_minor?,
authority_id: conversation_item.authority_id?,
reference_id: None, // conversation_item does not use message reference_id
})
});

Expand Down
4 changes: 4 additions & 0 deletions xmtp_mls/src/groups/mls_sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,7 @@ where
version_major: queryable_content_fields.version_major,
version_minor: queryable_content_fields.version_minor,
authority_id: queryable_content_fields.authority_id,
reference_id: queryable_content_fields.reference_id,
}
.store_or_ignore(provider.conn_ref())?
}
Expand Down Expand Up @@ -591,6 +592,7 @@ where
version_major: 0,
version_minor: 0,
authority_id: "unknown".to_string(),
reference_id: None,
}
.store_or_ignore(provider.conn_ref())?;

Expand Down Expand Up @@ -624,6 +626,7 @@ where
version_major: 0,
version_minor: 0,
authority_id: "unknown".to_string(),
reference_id: None,
}
.store_or_ignore(provider.conn_ref())?;

Expand Down Expand Up @@ -971,6 +974,7 @@ where
version_major: content_type.version_major as i32,
version_minor: content_type.version_minor as i32,
authority_id: content_type.authority_id.to_string(),
reference_id: None,
};

msg.store_or_ignore(conn)?;
Expand Down
41 changes: 34 additions & 7 deletions xmtp_mls/src/groups/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ use openmls_traits::OpenMlsProvider;
use prost::Message;
use thiserror::Error;
use tokio::sync::Mutex;
use xmtp_content_types::reaction::ReactionCodec;

use self::device_sync::DeviceSyncError;
pub use self::group_permissions::PreconfiguredPolicies;
Expand Down Expand Up @@ -66,8 +67,7 @@ use xmtp_proto::xmtp::mls::{
GroupMessage,
},
message_contents::{
plaintext_envelope::{Content, V1},
EncodedContent, PlaintextEnvelope,
content_types::ReactionV2, plaintext_envelope::{Content, V1}, EncodedContent, PlaintextEnvelope
},
};

Expand Down Expand Up @@ -320,6 +320,7 @@ pub struct QueryableContentFields {
pub version_major: i32,
pub version_minor: i32,
pub authority_id: String,
pub reference_id: Option<Vec<u8>>,
}

impl Default for QueryableContentFields {
Expand All @@ -329,20 +330,38 @@ impl Default for QueryableContentFields {
version_major: 0,
version_minor: 0,
authority_id: String::new(),
reference_id: None,
}
}
}

impl From<EncodedContent> for QueryableContentFields {
fn from(content: EncodedContent) -> Self {
impl TryFrom<EncodedContent> for QueryableContentFields {
type Error = prost::DecodeError;

fn try_from(content: EncodedContent) -> Result<Self, Self::Error> {
let content_type_id = content.r#type.unwrap_or_default();
let reference_id = match (
content_type_id.type_id.as_str(),
content_type_id.version_major,
) {
(ReactionCodec::TYPE_ID, major) if major >= 2 => {
let reaction = ReactionV2::decode(content.content.as_slice())?;
hex::decode(reaction.reference).ok()
}
(ReactionCodec::TYPE_ID, _) => {
// TODO: Implement JSON deserialization for legacy reaction format
None
}
_ => None,
};

QueryableContentFields {
Ok(QueryableContentFields {
content_type: content_type_id.type_id.into(),
version_major: content_type_id.version_major as i32,
version_minor: content_type_id.version_minor as i32,
authority_id: content_type_id.authority_id.to_string(),
}
reference_id,
})
}
}

Expand Down Expand Up @@ -746,7 +765,14 @@ impl<ScopedClient: ScopedGroupClient> MlsGroup<ScopedClient> {
// Return early with default if decoding fails or type is missing
EncodedContent::decode(message)
.inspect_err(|e| tracing::debug!("Failed to decode message as EncodedContent: {}", e))
.map(QueryableContentFields::from)
.and_then(|content| {
QueryableContentFields::try_from(content).inspect_err(|e| {
tracing::debug!(
"Failed to convert EncodedContent to QueryableContentFields: {}",
e
)
})
})
.unwrap_or_default()
}

Expand Down Expand Up @@ -792,6 +818,7 @@ impl<ScopedClient: ScopedGroupClient> MlsGroup<ScopedClient> {
version_major: queryable_content_fields.version_major,
version_minor: queryable_content_fields.version_minor,
authority_id: queryable_content_fields.authority_id,
reference_id: queryable_content_fields.reference_id,
};
group_message.store(provider.conn_ref())?;

Expand Down
75 changes: 74 additions & 1 deletion xmtp_mls/src/storage/encrypted_store/group_message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ pub struct StoredGroupMessage {
pub version_minor: i32,
/// The ID of the authority defining the content type
pub authority_id: String,
/// The ID of a referenced message
pub reference_id: Option<Vec<u8>>,
}

pub struct StoredGroupMessageWithReactions {
pub message: StoredGroupMessage,
// Messages who's reference_id matches this message's id
pub reactions: Vec<StoredGroupMessage>,
}

#[derive(Clone, Debug, PartialEq)]
Expand Down Expand Up @@ -213,7 +221,7 @@ impl_fetch!(StoredGroupMessage, group_messages, Vec<u8>);
impl_store!(StoredGroupMessage, group_messages);
impl_store_or_ignore!(StoredGroupMessage, group_messages);

#[derive(Default)]
#[derive(Default, Clone)]
pub struct MsgQueryArgs {
pub sent_after_ns: Option<i64>,
pub sent_before_ns: Option<i64>,
Expand Down Expand Up @@ -282,6 +290,70 @@ impl DbConnection {
Ok(self.raw_query(|conn| query.load::<StoredGroupMessage>(conn))?)
}


/// Query for group messages with their reactions
#[allow(clippy::too_many_arguments)]
pub fn get_group_messages_with_reactions(
&self,
group_id: &[u8],
args: &MsgQueryArgs,
) -> Result<Vec<StoredGroupMessageWithReactions>, StorageError> {
// First get all the main messages
let mut modified_args = args.clone();
// filter out reactions from the main query so we don't get them twice
let mut content_types = modified_args.content_types.clone().unwrap_or_default();
content_types.retain(|content_type| *content_type != ContentType::Reaction);
modified_args.content_types = Some(content_types);
let messages = self.get_group_messages(group_id, &modified_args)?;

// Then get all reactions for these messages in a single query
let message_ids: Vec<&[u8]> = messages.iter().map(|m| m.id.as_slice()).collect();

let mut reactions_query = dsl::group_messages
.filter(dsl::group_id.eq(group_id))
.filter(dsl::reference_id.is_not_null())
.filter(dsl::reference_id.eq_any(message_ids))
.into_boxed();

// Apply the same sorting as the main messages
reactions_query = match args.direction.as_ref().unwrap_or(&SortDirection::Ascending) {
SortDirection::Ascending => reactions_query.order(dsl::sent_at_ns.asc()),
SortDirection::Descending => reactions_query.order(dsl::sent_at_ns.desc()),
};

let reactions: Vec<StoredGroupMessage> =
self.raw_query(|conn| reactions_query.load(conn))?;

// Group reactions by parent message id
let mut reactions_by_reference: std::collections::HashMap<Vec<u8>, Vec<StoredGroupMessage>> =
std::collections::HashMap::new();

for reaction in reactions {
if let Some(reference_id) = &reaction.reference_id {
reactions_by_reference
.entry(reference_id.clone())
.or_default()
.push(reaction);
}
}

// Combine messages with their reactions
let messages_with_reactions: Vec<StoredGroupMessageWithReactions> = messages
.into_iter()
.map(|message| {
let message_clone = message.clone();
StoredGroupMessageWithReactions {
message,
reactions: reactions_by_reference
.remove(&message_clone.id)
.unwrap_or_default(),
}
})
.collect();

Ok(messages_with_reactions)
}

/// Get a particular group message
pub fn get_group_message<MessageId: AsRef<[u8]>>(
&self,
Expand Down Expand Up @@ -370,6 +442,7 @@ pub(crate) mod tests {
version_major: 0,
version_minor: 0,
authority_id: "unknown".to_string(),
reference_id: None,
}
}

Expand Down
1 change: 1 addition & 0 deletions xmtp_mls/src/storage/encrypted_store/schema_gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ diesel::table! {
version_minor -> Integer,
version_major -> Integer,
authority_id -> Text,
reference_id -> Nullable<Binary>,
}
}

Expand Down

0 comments on commit e68a294

Please sign in to comment.