From 5d2a52fce8766aed77d136d481b581c0d29afc51 Mon Sep 17 00:00:00 2001 From: Travis Ralston <travpc@gmail.com> Date: Sun, 3 Feb 2019 17:27:54 -0700 Subject: [PATCH] Support MSC1708 and prepare for MSC1711 Fixes https://github.com/turt2live/matrix-media-repo/issues/136 Later support for MSC1711 will be done in https://github.com/turt2live/matrix-media-repo/issues/137 --- .../download_resource_handler.go | 4 +- .../matrix-media-repo/matrix/federation.go | 115 +++++++++++++++--- .../matrix-media-repo/matrix/responses.go | 4 + 3 files changed, 103 insertions(+), 20 deletions(-) diff --git a/src/github.com/turt2live/matrix-media-repo/controllers/download_controller/download_resource_handler.go b/src/github.com/turt2live/matrix-media-repo/controllers/download_controller/download_resource_handler.go index 7b76345e..a15dce89 100644 --- a/src/github.com/turt2live/matrix-media-repo/controllers/download_controller/download_resource_handler.go +++ b/src/github.com/turt2live/matrix-media-repo/controllers/download_controller/download_resource_handler.go @@ -190,14 +190,14 @@ func DownloadRemoteMediaDirect(server string, mediaId string, log *logrus.Entry) return nil, item.(error) } - baseUrl, err := matrix.GetServerApiUrl(server) + baseUrl, realHost, err := matrix.GetServerApiUrl(server) if err != nil { downloadErrorsCache.Set(cacheKey, err, cache.DefaultExpiration) return nil, err } downloadUrl := baseUrl + "/_matrix/media/v1/download/" + server + "/" + mediaId + "?allow_remote=false" - resp, err := matrix.FederatedGet(downloadUrl, server) + resp, err := matrix.FederatedGet(downloadUrl, realHost) if err != nil { downloadErrorsCache.Set(cacheKey, err, cache.DefaultExpiration) return nil, err diff --git a/src/github.com/turt2live/matrix-media-repo/matrix/federation.go b/src/github.com/turt2live/matrix-media-repo/matrix/federation.go index e211a439..d02db147 100644 --- a/src/github.com/turt2live/matrix-media-repo/matrix/federation.go +++ b/src/github.com/turt2live/matrix-media-repo/matrix/federation.go @@ -2,12 +2,16 @@ package matrix import ( "crypto/tls" + "encoding/json" "fmt" + "io/ioutil" "net" "net/http" + "strings" "sync" "time" + "github.com/alioygur/is" "github.com/patrickmn/go-cache" "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/common/config" @@ -24,7 +28,7 @@ func setupCache() { } } -func GetServerApiUrl(hostname string) (string, error) { +func GetServerApiUrl(hostname string) (string, string, error) { logrus.Info("Getting server API URL for " + hostname) // Check to see if we've cached this hostname at all @@ -33,21 +37,96 @@ func GetServerApiUrl(hostname string) (string, error) { if found { url := record.(string) logrus.Info("Server API URL for " + hostname + " is " + url + " (cache)") - return url, nil + return url, hostname, nil } - // If not cached, start by seeing if there's a port. If there is a port, use that. - // Note: we ignore errors because they are parsing errors. Invalid hostnames will fail through elsewhere. - h, p, _ := net.SplitHostPort(hostname) - if p != "" { + h, p, err := net.SplitHostPort(hostname) + defPort := false + logrus.Info(err.Error()) + if err != nil && strings.HasSuffix(err.Error(),"missing port in address") { + h, p, err = net.SplitHostPort(hostname + ":8448") + defPort = true + } + if err != nil { + return "", "", err + } + + // Step 1 of the discovery process: if the hostname is an IP, use that with explicit or default port + if is.IP(h) { + url := fmt.Sprintf("https://%s:%s", h, p) + apiUrlCacheInstance.Set(hostname, url, cache.DefaultExpiration) + logrus.Info("Server API URL for " + hostname + " is " + url + " (IP address)") + return url, hostname, nil + } + + // Step 2: if the hostname is not an IP address, and an explicit port is given, use that + if !defPort { url := fmt.Sprintf("https://%s:%s", h, p) apiUrlCacheInstance.Set(hostname, url, cache.DefaultExpiration) logrus.Info("Server API URL for " + hostname + " is " + url + " (explicit port)") - return url, nil + return url, h, nil + } + + // Step 3: if the hostname is not an IP address and no explicit port is given, do .well-known + // Note that we have sprawling branches here because we need to fall through to step 4 if parsing fails + r, err := http.Get(fmt.Sprintf("https://%s/.well-known/matrix/server", h)) + if err == nil && r.StatusCode == http.StatusOK { + // Try parsing .well-known + c, err2 := ioutil.ReadAll(r.Body) + if err2 == nil { + wk := &wellknownServerResponse{} + err3 := json.Unmarshal(c, wk) + if err3 == nil && wk.ServerAddr != "" { + wkHost, wkPort, err4 := net.SplitHostPort(wk.ServerAddr) + wkDefPort := false + if err4 != nil && strings.HasSuffix(err4.Error(),"missing port in address") { + wkHost, wkPort, err4 = net.SplitHostPort(wk.ServerAddr + ":8448") + wkDefPort = true + } + if err4 == nil { + // Step 3a: if the delegated host is an IP address, use that (regardless of port) + if is.IP(wkHost) { + url := fmt.Sprintf("https://%s:%s", wkHost, wkPort) + apiUrlCacheInstance.Set(hostname, url, cache.DefaultExpiration) + logrus.Info("Server API URL for " + hostname + " is " + url + " (WK; IP address)") + return url, wk.ServerAddr, nil + } + + // Step 3b: if the delegated host is not an IP and an explicit port is given, use that + if !wkDefPort { + url := fmt.Sprintf("https://%s:%s", wkHost, wkPort) + apiUrlCacheInstance.Set(hostname, url, cache.DefaultExpiration) + logrus.Info("Server API URL for " + hostname + " is " + url + " (WK; explicit port)") + return url, wkHost, nil + } + + // Step 3c: if the delegated host is not an IP and doesn't have a port, start a SRV lookup and use it + // Note: we ignore errors here because the hostname will fail elsewhere. + _, addrs, _ := net.LookupSRV("matrix", "tcp", wkHost) + if len(addrs) > 0 { + // Trim off the trailing period if there is one (golang doesn't like this) + realAddr := addrs[0].Target + if realAddr[len(realAddr)-1:] == "." { + realAddr = realAddr[0 : len(realAddr)-1] + } + url := fmt.Sprintf("https://%s:%d", realAddr, addrs[0].Port) + apiUrlCacheInstance.Set(hostname, url, cache.DefaultExpiration) + logrus.Info("Server API URL for " + hostname + " is " + url + " (WK; SRV)") + return url, wkHost, nil + } + + // Step 3d: use the delegated host as-is + url := fmt.Sprintf("https://%s:%s", wkHost, wkPort) + apiUrlCacheInstance.Set(hostname, url, cache.DefaultExpiration) + logrus.Info("Server API URL for " + hostname + " is " + url + " (WK; fallback)") + return url, wkHost, nil + } + } + } } - // Try resolving by SRV record. If there's at least one result, use that. - // Note: we also ignore errors here because the hostname will fail elsewhere. + // Step 4: try resolving a hostname using SRV records and use it + // Note: we ignore errors here because the hostname will fail elsewhere. _, addrs, _ := net.LookupSRV("matrix", "tcp", hostname) if len(addrs) > 0 { // Trim off the trailing period if there is one (golang doesn't like this) @@ -58,17 +137,18 @@ func GetServerApiUrl(hostname string) (string, error) { url := fmt.Sprintf("https://%s:%d", realAddr, addrs[0].Port) apiUrlCacheInstance.Set(hostname, url, cache.DefaultExpiration) logrus.Info("Server API URL for " + hostname + " is " + url + " (SRV)") - return url, nil + return url, h, nil } - // Lastly fall back to port 8448 - url := fmt.Sprintf("https://%s:%d", hostname, 8448) + // Step 5: use the target host as-is + url := fmt.Sprintf("https://%s:%s", h, p) apiUrlCacheInstance.Set(hostname, url, cache.DefaultExpiration) logrus.Info("Server API URL for " + hostname + " is " + url + " (fallback)") - return url, nil + return url, h, nil } func FederatedGet(url string, realHost string) (*http.Response, error) { + // TODO: Support MSC1711 by relying on plain HTTPS requests to servers logrus.Info("Doing federated GET to " + url + " with host " + realHost) transport := &http.Transport{ // Based on https://github.com/matrix-org/gomatrixserverlib/blob/51152a681e69a832efcd934b60080b92bc98b286/client.go#L74-L90 @@ -83,19 +163,18 @@ func FederatedGet(url string, realHost string) (*http.Response, error) { // Wrap a raw connection ourselves since tls.Dial defaults the SNI // #125: Some servers require SNI, so we should try it first. Most things on the planet support it. conn := tls.Client(rawconn, &tls.Config{ - ServerName: realHost, - // TODO: We should be checking that the TLS certificate we see here matches one of the allowed SHA-256 fingerprints for the server. + ServerName: realHost, InsecureSkipVerify: true, }) if err := conn.Handshake(); err != nil { logrus.Warn("Handshake failed due to ", err, ". Attempting handshake without SNI."); // ...however there are reasons for some servers NOT supplying the correct ServerName, so fallback to not providing one. conn := tls.Client(rawconn, &tls.Config{ - ServerName: "", // An empty ServerName means we will not try to verify it. + ServerName: "", // An empty ServerName means we will not try to verify it. InsecureSkipVerify: true, }) - if err := conn.Handshake(); err != nil { - return nil, err; + if err := conn.Handshake(); err != nil { + return nil, err; } return nil, err; } diff --git a/src/github.com/turt2live/matrix-media-repo/matrix/responses.go b/src/github.com/turt2live/matrix-media-repo/matrix/responses.go index 19eb759b..8c1cd065 100644 --- a/src/github.com/turt2live/matrix-media-repo/matrix/responses.go +++ b/src/github.com/turt2live/matrix-media-repo/matrix/responses.go @@ -17,6 +17,10 @@ type mediaListResponse struct { RemoteMxcs []string `json:"remote"` } +type wellknownServerResponse struct { + ServerAddr string `json:"m.server"` +} + type errorResponse struct { ErrorCode string `json:"errcode"` Message string `json:"error"` -- GitLab