From fca3457bd94bdf805615af5010e395edbabe7f24 Mon Sep 17 00:00:00 2001
From: Travis Ralston <travpc@gmail.com>
Date: Thu, 10 Aug 2023 20:00:52 -0600
Subject: [PATCH] Add simple deduplication tests

---
 test/test_internals/deps.go       |  72 ++++++++++++++-----
 test/test_internals/deps_minio.go |  18 +++--
 test/test_internals/deps_mmr.go   |  30 +++++---
 test/upload_suite_test.go         | 110 ++++++++++++++++++++++++++++++
 4 files changed, 198 insertions(+), 32 deletions(-)

diff --git a/test/test_internals/deps.go b/test/test_internals/deps.go
index c977f48e..24e2543c 100644
--- a/test/test_internals/deps.go
+++ b/test/test_internals/deps.go
@@ -12,14 +12,17 @@ import (
 	"github.com/testcontainers/testcontainers-go"
 	"github.com/testcontainers/testcontainers-go/modules/postgres"
 	"github.com/testcontainers/testcontainers-go/wait"
+	"github.com/turt2live/matrix-media-repo/common/assets"
+	"github.com/turt2live/matrix-media-repo/common/config"
 )
 
 type ContainerDeps struct {
-	ctx            context.Context
-	pgContainer    *postgres.PostgresContainer
-	redisContainer testcontainers.Container
-	minioDep       *MinioDep
-	depNet         *NetworkDep
+	ctx              context.Context
+	pgContainer      *postgres.PostgresContainer
+	redisContainer   testcontainers.Container
+	minioDep         *MinioDep
+	depNet           *NetworkDep
+	mmrExtConfigPath string
 
 	Homeservers []*SynapseDep
 	Machines    []*mmrContainer
@@ -63,6 +66,10 @@ func MakeTestDeps() (*ContainerDeps, error) {
 	}
 	// we can hardcode the port and most of the connection details because we're behind the docker network here
 	pgConnStr := fmt.Sprintf("host=%s port=5432 user=postgres password=test1234 dbname=mmr sslmode=disable", pgHost)
+	extPgConnStr, err := pgContainer.ConnectionString(ctx, "sslmode=disable")
+	if err != nil {
+		return nil, err
+	}
 
 	// Start a redis container
 	cwd, err := os.Getwd()
@@ -76,15 +83,24 @@ func MakeTestDeps() (*ContainerDeps, error) {
 			Mounts: []testcontainers.ContainerMount{
 				testcontainers.BindMount(path.Join(cwd, ".", "dev", "redis.conf"), "/usr/local/etc/redis/redis.conf"),
 			},
-			Cmd:      []string{"redis-server", "/usr/local/etc/redis/redis.conf"},
-			Networks: []string{depNet.NetId},
+			Cmd:        []string{"redis-server", "/usr/local/etc/redis/redis.conf"},
+			Networks:   []string{depNet.NetId},
+			WaitingFor: wait.ForListeningPort("6379/tcp"),
 		},
 		Started: true,
 	})
 	if err != nil {
 		return nil, err
 	}
-	redisHost, err := redisContainer.ContainerIP(ctx)
+	redisIp, err := redisContainer.ContainerIP(ctx)
+	if err != nil {
+		return nil, err
+	}
+	redisHost, err := redisContainer.Host(ctx)
+	if err != nil {
+		return nil, err
+	}
+	redisPort, err := redisContainer.MappedPort(ctx, "6379/tcp")
 	if err != nil {
 		return nil, err
 	}
@@ -96,7 +112,7 @@ func MakeTestDeps() (*ContainerDeps, error) {
 	}
 
 	// Start two MMRs for testing
