From fa36956fed4b955420521541840517b861c2cad3 Mon Sep 17 00:00:00 2001
From: Travis Ralston <travpc@gmail.com>
Date: Fri, 9 Jun 2023 22:43:53 -0600
Subject: [PATCH] Handle quarantined media (inline thumbnail generation)

For https://github.com/turt2live/matrix-media-repo/issues/413
For https://github.com/turt2live/matrix-media-repo/issues/412
---
 api/_responses/content.go                | 10 +++
 api/r0/download.go                       | 11 +--
 api/r0/thumbnail.go                      |  9 ++-
 config.sample.yaml                       |  1 +
 go.mod                                   |  2 +-
 go.sum                                   |  4 +-
 pipelines/_steps/quarantine/logic.go     | 29 ++++++++
 pipelines/_steps/quarantine/thumbnail.go | 90 ++++++++++++++++++++++++
 pipelines/pipeline_download/pipeline.go  | 15 ++++
 pipelines/pipeline_thumbnail/pipeline.go | 27 +++++--
 10 files changed, 182 insertions(+), 16 deletions(-)
 create mode 100644 pipelines/_steps/quarantine/logic.go
 create mode 100644 pipelines/_steps/quarantine/thumbnail.go

diff --git a/api/_responses/content.go b/api/_responses/content.go
index 1f697ff0..735e4b1b 100644
--- a/api/_responses/content.go
+++ b/api/_responses/content.go
@@ -19,3 +19,13 @@ type DownloadResponse struct {
 type StreamDataResponse struct {
 	Stream io.Reader
 }
+
+func MakeQuarantinedImageResponse(stream io.ReadCloser) *DownloadResponse {
+	return &DownloadResponse{
+		ContentType:       "image/png",
+		Filename:          "not_allowed.png",
+		SizeBytes:         -1,
+		Data:              stream,
+		TargetDisposition: "inline",
+	}
+}
diff --git a/api/r0/download.go b/api/r0/download.go
index ed81ad4a..9fe8bf8b 100644
--- a/api/r0/download.go
+++ b/api/r0/download.go
@@ -16,8 +16,6 @@ import (
 	"github.com/turt2live/matrix-media-repo/common/rcontext"
 )
 
-type DownloadMediaResponse = _responses.DownloadResponse
-
 func DownloadMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} {
 	server := _routers.GetParam("server", r)
 	mediaId := _routers.GetParam("mediaId", r)
@@ -76,7 +74,12 @@ func DownloadMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.
 		} else if err == common.ErrMediaTooLarge {
 			return _responses.RequestTooLarge()
 		} else if err == common.ErrMediaQuarantined {
-			return _responses.NotFoundError() // We lie for security
+			rctx.Log.Debug("Quarantined media accessed. Has stream? ", stream != nil)
+			if stream != nil {
+				return _responses.MakeQuarantinedImageResponse(stream)
+			} else {
+				return _responses.NotFoundError() // We lie for security
+			}
 		}
 		rctx.Log.Error("Unexpected error locating media: " + err.Error())
 		sentry.CaptureException(err)
@@ -87,7 +90,7 @@ func DownloadMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.
 		filename = media.UploadName
 	}
 
-	return &DownloadMediaResponse{
+	return &_responses.DownloadResponse{
 		ContentType:       media.ContentType,
 		Filename:          filename,
 		SizeBytes:         media.SizeBytes,
diff --git a/api/r0/thumbnail.go b/api/r0/thumbnail.go
index 552fe0b8..139fb52f 100644
--- a/api/r0/thumbnail.go
+++ b/api/r0/thumbnail.go
@@ -123,14 +123,19 @@ func ThumbnailMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta
 		} else if err == common.ErrMediaTooLarge {
 			return _responses.RequestTooLarge()
 		} else if err == common.ErrMediaQuarantined {
-			return _responses.NotFoundError() // We lie for security
+			rctx.Log.Debug("Quarantined media accessed. Has stream? ", stream != nil)
+			if stream != nil {
+				return _responses.MakeQuarantinedImageResponse(stream)
+			} else {
+				return _responses.NotFoundError() // We lie for security
+			}
 		}
 		rctx.Log.Error("Unexpected error locating media: " + err.Error())
 		sentry.CaptureException(err)
 		return _responses.InternalServerError("Unexpected Error")
 	}
 
