Skip to content
Snippets Groups Projects
Commit 73dfdc84 authored by Travis Ralston's avatar Travis Ralston
Browse files

Be slightly more responsible with channels in download/thumbnails

parent c2bd8999
No related branches found
No related tags found
No related merge requests found
...@@ -27,12 +27,13 @@ func GetUploadWaitChannel(origin string, mediaId string) (<-chan *database.DbMed ...@@ -27,12 +27,13 @@ func GetUploadWaitChannel(origin string, mediaId string) (<-chan *database.DbMed
localUploadWaiters[mxc] = make([]chan *database.DbMedia, 0) localUploadWaiters[mxc] = make([]chan *database.DbMedia, 0)
} }
ch := make(chan *database.DbMedia) ch := make(chan *database.DbMedia, 1)
localUploadWaiters[mxc] = append(localUploadWaiters[mxc], ch) localUploadWaiters[mxc] = append(localUploadWaiters[mxc], ch)
finishFn := func() { finishFn := func() {
uploadMutex.Lock() uploadMutex.Lock()
defer uploadMutex.Unlock() defer uploadMutex.Unlock()
defer close(ch)
if arr, ok := localUploadWaiters[mxc]; ok { if arr, ok := localUploadWaiters[mxc]; ok {
newArr := make([]chan *database.DbMedia, 0) newArr := make([]chan *database.DbMedia, 0)
...@@ -43,8 +44,6 @@ func GetUploadWaitChannel(origin string, mediaId string) (<-chan *database.DbMed ...@@ -43,8 +44,6 @@ func GetUploadWaitChannel(origin string, mediaId string) (<-chan *database.DbMed
} }
localUploadWaiters[mxc] = newArr localUploadWaiters[mxc] = newArr
} }
close(ch)
} }
return ch, finishFn return ch, finishFn
......
...@@ -15,7 +15,6 @@ import ( ...@@ -15,7 +15,6 @@ import (
"github.com/turt2live/matrix-media-repo/pipelines/_steps/download" "github.com/turt2live/matrix-media-repo/pipelines/_steps/download"
"github.com/turt2live/matrix-media-repo/pipelines/_steps/meta" "github.com/turt2live/matrix-media-repo/pipelines/_steps/meta"
"github.com/turt2live/matrix-media-repo/pipelines/_steps/quarantine" "github.com/turt2live/matrix-media-repo/pipelines/_steps/quarantine"
"github.com/turt2live/matrix-media-repo/util/readers"
) )
var sf = new(sfstreams.Group) var sf = new(sfstreams.Group)
...@@ -35,11 +34,11 @@ func (o DownloadOpts) String() string { ...@@ -35,11 +34,11 @@ func (o DownloadOpts) String() string {
func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts DownloadOpts) (*database.DbMedia, io.ReadCloser, error) { func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts DownloadOpts) (*database.DbMedia, io.ReadCloser, error) {
// Step 1: Make our context a timeout context // Step 1: Make our context a timeout context
var cancel context.CancelFunc var cancel context.CancelFunc
//goland:noinspection GoVetLostCancel - we handle the function in our custom cancelCloser struct
ctx.Context, cancel = context.WithTimeout(ctx.Context, opts.BlockForReadUntil) ctx.Context, cancel = context.WithTimeout(ctx.Context, opts.BlockForReadUntil)
defer cancel()
// Step 2: Join the singleflight queue // Step 2: Join the singleflight queue
recordCh := make(chan *database.DbMedia) recordCh := make(chan *database.DbMedia, 1)
defer close(recordCh) defer close(recordCh)
r, err, _ := sf.Do(fmt.Sprintf("%s/%s?%s", origin, mediaId, opts.String()), func() (io.ReadCloser, error) { r, err, _ := sf.Do(fmt.Sprintf("%s/%s?%s", origin, mediaId, opts.String()), func() (io.ReadCloser, error) {
serveRecord := func(recordCh chan *database.DbMedia, record *database.DbMedia) { serveRecord := func(recordCh chan *database.DbMedia, record *database.DbMedia) {
...@@ -104,11 +103,9 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Do ...@@ -104,11 +103,9 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Do
return r, nil return r, nil
}) })
if errors.Is(err, common.ErrMediaQuarantined) { if errors.Is(err, common.ErrMediaQuarantined) {
cancel()
return nil, r, err return nil, r, err
} }
if err != nil { if err != nil {
cancel()
return nil, nil, err return nil, nil, err
} }
record := <-recordCh record := <-recordCh
...@@ -119,8 +116,7 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Do ...@@ -119,8 +116,7 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Do
sentry.CaptureException(devErr) sentry.CaptureException(devErr)
r.Close() r.Close()
} }
cancel()
return record, nil, nil return record, nil, nil
} }
return record, readers.NewCancelCloser(r, cancel), nil return record, r, nil
} }
...@@ -15,7 +15,6 @@ import ( ...@@ -15,7 +15,6 @@ import (
"github.com/turt2live/matrix-media-repo/pipelines/_steps/quarantine" "github.com/turt2live/matrix-media-repo/pipelines/_steps/quarantine"
"github.com/turt2live/matrix-media-repo/pipelines/_steps/thumbnails" "github.com/turt2live/matrix-media-repo/pipelines/_steps/thumbnails"
"github.com/turt2live/matrix-media-repo/pipelines/pipeline_download" "github.com/turt2live/matrix-media-repo/pipelines/pipeline_download"
"github.com/turt2live/matrix-media-repo/util/readers"
) )
var sf = new(sfstreams.Group) var sf = new(sfstreams.Group)
...@@ -57,11 +56,11 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Th ...@@ -57,11 +56,11 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Th
// Step 2: Make our context a timeout context // Step 2: Make our context a timeout context
var cancel context.CancelFunc var cancel context.CancelFunc
//goland:noinspection GoVetLostCancel - we handle the function in our custom cancelCloser struct
ctx.Context, cancel = context.WithTimeout(ctx.Context, opts.BlockForReadUntil) ctx.Context, cancel = context.WithTimeout(ctx.Context, opts.BlockForReadUntil)
defer cancel()
// Step 3: Join the singleflight queue // Step 3: Join the singleflight queue
recordCh := make(chan *database.DbThumbnail) recordCh := make(chan *database.DbThumbnail, 1)
defer close(recordCh) defer close(recordCh)
r, err, _ := sf.Do(fmt.Sprintf("%s/%s?%s", origin, mediaId, opts.String()), func() (io.ReadCloser, error) { r, err, _ := sf.Do(fmt.Sprintf("%s/%s?%s", origin, mediaId, opts.String()), func() (io.ReadCloser, error) {
serveRecord := func(recordCh chan *database.DbThumbnail, record *database.DbThumbnail) { serveRecord := func(recordCh chan *database.DbThumbnail, record *database.DbThumbnail) {
...@@ -117,11 +116,9 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Th ...@@ -117,11 +116,9 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Th
return download.CreateLimitedStream(ctx, r, opts.StartByte, opts.EndByte) return download.CreateLimitedStream(ctx, r, opts.StartByte, opts.EndByte)
}) })
if errors.Is(err, common.ErrMediaQuarantined) { if errors.Is(err, common.ErrMediaQuarantined) {
cancel()
return nil, r, err return nil, r, err
} }
if err != nil { if err != nil {
cancel()
return nil, nil, err return nil, nil, err
} }
record := <-recordCh record := <-recordCh
...@@ -132,8 +129,7 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Th ...@@ -132,8 +129,7 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Th
sentry.CaptureException(devErr) sentry.CaptureException(devErr)
r.Close() r.Close()
} }
cancel()
return record, nil, nil return record, nil, nil
} }
return record, readers.NewCancelCloser(r, cancel), nil return record, r, nil
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment