From be16c9eb44aa736be17f06460fa15b9181130f19 Mon Sep 17 00:00:00 2001
From: Travis Ralston <travpc@gmail.com>
Date: Sat, 15 Jul 2023 23:58:33 -0600
Subject: [PATCH] Re-enable plugins

---
 CHANGELOG.md                                  |  2 +
 config.sample.yaml                            |  3 ++
 .../upload_controller/upload_controller.go    |  3 +-
 pipelines/_steps/upload/spam.go               | 39 +++++++++++++++++++
 pipelines/pipeline_upload/pipeline.go         | 28 ++++++++++++-
 plugins/manager.go                            | 13 ++++++-
 6 files changed, 83 insertions(+), 5 deletions(-)
 create mode 100644 pipelines/_steps/upload/spam.go

diff --git a/CHANGELOG.md b/CHANGELOG.md
index edcdb3ac..348a17fa 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -101,6 +101,8 @@ path/server, for example, then you can simply update the path in the config for
 ### Fixed
 
 * URL previews now follow redirects properly.
+* Overall memory usage is improved, particularly during media uploads. 
+  * Note: If you use plugins then memory usage will still be somewhat high due to temporary caching of uploads. 
 
 ## [1.2.13] - February 12, 2023
 
diff --git a/config.sample.yaml b/config.sample.yaml
index bb023ec6..be88b11a 100644
--- a/config.sample.yaml
+++ b/config.sample.yaml
@@ -503,6 +503,9 @@ metrics:
 # Currently there are only antispam plugins, but in future there should be more options.
 # Plugins are not supported on per-domain paths and are instead repo-wide. For more
 # information on writing plugins, please visit #matrix-media-repo:t2bot.io on Matrix.
+#
+# Note: Using plugins will mean that your media repo's memory usage will be higher because
+# uploads are cached in-memory temporarily.
 plugins:
   - exec: /path/to/plugin/executable
     # Note: the exact config varies by plugin.
