From be16c9eb44aa736be17f06460fa15b9181130f19 Mon Sep 17 00:00:00 2001 From: Travis Ralston <travpc@gmail.com> Date: Sat, 15 Jul 2023 23:58:33 -0600 Subject: [PATCH] Re-enable plugins --- CHANGELOG.md | 2 + config.sample.yaml | 3 ++ .../upload_controller/upload_controller.go | 3 +- pipelines/_steps/upload/spam.go | 39 +++++++++++++++++++ pipelines/pipeline_upload/pipeline.go | 28 ++++++++++++- plugins/manager.go | 13 ++++++- 6 files changed, 83 insertions(+), 5 deletions(-) create mode 100644 pipelines/_steps/upload/spam.go diff --git a/CHANGELOG.md b/CHANGELOG.md index edcdb3ac..348a17fa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -101,6 +101,8 @@ path/server, for example, then you can simply update the path in the config for ### Fixed * URL previews now follow redirects properly. +* Overall memory usage is improved, particularly during media uploads. + * Note: If you use plugins then memory usage will still be somewhat high due to temporary caching of uploads. ## [1.2.13] - February 12, 2023 diff --git a/config.sample.yaml b/config.sample.yaml index bb023ec6..be88b11a 100644 --- a/config.sample.yaml +++ b/config.sample.yaml @@ -503,6 +503,9 @@ metrics: # Currently there are only antispam plugins, but in future there should be more options. # Plugins are not supported on per-domain paths and are instead repo-wide. For more # information on writing plugins, please visit #matrix-media-repo:t2bot.io on Matrix. +# +# Note: Using plugins will mean that your media repo's memory usage will be higher because +# uploads are cached in-memory temporarily. plugins: - exec: /path/to/plugin/executable # Note: the exact config varies by plugin. diff --git a/controllers/upload_controller/upload_controller.go b/controllers/upload_controller/upload_controller.go index c2ad26b6..2cf96b2c 100644 --- a/controllers/upload_controller/upload_controller.go +++ b/controllers/upload_controller/upload_controller.go @@ -1,6 +1,7 @@ package upload_controller import ( + "bytes" "errors" "io" "strconv" @@ -159,7 +160,7 @@ func trackUploadAsLastAccess(ctx rcontext.RequestContext, media *types.Media) { } func checkSpam(contents []byte, filename string, contentType string, userId string, origin string, mediaId string) error { - spam, err := plugins.CheckForSpam(contents, filename, contentType, userId, origin, mediaId) + spam, err := plugins.CheckForSpam(bytes.NewBuffer(contents), filename, contentType, userId, origin, mediaId) if err != nil { logrus.Warn("Error checking spam - assuming not spam: " + err.Error()) sentry.CaptureException(err) diff --git a/pipelines/_steps/upload/spam.go b/pipelines/_steps/upload/spam.go new file mode 100644 index 00000000..1dae84d0 --- /dev/null +++ b/pipelines/_steps/upload/spam.go @@ -0,0 +1,39 @@ +package upload + +import ( + "io" + + "github.com/turt2live/matrix-media-repo/common/rcontext" + "github.com/turt2live/matrix-media-repo/plugins" +) + +type FileMetadata struct { + Name string + ContentType string + UserId string + Origin string + MediaId string +} + +type SpamResponse struct { + Err error + IsSpam bool +} + +func CheckSpamAsync(ctx rcontext.RequestContext, reader io.Reader, metadata FileMetadata) chan SpamResponse { + opChan := make(chan SpamResponse) + go func() { + //goland:noinspection GoUnhandledErrorResult + defer io.Copy(io.Discard, reader) // we need to flush the reader as we might end up blocking the upload + + spam, err := plugins.CheckForSpam(reader, metadata.Name, metadata.ContentType, metadata.UserId, metadata.Origin, metadata.MediaId) + go func() { + // run async to avoid deadlock + opChan <- SpamResponse{ + Err: err, + IsSpam: spam, + } + }() + }() + return opChan +} diff --git a/pipelines/pipeline_upload/pipeline.go b/pipelines/pipeline_upload/pipeline.go index 4f93e768..8dd30f51 100644 --- a/pipelines/pipeline_upload/pipeline.go +++ b/pipelines/pipeline_upload/pipeline.go @@ -5,6 +5,7 @@ import ( "io" "github.com/getsentry/sentry-go" + "github.com/turt2live/matrix-media-repo/common" "github.com/turt2live/matrix-media-repo/common/config" "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/database" @@ -14,6 +15,7 @@ import ( "github.com/turt2live/matrix-media-repo/pipelines/_steps/quota" "github.com/turt2live/matrix-media-repo/pipelines/_steps/upload" "github.com/turt2live/matrix-media-repo/util" + "github.com/turt2live/matrix-media-repo/util/readers" ) // Execute Media upload. If mediaId is an empty string, one will be generated. @@ -46,12 +48,34 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, r io.Re return nil, err } - // Step 4: Buffer to the datastore's temporary path - sha256hash, sizeBytes, reader, err := datastores.BufferTemp(dsConf, r) + // Step 4: Buffer to the datastore's temporary path, and check for spam + spamR, spamW := io.Pipe() + spamTee := io.TeeReader(r, spamW) + spamChan := upload.CheckSpamAsync(ctx, spamR, upload.FileMetadata{ + Name: fileName, + ContentType: contentType, + UserId: userId, + Origin: origin, + MediaId: mediaId, + }) + sha256hash, sizeBytes, reader, err := datastores.BufferTemp(dsConf, readers.NewCancelCloser(io.NopCloser(spamTee), func() { + r.Close() + })) if err != nil { return nil, err } + if err = spamW.Close(); err != nil { + ctx.Log.Warn("Failed to close writer for spam checker: ", err) + spamChan <- upload.SpamResponse{Err: errors.New("failed to close")} + } defer reader.Close() + spam := <-spamChan + if spam.Err != nil { + return nil, err + } + if spam.IsSpam { + return nil, common.ErrMediaQuarantined + } // Step 5: Split the buffer to calculate a blurhash & populate cache later bhR, bhW := io.Pipe() diff --git a/plugins/manager.go b/plugins/manager.go index 75e891c2..b1c91fce 100644 --- a/plugins/manager.go +++ b/plugins/manager.go @@ -2,6 +2,7 @@ package plugins import ( "encoding/base64" + "io" "github.com/hashicorp/go-plugin" "github.com/sirupsen/logrus" @@ -40,7 +41,8 @@ func StopPlugins() { existingPlugins = make([]*mmrPlugin, 0) } -func CheckForSpam(contents []byte, filename string, contentType string, userId string, origin string, mediaId string) (bool, error) { +func CheckForSpam(r io.Reader, filename string, contentType string, userId string, origin string, mediaId string) (bool, error) { + b := make([]byte, 0) for _, pl := range existingPlugins { as, err := pl.Antispam() if err != nil { @@ -48,7 +50,14 @@ func CheckForSpam(contents []byte, filename string, contentType string, userId s continue } - b64 := base64.StdEncoding.EncodeToString(contents) + if len(b) == 0 { + b, err = io.ReadAll(r) + if err != nil { + return false, err + } + } + + b64 := base64.StdEncoding.EncodeToString(b) spam, err := as.CheckForSpam(b64, filename, contentType, userId, origin, mediaId) if err != nil { return false, err -- GitLab