From 06781c8f08ba850827acf0670bb58cfe38ab8477 Mon Sep 17 00:00:00 2001
From: Travis Ralston <travpc@gmail.com>
Date: Tue, 1 Aug 2023 17:11:13 -0600
Subject: [PATCH] Support datastore migrations on new infrastructure

---
 api/custom/datastores.go               |  21 ++---
 database/table_media.go                |  10 +++
 database/table_tasks.go                |  10 +++
 database/table_thumbnails.go           |  10 +++
 database/virtualtable_metadata.go      |   7 +-
 tasks/exec.go                          |  11 ++-
 tasks/schedule.go                      |  11 ++-
 tasks/task_runner/00-internal.go       |  17 ++++
 tasks/task_runner/datastore_migrate.go | 113 +++++++++++++++++++++++++
 9 files changed, 191 insertions(+), 19 deletions(-)
 create mode 100644 tasks/task_runner/00-internal.go
 create mode 100644 tasks/task_runner/datastore_migrate.go

diff --git a/api/custom/datastores.go b/api/custom/datastores.go
index 19119858..a513f460 100644
--- a/api/custom/datastores.go
+++ b/api/custom/datastores.go
@@ -7,14 +7,13 @@ import (
 	"github.com/turt2live/matrix-media-repo/api/_routers"
 	"github.com/turt2live/matrix-media-repo/common/config"
 	"github.com/turt2live/matrix-media-repo/datastores"
+	"github.com/turt2live/matrix-media-repo/tasks"
 
 	"net/http"
 	"strconv"
 
 	"github.com/sirupsen/logrus"
 	"github.com/turt2live/matrix-media-repo/common/rcontext"
-	"github.com/turt2live/matrix-media-repo/controllers/maintenance_controller"
-	"github.com/turt2live/matrix-media-repo/storage/datastore"
 	"github.com/turt2live/matrix-media-repo/util"
 )
 