-	mmrs, err := makeMmrInstances(ctx, 2, depNet, mmrTmplArgs{
+	tmplArgs := mmrTmplArgs{
 		Homeservers: []mmrHomeserverTmplArgs{
 			{
 				ServerName:         syn1.ServerName,
@@ -107,22 +123,39 @@ func MakeTestDeps() (*ContainerDeps, error) {
 				ClientServerApiUrl: syn2.InternalClientServerApiUrl,
 			},
 		},
-		RedisAddr:          fmt.Sprintf("%s:%d", redisHost, 6379), // we're behind the network for redis
+		RedisAddr:          fmt.Sprintf("%s:%d", redisIp, 6379), // we're behind the network for redis
 		PgConnectionString: pgConnStr,
 		S3Endpoint:         minioDep.Endpoint,
-	})
+	}
+	mmrs, err := makeMmrInstances(ctx, 2, depNet, tmplArgs)
+	if err != nil {
+		return nil, err
+	}
+
+	// Generate a config that's safe to use in tests, for inspecting state of the containers
+	tmplArgs.RedisAddr = fmt.Sprintf("%s:%d", redisHost, redisPort.Int())
+	tmplArgs.PgConnectionString = extPgConnStr
+	tmplArgs.S3Endpoint = minioDep.ExternalEndpoint
+	tmplArgs.Homeservers[0].ClientServerApiUrl = syn1.ExternalClientServerApiUrl
+	tmplArgs.Homeservers[1].ClientServerApiUrl = syn2.ExternalClientServerApiUrl
+	tmpPath, err := writeMmrConfig(tmplArgs)
 	if err != nil {
 		return nil, err
 	}
+	config.Path = tmpPath
+	assets.SetupMigrations(config.DefaultMigrationsPath)
+	assets.SetupTemplates(config.DefaultTemplatesPath)
+	assets.SetupAssets(config.DefaultAssetsPath)
 
 	return &ContainerDeps{
-		ctx:            ctx,
-		pgContainer:    pgContainer,
-		redisContainer: redisContainer,
-		minioDep:       minioDep,
-		Homeservers:    []*SynapseDep{syn1, syn2},
-		Machines:       mmrs,
-		depNet:         depNet,
+		ctx:              ctx,
+		pgContainer:      pgContainer,
+		redisContainer:   redisContainer,
+		minioDep:         minioDep,
+		mmrExtConfigPath: tmpPath,
+		Homeservers:      []*SynapseDep{syn1, syn2},
+		Machines:         mmrs,
+		depNet:           depNet,
 	}, nil
 }
 
@@ -141,6 +174,9 @@ func (c *ContainerDeps) Teardown() {
 	}
 	c.minioDep.Teardown()
 	c.depNet.Teardown()
