From d57593d3857d08269c11622e9daeb22095dbd2f6 Mon Sep 17 00:00:00 2001
From: Travis Ralston <travpc@gmail.com>
Date: Fri, 21 Aug 2020 22:51:08 -0600
Subject: [PATCH] Initial creation of a common export interface

---
 archival/v2_export.go                 | 217 ++++++++++++++++++++++++++
 archival/v2_export_disk_writer.go     |  26 +++
 archival/v2_manifest.go               |  21 +++
 cmd/export_synapse_for_import/main.go | 189 +---------------------
 go.mod                                |   1 +
 go.sum                                |   2 +
 util/streams.go                       |   5 +
 7 files changed, 279 insertions(+), 182 deletions(-)
 create mode 100644 archival/v2_export.go
 create mode 100644 archival/v2_export_disk_writer.go
 create mode 100644 archival/v2_manifest.go

diff --git a/archival/v2_export.go b/archival/v2_export.go
new file mode 100644
index 00000000..f1ce9ca6
--- /dev/null
+++ b/archival/v2_export.go
@@ -0,0 +1,217 @@
+package archival
+
+import (
+	"archive/tar"
+	"bytes"
+	"compress/gzip"
+	"encoding/json"
+	"fmt"
+	"io"
+	"time"
+
+	"github.com/dustin/go-humanize"
+	"github.com/gabriel-vasile/mimetype"
+	"github.com/sirupsen/logrus"
+	"github.com/turt2live/matrix-media-repo/common/rcontext"
+	"github.com/turt2live/matrix-media-repo/templating"
+	"github.com/turt2live/matrix-media-repo/util"
+)
+
+type V2ArchiveWriter interface {
+	WritePart(part int, fileName string, archive io.Reader, size int64) error
+}
+
+type V2ArchiveExport struct {
+	exportId      string
+	entity        string
+	indexModel    *templating.ExportIndexModel
+	writer        V2ArchiveWriter
+	mediaManifest map[string]*V2ManifestRecord
+	partSize      int64
+	ctx           rcontext.RequestContext
+
+	// state variables
+	currentPart     int
+	currentTar      *tar.Writer
+	currentTarBytes *bytes.Buffer
+	currentSize     int64
+	writingManifest bool
+}
+
+func NewV2Export(exportId string, entity string, partSize int64, writer V2ArchiveWriter, ctx rcontext.RequestContext) (*V2ArchiveExport, error) {
+	ctx = ctx.LogWithFields(logrus.Fields{
+		"v2_export-id":       exportId,
+		"v2_export-entity":   entity,
+		"v2_export-partSize": partSize,
+	})
+	archiver := &V2ArchiveExport{
+		exportId: exportId,
+		entity:   entity,
+		writer:   writer,
+		partSize: partSize,
+		ctx: ctx,
+		indexModel: &templating.ExportIndexModel{
+			Entity:   entity,
+			ExportID: exportId,
+			Media:    make([]*templating.ExportIndexMediaModel, 0),
+		},
+		mediaManifest: make(map[string]*V2ManifestRecord),
+		currentPart:   0,
+	}
+	ctx.Log.Info("Preparing first tar file...")
+	err := archiver.newTar()
+	return archiver, err
+}
+
+func (e *V2ArchiveExport) newTar() error {
+	if e.currentPart > 0 {
+		e.ctx.Log.Info("Persisting complete tar file...")
+		if err := e.persistTar(); err != nil {
+			return err
+		}
+	}
+
+	e.ctx.Log.Info("Starting new tar file...")
+	e.currentTarBytes = &bytes.Buffer{}
+	e.currentTar = tar.NewWriter(e.currentTarBytes)
+	e.currentPart = e.currentPart + 1
+	e.currentSize = 0
+
+	return nil
+}
+
+func (e *V2ArchiveExport) persistTar() error {
+	_ = e.currentTar.Close()
+
+	e.ctx.Log.Info("Compressing tar file...")
+	gzipBytes := &bytes.Buffer{}
+	archiver := gzip.NewWriter(gzipBytes)
+	archiver.Name = fmt.Sprintf("export-part-%d.tar", e.currentPart)
+	if e.writingManifest {
+		archiver.Name = "export-manifest.tar"
+	}
+
+	if _, err := io.Copy(archiver, util.ClonedBufReader(*e.currentTarBytes)); err != nil {
+		return err
+	}
+	_ = archiver.Close()
+
+	e.ctx.Log.Info("Writing compressed tar")
+	name := fmt.Sprintf("export-part-%d.tgz", e.currentPart)
+	if e.writingManifest {
+		name = "export-manifest.tgz"
+	}
+	return e.writer.WritePart(e.currentPart, name, gzipBytes, int64(len(gzipBytes.Bytes())))
+}
+
+func (e *V2ArchiveExport) putFile(buf *bytes.Buffer, name string, creationTime time.Time) (int64, error) {
+	length := int64(len(buf.Bytes()))
+	header := &tar.Header{
+		Name:    name,
+		Size:    length,
+		Mode:    int64(0644),
+		ModTime: creationTime,
+	}
+	if err := e.currentTar.WriteHeader(header); err != nil {
+		return 0, err
+	}
+
+	i, err := io.Copy(e.currentTar, buf)
+	if err != nil {
+		return 0, err
+	}
+	e.currentSize += i
+
+	return length, nil
+}
+
+func (e *V2ArchiveExport) AppendMedia(origin string, mediaId string, originalName string, contentType string, creationTime time.Time, file io.Reader, sha256 string, s3Url string, userId string) error {
+	// buffer the entire file into memory
+	buf := &bytes.Buffer{}
+	if _, err := io.Copy(buf, file); err != nil {
+		return err
+	}
+
+	mime := mimetype.Detect(buf.Bytes())
+	internalName := fmt.Sprintf("%s__%s%s", origin, mediaId, mime.Extension())
+
+	length, err := e.putFile(buf, internalName, creationTime)
+	if err != nil {
+		return err
+	}
+
+	mxc := fmt.Sprintf("mxc://%s/%s", origin, mediaId)
+	e.mediaManifest[mxc] = &V2ManifestRecord{
+		ArchivedName: internalName,
+		FileName:     originalName,
+		SizeBytes:    length,
+		ContentType:  contentType,
+		S3Url:        s3Url,
+		Sha256:       sha256,
+		Origin:       origin,
+		MediaId:      mediaId,
+		CreatedTs:    creationTime.UnixNano() / 1000000,
+		Uploader:     userId,
+	}
+	e.indexModel.Media = append(e.indexModel.Media, &templating.ExportIndexMediaModel{
+		ExportID:        e.exportId,
+		ArchivedName:    internalName,
+		FileName:        originalName,
+		SizeBytes:       length,
+		SizeBytesHuman:  humanize.Bytes(uint64(length)),
+		Origin:          origin,
+		MediaID:         mediaId,
+		Sha256Hash:      sha256,
+		ContentType:     contentType,
+		UploadTs:        creationTime.UnixNano() / 1000000,
+		UploadDateHuman: creationTime.Format(time.UnixDate),
+		Uploader:        userId,
+	})
+
+	if e.currentSize >= e.partSize {
+		e.ctx.Log.Info("Rotating tar...")
+		return e.newTar()
+	}
+
+	return nil
+}
+
+func (e *V2ArchiveExport) Finish() error {
+	if err := e.newTar(); err != nil {
+		return err
+	}
+
+	e.ctx.Log.Info("Writing manifest...")
+	e.writingManifest = true
+	defer (func() { e.writingManifest = false })()
+	manifest := &V2Manifest{
+		Version:   2,
+		EntityId:  e.entity,
+		CreatedTs: util.NowMillis(),
+		Media:     e.mediaManifest,
+	}
+	b, err := json.Marshal(manifest)
+	if err != nil {
+		e.writingManifest = false
+		return err
+	}
+	if _, err := e.putFile(bytes.NewBuffer(b), "manifest.json", time.Now()); err != nil {
+		return err
+	}
+
+	e.ctx.Log.Info("Writing index...")
+	t, err := templating.GetTemplate("export_index")
+	if err != nil {
+		return err
+	}
+	html := bytes.Buffer{}
+	if err := t.Execute(&html, e.indexModel); err != nil {
+		return err
+	}
+	if _, err := e.putFile(bytes.NewBuffer(html.Bytes()), "index.html", time.Now()); err != nil {
+		return err
+	}
+
+	e.ctx.Log.Info("Writing manifest tar...")
+	return e.persistTar()
+}
diff --git a/archival/v2_export_disk_writer.go b/archival/v2_export_disk_writer.go
new file mode 100644
index 00000000..6d6cd1a3
--- /dev/null
+++ b/archival/v2_export_disk_writer.go
@@ -0,0 +1,26 @@
+package archival
+
+import (
+	"io"
+	"os"
+	"path"
+)
+
+type V2ArchiveDiskWriter struct {
+	directory string
+}
+
+func NewV2ArchiveDiskWriter(directory string) *V2ArchiveDiskWriter {
+	return &V2ArchiveDiskWriter{directory: directory}
+}
+
+func (w V2ArchiveDiskWriter) WritePart(part int, fileName string, archive io.Reader, size int64) error {
+	f, err := os.Create(path.Join(w.directory, fileName))
+	if err != nil {
+		return err
+	}
+	if _, err := io.Copy(f, archive); err != nil {
+		return err
+	}
+	return f.Close()
+}
diff --git a/archival/v2_manifest.go b/archival/v2_manifest.go
new file mode 100644
index 00000000..249e28f2
--- /dev/null
+++ b/archival/v2_manifest.go
@@ -0,0 +1,21 @@
+package archival
+
+type V2ManifestRecord struct {
+	FileName     string `json:"name"`
+	ArchivedName string `json:"file_name"`
+	SizeBytes    int64  `json:"size_bytes"`
+	ContentType  string `json:"content_type"`
+	S3Url        string `json:"s3_url"`
+	Sha256       string `json:"sha256"`
+	Origin       string `json:"origin"`
+	MediaId      string `json:"media_id"`
+	CreatedTs    int64  `json:"created_ts"`
+	Uploader     string `json:"uploader"`
+}
+
+type V2Manifest struct {
+	Version   int                          `json:"version"`
+	EntityId  string                       `json:"entity_id"`
+	CreatedTs int64                        `json:"created_ts"`
+	Media     map[string]*V2ManifestRecord `json:"media"`
+}
diff --git a/cmd/export_synapse_for_import/main.go b/cmd/export_synapse_for_import/main.go
index 17c4687a..6c8f5ff0 100644
--- a/cmd/export_synapse_for_import/main.go
+++ b/cmd/export_synapse_for_import/main.go
@@ -1,10 +1,7 @@
 package main
 
 import (
-	"archive/tar"
 	"bytes"
-	"compress/gzip"
-	"encoding/json"
 	"flag"
 	"fmt"
 	"io"
@@ -13,16 +10,14 @@ import (
 	"path"
 	"strconv"
 	"strings"
-	"time"
 
-	"github.com/dustin/go-humanize"
 	"github.com/sirupsen/logrus"
+	"github.com/turt2live/matrix-media-repo/archival"
 	"github.com/turt2live/matrix-media-repo/common/assets"
 	"github.com/turt2live/matrix-media-repo/common/config"
 	"github.com/turt2live/matrix-media-repo/common/logging"
-	"github.com/turt2live/matrix-media-repo/controllers/data_controller"
+	"github.com/turt2live/matrix-media-repo/common/rcontext"
 	"github.com/turt2live/matrix-media-repo/synapse"
-	"github.com/turt2live/matrix-media-repo/templating"
 	"github.com/turt2live/matrix-media-repo/util"
 	"golang.org/x/crypto/ssh/terminal"
 )
@@ -86,105 +81,12 @@ func main() {
 
 	logrus.Info(fmt.Sprintf("Exporting %d media records", len(records)))
 
-	// TODO: Share this logic with export_controller somehow
-	var currentTar *tar.Writer
-	var currentTarBytes bytes.Buffer
-	part := 0
-	currentSize := int64(0)
-	isManifestTar := false
-
-	persistTar := func() error {
-		_ = currentTar.Close()
-
-		// compress
-		logrus.Info("Compressing tar file...")
-		gzipBytes := bytes.Buffer{}
-		archiver := gzip.NewWriter(&gzipBytes)
-		archiver.Name = fmt.Sprintf("export-part-%d.tar", part)
-		if isManifestTar {
-			archiver.Name = fmt.Sprintf("export-manifest.tar")
-		}
-		_, err := io.Copy(archiver, bytes.NewBuffer(currentTarBytes.Bytes()))
-		if err != nil {
-			return err
-		}
-		_ = archiver.Close()
-
-		logrus.Info("Writing compressed tar to disk...")
-		name := fmt.Sprintf("export-part-%d.tgz", part)
-		if isManifestTar {
-			name = "export-manifest.tgz"
-		}
-		f, err := os.Create(path.Join(*exportPath, name))
-		if err != nil {
-			return err
-		}
-		_, _ = io.Copy(f, &gzipBytes)
-		_ = f.Close()
-
-		return nil
-	}
-
-	newTar := func() error {
-		if part > 0 {
-			logrus.Info("Persisting complete tar file...")
-			err := persistTar()
-			if err != nil {
-				return err
-			}
-		}
-
-		logrus.Info("Starting new tar file...")
-		currentTarBytes = bytes.Buffer{}
-		currentTar = tar.NewWriter(&currentTarBytes)
-		part = part + 1
-		currentSize = 0
-
-		return nil
-	}
-
-	// Start the first tar file
-	logrus.Info("Preparing first tar file...")
-	err = newTar()
+	writer := archival.NewV2ArchiveDiskWriter(*exportPath)
+	exporter, err := archival.NewV2Export("OOB", *serverName, *partSizeBytes, writer, rcontext.Initial())
 	if err != nil {
 		logrus.Fatal(err)
 	}
 
-	putFile := func(name string, size int64, creationTime time.Time, file io.Reader) error {
-		header := &tar.Header{
-			Name:    name,
-			Size:    size,
-			Mode:    int64(0644),
-			ModTime: creationTime,
-		}
-		err := currentTar.WriteHeader(header)
-		if err != nil {
-			return err
-		}
-
-		i, err := io.Copy(currentTar, file)
-		if err != nil {
-			return err
-		}
-
-		currentSize += i
-
-		return nil
-	}
-
-	archivedName := func(origin string, mediaId string) string {
-		// TODO: Pick the right extension for the file type
-		return fmt.Sprintf("%s__%s.obj", origin, mediaId)
-	}
-
-	logrus.Info("Preparing manifest...")
-	indexModel := &templating.ExportIndexModel{
-		Entity:   *serverName,
-		ExportID: "OOB",
-		Media:    make([]*templating.ExportIndexMediaModel, 0),
-	}
-	mediaManifest := make(map[string]*data_controller.ManifestRecord)
-
 	missing := make([]string, 0)
 
 	for _, r := range records {
@@ -224,90 +126,13 @@ func main() {
 			logrus.Fatal(err)
 		}
 
-		err = putFile(archivedName(*serverName, r.MediaId), r.SizeBytes, util.FromMillis(r.CreatedTs), d)
+		err = exporter.AppendMedia(*serverName, r.MediaId, r.UploadName, r.ContentType, util.FromMillis(r.CreatedTs), d, sha256, "", r.UserId)
 		if err != nil {
 			logrus.Fatal(err)
 		}
-
-		if currentSize >= *partSizeBytes {
-			logrus.Info("Rotating tar...")
-			err = newTar()
-			if err != nil {
-				logrus.Fatal(err)
-			}
-		}
-
-		mediaManifest[mxc] = &data_controller.ManifestRecord{
-			ArchivedName: archivedName(*serverName, r.MediaId),
-			FileName:     r.UploadName,
-			SizeBytes:    r.SizeBytes,
-			ContentType:  r.ContentType,
-			S3Url:        "",
-			Sha256:       sha256,
-			Origin:       *serverName,
-			MediaId:      r.MediaId,
-			CreatedTs:    r.CreatedTs,
-			Uploader:     r.UserId,
-		}
-		indexModel.Media = append(indexModel.Media, &templating.ExportIndexMediaModel{
-			ExportID:        "OOB",
-			ArchivedName:    archivedName(*serverName, r.MediaId),
-			FileName:        r.UploadName,
-			SizeBytes:       r.SizeBytes,
-			SizeBytesHuman:  humanize.Bytes(uint64(r.SizeBytes)),
-			Origin:          *serverName,
-			MediaID:         r.MediaId,
-			Sha256Hash:      sha256,
-			ContentType:     r.ContentType,
-			UploadTs:        r.CreatedTs,
-			UploadDateHuman: util.FromMillis(r.CreatedTs).Format(time.UnixDate),
-			Uploader:        r.UserId,
-		})
-	}
-
-	logrus.Info("Preparing manifest-specific tar...")
-	err = newTar()
-	if err != nil {
-		logrus.Fatal(err)
-	}
-
-	logrus.Info("Writing manifest...")
-	isManifestTar = true
-	manifest := &data_controller.Manifest{
-		Version:   2,
-		EntityId:  *serverName,
-		CreatedTs: util.NowMillis(),
-		Media:     mediaManifest,
-	}
-	b, err := json.Marshal(manifest)
-	if err != nil {
-		logrus.Fatal(err)
-	}
-	err = putFile("manifest.json", int64(len(b)), time.Now(), bytes.NewBuffer(b))
-	if err != nil {
-		logrus.Fatal(err)
-	}
-
-	logrus.Info("Building and writing index...")
-	t, err := templating.GetTemplate("export_index")
-	if err != nil {
-		logrus.Fatal(err)
-		return
-	}
-	html := bytes.Buffer{}
-	err = t.Execute(&html, indexModel)
-	if err != nil {
-		logrus.Fatal(err)
-		return
-	}
-	err = putFile("index.html", int64(html.Len()), time.Now(), util.BufferToStream(bytes.NewBuffer(html.Bytes())))
-	if err != nil {
-		logrus.Fatal(err)
-		return
 	}
 
-	logrus.Info("Writing final tar...")
-	err = persistTar()
+	err = exporter.Finish()
 	if err != nil {
 		logrus.Fatal(err)
 	}
@@ -324,5 +149,5 @@ func main() {
 		}
 	}
 
-	logrus.Info("Import completed")
+	logrus.Info("Export completed")
 }
