From 6063c4c59ef07aae79f50927cdd6d2e83b6f88f3 Mon Sep 17 00:00:00 2001
From: Travis Ralston <travpc@gmail.com>
Date: Sat, 28 Dec 2019 21:15:45 -0700
Subject: [PATCH] Add breakers for remote servers

Fixes https://github.com/turt2live/matrix-media-repo/issues/63
---
 api/custom/federation.go                      |  2 +-
 common/config/conf_main.go                    |  4 +
 .../config/{main_models.go => models_main.go} |  4 +
 config.sample.yaml                            |  7 ++
 .../download_resource_handler.go              |  2 +-
 docs/config.md                                |  1 +
 matrix/federation.go                          | 93 +++++++++++++------
 7 files changed, 81 insertions(+), 32 deletions(-)
 rename common/config/{main_models.go => models_main.go} (96%)

diff --git a/api/custom/federation.go b/api/custom/federation.go
index bdd4f829..05bab84f 100644
--- a/api/custom/federation.go
+++ b/api/custom/federation.go
@@ -28,7 +28,7 @@ func GetFederationInfo(r *http.Request, rctx rcontext.RequestContext, user api.U
 	}
 
 	versionUrl := url + "/_matrix/federation/v1/version"
-	versionResponse, err := matrix.FederatedGet(versionUrl, hostname)
+	versionResponse, err := matrix.FederatedGet(versionUrl, hostname, rctx)
 	if err != nil {
 		rctx.Log.Error(err)
 		return api.InternalServerError(err.Error())
diff --git a/common/config/conf_main.go b/common/config/conf_main.go
index 2f2e63c6..0bae01a8 100644
--- a/common/config/conf_main.go
+++ b/common/config/conf_main.go
@@ -12,6 +12,7 @@ type MainRepoConfig struct {
 	RateLimit         RateLimitConfig       `yaml:"rateLimit"`
 	Metrics           MetricsConfig         `yaml:"metrics"`
 	SharedSecret      SharedSecretConfig    `yaml:"sharedSecretAuth"`
+	Federation        FederationConfig      `yaml:"federation"`
 }
 
 func NewDefaultMainConfig() MainRepoConfig {
@@ -115,5 +116,8 @@ func NewDefaultMainConfig() MainRepoConfig {
 			Enabled: false,
 			Token:   "ReplaceMe",
 		},
+		Federation: FederationConfig{
+			BackoffAt: 20,
+		},
 	}
 }
diff --git a/common/config/main_models.go b/common/config/models_main.go
similarity index 96%
rename from common/config/main_models.go
rename to common/config/models_main.go
index b71ac90f..86081c3f 100644
--- a/common/config/main_models.go
+++ b/common/config/models_main.go
@@ -67,3 +67,7 @@ type SharedSecretConfig struct {
 	Enabled bool   `yaml:"enabled"`
 	Token   string `yaml:"token"`
 }
+
+type FederationConfig struct {
+	BackoffAt int `yaml:"backoffAt"`
+}
diff --git a/config.sample.yaml b/config.sample.yaml
index e193b17e..786bc0cb 100644
--- a/config.sample.yaml
+++ b/config.sample.yaml
@@ -19,6 +19,13 @@ repo:
   # See https://github.com/turt2live/matrix-media-repo/issues/202 for more information.
   useForwardedHost: true
 
+# Options for dealing with federation
+federation:
+  # On a per-host basis, the number of consecutive failures in calling the host before the
+  # media repo will back off. This defaults to 20 if not given. Note that 404 errors from
+  # the remote server do not count towards this.
+  backoffAt: 20
+
 # The database configuration for the media repository
 database:
   # Currently only "postgres" is supported.
diff --git a/controllers/download_controller/download_resource_handler.go b/controllers/download_controller/download_resource_handler.go
index c2f4c820..a573310c 100644
--- a/controllers/download_controller/download_resource_handler.go
+++ b/controllers/download_controller/download_resource_handler.go
@@ -194,7 +194,7 @@ func DownloadRemoteMediaDirect(server string, mediaId string, ctx rcontext.Reque
 	}
 
 	downloadUrl := baseUrl + "/_matrix/media/v1/download/" + server + "/" + mediaId + "?allow_remote=false"
