From f53e8531a03be1c6b0a6f54c12ff4761a149f540 Mon Sep 17 00:00:00 2001
From: Travis Ralston <travpc@gmail.com>
Date: Wed, 4 Sep 2019 18:00:14 -0600
Subject: [PATCH] Allow local admins and individual users to delete scoped
 media

---
 api/custom/purge.go                           | 60 ++++++++++++++++++-
 api/webserver/webserver.go                    |  4 +-
 .../maintainance_controller.go                | 17 ++++++
 storage/stores/media_store.go                 | 36 +++++++++++
 4 files changed, 112 insertions(+), 5 deletions(-)

diff --git a/api/custom/purge.go b/api/custom/purge.go
index 9b5f40a5..5bb507d8 100644
--- a/api/custom/purge.go
+++ b/api/custom/purge.go
@@ -1,13 +1,19 @@
 package custom
 
 import (
+	"database/sql"
 	"net/http"
 	"strconv"
 
 	"github.com/gorilla/mux"
 	"github.com/sirupsen/logrus"
 	"github.com/turt2live/matrix-media-repo/api"
+	"github.com/turt2live/matrix-media-repo/common"
 	"github.com/turt2live/matrix-media-repo/controllers/maintenance_controller"
+	"github.com/turt2live/matrix-media-repo/matrix"
+	"github.com/turt2live/matrix-media-repo/storage"
+	"github.com/turt2live/matrix-media-repo/types"
+	"github.com/turt2live/matrix-media-repo/util"
 )
 
 type MediaPurgedResponse struct {
@@ -39,7 +45,8 @@ func PurgeRemoteMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) int
 }
 
 func PurgeIndividualRecord(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} {
-	// TODO: Allow non-repo-admins to delete things
+	isGlobalAdmin, isLocalAdmin := getPurgeRequestInfo(r, log, user)
+	localServerName := r.Host
 
 	params := mux.Vars(r)
 
@@ -51,7 +58,32 @@ func PurgeIndividualRecord(r *http.Request, log *logrus.Entry, user api.UserInfo
 		"mediaId": mediaId,
 	})
 
+	// If the user is NOT a global admin, ensure they are speaking to the right server
+	if !isGlobalAdmin {
+		if server != localServerName {
+			return api.AuthFailed()
+		}
+		// If the user is NOT a local admin, ensure they uploaded the content in the first place
+		if !isLocalAdmin {
+			db := storage.GetDatabase().GetMediaStore(r.Context(), log)
+			m, err := db.Get(server, mediaId)
+			if err == sql.ErrNoRows {
+				return api.NotFoundError()
+			}
+			if err != nil {
+				log.Error("Error checking ownership of media: " + err.Error())
+				return api.InternalServerError("error checking media ownership")
+			}
+			if m.UserId != user.UserId {
+				return api.AuthFailed()
+			}
+		}
+	}
+
 	err := maintenance_controller.PurgeMedia(server, mediaId, r.Context(), log)