diff --git a/controllers/upload_controller/upload_controller.go b/controllers/upload_controller/upload_controller.go
index c2ad26b6..2cf96b2c 100644
--- a/controllers/upload_controller/upload_controller.go
+++ b/controllers/upload_controller/upload_controller.go
@@ -1,6 +1,7 @@
 package upload_controller
 
 import (
+	"bytes"
 	"errors"
 	"io"
 	"strconv"
@@ -159,7 +160,7 @@ func trackUploadAsLastAccess(ctx rcontext.RequestContext, media *types.Media) {
 }
 
 func checkSpam(contents []byte, filename string, contentType string, userId string, origin string, mediaId string) error {
-	spam, err := plugins.CheckForSpam(contents, filename, contentType, userId, origin, mediaId)
+	spam, err := plugins.CheckForSpam(bytes.NewBuffer(contents), filename, contentType, userId, origin, mediaId)
 	if err != nil {
 		logrus.Warn("Error checking spam - assuming not spam: " + err.Error())
 		sentry.CaptureException(err)
diff --git a/pipelines/_steps/upload/spam.go b/pipelines/_steps/upload/spam.go
new file mode 100644
index 00000000..1dae84d0
--- /dev/null
+++ b/pipelines/_steps/upload/spam.go
@@ -0,0 +1,39 @@
+package upload
+
+import (
+	"io"
+
+	"github.com/turt2live/matrix-media-repo/common/rcontext"
+	"github.com/turt2live/matrix-media-repo/plugins"
+)
+
+type FileMetadata struct {
+	Name        string
+	ContentType string
+	UserId      string
+	Origin      string
+	MediaId     string
+}
+
+type SpamResponse struct {
+	Err    error
+	IsSpam bool
+}
+
+func CheckSpamAsync(ctx rcontext.RequestContext, reader io.Reader, metadata FileMetadata) chan SpamResponse {
+	opChan := make(chan SpamResponse)
+	go func() {
+		//goland:noinspection GoUnhandledErrorResult
+		defer io.Copy(io.Discard, reader) // we need to flush the reader as we might end up blocking the upload
+
+		spam, err := plugins.CheckForSpam(reader, metadata.Name, metadata.ContentType, metadata.UserId, metadata.Origin, metadata.MediaId)
+		go func() {
+			// run async to avoid deadlock
+			opChan <- SpamResponse{
+				Err:    err,
+				IsSpam: spam,
+			}
+		}()
+	}()
+	return opChan
+}
diff --git a/pipelines/pipeline_upload/pipeline.go b/pipelines/pipeline_upload/pipeline.go
index 4f93e768..8dd30f51 100644
--- a/pipelines/pipeline_upload/pipeline.go
+++ b/pipelines/pipeline_upload/pipeline.go
@@ -5,6 +5,7 @@ import (
 	"io"
 
 	"github.com/getsentry/sentry-go"
+	"github.com/turt2live/matrix-media-repo/common"
 	"github.com/turt2live/matrix-media-repo/common/config"
 	"github.com/turt2live/matrix-media-repo/common/rcontext"
 	"github.com/turt2live/matrix-media-repo/database"
@@ -14,6 +15,7 @@ import (
 	"github.com/turt2live/matrix-media-repo/pipelines/_steps/quota"
 	"github.com/turt2live/matrix-media-repo/pipelines/_steps/upload"
 	"github.com/turt2live/matrix-media-repo/util"
+	"github.com/turt2live/matrix-media-repo/util/readers"
 )
 
 // Execute Media upload. If mediaId is an empty string, one will be generated.
@@ -46,12 +48,34 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, r io.Re
 		return nil, err
 	}
 
-	// Step 4: Buffer to the datastore's temporary path
-	sha256hash, sizeBytes, reader, err := datastores.BufferTemp(dsConf, r)
+	// Step 4: Buffer to the datastore's temporary path, and check for spam
+	spamR, spamW := io.Pipe()
+	spamTee := io.TeeReader(r, spamW)
+	spamChan := upload.CheckSpamAsync(ctx, spamR, upload.FileMetadata{
+		Name:        fileName,
+		ContentType: contentType,
+		UserId:      userId,
+		Origin:      origin,
+		MediaId:     mediaId,
+	})
+	sha256hash, sizeBytes, reader, err := datastores.BufferTemp(dsConf, readers.NewCancelCloser(io.NopCloser(spamTee), func() {
+		r.Close()
+	}))
 	if err != nil {
 		return nil, err
 	}
+	if err = spamW.Close(); err != nil {
+		ctx.Log.Warn("Failed to close writer for spam checker: ", err)
+		spamChan <- upload.SpamResponse{Err: errors.New("failed to close")}
+	}
 	defer reader.Close()
+	spam := <-spamChan
+	if spam.Err != nil {
+		return nil, err
+	}
+	if spam.IsSpam {
+		return nil, common.ErrMediaQuarantined
+	}
 
 	// Step 5: Split the buffer to calculate a blurhash & populate cache later
 	bhR, bhW := io.Pipe()
diff --git a/plugins/manager.go b/plugins/manager.go
index 75e891c2..b1c91fce 100644
--- a/plugins/manager.go
+++ b/plugins/manager.go
@@ -2,6 +2,7 @@ package plugins
 
 import (
 	"encoding/base64"
+	"io"
 
 	"github.com/hashicorp/go-plugin"
 	"github.com/sirupsen/logrus"
@@ -40,7 +41,8 @@ func StopPlugins() {
 	existingPlugins = make([]*mmrPlugin, 0)
 }
 
-func CheckForSpam(contents []byte, filename string, contentType string, userId string, origin string, mediaId string) (bool, error) {
+func CheckForSpam(r io.Reader, filename string, contentType string, userId string, origin string, mediaId string) (bool, error) {
+	b := make([]byte, 0)
 	for _, pl := range existingPlugins {
 		as, err := pl.Antispam()
 		if err != nil {
@@ -48,7 +50,14 @@ func CheckForSpam(contents []byte, filename string, contentType string, userId s
 			continue
 		}
 
-		b64 := base64.StdEncoding.EncodeToString(contents)
+		if len(b) == 0 {
+			b, err = io.ReadAll(r)
+			if err != nil {
+				return false, err
+			}
+		}
+
+		b64 := base64.StdEncoding.EncodeToString(b)
 		spam, err := as.CheckForSpam(b64, filename, contentType, userId, origin, mediaId)
 		if err != nil {
 			return false, err
-- 
GitLab