Skip to content
Snippets Groups Projects
Commit 892f4ba2 authored by Travis Ralston's avatar Travis Ralston
Browse files

Use new upload pipeline in synapse import

parent a1dc464f
No related branches found
No related tags found
No related merge requests found
......@@ -17,7 +17,7 @@ import (
"github.com/turt2live/matrix-media-repo/common/logging"
"github.com/turt2live/matrix-media-repo/common/rcontext"
"github.com/turt2live/matrix-media-repo/common/runtime"
"github.com/turt2live/matrix-media-repo/synapse"
"github.com/turt2live/matrix-media-repo/homeserver_interop/synapse"
"github.com/turt2live/matrix-media-repo/util"
"github.com/turt2live/matrix-media-repo/util/stream_util"
"golang.org/x/crypto/ssh/terminal"
......
......@@ -11,27 +11,22 @@ import (
"sync"
"time"
"github.com/Jeffail/tunny"
"github.com/panjf2000/ants/v2"
"github.com/sirupsen/logrus"
"github.com/turt2live/matrix-media-repo/common"
"github.com/turt2live/matrix-media-repo/common/assets"
"github.com/turt2live/matrix-media-repo/common/config"
"github.com/turt2live/matrix-media-repo/common/logging"
"github.com/turt2live/matrix-media-repo/common/rcontext"
"github.com/turt2live/matrix-media-repo/common/runtime"
"github.com/turt2live/matrix-media-repo/controllers/upload_controller"
"github.com/turt2live/matrix-media-repo/storage"
"github.com/turt2live/matrix-media-repo/synapse"
"github.com/turt2live/matrix-media-repo/util/stream_util"
"github.com/turt2live/matrix-media-repo/common/version"
"github.com/turt2live/matrix-media-repo/database"
"github.com/turt2live/matrix-media-repo/datastores"
"github.com/turt2live/matrix-media-repo/homeserver_interop/synapse"
"github.com/turt2live/matrix-media-repo/pipelines/pipeline_upload"
"github.com/turt2live/matrix-media-repo/util/ids"
"golang.org/x/crypto/ssh/terminal"
)
type fetchRequest struct {
media *synapse.LocalMedia
csApiUrl string
serverName string
}
func main() {
postgresHost := flag.String("dbHost", "localhost", "The PostgresSQL hostname for your Synapse database")
postgresPort := flag.Int("dbPort", 5432, "The port for your Synapse's PostgreSQL database")
......@@ -42,7 +37,7 @@ func main() {
serverName := flag.String("serverName", "localhost", "The name of your homeserver (eg: matrix.org)")
configPath := flag.String("config", "media-repo.yaml", "The path to the media repo configuration (configured for the media repo's database)")
migrationsPath := flag.String("migrations", "./migrations", "The absolute path the media repo's migrations folder")
numWorkers := flag.Int("workers", 1, "The number of workers to use when downloading media. Using multiple workers risks deduplication not working as efficiently.")
numWorkers := flag.Int("workers", 10, "The number of workers to use when downloading media. Using multiple workers is recommended.")
flag.Parse()
// Override config path with config for Docker users
......@@ -51,8 +46,17 @@ func main() {
configPath = &configEnv
}
version.SetDefaults()
config.Path = *configPath
runtime.CheckIdGenerator()
config.Runtime.IsImportProcess = true
if ids.GetMachineId() == 0 {
_ = os.Setenv("MACHINE_ID", "1023")
if ids.GetMachineId() != 1023 {
panic(errors.New("expected machine ID 1023 or custom ID for import process"))
}
}
assets.SetupMigrations(*migrationsPath)
var realPsqlPassword string
......@@ -86,7 +90,7 @@ func main() {
logrus.Info("Starting up...")
runtime.RunStartupSequence()
logrus.Info("Setting up for importing...")
logrus.Debug("Setting up for importing...")
connectionString := "postgres://" + *postgresUsername + ":" + realPsqlPassword + "@" + *postgresHost + ":" + strconv.Itoa(*postgresPort) + "/" + *postgresDatabase + "?sslmode=disable"
csApiUrl := *baseUrl
......@@ -94,13 +98,13 @@ func main() {
csApiUrl = csApiUrl[:len(csApiUrl)-1]
}
logrus.Info("Connecting to synapse database...")
logrus.Debug("Connecting to synapse database...")
synDb, err := synapse.OpenDatabase(connectionString)
if err != nil {
panic(err)
}
logrus.Info("Fetching all local media records from synapse...")
logrus.Debug("Fetching all local media records from synapse...")
records, err := synDb.GetAllMedia()
if err != nil {
panic(err)
......@@ -108,31 +112,44 @@ func main() {
logrus.Info(fmt.Sprintf("Downloading %d media records", len(records)))
pool := tunny.NewFunc(*numWorkers, fetchMedia)
pool, err := ants.NewPool(*numWorkers, ants.WithOptions(ants.Options{
ExpiryDuration: 1 * time.Hour,
PreAlloc: false,
MaxBlockingTasks: 0, // no limit
Nonblocking: false,
Logger: &logging.SendToDebugLogger{},
DisablePurge: false,
PanicHandler: func(err interface{}) {
panic(err)
},
}))
if err != nil {
panic(err)
}
numCompleted := 0
lock := &sync.RWMutex{}
onComplete := func(interface{}, error) {
lock.Lock()
mu := &sync.RWMutex{}
onComplete := func() {
mu.Lock()
numCompleted++
percent := int((float32(numCompleted) / float32(len(records))) * 100)
logrus.Info(fmt.Sprintf("%d/%d downloaded (%d%%)", numCompleted, len(records), percent))
lock.Unlock()
mu.Unlock()
}
for i := 0; i < len(records); i++ {
percent := int((float32(i+1) / float32(len(records))) * 100)
record := records[i]
logrus.Info(fmt.Sprintf("Queuing %s (%d/%d %d%%)", record.MediaId, i+1, len(records), percent))
go func() {
result := pool.Process(&fetchRequest{media: record, serverName: *serverName, csApiUrl: csApiUrl})
onComplete(result, nil)
}()
logrus.Debug(fmt.Sprintf("Queuing %s (%d/%d %d%%)", record.MediaId, i+1, len(records), percent))
err = pool.Submit(doWork(record, *serverName, csApiUrl, onComplete))
if err != nil {
panic(err)
}
}
for numCompleted < len(records) {
logrus.Info("Waiting for import to complete...")
logrus.Debug("Waiting for import to complete...")
time.Sleep(1 * time.Second)
}
......@@ -142,37 +159,41 @@ func main() {
logrus.Info("Import completed")
}
func fetchMedia(req interface{}) interface{} {
payload := req.(*fetchRequest)
record := payload.media
ctx := rcontext.Initial()
func doWork(record *synapse.LocalMedia, serverName string, csApiUrl string, onComplete func()) func() {
return func() {
defer onComplete()
db := storage.GetDatabase().GetMediaStore(ctx)
ctx := rcontext.Initial().LogWithFields(logrus.Fields{"origin": serverName, "mediaId": record.MediaId})
_, err := db.Get(payload.serverName, record.MediaId)
if err == nil {
logrus.Info("Media already downloaded: " + payload.serverName + "/" + record.MediaId)
return nil
}
db := database.GetInstance().Media.Prepare(ctx)
body, err := downloadMedia(payload.csApiUrl, payload.serverName, record.MediaId)
if err != nil {
logrus.Error(err.Error())
return nil
}
defer stream_util.DumpAndCloseStream(body)
dbRecord, err := db.GetById(serverName, record.MediaId)
if err != nil {
panic(err)
}
if dbRecord != nil {
ctx.Log.Debug("Already downloaded - skipping")
return
}
_, err = upload_controller.StoreDirect(nil, body, -1, record.ContentType, record.UploadName, record.UserId, payload.serverName, record.MediaId, common.KindLocalMedia, ctx, false)
if err != nil {
logrus.Error(err.Error())
return nil
}
body, err := downloadMedia(csApiUrl, serverName, record.MediaId)
if err != nil {
panic(err)
}
return nil
dbRecord, err = pipeline_upload.Execute(ctx, serverName, record.MediaId, body, record.ContentType, record.UploadName, record.UserId, datastores.LocalMediaKind)
if err != nil {
panic(err)
}
if dbRecord.SizeBytes != record.SizeBytes {
ctx.Log.Warnf("Size mismatch! Expected %d bytes but got %d", record.SizeBytes, dbRecord.SizeBytes)
}
}
}
func downloadMedia(baseUrl string, serverName string, mediaId string) (io.ReadCloser, error) {
downloadUrl := baseUrl + "/_matrix/media/r0/download/" + serverName + "/" + mediaId
downloadUrl := baseUrl + "/_matrix/media/v3/download/" + serverName + "/" + mediaId
resp, err := http.Get(downloadUrl)
if err != nil {
return nil, err
......
......@@ -30,7 +30,6 @@ func main() {
}
config.Path = *configPath
runtime.CheckIdGenerator()
assets.SetupMigrations(*migrationsPath)
assets.SetupTemplates(*templatesPath)
......
......@@ -14,9 +14,10 @@ import (
)
type runtimeConfig struct {
MigrationsPath string
TemplatesPath string
AssetsPath string
MigrationsPath string
TemplatesPath string
AssetsPath string
IsImportProcess bool
}
const DefaultMigrationsPath = "./migrations"
......
File moved
File moved
......@@ -5,6 +5,7 @@ import (
"io"
"github.com/getsentry/sentry-go"
"github.com/turt2live/matrix-media-repo/common/config"
"github.com/turt2live/matrix-media-repo/common/rcontext"
"github.com/turt2live/matrix-media-repo/database"
"github.com/turt2live/matrix-media-repo/datastores"
......@@ -67,7 +68,7 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, r io.Re
}
// Step 7: Ensure user can upload within quota
if userId != "" {
if userId != "" && !config.Runtime.IsImportProcess {
err = quota.CanUpload(ctx, userId, sizeBytes)
if err != nil {
return nil, err
......
package ids
import (
"errors"
"os"
"strconv"
"github.com/bwmarrin/snowflake"
"github.com/turt2live/matrix-media-repo/common/config"
)
func GetMachineId() int64 {
if val, ok := os.LookupEnv("MACHINE_ID"); ok {
if i, err := strconv.ParseInt(val, 10, 64); err == nil {
if i == 1023 && !config.Runtime.IsImportProcess {
panic(errors.New("machine ID 1023 is reserved for use by import process"))
}
return i
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment