From d856942b0ec753a69c303947732f15b5137fb75b Mon Sep 17 00:00:00 2001
From: Travis Ralston <travpc@gmail.com>
Date: Sun, 20 Aug 2023 00:21:15 -0600
Subject: [PATCH] Revert "Be slightly more responsible with channels in
 download/thumbnails"

This reverts commit 73dfdc84b0ef6b65907559c0809ef56b2ca94272.
---
 notifier/uploads.go                      |  5 +++--
 pipelines/pipeline_download/pipeline.go  | 10 +++++++---
 pipelines/pipeline_thumbnail/pipeline.go | 10 +++++++---
 3 files changed, 17 insertions(+), 8 deletions(-)

diff --git a/notifier/uploads.go b/notifier/uploads.go
index 2a6ee911..8cf95121 100644
--- a/notifier/uploads.go
+++ b/notifier/uploads.go
@@ -27,13 +27,12 @@ func GetUploadWaitChannel(origin string, mediaId string) (<-chan *database.DbMed
 		localUploadWaiters[mxc] = make([]chan *database.DbMedia, 0)
 	}
 
-	ch := make(chan *database.DbMedia, 1)
+	ch := make(chan *database.DbMedia)
 	localUploadWaiters[mxc] = append(localUploadWaiters[mxc], ch)
 
 	finishFn := func() {
 		uploadMutex.Lock()
 		defer uploadMutex.Unlock()
-		defer close(ch)
 
 		if arr, ok := localUploadWaiters[mxc]; ok {
 			newArr := make([]chan *database.DbMedia, 0)
@@ -44,6 +43,8 @@ func GetUploadWaitChannel(origin string, mediaId string) (<-chan *database.DbMed
 			}
 			localUploadWaiters[mxc] = newArr
 		}
+
+		close(ch)
 	}
 
 	return ch, finishFn
diff --git a/pipelines/pipeline_download/pipeline.go b/pipelines/pipeline_download/pipeline.go
index 9d4db1f2..f57de2fd 100644
--- a/pipelines/pipeline_download/pipeline.go
+++ b/pipelines/pipeline_download/pipeline.go
@@ -15,6 +15,7 @@ import (
 	"github.com/turt2live/matrix-media-repo/pipelines/_steps/download"
 	"github.com/turt2live/matrix-media-repo/pipelines/_steps/meta"
 	"github.com/turt2live/matrix-media-repo/pipelines/_steps/quarantine"
+	"github.com/turt2live/matrix-media-repo/util/readers"
 )
 
 var sf = new(sfstreams.Group)
@@ -34,11 +35,11 @@ func (o DownloadOpts) String() string {
 func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts DownloadOpts) (*database.DbMedia, io.ReadCloser, error) {
 	// Step 1: Make our context a timeout context
 	var cancel context.CancelFunc
+	//goland:noinspection GoVetLostCancel - we handle the function in our custom cancelCloser struct
 	ctx.Context, cancel = context.WithTimeout(ctx.Context, opts.BlockForReadUntil)
-	defer cancel()
 
 	// Step 2: Join the singleflight queue
-	recordCh := make(chan *database.DbMedia, 1)
+	recordCh := make(chan *database.DbMedia)
 	defer close(recordCh)
 	r, err, _ := sf.Do(fmt.Sprintf("%s/%s?%s", origin, mediaId, opts.String()), func() (io.ReadCloser, error) {
 		serveRecord := func(recordCh chan *database.DbMedia, record *database.DbMedia) {
@@ -103,9 +104,11 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Do
 		return r, nil
 	})
 	if errors.Is(err, common.ErrMediaQuarantined) {
+		cancel()
 		return nil, r, err
 	}
 	if err != nil {
+		cancel()
 		return nil, nil, err
 	}
 	record := <-recordCh
@@ -116,7 +119,8 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Do
 			sentry.CaptureException(devErr)
 			r.Close()
 		}
+		cancel()
 		return record, nil, nil
 	}
-	return record, r, nil
+	return record, readers.NewCancelCloser(r, cancel), nil
 }
diff --git a/pipelines/pipeline_thumbnail/pipeline.go b/pipelines/pipeline_thumbnail/pipeline.go
index 8dad9efd..9786eac9 100644
--- a/pipelines/pipeline_thumbnail/pipeline.go
+++ b/pipelines/pipeline_thumbnail/pipeline.go
@@ -15,6 +15,7 @@ import (
 	"github.com/turt2live/matrix-media-repo/pipelines/_steps/quarantine"
 	"github.com/turt2live/matrix-media-repo/pipelines/_steps/thumbnails"
 	"github.com/turt2live/matrix-media-repo/pipelines/pipeline_download"
+	"github.com/turt2live/matrix-media-repo/util/readers"
 )
 
 var sf = new(sfstreams.Group)
@@ -56,11 +57,11 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Th
 
 	// Step 2: Make our context a timeout context
 	var cancel context.CancelFunc
+	//goland:noinspection GoVetLostCancel - we handle the function in our custom cancelCloser struct
 	ctx.Context, cancel = context.WithTimeout(ctx.Context, opts.BlockForReadUntil)
-	defer cancel()
 
 	// Step 3: Join the singleflight queue
-	recordCh := make(chan *database.DbThumbnail, 1)
+	recordCh := make(chan *database.DbThumbnail)
 	defer close(recordCh)
 	r, err, _ := sf.Do(fmt.Sprintf("%s/%s?%s", origin, mediaId, opts.String()), func() (io.ReadCloser, error) {
 		serveRecord := func(recordCh chan *database.DbThumbnail, record *database.DbThumbnail) {
@@ -116,9 +117,11 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Th
 		return download.CreateLimitedStream(ctx, r, opts.StartByte, opts.EndByte)
 	})
 	if errors.Is(err, common.ErrMediaQuarantined) {
+		cancel()
 		return nil, r, err
 	}
 	if err != nil {
+		cancel()
 		return nil, nil, err
 	}
 	record := <-recordCh
@@ -129,7 +132,8 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Th
 			sentry.CaptureException(devErr)
 			r.Close()
 		}
+		cancel()
 		return record, nil, nil
 	}
-	return record, r, nil
+	return record, readers.NewCancelCloser(r, cancel), nil
 }
-- 
GitLab