Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Media authentication & event linking (MSC3916, MSC3911) #465

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions api/_apimeta/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ type UserInfo struct {
IsShared bool
}

type ServerInfo struct {
ServerName string
}

func GetRequestUserAdminStatus(r *http.Request, rctx rcontext.RequestContext, user UserInfo) (bool, bool) {
isGlobalAdmin := util.IsGlobalAdmin(user.UserId) || user.IsShared
isLocalAdmin, err := matrix.IsUserAdmin(rctx, r.Host, user.AccessToken, r.RemoteAddr)
Expand Down
2 changes: 1 addition & 1 deletion api/_auth_cache/auth_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"github.com/turt2live/matrix-media-repo/matrix"
)

var tokenCache = cache.New(0*time.Second, 30*time.Second)
var tokenCache = cache.New(cache.NoExpiration, 30*time.Second)
var rwLock = &sync.RWMutex{}
var regexCache = make(map[string]*regexp.Regexp)

Expand Down
30 changes: 30 additions & 0 deletions api/_routers/97-require-server-auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package _routers

import (
"net/http"

"github.com/turt2live/matrix-media-repo/api/_apimeta"
"github.com/turt2live/matrix-media-repo/api/_responses"
"github.com/turt2live/matrix-media-repo/common"
"github.com/turt2live/matrix-media-repo/common/rcontext"
"github.com/turt2live/matrix-media-repo/matrix"
)

type GeneratorWithServerFn = func(r *http.Request, ctx rcontext.RequestContext, server _apimeta.ServerInfo) interface{}

