From 049ea12e2bcabfbe7998ea677d4df1b94b2b874d Mon Sep 17 00:00:00 2001
From: Travis Ralston <travpc@gmail.com>
Date: Tue, 1 Aug 2023 21:56:54 -0600
Subject: [PATCH] Move quarantine logic to new structures

---
 api/custom/quarantine.go          | 160 ++++++------------------------
 database/virtualtable_metadata.go |  26 +++++
 redislib/cache.go                 |   9 ++
 tasks/task_runner/quarantine.go   |  88 ++++++++++++++++
 4 files changed, 155 insertions(+), 128 deletions(-)
 create mode 100644 tasks/task_runner/quarantine.go

diff --git a/api/custom/quarantine.go b/api/custom/quarantine.go
index 6cdf7d18..477815da 100644
--- a/api/custom/quarantine.go
+++ b/api/custom/quarantine.go
@@ -1,30 +1,25 @@
 package custom
 
 import (
-	"database/sql"
 	"net/http"
 
 	"github.com/getsentry/sentry-go"
 	"github.com/turt2live/matrix-media-repo/api/_apimeta"
 	"github.com/turt2live/matrix-media-repo/api/_responses"
 	"github.com/turt2live/matrix-media-repo/api/_routers"
+	"github.com/turt2live/matrix-media-repo/database"
+	"github.com/turt2live/matrix-media-repo/tasks/task_runner"
 
 	"github.com/sirupsen/logrus"
 	"github.com/turt2live/matrix-media-repo/common/rcontext"
-	"github.com/turt2live/matrix-media-repo/internal_cache"
 	"github.com/turt2live/matrix-media-repo/matrix"
-	"github.com/turt2live/matrix-media-repo/storage"
-	"github.com/turt2live/matrix-media-repo/types"
 	"github.com/turt2live/matrix-media-repo/util"
 )
 
 type MediaQuarantinedResponse struct {
-	NumQuarantined int `json:"num_quarantined"`
+	NumQuarantined int64 `json:"num_quarantined"`
 }
 
