diff --git a/backend/fastapi/Autocomplete.py b/backend/fastapi/Autocomplete.py index fbb34ab3..c8a5b387 100644 --- a/backend/fastapi/Autocomplete.py +++ b/backend/fastapi/Autocomplete.py @@ -1,12 +1,14 @@ from pprint import pprint +from helpers import get_guild_collection -def autocomplete_categories(db, guild_id: str, partial_category: str, limit: int): + +def autocomplete_categories(guild_id: str, partial_category: str, limit: int): """ Searches for categories. limited to {limit} results * 10 """ - collection_channels = db["channels"] + collection_channels = get_guild_collection(guild_id, "channels") # ignore "GuildPublicThread" or "GuildPrivateThread", because their category is channel name query = { @@ -55,12 +57,12 @@ def autocomplete_categories(db, guild_id: str, partial_category: str, limit: int return category_names -def autocomplete_channels(db, guild_id: str, partial_channel: str, limit: int): +def autocomplete_channels(guild_id: str, partial_channel: str, limit: int): """ Searches for channels. limited to {limit} results """ - collection_channels = db["channels"] + collection_channels = get_guild_collection(guild_id, "channels") query = { "name": { @@ -86,12 +88,12 @@ def autocomplete_channels(db, guild_id: str, partial_channel: str, limit: int): return channel_names -def autocomplete_reactions(db, guild_id: str, partial_reaction: str, limit: int): +def autocomplete_reactions(guild_id: str, partial_reaction: str, limit: int): """ Searches for reactions. limited to {limit} results """ - collection_emojis = db["emojis"] + collection_emojis = get_guild_collection(guild_id, "emojis") query = { "name": { @@ -121,12 +123,12 @@ def autocomplete_reactions(db, guild_id: str, partial_reaction: str, limit: int) return reaction_names -def autocomplete_filenames(db, guild_id: str, partial_filename: str, limit: int): +def autocomplete_filenames(guild_id: str, partial_filename: str, limit: int): """ Searches for filenames. limited to {limit} results """ - collection_assets = db["assets"] + collection_assets = get_guild_collection(guild_id, "assets") query = { "filenameWithoutHash": { @@ -157,13 +159,14 @@ def autocomplete_filenames(db, guild_id: str, partial_filename: str, limit: int) return filenames -def autocomplete_users(db, guild_id: str, partial_user_name: str, limit: int): +def autocomplete_users(guild_id: str, partial_user_name: str, limit: int): """ Searches for users by name. limited to {limit} results only shows users that have messages in the guild {guild_id} """ - collection_authors = db["authors"] + collection_authors = get_guild_collection(guild_id, "authors") + print("collection_authors", collection_authors) query = { "guildIds": guild_id, diff --git a/backend/fastapi/app.py b/backend/fastapi/app.py index c543727c..0943325f 100644 --- a/backend/fastapi/app.py +++ b/backend/fastapi/app.py @@ -7,8 +7,10 @@ import traceback import pymongo from fastapi import FastAPI, Query +from pydantic import BaseModel import Autocomplete +from helpers import get_global_collection, pad_id, get_guild_collection # fix PIPE encoding error on Windows, auto flush print sys.stdout.reconfigure(encoding='utf-8') @@ -16,10 +18,7 @@ print = functools.partial(print, flush=True) -def pad_id(id): - if id == None: - return None - return str(id).zfill(24) + # specify guild ids that should be hidden from the public (list of strings) @@ -29,16 +28,9 @@ def pad_id(id): blacklisted_guild_ids = [pad_id(id) for id in blacklisted_guild_ids] -URI = "mongodb://127.0.0.1:27017" -client = pymongo.MongoClient(URI) -db = client["dcef"] -collection_messages = db["messages"] -collection_channels = db["channels"] -collection_guilds = db["guilds"] -collection_authors = db["authors"] -collection_emojis = db["emojis"] -collection_assets = db["assets"] -collection_roles = db["roles"] + + + app = FastAPI( title="DCEF backend api", @@ -81,6 +73,7 @@ async def get_guilds(guild_id: str = None): Filters out blacklisted guilds from the config.toml file. """ + collection_guilds = get_global_collection("guilds") if guild_id: if guild_id in blacklisted_guild_ids: return {"message": "Not found"} @@ -112,12 +105,8 @@ async def get_channels(guild_id: str = None, channel_id: str = None): Optionally, a guild_id query parameter can be provided to filter by guild. Optionally, a channel_id query parameter can be provided to get only specific channel. """ - if guild_id: - if guild_id in blacklisted_guild_ids: - return [] + collection_channels = get_guild_collection(guild_id, "channels") - cursor = collection_channels.find({"guildId": guild_id}) - return list(cursor) if channel_id: channel = collection_channels.find_one( @@ -132,6 +121,13 @@ async def get_channels(guild_id: str = None, channel_id: str = None): return {"message": "Not found"} return channel + if guild_id: + if guild_id in blacklisted_guild_ids: + return [] + + cursor = collection_channels.find({"guildId": guild_id}) + return list(cursor) + cursor = collection_channels.find( { "guildId": { @@ -150,12 +146,7 @@ async def get_roles(guild_id: str = None, role_id: str = None): Optionally, a guild_id query parameter can be provided to filter by guild. Optionally, a role_id query parameter can be provided to get only specific role. """ - if guild_id: - if guild_id in blacklisted_guild_ids: - return [] - - cursor = collection_roles.find({"guildId": guild_id}) - return list(cursor) + collection_roles = get_guild_collection(guild_id, "roles") if role_id: role = collection_roles.find_one( @@ -170,6 +161,13 @@ async def get_roles(guild_id: str = None, role_id: str = None): return {"message": "Not found"} return role + if guild_id: + if guild_id in blacklisted_guild_ids: + return [] + + cursor = collection_roles.find({"guildId": guild_id}) + return list(cursor) + # order by position desc cursor = collection_roles.find( { @@ -182,11 +180,12 @@ async def get_roles(guild_id: str = None, role_id: str = None): @app.get("/message-ids") -async def get_message_ids(channel_id: str = None): +async def get_message_ids(channel_id: str, guild_id: str): """ Returns a list of message ids. Optionally, a channel_id query parameter can be provided to filter by channel. """ + collection_messages = get_guild_collection(guild_id, "messages") if is_compiled(): cache_path = f"../../storage/cache/message-ids/{channel_id}.json" else: @@ -220,32 +219,21 @@ async def get_message_ids(channel_id: str = None): return new_ids -@app.get("/message") -async def get_message_content(message_id: str): - """ - Returns the content of a message by its id. - """ - message = collection_messages.find_one( - { - "_id": message_id, - "guildId": { - "$nin": blacklisted_guild_ids - } - } - ) - if not message: - return {"message": "Not found"} - - if message["guildId"] in blacklisted_guild_ids: - return {"message": "Not found"} - return message +class MessageRequest(BaseModel): + message_ids: list + guild_id: str @app.post("/messages") -async def get_multiple_message_content(message_ids: list): +async def get_multiple_message_content(message_req_obj: MessageRequest): """ Returns the content of multiple messages by their ids. """ + message_ids = message_req_obj.message_ids + guild_id = message_req_obj.guild_id + + collection_messages = get_guild_collection(guild_id, "messages") + messages = collection_messages.find( { "_id": { @@ -257,7 +245,7 @@ async def get_multiple_message_content(message_ids: list): } ) list_of_messages = list(messages) - list_of_messages = enrich_messages(list_of_messages) + list_of_messages = enrich_messages(list_of_messages, guild_id) return list_of_messages @@ -265,6 +253,8 @@ def channel_names_to_ids(in_channel_ids: list, in_channels: list, guild_id: str """ Convert channel names to ids. """ + collection_channels = get_guild_collection(guild_id, "channels") + if len(in_channels) == 0: return in_channel_ids @@ -294,6 +284,8 @@ def category_names_to_ids(in_category_ids: list, in_categories: list, guild_id: """ Convert category names to ids. """ + collection_channels = get_guild_collection(guild_id, "channels") + if len(in_categories) == 0: return in_category_ids @@ -312,13 +304,15 @@ def category_names_to_ids(in_category_ids: list, in_categories: list, guild_id: return out_category_ids -def extend_channels(channels: list): +def extend_channels(channels: list, guild_id: str): """ Extend a list of channels with thread ids and forum post ids. Can be also used to extend a list of categories with channel ids. In this case, we will not clean category ids from the list, but it causes no problems except for a little bit of performance loss. """ + collection_channels = get_guild_collection(guild_id, "channels") + if len(channels) == 0: return channels @@ -329,11 +323,12 @@ def extend_channels(channels: list): channels = list(set(channels)) # remove duplicates return channels -def extend_users(user_ids: list, usernames: list): +def extend_users(user_ids: list, usernames: list, guild_id: str): """ Find new user ids by user names. exactly match user names """ + collection_authors = get_guild_collection(guild_id, "authors") if len(usernames) == 0: return user_ids @@ -352,12 +347,14 @@ def extend_users(user_ids: list, usernames: list): return user_ids -def extend_reactions(reaction_ids: list, reactions: list): +def extend_reactions(reaction_ids: list, reactions: list, guild_id: str): """ Find new reaction ids by reaction names. Support partial or lowercase match. """ + collection_emojis = get_guild_collection(guild_id, "emojis") + if len(reactions) == 0: return reaction_ids @@ -378,11 +375,12 @@ def extend_reactions(reaction_ids: list, reactions: list): -def get_emotes_from_db(emote_names: list) -> dict: +def get_emotes_from_db(emote_names: list, guild_id: str) -> dict: """ try to find emotes from DB by their name use exact match only """ + collection_emojis = get_guild_collection(guild_id, "emojis") if len(emote_names) == 0: return {} @@ -396,11 +394,13 @@ def get_emotes_from_db(emote_names: list) -> dict: emotes = {emote["name"]: emote for emote in emotes} return emotes -def get_channel_info(channel_id): +def get_channel_info(channel_id: str, guild_id: str): """ get channel info by id 'channel' can be thread or channel or forum post """ + collection_channels = get_guild_collection(guild_id, "channels") + collection_messages = get_guild_collection(guild_id, "messages") channel = collection_channels.find_one( { @@ -419,7 +419,7 @@ def get_channel_info(channel_id): return channel -def enrich_messages(list_of_messages: list) -> list: +def enrich_messages(list_of_messages: list, guild_id: str) -> list: regex = re.compile(r':([^ ]+):') possible_emotes = [] @@ -434,7 +434,7 @@ def enrich_messages(list_of_messages: list) -> list: # get all emotes from db - emotes = get_emotes_from_db(emote_names=possible_emotes) + emotes = get_emotes_from_db(emote_names=possible_emotes, guild_id=guild_id) # replace emotes in messages for message in list_of_messages: @@ -452,7 +452,7 @@ def enrich_messages(list_of_messages: list) -> list: for message in list_of_messages: if message["type"] == "ThreadCreated": - message["thread"] = get_channel_info(message["reference"]["channelId"]) + message["thread"] = get_channel_info(message["reference"]["channelId"], guild_id) return list_of_messages @@ -664,15 +664,15 @@ def search_autocomplete(guild_id: str = None, key: str = None, value: str = None guild_id = pad_id(guild_id) if (key == "users"): - return Autocomplete.autocomplete_users(db, guild_id, value, limit) + return Autocomplete.autocomplete_users(guild_id, value, limit) elif (key == "filenames"): - return Autocomplete.autocomplete_filenames(db, guild_id, value, limit) + return Autocomplete.autocomplete_filenames(guild_id, value, limit) elif (key == "reactions"): - return Autocomplete.autocomplete_reactions(db, guild_id, value, limit) + return Autocomplete.autocomplete_reactions(guild_id, value, limit) elif (key == "channels"): - return Autocomplete.autocomplete_channels(db, guild_id, value, limit) + return Autocomplete.autocomplete_channels(guild_id, value, limit) elif (key == "categories"): - return Autocomplete.autocomplete_categories(db, guild_id, value, limit) + return Autocomplete.autocomplete_categories(guild_id, value, limit) else: return [] @@ -781,6 +781,8 @@ async def search_messages(prompt: str = None, guild_id: str = None, only_ids: bo Searches for messages that contain the prompt. """ + collection_messages = get_guild_collection(guild_id, "messages") + try: # todo: parse prompt search = parse_prompt(prompt) @@ -815,27 +817,27 @@ async def search_messages(prompt: str = None, guild_id: str = None, only_ids: bo message_ids = [pad_id(id) for id in message_ids] from_user_ids = [pad_id(id) for id in from_user_ids] - from_user_ids = extend_users(from_user_ids, from_users) + from_user_ids = extend_users(from_user_ids, from_users, guild_id) print("from_user_ids", from_user_ids) reaction_from_ids = [pad_id(id) for id in reaction_from_ids] - reaction_from_ids = extend_users(reaction_from_ids, reaction_from) + reaction_from_ids = extend_users(reaction_from_ids, reaction_from, guild_id) mentions_user_ids = [pad_id(id) for id in mentions_user_ids] - mentions_user_ids = extend_users(mentions_user_ids, mentions_users) + mentions_user_ids = extend_users(mentions_user_ids, mentions_users, guild_id) reaction_ids = [pad_id(id) for id in reaction_ids] reactions = [reaction.lower() for reaction in reactions] - reaction_ids = extend_reactions(reaction_ids, reactions) + reaction_ids = extend_reactions(reaction_ids, reactions, guild_id) extensions = [ext.lower() for ext in extensions] in_channel_ids = channel_names_to_ids(in_channel_ids, in_channels, guild_id) in_channel_ids = [pad_id(id) for id in in_channel_ids] - in_channel_ids = extend_channels(in_channel_ids) # extend channels with threads and forum posts + in_channel_ids = extend_channels(in_channel_ids, guild_id) # extend channels with threads and forum posts in_category_ids = [pad_id(id) for id in in_category_ids] in_category_ids = category_names_to_ids(in_category_ids, in_categories, guild_id) in_category_ids = [pad_id(id) for id in in_category_ids] - in_category_ids = extend_channels(in_category_ids) # extend categories with channels - in_category_ids = extend_channels(in_category_ids) # extend channels with threads and forum posts + in_category_ids = extend_channels(in_category_ids, guild_id) # extend categories with channels + in_category_ids = extend_channels(in_category_ids, guild_id) # extend channels with threads and forum posts @@ -1020,7 +1022,7 @@ async def search_messages(prompt: str = None, guild_id: str = None, only_ids: bo return ids else: list_of_messages = list(cursor) - list_of_messages = enrich_messages(list_of_messages) + list_of_messages = enrich_messages(list_of_messages, guild_id) return list_of_messages except Exception as e: print("/search error:") diff --git a/backend/fastapi/helpers.py b/backend/fastapi/helpers.py new file mode 100644 index 00000000..0bce6c0a --- /dev/null +++ b/backend/fastapi/helpers.py @@ -0,0 +1,22 @@ +import pymongo + +URI = "mongodb://127.0.0.1:27017" +client = pymongo.MongoClient(URI) +db = client["dcef"] +collection_messages = db["messages"] +collection_guilds = db["guilds"] +collection_authors = db["authors"] +collection_emojis = db["emojis"] +collection_assets = db["assets"] +collection_roles = db["roles"] + +def get_guild_collection(guild_id, collection_name): + return db[f"g{pad_id(guild_id)}_{collection_name}"] + +def get_global_collection(collection_name): + return db[collection_name] + +def pad_id(id): + if id == None: + return None + return str(id).zfill(24) \ No newline at end of file diff --git a/backend/preprocess/AssetProcessor.py b/backend/preprocess/AssetProcessor.py index 95a142f9..2e215801 100644 --- a/backend/preprocess/AssetProcessor.py +++ b/backend/preprocess/AssetProcessor.py @@ -8,17 +8,27 @@ from FileFinder import FileFinder from MongoDatabase import MongoDatabase +from helpers import pad_id print = functools.partial(print, flush=True) class AssetProcessor: def __init__(self, file_finder: FileFinder, database: MongoDatabase): + self.database = database self.file_finder = file_finder - self.collection_assets = database.get_collection("assets") + self.collection_assets = None self.local_assets = None self.fast_mode = False self.cache = {} + def set_guild_id(self, guild_id): + """ + it is required to set guild id, to set the correct collection + """ + guild_collections = self.database.get_guild_collections(guild_id) + self.collection_assets = guild_collections["assets"] + + def set_fast_mode(self, fast_mode: bool): self.fast_mode = fast_mode @@ -169,6 +179,9 @@ def process(self, original_filepath: str): original_filepath is the path from json file, but is not necessarily a valid path """ + if self.collection_assets is None: + raise Exception("AssetProcessor: Guild id not set, call set_guild_id method first") + # do not process twice the same file. Processing is relatively slow if original_filepath in self.cache: return self.cache[original_filepath] diff --git a/backend/preprocess/JsonProcessor.py b/backend/preprocess/JsonProcessor.py index 9e8082ab..b2d31660 100644 --- a/backend/preprocess/JsonProcessor.py +++ b/backend/preprocess/JsonProcessor.py @@ -22,16 +22,10 @@ class JsonProcessor: def __init__(self, database: MongoDatabase, file_finder: FileFinder, json_path:str, asset_processor: AssetProcessor, index: int, total: int): + self.collections = None # will be set after guild id is known + self.json_path = json_path self.database = database - self.collection_guilds = self.database.get_collection("guilds") - self.collection_channels = self.database.get_collection("channels") - self.collection_messages = self.database.get_collection("messages") - self.collection_authors = self.database.get_collection("authors") - self.collection_emojis = self.database.get_collection("emojis") - self.collection_assets = self.database.get_collection("assets") - self.collection_roles = self.database.get_collection("roles") - self.collection_jsons = self.database.get_collection("jsons") self.file_finder = file_finder self.asset_processor = asset_processor self.index = index @@ -39,6 +33,7 @@ def __init__(self, database: MongoDatabase, file_finder: FileFinder, json_path:s def process_guild(self, guild): guild["_id"] = pad_id(guild.pop("id")) + self.asset_processor.set_guild_id(guild["_id"]) guild["icon"] = self.asset_processor.process(guild.pop("iconUrl")) return guild @@ -297,7 +292,7 @@ def process_roles(self, messages: list, guild_id: str, exported_at: str, roles: return roles def insert_guild(self, guild): - database_document = self.collection_guilds.find_one({"_id": guild["_id"]}) + database_document = self.collections["guilds"].find_one({"_id": guild["_id"]}) if database_document != None: # guild already exists, ignore @@ -305,10 +300,10 @@ def insert_guild(self, guild): guild["msg_count"] = 0 - self.collection_guilds.insert_one(guild) + self.collections["guilds"].insert_one(guild) def insert_channel(self, channel): - database_document = self.collection_channels.find_one({"_id": channel["_id"]}) + database_document = self.collections["channels"].find_one({"_id": channel["_id"]}) if database_document != None: # channel already exists @@ -316,15 +311,15 @@ def insert_channel(self, channel): channel["msg_count"] = 0 - self.collection_channels.insert_one(channel) + self.collections["channels"].insert_one(channel) def insert_author(self, author): - database_author = self.collection_authors.find_one({"_id": author["_id"]}) + database_author = self.collections["authors"].find_one({"_id": author["_id"]}) if database_author == None: # author doesn't exist yet author["msg_count"] = 0 - self.collection_authors.insert_one(author) + self.collections["authors"].insert_one(author) return # merge new author with existing author @@ -333,7 +328,7 @@ def insert_author(self, author): names = list(set(author["names"] + database_author["names"])) # update guildIds and nicknames in database - self.collection_authors.update_one({"_id": author["_id"]}, { + self.collections["authors"].update_one({"_id": author["_id"]}, { "$set": { "guildIds": guildIds, "nicknames": nicknames, @@ -344,19 +339,19 @@ def insert_author(self, author): def insert_emoji(self, emoji, guild_id): - database_document = self.collection_emojis.find_one({"_id": emoji['emoji']["_id"]}) + database_document = self.collections["emojis"].find_one({"_id": emoji['emoji']["_id"]}) if database_document == None: # new emoji emoji['emoji']["usage_count"] = emoji['count'] emoji['emoji']["guildIds"] = [guild_id] - self.collection_emojis.insert_one(emoji['emoji']) + self.collections["emojis"].insert_one(emoji['emoji']) return guildIds = list(set(emoji['emoji']["guildIds"] + database_document["guildIds"])) # increase usage count - self.collection_emojis.update_one({"_id": emoji['emoji']["_id"]}, { + self.collections["emojis"].update_one({"_id": emoji['emoji']["_id"]}, { "$inc": { "usage_count": emoji['count'] }, @@ -366,11 +361,11 @@ def insert_emoji(self, emoji, guild_id): }) def insert_role(self, role): - database_document = self.collection_roles.find_one({"_id": role["_id"]}) + database_document = self.collections["roles"].find_one({"_id": role["_id"]}) if database_document == None: # new role - self.collection_roles.insert_one(role) + self.collections["roles"].insert_one(role) return def mark_as_processed(self, json_path): @@ -394,7 +389,7 @@ def mark_as_processed(self, json_path): hex_hash = file_hash.hexdigest() - self.collection_jsons.insert_one({ + self.collections["jsons"].insert_one({ "_id": json_path, "size": file_size, "sha256_hash": hex_hash, @@ -477,6 +472,8 @@ def process(self): print(' exported_at:', exported_at) guild = self.process_guild(guild) + self.collections = self.database.get_guild_collections(guild["_id"]) + channel = self.process_channel(channel, guild["_id"]) guild['exportedAt'] = exported_at channel['exportedAt'] = exported_at @@ -485,7 +482,7 @@ def process(self): roles = {} # role_id -> role_object print(' deleted messages - stage 1/3') - iterator = self.collection_messages.find({"channelId": channel["_id"]}, {"_id": 1, "sources": 1}).sort("_id", 1) + iterator = self.collections["messages"].find({"channelId": channel["_id"]}, {"_id": 1, "sources": 1}).sort("_id", 1) old_channel_ids_by_source = {} # source -> set of ids for message in iterator: for source in message["sources"]: @@ -535,26 +532,26 @@ def process(self): message_ids = [message["_id"] for message in messages] print(' getting existing messages') - existing_messages = list(self.collection_messages.find({"_id": {"$in": message_ids}})) + existing_messages = list(self.collections["messages"].find({"_id": {"$in": message_ids}})) print(' existing messages count:', len(list(existing_messages))) print(' removing existing messages') - self.collection_messages.delete_many({"_id": {"$in": message_ids}}) + self.collections["messages"].delete_many({"_id": {"$in": message_ids}}) print(' merging messages') messages = self.merge_messages(list(messages), list(existing_messages)) # insert messages print(' inserting messages') - self.collection_messages.insert_many(messages) + self.collections["messages"].insert_many(messages) print(' updating message counts') new_messages_count = len(messages) - len(list(existing_messages)) # update message count of channel - self.collection_channels.update_one({"_id": message["channelId"]}, {"$inc": {"msg_count": new_messages_count}}) + self.collections["channels"].update_one({"_id": message["channelId"]}, {"$inc": {"msg_count": new_messages_count}}) # update message count of guild - self.collection_guilds.update_one({"_id": message["guildId"]}, {"$inc": {"msg_count": new_messages_count}}) + self.collections["guilds"].update_one({"_id": message["guildId"]}, {"$inc": {"msg_count": new_messages_count}}) # update message count of author bulk = [] @@ -563,7 +560,7 @@ def process(self): for message in messages: bulk.append(UpdateOne({"_id": message["author"]["_id"]}, {"$inc": {"msg_count": 1}})) if len(bulk) > 0: - self.collection_authors.bulk_write(bulk) + self.collections["authors"].bulk_write(bulk) print(' inserting emojis') for emoji in emojis: @@ -583,7 +580,7 @@ def process(self): deleted_messages_ids = find_additional_missing_numbers(old_channel_ids_by_source, new_channel_ids) print(f' found {len(deleted_messages_ids)} new deleted messages') print(' deleted messages - stage 3/3') - self.collection_messages.update_many({"_id": {"$in": list(deleted_messages_ids)}}, {"$set": {"isDeleted": True}}) + self.collections["messages"].update_many({"_id": {"$in": list(deleted_messages_ids)}}, {"$set": {"isDeleted": True}}) self.mark_as_processed(self.json_path) \ No newline at end of file diff --git a/backend/preprocess/MongoDatabase.py b/backend/preprocess/MongoDatabase.py index 05a5984f..f12aa8c9 100644 --- a/backend/preprocess/MongoDatabase.py +++ b/backend/preprocess/MongoDatabase.py @@ -1,6 +1,7 @@ import functools from pymongo import MongoClient +from helpers import pad_id print = functools.partial(print, flush=True) @@ -17,43 +18,82 @@ def __init__(self): self.database = self.client[DATABASE] self.col = { # collections - "messages": self.database["messages"], - "channels": self.database["channels"], "guilds": self.database["guilds"], - "authors": self.database["authors"], - "emojis": self.database["emojis"], - "jsons": self.database["jsons"], - "assets": self.database["assets"], - "roles": self.database["roles"], "jsons": self.database["jsons"], "config": self.database["config"], } self.create_indexes() + def get_guild_collections(self, guild_id): + """ + Returns a list of collections that are guild specific + they are prefixed with the guild id and _ (underscore) + """ + padded_guild_id = pad_id(guild_id) + return { + "messages": self.database[f"g{padded_guild_id}_messages"], + "channels": self.database[f"g{padded_guild_id}_channels"], + "authors": self.database[f"g{padded_guild_id}_authors"], + "emojis": self.database[f"g{padded_guild_id}_emojis"], + "assets": self.database[f"g{padded_guild_id}_assets"], + "roles": self.database[f"g{padded_guild_id}_roles"], + "guilds": self.database["guilds"], + "jsons": self.database["jsons"], + "config": self.database["config"], + } + def create_indexes(self): # create case insensitive text indexes # self.col["messages"].create_index("content.content", default_language="none") - self.col["messages"].create_index("channelId", default_language="none") + pass + # TODO: add this back + # self.col["messages"].create_index("channelId", default_language="none") - def clear_database_except_assets(self): + def clear_database(self, guild_ids): """ Clears the database Useful for debugging Assets are not cleared, because they are expensive to recompute """ - for collection_name in self.col: - if collection_name == "assets": # assets are expensive to recompute - continue - if collection_name == "config": - continue - self.col[collection_name].delete_many({}) + pass + # TODO: add this back + + for guild_id in guild_ids: + print(f"Wiping guild {guild_id}...") + collections = self.get_guild_collections(guild_id) - def clear_assets(self): - self.col["assets"].delete_many({}) + # key, value + for collection_name, collection in collections.items(): + if collection_name in ["config", "guilds", "jsons"]: + continue + print(f" Wiping collection {collection_name}...") + collection.drop() + + print(f" Wiping collection {collection_name}...") + + print(f"Wiping global...") + # this list contains old collections too, that are not used anymore + global_collections = [ + "messages", + "channels", + "authors", + "emojis", + "assets", + "roles", + "jsons", + "guilds", + ] + + for collection_name in global_collections: + print(f" Wiping collection {collection_name}...") + self.database[collection_name].drop() def get_collection(self, collection_name): + """ + returns global collection by name + """ return self.col[collection_name] diff --git a/backend/preprocess/main_mongo.py b/backend/preprocess/main_mongo.py index 68af37f6..34f07328 100644 --- a/backend/preprocess/main_mongo.py +++ b/backend/preprocess/main_mongo.py @@ -20,12 +20,12 @@ print = functools.partial(print, flush=True) -def wipe_database(database): +def wipe_database(database: MongoDatabase): """ Deletes all collections on version bump (on program update) Change EXPECTED_VERSION to force wipe on incompatible schema changes """ - EXPECTED_VERSION = 12 # <---- change this to wipe database + EXPECTED_VERSION = 13 # <---- change this to wipe database config = database.get_collection("config") version = config.find_one({"key": "version"}) @@ -37,9 +37,11 @@ def wipe_database(database): print("Database schema up to date, no wipe needed") return - print("Wiping database...") - database.clear_database_except_assets() - database.clear_assets() + guild_ids = database.get_collection("guilds").find({}, {"_id": 1}) + guild_ids = [guild["_id"] for guild in guild_ids] + + print("Wiping old database...") + database.clear_database(guild_ids) print("Done wiping database") version["value"] = EXPECTED_VERSION @@ -93,8 +95,7 @@ def main(input_dir, output_dir): wipe_database(database) # DEBUG clear database - # database.clear_database_except_assets() - # database.clear_assets() + # database.clear_database() file_finder = FileFinder(input_dir) diff --git a/frontend/src/components/messages/MessageLoader.svelte b/frontend/src/components/messages/MessageLoader.svelte index f38fa1cf..d4f289ea 100644 --- a/frontend/src/components/messages/MessageLoader.svelte +++ b/frontend/src/components/messages/MessageLoader.svelte @@ -7,11 +7,11 @@ export let guildName: string; // fetch message from api - async function fetchMessages(messageId: string, previousMessageId: string | null) { - const messagePromise = getMessageContent(messageId); + async function fetchMessages(messageId: string, previousMessageId: string | null, selectedGuildId: string) { + const messagePromise = getMessageContent(messageId, selectedGuildId); let previousMessage if (previousMessageId !== null) { - previousMessage = await getMessageContent(previousMessageId); + previousMessage = await getMessageContent(previousMessageId, selectedGuildId); } else { previousMessage = null; @@ -22,7 +22,7 @@ let referencedMessage if (message.reference) { - referencedMessage = await getMessageContent(message.reference.messageId); + referencedMessage = await getMessageContent(message.reference.messageId, selectedGuildId); } else { referencedMessage = null; @@ -36,7 +36,7 @@ } // promise - let fullMessagesPromise = fetchMessages(messageId, previousMessageId); + let fullMessagesPromise = fetchMessages(messageId, previousMessageId, selectedGuildId); {#if messageId == "error"}
SEARCH ERROR - check server logs for details
@@ -48,7 +48,7 @@ {/key} {:catch error} -
{error} fullMessagesPromise = fetchMessages(messageId, previousMessageId)}>retry
+
{error} fullMessagesPromise = fetchMessages(messageId, previousMessageId, selectedGuildId)}>retry
{/await} {/if} diff --git a/frontend/src/components/messages/MessageMarkdown.svelte b/frontend/src/components/messages/MessageMarkdown.svelte index 8c9a6934..64bd4925 100644 --- a/frontend/src/components/messages/MessageMarkdown.svelte +++ b/frontend/src/components/messages/MessageMarkdown.svelte @@ -161,7 +161,7 @@ let channelId = matches[1] let fullMatch = matches[0] - let channelInfo = await getChannelInfo(channelId) + let channelInfo = await getChannelInfo(channelId, guildId) processedContent = processedContent.replace(fullMatch, `${channelIcon} ${channelInfo.name} `) } } @@ -174,7 +174,7 @@ let roleId = matches[1] let fullMatch = matches[0] - let roleInfo = await getRoleInfo(roleId) + let roleInfo = await getRoleInfo(roleId, guildId) processedContent = processedContent.replace(fullMatch, `@${roleInfo.name}`) } } diff --git a/frontend/src/js/api.ts b/frontend/src/js/api.ts index a014d691..fda6e4ad 100644 --- a/frontend/src/js/api.ts +++ b/frontend/src/js/api.ts @@ -14,12 +14,12 @@ function createMockChannelInfo(channelId: string | undefined): Channel { } -export async function getChannelInfo(channelId: string): Promise { +export async function getChannelInfo(channelId: string, guildId: string): Promise { if (!channelId) { return createMockChannelInfo(undefined); } const paddedChannelId = channelId.padStart(24, "0"); - const response = await fetch(`/api/channels?channel_id=${paddedChannelId}`); + const response = await fetch(`/api/channels?channel_id=${paddedChannelId}&guild_id=${guildId}`); const json = await response.json(); if (!json._id) { @@ -30,7 +30,7 @@ export async function getChannelInfo(channelId: string): Promise { } -export async function getRoleInfo(roleId: string) { +export async function getRoleInfo(roleId: string, guildId: string) { if (!roleId) { return { "color":"#d4e0fc", @@ -39,7 +39,7 @@ export async function getRoleInfo(roleId: string) { }; } const paddedRoleId = roleId.padStart(24, "0"); - const response = await fetch(`/api/roles?role_id=${paddedRoleId}`); + const response = await fetch(`/api/roles?role_id=${paddedRoleId}&guild_id=${guildId}`); const json = await response.json(); if (!json._id) { diff --git a/frontend/src/js/messageMiddleware.ts b/frontend/src/js/messageMiddleware.ts index 58ee0dbc..7f7d975f 100644 --- a/frontend/src/js/messageMiddleware.ts +++ b/frontend/src/js/messageMiddleware.ts @@ -12,18 +12,24 @@ const MESSAGE_LIMIT_PER_FETCH = 250 // observer pattern using svelte store export const justFetchedMessageIds: any = writable([]) +let messageid_guildid: Record = {} + async function fetchMessages(messageIds: string[]) { console.log("fetching " + messageIds.length + " messages"); + let guild_id = messageid_guildid[messageIds[0]] // TODO: this is a hack, fix it // fetch messages from server - const response = await fetch("/api/messages", { + const response = await fetch(`/api/messages`, { method: "POST", headers: { "Content-Type": "application/json", }, - body: JSON.stringify(messageIds), + body: JSON.stringify({ + guild_id: guild_id, + message_ids: messageIds + }) }) // parse response @@ -56,7 +62,8 @@ export function cancelMessageContentRequest(messageId: string) { } -export async function getMessageContent(messageId: string): Promise { +export async function getMessageContent(messageId: string, guild_id: string): Promise { + messageid_guildid[messageId] = guild_id // if message is already loaded, return it if (messages[messageId]) { return new Promise((resolve) => { diff --git a/frontend/src/routes/channels/[guildId]/[channelId]/+page.ts b/frontend/src/routes/channels/[guildId]/[channelId]/+page.ts index dcf60eed..fe338ffe 100644 --- a/frontend/src/routes/channels/[guildId]/[channelId]/+page.ts +++ b/frontend/src/routes/channels/[guildId]/[channelId]/+page.ts @@ -13,7 +13,7 @@ export const load: Load = async({ fetch, params, parent }) => { let messages try { - let response = await fetch('/api/message-ids?channel_id=' + params.channelId) + let response = await fetch(`/api/message-ids?guild_id=${selectedGuildId}&channel_id=${selectedChannelId}`) let messageIds = await response.json() messages = messageIds.map((messageId: string) => { diff --git a/frontend/vite.config.ts b/frontend/vite.config.ts index 7db4edc4..5f96e696 100644 --- a/frontend/vite.config.ts +++ b/frontend/vite.config.ts @@ -11,7 +11,7 @@ const config: UserConfig = { }, https: false, hmr: { - clientPort: 5050, + clientPort: 21012, }, // host: '0.0.0.0', port: 5050,