diff --git a/test/templates/mmr.config.yaml b/test/templates/mmr.config.yaml index 2114b279225aba4e2aaa5f6ae30c40afa554978d..40fbf6d6497f983d9be9cae0f8e0bdd04effd9cf 100644 --- a/test/templates/mmr.config.yaml +++ b/test/templates/mmr.config.yaml @@ -37,4 +37,6 @@ datastores: bucketName: "mybucket" accessKeyId: "mykey" accessSecret: "mysecret" - ssl: false \ No newline at end of file + ssl: false +rateLimit: + enabled: false # we've got tests which intentionally spam diff --git a/test/upload_suite_test.go b/test/upload_suite_test.go index 09aec7b111e2ef2ebddae798f5eccc8cc44f9e81..f447fcfaca5752cb214ec2abc59a2a0c345dc349 100644 --- a/test/upload_suite_test.go +++ b/test/upload_suite_test.go @@ -2,8 +2,11 @@ package test import ( "fmt" + "io" "log" + "math/rand" "net/http" + "sync" "testing" "github.com/stretchr/testify/assert" @@ -171,6 +174,82 @@ func (s *UploadTestSuite) TestUploadDeduplicationDifferentUser() { assert.Equal(t, record1.Location, record2.Location) } +func (s *UploadTestSuite) TestUploadSpam() { + t := s.T() + const concurrentUploads = 100 + + // Clients are for the same user/server, but using different MMR machines + client1 := s.deps.Homeservers[0].UnprivilegedUsers[0].WithCsUrl(s.deps.Machines[0].HttpUrl) + client2 := s.deps.Homeservers[0].UnprivilegedUsers[0].WithCsUrl(s.deps.Machines[1].HttpUrl) + assert.Equal(t, client1.ServerName, client2.ServerName) + + // Create the image streams first (so we don't accidentally hit slowdowns during upload) + images := make([]io.Reader, concurrentUploads) + contentTypes := make([]string, concurrentUploads) + for i := 0; i < concurrentUploads; i++ { + c, img, err := test_internals.MakeTestImage(128, 128) + assert.NoError(t, err) + images[i] = img + contentTypes[i] = c + } + + // Start all the uploads concurrently, and wait for them to complete + waiter := new(sync.WaitGroup) + waiter.Add(1) + uploadWaiter := new(sync.WaitGroup) + mediaIds := new(sync.Map) + for i := 0; i < concurrentUploads; i++ { + go func(j int) { + uploadWaiter.Add(1) + defer uploadWaiter.Done() + + img := images[j] + contentType := contentTypes[j] + client := client1 + if rand.Float32() < 0.5 { + client = client2 + } + waiter.Wait() + + // We use random file names to guarantee different media IDs at a minimum + rstr, err := util.GenerateRandomString(64) + assert.NoError(t, err) + res, err := client.Upload("image"+rstr+util.ExtensionForContentType(contentType), contentType, img) + assert.NoError(t, err) + assert.NotEmpty(t, res.MxcUri) + + origin, mediaId, err := util.SplitMxc(res.MxcUri) + assert.NoError(t, err) + assert.Equal(t, client.ServerName, origin) + assert.NotEmpty(t, mediaId) + mediaIds.Store(mediaId, true) + }(i) + } + waiter.Done() + uploadWaiter.Wait() + + // Prepare to check that only one copy of the file was uploaded each time + mediaDb := database.GetInstance().Media.Prepare(rcontext.Initial()) + realMediaIds := make([]string, 0) + mediaIds.Range(func(key any, value any) bool { + realMediaIds = append(realMediaIds, key.(string)) + return true + }) + assert.Greater(t, len(realMediaIds), 0) + records, err := mediaDb.GetByIds(client1.ServerName, realMediaIds) + assert.NoError(t, err) + assert.NotNil(t, records) + assert.Len(t, records, len(realMediaIds)) + + // Actually do the comparison + dsId := records[0].DatastoreId + dsLocation := records[0].Location + for _, r := range records { + assert.Equal(t, dsId, r.DatastoreId) + assert.Equal(t, dsLocation, r.Location) + } +} + func TestUploadTestSuite(t *testing.T) { suite.Run(t, new(UploadTestSuite)) }