-	return &DownloadMediaResponse{
+	return &_responses.DownloadResponse{
 		ContentType:       thumbnail.ContentType,
 		Filename:          "thumbnail" + util.ExtensionForContentType(thumbnail.ContentType),
 		SizeBytes:         thumbnail.SizeBytes,
diff --git a/config.sample.yaml b/config.sample.yaml
index fefc5f3f..bc549de5 100644
--- a/config.sample.yaml
+++ b/config.sample.yaml
@@ -464,6 +464,7 @@ quarantine:
   replaceDownloads: false
 
   # If provided, the given image will be returned as a thumbnail for media that is quarantined.
+  # The recommended size is at least 512x512.
   #thumbnailPath: "/path/to/thumbnail.png"
 
   # If true, administrators of the configured homeservers may quarantine media for their server
diff --git a/go.mod b/go.mod
index d8f38acf..86e7f62b 100644
--- a/go.mod
+++ b/go.mod
@@ -54,7 +54,7 @@ require (
 	github.com/minio/minio-go/v7 v7.0.55
 	github.com/panjf2000/ants/v2 v2.7.4
 	github.com/redis/go-redis/v9 v9.0.4
-	github.com/t2bot/go-singleflight-streams v0.0.2
+	github.com/t2bot/go-singleflight-streams v0.0.3
 )
 
 require (
diff --git a/go.sum b/go.sum
index f7ab76e5..79521e1e 100644
--- a/go.sum
+++ b/go.sum
@@ -335,8 +335,8 @@ github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKs
 github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
 github.com/stvp/tempredis v0.0.0-20181119212430-b82af8480203 h1:QVqDTf3h2WHt08YuiTGPZLls0Wq99X9bWd0Q5ZSBesM=
 github.com/stvp/tempredis v0.0.0-20181119212430-b82af8480203/go.mod h1:oqN97ltKNihBbwlX8dLpwxCl3+HnXKV/R0e+sRLd9C8=
-github.com/t2bot/go-singleflight-streams v0.0.2 h1:N60e6rZvuf5CybfuGREgMDxKbFtvHqr3FzawuPLaGOI=
-github.com/t2bot/go-singleflight-streams v0.0.2/go.mod h1:oiZ5Zj2o6p3SuMWTPXAmye7Ou86WCrQ2oGcr32bEEq8=
+github.com/t2bot/go-singleflight-streams v0.0.3 h1:vedZM34NuCMx7YJkPek2b73zjzV8Qt8v+hR2u7RCbbk=
+github.com/t2bot/go-singleflight-streams v0.0.3/go.mod h1:oiZ5Zj2o6p3SuMWTPXAmye7Ou86WCrQ2oGcr32bEEq8=
 github.com/tebeka/strftime v0.1.3 h1:5HQXOqWKYRFfNyBMNVc9z5+QzuBtIXy03psIhtdJYto=
 github.com/tebeka/strftime v0.1.3/go.mod h1:7wJm3dZlpr4l/oVK0t1HYIc4rMzQ2XJlOMIUJUJH6XQ=
 github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
diff --git a/pipelines/_steps/quarantine/logic.go b/pipelines/_steps/quarantine/logic.go
new file mode 100644
index 00000000..31cb186d
--- /dev/null
+++ b/pipelines/_steps/quarantine/logic.go
@@ -0,0 +1,29 @@
+package quarantine
+
+import (
+	"io"
+
+	"github.com/turt2live/matrix-media-repo/common"
+	"github.com/turt2live/matrix-media-repo/common/rcontext"
+	"github.com/turt2live/matrix-media-repo/pipelines/_steps/download"
+)
+
+func ReturnAppropriateThing(ctx rcontext.RequestContext, isDownload bool, recordOnly bool, width int, height int, startByte int64, endByte int64) (io.ReadCloser, error) {
+	flag := ctx.Config.Quarantine.ReplaceDownloads
+	if !isDownload {
+		flag = ctx.Config.Quarantine.ReplaceThumbnails
+	}
+	if !flag || recordOnly {
+		return nil, common.ErrMediaQuarantined
+	} else {
+		if qr, err := MakeThumbnail(ctx, width, height); err != nil {
+			return nil, err
+		} else {
+			if r, err2 := download.CreateLimitedStream(ctx, qr, startByte, endByte); err2 != nil {
+				return nil, err2
+			} else {
+				return r, common.ErrMediaQuarantined
+			}
+		}
+	}
+}
diff --git a/pipelines/_steps/quarantine/thumbnail.go b/pipelines/_steps/quarantine/thumbnail.go
new file mode 100644
index 00000000..d5a37718
--- /dev/null
+++ b/pipelines/_steps/quarantine/thumbnail.go
@@ -0,0 +1,90 @@
+package quarantine
+
+import (
+	"image"
+	"image/color"
+	"io"
+	"math"
+
+	"github.com/disintegration/imaging"
+	"github.com/fogleman/gg"
+	"github.com/golang/freetype/truetype"
+	"github.com/turt2live/matrix-media-repo/common/rcontext"
+	"golang.org/x/image/font/gofont/gosmallcaps"
+)
+
+func MakeThumbnail(ctx rcontext.RequestContext, width int, height int) (io.ReadCloser, error) {
+	var centerImage image.Image
+	var err error
+	if ctx.Config.Quarantine.ThumbnailPath != "" {
+		centerImage, err = imaging.Open(ctx.Config.Quarantine.ThumbnailPath)
+	} else {
+		centerImage, err = makeDefaultImage()
+	}
+	if err != nil {
+		return nil, err
+	}
+
+	c := gg.NewContext(width, height)
+	centerImage = imaging.Fit(centerImage, width, height, imaging.Lanczos)
+	c.DrawImageAnchored(centerImage, width/2, height/2, 0.5, 0.5)
+
+	pr, pw := io.Pipe()
+	go func(pw *io.PipeWriter, c *gg.Context) {
+		encErr := c.EncodePNG(pw)
+		if encErr != nil {
+			_ = pw.CloseWithError(encErr)
+		} else {
+			_ = pw.Close()
+		}
+	}(pw, c)
+	return pr, nil
+}
+
+func makeDefaultImage() (image.Image, error) {
+	c := gg.NewContext(700, 700)
+	c.Clear()
+
+	red := color.RGBA{R: 190, G: 26, B: 25, A: 255}
+	orange := color.RGBA{R: 255, G: 186, B: 73, A: 255}
+	x := 350.0
+	y := 300.0
+	r := 256.0
+	w := 55.0
+	p := 64.0
+	m := "media not allowed"
+
+	c.SetColor(orange)
+	c.DrawRectangle(0, 0, 700, 700)
+	c.Fill()
+
+	c.SetColor(red)
+	c.DrawCircle(x, y, r)
+	c.Fill()
+
+	c.SetColor(color.White)
+	c.DrawCircle(x, y, r-w)
+	c.Fill()
+
+	lr := r - (w / 2)
+	sx := x + (lr * math.Cos(gg.Radians(225.0)))
+	sy := y + (lr * math.Sin(gg.Radians(225.0)))
+	ex := x + (lr * math.Cos(gg.Radians(45.0)))
+	ey := y + (lr * math.Sin(gg.Radians(45.0)))
+	c.SetLineCap(gg.LineCapButt)
+	c.SetLineWidth(w)
+	c.SetColor(red)
+	c.DrawLine(sx, sy, ex, ey)
+	c.Stroke()
+
+	f, err := truetype.Parse(gosmallcaps.TTF)
+	if err != nil {
+		return nil, err
+	}
+
+	c.SetColor(color.Black)
+	c.SetFontFace(truetype.NewFace(f, &truetype.Options{Size: 64}))
+	c.DrawStringAnchored(m, x, y+r+p, 0.5, 0.5)
+
+	return c.Image(), nil
+}
diff --git a/pipelines/pipeline_download/pipeline.go b/pipelines/pipeline_download/pipeline.go
index e4eecf30..b506bfe4 100644
--- a/pipelines/pipeline_download/pipeline.go
+++ b/pipelines/pipeline_download/pipeline.go
@@ -13,6 +13,7 @@ import (
 	"github.com/turt2live/matrix-media-repo/common/rcontext"
 	"github.com/turt2live/matrix-media-repo/database"
 	"github.com/turt2live/matrix-media-repo/pipelines/_steps/download"
+	"github.com/turt2live/matrix-media-repo/pipelines/_steps/quarantine"
 	"github.com/turt2live/matrix-media-repo/util"
 )
 
@@ -41,6 +42,10 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Do
 	defer close(recordCh)
 	r, err, _ := sf.Do(fmt.Sprintf("%s/%s?%s", origin, mediaId, opts.String()), func() (io.ReadCloser, error) {
 		serveRecord := func(recordCh chan *database.DbMedia, record *database.DbMedia) {
+			defer func() {
+				// Don't crash when we send to a closed channel
+				recover()
+			}()
 			recordCh <- record
 		}
 
@@ -52,6 +57,9 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Do
 		}
 		if record != nil {
 			go serveRecord(recordCh, record) // async function to prevent deadlock
+			if record.Quarantined {
+				return quarantine.ReturnAppropriateThing(ctx, true, opts.RecordOnly, 512, 512, opts.StartByte, opts.EndByte)
+			}
 			if opts.RecordOnly {
 				return nil, nil
 			}
@@ -67,6 +75,9 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Do
 			return nil, err
 		}
 		go serveRecord(recordCh, record) // async function to prevent deadlock
+		if record.Quarantined {
+			return quarantine.ReturnAppropriateThing(ctx, true, opts.RecordOnly, 512, 512, opts.StartByte, opts.EndByte)
+		}
 		if opts.RecordOnly {
 			r.Close()
 			return nil, nil
@@ -80,6 +91,10 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Do
 
 		return r, nil
 	})
+	if err == common.ErrMediaQuarantined {
+		cancel()
+		return nil, r, err
+	}
 	if err != nil {
 		cancel()
 		return nil, nil, err
diff --git a/pipelines/pipeline_thumbnail/pipeline.go b/pipelines/pipeline_thumbnail/pipeline.go
index 4d845f81..8d84d7f6 100644
--- a/pipelines/pipeline_thumbnail/pipeline.go
+++ b/pipelines/pipeline_thumbnail/pipeline.go
@@ -12,6 +12,7 @@ import (
 	"github.com/turt2live/matrix-media-repo/common/rcontext"
 	"github.com/turt2live/matrix-media-repo/database"
 	"github.com/turt2live/matrix-media-repo/pipelines/_steps/download"
+	"github.com/turt2live/matrix-media-repo/pipelines/_steps/quarantine"
 	"github.com/turt2live/matrix-media-repo/pipelines/_steps/thumbnails"
 	"github.com/turt2live/matrix-media-repo/pipelines/pipeline_download"
 	"github.com/turt2live/matrix-media-repo/util"
@@ -64,22 +65,30 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Th
 	defer close(recordCh)
 	r, err, _ := sf.Do(fmt.Sprintf("%s/%s?%s", origin, mediaId, opts.String()), func() (io.ReadCloser, error) {
 		serveRecord := func(recordCh chan *database.DbThumbnail, record *database.DbThumbnail) {
+			defer func() {
+				// Don't crash when we send to a closed channel
+				recover()
+			}()
 			recordCh <- record
 		}
 
 		// Step 4: Get the associated media record (without stream)
-		mediaRecord, _, err := pipeline_download.Execute(ctx, origin, mediaId, opts.ImpliedDownloadOpts())
+		mediaRecord, dr, err := pipeline_download.Execute(ctx, origin, mediaId, opts.ImpliedDownloadOpts())
 		if err != nil {
+			if err == common.ErrMediaQuarantined {
+				go serveRecord(recordCh, nil) // async function to prevent deadlock
+				if dr != nil {
+					dr.Close()
+				}
+				return quarantine.ReturnAppropriateThing(ctx, false, opts.RecordOnly, opts.Width, opts.Height, opts.StartByte, opts.EndByte)
+			}
 			return nil, err
 		}
 		if mediaRecord == nil {
 			return nil, common.ErrMediaNotFound
 		}
 
-		// Step 5: Check for quarantine
-		// TODO: Quarantine
-
-		// Step 6: See if we're lucky enough to already have this thumbnail
+		// Step 5: See if we're lucky enough to already have this thumbnail
 		thumbDb := database.GetInstance().Thumbnails.Prepare(ctx)
 		record, err := thumbDb.GetByParams(origin, mediaId, opts.Width, opts.Height, opts.Method, opts.Animated)
 		if err != nil {
@@ -93,7 +102,7 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Th
 			return download.OpenStream(ctx, record.Locatable, opts.StartByte, opts.EndByte)
 		}
 
-		// Step 7: Generate the thumbnail and return that
+		// Step 6: Generate the thumbnail and return that
 		record, r, err := thumbnails.Generate(ctx, mediaRecord, opts.Width, opts.Height, opts.Method, opts.Animated)
 		if err != nil {
 			return nil, err
@@ -104,9 +113,13 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Th
 			return nil, nil
 		}
 
-		// Step 8: Create a limited stream
+		// Step 7: Create a limited stream
 		return download.CreateLimitedStream(ctx, r, opts.StartByte, opts.EndByte)
 	})
+	if err == common.ErrMediaQuarantined {
+		cancel()
+		return nil, r, err
+	}
 	if err != nil {
 		cancel()
 		return nil, nil, err
-- 
GitLab