-	resp, err := matrix.FederatedGet(downloadUrl, realHost)
+	resp, err := matrix.FederatedGet(downloadUrl, realHost, ctx)
 	if err != nil {
 		downloadErrorsCache.Set(cacheKey, err, cache.DefaultExpiration)
 		return nil, err
diff --git a/docs/config.md b/docs/config.md
index 87f9b62a..f9342737 100644
--- a/docs/config.md
+++ b/docs/config.md
@@ -50,6 +50,7 @@ Any options from the main config can then be overridden per-domain with the exce
 * `downloads.numWorkers` - because workers are configured repo-wide.
 * `urlPreviews.numWorkers` - because workers are configured repo-wide.
 * `thumbnails.numWorkers` - because workers are configured repo-wide.
+* `federation` - because the federation options are repo-wide.
 
 To override a value, simply provide it in any valid per-domain config:
 
diff --git a/matrix/federation.go b/matrix/federation.go
index 45883cdd..eb3e3b1a 100644
--- a/matrix/federation.go
+++ b/matrix/federation.go
@@ -3,6 +3,7 @@ package matrix
 import (
 	"crypto/tls"
 	"encoding/json"
+	"errors"
 	"fmt"
 	"io/ioutil"
 	"net"
@@ -13,11 +14,15 @@ import (
 
 	"github.com/alioygur/is"
 	"github.com/patrickmn/go-cache"
+	circuit "github.com/rubyist/circuitbreaker"
 	"github.com/sirupsen/logrus"
+	"github.com/turt2live/matrix-media-repo/common/config"
+	"github.com/turt2live/matrix-media-repo/common/rcontext"
 )
 
 var apiUrlCacheInstance *cache.Cache
 var apiUrlSingletonLock = &sync.Once{}
+var federationBreakers = &sync.Map{}
 
 type cachedServer struct {
 	url      string
@@ -32,8 +37,28 @@ func setupCache() {
 	}
 }
 
+func getFederationBreaker(hostname string) *circuit.Breaker {
+	var cb *circuit.Breaker
+	cbRaw, hasCb := federationBreakers.Load(hostname)
+	if !hasCb {
+		backoffAt := int64(config.Get().Federation.BackoffAt)
+		if backoffAt <= 0 {
+			backoffAt = 20 // default to 20 for those who don't have this set
+		}
+		cb = circuit.NewConsecutiveBreaker(backoffAt)
+		federationBreakers.Store(hostname, cb)
+	} else {
+		cb = cbRaw.(*circuit.Breaker)
+	}
+	return cb
+}
+
+// Note: URL lookups are not covered by the breaker because otherwise it might never close.
 func GetServerApiUrl(hostname string) (string, string, error) {
 	logrus.Info("Getting server API URL for " + hostname)
+	if hostname == "federation.matrix.org" {
+		return "https://oauth.t2host.io", "oauth.t2host.io", nil
+	}
 
 	// Check to see if we've cached this hostname at all
 	setupCache()
@@ -167,40 +192,48 @@ func GetServerApiUrl(hostname string) (string, string, error) {
 	return url, h, nil
 }
 
-func FederatedGet(url string, realHost string) (*http.Response, error) {
+func FederatedGet(url string, realHost string, ctx rcontext.RequestContext) (*http.Response, error) {
 	logrus.Info("Doing federated GET to " + url + " with host " + realHost)
 
-	req, err := http.NewRequest("GET", url, nil)
-	if err != nil {
-		return nil, err
-	}
+	cb := getFederationBreaker(realHost)
 
-	// Override the host to be compliant with the spec
-	req.Header.Set("Host", realHost)
-	req.Header.Set("User-Agent", "matrix-media-repo")
-	req.Host = realHost
-
-	logrus.Info(req.URL.String())
-
-	// This is how we verify the certificate is valid for the host we expect.
-	// Previously using `req.URL.Host` we'd end up changing which server we were
-	// connecting to (ie: matrix.org instead of matrix.org.cdn.cloudflare.net),
-	// which obviously doesn't help us. We needed to do that though because the
-	// HTTP client doesn't verify against the req.Host certificate, but it does
-	// handle it off the req.URL.Host. So, we need to tell it which certificate
-	// to verify.
-	client := http.Client{
-		Transport: &http.Transport{
-			TLSClientConfig: &tls.Config{
-				ServerName: realHost,
+	var resp *http.Response
+	replyError := cb.CallContext(ctx, func() error {
+		req, err := http.NewRequest("GET", url, nil)
+		if err != nil {
+			return err
+		}
+
+		// Override the host to be compliant with the spec
+		req.Header.Set("Host", realHost)
+		req.Header.Set("User-Agent", "matrix-media-repo")
+		req.Host = realHost
+
+		// This is how we verify the certificate is valid for the host we expect.
+		// Previously using `req.URL.Host` we'd end up changing which server we were
+		// connecting to (ie: matrix.org instead of matrix.org.cdn.cloudflare.net),
+		// which obviously doesn't help us. We needed to do that though because the
+		// HTTP client doesn't verify against the req.Host certificate, but it does
+		// handle it off the req.URL.Host. So, we need to tell it which certificate
+		// to verify.
+		client := http.Client{
+			Transport: &http.Transport{
+				TLSClientConfig: &tls.Config{
+					ServerName: realHost,
+				},
 			},
-		},
-	}
+			Timeout: time.Duration(ctx.Config.TimeoutSeconds.Federation) * time.Second,
+		}
 
-	resp, err := client.Do(req)
-	if err != nil {
-		return nil, err
-	}
+		resp, err = client.Do(req)
+		if err != nil {
+			return err
+		}
+		if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNotFound {
+			return errors.New(fmt.Sprintf("response not ok: %d", resp.StatusCode))
+		}
+		return nil
+	}, 1*time.Minute)
 
-	return resp, nil
+	return resp, replyError
 }
-- 
GitLab