From f20bd9aa3f190e6a3d0bf38be71bacf64c9b634a Mon Sep 17 00:00:00 2001 From: Travis Ralston <travpc@gmail.com> Date: Sun, 11 Jun 2023 00:38:12 -0600 Subject: [PATCH] Support the remaining bit of MSC2246: timeouts/wait for media Fixes https://github.com/turt2live/matrix-media-repo/issues/407 Closes https://github.com/turt2live/matrix-media-repo/issues/412 Closes https://github.com/turt2live/matrix-media-repo/issues/413 --- api/_responses/errors.go | 4 + api/_routers/98-use-rcontext.go | 3 + api/r0/download.go | 2 + api/r0/thumbnail.go | 2 + common/errorcodes.go | 1 + common/errors.go | 1 + database/table_expiring_media.go | 4 + notifier/uploads.go | 112 ++++++++++++++++++++++++ pipelines/_steps/download/wait.go | 28 ++++++ pipelines/pipeline_download/pipeline.go | 13 ++- pipelines/pipeline_upload/pipeline.go | 12 ++- pipelines/pipeline_upload/pipeline2.go | 3 +- redislib/connection.go | 8 +- redislib/pubsub.go | 88 +++++++++++++++++++ 14 files changed, 275 insertions(+), 6 deletions(-) create mode 100644 notifier/uploads.go create mode 100644 pipelines/_steps/download/wait.go create mode 100644 redislib/pubsub.go diff --git a/api/_responses/errors.go b/api/_responses/errors.go index 1cee5021..a2b1c971 100644 --- a/api/_responses/errors.go +++ b/api/_responses/errors.go @@ -55,3 +55,7 @@ func BadRequest(message string) *ErrorResponse { func QuotaExceeded() *ErrorResponse { return &ErrorResponse{common.ErrCodeForbidden, "Quota Exceeded", common.ErrCodeQuotaExceeded} } + +func NotYetUploaded() *ErrorResponse { + return &ErrorResponse{common.ErrCodeNotYetUploaded, "Media not yet uploaded", common.ErrCodeNotYetUploaded} +} diff --git a/api/_routers/98-use-rcontext.go b/api/_routers/98-use-rcontext.go index ba774e81..d9cbc0da 100644 --- a/api/_routers/98-use-rcontext.go +++ b/api/_routers/98-use-rcontext.go @@ -172,6 +172,9 @@ beforeParseDownload: case common.ErrCodeCannotOverwrite: proposedStatusCode = http.StatusConflict break + case common.ErrCodeNotYetUploaded: + proposedStatusCode = http.StatusGatewayTimeout + break default: // Treat as unknown (a generic server error) proposedStatusCode = http.StatusInternalServerError break diff --git a/api/r0/download.go b/api/r0/download.go index 9fe8bf8b..5ecf793a 100644 --- a/api/r0/download.go +++ b/api/r0/download.go @@ -80,6 +80,8 @@ func DownloadMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta. } else { return _responses.NotFoundError() // We lie for security } + } else if err == common.ErrMediaNotYetUploaded { + return _responses.NotYetUploaded() } rctx.Log.Error("Unexpected error locating media: " + err.Error()) sentry.CaptureException(err) diff --git a/api/r0/thumbnail.go b/api/r0/thumbnail.go index 139fb52f..a98a292d 100644 --- a/api/r0/thumbnail.go +++ b/api/r0/thumbnail.go @@ -129,6 +129,8 @@ func ThumbnailMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta } else { return _responses.NotFoundError() // We lie for security } + } else if err == common.ErrMediaNotYetUploaded { + return _responses.NotYetUploaded() } rctx.Log.Error("Unexpected error locating media: " + err.Error()) sentry.CaptureException(err) diff --git a/common/errorcodes.go b/common/errorcodes.go index 0be63484..d616d1e4 100644 --- a/common/errorcodes.go +++ b/common/errorcodes.go @@ -17,3 +17,4 @@ const ErrCodeUnknown = "M_UNKNOWN" const ErrCodeForbidden = "M_FORBIDDEN" const ErrCodeQuotaExceeded = "M_QUOTA_EXCEEDED" const ErrCodeCannotOverwrite = "M_CANNOT_OVERWRITE_MEDIA" +const ErrCodeNotYetUploaded = "M_NOT_YET_UPLOADED" diff --git a/common/errors.go b/common/errors.go index 3da519c9..809e28cd 100644 --- a/common/errors.go +++ b/common/errors.go @@ -14,3 +14,4 @@ var ErrQuotaExceeded = errors.New("quota exceeded") var ErrWrongUser = errors.New("wrong user") var ErrExpired = errors.New("expired") var ErrAlreadyUploaded = errors.New("already uploaded") +var ErrMediaNotYetUploaded = errors.New("media not yet uploaded") diff --git a/database/table_expiring_media.go b/database/table_expiring_media.go index dd1f6546..416d1c37 100644 --- a/database/table_expiring_media.go +++ b/database/table_expiring_media.go @@ -15,6 +15,10 @@ type DbExpiringMedia struct { ExpiresTs int64 } +func (r *DbExpiringMedia) IsExpired() bool { + return r.ExpiresTs < util.NowMillis() +} + const insertExpiringMedia = "INSERT INTO expiring_media (origin, media_id, user_id, expires_ts) VALUES ($1, $2, $3, $4);" const selectExpiringMediaByUserCount = "SELECT COUNT(*) FROM expiring_media WHERE user_id = $1 AND expires_ts >= $2;" const selectExpiringMediaById = "SELECT origin, media_id, user_id, expires_ts FROM expiring_media WHERE origin = $1 AND media_id = $2;" diff --git a/notifier/uploads.go b/notifier/uploads.go new file mode 100644 index 00000000..6bd6569d --- /dev/null +++ b/notifier/uploads.go @@ -0,0 +1,112 @@ +package notifier + +import ( + "sync" + + "github.com/sirupsen/logrus" + "github.com/turt2live/matrix-media-repo/common/rcontext" + "github.com/turt2live/matrix-media-repo/database" + "github.com/turt2live/matrix-media-repo/redislib" + "github.com/turt2live/matrix-media-repo/util" +) + +var localWaiters = make(map[string][]chan *database.DbMedia) +var mu = new(sync.Mutex) +var redisChan <-chan string + +const uploadsNotifyRedisChannel = "mmr:upload_mxc" + +func GetUploadWaitChannel(origin string, mediaId string) (<-chan *database.DbMedia, func()) { + subscribeRedis() + mxc := util.MxcUri(origin, mediaId) + + mu.Lock() + defer mu.Unlock() + + if _, ok := localWaiters[mxc]; !ok { + localWaiters[mxc] = make([]chan *database.DbMedia, 0) + } + + ch := make(chan *database.DbMedia) + localWaiters[mxc] = append(localWaiters[mxc], ch) + + finishFn := func() { + mu.Lock() + defer mu.Unlock() + + if arr, ok := localWaiters[mxc]; ok { + newArr := make([]chan *database.DbMedia, 0) + for _, xch := range arr { + if xch != ch { + newArr = append(newArr, xch) + } + } + localWaiters[mxc] = newArr + } + + close(ch) + } + + return ch, finishFn +} + +func UploadDone(ctx rcontext.RequestContext, record *database.DbMedia) error { + mxc := util.MxcUri(record.Origin, record.MediaId) + noRelayNotifyUpload(record) + return redislib.Publish(ctx, uploadsNotifyRedisChannel, mxc) +} + +func noRelayNotifyUpload(record *database.DbMedia) { + go func() { + mxc := util.MxcUri(record.Origin, record.MediaId) + + mu.Lock() + defer mu.Unlock() + + if arr, ok := localWaiters[mxc]; ok { + for _, ch := range arr { + ch <- record + } + delete(localWaiters, mxc) + } + }() +} + +func subscribeRedis() { + if redisChan != nil { + return + } + + mu.Lock() + defer mu.Unlock() + + redisChan = redislib.Subscribe(uploadsNotifyRedisChannel) + if redisChan == nil { + return // no redis to subscribe with + } + go func() { + for { + val := <-redisChan + logrus.Debug("Received value from uploads notify channel: ", val) + + origin, mediaId, err := util.SplitMxc(val) + if err != nil { + logrus.Warn("Non-fatal error receiving from uploads notify channel: ", err) + continue + } + + db := database.GetInstance().Media.Prepare(rcontext.Initial()) + record, err := db.GetById(origin, mediaId) + if err != nil { + logrus.Warn("Non-fatal error processing record from uploads notify channel: ", err) + continue + } + if record == nil { + logrus.Warn("Received notification that a media record is available, but it's not") + continue + } + + noRelayNotifyUpload(record) + } + }() +} diff --git a/pipelines/_steps/download/wait.go b/pipelines/_steps/download/wait.go new file mode 100644 index 00000000..c465fa9f --- /dev/null +++ b/pipelines/_steps/download/wait.go @@ -0,0 +1,28 @@ +package download + +import ( + "github.com/turt2live/matrix-media-repo/common" + "github.com/turt2live/matrix-media-repo/common/rcontext" + "github.com/turt2live/matrix-media-repo/database" + "github.com/turt2live/matrix-media-repo/notifier" +) + +func WaitForAsyncMedia(ctx rcontext.RequestContext, origin string, mediaId string) (*database.DbMedia, error) { + db := database.GetInstance().ExpiringMedia.Prepare(ctx) + record, err := db.Get(origin, mediaId) + if err != nil { + return nil, err + } + if record == nil || record.IsExpired() { + return nil, nil // there's not going to be a record + } + + ch, finish := notifier.GetUploadWaitChannel(origin, mediaId) + defer finish() + select { + case <-ctx.Context.Done(): + return nil, common.ErrMediaNotYetUploaded + case val := <-ch: + return val, nil + } +} diff --git a/pipelines/pipeline_download/pipeline.go b/pipelines/pipeline_download/pipeline.go index 21dcd824..0c5ebef3 100644 --- a/pipelines/pipeline_download/pipeline.go +++ b/pipelines/pipeline_download/pipeline.go @@ -53,6 +53,8 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Do // Step 3: Do we already have the media? Serve it if yes. mediaDb := database.GetInstance().Media.Prepare(ctx) record, err := mediaDb.GetById(origin, mediaId) + didAsyncWait := false + handlePossibleRecord: if err != nil { return nil, err } @@ -68,7 +70,14 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Do return download.OpenStream(ctx, record.Locatable, opts.StartByte, opts.EndByte) } - // Step 4: Media record unknown - download it (if possible) + // Step 4: Wait for the media, if we can + if !didAsyncWait { + record, err = download.WaitForAsyncMedia(ctx, origin, mediaId) + didAsyncWait = true + goto handlePossibleRecord + } + + // Step 5: Media record unknown - download it (if possible) if !opts.FetchRemoteIfNeeded { return nil, common.ErrMediaNotFound } @@ -86,7 +95,7 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Do return nil, nil } - // Step 5: Limit the stream if needed + // Step 6: Limit the stream if needed r, err = download.CreateLimitedStream(ctx, r, opts.StartByte, opts.EndByte) if err != nil { return nil, err diff --git a/pipelines/pipeline_upload/pipeline.go b/pipelines/pipeline_upload/pipeline.go index 15fbc609..69bd4412 100644 --- a/pipelines/pipeline_upload/pipeline.go +++ b/pipelines/pipeline_upload/pipeline.go @@ -8,6 +8,7 @@ import ( "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/database" "github.com/turt2live/matrix-media-repo/datastores" + "github.com/turt2live/matrix-media-repo/notifier" "github.com/turt2live/matrix-media-repo/pipelines/_steps/meta" "github.com/turt2live/matrix-media-repo/pipelines/_steps/quota" "github.com/turt2live/matrix-media-repo/pipelines/_steps/upload" @@ -16,6 +17,14 @@ import ( // Execute Media upload. If mediaId is an empty string, one will be generated. func Execute(ctx rcontext.RequestContext, origin string, mediaId string, r io.ReadCloser, contentType string, fileName string, userId string, kind datastores.Kind) (*database.DbMedia, error) { + uploadDone := func(record *database.DbMedia) { + meta.FlagAccess(ctx, record.Sha256Hash) + if err := notifier.UploadDone(ctx, record); err != nil { + ctx.Log.Warn("Non-fatal error notifying about completed upload: ", err) + sentry.CaptureException(err) + } + } + // Step 1: Limit the stream's length r = upload.LimitStream(ctx, r) @@ -103,6 +112,7 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, r io.Re if err = database.GetInstance().Media.Prepare(ctx).Insert(newRecord); err != nil { return nil, err } + uploadDone(newRecord) return newRecord, nil } } @@ -141,6 +151,6 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, r io.Re } return nil, err } - meta.FlagAccess(ctx, newRecord.Sha256Hash) + uploadDone(newRecord) return newRecord, nil } diff --git a/pipelines/pipeline_upload/pipeline2.go b/pipelines/pipeline_upload/pipeline2.go index 6d89c1b4..e883cfca 100644 --- a/pipelines/pipeline_upload/pipeline2.go +++ b/pipelines/pipeline_upload/pipeline2.go @@ -8,7 +8,6 @@ import ( "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/database" "github.com/turt2live/matrix-media-repo/datastores" - "github.com/turt2live/matrix-media-repo/util" ) func ExecutePut(ctx rcontext.RequestContext, origin string, mediaId string, r io.ReadCloser, contentType string, fileName string, userId string) (*database.DbMedia, error) { @@ -30,7 +29,7 @@ func ExecutePut(ctx rcontext.RequestContext, origin string, mediaId string, r io } // Step 3: Is the record expired? - if record == nil || record.ExpiresTs < util.NowMillis() { + if record == nil || record.IsExpired() { return nil, common.ErrExpired } diff --git a/redislib/connection.go b/redislib/connection.go index de03604f..50a38bb7 100644 --- a/redislib/connection.go +++ b/redislib/connection.go @@ -49,11 +49,17 @@ func makeConnection() { } func Reconnect() { - Stop() + softStop() makeConnection() + resubscribeAll() } func Stop() { + softStop() + resubscribeAll() // since we don't have a `ring`, it'll close everything +} + +func softStop() { if ring != nil { _ = ring.Close() } diff --git a/redislib/pubsub.go b/redislib/pubsub.go new file mode 100644 index 00000000..126837e5 --- /dev/null +++ b/redislib/pubsub.go @@ -0,0 +1,88 @@ +package redislib + +import ( + "context" + "sync" + + "github.com/redis/go-redis/v9" + "github.com/turt2live/matrix-media-repo/common/rcontext" +) + +var subscribeMutex = new(sync.Mutex) +var subscribeChans = make(map[string][]chan string) + +type PubSubValue struct { + Err error + Str string +} + +func Publish(ctx rcontext.RequestContext, channel string, payload string) error { + makeConnection() + if ring == nil { + return nil + } + + if ring.PoolStats().TotalConns == 0 { + ctx.Log.Warn("Not broadcasting upload to Redis - no connections available") + return nil + } + + r := ring.Publish(ctx.Context, channel, payload) + if r.Err() != nil { + if r.Err() == redis.Nil { + ctx.Log.Warn("Not broadcasting upload to Redis - no connections available") + return nil + } + return r.Err() + } + return nil +} + +func Subscribe(channel string) <-chan string { + makeConnection() + if ring == nil { + return nil + } + + ch := make(chan string) + subscribeMutex.Lock() + if _, ok := subscribeChans[channel]; !ok { + subscribeChans[channel] = make([]chan string, 0) + } + subscribeChans[channel] = append(subscribeChans[channel], ch) + subscribeMutex.Unlock() + doSubscribe(channel, ch) + return ch +} + +func doSubscribe(channel string, ch chan<- string) { + sub := ring.Subscribe(context.Background(), channel) + go func(ch chan<- string) { + recvCh := sub.Channel() + for { + val := <-recvCh + if val != nil { + ch <- val.Payload + } else { + break + } + } + }(ch) +} + +func resubscribeAll() { + subscribeMutex.Lock() + defer subscribeMutex.Unlock() + for channel, chs := range subscribeChans { + for _, ch := range chs { + if ring == nil { + close(ch) + } else { + doSubscribe(channel, ch) + } + } + } + if ring == nil { + subscribeChans = make(map[string][]chan string) + } +} -- GitLab