Skip to content
Snippets Groups Projects
main.go 4.76 KiB
package main

import (
	"context"
	"errors"
	"flag"
	"fmt"
	"io"
	"net/http"
	"strconv"
	"sync"
	"time"

	"github.com/howeyc/gopass"
	"github.com/jeffail/tunny"
	"github.com/sirupsen/logrus"
	"github.com/turt2live/matrix-media-repo/common/config"
	"github.com/turt2live/matrix-media-repo/common/logging"
	"github.com/turt2live/matrix-media-repo/controllers/upload_controller"
	"github.com/turt2live/matrix-media-repo/storage"
	"github.com/turt2live/matrix-media-repo/synapse"
)

type fetchRequest struct {
	media      *synapse.LocalMedia
	csApiUrl   string
	serverName string
}

func main() {
	postgresHost := flag.String("dbHost", "localhost", "The IP or hostname of the postgresql server with the synapse database")
	postgresPort := flag.Int("dbPort", 5432, "The port to access postgres on")
	postgresUsername := flag.String("dbUsername", "synapse", "The username to access postgres with")
	postgresPassword := flag.String("dbPassword", "", "The password to authorize the postgres user. Can be omitted to be prompted when run")
	postgresDatabase := flag.String("dbName", "synapse", "The name of the synapse database")
	baseUrl := flag.String("baseUrl", "http://localhost:8008", "The base URL to access your homeserver with")
	serverName := flag.String("serverName", "localhost", "The name of your homeserver (eg: matrix.org)")
	configPath := flag.String("config", "media-repo.yaml", "The path to the configuration")
	migrationsPath := flag.String("migrations", "./migrations", "The absolute path the migrations folder")
	numWorkers := flag.Int("workers", 1, "The number of workers to use when downloading media. Using multiple workers risks deduplication not working as efficiently.")
	flag.Parse()

	config.Path = *configPath
	config.Runtime.MigrationsPath = *migrationsPath

	var realPsqlPassword string
	if *postgresPassword == "" {
		fmt.Printf("Postgres password: ")
		pass, err := gopass.GetPasswd()
		if err != nil {
			panic(err)
		}
		realPsqlPassword = string(pass[:])
	} else {
		realPsqlPassword = *postgresPassword
	}

	err := logging.Setup(config.Get().General.LogDirectory)
	if err != nil {
		panic(err)
	}

	logrus.Info("Setting up for importing...")

	connectionString := "postgres://" + *postgresUsername + ":" + realPsqlPassword + "@" + *postgresHost + ":" + strconv.Itoa(*postgresPort) + "/" + *postgresDatabase + "?sslmode=disable"
	csApiUrl := *baseUrl
	if csApiUrl[len(csApiUrl)-1:] == "/" {
		csApiUrl = csApiUrl[:len(csApiUrl)-1]
	}
	logrus.Info("Connecting to synapse database...")
	synDb, err := synapse.OpenDatabase(connectionString)
	if err != nil {
		panic(err)
	}

	logrus.Info("Fetching all local media records from synapse...")
	records, err := synDb.GetAllMedia()
	if err != nil {
		panic(err)
	}

	logrus.Info(fmt.Sprintf("Downloading %d media records", len(records)))

	pool := tunny.NewFunc(*numWorkers, fetchMedia)

	numCompleted := 0
	lock := &sync.RWMutex{}
	onComplete := func(interface{}, error) {
		lock.Lock()
		numCompleted++
		percent := int((float32(numCompleted) / float32(len(records))) * 100)
		logrus.Info(fmt.Sprintf("%d/%d downloaded (%d%%)", numCompleted, len(records), percent))
		lock.Unlock()
	}

	for i := 0; i < len(records); i++ {
		percent := int((float32(i+1) / float32(len(records))) * 100)
		record := records[i]

		logrus.Info(fmt.Sprintf("Queuing %s (%d/%d %d%%)", record.MediaId, i+1, len(records), percent))
		go func() {
			result := pool.Process(&fetchRequest{media: record, serverName: *serverName, csApiUrl: csApiUrl})
			onComplete(result, nil)
		}()
	}

	for numCompleted < len(records) {
		logrus.Info("Waiting for import to complete...")
		time.Sleep(1 * time.Second)
	}

	logrus.Info("Import completed")
}

func fetchMedia(req interface{}) interface{} {
	payload := req.(*fetchRequest)
	record := payload.media
	ctx := context.TODO()
	log := logrus.WithFields(logrus.Fields{})

	db := storage.GetDatabase().GetMediaStore(ctx, log)

	_, err := db.Get(payload.serverName, record.MediaId)
	if err == nil {
		logrus.Info("Media already downloaded: " + payload.serverName + "/" + record.MediaId)
		return nil
	}

	body, err := downloadMedia(payload.csApiUrl, payload.serverName, record.MediaId)
	if err != nil {
		logrus.Error(err.Error())
		return nil
	}

	_, err = upload_controller.StoreDirect(body, -1, record.ContentType, record.UploadName, record.UserId, payload.serverName, record.MediaId, ctx, log)
	if err != nil {
		logrus.Error(err.Error())
		return nil
	}

	body.Close()
	return nil
}

func downloadMedia(baseUrl string, serverName string, mediaId string) (io.ReadCloser, error) {
	downloadUrl := baseUrl + "/_matrix/media/r0/download/" + serverName + "/" + mediaId
	resp, err := http.Get(downloadUrl)
	if err != nil {
		return nil, err
	}
	if resp.StatusCode != http.StatusOK {
		return nil, errors.New("received status code " + strconv.Itoa(resp.StatusCode))
	}

	return resp.Body, nil
}