From ff7f22a1ae92bfe1c9eb80a90faf6a98c2de960e Mon Sep 17 00:00:00 2001
From: Travis Ralston <travpc@gmail.com>
Date: Wed, 25 Dec 2019 00:04:19 -0700
Subject: [PATCH] Support importing from a previously exported dataset

Fixes https://github.com/turt2live/matrix-media-repo/issues/208
---
 api/custom/exports.go                         |   2 +-
 api/custom/imports.go                         |  71 ++++
 api/webserver/webserver.go                    |   6 +
 .../data_controller/export_controller.go      |   9 +-
 .../data_controller/import_controller.go      | 348 ++++++++++++++++++
 docs/admin.md                                 |  40 +-
 storage/datastore/datastore.go                |  20 +
 storage/datastore/ds_s3/s3_store.go           |  15 +
 8 files changed, 506 insertions(+), 5 deletions(-)
 create mode 100644 api/custom/imports.go
 create mode 100644 controllers/data_controller/import_controller.go

diff --git a/api/custom/exports.go b/api/custom/exports.go
index 1c8091b0..9d1c9238 100644
--- a/api/custom/exports.go
+++ b/api/custom/exports.go
@@ -52,7 +52,7 @@ func ExportUserData(r *http.Request, log *logrus.Entry, user api.UserInfo) inter
 
 	userId := params["userId"]
 
