From a3216a0bcc5770b4aa73a122a6ffbf0eccafdab6 Mon Sep 17 00:00:00 2001
From: Travis Ralston <travpc@gmail.com>
Date: Wed, 9 Aug 2023 23:08:27 -0600
Subject: [PATCH] Make test MatrixClients easier to access/create

---
 test/test_internals/deps_synapse.go | 86 ++++++++++++++---------------
 test/test_internals/util_client.go  | 11 ++++
 test/upload_suite_test.go           |  5 +-
 3 files changed, 54 insertions(+), 48 deletions(-)

diff --git a/test/test_internals/deps_synapse.go b/test/test_internals/deps_synapse.go
index cb317126..fdbd948f 100644
--- a/test/test_internals/deps_synapse.go
+++ b/test/test_internals/deps_synapse.go
@@ -38,21 +38,8 @@ type SynapseDep struct {
 	ExternalClientServerApiUrl string
 	ServerName                 string
 
-	AdminUserId                  string
-	AdminAccessToken             string
-	UnprivilegedAliceUserId      string
-	UnprivilegedAliceAccessToken string
-	UnprivilegedBobUserId        string
-	UnprivilegedBobAccessToken   string
-}
-
-type fixNetwork struct {
-	testcontainers.ContainerCustomizer
-	NetId string
-}
-
-func (f *fixNetwork) Customize(req *testcontainers.GenericContainerRequest) {
-	req.Networks = []string{f.NetId}
+	AdminUsers        []*MatrixClient // uses ExternalClientServerApiUrl
+	UnprivilegedUsers []*MatrixClient // uses ExternalClientServerApiUrl
 }
 
 func MakeSynapse(domainName string, depNet *NetworkDep) (*SynapseDep, error) {
@@ -152,7 +139,9 @@ func MakeSynapse(domainName string, depNet *NetworkDep) (*SynapseDep, error) {
 	extCsApiUrl := fmt.Sprintf("http://%s:%d", synHost, synPort.Int())
 
 	// Register the accounts
-	registerUser := func(localpart string, admin bool) (string, string, error) { // userId, accessToken, err
+	adminUsers := make([]*MatrixClient, 0)
+	unprivilegedUsers := make([]*MatrixClient, 0)
+	registerUser := func(localpart string, admin bool) error {
 		adminFlag := "--admin"
 		if !admin {
 			adminFlag = "--no-admin"
@@ -161,21 +150,21 @@ func MakeSynapse(domainName string, depNet *NetworkDep) (*SynapseDep, error) {
 		log.Println("[Synapse Command] " + cmd)
 		i, r, err := synContainer.Exec(ctx, strings.Split(cmd, " "))
 		if err != nil {
-			return "", "", err
+			return err
 		}
 		b, err := io.ReadAll(r)
 		if err != nil {
-			return "", "", err
+			return err
 		}
 		if i != 0 {
-			return "", "", errors.New(string(b))
+			return errors.New(string(b))
 		}
 
 		// Get user ID and access token from admin API
 		log.Println("[Synapse API] Logging in")
 		endpoint, err := url.JoinPath(extCsApiUrl, "/_matrix/client/v3/login")
 		if err != nil {
-			return "", "", err
+			return err
 		}
 		b, err = json.Marshal(map[string]interface{}{
 			"type": "m.login.password",
@@ -187,66 +176,75 @@ func MakeSynapse(domainName string, depNet *NetworkDep) (*SynapseDep, error) {
 			"refresh_token": false,
 		})
 		if err != nil {
-			return "", "", err
+			return err
 		}
 		res, err := http.DefaultClient.Post(endpoint, "application/json", bytes.NewBuffer(b))
 		if err != nil {
-			return "", "", err
+			return err
 		}
 		b, err = io.ReadAll(res.Body)
 		if err != nil {
-			return "", "", err
+			return err
 		}
 		if res.StatusCode != http.StatusOK {
-			return "", "", errors.New(res.Status + "\n" + string(b))
+			return errors.New(res.Status + "\n" + string(b))
 		}
 		log.Println("[Synapse API] " + string(b))
 		m := make(map[string]interface{})
 		err = json.Unmarshal(b, &m)
 		if err != nil {
-			return "", "", err
+			return err
 		}
 
 		var userId interface{}
 		var accessToken interface{}
 		var ok bool
 		if userId, ok = m["user_id"]; !ok {
-			return "", "", errors.New("missing user_id")
+			return errors.New("missing user_id")
 		}
 		if accessToken, ok = m["access_token"]; !ok {
-			return "", "", errors.New("missing access_token")
+			return errors.New("missing access_token")
+		}
+
+		mxClient := &MatrixClient{
+			AccessToken:     accessToken.(string),
+			ClientServerUrl: extCsApiUrl,
+			UserId:          userId.(string),
+			ServerName:      domainName,
+		}
+
+		if admin {
+			adminUsers = append(adminUsers, mxClient)
+		} else {
+			unprivilegedUsers = append(unprivilegedUsers, mxClient)
 		}
 
-		return userId.(string), accessToken.(string), nil
+		return nil
 	}
-	adminUserId, adminAccessToken, err := registerUser("admin", true)
+	err = registerUser("admin", true)
 	if err != nil {
 		return nil, err
 	}
-	aliceUserId, aliceAccessToken, err := registerUser("user_alice", false)
+	err = registerUser("user_alice", false)
 	if err != nil {
 		return nil, err
 	}
-	bobUserId, bobAccessToken, err := registerUser("user_bob", false)
+	err = registerUser("user_bob", false)
 	if err != nil {
 		return nil, err
 	}
 
 	// Create the dependency
 	return &SynapseDep{
-		ctx:                          ctx,
-		pgContainer:                  pgContainer,
-		synContainer:                 synContainer,
-		tmpConfigPath:                f.Name(),
-		InternalClientServerApiUrl:   intCsApiUrl,
-		ExternalClientServerApiUrl:   extCsApiUrl,
-		ServerName:                   domainName,
-		AdminUserId:                  adminUserId,
-		AdminAccessToken:             adminAccessToken,
-		UnprivilegedAliceUserId:      aliceUserId,
-		UnprivilegedAliceAccessToken: aliceAccessToken,
-		UnprivilegedBobUserId:        bobUserId,
-		UnprivilegedBobAccessToken:   bobAccessToken,
+		ctx:                        ctx,
+		pgContainer:                pgContainer,
+		synContainer:               synContainer,
+		tmpConfigPath:              f.Name(),
+		InternalClientServerApiUrl: intCsApiUrl,
+		ExternalClientServerApiUrl: extCsApiUrl,
+		ServerName:                 domainName,
+		AdminUsers:                 adminUsers,
+		UnprivilegedUsers:          unprivilegedUsers,
 	}, nil
 }
 
diff --git a/test/test_internals/util_client.go b/test/test_internals/util_client.go
index 59080654..16d839bd 100644
--- a/test/test_internals/util_client.go
+++ b/test/test_internals/util_client.go
@@ -14,6 +14,17 @@ import (
 type MatrixClient struct {
 	AccessToken     string
 	ClientServerUrl string
+	UserId          string
+	ServerName      string
+}
+
+func (c *MatrixClient) WithCsUrl(newUrl string) *MatrixClient {
+	return &MatrixClient{
+		AccessToken:     c.AccessToken,
+		ClientServerUrl: newUrl,
+		UserId:          c.UserId,
+		ServerName:      c.ServerName,
+	}
 }
 
 func (c *MatrixClient) Upload(filename string, contentType string, body io.Reader) (*MatrixUploadResponse, error) {
diff --git a/test/upload_suite_test.go b/test/upload_suite_test.go
index 50762912..7d75f9ed 100644
--- a/test/upload_suite_test.go
+++ b/test/upload_suite_test.go
@@ -35,10 +35,7 @@ func (s *UploadTestSuite) TearDownSuite() {
 func (s *UploadTestSuite) TestUpload() {
 	t := s.T()
 
-	client := &test_internals.MatrixClient{
-		AccessToken:     s.deps.Homeservers[0].UnprivilegedAliceAccessToken,
-		ClientServerUrl: s.deps.Machines[0].HttpUrl,
-	}
+	client := s.deps.Homeservers[0].UnprivilegedUsers[0].WithCsUrl(s.deps.Machines[0].HttpUrl)
 
 	contentType, img, err := test_internals.MakeTestImage(512, 512)
 	res, err := client.Upload("image"+util.ExtensionForContentType(contentType), contentType, img)
-- 
GitLab