From c60269a512ec4eafab96906173957fc63a59be81 Mon Sep 17 00:00:00 2001
From: Travis Ralston <travpc@gmail.com>
Date: Sat, 13 May 2023 14:20:38 -0600
Subject: [PATCH] Move snowflake stuff in-process

---
 cluster/idgen.go                      | 41 ---------------
 cmd/export_synapse_for_import/main.go |  2 +
 cmd/gdpr_export/main.go               |  1 +
 cmd/gdpr_import/main.go               |  1 +
 cmd/import_synapse/main.go            |  1 +
 cmd/s3_consistency_check/main.go      |  4 +-
 cmd/service_idgen/main.go             | 73 ---------------------------
 common/config/conf_main.go            |  7 ---
 common/config/models_main.go          |  9 ----
 common/runtime/init.go                | 13 +++++
 config.sample.yaml                    | 14 -----
 util/ids/snowflake.go                 | 32 ++++++++++++
 util/ids/unique.go                    | 13 ++---
 13 files changed, 58 insertions(+), 153 deletions(-)
 delete mode 100644 cluster/idgen.go
 delete mode 100644 cmd/service_idgen/main.go
 create mode 100644 util/ids/snowflake.go

diff --git a/cluster/idgen.go b/cluster/idgen.go
deleted file mode 100644
index 89050c44..00000000
--- a/cluster/idgen.go
+++ /dev/null
@@ -1,41 +0,0 @@
-package cluster
-
-import (
-	"errors"
-	"fmt"
-	"io"
-	"net/http"
-	"time"
-
-	"github.com/turt2live/matrix-media-repo/common/config"
-	"github.com/turt2live/matrix-media-repo/util"
-	"github.com/turt2live/matrix-media-repo/util/stream_util"
-)
-
-func GetId() (string, error) {
-	req, err := http.NewRequest("GET", util.MakeUrl(config.Get().Cluster.IDGenerator.Location, "/v1/id"), nil)
-	if err != nil {
-		return "", err
-	}
-
-	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", config.Get().Cluster.IDGenerator.Secret))
-
-	client := &http.Client{
-		Timeout: 1 * time.Second, // the server should be fast (much faster than this)
-	}
-	res, err := client.Do(req)
-	if err != nil {
-		return "", err
-	}
-	defer stream_util.DumpAndCloseStream(res.Body)
-
-	contents, err := io.ReadAll(res.Body)
-	if err != nil {
-		return "", err
-	}
-	if res.StatusCode != http.StatusOK {
-		return "", errors.New(fmt.Sprintf("unexpected status code from ID generator: %d", res.StatusCode))
-	}
-
-	return string(contents), nil
-}
diff --git a/cmd/export_synapse_for_import/main.go b/cmd/export_synapse_for_import/main.go
index c71b071b..7c6c7c2b 100644
--- a/cmd/export_synapse_for_import/main.go
+++ b/cmd/export_synapse_for_import/main.go
@@ -16,6 +16,7 @@ import (
 	"github.com/turt2live/matrix-media-repo/common/config"
 	"github.com/turt2live/matrix-media-repo/common/logging"
 	"github.com/turt2live/matrix-media-repo/common/rcontext"
+	"github.com/turt2live/matrix-media-repo/common/runtime"
 	"github.com/turt2live/matrix-media-repo/synapse"
 	"github.com/turt2live/matrix-media-repo/util"
 	"github.com/turt2live/matrix-media-repo/util/stream_util"
@@ -69,6 +70,7 @@ func main() {
 	}
 
 	logrus.Info("Setting up for importing...")
+	runtime.CheckIdGenerator()
 
 	connectionString := "postgres://" + *postgresUsername + ":" + realPsqlPassword + "@" + *postgresHost + ":" + strconv.Itoa(*postgresPort) + "/" + *postgresDatabase + "?sslmode=disable"
 
diff --git a/cmd/gdpr_export/main.go b/cmd/gdpr_export/main.go
index 0b5887e0..722c52e9 100644
--- a/cmd/gdpr_export/main.go
+++ b/cmd/gdpr_export/main.go
@@ -41,6 +41,7 @@ func main() {
 	}
 
 	config.Path = *configPath
+	runtime.CheckIdGenerator()
 	assets.SetupMigrations(*migrationsPath)
 	assets.SetupTemplates(*templatesPath)
 
diff --git a/cmd/gdpr_import/main.go b/cmd/gdpr_import/main.go
index 95cf08bd..9b14db3f 100644
--- a/cmd/gdpr_import/main.go
+++ b/cmd/gdpr_import/main.go
@@ -31,6 +31,7 @@ func main() {
 	}
 
 	config.Path = *configPath
+	runtime.CheckIdGenerator()
 	assets.SetupMigrations(*migrationsPath)
 
 	var err error
diff --git a/cmd/import_synapse/main.go b/cmd/import_synapse/main.go
index 4c8b246a..5de58b74 100644
--- a/cmd/import_synapse/main.go
+++ b/cmd/import_synapse/main.go
@@ -52,6 +52,7 @@ func main() {
 	}
 
 	config.Path = *configPath
+	runtime.CheckIdGenerator()
 	assets.SetupMigrations(*migrationsPath)
 
 	var realPsqlPassword string
diff --git a/cmd/s3_consistency_check/main.go b/cmd/s3_consistency_check/main.go
index e9162a97..61e145a3 100644
--- a/cmd/s3_consistency_check/main.go
+++ b/cmd/s3_consistency_check/main.go
@@ -3,6 +3,8 @@ package main
 import (
 	"flag"
 	"fmt"
+	"os"
+
 	"github.com/sirupsen/logrus"
 	"github.com/turt2live/matrix-media-repo/common/assets"
 	"github.com/turt2live/matrix-media-repo/common/config"
@@ -11,7 +13,6 @@ import (
 	"github.com/turt2live/matrix-media-repo/common/runtime"
 	"github.com/turt2live/matrix-media-repo/storage"
 	"github.com/turt2live/matrix-media-repo/storage/datastore"
-	"os"
 )
 
 func main() {
@@ -29,6 +30,7 @@ func main() {
 	}
 
 	config.Path = *configPath
+	runtime.CheckIdGenerator()
 	assets.SetupMigrations(*migrationsPath)
 	assets.SetupTemplates(*templatesPath)
 
diff --git a/cmd/service_idgen/main.go b/cmd/service_idgen/main.go
deleted file mode 100644
index c7b2b1df..00000000
--- a/cmd/service_idgen/main.go
+++ /dev/null
@@ -1,73 +0,0 @@
-package main
-
-import (
-	"flag"
-	"fmt"
-	"github.com/bwmarrin/snowflake"
-	"github.com/sirupsen/logrus"
-	"github.com/turt2live/matrix-media-repo/util"
-	"net/http"
-	"os"
-	"strconv"
-)
-
-func main() {
-	machineId := flag.Int("machine", getIdFromEnv(), "The machine ID. 0-1023 (inclusive)")
-	secret := flag.String("secret", getValFromEnv("API_SECRET", ""), "The API secret to require on requests")
-	bind := flag.String("bind", getValFromEnv("API_BIND", ":8090"), "Where to bind the API to")
-	flag.Parse()
-
-	node, err := snowflake.NewNode(int64(*machineId))
-	if err != nil {
-		panic(err)
-	}
-
-	fmt.Printf("Running as machine %d\n", *machineId)
-
-	expectedSecret := fmt.Sprintf("Bearer %s", *secret)
-
-	http.HandleFunc("/v1/id", func(w http.ResponseWriter, req *http.Request) {
-		if req.Header.Get("Authorization") != expectedSecret {
-			w.WriteHeader(http.StatusForbidden)
-			return
-		}
-
-		// Generate a random string to pad out the returned ID
-		r, err := util.GenerateRandomString(32)
-		if err != nil {
-			logrus.Error(err)
-			w.WriteHeader(500)
-			return
-		}
-		s := r + node.Generate().String()
-
-		w.Header().Set("Content-Type", "text/plain")
-		w.WriteHeader(http.StatusOK)
-		_, err = w.Write([]byte(s))
-		if err != nil {
-			fmt.Println(err)
-			return
-		}
-	})
-
-	err = http.ListenAndServe(*bind, nil)
-	if err != nil {
-		panic(err)
-	}
-}
-
-func getIdFromEnv() int {
-	if val, ok := os.LookupEnv("MACHINE_ID"); ok {
-		if i, err := strconv.Atoi(val); err == nil {
-			return i
-		}
-	}
-	return 0
-}
-
-func getValFromEnv(key string, def string) string {
-	if val, ok := os.LookupEnv(key); ok {
-		return val
-	}
-	return def
-}
diff --git a/common/config/conf_main.go b/common/config/conf_main.go
index 0353e781..ecc508d0 100644
--- a/common/config/conf_main.go
+++ b/common/config/conf_main.go
@@ -16,7 +16,6 @@ type MainRepoConfig struct {
 	Plugins           []PluginConfig        `yaml:"plugins,flow"`
 	Sentry            SentryConfig          `yaml:"sentry"`
 	Redis             RedisConfig           `yaml:"redis"`
-	Cluster           ClusterConfig         `yaml:"cluster"`
 }
 
 func NewDefaultMainConfig() MainRepoConfig {
@@ -135,11 +134,5 @@ func NewDefaultMainConfig() MainRepoConfig {
 			Enabled: false,
 			Shards:  []RedisShardConfig{},
 		},
-		Cluster: ClusterConfig{
-			IDGenerator: IDGeneratorConfig{
-				Location: "",
-				Secret:   "",
-			},
-		},
 	}
 }
diff --git a/common/config/models_main.go b/common/config/models_main.go
index d8f5948a..fcd043a6 100644
--- a/common/config/models_main.go
+++ b/common/config/models_main.go
@@ -90,12 +90,3 @@ type RedisShardConfig struct {
 	Name    string `yaml:"name"`
 	Address string `yaml:"addr"`
 }
-
-type ClusterConfig struct {
-	IDGenerator IDGeneratorConfig `yaml:"idGenerator"`
-}
-
-type IDGeneratorConfig struct {
-	Location string `yaml:"location"`
-	Secret   string `yaml:"secret"`
-}
diff --git a/common/runtime/init.go b/common/runtime/init.go
index 4d278030..4aaf5065 100644
--- a/common/runtime/init.go
+++ b/common/runtime/init.go
@@ -4,6 +4,7 @@ import (
 	"fmt"
 
 	"github.com/getsentry/sentry-go"
+	"github.com/turt2live/matrix-media-repo/util/ids"
 
 	"github.com/sirupsen/logrus"
 	"github.com/turt2live/matrix-media-repo/common/config"
@@ -17,6 +18,7 @@ import (
 
 func RunStartupSequence() {
 	version.Print(true)
+	CheckIdGenerator()
 	config.PrintDomainInfo()
 	config.CheckDeprecations()
 	LoadDatabase()
@@ -80,3 +82,14 @@ func LoadDatastores() {
 		}
 	}
 }
+
+func CheckIdGenerator() {
+	// Create a throwaway ID to ensure no errors
+	_, err := ids.NewUniqueId()
+	if err != nil {
+		panic(err)
+	}
+
+	id := ids.GetMachineId()
+	logrus.Infof("Running as machine %d for ID generation. This ID must be unique within your cluster.", id)
+}
diff --git a/config.sample.yaml b/config.sample.yaml
index e3570bb4..a3900d94 100644
--- a/config.sample.yaml
+++ b/config.sample.yaml
@@ -578,17 +578,3 @@ sentry:
 
   # Whether or not to turn on sentry's built in debugging. This will increase log output.
   debug: false
-
-# Options for controlling clustering behaviour of the media repo. This requires an ID generator
-# service in your infrastructure.
-#
-# For more information see https://docs.t2bot.io/matrix-media-repo/installing/environments/clustered.html
-cluster:
-  # Options for accessing the ID generator service.
-  idGenerator:
-    # The secret being used by the ID generator service. Clustering is disabled unless this is set
-    # to a non-empty string.
-    secret: ""
-
-    # The URL for where the ID generator service can be reached.
-    location: "http://localhost:8090"
diff --git a/util/ids/snowflake.go b/util/ids/snowflake.go
new file mode 100644
index 00000000..73f531ac
--- /dev/null
+++ b/util/ids/snowflake.go
@@ -0,0 +1,32 @@
+package ids
+
+import (
+	"os"
+	"strconv"
+
+	"github.com/bwmarrin/snowflake"
+)
+
+func GetMachineId() int64 {
+	if val, ok := os.LookupEnv("MACHINE_ID"); ok {
+		if i, err := strconv.ParseInt(val, 10, 64); err == nil {
+			return i
+		}
+	}
+	return 0
+}
+
+var sfnode *snowflake.Node
+
+func makeSnowflake() (*snowflake.Node, error) {
+	if sfnode != nil {
+		return sfnode, nil
+	}
+	machineId := GetMachineId()
+	node, err := snowflake.NewNode(machineId)
+	if err != nil {
+		return nil, err
+	}
+	sfnode = node
+	return sfnode, nil
+}
diff --git a/util/ids/unique.go b/util/ids/unique.go
index 2baa7fec..a7ff7e43 100644
--- a/util/ids/unique.go
+++ b/util/ids/unique.go
@@ -1,20 +1,17 @@
 package ids
 
 import (
-	"github.com/turt2live/matrix-media-repo/cluster"
-	"github.com/turt2live/matrix-media-repo/common/config"
 	"github.com/turt2live/matrix-media-repo/util"
-	"strconv"
 )
 
 func NewUniqueId() (string, error) {
-	if config.Get().Cluster.IDGenerator.Secret != "" {
-		return cluster.GetId()
+	r, err := util.GenerateRandomString(32) // pad out the snowflake
+	if err != nil {
+		return "", err
 	}
-
-	b, err := util.GenerateRandomBytes(64)
+	sf, err := makeSnowflake()
 	if err != nil {
 		return "", err
 	}
-	return util.GetSha1OfString(string(b) + strconv.FormatInt(util.NowMillis(), 10))
+	return r + sf.Generate().String(), nil
 }
-- 
GitLab