-// Developer note: This isn't broken out into a dedicated controller class because the logic is slightly
-// too complex to do so. If anything, the logic should be improved and moved.
-
 func QuarantineRoomMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} {
 	canQuarantine, allowOtherHosts, isLocalAdmin := getQuarantineRequestInfo(r, rctx, user)
 	if !canQuarantine {
@@ -47,31 +42,13 @@ func QuarantineRoomMedia(r *http.Request, rctx rcontext.RequestContext, user _ap
 
 	var mxcs []string
 	mxcs = append(mxcs, allMedia.LocalMxcs...)
-	mxcs = append(mxcs, allMedia.RemoteMxcs...)
-
-	total := 0
-	for _, mxc := range mxcs {
-		server, mediaId, err := util.SplitMxc(mxc)
-		if err != nil {
-			rctx.Log.Error("Error parsing MXC URI ("+mxc+"): ", err)
-			sentry.CaptureException(err)
-			return _responses.InternalServerError("error parsing mxc uri")
-		}
-
-		if !allowOtherHosts && r.Host != server {
-			rctx.Log.Warn("Skipping media " + mxc + " because it is on a different host")
-			continue
-		}
-
-		resp, ok := doQuarantine(rctx, server, mediaId, allowOtherHosts)
-		if !ok {
-			return resp
-		}
-
-		total += resp.(*MediaQuarantinedResponse).NumQuarantined
+	if allowOtherHosts {
+		mxcs = append(mxcs, allMedia.RemoteMxcs...)
 	}
 
-	return &_responses.DoNotCacheResponse{Payload: &MediaQuarantinedResponse{NumQuarantined: total}}
+	return performQuarantineRequest(rctx, r.Host, allowOtherHosts, &task_runner.QuarantineThis{
+		MxcUris: mxcs,
+	})
 }
 
 func QuarantineUserMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} {
@@ -98,25 +75,17 @@ func QuarantineUserMedia(r *http.Request, rctx rcontext.RequestContext, user _ap
 		return _responses.AuthFailed()
 	}
 
-	db := storage.GetDatabase().GetMediaStore(rctx)
-	userMedia, err := db.GetMediaByUser(userId)
+	db := database.GetInstance().Media.Prepare(rctx)
+	userMedia, err := db.GetByUserId(userId)
 	if err != nil {
 		rctx.Log.Error("Error while listing media for the user: ", err)
 		sentry.CaptureException(err)
 		return _responses.InternalServerError("error retrieving media for user")
 	}
 
-	total := 0
-	for _, media := range userMedia {
-		resp, ok := doQuarantineOn(media, allowOtherHosts, rctx)
-		if !ok {
-			return resp
-		}
-
-		total += resp.(*MediaQuarantinedResponse).NumQuarantined
-	}
-
-	return &_responses.DoNotCacheResponse{Payload: &MediaQuarantinedResponse{NumQuarantined: total}}
+	return performQuarantineRequest(rctx, r.Host, allowOtherHosts, &task_runner.QuarantineThis{
+		DbMedia: userMedia,
+	})
 }
 
 func QuarantineDomainMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} {
@@ -140,25 +109,17 @@ func QuarantineDomainMedia(r *http.Request, rctx rcontext.RequestContext, user _
 		return _responses.AuthFailed()
 	}
 
-	db := storage.GetDatabase().GetMediaStore(rctx)
-	userMedia, err := db.GetAllMediaForServer(serverName)
+	db := database.GetInstance().Media.Prepare(rctx)
+	domainMedia, err := db.GetByOrigin(serverName)
 	if err != nil {
 		rctx.Log.Error("Error while listing media for the server: ", err)
 		sentry.CaptureException(err)
 		return _responses.InternalServerError("error retrieving media for server")
 	}
 
-	total := 0
-	for _, media := range userMedia {
-		resp, ok := doQuarantineOn(media, allowOtherHosts, rctx)
-		if !ok {
-			return resp
-		}
-
-		total += resp.(*MediaQuarantinedResponse).NumQuarantined
-	}
-
-	return &_responses.DoNotCacheResponse{Payload: &MediaQuarantinedResponse{NumQuarantined: total}}
+	return performQuarantineRequest(rctx, r.Host, allowOtherHosts, &task_runner.QuarantineThis{
+		DbMedia: domainMedia,
+	})
 }
 
 func QuarantineMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} {
@@ -184,80 +145,28 @@ func QuarantineMedia(r *http.Request, rctx rcontext.RequestContext, user _apimet
 		return _responses.BadRequest("unable to quarantine media on other homeservers")
 	}
 
-	resp, _ := doQuarantine(rctx, server, mediaId, allowOtherHosts)
-	return &_responses.DoNotCacheResponse{Payload: resp}
-}
-
-func doQuarantine(ctx rcontext.RequestContext, origin string, mediaId string, allowOtherHosts bool) (interface{}, bool) {
-	db := storage.GetDatabase().GetMediaStore(ctx)
-	media, err := db.Get(origin, mediaId)
-	if err != nil {
-		if err == sql.ErrNoRows {
-			ctx.Log.Warn("Media not found, could not quarantine: " + origin + "/" + mediaId)
-			return &MediaQuarantinedResponse{0}, true
-		}
-
-		ctx.Log.Error("Error fetching media: ", err)
-		sentry.CaptureException(err)
-		return _responses.InternalServerError("error quarantining media"), false
-	}
-
-	return doQuarantineOn(media, allowOtherHosts, ctx)
+	return performQuarantineRequest(rctx, r.Host, allowOtherHosts, &task_runner.QuarantineThis{
+		Single: &task_runner.QuarantineRecord{
+			Origin:  server,
+			MediaId: mediaId,
+		},
+	})
 }
 
