From c8e20b4915185ccb8e6d8ea3b39d540811bf98f9 Mon Sep 17 00:00:00 2001 From: slatinsky Date: Sat, 28 Oct 2023 13:30:28 +0200 Subject: [PATCH] add guild whitelisting, simplify database queries --- backend/fastapi/Autocomplete.py | 8 +- backend/fastapi/app.py | 210 ++++++++++--------------------- backend/fastapi/helpers.py | 38 ++++-- backend/preprocess/main_mongo.py | 5 + frontend/src/js/api.ts | 4 +- 5 files changed, 103 insertions(+), 162 deletions(-) diff --git a/backend/fastapi/Autocomplete.py b/backend/fastapi/Autocomplete.py index c8a5b387..d6a8d298 100644 --- a/backend/fastapi/Autocomplete.py +++ b/backend/fastapi/Autocomplete.py @@ -16,7 +16,6 @@ def autocomplete_categories(guild_id: str, partial_category: str, limit: int): "$regex": partial_category, "$options": "i" }, - "guildId": guild_id, "type": { "$nin": [ "GuildPublicThread", @@ -67,8 +66,7 @@ def autocomplete_channels(guild_id: str, partial_channel: str, limit: int): query = { "name": { "$regex": partial_channel, "$options": "i" - }, - "guildId": guild_id + } } cursor = collection_channels.find(query, { "name": 1, @@ -99,8 +97,7 @@ def autocomplete_reactions(guild_id: str, partial_reaction: str, limit: int): "name": { "$regex": partial_reaction, "$options": "i" - }, - "guildIds": guild_id + } } cursor = collection_emojis.find(query, { "name": 1, @@ -169,7 +166,6 @@ def autocomplete_users(guild_id: str, partial_user_name: str, limit: int): print("collection_authors", collection_authors) query = { - "guildIds": guild_id, "names": { "$regex": partial_user_name, "$options": "i" } diff --git a/backend/fastapi/app.py b/backend/fastapi/app.py index 0943325f..22532401 100644 --- a/backend/fastapi/app.py +++ b/backend/fastapi/app.py @@ -10,7 +10,7 @@ from pydantic import BaseModel import Autocomplete -from helpers import get_global_collection, pad_id, get_guild_collection +from helpers import get_global_collection, get_whitelisted_guild_ids, is_db_online, pad_id, get_guild_collection # fix PIPE encoding error on Windows, auto flush print sys.stdout.reconfigure(encoding='utf-8') @@ -19,19 +19,6 @@ - - -# specify guild ids that should be hidden from the public (list of strings) -# TODO: move to config file -blacklisted_guild_ids = [] - -blacklisted_guild_ids = [pad_id(id) for id in blacklisted_guild_ids] - - - - - - app = FastAPI( title="DCEF backend api", description="This is the backend api for the DCEF viewer.", @@ -56,7 +43,7 @@ async def api_status(): Returns the status of the api and the database. """ try: - database_status = "online" if client.server_info()["ok"] == 1 else "offline" + database_status = "online" if is_db_online() else "offline" except: database_status = "offline" return { @@ -66,124 +53,85 @@ async def api_status(): @app.get("/guilds") -async def get_guilds(guild_id: str = None): +async def get_guilds(): """ Returns a list of guilds - or a single guild if a guild_id query parameter is provided. + If whitelist is enabled (by not being an empty list), only whitelisted guilds will be returned. - Filters out blacklisted guilds from the config.toml file. + all other whitelist logic is handled by get_guild_collection() method - it won't return a collection for non-whitelisted guilds """ collection_guilds = get_global_collection("guilds") - if guild_id: - if guild_id in blacklisted_guild_ids: - return {"message": "Not found"} - - guild = collection_guilds.find_one({"_id": guild_id}) - if not guild: - return {"message": "Not found"} - return guild + whitelisted_guild_ids = get_whitelisted_guild_ids() - cursor = collection_guilds.find( - { - "_id": { - "$nin": blacklisted_guild_ids + if len(whitelisted_guild_ids) == 0: + cursor = collection_guilds.find().sort([("msg_count", pymongo.DESCENDING)]) + else: + cursor = collection_guilds.find( + { + "_id": { + "$in": whitelisted_guild_ids + } } - } - ).sort([("msg_count", pymongo.DESCENDING)]) - - - + ).sort([("msg_count", pymongo.DESCENDING)]) return list(cursor) @app.get("/channels") -async def get_channels(guild_id: str = None, channel_id: str = None): +async def get_channels(guild_id: str): """ - Returns a list of all channels. + Returns a list of all channels in a guild. That includes channels, threads and forum posts. - - Optionally, a guild_id query parameter can be provided to filter by guild. - Optionally, a channel_id query parameter can be provided to get only specific channel. """ collection_channels = get_guild_collection(guild_id, "channels") + cursor = collection_channels.find() + return list(cursor) - if channel_id: - channel = collection_channels.find_one( - { - "_id": channel_id, - "guildId": { - "$nin": blacklisted_guild_ids - } - } - ) - if not channel: - return {"message": "Not found"} - return channel - - if guild_id: - if guild_id in blacklisted_guild_ids: - return [] - - cursor = collection_channels.find({"guildId": guild_id}) - return list(cursor) - - cursor = collection_channels.find( +@app.get("/channel") +async def get_channel(guild_id: str, channel_id: str): + """ + get only specific channel (used to resolve channel mentions) + """ + collection_channels = get_guild_collection(guild_id, "channels") + channel = collection_channels.find_one( { - "guildId": { - "$nin": blacklisted_guild_ids - } + "_id": channel_id } ) - return list(cursor) + if not channel: + return {"message": "Not found"} + return channel @app.get("/roles") -async def get_roles(guild_id: str = None, role_id: str = None): +async def get_roles(guild_id: str): """ - Returns a list of all roles. - - Optionally, a guild_id query parameter can be provided to filter by guild. - Optionally, a role_id query parameter can be provided to get only specific role. + Returns a list of all roles in a guild. """ collection_roles = get_guild_collection(guild_id, "roles") + cursor = collection_roles.find().sort([("position", pymongo.DESCENDING)]) + return list(cursor) - if role_id: - role = collection_roles.find_one( - { - "_id": role_id, - "guildId": { - "$nin": blacklisted_guild_ids - } - } - ) - if not role: - return {"message": "Not found"} - return role - - if guild_id: - if guild_id in blacklisted_guild_ids: - return [] - - cursor = collection_roles.find({"guildId": guild_id}) - return list(cursor) - - # order by position desc - cursor = collection_roles.find( +@app.get("/role") +async def get_role(guild_id: str, role_id: str): + """ + similar to /roles, but only returns one role + """ + collection_roles = get_guild_collection(guild_id, "roles") + role = collection_roles.find_one( { - "guildId": { - "$nin": blacklisted_guild_ids - } + "_id": role_id } - ).sort([("position", pymongo.DESCENDING)]) - return list(cursor) + ) + if not role: + return {"message": "Not found"} + return role @app.get("/message-ids") async def get_message_ids(channel_id: str, guild_id: str): """ - Returns a list of message ids. - Optionally, a channel_id query parameter can be provided to filter by channel. + Returns a list of message ids for a channel. """ collection_messages = get_guild_collection(guild_id, "messages") if is_compiled(): @@ -191,8 +139,8 @@ async def get_message_ids(channel_id: str, guild_id: str): else: cache_path = f"../../release/dcef/storage/cache/message-ids/{channel_id}.json" + # read cached ids if available if os.path.exists(cache_path): - # read file and return content with open(cache_path, "r", encoding="utf-8") as f: print("get_message_ids() cache hit - channel id", channel_id) file_content = f.read() @@ -201,13 +149,8 @@ async def get_message_ids(channel_id: str, guild_id: str): print("get_message_ids() cache miss - channel id", channel_id) query = { - "guildId": { - "$nin": blacklisted_guild_ids - } + "channelId": channel_id } - if channel_id: - query["channelId"] = channel_id - ids = collection_messages.find(query, {"_id": 1}).sort([("_id", pymongo.ASCENDING)]) new_ids = [str(id["_id"]) for id in ids] @@ -228,6 +171,7 @@ class MessageRequest(BaseModel): async def get_multiple_message_content(message_req_obj: MessageRequest): """ Returns the content of multiple messages by their ids. + All ids must be from the same guild. """ message_ids = message_req_obj.message_ids guild_id = message_req_obj.guild_id @@ -238,9 +182,6 @@ async def get_multiple_message_content(message_req_obj: MessageRequest): { "_id": { "$in": message_ids - }, - "guildId": { - "$nin": blacklisted_guild_ids } } ) @@ -249,7 +190,7 @@ async def get_multiple_message_content(message_req_obj: MessageRequest): return list_of_messages -def channel_names_to_ids(in_channel_ids: list, in_channels: list, guild_id: str = None): +def channel_names_to_ids(in_channel_ids: list, in_channels: list, guild_id: str): """ Convert channel names to ids. """ @@ -258,9 +199,6 @@ def channel_names_to_ids(in_channel_ids: list, in_channels: list, guild_id: str if len(in_channels) == 0: return in_channel_ids - if guild_id in blacklisted_guild_ids: - return [] - out_channel_ids = in_channel_ids.copy() for channel in in_channels: if channel in out_channel_ids: @@ -268,8 +206,7 @@ def channel_names_to_ids(in_channel_ids: list, in_channels: list, guild_id: str channel_id = collection_channels.find_one( { - "name": channel, - "guildId": guild_id + "name": channel }, { "_id": 1 @@ -280,24 +217,21 @@ def channel_names_to_ids(in_channel_ids: list, in_channels: list, guild_id: str return out_channel_ids -def category_names_to_ids(in_category_ids: list, in_categories: list, guild_id: str = None): +def category_names_to_ids(in_category_ids: list, in_categories: list, guild_id: str): """ - Convert category names to ids. + Convert category names to ids """ collection_channels = get_guild_collection(guild_id, "channels") if len(in_categories) == 0: return in_category_ids - if guild_id in blacklisted_guild_ids: - return [] - out_category_ids = in_category_ids.copy() for category in in_categories: if category in out_category_ids: continue - channel_id = collection_channels.find_one({"category": category, "guildId": guild_id}, {"categoryId": 1}) + channel_id = collection_channels.find_one({"category": category}, {"categoryId": 1}) if channel_id: out_category_ids.append(channel_id["categoryId"]) @@ -359,7 +293,6 @@ def extend_reactions(reaction_ids: list, reactions: list, guild_id: str): return reaction_ids reaction_ids = reaction_ids.copy() - # partial match or_ = [] for reaction in reactions: @@ -404,8 +337,7 @@ def get_channel_info(channel_id: str, guild_id: str): channel = collection_channels.find_one( { - "_id": channel_id, - "guildId": {"$nin": blacklisted_guild_ids} + "_id": channel_id } ) if not channel: @@ -426,7 +358,6 @@ def enrich_messages(list_of_messages: list, guild_id: str) -> list: for message in list_of_messages: for content in message["content"]: message_content = content["content"] - # match all search = regex.findall(message_content) possible_emotes.extend(search) @@ -654,25 +585,22 @@ def enrich_messages(list_of_messages: list, guild_id: str) -> list: @app.get("/search-autocomplete") -def search_autocomplete(guild_id: str = None, key: str = None, value: str = None, limit: int = 100): - if (guild_id == None or key == None or value == None): - return [] - - if guild_id in blacklisted_guild_ids: +def search_autocomplete(guild_id: str, key: str = None, value: str = None, limit: int = 100): + if (key == None or value == None): return [] - guild_id = pad_id(guild_id) + padded_guild_id = pad_id(guild_id) if (key == "users"): - return Autocomplete.autocomplete_users(guild_id, value, limit) + return Autocomplete.autocomplete_users(padded_guild_id, value, limit) elif (key == "filenames"): - return Autocomplete.autocomplete_filenames(guild_id, value, limit) + return Autocomplete.autocomplete_filenames(padded_guild_id, value, limit) elif (key == "reactions"): - return Autocomplete.autocomplete_reactions(guild_id, value, limit) + return Autocomplete.autocomplete_reactions(padded_guild_id, value, limit) elif (key == "channels"): - return Autocomplete.autocomplete_channels(guild_id, value, limit) + return Autocomplete.autocomplete_channels(padded_guild_id, value, limit) elif (key == "categories"): - return Autocomplete.autocomplete_categories(guild_id, value, limit) + return Autocomplete.autocomplete_categories(padded_guild_id, value, limit) else: return [] @@ -776,7 +704,7 @@ def parse_prompt(prompt: str): @app.get("/search") -async def search_messages(prompt: str = None, guild_id: str = None, only_ids: bool = True, order_by: str = Query("newest", enum=["newest", "oldest"])): +async def search_messages(guild_id: str, prompt: str = None, only_ids: bool = True, order_by: str = Query("newest", enum=["newest", "oldest"])): """ Searches for messages that contain the prompt. """ @@ -991,14 +919,6 @@ async def search_messages(prompt: str = None, guild_id: str = None, only_ids: bo query["$and"].append({"$and": and_}) - - if guild_id: - if guild_id in blacklisted_guild_ids: - return [] - query["guildId"]=guild_id - else: - query["guildId"]={"$nin": blacklisted_guild_ids} - if only_ids: limited_fields["_id"]=1 diff --git a/backend/fastapi/helpers.py b/backend/fastapi/helpers.py index 0bce6c0a..caab566b 100644 --- a/backend/fastapi/helpers.py +++ b/backend/fastapi/helpers.py @@ -1,22 +1,42 @@ import pymongo +def pad_id(id): + if id == None: + return None + return str(id).zfill(24) + + + URI = "mongodb://127.0.0.1:27017" client = pymongo.MongoClient(URI) db = client["dcef"] collection_messages = db["messages"] collection_guilds = db["guilds"] -collection_authors = db["authors"] -collection_emojis = db["emojis"] -collection_assets = db["assets"] -collection_roles = db["roles"] +collection_config = db["config"] + +def get_whitelisted_guild_ids(): + whitelisted_guild_ids = collection_config.find_one({"key": "whitelisted_guild_ids"})["value"] + whitelisted_guild_ids = [pad_id(id) for id in whitelisted_guild_ids] + return whitelisted_guild_ids + + def get_guild_collection(guild_id, collection_name): - return db[f"g{pad_id(guild_id)}_{collection_name}"] + whitelisted_guild_ids = get_whitelisted_guild_ids() + padded_guild_id = pad_id(guild_id) + if len(whitelisted_guild_ids) > 0 and padded_guild_id not in whitelisted_guild_ids: + raise Exception(f"Guild {guild_id} not whitelisted") + + return db[f"g{padded_guild_id}_{collection_name}"] def get_global_collection(collection_name): return db[collection_name] -def pad_id(id): - if id == None: - return None - return str(id).zfill(24) \ No newline at end of file + + +def is_db_online(): + try: + client.server_info() + return True + except: + return False \ No newline at end of file diff --git a/backend/preprocess/main_mongo.py b/backend/preprocess/main_mongo.py index 34f07328..14898dca 100644 --- a/backend/preprocess/main_mongo.py +++ b/backend/preprocess/main_mongo.py @@ -28,6 +28,11 @@ def wipe_database(database: MongoDatabase): EXPECTED_VERSION = 13 # <---- change this to wipe database config = database.get_collection("config") + # add empty whitelisted_guild_ids config if it does not exist + whitelisted_guild_ids = config.find_one({"key": "whitelisted_guild_ids"}) + if whitelisted_guild_ids is None: + config.insert_one({"key": "whitelisted_guild_ids", "value": []}) + version = config.find_one({"key": "version"}) if version is None: version = {"key": "version", "value": 0} diff --git a/frontend/src/js/api.ts b/frontend/src/js/api.ts index fda6e4ad..4fb567ff 100644 --- a/frontend/src/js/api.ts +++ b/frontend/src/js/api.ts @@ -19,7 +19,7 @@ export async function getChannelInfo(channelId: string, guildId: string): Promis return createMockChannelInfo(undefined); } const paddedChannelId = channelId.padStart(24, "0"); - const response = await fetch(`/api/channels?channel_id=${paddedChannelId}&guild_id=${guildId}`); + const response = await fetch(`/api/channel?channel_id=${paddedChannelId}&guild_id=${guildId}`); const json = await response.json(); if (!json._id) { @@ -39,7 +39,7 @@ export async function getRoleInfo(roleId: string, guildId: string) { }; } const paddedRoleId = roleId.padStart(24, "0"); - const response = await fetch(`/api/roles?role_id=${paddedRoleId}&guild_id=${guildId}`); + const response = await fetch(`/api/role?role_id=${paddedRoleId}&guild_id=${guildId}`); const json = await response.json(); if (!json._id) {