-	if !isAdmin && user.UserId != userId {
+	if !isAdmin && user.UserId != userId  {
 		return api.BadRequest("cannot export data for another user")
 	}
 
diff --git a/api/custom/imports.go b/api/custom/imports.go
new file mode 100644
index 00000000..5592d15d
--- /dev/null
+++ b/api/custom/imports.go
@@ -0,0 +1,71 @@
+package custom
+
+import (
+	"net/http"
+
+	"github.com/gorilla/mux"
+	"github.com/sirupsen/logrus"
+	"github.com/turt2live/matrix-media-repo/api"
+	"github.com/turt2live/matrix-media-repo/common/config"
+	"github.com/turt2live/matrix-media-repo/controllers/data_controller"
+)
+
+type ImportStarted struct {
+	ImportID string `json:"import_id"`
+	TaskID   int    `json:"task_id"`
+}
+
+func StartImport(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} {
+	if !config.Get().Archiving.Enabled {
+		return api.BadRequest("archiving is not enabled")
+	}
+
+	defer r.Body.Close()
+	task, importId, err := data_controller.StartImport(r.Body, log)
+	if err != nil {
+		log.Error(err)
+		return api.InternalServerError("fatal error starting import")
+	}
+
+	return &api.DoNotCacheResponse{Payload: &ImportStarted{
+		TaskID:   task.ID,
+		ImportID: importId,
+	}}
+}
+
+func AppendToImport(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} {
+	if !config.Get().Archiving.Enabled {
+		return api.BadRequest("archiving is not enabled")
+	}
+
+	params := mux.Vars(r)
+
+	importId := params["importId"]
+
+	defer r.Body.Close()
+	err := data_controller.AppendToImport(importId, r.Body)
+	if err != nil {
+		log.Error(err)
+		return api.InternalServerError("fatal error appending to import")
+	}
+
+	return &api.DoNotCacheResponse{Payload: &api.EmptyResponse{}}
+}
+
+func StopImport(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} {
+	if !config.Get().Archiving.Enabled {
+		return api.BadRequest("archiving is not enabled")
+	}
+
+	params := mux.Vars(r)
+
+	importId := params["importId"]
+
+	err := data_controller.StopImport(importId)
+	if err != nil {
+		log.Error(err)
+		return api.InternalServerError("fatal error stopping import")
+	}
+
+	return &api.DoNotCacheResponse{Payload: &api.EmptyResponse{}}
+}
diff --git a/api/webserver/webserver.go b/api/webserver/webserver.go
index 37b5d6db..37a6bae8 100644
--- a/api/webserver/webserver.go
+++ b/api/webserver/webserver.go
@@ -60,6 +60,9 @@ func Init() {
 	getExportMetadataHandler := handler{api.AccessTokenOptionalRoute(custom.GetExportMetadata), "get_export_metadata", counter, false}
 	downloadExportPartHandler := handler{api.AccessTokenOptionalRoute(custom.DownloadExportPart), "download_export_part", counter, false}
 	deleteExportHandler := handler{api.AccessTokenOptionalRoute(custom.DeleteExport), "delete_export", counter, false}
+	startImportHandler := handler{api.RepoAdminRoute(custom.StartImport), "start_import", counter, false}
+	appendToImportHandler := handler{api.RepoAdminRoute(custom.AppendToImport), "append_to_import", counter, false}
+	stopImportHandler := handler{api.RepoAdminRoute(custom.StopImport), "stop_import", counter, false}
 
 	routes := make(map[string]route)
 	versions := []string{"r0", "v1", "unstable"} // r0 is typically clients and v1 is typically servers. v1 is deprecated.
@@ -102,6 +105,9 @@ func Init() {
 		routes["/_matrix/media/"+version+"/admin/export/{exportId:[a-zA-Z0-9.:\\-_]+}/metadata"] = route{"GET", getExportMetadataHandler}
 		routes["/_matrix/media/"+version+"/admin/export/{exportId:[a-zA-Z0-9.:\\-_]+}/part/{partId:[0-9]+}"] = route{"GET", downloadExportPartHandler}
 		routes["/_matrix/media/"+version+"/admin/export/{exportId:[a-zA-Z0-9.:\\-_]+}/delete"] = route{"DELETE", deleteExportHandler}
+		routes["/_matrix/media/"+version+"/admin/import"] = route{"POST", startImportHandler}
+		routes["/_matrix/media/"+version+"/admin/import/{importId:[a-zA-Z0-9.:\\-_]+}/part"] = route{"POST", appendToImportHandler}
+		routes["/_matrix/media/"+version+"/admin/import/{importId:[a-zA-Z0-9.:\\-_]+}/close"] = route{"POST", stopImportHandler}
 
 		// Routes that we should handle but aren't in the media namespace (synapse compat)
 		routes["/_matrix/client/"+version+"/admin/purge_media_cache"] = route{"POST", purgeRemote}
diff --git a/controllers/data_controller/export_controller.go b/controllers/data_controller/export_controller.go
index edf375a5..7ec542a6 100644
--- a/controllers/data_controller/export_controller.go
+++ b/controllers/data_controller/export_controller.go
@@ -36,9 +36,12 @@ type manifestRecord struct {
 
 type manifest struct {
 	Version   int                        `json:"version"`
-	UserId    string                     `json:"user_id"`
+	EntityId  string                     `json:"entity_id"`
 	CreatedTs int64                      `json:"created_ts"`
 	Media     map[string]*manifestRecord `json:"media"`
+
+	// Deprecated: for v1 manifests
+	UserId string `json:"user_id,omitempty"`
 }
 
 func StartServerExport(serverName string, s3urls bool, includeData bool, log *logrus.Entry) (*types.BackgroundTask, string, error) {
@@ -280,8 +283,8 @@ func compileArchive(exportId string, entityId string, archiveDs *datastore.Datas
 		})
 	}
 	manifest := &manifest{
-		Version:   1,
-		UserId:    entityId,
+		Version:   2,
+		EntityId:  entityId,
 		CreatedTs: util.NowMillis(),
 		Media:     mediaManifest,
 	}
diff --git a/controllers/data_controller/import_controller.go b/controllers/data_controller/import_controller.go
new file mode 100644
index 00000000..c769d489
--- /dev/null
+++ b/controllers/data_controller/import_controller.go
@@ -0,0 +1,348 @@
+package data_controller
+
+import (
+	"archive/tar"
+	"bytes"
+	"compress/gzip"
+	"context"
+	"database/sql"
+	"encoding/json"
+	"errors"
+	"io"
+	"net/http"
+	"sync"
+
+	"github.com/sirupsen/logrus"
+	"github.com/turt2live/matrix-media-repo/common"
+	"github.com/turt2live/matrix-media-repo/controllers/upload_controller"
+	"github.com/turt2live/matrix-media-repo/storage"
+	"github.com/turt2live/matrix-media-repo/storage/datastore"
+	"github.com/turt2live/matrix-media-repo/storage/datastore/ds_s3"
+	"github.com/turt2live/matrix-media-repo/types"
+	"github.com/turt2live/matrix-media-repo/util"
+)
+
+type importUpdate struct {
+	stop    bool
+	fileMap map[string]*bytes.Buffer
+}
+
+var openImports = &sync.Map{} // importId => updateChan
+
+func StartImport(data io.Reader, log *logrus.Entry) (*types.BackgroundTask, string, error) {
+	ctx := context.Background()
+
+	// Prepare the first update for the import (sync, so we can error)
+	// We do this before anything else because if the archive is invalid then we shouldn't
+	// even bother with an import.
+	results, err := processArchive(data)
+	if err != nil {
+		return nil, "", err
+	}
+
+	importId, err := util.GenerateRandomString(128)
+	if err != nil {
+		return nil, "", err
+	}
+
+	db := storage.GetDatabase().GetMetadataStore(ctx, log)
+	task, err := db.CreateBackgroundTask("import_data", map[string]interface{}{
+		"import_id": importId,
+	})
+
+	if err != nil {
+		return nil, "", err
+	}
+
+	// Start the import and send it its first update
+	updateChan := make(chan *importUpdate)
+	go doImport(updateChan, task.ID, importId, ctx, log)
+	openImports.Store(importId, updateChan)
+	updateChan <- &importUpdate{stop: false, fileMap: results}
+
+	return task, importId, nil
+}
+
+func AppendToImport(importId string, data io.Reader) error {
+	runningImport, ok := openImports.Load(importId)
+	if !ok || runningImport == nil {
+		return errors.New("import not found or it has been closed")
+	}
+
+	results, err := processArchive(data)
+	if err != nil {
+		return err
+	}
+
+	updateChan := runningImport.(chan *importUpdate)
+	updateChan <- &importUpdate{stop: false, fileMap: results}
+
+	return nil
+}
+
+func StopImport(importId string) error {
+	runningImport, ok := openImports.Load(importId)
+	if !ok || runningImport == nil {
+		return errors.New("import not found or it has been closed")
+	}
+
+	updateChan := runningImport.(chan *importUpdate)
+	updateChan <- &importUpdate{stop: true, fileMap: make(map[string]*bytes.Buffer)}
+
+	return nil
+}
+
+func processArchive(data io.Reader) (map[string]*bytes.Buffer, error) {
+	archiver, err := gzip.NewReader(data)
+	if err != nil {
+		return nil, err
+	}
+
+	defer archiver.Close()
+
+	tarFile := tar.NewReader(archiver)
+	index := make(map[string]*bytes.Buffer)
+	for {
+		header, err := tarFile.Next()
+		if err == io.EOF {
+			break // we're done
+		}
+		if err != nil {
+			return nil, err
+		}
+
+		if header == nil {
+			continue // skip this weird file
+		}
+		if header.Typeflag != tar.TypeReg {
+			continue // skip directories and other stuff
+		}
+
+		// Copy the file into our index
+		buf := &bytes.Buffer{}
+		_, err = io.Copy(buf, tarFile)
+		if err != nil {
+			return nil, err
+		}
+		buf = bytes.NewBuffer(buf.Bytes()) // clone to reset reader position
+		index[header.Name] = buf
+	}
+
+	return index, nil
+}
+
+func doImport(updateChannel chan *importUpdate, taskId int, importId string, ctx context.Context, log *logrus.Entry) {
+	log.Info("Preparing for import...")
+	fileMap := make(map[string]*bytes.Buffer)
+	stopImport := false
+	archiveManifest := &manifest{}
+	haveManifest := false
+	imported := make(map[string]bool)
+	db := storage.GetDatabase().GetMediaStore(ctx, log)
+
+	for !stopImport {
+		update := <-updateChannel
+		if update.stop {
+			log.Info("Close requested")
+			stopImport = true
+		}
+
+		// Populate files
+		for name, fileBytes := range update.fileMap {
+			if _, ok := fileMap[name]; ok {
+				log.Warnf("Duplicate file name, skipping: %s", name)
+				continue // file already known to us
+			}
+			log.Infof("Tracking file: %s", name)
+			fileMap[name] = fileBytes
+		}
+
+		// TODO: Search for a manifest and import a bunch of files
+		var manifestBuf *bytes.Buffer
+		var ok bool
+		if manifestBuf, ok = fileMap["manifest.json"]; !ok {
+			log.Info("No manifest found - waiting for more files")
+			continue
+		}
+
+		if !haveManifest {
+			haveManifest = true
+			err := json.Unmarshal(manifestBuf.Bytes(), archiveManifest)
+			if err != nil {
+				log.Error("Failed to parse manifest - giving up on import")
+				log.Error(err)
+				break
+			}
+			if archiveManifest.Version != 1 && archiveManifest.Version != 2 {
+				log.Error("Unsupported archive version")
+				break
+			}
+			if archiveManifest.Version == 1 {
+				archiveManifest.EntityId = archiveManifest.UserId
+			}
+			if archiveManifest.EntityId == "" {
+				log.Error("Invalid manifest: no entity")
+				break
+			}
+			if archiveManifest.Media == nil {
+				log.Error("Invalid manifest: no media")
+				break
+			}
+			log.Infof("Using manifest for %s (v%d) created %d", archiveManifest.EntityId, archiveManifest.Version, archiveManifest.CreatedTs)
+		}
+
+		if !haveManifest {
+			// Without a manifest we can't import anything
+			continue
+		}
+
+		for mxc, record := range archiveManifest.Media {
+			_, found := imported[mxc]
+			if found {
+				continue // already imported
+			}
+
+			userId := archiveManifest.EntityId
+			if userId[0] != '@' {
+				userId = "" // assume none for now
+			}
+
+			kind := common.KindLocalMedia
+			serverName := archiveManifest.EntityId
+			if userId != "" {
+				_, s, err := util.SplitUserId(userId)
+				if err != nil {
+					log.Errorf("Invalid user ID: %s", userId)
+					serverName = ""
+				} else {
+					serverName = s
+				}
+			}
+			if !util.IsServerOurs(serverName) {
+				kind = common.KindRemoteMedia
+			}
+
+			log.Infof("Attempting to import %s for %s", mxc, archiveManifest.EntityId)
+			buf, found := fileMap[record.ArchivedName]
+			if found {
+				log.Info("Using file from memory")
+				closer := util.BufferToStream(buf)
+				_, err := upload_controller.StoreDirect(closer, record.SizeBytes, record.ContentType, record.FileName, userId, record.Origin, record.MediaId, kind, ctx, log)
+				if err != nil {
+					log.Errorf("Error importing file: %s", err.Error())
+					continue
+				}
+			} else if record.S3Url != "" {
+				log.Info("Using S3 URL")
+				endpoint, bucket, location, err := ds_s3.ParseS3URL(record.S3Url)
+				if err != nil {
+					log.Errorf("Error importing file: %s", err.Error())
+					continue
+				}
+
+				log.Infof("Seeing if a datastore for %s/%s exists", endpoint, bucket)
+				datastores, err := datastore.GetAvailableDatastores()
+				if err != nil {
+					log.Errorf("Error locating datastore: %s", err.Error())
+					continue
+				}
+				imported := false
+				for _, ds := range datastores {
+					if ds.Type != "s3" {
+						continue
+					}
+
+					tmplUrl, err := ds_s3.GetS3URL(ds.DatastoreId, location)
+					if err != nil {
+						log.Errorf("Error investigating s3 datastore: %s", err.Error())
+						continue
+					}
+					if tmplUrl == record.S3Url {
+						log.Infof("File matches! Assuming the file has been uploaded already")
+
+						existingRecord, err := db.Get(record.Origin, record.MediaId)
+						if err != nil && err != sql.ErrNoRows {
+							log.Errorf("Error testing file in database: %s", err.Error())
+							break
+						}
+						if err != sql.ErrNoRows && existingRecord != nil {
+							log.Warnf("Media %s already exists - skipping without altering record", existingRecord.MxcUri())
+							imported = true
+							break
+						}
+
+						media := &types.Media{
+							Origin:      record.Origin,
+							MediaId:     record.MediaId,
+							UploadName:  record.FileName,
+							ContentType: record.ContentType,
+							UserId:      userId,
+							Sha256Hash:  record.Sha256,
+							SizeBytes:   record.SizeBytes,
+							DatastoreId: ds.DatastoreId,
+							Location:    location,
+							CreationTs:  record.CreatedTs,
+						}
+
+						err = db.Insert(media)
+						if err != nil {
+							log.Errorf("Error creating media record: %s", err.Error())
+							break
+						}
+
+						log.Infof("Media %s has been imported", media.MxcUri())
+						imported = true
+						break
+					}
+				}
+
+				if !imported {
+					log.Info("No datastore found - trying to upload by downloading first")
+					r, err := http.DefaultClient.Get(record.S3Url)
+					if err != nil {
+						log.Errorf("Error trying to download file from S3 via HTTP: ", err.Error())
+						continue
+					}
+
+					_, err = upload_controller.StoreDirect(r.Body, r.ContentLength, record.ContentType, record.FileName, userId, record.Origin, record.MediaId, kind, ctx, log)
+					if err != nil {
+						log.Errorf("Error importing file: %s", err.Error())
+						continue
+					}
+				}
+			} else {
+				log.Warn("Missing usable file for import - assuming it will show up in a future upload")
+				continue
+			}
+
+			log.Info("Counting file as imported")
+			imported[mxc] = true
+		}
+
+		missingAny := false
+		for mxc, _ := range archiveManifest.Media {
+			_, found := imported[mxc]
+			if found {
+				continue // already imported
+			}
+			missingAny = true
+			break
+		}
+
+		if !missingAny {
+			log.Info("No more files to import - closing import")
+			stopImport = true
+		}
+	}
+
+	openImports.Delete(importId)
+
+	log.Info("Finishing import task")
+	dbMeta := storage.GetDatabase().GetMetadataStore(ctx, log)
+	err := dbMeta.FinishedBackgroundTask(taskId)
+	if err != nil {
+		log.Error(err)
+		log.Error("Failed to flag task as finished")
+	}
+	log.Info("Finished import")
+}
diff --git a/docs/admin.md b/docs/admin.md
index ae390702..30ece387 100644
--- a/docs/admin.md
+++ b/docs/admin.md
@@ -392,4 +392,42 @@ The response is an empty JSON object if successful.
 
 #### Importing a previous export
 
-Not yet implemented.
+Once an export has been completed it can be imported back into the media repo. Files that are already known to the repo will not be overwritten - it'll use its known copy first.
+
+**Note**: Imports happen in memory, which can balloon quickly depending on how you exported your data. Although you can import data without s3 it is recommended that you only import from archives generated with `include_data=false`.
+
+**Note**: Only repository administrators can perform imports, regardless of who they are for.
+
+URL: `POST /_matrix/media/unstable/admin/import`
+
+The request body is the bytes of the first archive (eg: `TravisR-part-1.tgz` in the above examples).
+
+The response body will be something like the following: 
+```json
+{ 
+  "import_id": "abcdef",
+  "task_id": 13
+}
+```
+
+**Note**: the `import_id` will be included in the task's `params`.
+
+**Note**: the `import_id` should be treated as a secret/authentication token as it could allow for an attacker to change what the user has uploaded.
+
+To import the subsequent parts of an export, use the following endpoint and supply the archive as the request body: `POST /_matrix/media/unstable/admin/import/<import ID>/part`
+
+The parts can be uploaded in any order and will be extracted in memory.
+
+Imports will look for the files included from the archives, though if an S3 URL is available and the file isn't found it will use that instead. If the S3 URL points at a known datastore for the repo, it will assume the file exists and use that location without pulling it into memory.
+
+Imports stay open until all files have been imported (or until the process crashes). This also means you can upload the parts at your leisure instead of trying to push all the data up to the server as fast as possible. If the task is still considered running, the import is still open.
+
+**Note**: When using s3 URLs to do imports it is possible for the media to bypass checks like allowed file types, maximum sizes, and quarantines.
+
+#### Closing an import manually
+
+If you have no intention of continuing an import, use this endpoint.
+
+URL: `POST /_matrix/media/unstable/admin/import/<import ID>/close`
+
+The import will be closed and stop waiting for new files to show up. It will continue importing whatever files it already knows about - to forcefully end this task simply restart the process.
diff --git a/storage/datastore/datastore.go b/storage/datastore/datastore.go
index 04c57921..9fe6ba7d 100644
--- a/storage/datastore/datastore.go
+++ b/storage/datastore/datastore.go
@@ -13,6 +13,26 @@ import (
 	"github.com/turt2live/matrix-media-repo/types"
 )
 
+func GetAvailableDatastores() ([]*types.Datastore, error) {
+	datastores := make([]*types.Datastore, 0)
+	for _, ds := range config.Get().DataStores {
+		if !ds.Enabled {
+			continue
+		}
+
+		uri := GetUriForDatastore(ds)
+
+		dsInstance, err := storage.GetOrCreateDatastoreOfType(context.TODO(), &logrus.Entry{}, ds.Type, uri)
+		if err != nil {
+			return nil, err
+		}
+
+		datastores = append(datastores, dsInstance)
+	}
+
+	return datastores, nil
+}
+
 func LocateDatastore(ctx context.Context, log *logrus.Entry, datastoreId string) (*DatastoreRef, error) {
 	ds, err := storage.GetDatabase().GetMediaStore(ctx, log).GetDatastore(datastoreId)
 	if err != nil {
diff --git a/storage/datastore/ds_s3/s3_store.go b/storage/datastore/ds_s3/s3_store.go
index 64c7ba3f..e29c833d 100644
--- a/storage/datastore/ds_s3/s3_store.go
+++ b/storage/datastore/ds_s3/s3_store.go
@@ -7,6 +7,7 @@ import (
 	"io/ioutil"
 	"os"
 	"strconv"
+	"strings"
 
 	"github.com/minio/minio-go"
 	"github.com/pkg/errors"
@@ -76,6 +77,20 @@ func GetS3URL(datastoreId string, location string) (string, error) {
 	return fmt.Sprintf("https://%s/%s/%s", store.conf.Options["endpoint"], store.bucket, location), nil
 }
 
+func ParseS3URL(s3url string) (string, string, string, error) {
+	trimmed := s3url[8:] // trim off https
+	parts := strings.Split(trimmed, "/")
+	if len(parts) < 3 {
+		return "", "", "", errors.New("invalid url")
+	}
+
+	endpoint := parts[0]
+	location := parts[len(parts)-1]
+	bucket := strings.Join(parts[1:len(parts)-1], "/")
+
+	return endpoint, bucket, location, nil
+}
+
 func (s *s3Datastore) EnsureBucketExists() error {
 	found, err := s.client.BucketExists(s.bucket)
 	if err != nil {
-- 
GitLab