diff --git a/cluster/idgen.go b/cluster/idgen.go new file mode 100644 index 0000000000000000000000000000000000000000..54c8db15a5c53a5b7f147d35d75edae06096a277 --- /dev/null +++ b/cluster/idgen.go @@ -0,0 +1,40 @@ +package cluster + +import ( + "errors" + "fmt" + "github.com/turt2live/matrix-media-repo/common/config" + "github.com/turt2live/matrix-media-repo/util" + "github.com/turt2live/matrix-media-repo/util/cleanup" + "io/ioutil" + "net/http" + "time" +) + +func GetId() (string, error) { + req, err := http.NewRequest("GET", util.MakeUrl(config.Get().Cluster.IDGenerator.Location, "/v1/id"), nil) + if err != nil { + return "", err + } + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", config.Get().Cluster.IDGenerator.Secret)) + + client := &http.Client{ + Timeout: 1 * time.Second, // the server should be fast (much faster than this) + } + res, err := client.Do(req) + if err != nil { + return "", err + } + defer cleanup.DumpAndCloseStream(res.Body) + + contents, err := ioutil.ReadAll(res.Body) + if err != nil { + return "", err + } + if res.StatusCode != http.StatusOK { + return "", errors.New(fmt.Sprintf("unexpected status code from ID generator: %d", res.StatusCode)) + } + + return string(contents), nil +} diff --git a/cmd/service_idgen/main.go b/cmd/service_idgen/main.go new file mode 100644 index 0000000000000000000000000000000000000000..d38a818605b6376fdbf43681fa791c85e5b28ba5 --- /dev/null +++ b/cmd/service_idgen/main.go @@ -0,0 +1,62 @@ +package main + +import ( + "flag" + "fmt" + "github.com/bwmarrin/snowflake" + "net/http" + "os" + "strconv" +) + +func main() { + machineId := flag.Int("machine", getIdFromEnv(), "The machine ID. 0-1023 (inclusive)") + secret := flag.String("secret", getValFromEnv("API_SECRET", ""), "The API secret to require on requests") + bind := flag.String("bind", getValFromEnv("API_BIND", ":8090"), "Where to bind the API to") + flag.Parse() + + node, err := snowflake.NewNode(int64(*machineId)) + if err != nil { + panic(err) + } + + fmt.Printf("Running as machine %d\n", *machineId) + + expectedSecret := fmt.Sprintf("Bearer %s", *secret) + + http.HandleFunc("/v1/id", func(w http.ResponseWriter, req *http.Request) { + if req.Header.Get("Authorization") != expectedSecret { + w.WriteHeader(http.StatusForbidden) + return + } + + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte("mmr_" + node.Generate().String())) + if err != nil { + fmt.Println(err) + return + } + }) + + err = http.ListenAndServe(*bind, nil) + if err != nil { + panic(err) + } +} + +func getIdFromEnv() int { + if val, ok := os.LookupEnv("MACHINE_ID"); ok { + if i, err := strconv.Atoi(val); err == nil { + return i + } + } + return 0 +} + +func getValFromEnv(key string, def string) string { + if val, ok := os.LookupEnv(key); ok { + return val + } + return def +} diff --git a/common/config/conf_main.go b/common/config/conf_main.go index ecc508d093b6ba79d36f8e1a95011ae18e541d90..0353e781c08b21e7a2e2c592094973be27b22941 100644 --- a/common/config/conf_main.go +++ b/common/config/conf_main.go @@ -16,6 +16,7 @@ type MainRepoConfig struct { Plugins []PluginConfig `yaml:"plugins,flow"` Sentry SentryConfig `yaml:"sentry"` Redis RedisConfig `yaml:"redis"` + Cluster ClusterConfig `yaml:"cluster"` } func NewDefaultMainConfig() MainRepoConfig { @@ -134,5 +135,11 @@ func NewDefaultMainConfig() MainRepoConfig { Enabled: false, Shards: []RedisShardConfig{}, }, + Cluster: ClusterConfig{ + IDGenerator: IDGeneratorConfig{ + Location: "", + Secret: "", + }, + }, } } diff --git a/common/config/models_main.go b/common/config/models_main.go index 084ff06cf2af6c005f35231b6727316b80993031..bde7479dd82ab2a4b56954f7b4b8086e76b3e655 100644 --- a/common/config/models_main.go +++ b/common/config/models_main.go @@ -89,3 +89,12 @@ type RedisShardConfig struct { Name string `yaml:"name"` Address string `yaml:"addr"` } + +type ClusterConfig struct { + IDGenerator IDGeneratorConfig `yaml:"idGenerator"` +} + +type IDGeneratorConfig struct { + Location string `yaml:"location"` + Secret string `yaml:"secret"` +} diff --git a/config.sample.yaml b/config.sample.yaml index 480e8f7865c58ecfec29c6cb35c7e2da1fb65139..9837491e6ecb6540a82693c56b88feca75ab0bdd 100644 --- a/config.sample.yaml +++ b/config.sample.yaml @@ -565,4 +565,18 @@ sentry: environment: "" # Whether or not to turn on sentry's built in debugging. This will increase log output. - debug: false \ No newline at end of file + debug: false + +# Options for controlling clustering behaviour of the media repo. This requires an ID generator +# service in your infrastructure. +# +# For more information see https://docs.t2bot.io/matrix-media-repo/installing/environments/clustered.html +cluster: + # Options for accessing the ID generator service. + idGenerator: + # The secret being used by the ID generator service. Clustering is disabled unless this is set + # to a non-empty string. + secret: "" + + # The URL for where the ID generator service can be reached. + location: "http://localhost:8090" diff --git a/controllers/upload_controller/upload_controller.go b/controllers/upload_controller/upload_controller.go index ce9ee1b4bd7e86609dffdc9c610ebffab78d7291..ae6e5c1ca5168f52f7e7b90f3c4609b480484833 100644 --- a/controllers/upload_controller/upload_controller.go +++ b/controllers/upload_controller/upload_controller.go @@ -1,6 +1,9 @@ package upload_controller import ( + "fmt" + "github.com/getsentry/sentry-go" + "github.com/turt2live/matrix-media-repo/util/ids" "io" "io/ioutil" "strconv" @@ -118,11 +121,7 @@ func UploadMedia(contents io.ReadCloser, contentLength int64, contentType string return nil, errors.New("failed to generate a media ID after 10 rounds") } - mediaId, err = util.GenerateRandomString(64) - if err != nil { - return nil, err - } - mediaId, err = util.GetSha1OfString(mediaId + strconv.FormatInt(util.NowMillis(), 10)) + mediaId, err = ids.NewUniqueId() if err != nil { return nil, err } diff --git a/go.mod b/go.mod index d617647cdd2b7c68c92333d902bdfec7410eea6f..e6e356cc53c7ad32648ae178f0879d70e44fa83f 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,8 @@ require ( github.com/alioygur/is v1.0.3 github.com/bep/debounce v1.2.1 github.com/buckket/go-blurhash v1.1.0 + github.com/bwmarrin/snowflake v0.3.0 + github.com/cenk/backoff v2.2.1+incompatible // indirect github.com/cupcake/sigil v0.0.0-20131127230922-6bf9722f2ae8 github.com/dhowden/tag v0.0.0-20220618230019-adf36e896086 github.com/didip/tollbooth v4.0.2+incompatible diff --git a/go.sum b/go.sum index 975d4d9b75ae4a856f6c192f4fb2f20c6b60e10f..4015d2bf071d8755d1d9998bf4cc1150bab34387 100644 --- a/go.sum +++ b/go.sum @@ -66,6 +66,10 @@ github.com/bep/debounce v1.2.1 h1:v67fRdBA9UQu2NhLFXrSg0Brw7CexQekrBwDMM8bzeY= github.com/bep/debounce v1.2.1/go.mod h1:H8yggRPQKLUhUoqrJC1bO2xNya7vanpDl7xR3ISbCJ0= github.com/buckket/go-blurhash v1.1.0 h1:X5M6r0LIvwdvKiUtiNcRL2YlmOfMzYobI3VCKCZc9Do= github.com/buckket/go-blurhash v1.1.0/go.mod h1:aT2iqo5W9vu9GpyoLErKfTHwgODsZp3bQfXjXJUxNb8= +github.com/buger/jsonparser v0.0.0-20181115193947-bf1c66bbce23/go.mod h1:bbYlZJ7hK1yFx9hf58LP0zeX7UjIGs20ufpu3evjr+s= +github.com/bwmarrin/snowflake v0.3.0 h1:xm67bEhkKh6ij1790JB83OujPR5CzNe8QuQqAgISZN0= +github.com/bwmarrin/snowflake v0.3.0/go.mod h1:NdZxfVWX+oR6y2K0o6qAYv6gIOP9rjG0/E9WsDpxqwE= +github.com/casbin/casbin/v2 v2.1.2/go.mod h1:YcPU1XXisHhLzuxH9coDNf2FbKpjGlbCg3n9yuLkIJQ= github.com/cenk/backoff v2.2.1+incompatible h1:djdFT7f4gF2ttuzRKPbMOWgZajgesItGLwG5FTQKmmE= github.com/cenk/backoff v2.2.1+incompatible/go.mod h1:7FtoeaSnHoZnmZzz47cM35Y9nSW7tNyaidugnHTaFDE= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= diff --git a/pipline/upload_pipeline/step_gen_media_id.go b/pipline/upload_pipeline/step_gen_media_id.go index a2ce11acb280c640074ab09ee912dec82f133631..9f09aa4b5bc696c124171a8a05535e0e04a6e0eb 100644 --- a/pipline/upload_pipeline/step_gen_media_id.go +++ b/pipline/upload_pipeline/step_gen_media_id.go @@ -2,13 +2,12 @@ package upload_pipeline import ( "errors" - "strconv" + "github.com/turt2live/matrix-media-repo/util/ids" "time" "github.com/patrickmn/go-cache" "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/storage" - "github.com/turt2live/matrix-media-repo/util" ) var recentMediaIds = cache.New(30*time.Second, 60*time.Second) @@ -25,14 +24,7 @@ func generateMediaID(ctx rcontext.RequestContext, origin string) (string, error) return "", errors.New("failed to generate a media ID after 10 rounds") } - mediaId, err = util.GenerateRandomString(64) - if err != nil { - return "", err - } - mediaId, err = util.GetSha1OfString(mediaId + strconv.FormatInt(util.NowMillis(), 10)) - if err != nil { - return "", err - } + mediaId, err = ids.NewUniqueId() // Because we use the current time in the media ID, we don't need to worry about // collisions from the database. diff --git a/util/ids/unique.go b/util/ids/unique.go new file mode 100644 index 0000000000000000000000000000000000000000..d097189a05b0fb6b4a296ee755528b72a72cc9b3 --- /dev/null +++ b/util/ids/unique.go @@ -0,0 +1,20 @@ +package ids + +import ( + "github.com/turt2live/matrix-media-repo/cluster" + "github.com/turt2live/matrix-media-repo/common/config" + "github.com/turt2live/matrix-media-repo/util" + "strconv" +) + +func NewUniqueId() (string, error) { + if config.Get().Cluster.IDGenerator.Secret != "" { + return cluster.GetId() + } + + s, err := util.GenerateRandomString(64) + if err != nil { + return "", err + } + return util.GetSha1OfString(s + strconv.FormatInt(util.NowMillis(), 10)) +}