From 049ea12e2bcabfbe7998ea677d4df1b94b2b874d Mon Sep 17 00:00:00 2001 From: Travis Ralston <travpc@gmail.com> Date: Tue, 1 Aug 2023 21:56:54 -0600 Subject: [PATCH] Move quarantine logic to new structures --- api/custom/quarantine.go | 160 ++++++------------------------ database/virtualtable_metadata.go | 26 +++++ redislib/cache.go | 9 ++ tasks/task_runner/quarantine.go | 88 ++++++++++++++++ 4 files changed, 155 insertions(+), 128 deletions(-) create mode 100644 tasks/task_runner/quarantine.go diff --git a/api/custom/quarantine.go b/api/custom/quarantine.go index 6cdf7d18..477815da 100644 --- a/api/custom/quarantine.go +++ b/api/custom/quarantine.go @@ -1,30 +1,25 @@ package custom import ( - "database/sql" "net/http" "github.com/getsentry/sentry-go" "github.com/turt2live/matrix-media-repo/api/_apimeta" "github.com/turt2live/matrix-media-repo/api/_responses" "github.com/turt2live/matrix-media-repo/api/_routers" + "github.com/turt2live/matrix-media-repo/database" + "github.com/turt2live/matrix-media-repo/tasks/task_runner" "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/common/rcontext" - "github.com/turt2live/matrix-media-repo/internal_cache" "github.com/turt2live/matrix-media-repo/matrix" - "github.com/turt2live/matrix-media-repo/storage" - "github.com/turt2live/matrix-media-repo/types" "github.com/turt2live/matrix-media-repo/util" ) type MediaQuarantinedResponse struct { - NumQuarantined int `json:"num_quarantined"` + NumQuarantined int64 `json:"num_quarantined"` } -// Developer note: This isn't broken out into a dedicated controller class because the logic is slightly -// too complex to do so. If anything, the logic should be improved and moved. - func QuarantineRoomMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} { canQuarantine, allowOtherHosts, isLocalAdmin := getQuarantineRequestInfo(r, rctx, user) if !canQuarantine { @@ -47,31 +42,13 @@ func QuarantineRoomMedia(r *http.Request, rctx rcontext.RequestContext, user _ap var mxcs []string mxcs = append(mxcs, allMedia.LocalMxcs...) - mxcs = append(mxcs, allMedia.RemoteMxcs...) - - total := 0 - for _, mxc := range mxcs { - server, mediaId, err := util.SplitMxc(mxc) - if err != nil { - rctx.Log.Error("Error parsing MXC URI ("+mxc+"): ", err) - sentry.CaptureException(err) - return _responses.InternalServerError("error parsing mxc uri") - } - - if !allowOtherHosts && r.Host != server { - rctx.Log.Warn("Skipping media " + mxc + " because it is on a different host") - continue - } - - resp, ok := doQuarantine(rctx, server, mediaId, allowOtherHosts) - if !ok { - return resp - } - - total += resp.(*MediaQuarantinedResponse).NumQuarantined + if allowOtherHosts { + mxcs = append(mxcs, allMedia.RemoteMxcs...) } - return &_responses.DoNotCacheResponse{Payload: &MediaQuarantinedResponse{NumQuarantined: total}} + return performQuarantineRequest(rctx, r.Host, allowOtherHosts, &task_runner.QuarantineThis{ + MxcUris: mxcs, + }) } func QuarantineUserMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} { @@ -98,25 +75,17 @@ func QuarantineUserMedia(r *http.Request, rctx rcontext.RequestContext, user _ap return _responses.AuthFailed() } - db := storage.GetDatabase().GetMediaStore(rctx) - userMedia, err := db.GetMediaByUser(userId) + db := database.GetInstance().Media.Prepare(rctx) + userMedia, err := db.GetByUserId(userId) if err != nil { rctx.Log.Error("Error while listing media for the user: ", err) sentry.CaptureException(err) return _responses.InternalServerError("error retrieving media for user") } - total := 0 - for _, media := range userMedia { - resp, ok := doQuarantineOn(media, allowOtherHosts, rctx) - if !ok { - return resp - } - - total += resp.(*MediaQuarantinedResponse).NumQuarantined - } - - return &_responses.DoNotCacheResponse{Payload: &MediaQuarantinedResponse{NumQuarantined: total}} + return performQuarantineRequest(rctx, r.Host, allowOtherHosts, &task_runner.QuarantineThis{ + DbMedia: userMedia, + }) } func QuarantineDomainMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} { @@ -140,25 +109,17 @@ func QuarantineDomainMedia(r *http.Request, rctx rcontext.RequestContext, user _ return _responses.AuthFailed() } - db := storage.GetDatabase().GetMediaStore(rctx) - userMedia, err := db.GetAllMediaForServer(serverName) + db := database.GetInstance().Media.Prepare(rctx) + domainMedia, err := db.GetByOrigin(serverName) if err != nil { rctx.Log.Error("Error while listing media for the server: ", err) sentry.CaptureException(err) return _responses.InternalServerError("error retrieving media for server") } - total := 0 - for _, media := range userMedia { - resp, ok := doQuarantineOn(media, allowOtherHosts, rctx) - if !ok { - return resp - } - - total += resp.(*MediaQuarantinedResponse).NumQuarantined - } - - return &_responses.DoNotCacheResponse{Payload: &MediaQuarantinedResponse{NumQuarantined: total}} + return performQuarantineRequest(rctx, r.Host, allowOtherHosts, &task_runner.QuarantineThis{ + DbMedia: domainMedia, + }) } func QuarantineMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} { @@ -184,80 +145,28 @@ func QuarantineMedia(r *http.Request, rctx rcontext.RequestContext, user _apimet return _responses.BadRequest("unable to quarantine media on other homeservers") } - resp, _ := doQuarantine(rctx, server, mediaId, allowOtherHosts) - return &_responses.DoNotCacheResponse{Payload: resp} -} - -func doQuarantine(ctx rcontext.RequestContext, origin string, mediaId string, allowOtherHosts bool) (interface{}, bool) { - db := storage.GetDatabase().GetMediaStore(ctx) - media, err := db.Get(origin, mediaId) - if err != nil { - if err == sql.ErrNoRows { - ctx.Log.Warn("Media not found, could not quarantine: " + origin + "/" + mediaId) - return &MediaQuarantinedResponse{0}, true - } - - ctx.Log.Error("Error fetching media: ", err) - sentry.CaptureException(err) - return _responses.InternalServerError("error quarantining media"), false - } - - return doQuarantineOn(media, allowOtherHosts, ctx) + return performQuarantineRequest(rctx, r.Host, allowOtherHosts, &task_runner.QuarantineThis{ + Single: &task_runner.QuarantineRecord{ + Origin: server, + MediaId: mediaId, + }, + }) } -func doQuarantineOn(media *types.Media, allowOtherHosts bool, ctx rcontext.RequestContext) (interface{}, bool) { - // Check to make sure the media doesn't have a purpose in staying - attrDb := storage.GetDatabase().GetMediaAttributesStore(ctx) - attr, err := attrDb.GetAttributesDefaulted(media.Origin, media.MediaId) - if err != nil { - ctx.Log.Error("Error while getting attributes for media: ", err) - sentry.CaptureException(err) - return _responses.InternalServerError("Error quarantining media"), false - } - if attr.Purpose == types.PurposePinned { - ctx.Log.Warn("Refusing to quarantine media due to it being pinned") - return &MediaQuarantinedResponse{NumQuarantined: 0}, true +func performQuarantineRequest(ctx rcontext.RequestContext, host string, allowOtherHosts bool, toQuarantine *task_runner.QuarantineThis) interface{} { + lockedHost := host + if allowOtherHosts { + lockedHost = "" } - // We reset the entire cache to avoid any lingering links floating around, such as thumbnails or other media. - // The reset is done before actually quarantining the media because that could fail for some reason - internal_cache.Get().Reset() - - num, err := setMediaQuarantined(media, true, allowOtherHosts, ctx) + total, err := task_runner.QuarantineMedia(ctx, lockedHost, toQuarantine) if err != nil { - ctx.Log.Error("Error quarantining media: ", err) + ctx.Log.Error(err) sentry.CaptureException(err) - return _responses.InternalServerError("Error quarantining media"), false - } - - return &MediaQuarantinedResponse{NumQuarantined: num}, true -} - -func setMediaQuarantined(media *types.Media, isQuarantined bool, allowOtherHosts bool, ctx rcontext.RequestContext) (int, error) { - db := storage.GetDatabase().GetMediaStore(ctx) - numQuarantined := 0 - - // Quarantine all media with the same hash, including the one requested - otherMedia, err := db.GetByHash(media.Sha256Hash) - if err != nil { - return numQuarantined, err - } - for _, m := range otherMedia { - if m.Origin != media.Origin && !allowOtherHosts { - ctx.Log.Warn("Skipping quarantine on " + m.Origin + "/" + m.MediaId + " because it is on a different host from " + media.Origin + "/" + media.MediaId) - continue - } - - err := db.SetQuarantined(m.Origin, m.MediaId, isQuarantined) - if err != nil { - return numQuarantined, err - } - - numQuarantined++ - ctx.Log.Warn("Media has been quarantined: " + m.Origin + "/" + m.MediaId) + return _responses.InternalServerError("error quarantining media") } - return numQuarantined, nil + return &_responses.DoNotCacheResponse{Payload: &MediaQuarantinedResponse{NumQuarantined: total}} } func getQuarantineRequestInfo(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) (bool, bool, bool) { @@ -277,7 +186,6 @@ func getQuarantineRequestInfo(r *http.Request, rctx rcontext.RequestContext, use } if !isLocalAdmin { - rctx.Log.Warn(user.UserId + " tried to quarantine media on another server") canQuarantine = false return canQuarantine, allowOtherHosts, isLocalAdmin } @@ -287,9 +195,5 @@ func getQuarantineRequestInfo(r *http.Request, rctx rcontext.RequestContext, use } } - if !canQuarantine { - rctx.Log.Warn(user.UserId + " tried to quarantine media") - } - return canQuarantine, allowOtherHosts, isLocalAdmin } diff --git a/database/virtualtable_metadata.go b/database/virtualtable_metadata.go index 5030f1d0..fac2c0ea 100644 --- a/database/virtualtable_metadata.go +++ b/database/virtualtable_metadata.go @@ -22,6 +22,8 @@ const selectUploadSizesForServer = "SELECT COALESCE((SELECT SUM(size_bytes) FROM const selectUploadCountsForServer = "SELECT COALESCE((SELECT COUNT(origin) FROM media WHERE origin = $1), 0) AS media, COALESCE((SELECT COUNT(origin) FROM thumbnails WHERE origin = $1), 0) AS thumbnails;" const selectMediaForDatastoreWithLastAccess = "SELECT m.sha256_hash, m.size_bytes, m.datastore_id, m.location, m.creation_ts, a.last_access_ts, m.content_type FROM media AS m JOIN last_access AS a ON m.sha256_hash = a.sha256_hash WHERE a.last_access_ts < $1 AND m.datastore_id = $2;" const selectThumbnailsForDatastoreWithLastAccess = "SELECT m.sha256_hash, m.size_bytes, m.datastore_id, m.location, m.creation_ts, a.last_access_ts, m.content_type FROM thumbnails AS m JOIN last_access AS a ON m.sha256_hash = a.sha256_hash WHERE a.last_access_ts < $1 AND m.datastore_id = $2;" +const updateQuarantineByHash = "WITH t AS (SELECT m.origin AS origin, m.media_id AS media_id, a.purpose AS purpose FROM media AS m LEFT JOIN media_attributes AS a ON m.origin = a.origin AND m.media_id = a.media_id WHERE m.sha256_hash = $1 AND (a.purpose IS NULL OR a.purpose <> $2) AND m.quarantined <> $3) UPDATE media AS m2 SET quarantined = $3 FROM t WHERE m2.origin = t.origin AND m2.media_id = t.media_id;" +const updateQuarantineByHashAndOrigin = "WITH t AS (SELECT m.origin AS origin, m.media_id AS media_id, a.purpose AS purpose FROM media AS m LEFT JOIN media_attributes AS a ON m.origin = a.origin AND m.media_id = a.media_id WHERE m.origin = $1 AND m.sha256_hash = $2 AND (a.purpose IS NULL OR a.purpose <> $3) AND m.quarantined <> $4) UPDATE media AS m2 SET quarantined = $4 FROM t WHERE m2.origin = t.origin AND m2.media_id = t.media_id;" type SynStatUserOrderBy string @@ -51,6 +53,8 @@ type metadataVirtualTableStatements struct { selectUploadCountsForServer *sql.Stmt selectMediaForDatastoreWithLastAccess *sql.Stmt selectThumbnailsForDatastoreWithLastAccess *sql.Stmt + updateQuarantineByHash *sql.Stmt + updateQuarantineByHashAndOrigin *sql.Stmt } type metadataVirtualTableWithContext struct { @@ -79,6 +83,12 @@ func prepareMetadataVirtualTables(db *sql.DB) (*metadataVirtualTableStatements, if stmts.selectThumbnailsForDatastoreWithLastAccess, err = db.Prepare(selectThumbnailsForDatastoreWithLastAccess); err != nil { return nil, errors.New("error preparing selectThumbnailsForDatastoreWithLastAccess: " + err.Error()) } + if stmts.updateQuarantineByHash, err = db.Prepare(updateQuarantineByHash); err != nil { + return nil, errors.New("error preparing updateQuarantineByHash: " + err.Error()) + } + if stmts.updateQuarantineByHashAndOrigin, err = db.Prepare(updateQuarantineByHashAndOrigin); err != nil { + return nil, errors.New("error preparing updateQuarantineByHashAndOrigin: " + err.Error()) + } return stmts, nil } @@ -228,3 +238,19 @@ func (s *metadataVirtualTableWithContext) GetMediaForDatastoreByLastAccess(datas func (s *metadataVirtualTableWithContext) GetThumbnailsForDatastoreByLastAccess(datastoreId string, lastAccessTs int64) ([]*VirtLastAccess, error) { return s.scanLastAccess(s.statements.selectThumbnailsForDatastoreWithLastAccess.QueryContext(s.ctx, lastAccessTs, datastoreId)) } + +func (s *metadataVirtualTableWithContext) UpdateQuarantineByHash(hash string, quarantined bool) (int64, error) { + c, err := s.statements.updateQuarantineByHash.ExecContext(s.ctx, hash, PurposePinned, quarantined) + if err != nil { + return 0, err + } + return c.RowsAffected() +} + +func (s *metadataVirtualTableWithContext) UpdateQuarantineByHashAndOrigin(origin string, hash string, quarantined bool) (int64, error) { + c, err := s.statements.updateQuarantineByHashAndOrigin.ExecContext(s.ctx, origin, hash, PurposePinned, quarantined) + if err != nil { + return 0, err + } + return c.RowsAffected() +} diff --git a/redislib/cache.go b/redislib/cache.go index 93f013e1..3e3d25b5 100644 --- a/redislib/cache.go +++ b/redislib/cache.go @@ -82,3 +82,12 @@ func TryGetMedia(ctx rcontext.RequestContext, hash string, startByte int64, endB metrics.CacheHits.With(prometheus.Labels{"cache": "media"}).Inc() return bytes.NewBuffer([]byte(s)), nil } + +func DeleteMedia(ctx rcontext.RequestContext, hash string) { + makeConnection() + if ring == nil { + return + } + + ring.Del(ctx, hash) +} diff --git a/tasks/task_runner/quarantine.go b/tasks/task_runner/quarantine.go new file mode 100644 index 00000000..24cdb21d --- /dev/null +++ b/tasks/task_runner/quarantine.go @@ -0,0 +1,88 @@ +package task_runner + +import ( + "github.com/turt2live/matrix-media-repo/common/rcontext" + "github.com/turt2live/matrix-media-repo/database" + "github.com/turt2live/matrix-media-repo/redislib" + "github.com/turt2live/matrix-media-repo/util" +) + +type QuarantineRecord struct { + Origin string + MediaId string +} + +type QuarantineThis struct { + MxcUris []string + Single *QuarantineRecord + DbMedia []*database.DbMedia +} + +// QuarantineMedia returns (count quarantined, error) +func QuarantineMedia(ctx rcontext.RequestContext, onlyHost string, toHandle *QuarantineThis) (int64, error) { + records, err := resolveMedia(ctx, onlyHost, toHandle) // records are roughly safe to rely on host-wise + if err != nil { + return 0, err + } + + metadataDb := database.GetInstance().MetadataView.Prepare(ctx) + total := int64(0) + for _, r := range records { + if onlyHost != "" && onlyHost != r.Origin { + continue + } + + count := int64(0) + if onlyHost != "" { + count, err = metadataDb.UpdateQuarantineByHashAndOrigin(r.Origin, r.Sha256Hash, true) + } else { + count, err = metadataDb.UpdateQuarantineByHash(r.Sha256Hash, true) + } + total += count + if err != nil { + return total, err + } + + redislib.DeleteMedia(ctx, r.Sha256Hash) + } + + return total, nil +} + +func resolveMedia(ctx rcontext.RequestContext, onlyHost string, toHandle *QuarantineThis) ([]*database.DbMedia, error) { + db := database.GetInstance().Media.Prepare(ctx) + + records := make([]*database.DbMedia, 0) + if toHandle.DbMedia != nil { + records = append(records, toHandle.DbMedia...) + } + if toHandle.Single != nil && (onlyHost == "" || toHandle.Single.Origin == onlyHost) { + r, err := db.GetById(toHandle.Single.Origin, toHandle.Single.MediaId) + if err != nil { + return nil, err + } + if r != nil { + records = append(records, r) + } + } + if toHandle.MxcUris != nil { + for _, mxc := range toHandle.MxcUris { + origin, mediaId, err := util.SplitMxc(mxc) + if onlyHost != "" && origin != onlyHost { + continue + } + if err != nil { + return nil, err + } + r, err := db.GetById(origin, mediaId) + if err != nil { + return nil, err + } + if r != nil { + records = append(records, r) + } + } + } + + return records, nil +} -- GitLab