func RequireServerAuth(generator GeneratorWithServerFn) GeneratorFn {
return func(r *http.Request, ctx rcontext.RequestContext) interface{} {
serverName, err := matrix.ValidateXMatrixAuth(r, true)
if err != nil {
ctx.Log.Debug("Error with X-Matrix auth: ", err)
return &_responses.ErrorResponse{
Code: common.ErrCodeForbidden,
Message: "no auth provided (required)",
InternalCode: common.ErrCodeMissingToken,
}
}
return generator(r, ctx, _apimeta.ServerInfo{
ServerName: serverName,
})
}
}
32 changes: 18 additions & 14 deletions api/_routers/98-use-rcontext.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,20 +88,24 @@ func (c *RContextRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
beforeParseDownload:
log.Infof("Replying with result: %T %+v", res, res)
if downloadRes, isDownload := res.(*_responses.DownloadResponse); isDownload {
ranges, err := http_range.ParseRange(r.Header.Get("Range"), downloadRes.SizeBytes, rctx.Config.Downloads.DefaultRangeChunkSizeBytes)
if errors.Is(err, http_range.ErrInvalid) {
proposedStatusCode = http.StatusRequestedRangeNotSatisfiable
res = _responses.BadRequest("invalid range header")
goto beforeParseDownload // reprocess `res`
} else if errors.Is(err, http_range.ErrNoOverlap) {
proposedStatusCode = http.StatusRequestedRangeNotSatisfiable
res = _responses.BadRequest("out of range")
goto beforeParseDownload // reprocess `res`
}
if len(ranges) > 1 {
proposedStatusCode = http.StatusRequestedRangeNotSatisfiable
res = _responses.BadRequest("only 1 range is supported")
goto beforeParseDownload // reprocess `res`
var ranges []http_range.Range
var err error
if downloadRes.SizeBytes > 0 {
ranges, err = http_range.ParseRange(r.Header.Get("Range"), downloadRes.SizeBytes, rctx.Config.Downloads.DefaultRangeChunkSizeBytes)
if errors.Is(err, http_range.ErrInvalid) {
proposedStatusCode = http.StatusRequestedRangeNotSatisfiable
res = _responses.BadRequest("invalid range header")
goto beforeParseDownload // reprocess `res`
} else if errors.Is(err, http_range.ErrNoOverlap) {
proposedStatusCode = http.StatusRequestedRangeNotSatisfiable
res = _responses.BadRequest("out of range")
goto beforeParseDownload // reprocess `res`
}
if len(ranges) > 1 {
proposedStatusCode = http.StatusRequestedRangeNotSatisfiable
res = _responses.BadRequest("only 1 range is supported")
goto beforeParseDownload // reprocess `res`
}
}

contentType = "application/octet-stream"
Expand Down
3 changes: 3 additions & 0 deletions api/custom/federation.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ func GetFederationInfo(r *http.Request, rctx rcontext.RequestContext, user _apim

versionUrl := url + "/_matrix/federation/v1/version"
versionResponse, err := matrix.FederatedGet(versionUrl, hostname, rctx)
if versionResponse != nil {
defer versionResponse.Body.Close()
}
if err != nil {
rctx.Log.Error(err)
sentry.CaptureException(err)
Expand Down
23 changes: 21 additions & 2 deletions api/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (

const PrefixMedia = "/_matrix/media"
const PrefixClient = "/_matrix/client"
const PrefixFederation = "/_matrix/federation"

func buildRoutes() http.Handler {
counter := &_routers.RequestCounter{}
Expand All @@ -36,13 +37,29 @@ func buildRoutes() http.Handler {
register([]string{"GET"}, PrefixMedia, "download/:server/:mediaId/:filename", mxSpecV3Transition, router, downloadRoute)
register([]string{"GET"}, PrefixMedia, "download/:server/:mediaId", mxSpecV3Transition, router, downloadRoute)
register([]string{"GET"}, PrefixMedia, "thumbnail/:server/:mediaId", mxSpecV3Transition, router, makeRoute(_routers.OptionalAccessToken(r0.ThumbnailMedia), "thumbnail", counter))
register([]string{"GET"}, PrefixMedia, "preview_url", mxSpecV3TransitionCS, router, makeRoute(_routers.RequireAccessToken(r0.PreviewUrl), "url_preview", counter))
previewUrlRoute := makeRoute(_routers.RequireAccessToken(r0.PreviewUrl), "url_preview", counter)
register([]string{"GET"}, PrefixMedia, "preview_url", mxSpecV3TransitionCS, router, previewUrlRoute)
register([]string{"GET"}, PrefixMedia, "identicon/*seed", mxR0, router, makeRoute(_routers.OptionalAccessToken(r0.Identicon), "identicon", counter))
register([]string{"GET"}, PrefixMedia, "config", mxSpecV3TransitionCS, router, makeRoute(_routers.RequireAccessToken(r0.PublicConfig), "config", counter))
configRoute := makeRoute(_routers.RequireAccessToken(r0.PublicConfig), "config", counter)
register([]string{"GET"}, PrefixMedia, "config", mxSpecV3TransitionCS, router, configRoute)
register([]string{"POST"}, PrefixClient, "logout", mxSpecV3TransitionCS, router, makeRoute(_routers.RequireAccessToken(r0.Logout), "logout", counter))
register([]string{"POST"}, PrefixClient, "logout/all", mxSpecV3TransitionCS, router, makeRoute(_routers.RequireAccessToken(r0.LogoutAll), "logout_all", counter))
register([]string{"POST"}, PrefixMedia, "create", mxV1, router, makeRoute(_routers.RequireAccessToken(v1.CreateMedia), "create", counter))

// MSC3916 - Authentication & endpoint API separation
register([]string{"GET"}, PrefixClient, "media/preview_url", msc3916, router, previewUrlRoute)
register([]string{"GET"}, PrefixClient, "media/config", msc3916, router, configRoute)
authedDownloadRoute := makeRoute(_routers.RequireAccessToken(unstable.ClientDownloadMedia), "download", counter)
register([]string{"GET"}, PrefixClient, "media/download/:server/:mediaId/:filename", msc3916, router, authedDownloadRoute)
register([]string{"GET"}, PrefixClient, "media/download/:server/:mediaId", msc3916, router, authedDownloadRoute)
register([]string{"GET"}, PrefixClient, "media/thumbnail/:server/:mediaId", msc3916, router, makeRoute(_routers.RequireAccessToken(r0.ThumbnailMedia), "thumbnail", counter))
register([]string{"GET"}, PrefixFederation, "media/download/:server/:mediaId", msc3916, router, makeRoute(_routers.RequireServerAuth(unstable.FederationDownloadMedia), "download", counter))
register([]string{"GET"}, PrefixFederation, "media/thumbnail/:server/:mediaId", msc3916, router, makeRoute(_routers.RequireServerAuth(unstable.FederationThumbnailMedia), "thumbnail", counter))

// MSC3911 - Linking media to events
register([]string{"POST"}, PrefixClient, "media/upload", msc3911, router, makeRoute(_routers.RequireAccessToken(unstable.ClientUploadMediaSync), "upload", counter))
register([]string{"POST"}, PrefixClient, "media/create", msc3911, router, makeRoute(_routers.RequireAccessToken(unstable.ClientCreateMedia), "create", counter))

// Custom features
register([]string{"GET"}, PrefixMedia, "local_copy/:server/:mediaId", mxUnstable, router, makeRoute(_routers.RequireAccessToken(unstable.LocalCopy), "local_copy", counter))
register([]string{"GET"}, PrefixMedia, "info/:server/:mediaId", mxUnstable, router, makeRoute(_routers.RequireAccessToken(unstable.MediaInfo), "info", counter))
Expand Down Expand Up @@ -129,6 +146,8 @@ var (
//mxAllSpec matrixVersions = []string{"r0", "v1", "v3", "unstable", "unstable/io.t2bot.media" /* and MSC routes */}
mxUnstable matrixVersions = []string{"unstable", "unstable/io.t2bot.media"}
msc4034 matrixVersions = []string{"unstable/org.matrix.msc4034"}
msc3916 matrixVersions = []string{"unstable/org.matrix.msc3916"}
msc3911 matrixVersions = []string{"unstable/org.matrix.msc3911"}
mxSpecV3Transition matrixVersions = []string{"r0", "v1", "v3"}
mxSpecV3TransitionCS matrixVersions = []string{"r0", "v3"}
mxR0 matrixVersions = []string{"r0"}
Expand Down
50 changes: 50 additions & 0 deletions api/unstable/msc3911_create.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package unstable

import (
"net/http"

"github.com/getsentry/sentry-go"
"github.com/turt2live/matrix-media-repo/api/_apimeta"
"github.com/turt2live/matrix-media-repo/api/_responses"
v1 "github.com/turt2live/matrix-media-repo/api/v1"
"github.com/turt2live/matrix-media-repo/common/rcontext"
"github.com/turt2live/matrix-media-repo/database"
"github.com/turt2live/matrix-media-repo/pipelines/pipeline_create"
"github.com/turt2live/matrix-media-repo/util"
)

func ClientCreateMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} {
id, err := restrictAsyncMediaId(rctx, r.Host, user.UserId)
if err != nil {
rctx.Log.Error("Unexpected error creating media ID:", err)
sentry.CaptureException(err)
return _responses.InternalServerError("unexpected error")
}

return &v1.MediaCreatedResponse{
ContentUri: util.MxcUri(id.Origin, id.MediaId),
ExpiresTs: id.ExpiresTs,
}
}

func restrictAsyncMediaId(ctx rcontext.RequestContext, host string, userId string) (*database.DbExpiringMedia, error) {
id, err := pipeline_create.Execute(ctx, host, userId, pipeline_create.DefaultExpirationTime)
if err != nil {
return nil, err
}

db := database.GetInstance().RestrictedMedia.Prepare(ctx)
err = db.Insert(id.Origin, id.MediaId, database.RestrictedToUser, id.UserId)
if err != nil {
// Try to clean up the expiring record, but don't fail if it fails
err2 := database.GetInstance().ExpiringMedia.Prepare(ctx).SetExpiry(id.Origin, id.MediaId, util.NowMillis())
if err2 != nil {
ctx.Log.Warn("Non-fatal error when trying to clean up interstitial expiring media: ", err2)
sentry.CaptureException(err2)
}

return nil, err
}

return id, nil
}
36 changes: 36 additions & 0 deletions api/unstable/msc3911_upload_sync.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package unstable

import (
"net/http"

"github.com/getsentry/sentry-go"
"github.com/turt2live/matrix-media-repo/api/_apimeta"
"github.com/turt2live/matrix-media-repo/api/_responses"
"github.com/turt2live/matrix-media-repo/api/_routers"
"github.com/turt2live/matrix-media-repo/api/r0"
"github.com/turt2live/matrix-media-repo/common/rcontext"
"github.com/turt2live/matrix-media-repo/util"
)

func ClientUploadMediaSync(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} {
// We're a bit fancy here. Instead of mirroring the "upload sync" endpoint to include restricted media, we
// internally create an async media ID then claim it immediately.

id, err := restrictAsyncMediaId(rctx, r.Host, user.UserId)
if err != nil {
rctx.Log.Error("Unexpected error creating media ID:", err)
sentry.CaptureException(err)
return _responses.InternalServerError("unexpected error")
}

r = _routers.ForceSetParam("server", id.Origin, r)
r = _routers.ForceSetParam("mediaId", id.MediaId, r)

resp := r0.UploadMediaAsync(r, rctx, user)
if _, ok := resp.(*r0.MediaUploadedResponse); ok {
return &r0.MediaUploadedResponse{
ContentUri: util.MxcUri(id.Origin, id.MediaId),
}
}
return resp
}
37 changes: 37 additions & 0 deletions api/unstable/msc3916_download.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package unstable

import (
"bytes"
"net/http"

"github.com/turt2live/matrix-media-repo/api/_apimeta"
"github.com/turt2live/matrix-media-repo/api/_responses"
"github.com/turt2live/matrix-media-repo/api/r0"
"github.com/turt2live/matrix-media-repo/common/rcontext"
"github.com/turt2live/matrix-media-repo/util/readers"
)

func ClientDownloadMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} {
r.URL.Query().Set("allow_remote", "true")
return r0.DownloadMedia(r, rctx, user)
}

func FederationDownloadMedia(r *http.Request, rctx rcontext.RequestContext, server _apimeta.ServerInfo) interface{} {
r.URL.Query().Set("allow_remote", "false")

res := r0.DownloadMedia(r, rctx, _apimeta.UserInfo{})
if dl, ok := res.(*_responses.DownloadResponse); ok {
return &_responses.DownloadResponse{
ContentType: "multipart/mixed",
Filename: "",
SizeBytes: 0,
Data: readers.NewMultipartReader(
&readers.MultipartPart{ContentType: "application/json", Reader: readers.MakeCloser(bytes.NewReader([]byte("{}")))},
&readers.MultipartPart{ContentType: dl.ContentType, FileName: dl.Filename, Reader: dl.Data},
),
TargetDisposition: "attachment",
}
} else {
return res
}
}
32 changes: 32 additions & 0 deletions api/unstable/msc3916_thumbnail.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package unstable

import (
"bytes"
"net/http"

"github.com/turt2live/matrix-media-repo/api/_apimeta"
"github.com/turt2live/matrix-media-repo/api/_responses"
"github.com/turt2live/matrix-media-repo/api/r0"
"github.com/turt2live/matrix-media-repo/common/rcontext"
"github.com/turt2live/matrix-media-repo/util/readers"
)

func FederationThumbnailMedia(r *http.Request, rctx rcontext.RequestContext, server _apimeta.ServerInfo) interface{} {
r.URL.Query().Set("allow_remote", "false")

res := r0.ThumbnailMedia(r, rctx, _apimeta.UserInfo{})
if dl, ok := res.(*_responses.DownloadResponse); ok {
return &_responses.DownloadResponse{
ContentType: "multipart/mixed",
Filename: "",
SizeBytes: 0,
Data: readers.NewMultipartReader(
&readers.MultipartPart{ContentType: "application/json", Reader: readers.MakeCloser(bytes.NewReader([]byte("{}")))},
&readers.MultipartPart{ContentType: dl.ContentType, FileName: dl.Filename, Reader: dl.Data},
),
TargetDisposition: "attachment",
}
} else {
return res
}
}
13 changes: 12 additions & 1 deletion config.sample.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,18 @@ plugins:
# Sections of this config might disappear or be added over time. By default all
# features are disabled in here and must be explicitly enabled to be used.
featureSupport:
# No unstable features are currently supported.
# MSC3911 enables linking media to events, allowing the associated media to be
# deleted when the event is (fully) deleted. MSC3911 support is always enabled
# and requires changes to either the homeserver or how the homeserver is deployed
# to work.
MSC3911:
# How long a "restricted" item of media can exist before it is automatically
# purged from the server. Defaults to 10 minutes.
maxRestrictedAgeMinutes: 10

# The maximum number of media items that can be attached to a single event.
# Defaults to 20.
maxAttachEvent: 20

# Support for redis as a cache mechanism
#
Expand Down
4 changes: 4 additions & 0 deletions database/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type Database struct {
Tasks *tasksTableStatements
Exports *exportsTableStatements
ExportParts *exportPartsTableStatements
RestrictedMedia *restrictedMediaTableStatements
}

var instance *Database
Expand Down Expand Up @@ -124,6 +125,9 @@ func openDatabase(connectionString string, maxConns int, maxIdleConns int) error
if d.ExportParts, err = prepareExportPartsTables(d.conn); err != nil {
return errors.New("failed to create export parts table accessor: " + err.Error())
}
if d.RestrictedMedia, err = prepareRestrictedMediaTables(d.conn); err != nil {
return errors.New("failed to create restricted media table accessor: " + err.Error())
}

instance = d
return nil
Expand Down
10 changes: 10 additions & 0 deletions database/table_expiring_media.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ const insertExpiringMedia = "INSERT INTO expiring_media (origin, media_id, user_
const selectExpiringMediaByUserCount = "SELECT COUNT(*) FROM expiring_media WHERE user_id = $1 AND expires_ts >= $2;"
const selectExpiringMediaById = "SELECT origin, media_id, user_id, expires_ts FROM expiring_media WHERE origin = $1 AND media_id = $2;"
const deleteExpiringMediaById = "DELETE FROM expiring_media WHERE origin = $1 AND media_id = $2;"
const updateExpiringMediaExpiration = "UPDATE expiring_media SET expires_ts = $3 WHERE origin = $1 AND media_id = $2;"

// Dev note: there is an UPDATE query in the Upload test suite.

Expand All @@ -31,6 +32,7 @@ type expiringMediaTableStatements struct {
selectExpiringMediaByUserCount *sql.Stmt
selectExpiringMediaById *sql.Stmt
deleteExpiringMediaById *sql.Stmt
updateExpiringMediaExpiration *sql.Stmt
}

type expiringMediaTableWithContext struct {
Expand All @@ -54,6 +56,9 @@ func prepareExpiringMediaTables(db *sql.DB) (*expiringMediaTableStatements, erro
if stmts.deleteExpiringMediaById, err = db.Prepare(deleteExpiringMediaById); err != nil {
return nil, errors.New("error preparing deleteExpiringMediaById: " + err.Error())
}
if stmts.updateExpiringMediaExpiration, err = db.Prepare(updateExpiringMediaExpiration); err != nil {
return nil, errors.New("error preparing updateExpiringMediaExpiration: " + err.Error())
}

return stmts, nil
}
Expand Down Expand Up @@ -96,3 +101,8 @@ func (s *expiringMediaTableWithContext) Delete(origin string, mediaId string) er
_, err := s.statements.deleteExpiringMediaById.ExecContext(s.ctx, origin, mediaId)
return err
}

func (s *expiringMediaTableWithContext) SetExpiry(origin string, mediaId string, expiresTs int64) error {
_, err := s.statements.updateExpiringMediaExpiration.ExecContext(s.ctx, origin, mediaId, expiresTs)
return err
}
Loading