diff --git a/notifier/uploads.go b/notifier/uploads.go index 2a6ee91177cb0bf00d7b086800e47e47de5ab8ce..8cf951217d3bed9e8355e1e301bc2a4b89602d81 100644 --- a/notifier/uploads.go +++ b/notifier/uploads.go @@ -27,13 +27,12 @@ func GetUploadWaitChannel(origin string, mediaId string) (<-chan *database.DbMed localUploadWaiters[mxc] = make([]chan *database.DbMedia, 0) } - ch := make(chan *database.DbMedia, 1) + ch := make(chan *database.DbMedia) localUploadWaiters[mxc] = append(localUploadWaiters[mxc], ch) finishFn := func() { uploadMutex.Lock() defer uploadMutex.Unlock() - defer close(ch) if arr, ok := localUploadWaiters[mxc]; ok { newArr := make([]chan *database.DbMedia, 0) @@ -44,6 +43,8 @@ func GetUploadWaitChannel(origin string, mediaId string) (<-chan *database.DbMed } localUploadWaiters[mxc] = newArr } + + close(ch) } return ch, finishFn diff --git a/pipelines/pipeline_download/pipeline.go b/pipelines/pipeline_download/pipeline.go index 9d4db1f29be2b5cf54ad6cead74f12ddc5f57dfe..f57de2fd449c27cb51eb26e61be0a688b3d824b7 100644 --- a/pipelines/pipeline_download/pipeline.go +++ b/pipelines/pipeline_download/pipeline.go @@ -15,6 +15,7 @@ import ( "github.com/turt2live/matrix-media-repo/pipelines/_steps/download" "github.com/turt2live/matrix-media-repo/pipelines/_steps/meta" "github.com/turt2live/matrix-media-repo/pipelines/_steps/quarantine" + "github.com/turt2live/matrix-media-repo/util/readers" ) var sf = new(sfstreams.Group) @@ -34,11 +35,11 @@ func (o DownloadOpts) String() string { func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts DownloadOpts) (*database.DbMedia, io.ReadCloser, error) { // Step 1: Make our context a timeout context var cancel context.CancelFunc + //goland:noinspection GoVetLostCancel - we handle the function in our custom cancelCloser struct ctx.Context, cancel = context.WithTimeout(ctx.Context, opts.BlockForReadUntil) - defer cancel() // Step 2: Join the singleflight queue - recordCh := make(chan *database.DbMedia, 1) + recordCh := make(chan *database.DbMedia) defer close(recordCh) r, err, _ := sf.Do(fmt.Sprintf("%s/%s?%s", origin, mediaId, opts.String()), func() (io.ReadCloser, error) { serveRecord := func(recordCh chan *database.DbMedia, record *database.DbMedia) { @@ -103,9 +104,11 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Do return r, nil }) if errors.Is(err, common.ErrMediaQuarantined) { + cancel() return nil, r, err } if err != nil { + cancel() return nil, nil, err } record := <-recordCh @@ -116,7 +119,8 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Do sentry.CaptureException(devErr) r.Close() } + cancel() return record, nil, nil } - return record, r, nil + return record, readers.NewCancelCloser(r, cancel), nil } diff --git a/pipelines/pipeline_thumbnail/pipeline.go b/pipelines/pipeline_thumbnail/pipeline.go index 8dad9efd2660d441a01a368fd0a7cf6f8eb7a102..9786eac96ea917be2caf93c53d5ff445372b42ad 100644 --- a/pipelines/pipeline_thumbnail/pipeline.go +++ b/pipelines/pipeline_thumbnail/pipeline.go @@ -15,6 +15,7 @@ import ( "github.com/turt2live/matrix-media-repo/pipelines/_steps/quarantine" "github.com/turt2live/matrix-media-repo/pipelines/_steps/thumbnails" "github.com/turt2live/matrix-media-repo/pipelines/pipeline_download" + "github.com/turt2live/matrix-media-repo/util/readers" ) var sf = new(sfstreams.Group) @@ -56,11 +57,11 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Th // Step 2: Make our context a timeout context var cancel context.CancelFunc + //goland:noinspection GoVetLostCancel - we handle the function in our custom cancelCloser struct ctx.Context, cancel = context.WithTimeout(ctx.Context, opts.BlockForReadUntil) - defer cancel() // Step 3: Join the singleflight queue - recordCh := make(chan *database.DbThumbnail, 1) + recordCh := make(chan *database.DbThumbnail) defer close(recordCh) r, err, _ := sf.Do(fmt.Sprintf("%s/%s?%s", origin, mediaId, opts.String()), func() (io.ReadCloser, error) { serveRecord := func(recordCh chan *database.DbThumbnail, record *database.DbThumbnail) { @@ -116,9 +117,11 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Th return download.CreateLimitedStream(ctx, r, opts.StartByte, opts.EndByte) }) if errors.Is(err, common.ErrMediaQuarantined) { + cancel() return nil, r, err } if err != nil { + cancel() return nil, nil, err } record := <-recordCh @@ -129,7 +132,8 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Th sentry.CaptureException(devErr) r.Close() } + cancel() return record, nil, nil } - return record, r, nil + return record, readers.NewCancelCloser(r, cancel), nil }