diff --git a/pipline/_steps/upload/lock.go b/pipline/_steps/upload/lock.go index 6a9f12d2b584ec3c95e75817d058d8430ac29cd9..83fdaf24d428e9a02bb922d6fd13bd89c8e62e8b 100644 --- a/pipline/_steps/upload/lock.go +++ b/pipline/_steps/upload/lock.go @@ -1,6 +1,7 @@ package upload import ( + "context" "errors" "time" @@ -11,7 +12,7 @@ import ( const maxLockAttemptTime = 30 * time.Second func LockForUpload(ctx rcontext.RequestContext, hash string) (func() error, error) { - mutex := redislib.GetMutex(hash, 1*time.Minute) + mutex := redislib.GetMutex(hash, 5*time.Minute) if mutex != nil { attemptDoneAt := time.Now().Add(maxLockAttemptTime) acquired := false @@ -20,9 +21,10 @@ func LockForUpload(ctx rcontext.RequestContext, hash string) (func() error, erro return nil, chErr } if err := mutex.LockContext(ctx.Context); err != nil { - ctx.Log.Warn("failed to acquire upload lock") if time.Now().After(attemptDoneAt) { return nil, errors.New("failed to acquire upload lock: " + err.Error()) + } else { + ctx.Log.Warn("failed to acquire upload lock: " + err.Error()) } } else { acquired = true @@ -31,12 +33,15 @@ func LockForUpload(ctx rcontext.RequestContext, hash string) (func() error, erro if !acquired { return nil, errors.New("failed to acquire upload lock: timeout") } + ctx.Log.Debugf("Lock acquired until %s", mutex.Until().UTC()) return func() error { - b, err := mutex.UnlockContext(ctx.Context) - if !b { - ctx.Log.Warn("Did not get quorum on unlock") + ctx.Log.Debug("Unlocking upload lock") + // We use a background context here to prevent a cancelled context from keeping the lock open + if ok, err := mutex.UnlockContext(context.Background()); !ok || err != nil { + ctx.Log.Warn("Did not get quorum on unlock: ", err) + return err } - return err + return nil }, nil } else { ctx.Log.Warn("Continuing upload without lock! Set up Redis to make this warning go away.") diff --git a/pipline/_steps/upload/redis_async.go b/pipline/_steps/upload/redis_async.go new file mode 100644 index 0000000000000000000000000000000000000000..5a3b0d544512f25b95007c9a16cc8270992e95d5 --- /dev/null +++ b/pipline/_steps/upload/redis_async.go @@ -0,0 +1,25 @@ +package upload + +import ( + "io" + + "github.com/turt2live/matrix-media-repo/common/rcontext" + "github.com/turt2live/matrix-media-repo/redislib" +) + +func PopulateCacheAsync(ctx rcontext.RequestContext, reader io.Reader, size int64, sha256hash string) chan struct{} { + var err error + opChan := make(chan struct{}) + go func() { + //goland:noinspection GoUnhandledErrorResult + defer io.Copy(io.Discard, reader) // we need to flush the reader as we might end up blocking the upload + defer close(opChan) + + err = redislib.StoreMedia(ctx, sha256hash, reader, size) + if err != nil { + ctx.Log.Debug("Not populating cache due to error: ", err) + return + } + }() + return opChan +} diff --git a/pipline/upload_pipeline/pipeline.go b/pipline/upload_pipeline/pipeline.go index bc18c1460390a312bb776d693806c83909fed4fb..cd0b00dd5a6d1918a5ce4c6d74d1daabfde7e5a5 100644 --- a/pipline/upload_pipeline/pipeline.go +++ b/pipline/upload_pipeline/pipeline.go @@ -1,6 +1,7 @@ package upload_pipeline import ( + "errors" "io" "github.com/getsentry/sentry-go" @@ -18,12 +19,14 @@ func UploadMedia(ctx rcontext.RequestContext, origin string, mediaId string, r i r = upload.LimitStream(ctx, r) // Step 2: Create a media ID (if needed) + mustUseMediaId := true if mediaId == "" { var err error mediaId, err = upload.GenerateMediaId(ctx, origin) if err != nil { return nil, err } + mustUseMediaId = false } // Step 3: Pick a datastore @@ -39,9 +42,14 @@ func UploadMedia(ctx rcontext.RequestContext, origin string, mediaId string, r i } defer reader.Close() - // Step 5: Split the buffer to calculate a blurhash later + // Step 5: Split the buffer to calculate a blurhash & populate cache later bhR, bhW := io.Pipe() - tee := io.TeeReader(reader, bhW) + cacheR, cacheW := io.Pipe() + allWriters := io.MultiWriter(cacheW, bhW) + tee := io.TeeReader(reader, allWriters) + + defer bhW.CloseWithError(errors.New("failed to finish write")) + defer cacheW.CloseWithError(errors.New("failed to finish write")) // Step 6: Check quarantine if err = upload.CheckQuarantineStatus(ctx, sha256hash); err != nil { @@ -56,11 +64,11 @@ func UploadMedia(ctx rcontext.RequestContext, origin string, mediaId string, r i // Step 8: Acquire a lock on the media hash for uploading unlockFn, err := upload.LockForUpload(ctx, sha256hash) - //goland:noinspection GoUnhandledErrorResult - defer unlockFn() if err != nil { return nil, err } + //goland:noinspection GoUnhandledErrorResult + defer unlockFn() // Step 9: Pull all upload records (to check if an upload has already happened) newRecord := &database.DbMedia{ @@ -79,7 +87,7 @@ func UploadMedia(ctx rcontext.RequestContext, origin string, mediaId string, r i record, perfect, err := upload.FindRecord(ctx, sha256hash, userId, contentType, fileName) if record != nil { // We already had this record in some capacity - if perfect { + if perfect && !mustUseMediaId { // Exact match - deduplicate, skip upload to datastore return record, nil } else { @@ -97,16 +105,28 @@ func UploadMedia(ctx rcontext.RequestContext, origin string, mediaId string, r i // Step 10: Asynchronously calculate blurhash bhChan := upload.CalculateBlurhashAsync(ctx, bhR, sha256hash) - // Step 11: Since we didn't find a duplicate, upload it to the datastore + // Step 11: Asynchronously upload to cache + cacheChan := upload.PopulateCacheAsync(ctx, cacheR, sizeBytes, sha256hash) + + // Step 12: Since we didn't find a duplicate, upload it to the datastore dsLocation, err := datastores.Upload(ctx, dsConf, io.NopCloser(tee), sizeBytes, contentType, sha256hash) if err != nil { return nil, err } + if err = bhW.Close(); err != nil { + ctx.Log.Warn("Failed to close writer for blurhash: ", err) + close(bhChan) + } + if err = cacheW.Close(); err != nil { + ctx.Log.Warn("Failed to close writer for cache layer: ", err) + close(cacheChan) + } - // Step 12: Wait for blurhash + // Step 13: Wait for channels <-bhChan + <-cacheChan - // Step 13: Everything finally looks good - return some stuff + // Step 14: Everything finally looks good - return some stuff newRecord.DatastoreId = dsConf.Id newRecord.Location = dsLocation if err = database.GetInstance().Media.Prepare(ctx).Insert(newRecord); err != nil { diff --git a/redislib/cache.go b/redislib/cache.go new file mode 100644 index 0000000000000000000000000000000000000000..93f013e1bf9761c8d7d833a53b5e30bb79abbea7 --- /dev/null +++ b/redislib/cache.go @@ -0,0 +1,84 @@ +package redislib + +import ( + "bytes" + "context" + "errors" + "io" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/redis/go-redis/v9" + "github.com/turt2live/matrix-media-repo/common/rcontext" + "github.com/turt2live/matrix-media-repo/metrics" +) + +const appendBufferSize = 1024 // 1kb +const mediaExpirationTime = 5 * time.Minute +const redisMaxValueSize = 512 * 1024 * 1024 // 512mb + +func StoreMedia(ctx rcontext.RequestContext, hash string, content io.Reader, size int64) error { + makeConnection() + if ring == nil { + return nil + } + if size >= redisMaxValueSize { + return nil + } + + if err := ring.ForEachShard(ctx.Context, func(ctx2 context.Context, client *redis.Client) error { + res := client.Set(ctx2, hash, make([]byte, 0), mediaExpirationTime) + return res.Err() + }); err != nil { + return err + } + + buf := make([]byte, appendBufferSize) + for true { + read, err := content.Read(buf) + if err == io.EOF { + break + } + if err = ring.ForEachShard(ctx.Context, func(ctx2 context.Context, client *redis.Client) error { + res := client.Append(ctx2, hash, string(buf[0:read])) + return res.Err() + }); err != nil { + return err + } + } + + return nil +} + +func TryGetMedia(ctx rcontext.RequestContext, hash string, startByte int64, endByte int64) (io.Reader, error) { + makeConnection() + if ring == nil { + return nil, nil + } + + timeoutCtx, cancel := context.WithTimeout(ctx.Context, 20*time.Second) + defer cancel() + + var result *redis.StringCmd + if startByte >= 0 && endByte >= 1 { + if startByte < endByte { + result = ring.GetRange(timeoutCtx, hash, startByte, endByte) + } else { + return nil, errors.New("invalid range - start must be before end") + } + } else { + result = ring.Get(timeoutCtx, hash) + } + + s, err := result.Result() + if err != nil { + if err == redis.Nil { + metrics.CacheMisses.With(prometheus.Labels{"cache": "media"}).Inc() + return nil, nil + } + return nil, err + } + + metrics.CacheHits.With(prometheus.Labels{"cache": "media"}).Inc() + return bytes.NewBuffer([]byte(s)), nil +} diff --git a/redislib/locking.go b/redislib/locking.go index 9f60a74e906ec9ef15f1c026ff8a4c0ebefe8b62..f3f19f4f6cee578bc626759eb9c3770bd8f61cfc 100644 --- a/redislib/locking.go +++ b/redislib/locking.go @@ -12,5 +12,8 @@ func GetMutex(key string, expiration time.Duration) *redsync.Mutex { return nil } - return rs.NewMutex(key, redsync.WithExpiry(expiration)) + // Dev note: the prefix is to prevent key conflicts. Specifically, we create an upload mutex using + // the sha256 hash of the file *and* populate the redis cache with that file at the same key - this + // causes the mutex lock to fail unlocking because the value "changed". A prefix avoids that conflict. + return rs.NewMutex("mutex-"+key, redsync.WithExpiry(expiration)) }