From 7ff48b7c812fd2f9637680bfff7e6d2f5d93507f Mon Sep 17 00:00:00 2001 From: Travis Ralston <travpc@gmail.com> Date: Fri, 11 Aug 2023 23:14:45 -0600 Subject: [PATCH] Add an upload spam test --- test/templates/mmr.config.yaml | 4 +- test/upload_suite_test.go | 79 ++++++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 1 deletion(-) diff --git a/test/templates/mmr.config.yaml b/test/templates/mmr.config.yaml index 2114b279..40fbf6d6 100644 --- a/test/templates/mmr.config.yaml +++ b/test/templates/mmr.config.yaml @@ -37,4 +37,6 @@ datastores: bucketName: "mybucket" accessKeyId: "mykey" accessSecret: "mysecret" - ssl: false \ No newline at end of file + ssl: false +rateLimit: + enabled: false # we've got tests which intentionally spam diff --git a/test/upload_suite_test.go b/test/upload_suite_test.go index 09aec7b1..f447fcfa 100644 --- a/test/upload_suite_test.go +++ b/test/upload_suite_test.go @@ -2,8 +2,11 @@ package test import ( "fmt" + "io" "log" + "math/rand" "net/http" + "sync" "testing" "github.com/stretchr/testify/assert" @@ -171,6 +174,82 @@ func (s *UploadTestSuite) TestUploadDeduplicationDifferentUser() { assert.Equal(t, record1.Location, record2.Location) } +func (s *UploadTestSuite) TestUploadSpam() { + t := s.T() + const concurrentUploads = 100 + + // Clients are for the same user/server, but using different MMR machines + client1 := s.deps.Homeservers[0].UnprivilegedUsers[0].WithCsUrl(s.deps.Machines[0].HttpUrl) + client2 := s.deps.Homeservers[0].UnprivilegedUsers[0].WithCsUrl(s.deps.Machines[1].HttpUrl) + assert.Equal(t, client1.ServerName, client2.ServerName) + + // Create the image streams first (so we don't accidentally hit slowdowns during upload) + images := make([]io.Reader, concurrentUploads) + contentTypes := make([]string, concurrentUploads) + for i := 0; i < concurrentUploads; i++ { + c, img, err := test_internals.MakeTestImage(128, 128) + assert.NoError(t, err) + images[i] = img + contentTypes[i] = c + } + + // Start all the uploads concurrently, and wait for them to complete + waiter := new(sync.WaitGroup) + waiter.Add(1) + uploadWaiter := new(sync.WaitGroup) + mediaIds := new(sync.Map) + for i := 0; i < concurrentUploads; i++ { + go func(j int) { + uploadWaiter.Add(1) + defer uploadWaiter.Done() + + img := images[j] + contentType := contentTypes[j] + client := client1 + if rand.Float32() < 0.5 { + client = client2 + } + waiter.Wait() + + // We use random file names to guarantee different media IDs at a minimum + rstr, err := util.GenerateRandomString(64) + assert.NoError(t, err) + res, err := client.Upload("image"+rstr+util.ExtensionForContentType(contentType), contentType, img) + assert.NoError(t, err) + assert.NotEmpty(t, res.MxcUri) + + origin, mediaId, err := util.SplitMxc(res.MxcUri) + assert.NoError(t, err) + assert.Equal(t, client.ServerName, origin) + assert.NotEmpty(t, mediaId) + mediaIds.Store(mediaId, true) + }(i) + } + waiter.Done() + uploadWaiter.Wait() + + // Prepare to check that only one copy of the file was uploaded each time + mediaDb := database.GetInstance().Media.Prepare(rcontext.Initial()) + realMediaIds := make([]string, 0) + mediaIds.Range(func(key any, value any) bool { + realMediaIds = append(realMediaIds, key.(string)) + return true + }) + assert.Greater(t, len(realMediaIds), 0) + records, err := mediaDb.GetByIds(client1.ServerName, realMediaIds) + assert.NoError(t, err) + assert.NotNil(t, records) + assert.Len(t, records, len(realMediaIds)) + + // Actually do the comparison + dsId := records[0].DatastoreId + dsLocation := records[0].Location + for _, r := range records { + assert.Equal(t, dsId, r.DatastoreId) + assert.Equal(t, dsLocation, r.Location) + } +} + func TestUploadTestSuite(t *testing.T) { suite.Run(t, new(UploadTestSuite)) } -- GitLab