From b4da6e526ec6e479d2682592c17f366355caffda Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 10 Aug 2024 00:08:16 +0300 Subject: [PATCH] Add support for authenticated media (#290) Setup instructions: 1. Set up a reverse proxy to pass `/_heisenbridge/media/*` to heisenbridge 2. Configure `heisenbridge` -> `media_url` in the registration file with the public URL that the reverse proxy handles Optionally, you can run another heisenbridge instance with the `--media-proxy` flag to have it in a separate process --- heisenbridge/__main__.py | 111 +++++++++++++++++++++++++++++++-------- 1 file changed, 89 insertions(+), 22 deletions(-) diff --git a/heisenbridge/__main__.py b/heisenbridge/__main__.py index eb4b4a8..0bea1d9 100644 --- a/heisenbridge/__main__.py +++ b/heisenbridge/__main__.py @@ -1,6 +1,9 @@ import argparse import asyncio +import base64 import grp +import hashlib +import hmac import logging import os import pwd @@ -14,6 +17,7 @@ from typing import List from typing import Tuple +from aiohttp import web from mautrix.api import HTTPAPI from mautrix.api import Method from mautrix.api import Path @@ -59,7 +63,7 @@ class BridgeAppService(AppService): _rooms: Dict[str, Room] _users: Dict[str, str] - DEFAULT_MEDIA_PATH = "/_matrix/media/v3/download/{netloc}{path}{filename}" + DEFAULT_MEDIA_PATH = "/_heisenbridge/media/{server}/{media_id}/{checksum}{filename}" async def push_bridge_state( self, @@ -332,17 +336,70 @@ async def detect_public_endpoint(self): logging.warning("Using internal URL for homeserver, media links are likely broken!") return str(self.api.base_url) - def mxc_to_url(self, mxc, filename=None): - mxc = urllib.parse.urlparse(mxc) + def mxc_checksum(self, server: str, media_id: str) -> str: + # Add trailing slash to prevent length extension attacks + checksum_raw = hmac.new(self.media_key, f"mxc://{server}/{media_id}/".encode("utf-8"), hashlib.sha256).digest() + return base64.urlsafe_b64encode(checksum_raw[:8]).decode("utf-8").rstrip("=") + + async def proxy_media(self, req: web.Request) -> web.StreamResponse | web.Response: + server = req.match_info["server"] + media_id = req.match_info["media_id"] + checksum = req.match_info["checksum"] + if self.mxc_checksum(server, media_id) != checksum: + return web.Response(status=403, text="Invalid checksum") + download_url = self.api.base_url / "_matrix/client/v1/media/download" / server / media_id + filename = req.match_info.get("filename", "") + if filename: + download_url /= filename + query_params: dict[str, str] = {"allow_redirect": "true", "user_id": self.az.bot_mxid} + headers: dict[str, str] = {"Authorization": f"Bearer {self.az.as_token}"} + resp_headers = { + "Content-Security-Policy": ( + "sandbox; default-src 'none'; script-src 'none'; style-src 'none'; object-src 'none';" + ), + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "GET, HEAD, OPTIONS", + "Content-Disposition": "attachment", + } + started_writing = False + try: + async with self.api.session.get(download_url, params=query_params, headers=headers) as dl_resp: + resp = web.StreamResponse(status=dl_resp.status, headers=resp_headers) + if dl_resp.content_length: + resp.content_length = dl_resp.content_length + resp.content_type = dl_resp.content_type + if "Content-Disposition" in dl_resp.headers: + resp.headers["Content-Disposition"] = dl_resp.headers["Content-Disposition"] + elif resp.status >= 300: + del resp.headers["Content-Disposition"] + started_writing = True + await resp.prepare(req) + async for chunk, end_of_chunk in dl_resp.content.iter_chunks(): + await resp.write(chunk) + return resp + except Exception: + if not started_writing: + logging.exception("Failed to fetch media") + return web.Response(status=502, text="Failed to fetch media") + + def mxc_to_url(self, mxc: str, filename=None): + if not self.media_endpoint: + return "" + try: + server, media_id = self.api.parse_mxc_uri(mxc) + except ValueError: + return "" if filename is None: filename = "" else: filename = "/" + urllib.parse.quote(filename) - media_path = self.media_path.format(netloc=mxc.netloc, path=mxc.path, filename=filename) + media_path = self.media_path.format( + server=server, media_id=media_id, checksum=self.mxc_checksum(server, media_id), filename=filename + ) - return "{}{}".format(self.endpoint, media_path) + return "{}{}".format(self.media_endpoint, media_path) async def reset(self, config_file, homeserver_url): with open(config_file) as f: @@ -448,7 +505,7 @@ async def ensure_hidden_room(self): return use_hidden_room - async def run(self, listen_address, listen_port, homeserver_url, owner, safe_mode): + async def run(self, listen_address, listen_port, homeserver_url, owner, safe_mode, media_proxy): if "sender_localpart" not in self.registration: print("Missing sender_localpart from registration file.") sys.exit(1) @@ -485,6 +542,8 @@ async def run(self, listen_address, listen_port, homeserver_url, owner, safe_mod print(f"Heisenbridge v{__version__}", flush=True) if safe_mode: print("Safe mode is enabled.", flush=True) + if media_proxy: + print("Media proxy only mode.", flush=True) url = urllib.parse.urlparse(homeserver_url) ws = None @@ -542,6 +601,8 @@ async def run(self, listen_address, listen_port, homeserver_url, owner, safe_mod state_store=MemoryBridgeStateStore(), ) self.az.matrix_event_handler(self._on_mx_event) + self.az.app.router.add_get("/_heisenbridge/media/{server}/{media_id}/{checksum}/{filename}", self.proxy_media) + self.az.app.router.add_get("/_heisenbridge/media/{server}/{media_id}/{checksum}", self.proxy_media) try: await self.az.start(host=listen_address, port=listen_port) @@ -578,6 +639,7 @@ async def run(self, listen_address, listen_port, homeserver_url, owner, safe_mod "use_reacts": True, "media_url": None, "media_path": None, + "media_key": None, "namespace": self.puppet_prefix, } logging.debug(f"Default config: {self.config}") @@ -594,27 +656,21 @@ async def run(self, listen_address, listen_port, homeserver_url, owner, safe_mod # load config from HS await self.load() - async def _resolve_media_endpoint(): - endpoint = await self.detect_public_endpoint() - - # only rewrite it if it wasn't changed - if self.endpoint == str(self.api.base_url): - self.endpoint = endpoint - - print("Homeserver is publicly available at " + self.endpoint, flush=True) + if "heisenbridge" in self.registration and "media_key" in self.registration["heisenbridge"]: + self.media_key = self.registration["heisenbridge"]["media_key"].encode("utf-8") + elif self.config["media_key"]: + self.media_key = self.config["media_key"].encode("utf-8") + else: + self.media_key = self.registration["hs_token"].encode("utf-8") # use configured media_url for endpoint if we have it if "heisenbridge" in self.registration and "media_url" in self.registration["heisenbridge"]: logging.debug( f"Overriding media URL from registration file to {self.registration['heisenbridge']['media_url']}" ) - self.endpoint = self.registration["heisenbridge"]["media_url"] + self.media_endpoint = self.registration["heisenbridge"]["media_url"] elif self.config["media_url"]: - self.endpoint = self.config["media_url"] - else: - print("Trying to detect homeserver public endpoint, this might take a while...", flush=True) - self.endpoint = str(self.api.base_url) - asyncio.ensure_future(_resolve_media_endpoint()) + self.media_endpoint = self.config["media_url"] # use configured media_path for media_path if we have it if "heisenbridge" in self.registration and "media_path" in self.registration["heisenbridge"]: @@ -627,6 +683,11 @@ async def _resolve_media_endpoint(): else: self.media_path = self.DEFAULT_MEDIA_PATH + if media_proxy: + logging.info("Media proxy mode startup complete") + await asyncio.Event().wait() + return + logging.info("Starting presence loop") self._keepalive() @@ -854,6 +915,12 @@ async def async_main(): help="reset ALL bridge configuration from homeserver and exit", default=argparse.SUPPRESS, ) + parser.add_argument( + "--media-proxy", + action="store_true", + help="run in media proxy mode", + default=False, + ) parser.add_argument( "--safe-mode", action="store_true", @@ -924,7 +991,7 @@ async def async_main(): service.load_reg(args.config) - if args.identd: + if args.identd and not args.media_proxy: identd = Identd() await identd.start_listening(service, args.identd_port) @@ -963,7 +1030,7 @@ async def async_main(): except Exception: pass - await service.run(listen_address, listen_port, args.homeserver, args.owner, args.safe_mode) + await service.run(listen_address, listen_port, args.homeserver, args.owner, args.safe_mode, args.media_proxy) def main():