From 5d2a52fce8766aed77d136d481b581c0d29afc51 Mon Sep 17 00:00:00 2001
From: Travis Ralston <travpc@gmail.com>
Date: Sun, 3 Feb 2019 17:27:54 -0700
Subject: [PATCH] Support MSC1708 and prepare for MSC1711

Fixes https://github.com/turt2live/matrix-media-repo/issues/136

Later support for MSC1711 will be done in https://github.com/turt2live/matrix-media-repo/issues/137
---
 .../download_resource_handler.go              |   4 +-
 .../matrix-media-repo/matrix/federation.go    | 115 +++++++++++++++---
 .../matrix-media-repo/matrix/responses.go     |   4 +
 3 files changed, 103 insertions(+), 20 deletions(-)

diff --git a/src/github.com/turt2live/matrix-media-repo/controllers/download_controller/download_resource_handler.go b/src/github.com/turt2live/matrix-media-repo/controllers/download_controller/download_resource_handler.go
index 7b76345e..a15dce89 100644
--- a/src/github.com/turt2live/matrix-media-repo/controllers/download_controller/download_resource_handler.go
+++ b/src/github.com/turt2live/matrix-media-repo/controllers/download_controller/download_resource_handler.go
@@ -190,14 +190,14 @@ func DownloadRemoteMediaDirect(server string, mediaId string, log *logrus.Entry)
 		return nil, item.(error)
 	}
 
-	baseUrl, err := matrix.GetServerApiUrl(server)
+	baseUrl, realHost, err := matrix.GetServerApiUrl(server)
 	if err != nil {
 		downloadErrorsCache.Set(cacheKey, err, cache.DefaultExpiration)
 		return nil, err
 	}
 
 	downloadUrl := baseUrl + "/_matrix/media/v1/download/" + server + "/" + mediaId + "?allow_remote=false"