+	if err == sql.ErrNoRows || err == common.ErrMediaNotFound {
+		return api.NotFoundError()
+	}
 	if err != nil {
 		log.Error("Error purging media: " + err.Error())
 		return api.InternalServerError("error purging media")
@@ -61,9 +93,20 @@ func PurgeIndividualRecord(r *http.Request, log *logrus.Entry, user api.UserInfo
 }
 
 func PurgeQurantined(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} {
-	// TODO: Allow non-repo-admins to delete things
+	isGlobalAdmin, isLocalAdmin := getPurgeRequestInfo(r, log, user)
+	localServerName := r.Host
+
+	var affected []*types.Media
+	var err error
+
+	if isGlobalAdmin {
+		affected, err = maintenance_controller.PurgeQuarantined(r.Context(), log)
+	} else if isLocalAdmin {
+		affected, err = maintenance_controller.PurgeQuarantinedFor(localServerName, r.Context(), log)
+	} else {
+		return api.AuthFailed()
+	}
 
-	affected, err := maintenance_controller.PurgeQuarantined(r.Context(), log)
 	if err != nil {
 		log.Error("Error purging media: " + err.Error())
 		return api.InternalServerError("error purging media")
@@ -76,3 +119,14 @@ func PurgeQurantined(r *http.Request, log *logrus.Entry, user api.UserInfo) inte
 
 	return &api.DoNotCacheResponse{Payload: map[string]interface{}{"purged": true, "affected": mxcs}}
 }
+
+func getPurgeRequestInfo(r *http.Request, log *logrus.Entry, user api.UserInfo) (bool, bool) {
+	isGlobalAdmin := util.IsGlobalAdmin(user.UserId)
+	isLocalAdmin, err := matrix.IsUserAdmin(r.Context(), r.Host, user.AccessToken, r.RemoteAddr)
+	if err != nil {
+		log.Error("Error verifying local admin: " + err.Error())
+		return isGlobalAdmin, false
+	}
+
+	return isGlobalAdmin, isLocalAdmin
+}
diff --git a/api/webserver/webserver.go b/api/webserver/webserver.go
index 7274ee27..c3cf298c 100644
--- a/api/webserver/webserver.go
+++ b/api/webserver/webserver.go
@@ -32,8 +32,8 @@ func Init() {
 	previewUrlHandler := handler{api.AccessTokenRequiredRoute(r0.PreviewUrl), "url_preview", counter, false}
 	identiconHandler := handler{api.AccessTokenOptionalRoute(r0.Identicon), "identicon", counter, false}
 	purgeRemote := handler{api.RepoAdminRoute(custom.PurgeRemoteMedia), "purge_remote_media", counter, false}
-	purgeOneHandler := handler{api.RepoAdminRoute(custom.PurgeIndividualRecord), "purge_individual_media", counter, false}
-	purgeQuarantinedHandler := handler{api.RepoAdminRoute(custom.PurgeQurantined), "purge_quarantined", counter, false}
+	purgeOneHandler := handler{api.AccessTokenRequiredRoute(custom.PurgeIndividualRecord), "purge_individual_media", counter, false}
+	purgeQuarantinedHandler := handler{api.AccessTokenRequiredRoute(custom.PurgeQurantined), "purge_quarantined", counter, false}
 	quarantineHandler := handler{api.AccessTokenRequiredRoute(custom.QuarantineMedia), "quarantine_media", counter, false}
 	quarantineRoomHandler := handler{api.AccessTokenRequiredRoute(custom.QuarantineRoomMedia), "quarantine_room", counter, false}
 	localCopyHandler := handler{api.AccessTokenRequiredRoute(unstable.LocalCopy), "local_copy", counter, false}
diff --git a/controllers/maintenance_controller/maintainance_controller.go b/controllers/maintenance_controller/maintainance_controller.go
index c9b701e3..70e1579a 100644
--- a/controllers/maintenance_controller/maintainance_controller.go
+++ b/controllers/maintenance_controller/maintainance_controller.go
@@ -245,6 +245,23 @@ func PurgeQuarantined(ctx context.Context, log *logrus.Entry) ([]*types.Media, e
 	return records, nil
 }
 
+func PurgeQuarantinedFor(serverName string, ctx context.Context, log *logrus.Entry) ([]*types.Media, error) {
+	mediaDb := storage.GetDatabase().GetMediaStore(ctx, log)
+	records, err := mediaDb.GetQuarantinedMediaFor(serverName)
+	if err != nil {
+		return nil, err
+	}
+
+	for _, r := range records {
+		err = doPurge(r, ctx, log)
+		if err != nil {
+			return nil, err
+		}
+	}
+
+	return records, nil
+}
+
 func PurgeMedia(origin string, mediaId string, ctx context.Context, log *logrus.Entry) error {
 	media, err := download_controller.FindMediaRecord(origin, mediaId, false, ctx, log)
 	if err != nil {
diff --git a/storage/stores/media_store.go b/storage/stores/media_store.go
index e13b2a6a..2f88bd0f 100644
--- a/storage/stores/media_store.go
+++ b/storage/stores/media_store.go
@@ -27,6 +27,7 @@ const selectAllMediaForServer = "SELECT origin, media_id, upload_name, content_t
 const selectAllMediaForServerUsers = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, datastore_id, location, creation_ts, quarantined FROM media WHERE origin = $1 AND user_id = ANY($2)"
 const selectAllMediaForServerIds = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, datastore_id, location, creation_ts, quarantined FROM media WHERE origin = $1 AND media_id = ANY($2)"
 const selectQuarantinedMedia = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, datastore_id, location, creation_ts, quarantined FROM media WHERE quarantined = true;"
+const selectServerQuarantinedMedia = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, datastore_id, location, creation_ts, quarantined FROM media WHERE quarantined = true AND origin = $1;"
 
 var dsCacheByPath = sync.Map{} // [string] => Datastore
 var dsCacheById = sync.Map{}   // [string] => Datastore
@@ -50,6 +51,7 @@ type mediaStoreStatements struct {
 	selectAllMediaForServerUsers    *sql.Stmt
 	selectAllMediaForServerIds      *sql.Stmt
 	selectQuarantinedMedia          *sql.Stmt
+	selectServerQuarantinedMedia    *sql.Stmt
 }
 
 type MediaStoreFactory struct {
@@ -121,6 +123,9 @@ func InitMediaStore(sqlDb *sql.DB) (*MediaStoreFactory, error) {
 	if store.stmts.selectQuarantinedMedia, err = store.sqlDb.Prepare(selectQuarantinedMedia); err != nil {
 		return nil, err
 	}
+	if store.stmts.selectServerQuarantinedMedia, err = store.sqlDb.Prepare(selectServerQuarantinedMedia); err != nil {
+		return nil, err
+	}
 
 	return &store, nil
 }
@@ -525,3 +530,34 @@ func (s *MediaStore) GetAllQuarantinedMedia() ([]*types.Media, error) {
 
 	return results, nil
 }
+
+func (s *MediaStore) GetQuarantinedMediaFor(serverName string) ([]*types.Media, error) {
+	rows, err := s.statements.selectServerQuarantinedMedia.QueryContext(s.ctx, serverName)
+	if err != nil {
+		return nil, err
+	}
+
+	var results []*types.Media
+	for rows.Next() {
+		obj := &types.Media{}
+		err = rows.Scan(
+			&obj.Origin,
+			&obj.MediaId,
+			&obj.UploadName,
+			&obj.ContentType,
+			&obj.UserId,
+			&obj.Sha256Hash,
+			&obj.SizeBytes,
+			&obj.DatastoreId,
+			&obj.Location,
+			&obj.CreationTs,
+			&obj.Quarantined,
+		)
+		if err != nil {
+			return nil, err
+		}
+		results = append(results, obj)
+	}
+
+	return results, nil
+}
-- 
GitLab