-func doQuarantineOn(media *types.Media, allowOtherHosts bool, ctx rcontext.RequestContext) (interface{}, bool) {
-	// Check to make sure the media doesn't have a purpose in staying
-	attrDb := storage.GetDatabase().GetMediaAttributesStore(ctx)
-	attr, err := attrDb.GetAttributesDefaulted(media.Origin, media.MediaId)
-	if err != nil {
-		ctx.Log.Error("Error while getting attributes for media: ", err)
-		sentry.CaptureException(err)
-		return _responses.InternalServerError("Error quarantining media"), false
-	}
-	if attr.Purpose == types.PurposePinned {
-		ctx.Log.Warn("Refusing to quarantine media due to it being pinned")
-		return &MediaQuarantinedResponse{NumQuarantined: 0}, true
+func performQuarantineRequest(ctx rcontext.RequestContext, host string, allowOtherHosts bool, toQuarantine *task_runner.QuarantineThis) interface{} {
+	lockedHost := host
+	if allowOtherHosts {
+		lockedHost = ""
 	}
 
-	// We reset the entire cache to avoid any lingering links floating around, such as thumbnails or other media.
-	// The reset is done before actually quarantining the media because that could fail for some reason
-	internal_cache.Get().Reset()
-
-	num, err := setMediaQuarantined(media, true, allowOtherHosts, ctx)
+	total, err := task_runner.QuarantineMedia(ctx, lockedHost, toQuarantine)
 	if err != nil {
-		ctx.Log.Error("Error quarantining media: ", err)
+		ctx.Log.Error(err)
 		sentry.CaptureException(err)
-		return _responses.InternalServerError("Error quarantining media"), false
-	}
-
-	return &MediaQuarantinedResponse{NumQuarantined: num}, true
-}
-
-func setMediaQuarantined(media *types.Media, isQuarantined bool, allowOtherHosts bool, ctx rcontext.RequestContext) (int, error) {
-	db := storage.GetDatabase().GetMediaStore(ctx)
-	numQuarantined := 0
-
-	// Quarantine all media with the same hash, including the one requested
-	otherMedia, err := db.GetByHash(media.Sha256Hash)
-	if err != nil {
-		return numQuarantined, err
-	}
-	for _, m := range otherMedia {
-		if m.Origin != media.Origin && !allowOtherHosts {
-			ctx.Log.Warn("Skipping quarantine on " + m.Origin + "/" + m.MediaId + " because it is on a different host from " + media.Origin + "/" + media.MediaId)
-			continue
-		}
-
-		err := db.SetQuarantined(m.Origin, m.MediaId, isQuarantined)
-		if err != nil {
-			return numQuarantined, err
-		}
-
-		numQuarantined++
-		ctx.Log.Warn("Media has been quarantined: " + m.Origin + "/" + m.MediaId)
+		return _responses.InternalServerError("error quarantining media")
 	}
 
-	return numQuarantined, nil
+	return &_responses.DoNotCacheResponse{Payload: &MediaQuarantinedResponse{NumQuarantined: total}}
 }
 
 func getQuarantineRequestInfo(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) (bool, bool, bool) {
@@ -277,7 +186,6 @@ func getQuarantineRequestInfo(r *http.Request, rctx rcontext.RequestContext, use
 			}
 
 			if !isLocalAdmin {
-				rctx.Log.Warn(user.UserId + " tried to quarantine media on another server")
 				canQuarantine = false
 				return canQuarantine, allowOtherHosts, isLocalAdmin
 			}
@@ -287,9 +195,5 @@ func getQuarantineRequestInfo(r *http.Request, rctx rcontext.RequestContext, use
 		}
 	}
 
-	if !canQuarantine {
-		rctx.Log.Warn(user.UserId + " tried to quarantine media")
-	}
-
 	return canQuarantine, allowOtherHosts, isLocalAdmin
 }
diff --git a/database/virtualtable_metadata.go b/database/virtualtable_metadata.go
index 5030f1d0..fac2c0ea 100644
--- a/database/virtualtable_metadata.go
+++ b/database/virtualtable_metadata.go
@@ -22,6 +22,8 @@ const selectUploadSizesForServer = "SELECT COALESCE((SELECT SUM(size_bytes) FROM
 const selectUploadCountsForServer = "SELECT COALESCE((SELECT COUNT(origin) FROM media WHERE origin = $1), 0) AS media, COALESCE((SELECT COUNT(origin) FROM thumbnails WHERE origin = $1), 0) AS thumbnails;"
 const selectMediaForDatastoreWithLastAccess = "SELECT m.sha256_hash, m.size_bytes, m.datastore_id, m.location, m.creation_ts, a.last_access_ts, m.content_type FROM media AS m JOIN last_access AS a ON m.sha256_hash = a.sha256_hash WHERE a.last_access_ts < $1 AND m.datastore_id = $2;"
 const selectThumbnailsForDatastoreWithLastAccess = "SELECT m.sha256_hash, m.size_bytes, m.datastore_id, m.location, m.creation_ts, a.last_access_ts, m.content_type FROM thumbnails AS m JOIN last_access AS a ON m.sha256_hash = a.sha256_hash WHERE a.last_access_ts < $1 AND m.datastore_id = $2;"
+const updateQuarantineByHash = "WITH t AS (SELECT m.origin AS origin, m.media_id AS media_id, a.purpose AS purpose FROM media AS m LEFT JOIN media_attributes AS a ON m.origin = a.origin AND m.media_id = a.media_id WHERE m.sha256_hash = $1 AND (a.purpose IS NULL OR a.purpose <> $2) AND m.quarantined <> $3) UPDATE media AS m2 SET quarantined = $3 FROM t WHERE m2.origin = t.origin AND m2.media_id = t.media_id;"
+const updateQuarantineByHashAndOrigin = "WITH t AS (SELECT m.origin AS origin, m.media_id AS media_id, a.purpose AS purpose FROM media AS m LEFT JOIN media_attributes AS a ON m.origin = a.origin AND m.media_id = a.media_id WHERE m.origin = $1 AND m.sha256_hash = $2 AND (a.purpose IS NULL OR a.purpose <> $3) AND m.quarantined <> $4) UPDATE media AS m2 SET quarantined = $4 FROM t WHERE m2.origin = t.origin AND m2.media_id = t.media_id;"
 
 type SynStatUserOrderBy string
 
@@ -51,6 +53,8 @@ type metadataVirtualTableStatements struct {
 	selectUploadCountsForServer                *sql.Stmt
 	selectMediaForDatastoreWithLastAccess      *sql.Stmt
 	selectThumbnailsForDatastoreWithLastAccess *sql.Stmt
+	updateQuarantineByHash                     *sql.Stmt
+	updateQuarantineByHashAndOrigin            *sql.Stmt
 }
 
 type metadataVirtualTableWithContext struct {
@@ -79,6 +83,12 @@ func prepareMetadataVirtualTables(db *sql.DB) (*metadataVirtualTableStatements,
 	if stmts.selectThumbnailsForDatastoreWithLastAccess, err = db.Prepare(selectThumbnailsForDatastoreWithLastAccess); err != nil {
 		return nil, errors.New("error preparing selectThumbnailsForDatastoreWithLastAccess: " + err.Error())
 	}
+	if stmts.updateQuarantineByHash, err = db.Prepare(updateQuarantineByHash); err != nil {
+		return nil, errors.New("error preparing updateQuarantineByHash: " + err.Error())
+	}
+	if stmts.updateQuarantineByHashAndOrigin, err = db.Prepare(updateQuarantineByHashAndOrigin); err != nil {
+		return nil, errors.New("error preparing updateQuarantineByHashAndOrigin: " + err.Error())
+	}
 
 	return stmts, nil
 }
@@ -228,3 +238,19 @@ func (s *metadataVirtualTableWithContext) GetMediaForDatastoreByLastAccess(datas
 func (s *metadataVirtualTableWithContext) GetThumbnailsForDatastoreByLastAccess(datastoreId string, lastAccessTs int64) ([]*VirtLastAccess, error) {
 	return s.scanLastAccess(s.statements.selectThumbnailsForDatastoreWithLastAccess.QueryContext(s.ctx, lastAccessTs, datastoreId))
 }
+
+func (s *metadataVirtualTableWithContext) UpdateQuarantineByHash(hash string, quarantined bool) (int64, error) {
+	c, err := s.statements.updateQuarantineByHash.ExecContext(s.ctx, hash, PurposePinned, quarantined)
+	if err != nil {
+		return 0, err
+	}
+	return c.RowsAffected()
+}
+
+func (s *metadataVirtualTableWithContext) UpdateQuarantineByHashAndOrigin(origin string, hash string, quarantined bool) (int64, error) {
+	c, err := s.statements.updateQuarantineByHashAndOrigin.ExecContext(s.ctx, origin, hash, PurposePinned, quarantined)
+	if err != nil {
+		return 0, err
+	}
+	return c.RowsAffected()
+}
diff --git a/redislib/cache.go b/redislib/cache.go
index 93f013e1..3e3d25b5 100644
--- a/redislib/cache.go
+++ b/redislib/cache.go
@@ -82,3 +82,12 @@ func TryGetMedia(ctx rcontext.RequestContext, hash string, startByte int64, endB
 	metrics.CacheHits.With(prometheus.Labels{"cache": "media"}).Inc()
 	return bytes.NewBuffer([]byte(s)), nil
 }
+
+func DeleteMedia(ctx rcontext.RequestContext, hash string) {
+	makeConnection()
+	if ring == nil {
+		return
+	}
+
+	ring.Del(ctx, hash)
+}
diff --git a/tasks/task_runner/quarantine.go b/tasks/task_runner/quarantine.go
new file mode 100644
index 00000000..24cdb21d
--- /dev/null
+++ b/tasks/task_runner/quarantine.go
@@ -0,0 +1,88 @@
+package task_runner
+
+import (
+	"github.com/turt2live/matrix-media-repo/common/rcontext"
+	"github.com/turt2live/matrix-media-repo/database"
+	"github.com/turt2live/matrix-media-repo/redislib"
+	"github.com/turt2live/matrix-media-repo/util"
+)
+
+type QuarantineRecord struct {
+	Origin  string
+	MediaId string
+}
+
+type QuarantineThis struct {
+	MxcUris []string
+	Single  *QuarantineRecord
+	DbMedia []*database.DbMedia
+}
+
+// QuarantineMedia returns (count quarantined, error)
+func QuarantineMedia(ctx rcontext.RequestContext, onlyHost string, toHandle *QuarantineThis) (int64, error) {
+	records, err := resolveMedia(ctx, onlyHost, toHandle) // records are roughly safe to rely on host-wise
+	if err != nil {
+		return 0, err
+	}
+
+	metadataDb := database.GetInstance().MetadataView.Prepare(ctx)
+	total := int64(0)
+	for _, r := range records {
+		if onlyHost != "" && onlyHost != r.Origin {
+			continue
+		}
+
+		count := int64(0)
+		if onlyHost != "" {
+			count, err = metadataDb.UpdateQuarantineByHashAndOrigin(r.Origin, r.Sha256Hash, true)
+		} else {
+			count, err = metadataDb.UpdateQuarantineByHash(r.Sha256Hash, true)
+		}
+		total += count
+		if err != nil {
+			return total, err
+		}
+
+		redislib.DeleteMedia(ctx, r.Sha256Hash)
+	}
+
+	return total, nil
+}
+
+func resolveMedia(ctx rcontext.RequestContext, onlyHost string, toHandle *QuarantineThis) ([]*database.DbMedia, error) {
+	db := database.GetInstance().Media.Prepare(ctx)
+
+	records := make([]*database.DbMedia, 0)
+	if toHandle.DbMedia != nil {
+		records = append(records, toHandle.DbMedia...)
+	}
+	if toHandle.Single != nil && (onlyHost == "" || toHandle.Single.Origin == onlyHost) {
+		r, err := db.GetById(toHandle.Single.Origin, toHandle.Single.MediaId)
+		if err != nil {
+			return nil, err
+		}
+		if r != nil {
+			records = append(records, r)
+		}
+	}
+	if toHandle.MxcUris != nil {
+		for _, mxc := range toHandle.MxcUris {
+			origin, mediaId, err := util.SplitMxc(mxc)
+			if onlyHost != "" && origin != onlyHost {
+				continue
+			}
+			if err != nil {
+				return nil, err
+			}
+			r, err := db.GetById(origin, mediaId)
+			if err != nil {
+				return nil, err
+			}
+			if r != nil {
+				records = append(records, r)
+			}
+		}
+	}
+
+	return records, nil
+}
-- 
GitLab