-	resp, err := matrix.FederatedGet(downloadUrl, server)
+	resp, err := matrix.FederatedGet(downloadUrl, realHost)
 	if err != nil {
 		downloadErrorsCache.Set(cacheKey, err, cache.DefaultExpiration)
 		return nil, err
diff --git a/src/github.com/turt2live/matrix-media-repo/matrix/federation.go b/src/github.com/turt2live/matrix-media-repo/matrix/federation.go
index e211a439..d02db147 100644
--- a/src/github.com/turt2live/matrix-media-repo/matrix/federation.go
+++ b/src/github.com/turt2live/matrix-media-repo/matrix/federation.go
@@ -2,12 +2,16 @@ package matrix
 
 import (
 	"crypto/tls"
+	"encoding/json"
 	"fmt"
+	"io/ioutil"
 	"net"
 	"net/http"
+	"strings"
 	"sync"
 	"time"
 
+	"github.com/alioygur/is"
 	"github.com/patrickmn/go-cache"
 	"github.com/sirupsen/logrus"
 	"github.com/turt2live/matrix-media-repo/common/config"
@@ -24,7 +28,7 @@ func setupCache() {
 	}
 }
 
-func GetServerApiUrl(hostname string) (string, error) {
+func GetServerApiUrl(hostname string) (string, string, error) {
 	logrus.Info("Getting server API URL for " + hostname)
 
 	// Check to see if we've cached this hostname at all
@@ -33,21 +37,96 @@ func GetServerApiUrl(hostname string) (string, error) {
 	if found {
 		url := record.(string)
 		logrus.Info("Server API URL for " + hostname + " is " + url + " (cache)")
-		return url, nil
+		return url, hostname, nil
 	}
 
-	// If not cached, start by seeing if there's a port. If there is a port, use that.
-	// Note: we ignore errors because they are parsing errors. Invalid hostnames will fail through elsewhere.
-	h, p, _ := net.SplitHostPort(hostname)
-	if p != "" {
+	h, p, err := net.SplitHostPort(hostname)
+	defPort := false
+	logrus.Info(err.Error())
+	if err != nil && strings.HasSuffix(err.Error(),"missing port in address") {
+		h, p, err = net.SplitHostPort(hostname + ":8448")
+		defPort = true
+	}
+	if err != nil {
+		return "", "", err
+	}
+
+	// Step 1 of the discovery process: if the hostname is an IP, use that with explicit or default port
+	if is.IP(h) {
+		url := fmt.Sprintf("https://%s:%s", h, p)
+		apiUrlCacheInstance.Set(hostname, url, cache.DefaultExpiration)
+		logrus.Info("Server API URL for " + hostname + " is " + url + " (IP address)")
+		return url, hostname, nil
+	}
+
+	// Step 2: if the hostname is not an IP address, and an explicit port is given, use that
+	if !defPort {
 		url := fmt.Sprintf("https://%s:%s", h, p)
 		apiUrlCacheInstance.Set(hostname, url, cache.DefaultExpiration)
 		logrus.Info("Server API URL for " + hostname + " is " + url + " (explicit port)")
-		return url, nil
+		return url, h, nil
+	}
+
+	// Step 3: if the hostname is not an IP address and no explicit port is given, do .well-known
+	// Note that we have sprawling branches here because we need to fall through to step 4 if parsing fails
+	r, err := http.Get(fmt.Sprintf("https://%s/.well-known/matrix/server", h))
+	if err == nil && r.StatusCode == http.StatusOK {
+		// Try parsing .well-known
+		c, err2 := ioutil.ReadAll(r.Body)
+		if err2 == nil {
+			wk := &wellknownServerResponse{}
+			err3 := json.Unmarshal(c, wk)
+			if err3 == nil && wk.ServerAddr != "" {
+				wkHost, wkPort, err4 := net.SplitHostPort(wk.ServerAddr)
+				wkDefPort := false
+				if err4 != nil && strings.HasSuffix(err4.Error(),"missing port in address") {
+					wkHost, wkPort, err4 = net.SplitHostPort(wk.ServerAddr + ":8448")
+					wkDefPort = true
+				}
+				if err4 == nil {
+					// Step 3a: if the delegated host is an IP address, use that (regardless of port)
+					if is.IP(wkHost) {
+						url := fmt.Sprintf("https://%s:%s", wkHost, wkPort)
+						apiUrlCacheInstance.Set(hostname, url, cache.DefaultExpiration)
+						logrus.Info("Server API URL for " + hostname + " is " + url + " (WK; IP address)")
+						return url, wk.ServerAddr, nil
+					}
+
+					// Step 3b: if the delegated host is not an IP and an explicit port is given, use that
+					if !wkDefPort {
+						url := fmt.Sprintf("https://%s:%s", wkHost, wkPort)
+						apiUrlCacheInstance.Set(hostname, url, cache.DefaultExpiration)
+						logrus.Info("Server API URL for " + hostname + " is " + url + " (WK; explicit port)")
+						return url, wkHost, nil
+					}
+
+					// Step 3c: if the delegated host is not an IP and doesn't have a port, start a SRV lookup and use it
+					// Note: we ignore errors here because the hostname will fail elsewhere.
+					_, addrs, _ := net.LookupSRV("matrix", "tcp", wkHost)
+					if len(addrs) > 0 {
+						// Trim off the trailing period if there is one (golang doesn't like this)
+						realAddr := addrs[0].Target
+						if realAddr[len(realAddr)-1:] == "." {
+							realAddr = realAddr[0 : len(realAddr)-1]
+						}
+						url := fmt.Sprintf("https://%s:%d", realAddr, addrs[0].Port)
+						apiUrlCacheInstance.Set(hostname, url, cache.DefaultExpiration)
+						logrus.Info("Server API URL for " + hostname + " is " + url + " (WK; SRV)")
+						return url, wkHost, nil
+					}
+
+					// Step 3d: use the delegated host as-is
+					url := fmt.Sprintf("https://%s:%s", wkHost, wkPort)
+					apiUrlCacheInstance.Set(hostname, url, cache.DefaultExpiration)
+					logrus.Info("Server API URL for " + hostname + " is " + url + " (WK; fallback)")
+					return url, wkHost, nil
+				}
+			}
+		}
 	}
 
-	// Try resolving by SRV record. If there's at least one result, use that.
-	// Note: we also ignore errors here because the hostname will fail elsewhere.
+	// Step 4: try resolving a hostname using SRV records and use it
+	// Note: we ignore errors here because the hostname will fail elsewhere.
 	_, addrs, _ := net.LookupSRV("matrix", "tcp", hostname)
 	if len(addrs) > 0 {
 		// Trim off the trailing period if there is one (golang doesn't like this)
@@ -58,17 +137,18 @@ func GetServerApiUrl(hostname string) (string, error) {
 		url := fmt.Sprintf("https://%s:%d", realAddr, addrs[0].Port)
 		apiUrlCacheInstance.Set(hostname, url, cache.DefaultExpiration)
 		logrus.Info("Server API URL for " + hostname + " is " + url + " (SRV)")
-		return url, nil
+		return url, h, nil
 	}
 
-	// Lastly fall back to port 8448
-	url := fmt.Sprintf("https://%s:%d", hostname, 8448)
+	// Step 5: use the target host as-is
+	url := fmt.Sprintf("https://%s:%s", h, p)
 	apiUrlCacheInstance.Set(hostname, url, cache.DefaultExpiration)
 	logrus.Info("Server API URL for " + hostname + " is " + url + " (fallback)")
-	return url, nil
+	return url, h, nil
 }
 
 func FederatedGet(url string, realHost string) (*http.Response, error) {
+	// TODO: Support MSC1711 by relying on plain HTTPS requests to servers
 	logrus.Info("Doing federated GET to " + url + " with host " + realHost)
 	transport := &http.Transport{
 		// Based on https://github.com/matrix-org/gomatrixserverlib/blob/51152a681e69a832efcd934b60080b92bc98b286/client.go#L74-L90
@@ -83,19 +163,18 @@ func FederatedGet(url string, realHost string) (*http.Response, error) {
 			// Wrap a raw connection ourselves since tls.Dial defaults the SNI
 			// #125: Some servers require SNI, so we should try it first. Most things on the planet support it.
 			conn := tls.Client(rawconn, &tls.Config{
-				ServerName: realHost,
-				// TODO: We should be checking that the TLS certificate we see here matches one of the allowed SHA-256 fingerprints for the server.
+				ServerName:         realHost,
 				InsecureSkipVerify: true,
 			})
 			if err := conn.Handshake(); err != nil {
 				logrus.Warn("Handshake failed due to ", err, ". Attempting handshake without SNI.");
 				// ...however there are reasons for some servers NOT supplying the correct ServerName, so fallback to not providing one.
 				conn := tls.Client(rawconn, &tls.Config{
-					ServerName: "", // An empty ServerName means we will not try to verify it.
+					ServerName:         "", // An empty ServerName means we will not try to verify it.
 					InsecureSkipVerify: true,
 				})
-				if err := conn.Handshake(); err != nil { 
-				 	return nil, err;
+				if err := conn.Handshake(); err != nil {
+					return nil, err;
 				}
 				return nil, err;
 			}
diff --git a/src/github.com/turt2live/matrix-media-repo/matrix/responses.go b/src/github.com/turt2live/matrix-media-repo/matrix/responses.go
index 19eb759b..8c1cd065 100644
--- a/src/github.com/turt2live/matrix-media-repo/matrix/responses.go
+++ b/src/github.com/turt2live/matrix-media-repo/matrix/responses.go
@@ -17,6 +17,10 @@ type mediaListResponse struct {
 	RemoteMxcs []string `json:"remote"`
 }
 
+type wellknownServerResponse struct {
+	ServerAddr string `json:"m.server"`
+}
+
 type errorResponse struct {
 	ErrorCode string `json:"errcode"`
 	Message   string `json:"error"`
-- 
GitLab