+	if err := os.Remove(c.mmrExtConfigPath); err != nil && !os.IsNotExist(err) {
+		log.Fatalf("Error cleaning up MMR-External config file '%s': %s", c.mmrExtConfigPath, err.Error())
+	}
 }
 
 func (c *ContainerDeps) Debug() {
diff --git a/test/test_internals/deps_minio.go b/test/test_internals/deps_minio.go
index 3941ecbc..e00e04ef 100644
--- a/test/test_internals/deps_minio.go
+++ b/test/test_internals/deps_minio.go
@@ -23,7 +23,8 @@ type MinioDep struct {
 	ctx       context.Context
 	container testcontainers.Container
 
-	Endpoint string
+	Endpoint         string
+	ExternalEndpoint string
 }
 
 func MakeMinio(depNet *NetworkDep) (*MinioDep, error) {
@@ -58,6 +59,14 @@ func MakeMinio(depNet *NetworkDep) (*MinioDep, error) {
 	if err != nil {
 		return nil, err
 	}
+	minioHost, err := container.Host(ctx)
+	if err != nil {
+		return nil, err
+	}
+	minioPort, err := container.MappedPort(ctx, "9090/tcp")
+	if err != nil {
+		return nil, err
+	}
 
 	// Prepare the test script
 	t, err := template.New("minio-config.sh").ParseFiles(path.Join(".", "test", "templates", "minio-config.sh"))
@@ -103,9 +112,10 @@ func MakeMinio(depNet *NetworkDep) (*MinioDep, error) {
 	}
 
 	return &MinioDep{
-		ctx:       ctx,
-		container: container,
-		Endpoint:  fmt.Sprintf("%s:%d", minioIp, 9000), // we're behind the network
+		ctx:              ctx,
+		container:        container,
+		Endpoint:         fmt.Sprintf("%s:%d", minioIp, 9000), // we're behind the network
+		ExternalEndpoint: fmt.Sprintf("%s:%d", minioHost, minioPort.Int()),
 	}, nil
 }
 
diff --git a/test/test_internals/deps_mmr.go b/test/test_internals/deps_mmr.go
index 410f8c91..d0bad936 100644
--- a/test/test_internals/deps_mmr.go
+++ b/test/test_internals/deps_mmr.go
@@ -30,9 +30,10 @@ type mmrTmplArgs struct {
 }
 
 type mmrContainer struct {
-	ctx           context.Context
-	container     testcontainers.Container
-	tmpConfigPath string
+	ctx                   context.Context
+	container             testcontainers.Container
+	tmpConfigPath         string
+	tmpExternalConfigPath string
 
 	HttpUrl   string
 	MachineId int
@@ -79,28 +80,37 @@ func reuseMmrBuild(ctx context.Context) (string, error) {
 	return mmrCachedImage, nil
 }
 
-func makeMmrInstances(ctx context.Context, count int, depNet *NetworkDep, tmplArgs mmrTmplArgs) ([]*mmrContainer, error) {
+func writeMmrConfig(tmplArgs mmrTmplArgs) (string, error) {
 	// Prepare a config template
 	t, err := template.New("mmr.config.yaml").ParseFiles(path.Join(".", "test", "templates", "mmr.config.yaml"))
 	if err != nil {
-		return nil, err
+		return "", err
 	}
 	w := new(strings.Builder)
 	err = t.Execute(w, tmplArgs)
 	if err != nil {
-		return nil, err
+		return "", err
 	}
 
 	// Write the MMR config to a temp file
 	f, err := os.CreateTemp(os.TempDir(), "mmr-tests-mediarepo")
 	if err != nil {
-		return nil, err
+		return "", err
 	}
 	_, err = f.Write([]byte(w.String()))
 	if err != nil {
-		return nil, err
+		return "", err
 	}
 	err = f.Close()
+	if err != nil {
+		return "", err
+	}
+
+	return f.Name(), nil
+}
+
+func makeMmrInstances(ctx context.Context, count int, depNet *NetworkDep, tmplArgs mmrTmplArgs) ([]*mmrContainer, error) {
+	intTmpName, err := writeMmrConfig(tmplArgs)
 	if err != nil {
 		return nil, err
 	}
@@ -121,7 +131,7 @@ func makeMmrInstances(ctx context.Context, count int, depNet *NetworkDep, tmplAr
 				Image:        mmrImage,
 				ExposedPorts: []string{"8000/tcp"},
 				Mounts: []testcontainers.ContainerMount{
-					testcontainers.BindMount(f.Name(), "/data/media-repo.yaml"),
+					testcontainers.BindMount(intTmpName, "/data/media-repo.yaml"),
 				},
 				Env: map[string]string{
 					"MACHINE_ID": strconv.Itoa(i),
@@ -151,7 +161,7 @@ func makeMmrInstances(ctx context.Context, count int, depNet *NetworkDep, tmplAr
 		mmrs = append(mmrs, &mmrContainer{
 			ctx:           ctx,
 			container:     container,
-			tmpConfigPath: f.Name(),
+			tmpConfigPath: intTmpName,
 			HttpUrl:       csApiUrl,
 			MachineId:     i,
 		})
diff --git a/test/upload_suite_test.go b/test/upload_suite_test.go
index 27fa43c7..2f8239cf 100644
--- a/test/upload_suite_test.go
+++ b/test/upload_suite_test.go
@@ -8,6 +8,8 @@ import (
 
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/suite"
+	"github.com/turt2live/matrix-media-repo/common/rcontext"
+	"github.com/turt2live/matrix-media-repo/database"
 	"github.com/turt2live/matrix-media-repo/test/test_internals"
 	"github.com/turt2live/matrix-media-repo/util"
 )
@@ -46,6 +48,7 @@ func (s *UploadTestSuite) TestUpload() {
 	}
 
 	contentType, img, err := test_internals.MakeTestImage(512, 512)
+	assert.NoError(t, err)
 	res, err := client1.Upload("image"+util.ExtensionForContentType(contentType), contentType, img)
 	assert.NoError(t, err)
 	assert.NotEmpty(t, res.MxcUri)
@@ -61,6 +64,113 @@ func (s *UploadTestSuite) TestUpload() {
 	test_internals.AssertIsTestImage(t, raw.Body)
 }
 
+func (s *UploadTestSuite) TestUploadDeduplicationSameUser() {
+	t := s.T()
+
+	client1 := s.deps.Homeservers[0].UnprivilegedUsers[0].WithCsUrl(s.deps.Machines[0].HttpUrl)
+
+	contentType, img, err := test_internals.MakeTestImage(512, 512)
+	assert.NoError(t, err)
+	res1, err := client1.Upload("image"+util.ExtensionForContentType(contentType), contentType, img)
+	assert.NoError(t, err)
+	assert.NotEmpty(t, res1.MxcUri)
+
+	origin, mediaId, err := util.SplitMxc(res1.MxcUri)
+	assert.NoError(t, err)
+	assert.Equal(t, origin, client1.ServerName)
+	assert.NotEmpty(t, mediaId)
+
+	contentType, img, err = test_internals.MakeTestImage(512, 512)
+	assert.NoError(t, err)
+	res2, err := client1.Upload("image"+util.ExtensionForContentType(contentType), contentType, img)
+	assert.NoError(t, err)
+	assert.NotEmpty(t, res2.MxcUri)
+
+	assert.Equal(t, res1.MxcUri, res2.MxcUri)
+}
+
+func (s *UploadTestSuite) TestUploadDeduplicationSameUserDifferentMetadata() {
+	t := s.T()
+
+	client1 := s.deps.Homeservers[0].UnprivilegedUsers[0].WithCsUrl(s.deps.Machines[0].HttpUrl)
+
+	contentType, img, err := test_internals.MakeTestImage(512, 512)
+	assert.NoError(t, err)
+	res1, err := client1.Upload("image"+util.ExtensionForContentType(contentType), contentType, img)
+	assert.NoError(t, err)
+	assert.NotEmpty(t, res1.MxcUri)
+
+	origin1, mediaId1, err := util.SplitMxc(res1.MxcUri)
+	assert.NoError(t, err)
+	assert.Equal(t, origin1, client1.ServerName)
+	assert.NotEmpty(t, mediaId1)
+
+	contentType, img, err = test_internals.MakeTestImage(512, 512)
+	assert.NoError(t, err)
+	res2, err := client1.Upload("DIFFERENT_FILE_NAME_SHOULD_GIVE_DIFFERENT_MEDIA_ID"+util.ExtensionForContentType(contentType), contentType, img)
+	assert.NoError(t, err)
+	assert.NotEmpty(t, res2.MxcUri)
+
+	origin2, mediaId2, err := util.SplitMxc(res2.MxcUri)
+	assert.NoError(t, err)
+	assert.Equal(t, origin2, client1.ServerName) // still the same server though
+	assert.NotEmpty(t, mediaId2)
+
+	assert.NotEqual(t, res1.MxcUri, res2.MxcUri) // should be different media IDs
+
+	// Inspect database to ensure file was reused rather than uploaded twice
+	mediaDb := database.GetInstance().Media.Prepare(rcontext.Initial())
+	records, err := mediaDb.GetByIds(origin1, []string{mediaId1, mediaId2})
+	assert.NoError(t, err)
+	assert.NotNil(t, records)
+	assert.Len(t, records, 2)
+	assert.NotEqual(t, records[0].MediaId, records[1].MediaId)
+	assert.Equal(t, records[0].DatastoreId, records[1].DatastoreId)
+	assert.Equal(t, records[0].Location, records[1].Location)
+}
+
+func (s *UploadTestSuite) TestUploadDeduplicationDifferentUser() {
+	t := s.T()
+
+	client1 := s.deps.Homeservers[0].UnprivilegedUsers[0].WithCsUrl(s.deps.Machines[0].HttpUrl)
+	client2 := s.deps.Homeservers[1].UnprivilegedUsers[0].WithCsUrl(s.deps.Machines[1].HttpUrl)
+
+	contentType, img, err := test_internals.MakeTestImage(512, 512)
+	assert.NoError(t, err)
+	res1, err := client1.Upload("image"+util.ExtensionForContentType(contentType), contentType, img)
+	assert.NoError(t, err)
+	assert.NotEmpty(t, res1.MxcUri)
+
+	origin1, mediaId1, err := util.SplitMxc(res1.MxcUri)
+	assert.NoError(t, err)
+	assert.Equal(t, origin1, client1.ServerName)
+	assert.NotEmpty(t, mediaId1)
+
+	contentType, img, err = test_internals.MakeTestImage(512, 512)
+	assert.NoError(t, err)
+	res2, err := client2.Upload("image"+util.ExtensionForContentType(contentType), contentType, img)
+	assert.NoError(t, err)
+	assert.NotEmpty(t, res2.MxcUri)
+
+	origin2, mediaId2, err := util.SplitMxc(res2.MxcUri)
+	assert.NoError(t, err)
+	assert.Equal(t, origin2, client2.ServerName)
+	assert.NotEmpty(t, mediaId2)
+
+	assert.NotEqual(t, res1.MxcUri, res2.MxcUri) // should be different URIs
+
+	// Inspect database to ensure file was reused rather than uploaded twice
+	mediaDb := database.GetInstance().Media.Prepare(rcontext.Initial())
+	record1, err := mediaDb.GetById(origin1, mediaId1)
+	assert.NoError(t, err)
+	assert.NotNil(t, record1)
+	record2, err := mediaDb.GetById(origin2, mediaId2)
+	assert.NoError(t, err)
+	assert.NotNil(t, record1)
+	assert.Equal(t, record1.DatastoreId, record2.DatastoreId)
+	assert.Equal(t, record1.Location, record2.Location)
+}
+
 func TestUploadTestSuite(t *testing.T) {
 	suite.Run(t, new(UploadTestSuite))
 }
-- 
GitLab