Skip to content
Snippets Groups Projects
Commit 4116e786 authored by Travis Ralston's avatar Travis Ralston
Browse files

Centralize all purge API calls

parent 902eaa35
No related branches found
No related tags found
No related merge requests found
package custom
import (
"database/sql"
"net/http"
"strconv"
......@@ -9,15 +8,13 @@ import (
"github.com/turt2live/matrix-media-repo/api/_apimeta"
"github.com/turt2live/matrix-media-repo/api/_responses"
"github.com/turt2live/matrix-media-repo/api/_routers"
"github.com/turt2live/matrix-media-repo/database"
"github.com/turt2live/matrix-media-repo/tasks/task_runner"
"github.com/sirupsen/logrus"
"github.com/turt2live/matrix-media-repo/common"
"github.com/turt2live/matrix-media-repo/common/rcontext"
"github.com/turt2live/matrix-media-repo/controllers/maintenance_controller"
"github.com/turt2live/matrix-media-repo/matrix"
"github.com/turt2live/matrix-media-repo/storage"
"github.com/turt2live/matrix-media-repo/types"
"github.com/turt2live/matrix-media-repo/util"
)
......@@ -51,8 +48,7 @@ func PurgeRemoteMedia(r *http.Request, rctx rcontext.RequestContext, user _apime
}
func PurgeIndividualRecord(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} {
isGlobalAdmin, isLocalAdmin := _apimeta.GetRequestUserAdminStatus(r, rctx, user)
localServerName := r.Host
authCtx, _, _ := getPurgeAuthContext(rctx, r, user)
server := _routers.GetParam("server", r)
mediaId := _routers.GetParam("mediaId", r)
......@@ -66,66 +62,54 @@ func PurgeIndividualRecord(r *http.Request, rctx rcontext.RequestContext, user _
"mediaId": mediaId,
})
// If the user is NOT a global admin, ensure they are speaking to the right server
if !isGlobalAdmin {
if server != localServerName {
_, err := task_runner.PurgeMedia(rctx, authCtx, &task_runner.QuarantineThis{
Single: &task_runner.QuarantineRecord{
Origin: server,
MediaId: mediaId,
},
})
if err != nil {
if err == common.ErrWrongUser {
return _responses.AuthFailed()
}
// If the user is NOT a local admin, ensure they uploaded the content in the first place
if !isLocalAdmin {
db := storage.GetDatabase().GetMediaStore(rctx)
m, err := db.Get(server, mediaId)
if err == sql.ErrNoRows {
return _responses.NotFoundError()
}
if err != nil {
rctx.Log.Error("Error checking ownership of media: ", err)
sentry.CaptureException(err)
return _responses.InternalServerError("error checking media ownership")
}
if m.UserId != user.UserId {
return _responses.AuthFailed()
}
}
}
err := maintenance_controller.PurgeMedia(server, mediaId, rctx)
if err == sql.ErrNoRows || err == common.ErrMediaNotFound {
return _responses.NotFoundError()
}
if err != nil {
rctx.Log.Error("Error purging media: ", err)
rctx.Log.Error(err)
sentry.CaptureException(err)
return _responses.InternalServerError("error purging media")
return _responses.InternalServerError("unexpected error")
}
return &_responses.DoNotCacheResponse{Payload: map[string]interface{}{"purged": true}}
}
func PurgeQuarantined(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} {
isGlobalAdmin, isLocalAdmin := _apimeta.GetRequestUserAdminStatus(r, rctx, user)
localServerName := r.Host
authCtx, isGlobalAdmin, isLocalAdmin := getPurgeAuthContext(rctx, r, user)
var affected []*types.Media
var affected []*database.DbMedia
var err error
mediaDb := database.GetInstance().Media.Prepare(rctx)
if isGlobalAdmin {
affected, err = maintenance_controller.PurgeQuarantined(rctx)
affected, err = mediaDb.GetByQuarantine()
} else if isLocalAdmin {
affected, err = maintenance_controller.PurgeQuarantinedFor(localServerName, rctx)
affected, err = mediaDb.GetByOriginQuarantine(r.Host)
} else {
return _responses.AuthFailed()
}
if err != nil {
rctx.Log.Error("Error purging media: ", err)
rctx.Log.Error(err)
sentry.CaptureException(err)
return _responses.InternalServerError("error purging media")
return _responses.InternalServerError("error fetching media records")
}
mxcs := make([]string, 0)
for _, a := range affected {
mxcs = append(mxcs, a.MxcUri())
mxcs, err := task_runner.PurgeMedia(rctx, authCtx, &task_runner.QuarantineThis{
DbMedia: affected,
})
if err != nil {
if err == common.ErrWrongUser {
return _responses.AuthFailed()
}
rctx.Log.Error(err)
sentry.CaptureException(err)
return _responses.InternalServerError("unexpected error")
}
return &_responses.DoNotCacheResponse{Payload: map[string]interface{}{"purged": true, "affected": mxcs}}
......@@ -156,24 +140,36 @@ func PurgeOldMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.
"include_local": includeLocal,
})
affected, err := maintenance_controller.PurgeOldMedia(beforeTs, includeLocal, rctx)
domains := make([]string, 0)
if !includeLocal {
domains = util.GetOurDomains()
}
mediaDb := database.GetInstance().Media.Prepare(rctx)
records, err := mediaDb.GetOldExcluding(domains, beforeTs)
if err != nil {
rctx.Log.Error("Error purging media: ", err)
rctx.Log.Error(err)
sentry.CaptureException(err)
return _responses.InternalServerError("error purging media")
return _responses.InternalServerError("error fetching media records")
}
mxcs := make([]string, 0)
for _, a := range affected {
mxcs = append(mxcs, a.MxcUri())
mxcs, err := task_runner.PurgeMedia(rctx, &task_runner.PurgeAuthContext{}, &task_runner.QuarantineThis{
DbMedia: records,
})
if err != nil {
if err == common.ErrWrongUser {
return _responses.AuthFailed()
}
rctx.Log.Error(err)
sentry.CaptureException(err)
return _responses.InternalServerError("unexpected error")
}
return &_responses.DoNotCacheResponse{Payload: map[string]interface{}{"purged": true, "affected": mxcs}}
}
func PurgeUserMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} {
isGlobalAdmin, isLocalAdmin := _apimeta.GetRequestUserAdminStatus(r, rctx, user)
authCtx, isGlobalAdmin, isLocalAdmin := getPurgeAuthContext(rctx, r, user)
if !isGlobalAdmin && !isLocalAdmin {
return _responses.AuthFailed()
}
......@@ -206,24 +202,31 @@ func PurgeUserMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta
return _responses.AuthFailed()
}
affected, err := maintenance_controller.PurgeUserMedia(userId, beforeTs, rctx)
mediaDb := database.GetInstance().Media.Prepare(rctx)
records, err := mediaDb.GetOldByUserId(userId, beforeTs)
if err != nil {
rctx.Log.Error("Error purging media: ", err)
rctx.Log.Error(err)
sentry.CaptureException(err)
return _responses.InternalServerError("error purging media")
return _responses.InternalServerError("error fetching media records")
}
mxcs := make([]string, 0)
for _, a := range affected {
mxcs = append(mxcs, a.MxcUri())
mxcs, err := task_runner.PurgeMedia(rctx, authCtx, &task_runner.QuarantineThis{
DbMedia: records,
})
if err != nil {
if err == common.ErrWrongUser {
return _responses.AuthFailed()
}
rctx.Log.Error(err)
sentry.CaptureException(err)
return _responses.InternalServerError("unexpected error")
}
return &_responses.DoNotCacheResponse{Payload: map[string]interface{}{"purged": true, "affected": mxcs}}
}
func PurgeRoomMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} {
isGlobalAdmin, isLocalAdmin := _apimeta.GetRequestUserAdminStatus(r, rctx, user)
authCtx, isGlobalAdmin, isLocalAdmin := getPurgeAuthContext(rctx, r, user)
if !isGlobalAdmin && !isLocalAdmin {
return _responses.AuthFailed()
}
......@@ -276,32 +279,27 @@ func PurgeRoomMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta
mxcs = append(mxcs, mxc)
}
} else {
for _, mxc := range allMedia.LocalMxcs {
mxcs = append(mxcs, mxc)
}
for _, mxc := range allMedia.RemoteMxcs {
mxcs = append(mxcs, mxc)
}
mxcs = append(mxcs, allMedia.LocalMxcs...)
mxcs = append(mxcs, allMedia.RemoteMxcs...)
}
affected, err := maintenance_controller.PurgeRoomMedia(mxcs, beforeTs, rctx)
mxcs2, err := task_runner.PurgeMedia(rctx, authCtx, &task_runner.QuarantineThis{
MxcUris: mxcs,
})
if err != nil {
rctx.Log.Error("Error purging media: ", err)
if err == common.ErrWrongUser {
return _responses.AuthFailed()
}
rctx.Log.Error(err)
sentry.CaptureException(err)
return _responses.InternalServerError("error purging media")
return _responses.InternalServerError("unexpected error")
}
mxcs = make([]string, 0)
for _, a := range affected {
mxcs = append(mxcs, a.MxcUri())
}
return &_responses.DoNotCacheResponse{Payload: map[string]interface{}{"purged": true, "affected": mxcs}}
return &_responses.DoNotCacheResponse{Payload: map[string]interface{}{"purged": true, "affected": mxcs2}}
}
func PurgeDomainMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} {
isGlobalAdmin, isLocalAdmin := _apimeta.GetRequestUserAdminStatus(r, rctx, user)
authCtx, isGlobalAdmin, isLocalAdmin := getPurgeAuthContext(rctx, r, user)
if !isGlobalAdmin && !isLocalAdmin {
return _responses.AuthFailed()
}
......@@ -331,18 +329,36 @@ func PurgeDomainMedia(r *http.Request, rctx rcontext.RequestContext, user _apime
return _responses.AuthFailed()
}
affected, err := maintenance_controller.PurgeDomainMedia(serverName, beforeTs, rctx)
mediaDb := database.GetInstance().Media.Prepare(rctx)
records, err := mediaDb.GetOldByOrigin(serverName, beforeTs)
if err != nil {
rctx.Log.Error("Error purging media: ", err)
rctx.Log.Error(err)
sentry.CaptureException(err)
return _responses.InternalServerError("error purging media")
return _responses.InternalServerError("error fetching media records")
}
mxcs := make([]string, 0)
for _, a := range affected {
mxcs = append(mxcs, a.MxcUri())
mxcs, err := task_runner.PurgeMedia(rctx, authCtx, &task_runner.QuarantineThis{
DbMedia: records,
})
if err != nil {
if err == common.ErrWrongUser {
return _responses.AuthFailed()
}
rctx.Log.Error(err)
sentry.CaptureException(err)
return _responses.InternalServerError("unexpected error")
}
return &_responses.DoNotCacheResponse{Payload: map[string]interface{}{"purged": true, "affected": mxcs}}
}
func getPurgeAuthContext(ctx rcontext.RequestContext, r *http.Request, user _apimeta.UserInfo) (*task_runner.PurgeAuthContext, bool, bool) {
globalAdmin, localAdmin := _apimeta.GetRequestUserAdminStatus(r, ctx, user)
if globalAdmin {
return &task_runner.PurgeAuthContext{}, true, localAdmin
}
if localAdmin {
return &task_runner.PurgeAuthContext{SourceOrigin: r.Host}, false, true
}
return &task_runner.PurgeAuthContext{UploaderUserId: user.UserId}, false, false
}
......@@ -36,7 +36,9 @@ const insertMedia = "INSERT INTO media (origin, media_id, upload_name, content_t
const selectMediaExists = "SELECT TRUE FROM media WHERE origin = $1 AND media_id = $2 LIMIT 1;"
const selectMediaById = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, creation_ts, quarantined, datastore_id, location FROM media WHERE origin = $1 AND media_id = $2;"
const selectMediaByUserId = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, creation_ts, quarantined, datastore_id, location FROM media WHERE user_id = $1;"
const selectOldMediaByUserId = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, creation_ts, quarantined, datastore_id, location FROM media WHERE user_id = $1 AND creation_ts < $2;"
const selectMediaByOrigin = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, creation_ts, quarantined, datastore_id, location FROM media WHERE origin = $1;"
const selectOldMediaByOrigin = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, creation_ts, quarantined, datastore_id, location FROM media WHERE origin = $1 AND creation_ts < $2;"
const selectMediaByLocationExists = "SELECT TRUE FROM media WHERE datastore_id = $1 AND location = $2 LIMIT 1;"
const selectMediaByUserCount = "SELECT COUNT(*) FROM media WHERE user_id = $1;"
const selectMediaByOriginAndUserIds = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, creation_ts, quarantined, datastore_id, location FROM media WHERE origin = $1 AND user_id = ANY($2);"
......@@ -44,23 +46,31 @@ const selectMediaByOriginAndIds = "SELECT origin, media_id, upload_name, content
const selectOldMediaExcludingDomains = "SELECT m.origin, m.media_id, m.upload_name, m.content_type, m.user_id, m.sha256_hash, m.size_bytes, m.creation_ts, m.quarantined, m.datastore_id, m.location FROM media AS m WHERE m.origin <> ANY($1) AND m.creation_ts < $2 AND (SELECT COUNT(d.*) FROM media AS d WHERE d.sha256_hash = m.sha256_hash AND d.creation_ts >= $2) = 0 AND (SELECT COUNT(d.*) FROM media AS d WHERE d.sha256_hash = m.sha256_hash AND d.origin = ANY($1)) = 0;"
const deleteMedia = "DELETE FROM media WHERE origin = $1 AND media_id = $2;"
const updateMediaLocation = "UPDATE media SET datastore_id = $3, location = $4 WHERE datastore_id = $1 AND location = $2;"
const selectMediaByLocation = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, creation_ts, quarantined, datastore_id, location FROM media WHERE datastore_id = $1 AND location = $2;"
const selectMediaByQuarantine = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, creation_ts, quarantined, datastore_id, location FROM media WHERE quarantined = TRUE;"
const selectMediaByQuarantineAndOrigin = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, creation_ts, quarantined, datastore_id, location FROM media WHERE quarantined = TRUE AND origin = $1;"
type mediaTableStatements struct {
selectDistinctMediaDatastoreIds *sql.Stmt
selectMediaIsQuarantinedByHash *sql.Stmt
selectMediaByHash *sql.Stmt
insertMedia *sql.Stmt
selectMediaExists *sql.Stmt
selectMediaById *sql.Stmt
selectMediaByUserId *sql.Stmt
selectMediaByOrigin *sql.Stmt
selectMediaByLocationExists *sql.Stmt
selectMediaByUserCount *sql.Stmt
selectMediaByOriginAndUserIds *sql.Stmt
selectMediaByOriginAndIds *sql.Stmt
selectOldMediaExcludingDomains *sql.Stmt
deleteMedia *sql.Stmt
updateMediaLocation *sql.Stmt
selectDistinctMediaDatastoreIds *sql.Stmt
selectMediaIsQuarantinedByHash *sql.Stmt
selectMediaByHash *sql.Stmt
insertMedia *sql.Stmt
selectMediaExists *sql.Stmt
selectMediaById *sql.Stmt
selectMediaByUserId *sql.Stmt
selectOldMediaByUserId *sql.Stmt
selectMediaByOrigin *sql.Stmt
selectOldMediaByOrigin *sql.Stmt
selectMediaByLocationExists *sql.Stmt
selectMediaByUserCount *sql.Stmt
selectMediaByOriginAndUserIds *sql.Stmt
selectMediaByOriginAndIds *sql.Stmt
selectOldMediaExcludingDomains *sql.Stmt
deleteMedia *sql.Stmt
updateMediaLocation *sql.Stmt
selectMediaByLocation *sql.Stmt
selectMediaByQuarantine *sql.Stmt
selectMediaByQuarantineAndOrigin *sql.Stmt
}
type mediaTableWithContext struct {
......@@ -93,9 +103,15 @@ func prepareMediaTables(db *sql.DB) (*mediaTableStatements, error) {
if stmts.selectMediaByUserId, err = db.Prepare(selectMediaByUserId); err != nil {
return nil, errors.New("error preparing selectMediaByUserId: " + err.Error())
}
if stmts.selectOldMediaByUserId, err = db.Prepare(selectOldMediaByUserId); err != nil {
return nil, errors.New("error preparing selectOldMediaByUserId: " + err.Error())
}
if stmts.selectMediaByOrigin, err = db.Prepare(selectMediaByOrigin); err != nil {
return nil, errors.New("error preparing selectMediaByOrigin: " + err.Error())
}
if stmts.selectOldMediaByOrigin, err = db.Prepare(selectOldMediaByOrigin); err != nil {
return nil, errors.New("error preparing selectOldMediaByOrigin: " + err.Error())
}
if stmts.selectMediaByLocationExists, err = db.Prepare(selectMediaByLocationExists); err != nil {
return nil, errors.New("error preparing selectMediaByLocationExists: " + err.Error())
}
......@@ -117,6 +133,15 @@ func prepareMediaTables(db *sql.DB) (*mediaTableStatements, error) {
if stmts.updateMediaLocation, err = db.Prepare(updateMediaLocation); err != nil {
return nil, errors.New("error preparing updateMediaLocation: " + err.Error())
}
if stmts.selectMediaByLocation, err = db.Prepare(selectMediaByLocation); err != nil {
return nil, errors.New("error preparing selectMediaByLocation: " + err.Error())
}
if stmts.selectMediaByQuarantine, err = db.Prepare(selectMediaByQuarantine); err != nil {
return nil, errors.New("error preparing selectMediaByQuarantine: " + err.Error())
}
if stmts.selectMediaByQuarantineAndOrigin, err = db.Prepare(selectMediaByQuarantineAndOrigin); err != nil {
return nil, errors.New("error preparing selectMediaByQuarantineAndOrigin: " + err.Error())
}
return stmts, nil
}
......@@ -188,10 +213,18 @@ func (s *mediaTableWithContext) GetByUserId(userId string) ([]*DbMedia, error) {
return s.scanRows(s.statements.selectMediaByUserId.QueryContext(s.ctx, userId))
}
func (s *mediaTableWithContext) GetOldByUserId(userId string, beforeTs int64) ([]*DbMedia, error) {
return s.scanRows(s.statements.selectOldMediaByUserId.QueryContext(s.ctx, userId, beforeTs))
}
func (s *mediaTableWithContext) GetByOrigin(origin string) ([]*DbMedia, error) {
return s.scanRows(s.statements.selectMediaByOrigin.QueryContext(s.ctx, origin))
}
func (s *mediaTableWithContext) GetOldByOrigin(origin string, beforeTs int64) ([]*DbMedia, error) {
return s.scanRows(s.statements.selectOldMediaByOrigin.QueryContext(s.ctx, origin, beforeTs))
}
func (s *mediaTableWithContext) GetByOriginUsers(origin string, userIds []string) ([]*DbMedia, error) {
return s.scanRows(s.statements.selectMediaByOriginAndUserIds.QueryContext(s.ctx, origin, pq.Array(userIds)))
}
......@@ -204,6 +237,18 @@ func (s *mediaTableWithContext) GetOldExcluding(origins []string, beforeTs int64
return s.scanRows(s.statements.selectOldMediaExcludingDomains.QueryContext(s.ctx, pq.Array(origins), beforeTs))
}
func (s *mediaTableWithContext) GetByLocation(datastoreId string, location string) ([]*DbMedia, error) {
return s.scanRows(s.statements.selectMediaByLocation.QueryContext(s.ctx, datastoreId, location))
}
func (s *mediaTableWithContext) GetByQuarantine() ([]*DbMedia, error) {
return s.scanRows(s.statements.selectMediaByQuarantine.QueryContext(s.ctx))
}
func (s *mediaTableWithContext) GetByOriginQuarantine(origin string) ([]*DbMedia, error) {
return s.scanRows(s.statements.selectMediaByQuarantineAndOrigin.QueryContext(s.ctx, origin))
}
func (s *mediaTableWithContext) GetById(origin string, mediaId string) (*DbMedia, error) {
row := s.statements.selectMediaById.QueryRowContext(s.ctx, origin, mediaId)
val := &DbMedia{Locatable: &Locatable{}}
......
......@@ -13,10 +13,12 @@ type DbReservedMedia struct {
Reason string
}
const insertReservedMedia = "INSERT INTO reserved_media (origin, media_id, reason) VALUES ($1, $2, $3);"
const insertReservedMediaNoConflict = "INSERT INTO reserved_media (origin, media_id, reason) VALUES ($1, $2, $3) ON CONFLICT (origin, media_id) DO NOTHING;"
const selectReservedMediaExists = "SELECT TRUE FROM reserved_media WHERE origin = $1 AND media_id = $2 LIMIT 1;"
type reservedMediaTableStatements struct {
insertReservedMedia *sql.Stmt
insertReservedMediaNoConflict *sql.Stmt
selectReservedMediaExists *sql.Stmt
}
type reservedMediaTableWithContext struct {
......@@ -28,8 +30,11 @@ func prepareReservedMediaTables(db *sql.DB) (*reservedMediaTableStatements, erro
var err error
var stmts = &reservedMediaTableStatements{}
if stmts.insertReservedMedia, err = db.Prepare(insertReservedMedia); err != nil {
return nil, errors.New("error preparing insertReservedMedia: " + err.Error())
if stmts.insertReservedMediaNoConflict, err = db.Prepare(insertReservedMediaNoConflict); err != nil {
return nil, errors.New("error preparing insertReservedMediaNoConflict: " + err.Error())
}
if stmts.selectReservedMediaExists, err = db.Prepare(selectReservedMediaExists); err != nil {
return nil, errors.New("error preparing selectReservedMediaExists: " + err.Error())
}
return stmts, nil
......@@ -42,7 +47,18 @@ func (s *reservedMediaTableStatements) Prepare(ctx rcontext.RequestContext) *res
}
}
func (s *reservedMediaTableWithContext) TryInsert(origin string, mediaId string, reason string) error {
_, err := s.statements.insertReservedMedia.ExecContext(s.ctx, origin, mediaId, reason)
func (s *reservedMediaTableWithContext) InsertNoConflict(origin string, mediaId string, reason string) error {
_, err := s.statements.insertReservedMediaNoConflict.ExecContext(s.ctx, origin, mediaId, reason)
return err
}
func (s *reservedMediaTableWithContext) IdExists(origin string, mediaId string) (bool, error) {
row := s.statements.selectReservedMediaExists.QueryRowContext(s.ctx, origin, mediaId)
val := false
err := row.Scan(&val)
if err == sql.ErrNoRows {
err = nil
val = false
}
return val, err
}
......@@ -32,6 +32,7 @@ const selectThumbnailsForMedia = "SELECT origin, media_id, content_type, width,
const selectOldThumbnails = "SELECT origin, media_id, content_type, width, height, method, animated, sha256_hash, size_bytes, creation_ts, datastore_id, location FROM thumbnails WHERE sha256_hash IN (SELECT t2.sha256_hash FROM thumbnails AS t2 WHERE t2.creation_ts < $1);"
const deleteThumbnail = "DELETE FROM thumbnails WHERE origin = $1 AND media_id = $2 AND content_type = $3 AND width = $4 AND height = $5 AND method = $6 AND animated = $7 AND sha256_hash = $8 AND size_bytes = $9 AND creation_ts = $10 AND datastore_id = $11 AND location = $11;"
const updateThumbnailLocation = "UPDATE thumbnails SET datastore_id = $3, location = $4 WHERE datastore_id = $1 AND location = $2;"
const selectThumbnailsByLocation = "SELECT origin, media_id, content_type, width, height, method, animated, sha256_hash, size_bytes, creation_ts, datastore_id, location FROM thumbnails WHERE datastore_id = $1 AND location = $2;"
type thumbnailsTableStatements struct {
selectThumbnailByParams *sql.Stmt
......@@ -41,6 +42,7 @@ type thumbnailsTableStatements struct {
selectOldThumbnails *sql.Stmt
deleteThumbnail *sql.Stmt
updateThumbnailLocation *sql.Stmt
selectThumbnailsByLocation *sql.Stmt
}
type thumbnailsTableWithContext struct {
......@@ -73,6 +75,9 @@ func prepareThumbnailsTables(db *sql.DB) (*thumbnailsTableStatements, error) {
if stmts.updateThumbnailLocation, err = db.Prepare(updateThumbnailLocation); err != nil {
return nil, errors.New("error preparing updateThumbnailLocation: " + err.Error())
}
if stmts.selectThumbnailsByLocation, err = db.Prepare(selectThumbnailsByLocation); err != nil {
return nil, errors.New("error preparing selectThumbnailsByLocation: " + err.Error())
}
return stmts, nil
}
......@@ -95,9 +100,8 @@ func (s *thumbnailsTableWithContext) GetByParams(origin string, mediaId string,
return val, err
}
func (s *thumbnailsTableWithContext) GetForMedia(origin string, mediaId string) ([]*DbThumbnail, error) {
func (s *thumbnailsTableWithContext) scanRows(rows *sql.Rows, err error) ([]*DbThumbnail, error) {
results := make([]*DbThumbnail, 0)
rows, err := s.statements.selectThumbnailsForMedia.QueryContext(s.ctx, origin, mediaId)
if err != nil {
if err == sql.ErrNoRows {
return results, nil
......@@ -111,26 +115,20 @@ func (s *thumbnailsTableWithContext) GetForMedia(origin string, mediaId string)
}
results = append(results, val)
}
return results, nil
}
func (s *thumbnailsTableWithContext) GetForMedia(origin string, mediaId string) ([]*DbThumbnail, error) {
return s.scanRows(s.statements.selectThumbnailsForMedia.QueryContext(s.ctx, origin, mediaId))
}
func (s *thumbnailsTableWithContext) GetOlderThan(ts int64) ([]*DbThumbnail, error) {
results := make([]*DbThumbnail, 0)
rows, err := s.statements.selectOldThumbnails.QueryContext(s.ctx, ts)
if err != nil {
if err == sql.ErrNoRows {
return results, nil
}
return nil, err
}
for rows.Next() {
val := &DbThumbnail{Locatable: &Locatable{}}
if err = rows.Scan(&val.Origin, &val.MediaId, &val.ContentType, &val.Width, &val.Height, &val.Method, &val.Animated, &val.Sha256Hash, &val.SizeBytes, &val.CreationTs, &val.DatastoreId, &val.Location); err != nil {
return nil, err
}
results = append(results, val)
}
return results, nil
return s.scanRows(s.statements.selectOldThumbnails.QueryContext(s.ctx, ts))
}
func (s *thumbnailsTableWithContext) GetByLocation(datastoreId string, location string) ([]*DbThumbnail, error) {
return s.scanRows(s.statements.selectThumbnailsByLocation.QueryContext(s.ctx, datastoreId, location))
}
func (s *thumbnailsTableWithContext) Insert(record *DbThumbnail) error {
......
DROP INDEX IF EXISTS idx_datastore_id_location_thumbnails;
DROP INDEX IF EXISTS idx_datastore_id_location_media;
CREATE INDEX IF NOT EXISTS idx_datastore_id_location_thumbnails ON thumbnails(datastore_id, location);
CREATE INDEX IF NOT EXISTS idx_datastore_id_location_media ON media(datastore_id, location);
......@@ -15,6 +15,7 @@ func GenerateMediaId(ctx rcontext.RequestContext, origin string) (string, error)
}
heldDb := database.GetInstance().HeldMedia.Prepare(ctx)
mediaDb := database.GetInstance().Media.Prepare(ctx)
reservedDb := database.GetInstance().ReservedMedia.Prepare(ctx)
var mediaId string
var err error
var exists bool
......@@ -41,6 +42,15 @@ func GenerateMediaId(ctx rcontext.RequestContext, origin string) (string, error)
continue
}
// Also check to see if the media ID is reserved due to a past action
exists, err = reservedDb.IdExists(origin, mediaId)
if err != nil {
return "", err
}
if exists {
continue
}
return mediaId, nil
}
return "", errors.New("internal limit reached: fell out of media ID generation loop")
......
package task_runner
import (
"errors"
"fmt"
"github.com/turt2live/matrix-media-repo/common"
"github.com/turt2live/matrix-media-repo/common/rcontext"
"github.com/turt2live/matrix-media-repo/database"
"github.com/turt2live/matrix-media-repo/datastores"
"github.com/turt2live/matrix-media-repo/util"
)
type purgeConfig struct {
IncludeQuarantined bool
}
type PurgeAuthContext struct {
UploaderUserId string
SourceOrigin string
}
func (c *PurgeAuthContext) canAffect(media *database.DbMedia) bool {
if c.UploaderUserId != "" && c.UploaderUserId != media.UserId {
return false
}
if c.SourceOrigin != "" && c.SourceOrigin != media.Origin {
return false
}
return true
}
func PurgeMedia(ctx rcontext.RequestContext, authContext *PurgeAuthContext, toHandle *QuarantineThis) ([]string, error) {
records, err := resolveMedia(ctx, "", toHandle)
if err != nil {
return nil, err
}
// Check auth on all records before actually processing them
for _, r := range records {
if !authContext.canAffect(r) {
return nil, common.ErrWrongUser
}
}
// Now we process all the records
return doPurge(ctx, records, &purgeConfig{IncludeQuarantined: true})
}
func doPurge(ctx rcontext.RequestContext, records []*database.DbMedia, config *purgeConfig) ([]string, error) {
mediaDb := database.GetInstance().Media.Prepare(ctx)
thumbsDb := database.GetInstance().Thumbnails.Prepare(ctx)
attrsDb := database.GetInstance().MediaAttributes.Prepare(ctx)
reservedDb := database.GetInstance().ReservedMedia.Prepare(ctx)
// Filter the records early on to remove things we're not going to handle
ctx.Log.Debug("Purge pre-filter")
records2 := make([]*database.DbMedia, 0)
for _, r := range records {
if r.Quarantined && !config.IncludeQuarantined {
continue // skip quarantined media so later loops don't try to purge it
}
attrs, err := attrsDb.Get(r.Origin, r.MediaId)
if err != nil {
return nil, err
}
if attrs != nil && attrs.Purpose == database.PurposePinned {
continue
}
records2 = append(records2, r)
}
records = records2
flagMap := make(map[string]map[string]bool) // outer key = file location, inner key = MXC, value = in records[]
thumbsMap := make(map[string][]*database.DbThumbnail)
// First, we identify all the media which is using the file references we think we want to delete
// This includes thumbnails (flagged under the original media MXC URI)
ctx.Log.Debug("Stage 1 of purge")
doFlagging := func(datastoreId string, location string) error {
locationId := fmt.Sprintf("%s/%s", datastoreId, location)
if _, ok := flagMap[locationId]; ok {
return nil // we already processed this file location - skip trying to populate from it
}
flagMap[locationId] = make(map[string]bool)
// Find media records first
media, err := mediaDb.GetByLocation(datastoreId, location)
if err != nil {
return err
}
for _, r2 := range media {
mxc := util.MxcUri(r2.Origin, r2.MediaId)
flagMap[locationId][mxc] = false
}
// Now thumbnails
thumbs, err := thumbsDb.GetByLocation(datastoreId, location)
if err != nil {
return err
}
for _, r2 := range thumbs {
mxc := util.MxcUri(r2.Origin, r2.MediaId)
flagMap[locationId][mxc] = false
}
return nil
}
for _, r := range records {
if err := doFlagging(r.DatastoreId, r.Location); err != nil {
return nil, err
}
// We also grab all the thumbnails of the proposed media to clear those files out safely too
thumbs, err := thumbsDb.GetForMedia(r.Origin, r.MediaId)
if err != nil {
return nil, err
}
thumbsMap[util.MxcUri(r.Origin, r.MediaId)] = thumbs
for _, t := range thumbs {
if err = doFlagging(t.DatastoreId, t.Location); err != nil {
return nil, err
}
}
}
// Next, we re-iterate to flag records as being deleted
ctx.Log.Debug("Stage 2 of purge")
markBeingPurged := func(locationId string, mxc string) error {
if m, ok := flagMap[locationId]; !ok {
return errors.New("logic error: missing flag map for location ID in second step")
} else {
if v, ok := m[mxc]; !ok {
return errors.New("logic error: missing flag map value for MXC URI in second step")
} else if !v { // if v is `true` then it's already been processed - skip a write step
m[mxc] = true
}
}
return nil
}
for _, r := range records {
locationId := fmt.Sprintf("%s/%s", r.DatastoreId, r.Location)
mxc := util.MxcUri(r.Origin, r.MediaId)
if err := markBeingPurged(locationId, mxc); err != nil {
return nil, err
}
// Mark the thumbnails too
if thumbs, ok := thumbsMap[mxc]; !ok {
return nil, errors.New("logic error: missing thumbnails map value for MXC URI in second step")
} else {
for _, t := range thumbs {
locationId = fmt.Sprintf("%s/%s", t.DatastoreId, t.Location)
mxc = util.MxcUri(t.Origin, t.MediaId)
if err := markBeingPurged(locationId, mxc); err != nil {
return nil, err
}
}
}
}
// Finally, we can run through the records and start deleting media that's safe to delete
ctx.Log.Debug("Stage 3 of purge")
deletedLocations := make(map[string]bool)
removedMxcs := make([]string, 0)
tryRemoveDsFile := func(datastoreId string, location string) error {
locationId := fmt.Sprintf("%s/%s", datastoreId, location)
if _, ok := deletedLocations[locationId]; ok {
return nil // already deleted/handled
}
if m, ok := flagMap[locationId]; !ok {
return errors.New("logic error: missing flag map value for location ID in third step")
} else {
for _, b := range m {
if !b {
return nil // unsafe to delete, but no error
}
}
}
// Try deleting the file
err := datastores.RemoveWithDsId(ctx, datastoreId, location)
if err != nil {
return err
}
deletedLocations[locationId] = true
return nil
}
for _, r := range records {
mxc := util.MxcUri(r.Origin, r.MediaId)
if err := tryRemoveDsFile(r.DatastoreId, r.Location); err != nil {
return nil, err
}
if util.IsServerOurs(r.Origin) {
if err := reservedDb.InsertNoConflict(r.Origin, r.MediaId, "purged / deleted"); err != nil {
return nil, err
}
}
if !r.Quarantined { // keep quarantined flag
if err := mediaDb.Delete(r.Origin, r.MediaId); err != nil {
return nil, err
}
}
removedMxcs = append(removedMxcs, mxc)
// Remove the thumbnails too
if thumbs, ok := thumbsMap[mxc]; !ok {
return nil, errors.New("logic error: missing thumbnails for MXC URI in third step")
} else {
for _, t := range thumbs {
if err := tryRemoveDsFile(t.DatastoreId, t.Location); err != nil {
return nil, err
}
if err := thumbsDb.Delete(t); err != nil {
return nil, err
}
}
}
}
// Finally, we're done
return removedMxcs, nil
}
package task_runner
import (
"fmt"
"github.com/getsentry/sentry-go"
"github.com/turt2live/matrix-media-repo/common/config"
"github.com/turt2live/matrix-media-repo/common/rcontext"
"github.com/turt2live/matrix-media-repo/database"
"github.com/turt2live/matrix-media-repo/datastores"
"github.com/turt2live/matrix-media-repo/util"
)
......@@ -29,7 +26,6 @@ func PurgeRemoteMedia(ctx rcontext.RequestContext) {
// PurgeRemoteMediaBefore returns (count affected, error)
func PurgeRemoteMediaBefore(ctx rcontext.RequestContext, beforeTs int64) (int, error) {
mediaDb := database.GetInstance().Media.Prepare(ctx)
thumbsDb := database.GetInstance().Thumbnails.Prepare(ctx)
origins := util.GetOurDomains()
......@@ -38,47 +34,10 @@ func PurgeRemoteMediaBefore(ctx rcontext.RequestContext, beforeTs int64) (int, e
return 0, err
}
removed := 0
deletedLocations := make(map[string]bool)
for _, record := range records {
mxc := util.MxcUri(record.Origin, record.MediaId)
if record.Quarantined {
ctx.Log.Debugf("Skipping quarantined media %s", mxc)
continue // skip quarantined media
}
if exists, err := thumbsDb.LocationExists(record.DatastoreId, record.Location); err != nil {
ctx.Log.Error("Error checking for conflicting thumbnail: ", err)
sentry.CaptureException(err)
} else if !exists { // if exists, skip
locationId := fmt.Sprintf("%s/%s", record.DatastoreId, record.Location)
if _, ok := deletedLocations[locationId]; !ok {
ctx.Log.Debugf("Trying to remove datastore object for %s", mxc)
err = datastores.RemoveWithDsId(ctx, record.DatastoreId, record.Location)
if err != nil {
ctx.Log.Error("Error deleting media from datastore: ", err)
sentry.CaptureException(err)
continue
}
deletedLocations[locationId] = true
}
ctx.Log.Debugf("Trying to database record for %s", mxc)
if err = mediaDb.Delete(record.Origin, record.MediaId); err != nil {
ctx.Log.Error("Error deleting thumbnail record: ", err)
sentry.CaptureException(err)
}
removed = removed + 1
thumbs, err := thumbsDb.GetForMedia(record.Origin, record.MediaId)
if err != nil {
ctx.Log.Warn("Error getting thumbnails for media: ", err)
sentry.CaptureException(err)
continue
}
doPurgeThumbnails(ctx, thumbs)
}
removed, err := doPurge(ctx, records, &purgeConfig{IncludeQuarantined: false})
if err != nil {
return 0, err
}
return removed, nil
return len(removed), nil
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment