diff --git a/api/custom/imports.go b/api/custom/imports.go index 0982e91ec662e7c516a32e723ff929253fc18697..ba31e58548050f4374f4182c297e0ac93dca8d59 100644 --- a/api/custom/imports.go +++ b/api/custom/imports.go @@ -43,7 +43,7 @@ func AppendToImport(r *http.Request, rctx rcontext.RequestContext, user api.User importId := params["importId"] defer cleanup.DumpAndCloseStream(r.Body) - err := data_controller.AppendToImport(importId, r.Body) + _, err := data_controller.AppendToImport(importId, r.Body, false) if err != nil { rctx.Log.Error(err) return api.InternalServerError("fatal error appending to import") diff --git a/cmd/gdpr_import/main.go b/cmd/gdpr_import/main.go index fd420b3a0f42774671d9ab095e8b99b4342a8dd1..2ac43b97b8e5b41d55673613b1762b8349d0f0be 100644 --- a/cmd/gdpr_import/main.go +++ b/cmd/gdpr_import/main.go @@ -15,6 +15,7 @@ import ( "github.com/turt2live/matrix-media-repo/common/runtime" "github.com/turt2live/matrix-media-repo/controllers/data_controller" "github.com/turt2live/matrix-media-repo/storage" + "github.com/turt2live/matrix-media-repo/util" ) func main() { @@ -51,10 +52,30 @@ func main() { files = append(files, path.Join(*filesDir, f.Name())) } + // Find the manifest so we can import as soon as possible + manifestIdx := 0 + for i, fname := range files { + logrus.Infof("Checking %s for export manifest", fname) + f, err := os.Open(fname) + if err != nil { + panic(err) + } + defer f.Close() + names, err := data_controller.GetFileNames(f) + if err != nil { + panic(err) + } + + if util.ArrayContains(names, "manifest.json") { + manifestIdx = i + break + } + } + logrus.Info("Starting import...") ctx := rcontext.Initial().LogWithFields(logrus.Fields{"flagDir": *filesDir}) - f, err := os.Open(files[0]) + f, err := os.Open(files[manifestIdx]) if err != nil { panic(err) } @@ -66,7 +87,7 @@ func main() { logrus.Info("Appending all other files to import") for i, fname := range files { - if i == 0 { + if i == manifestIdx { continue // already imported } @@ -76,10 +97,14 @@ func main() { panic(err) } defer f.Close() - err = data_controller.AppendToImport(importId, f) + ch, err := data_controller.AppendToImport(importId, f, true) if err != nil { panic(err) } + + logrus.Info("Waiting for file to be processed before moving on") + <-ch + close(ch) } logrus.Info("Waiting for import to complete") diff --git a/controllers/data_controller/import_controller.go b/controllers/data_controller/import_controller.go index 01a29be215fd4b3d8e992fd39987e37f0723a8db..7a84826a0d5ed59f98d85061717118e0912c2069 100644 --- a/controllers/data_controller/import_controller.go +++ b/controllers/data_controller/import_controller.go @@ -23,8 +23,9 @@ import ( ) type importUpdate struct { - stop bool - fileMap map[string]*bytes.Buffer + stop bool + fileMap map[string]*bytes.Buffer + onDoneChan chan bool } var openImports = &sync.Map{} // importId => updateChan @@ -61,21 +62,25 @@ func StartImport(data io.Reader, ctx rcontext.RequestContext) (*types.Background return task, importId, nil } -func AppendToImport(importId string, data io.Reader) error { +func AppendToImport(importId string, data io.Reader, withReturnChan bool) (chan bool, error) { runningImport, ok := openImports.Load(importId) if !ok || runningImport == nil { - return errors.New("import not found or it has been closed") + return nil, errors.New("import not found or it has been closed") } results, err := processArchive(data) if err != nil { - return err + return nil, err } + var doneChan chan bool + if withReturnChan { + doneChan = make(chan bool) + } updateChan := runningImport.(chan *importUpdate) - updateChan <- &importUpdate{stop: false, fileMap: results} + updateChan <- &importUpdate{stop: false, fileMap: results, onDoneChan: doneChan} - return nil + return doneChan, nil } func StopImport(importId string) error { @@ -90,6 +95,38 @@ func StopImport(importId string) error { return nil } +func GetFileNames(data io.Reader) ([]string, error) { + archiver, err := gzip.NewReader(data) + if err != nil { + return nil, err + } + + defer archiver.Close() + + tarFile := tar.NewReader(archiver) + names := make([]string, 0) + for { + header, err := tarFile.Next() + if err == io.EOF { + break // we're done + } + if err != nil { + return nil, err + } + + if header == nil { + continue // skip this weird file + } + if header.Typeflag != tar.TypeReg { + continue // skip directories and other stuff + } + + names = append(names, header.Name) + } + + return names, nil +} + func processArchive(data io.Reader) (map[string]*bytes.Buffer, error) { archiver, err := gzip.NewReader(data) if err != nil { @@ -142,9 +179,14 @@ func doImport(updateChannel chan *importUpdate, taskId int, importId string, ctx haveManifest := false imported := make(map[string]bool) db := storage.GetDatabase().GetMediaStore(ctx) + var update *importUpdate for !stopImport { - update := <-updateChannel + if update != nil && update.onDoneChan != nil { + ctx.Log.Info("Flagging tar as completed") + update.onDoneChan <- true + } + update = <-updateChannel if update.stop { ctx.Log.Info("Close requested") stopImport = true @@ -198,6 +240,8 @@ func doImport(updateChannel chan *importUpdate, taskId int, importId string, ctx continue } + toClear := make([]string, 0) + doClear := true for mxc, record := range archiveManifest.Media { _, found := imported[mxc] if found { @@ -232,8 +276,10 @@ func doImport(updateChannel chan *importUpdate, taskId int, importId string, ctx _, err := upload_controller.StoreDirect(nil, closer, record.SizeBytes, record.ContentType, record.FileName, userId, record.Origin, record.MediaId, kind, ctx, true) if err != nil { ctx.Log.Errorf("Error importing file: %s", err.Error()) + doClear = false // don't clear things on error continue } + toClear = append(toClear, record.ArchivedName) } else if record.S3Url != "" { ctx.Log.Info("Using S3 URL") endpoint, bucket, location, err := ds_s3.ParseS3URL(record.S3Url) @@ -327,6 +373,15 @@ func doImport(updateChannel chan *importUpdate, taskId int, importId string, ctx imported[mxc] = true } + if doClear { + ctx.Log.Info("Clearing up memory for imported files...") + for _, f := range toClear { + ctx.Log.Infof("Removing %s from memory", f) + delete(fileMap, f) + } + } + + ctx.Log.Info("Checking for any unimported files...") missingAny := false for mxc, _ := range archiveManifest.Media { _, found := imported[mxc] @@ -343,6 +398,12 @@ func doImport(updateChannel chan *importUpdate, taskId int, importId string, ctx } } + // Clean up the last tar file + if update != nil && update.onDoneChan != nil { + ctx.Log.Info("Flagging tar as completed") + update.onDoneChan <- true + } + openImports.Delete(importId) ctx.Log.Info("Finishing import task")