diff --git a/go.mod b/go.mod
index 0ab2f08e..eb34d640 100644
--- a/go.mod
+++ b/go.mod
@@ -24,6 +24,7 @@ require (
 	github.com/fastly/go-utils v0.0.0-20180712184237-d95a45783239 // indirect
 	github.com/fogleman/gg v1.3.0
 	github.com/fsnotify/fsnotify v1.4.7
+	github.com/gabriel-vasile/mimetype v1.1.1
 	github.com/go-redis/redis/v8 v8.0.0-beta.6
 	github.com/go-sql-driver/mysql v1.5.0 // indirect
 	github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0
diff --git a/go.sum b/go.sum
index 9d3c1f2b..0b9099b8 100644
--- a/go.sum
+++ b/go.sum
@@ -135,6 +135,8 @@ github.com/fogleman/gg v1.3.0 h1:/7zJX8F6AaYQc57WQCyN9cAIz+4bCJGO9B+dyW29am8=
 github.com/fogleman/gg v1.3.0/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
 github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I=
 github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
+github.com/gabriel-vasile/mimetype v1.1.1 h1:qbN9MPuRf3bstHu9zkI9jDWNfH//9+9kHxr9oRBBBOA=
+github.com/gabriel-vasile/mimetype v1.1.1/go.mod h1:6CDPel/o/3/s4+bp6kIbsWATq8pmgOisOPG40CJa6To=
 github.com/gdamore/encoding v1.0.0/go.mod h1:alR0ol34c49FCSBLjhosxzcPHQbf2trDkoo5dl+VrEg=
 github.com/gdamore/tcell v1.1.1/go.mod h1:K1udHkiR3cOtlpKG5tZPD5XxrF7v2y7lDq7Whcj+xkQ=
 github.com/go-bindata/go-bindata v3.1.2+incompatible/go.mod h1:xK8Dsgwmeed+BBsSy2XTopBn/8uK2HWuGSnA11C3Joo=
diff --git a/util/streams.go b/util/streams.go
index a51adcdb..6c6c513a 100644
--- a/util/streams.go
+++ b/util/streams.go
@@ -8,6 +8,7 @@ import (
 	"io/ioutil"
 
 	"github.com/turt2live/matrix-media-repo/util/cleanup"
+	"github.com/turt2live/matrix-media-repo/util/util_byte_seeker"
 )
 
 func BufferToStream(buf *bytes.Buffer) io.ReadCloser {
@@ -50,3 +51,7 @@ func GetSha256HashOfStream(r io.ReadCloser) (string, error) {
 
 	return hex.EncodeToString(hasher.Sum(nil)), nil
 }
+
+func ClonedBufReader(buf bytes.Buffer) util_byte_seeker.ByteSeeker {
+	return util_byte_seeker.NewByteSeeker(buf.Bytes())
+}
\ No newline at end of file
-- 
GitLab