@@ -64,17 +63,11 @@ func MigrateBetweenDatastores(r *http.Request, rctx rcontext.RequestContext, use
 	if sourceDsId == targetDsId {
 		return _responses.BadRequest("Source and target datastore cannot be the same")
 	}
-
-	sourceDatastore, err := datastore.LocateDatastore(rctx, sourceDsId)
-	if err != nil {
-		rctx.Log.Error(err)
-		return _responses.BadRequest("Error getting source datastore. Does it exist?")
+	if _, ok := datastores.Get(rctx, sourceDsId); !ok {
+		return _responses.BadRequest("Source datastore does not appear to exist")
 	}
-
-	targetDatastore, err := datastore.LocateDatastore(rctx, targetDsId)
-	if err != nil {
-		rctx.Log.Error(err)
-		return _responses.BadRequest("Error getting target datastore. Does it exist?")
+	if _, ok := datastores.Get(rctx, targetDsId); !ok {
+		return _responses.BadRequest("Target datastore does not appear to exist")
 	}
 
 	estimate, err := datastores.SizeOfDsIdWithAge(rctx, sourceDsId, beforeTs)
@@ -85,7 +78,7 @@ func MigrateBetweenDatastores(r *http.Request, rctx rcontext.RequestContext, use
 	}
 
 	rctx.Log.Infof("User %s has started a datastore media transfer", user.UserId)
-	task, err := maintenance_controller.StartStorageMigration(sourceDatastore, targetDatastore, beforeTs, rctx)
+	task, err := tasks.RunDatastoreMigration(rctx, sourceDsId, targetDsId, beforeTs)
 	if err != nil {
 		rctx.Log.Error(err)
 		sentry.CaptureException(err)
@@ -94,7 +87,7 @@ func MigrateBetweenDatastores(r *http.Request, rctx rcontext.RequestContext, use
 
 	migration := &DatastoreMigration{
 		SizeEstimate: estimate,
-		TaskID:       task.ID,
+		TaskID:       task.TaskId,
 	}
 
 	return &_responses.DoNotCacheResponse{Payload: migration}
diff --git a/database/table_media.go b/database/table_media.go
index 51008da5..d3f1cbb2 100644
--- a/database/table_media.go
+++ b/database/table_media.go
@@ -43,6 +43,7 @@ const selectMediaByOriginAndUserIds = "SELECT origin, media_id, upload_name, con
 const selectMediaByOriginAndIds = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, creation_ts, quarantined, datastore_id, location FROM media WHERE origin = $1 AND media_id = ANY($2);"
 const selectOldMediaExcludingDomains = "SELECT m.origin, m.media_id, m.upload_name, m.content_type, m.user_id, m.sha256_hash, m.size_bytes, m.creation_ts, m.quarantined, m.datastore_id, m.location FROM media AS m WHERE m.origin <> ANY($1) AND m.creation_ts < $2 AND (SELECT COUNT(d.*) FROM media AS d WHERE d.sha256_hash = m.sha256_hash AND d.creation_ts >= $2) = 0 AND (SELECT COUNT(d.*) FROM media AS d WHERE d.sha256_hash = m.sha256_hash AND d.origin = ANY($1)) = 0;"
 const deleteMedia = "DELETE FROM media WHERE origin = $1 AND media_id = $2;"
+const updateMediaLocation = "UPDATE media SET datastore_id = $3, location = $4 WHERE datastore_id = $1 AND location = $2;"
 
 type mediaTableStatements struct {
 	selectDistinctMediaDatastoreIds *sql.Stmt
@@ -59,6 +60,7 @@ type mediaTableStatements struct {
 	selectMediaByOriginAndIds       *sql.Stmt
 	selectOldMediaExcludingDomains  *sql.Stmt
 	deleteMedia                     *sql.Stmt
+	updateMediaLocation             *sql.Stmt
 }
 
 type mediaTableWithContext struct {
@@ -112,6 +114,9 @@ func prepareMediaTables(db *sql.DB) (*mediaTableStatements, error) {
 	if stmts.deleteMedia, err = db.Prepare(deleteMedia); err != nil {
 		return nil, errors.New("error preparing deleteMedia: " + err.Error())
 	}
+	if stmts.updateMediaLocation, err = db.Prepare(updateMediaLocation); err != nil {
+		return nil, errors.New("error preparing updateMediaLocation: " + err.Error())
+	}
 
 	return stmts, nil
 }
@@ -252,3 +257,8 @@ func (s *mediaTableWithContext) Delete(origin string, mediaId string) error {
 	_, err := s.statements.deleteMedia.ExecContext(s.ctx, origin, mediaId)
 	return err
 }
+
+func (s *mediaTableWithContext) UpdateLocation(sourceDsId string, sourceLocation string, targetDsId string, targetLocation string) error {
+	_, err := s.statements.updateMediaLocation.ExecContext(s.ctx, sourceDsId, sourceLocation, targetDsId, targetLocation)
+	return err
+}
diff --git a/database/table_tasks.go b/database/table_tasks.go
index 84e34be3..511094d8 100644
--- a/database/table_tasks.go
+++ b/database/table_tasks.go
@@ -19,12 +19,14 @@ const selectTask = "SELECT id, task, params, start_ts, end_ts FROM background_ta
 const insertTask = "INSERT INTO background_tasks (task, params, start_ts, end_ts) VALUES ($1, $2, $3, 0) RETURNING id, task, params, start_ts, end_ts;"
 const selectAllTasks = "SELECT id, task, params, start_ts, end_ts FROM background_tasks;"
 const selectIncompleteTasks = "SELECT id, task, params, start_ts, end_ts FROM background_tasks WHERE end_ts <= 0;"
+const updateTaskEndTime = "UPDATE background_tasks SET end_ts = $2 WHERE id = $1;"
 
 type tasksTableStatements struct {
 	selectTask            *sql.Stmt
 	insertTask            *sql.Stmt
 	selectAllTasks        *sql.Stmt
 	selectIncompleteTasks *sql.Stmt
+	updateTaskEndTime     *sql.Stmt
 }
 
 type tasksTableWithContext struct {
@@ -48,6 +50,9 @@ func prepareTasksTables(db *sql.DB) (*tasksTableStatements, error) {
 	if stmts.selectIncompleteTasks, err = db.Prepare(selectIncompleteTasks); err != nil {
 		return nil, errors.New("error preparing selectIncompleteTasks: " + err.Error())
 	}
+	if stmts.updateTaskEndTime, err = db.Prepare(updateTaskEndTime); err != nil {
+		return nil, errors.New("error preparing updateTaskEndTime: " + err.Error())
+	}
 
 	return stmts, nil
 }
@@ -69,6 +74,11 @@ func (s *tasksTableWithContext) Insert(name string, params *AnonymousJson, start
 	return val, nil
 }
 
+func (s *tasksTableWithContext) SetEndTime(taskId int, endTs int64) error {
+	_, err := s.statements.updateTaskEndTime.ExecContext(s.ctx, taskId, endTs)
+	return err
+}
+
 func (s *tasksTableWithContext) Get(id int) (*DbTask, error) {
 	row := s.statements.selectTask.QueryRowContext(s.ctx, id)
 	val := &DbTask{}
diff --git a/database/table_thumbnails.go b/database/table_thumbnails.go
index 9b6bc1f4..c431621b 100644
--- a/database/table_thumbnails.go
+++ b/database/table_thumbnails.go
@@ -31,6 +31,7 @@ const selectThumbnailByLocationExists = "SELECT TRUE FROM thumbnails WHERE datas
 const selectThumbnailsForMedia = "SELECT origin, media_id, content_type, width, height, method, animated, sha256_hash, size_bytes, creation_ts, datastore_id, location FROM thumbnails WHERE origin = $1 AND media_id = $2;"
 const selectOldThumbnails = "SELECT origin, media_id, content_type, width, height, method, animated, sha256_hash, size_bytes, creation_ts, datastore_id, location FROM thumbnails WHERE sha256_hash IN (SELECT t2.sha256_hash FROM thumbnails AS t2 WHERE t2.creation_ts < $1);"
 const deleteThumbnail = "DELETE FROM thumbnails WHERE origin = $1 AND media_id = $2 AND content_type = $3 AND width = $4 AND height = $5 AND method = $6 AND animated = $7 AND sha256_hash = $8 AND size_bytes = $9 AND creation_ts = $10 AND datastore_id = $11 AND location = $11;"
+const updateThumbnailLocation = "UPDATE thumbnails SET datastore_id = $3, location = $4 WHERE datastore_id = $1 AND location = $2;"
 
 type thumbnailsTableStatements struct {
 	selectThumbnailByParams         *sql.Stmt
@@ -39,6 +40,7 @@ type thumbnailsTableStatements struct {
 	selectThumbnailsForMedia        *sql.Stmt
 	selectOldThumbnails             *sql.Stmt
 	deleteThumbnail                 *sql.Stmt
+	updateThumbnailLocation         *sql.Stmt
 }
 
 type thumbnailsTableWithContext struct {
@@ -68,6 +70,9 @@ func prepareThumbnailsTables(db *sql.DB) (*thumbnailsTableStatements, error) {
 	if stmts.deleteThumbnail, err = db.Prepare(deleteThumbnail); err != nil {
 		return nil, errors.New("error preparing deleteThumbnail: " + err.Error())
 	}
+	if stmts.updateThumbnailLocation, err = db.Prepare(updateThumbnailLocation); err != nil {
+		return nil, errors.New("error preparing updateThumbnailLocation: " + err.Error())
+	}
 
 	return stmts, nil
 }
@@ -148,3 +153,8 @@ func (s *thumbnailsTableWithContext) Delete(record *DbThumbnail) error {
 	_, err := s.statements.deleteThumbnail.ExecContext(s.ctx, record.Origin, record.MediaId, record.ContentType, record.Width, record.Height, record.Method, record.Animated, record.Sha256Hash, record.SizeBytes, record.CreationTs, record.DatastoreId, record.Location)
 	return err
 }
+
+func (s *thumbnailsTableWithContext) UpdateLocation(sourceDsId string, sourceLocation string, targetDsId string, targetLocation string) error {
+	_, err := s.statements.updateThumbnailLocation.ExecContext(s.ctx, sourceDsId, sourceLocation, targetDsId, targetLocation)
+	return err
+}
diff --git a/database/virtualtable_metadata.go b/database/virtualtable_metadata.go
index 6ad1dcb4..5030f1d0 100644
--- a/database/virtualtable_metadata.go
+++ b/database/virtualtable_metadata.go
@@ -14,13 +14,14 @@ type VirtLastAccess struct {
 	SizeBytes    int64
 	CreationTs   int64
 	LastAccessTs int64
+	ContentType  string
 }
 
 const selectEstimatedDatastoreSize = "SELECT COALESCE(SUM(m2.size_bytes), 0) + COALESCE((SELECT SUM(t2.size_bytes) FROM (SELECT DISTINCT t.sha256_hash, MAX(t.size_bytes) AS size_bytes FROM thumbnails AS t WHERE t.datastore_id = $1 GROUP BY t.sha256_hash) AS t2), 0) AS size_total FROM (SELECT DISTINCT m.sha256_hash, MAX(m.size_bytes) AS size_bytes FROM media AS m WHERE m.datastore_id = $1 GROUP BY m.sha256_hash) AS m2;"
 const selectUploadSizesForServer = "SELECT COALESCE((SELECT SUM(size_bytes) FROM media WHERE origin = $1), 0) AS media, COALESCE((SELECT SUM(size_bytes) FROM thumbnails WHERE origin = $1), 0) AS thumbnails;"
 const selectUploadCountsForServer = "SELECT COALESCE((SELECT COUNT(origin) FROM media WHERE origin = $1), 0) AS media, COALESCE((SELECT COUNT(origin) FROM thumbnails WHERE origin = $1), 0) AS thumbnails;"
-const selectMediaForDatastoreWithLastAccess = "SELECT m.sha256_hash, m.size_bytes, m.datastore_id, m.location, m.creation_ts, a.last_access_ts FROM media AS m JOIN last_access AS a ON m.sha256_hash = a.sha256_hash WHERE a.last_access_ts < $1 AND m.datastore_id = $2;"
-const selectThumbnailsForDatastoreWithLastAccess = "SELECT m.sha256_hash, m.size_bytes, m.datastore_id, m.location, m.creation_ts, a.last_access_ts FROM thumbnails AS m JOIN last_access AS a ON m.sha256_hash = a.sha256_hash WHERE a.last_access_ts < $1 AND m.datastore_id = $2;"
+const selectMediaForDatastoreWithLastAccess = "SELECT m.sha256_hash, m.size_bytes, m.datastore_id, m.location, m.creation_ts, a.last_access_ts, m.content_type FROM media AS m JOIN last_access AS a ON m.sha256_hash = a.sha256_hash WHERE a.last_access_ts < $1 AND m.datastore_id = $2;"
+const selectThumbnailsForDatastoreWithLastAccess = "SELECT m.sha256_hash, m.size_bytes, m.datastore_id, m.location, m.creation_ts, a.last_access_ts, m.content_type FROM thumbnails AS m JOIN last_access AS a ON m.sha256_hash = a.sha256_hash WHERE a.last_access_ts < $1 AND m.datastore_id = $2;"
 
 type SynStatUserOrderBy string
 
@@ -211,7 +212,7 @@ func (s *metadataVirtualTableWithContext) scanLastAccess(rows *sql.Rows, err err
 	}
 	for rows.Next() {
 		val := &VirtLastAccess{Locatable: &Locatable{}}
-		if err = rows.Scan(&val.Sha256Hash, &val.SizeBytes, &val.DatastoreId, &val.Location, &val.CreationTs, &val.LastAccessTs); err != nil {
+		if err = rows.Scan(&val.Sha256Hash, &val.SizeBytes, &val.DatastoreId, &val.Location, &val.CreationTs, &val.LastAccessTs, &val.ContentType); err != nil {
 			return nil, err
 		}
 		results = append(results, val)
diff --git a/tasks/exec.go b/tasks/exec.go
index a012fe18..dc9d3ca8 100644
--- a/tasks/exec.go
+++ b/tasks/exec.go
@@ -1,6 +1,7 @@
 package tasks
 
 import (
+	"fmt"
 	"time"
 
 	"github.com/getsentry/sentry-go"
@@ -8,6 +9,7 @@ import (
 	"github.com/turt2live/matrix-media-repo/common/rcontext"
 	"github.com/turt2live/matrix-media-repo/database"
 	"github.com/turt2live/matrix-media-repo/notifier"
+	"github.com/turt2live/matrix-media-repo/tasks/task_runner"
 	"github.com/turt2live/matrix-media-repo/util/ids"
 )
 
@@ -54,5 +56,12 @@ func tryBeginTask(id int, recur bool) {
 }
 
 func beginTask(task *database.DbTask) {
-	logrus.Warn(task)
+	runnerCtx := rcontext.Initial().LogWithFields(logrus.Fields{"task_id": task.TaskId})
+	if task.Name == string(TaskDatastoreMigrate) {
+		go task_runner.DatastoreMigrate(runnerCtx, task)
+	} else {
+		m := fmt.Sprintf("Received unknown task to run %s (ID: %d)", task.Name, task.TaskId)
+		logrus.Warn(m)
+		sentry.CaptureMessage(m)
+	}
 }
diff --git a/tasks/schedule.go b/tasks/schedule.go
index 7f12a1a2..c1f394b2 100644
--- a/tasks/schedule.go
+++ b/tasks/schedule.go
@@ -9,6 +9,7 @@ import (
 	"github.com/turt2live/matrix-media-repo/common/rcontext"
 	"github.com/turt2live/matrix-media-repo/database"
 	"github.com/turt2live/matrix-media-repo/notifier"
+	"github.com/turt2live/matrix-media-repo/tasks/task_runner"
 	"github.com/turt2live/matrix-media-repo/util"
 	"github.com/turt2live/matrix-media-repo/util/ids"
 )
@@ -17,7 +18,7 @@ type TaskName string
 type RecurringTaskName string
 
 const (
-// TaskTesting TaskName = "test1234"
+	TaskDatastoreMigrate TaskName = "storage_migration"
 )
 const (
 	RecurringTaskPurgeThumbnails  RecurringTaskName = "recurring_purge_thumbnails"
@@ -95,3 +96,11 @@ func stopRecurring() {
 		ch <- true
 	}
 }
+
+func RunDatastoreMigration(ctx rcontext.RequestContext, sourceDsId string, targetDsId string, beforeTs int64) (*database.DbTask, error) {
+	return scheduleTask(ctx, TaskDatastoreMigrate, task_runner.DatastoreMigrateParams{
+		SourceDsId: sourceDsId,
+		TargetDsId: targetDsId,
+		BeforeTs:   beforeTs,
+	})
+}
diff --git a/tasks/task_runner/00-internal.go b/tasks/task_runner/00-internal.go
new file mode 100644
index 00000000..611d272f
--- /dev/null
+++ b/tasks/task_runner/00-internal.go
@@ -0,0 +1,17 @@
+package task_runner
+
+import (
+	"github.com/getsentry/sentry-go"
+	"github.com/turt2live/matrix-media-repo/common/rcontext"
+	"github.com/turt2live/matrix-media-repo/database"
+	"github.com/turt2live/matrix-media-repo/util"
+)
+
+func markDone(ctx rcontext.RequestContext, task *database.DbTask) {
+	taskDb := database.GetInstance().Tasks.Prepare(ctx)
+	if err := taskDb.SetEndTime(task.TaskId, util.NowMillis()); err != nil {
+		ctx.Log.Warn("Error updating task as complete: ", err)
+		sentry.CaptureException(err)
+	}
+	ctx.Log.Infof("Task '%s' completed", task.Name)
+}
diff --git a/tasks/task_runner/datastore_migrate.go b/tasks/task_runner/datastore_migrate.go
new file mode 100644
index 00000000..1a248b28
--- /dev/null
+++ b/tasks/task_runner/datastore_migrate.go
@@ -0,0 +1,113 @@
+package task_runner
+
+import (
+	"fmt"
+
+	"github.com/getsentry/sentry-go"
+	"github.com/sirupsen/logrus"
+	"github.com/turt2live/matrix-media-repo/common/config"
+	"github.com/turt2live/matrix-media-repo/common/rcontext"
+	"github.com/turt2live/matrix-media-repo/database"
+	"github.com/turt2live/matrix-media-repo/datastores"
+)
+
+type DatastoreMigrateParams struct {
+	SourceDsId string `json:"source_datastore_id"`
+	TargetDsId string `json:"target_datastore_id"`
+	BeforeTs   int64  `json:"before_ts"`
+}
+
+func DatastoreMigrate(ctx rcontext.RequestContext, task *database.DbTask) {
+	defer markDone(ctx, task)
+
+	params := DatastoreMigrateParams{}
+	if err := task.Params.ApplyTo(&params); err != nil {
+		ctx.Log.Error("Error decoding params: ", err)
+		sentry.CaptureException(err)
+		return
+	}
+
+	if params.SourceDsId == params.TargetDsId {
+		ctx.Log.Error("Source and target datastore are the same")
+		return
+	}
+
+	sourceDs, ok := datastores.Get(ctx, params.SourceDsId)
+	if !ok {
+		ctx.Log.Error("Unable to locate source datastore ID")
+		return
+	}
+
+	targetDs, ok := datastores.Get(ctx, params.TargetDsId)
+	if !ok {
+		ctx.Log.Error("Unable to locate target datastore ID")
+		return
+	}
+
+	db := database.GetInstance().MetadataView.Prepare(ctx)
+
+	if records, err := db.GetMediaForDatastoreByLastAccess(params.SourceDsId, params.BeforeTs); err != nil {
+		ctx.Log.Error("Error getting movable media: ", err)
+		sentry.CaptureException(err)
+		return
+	} else {
+		moveDatastoreObjects(ctx, records, sourceDs, targetDs)
+	}
+
+	if records, err := db.GetThumbnailsForDatastoreByLastAccess(params.SourceDsId, params.BeforeTs); err != nil {
+		ctx.Log.Error("Error getting movable thumbnails: ", err)
+		sentry.CaptureException(err)
+		return
+	} else {
+		moveDatastoreObjects(ctx, records, sourceDs, targetDs)
+	}
+}
+
+func moveDatastoreObjects(ctx rcontext.RequestContext, records []*database.VirtLastAccess, sourceDs config.DatastoreConfig, targetDs config.DatastoreConfig) {
+	mediaDb := database.GetInstance().Media.Prepare(ctx)
+	thumbsDb := database.GetInstance().Thumbnails.Prepare(ctx)
+	done := make(map[string]bool)
+	for _, record := range records {
+		doneId := fmt.Sprintf("%s/%s", record.DatastoreId, record.Location)
+		if _, ok := done[doneId]; ok {
+			continue
+		}
+
+		recordCtx := ctx.LogWithFields(logrus.Fields{"sha256": record.Sha256Hash, "dsId": record.DatastoreId, "location": record.Location})
+		recordCtx.Log.Debug("Moving record")
+
+		sourceStream, err := datastores.Download(recordCtx, sourceDs, record.Location)
+		if err != nil {
+			recordCtx.Log.Error("Failed to start download from source: ", err)
+			sentry.CaptureException(err)
+			continue
+		}
+
+		newLocation, err := datastores.Upload(recordCtx, targetDs, sourceStream, record.SizeBytes, record.ContentType, record.Sha256Hash)
+		if err != nil {
+			recordCtx.Log.Error("Failed to upload to target: ", err)
+			sentry.CaptureException(err)
+			continue
+		}
+
+		if err = mediaDb.UpdateLocation(record.DatastoreId, record.Location, targetDs.Id, newLocation); err != nil {
+			recordCtx.Log.Error("Failed to update media table with new datastore and location: ", err)
+			sentry.CaptureException(err)
+			continue
+		}
+
+		if err = thumbsDb.UpdateLocation(record.DatastoreId, record.Location, targetDs.Id, newLocation); err != nil {
+			recordCtx.Log.Error("Failed to update thumbnails table with new datastore and location: ", err)
+			sentry.CaptureException(err)
+			continue
+		}
+
+		if err = datastores.Remove(recordCtx, sourceDs, record.Location); err != nil {
+			recordCtx.Log.Error("Failed to remove source object from datastore: ", err)
+			sentry.CaptureException(err)
+			continue
+		}
+
+		done[doneId] = true
+	}
+}
-- 
GitLab