From 7ff48b7c812fd2f9637680bfff7e6d2f5d93507f Mon Sep 17 00:00:00 2001
From: Travis Ralston <travpc@gmail.com>
Date: Fri, 11 Aug 2023 23:14:45 -0600
Subject: [PATCH] Add an upload spam test

---
 test/templates/mmr.config.yaml |  4 +-
 test/upload_suite_test.go      | 79 ++++++++++++++++++++++++++++++++++
 2 files changed, 82 insertions(+), 1 deletion(-)

diff --git a/test/templates/mmr.config.yaml b/test/templates/mmr.config.yaml
index 2114b279..40fbf6d6 100644
--- a/test/templates/mmr.config.yaml
+++ b/test/templates/mmr.config.yaml
@@ -37,4 +37,6 @@ datastores:
       bucketName: "mybucket"
       accessKeyId: "mykey"
       accessSecret: "mysecret"
-      ssl: false
\ No newline at end of file
+      ssl: false
+rateLimit:
+  enabled: false # we've got tests which intentionally spam
diff --git a/test/upload_suite_test.go b/test/upload_suite_test.go
index 09aec7b1..f447fcfa 100644
--- a/test/upload_suite_test.go
+++ b/test/upload_suite_test.go
@@ -2,8 +2,11 @@ package test
 
 import (
 	"fmt"
+	"io"
 	"log"
+	"math/rand"
 	"net/http"
+	"sync"
 	"testing"
 
 	"github.com/stretchr/testify/assert"
@@ -171,6 +174,82 @@ func (s *UploadTestSuite) TestUploadDeduplicationDifferentUser() {
 	assert.Equal(t, record1.Location, record2.Location)
 }
 
+func (s *UploadTestSuite) TestUploadSpam() {
+	t := s.T()
+	const concurrentUploads = 100
+
+	// Clients are for the same user/server, but using different MMR machines
+	client1 := s.deps.Homeservers[0].UnprivilegedUsers[0].WithCsUrl(s.deps.Machines[0].HttpUrl)
+	client2 := s.deps.Homeservers[0].UnprivilegedUsers[0].WithCsUrl(s.deps.Machines[1].HttpUrl)
+	assert.Equal(t, client1.ServerName, client2.ServerName)
+
+	// Create the image streams first (so we don't accidentally hit slowdowns during upload)
+	images := make([]io.Reader, concurrentUploads)
+	contentTypes := make([]string, concurrentUploads)
+	for i := 0; i < concurrentUploads; i++ {
+		c, img, err := test_internals.MakeTestImage(128, 128)
+		assert.NoError(t, err)
+		images[i] = img
+		contentTypes[i] = c
+	}
+
+	// Start all the uploads concurrently, and wait for them to complete
+	waiter := new(sync.WaitGroup)
+	waiter.Add(1)
+	uploadWaiter := new(sync.WaitGroup)
+	mediaIds := new(sync.Map)
+	for i := 0; i < concurrentUploads; i++ {
+		go func(j int) {
+			uploadWaiter.Add(1)
+			defer uploadWaiter.Done()
+
+			img := images[j]
+			contentType := contentTypes[j]
+			client := client1
+			if rand.Float32() < 0.5 {
+				client = client2
+			}
+			waiter.Wait()
+
+			// We use random file names to guarantee different media IDs at a minimum
+			rstr, err := util.GenerateRandomString(64)
+			assert.NoError(t, err)
+			res, err := client.Upload("image"+rstr+util.ExtensionForContentType(contentType), contentType, img)
+			assert.NoError(t, err)
+			assert.NotEmpty(t, res.MxcUri)
+
+			origin, mediaId, err := util.SplitMxc(res.MxcUri)
+			assert.NoError(t, err)
+			assert.Equal(t, client.ServerName, origin)
+			assert.NotEmpty(t, mediaId)
+			mediaIds.Store(mediaId, true)
+		}(i)
+	}
+	waiter.Done()
+	uploadWaiter.Wait()
+
+	// Prepare to check that only one copy of the file was uploaded each time
+	mediaDb := database.GetInstance().Media.Prepare(rcontext.Initial())
+	realMediaIds := make([]string, 0)
+	mediaIds.Range(func(key any, value any) bool {
+		realMediaIds = append(realMediaIds, key.(string))
+		return true
+	})
+	assert.Greater(t, len(realMediaIds), 0)
+	records, err := mediaDb.GetByIds(client1.ServerName, realMediaIds)
+	assert.NoError(t, err)
+	assert.NotNil(t, records)
+	assert.Len(t, records, len(realMediaIds))
+
+	// Actually do the comparison
+	dsId := records[0].DatastoreId
+	dsLocation := records[0].Location
+	for _, r := range records {
+		assert.Equal(t, dsId, r.DatastoreId)
+		assert.Equal(t, dsLocation, r.Location)
+	}
+}
+
 func TestUploadTestSuite(t *testing.T) {
 	suite.Run(t, new(UploadTestSuite))
 }
-- 
GitLab