From 181125d0d76bb3d8cfdc28b0b73d6b8b16e665a9 Mon Sep 17 00:00:00 2001
From: Travis Ralston <travpc@gmail.com>
Date: Wed, 9 Aug 2023 13:30:15 -0600
Subject: [PATCH] Support MSC4040 SRV records

---
 matrix/federation.go | 43 +++++++++++++++++++++++++++++++++++++++----
 1 file changed, 39 insertions(+), 4 deletions(-)

diff --git a/matrix/federation.go b/matrix/federation.go
index 0ddda378..f99df303 100644
--- a/matrix/federation.go
+++ b/matrix/federation.go
@@ -137,10 +137,10 @@ func GetServerApiUrl(hostname string) (string, string, error) {
 					return url, wkHost, nil
 				}
 
-				// Step 3c: if the delegated host is not an IP and doesn't have a port, start a SRV lookup and use it
+				// Step 3c: if the delegated host is not an IP and doesn't have a port, start a SRV lookup and use it.
 				// Note: we ignore errors here because the hostname will fail elsewhere.
 				logrus.Debug("Doing SRV on WK host ", wkHost)
-				_, addrs, _ := net.LookupSRV("matrix", "tcp", wkHost)
+				_, addrs, _ := net.LookupSRV("matrix-fed", "tcp", wkHost)
 				if len(addrs) > 0 {
 					// Trim off the trailing period if there is one (golang doesn't like this)
 					realAddr := addrs[0].Target
@@ -154,6 +154,24 @@ func GetServerApiUrl(hostname string) (string, string, error) {
 					return url, wkHost, nil
 				}
 
+				// Step 3d: if the delegated host is not an IP and doesn't have a port, start a DEPRECATED SRV
+				// lookup and use it.
+				// Note: we ignore errors here because the hostname will fail elsewhere.
+				logrus.Debug("Doing SRV on WK host ", wkHost)
+				_, addrs, _ = net.LookupSRV("matrix", "tcp", wkHost)
+				if len(addrs) > 0 {
+					// Trim off the trailing period if there is one (golang doesn't like this)
+					realAddr := addrs[0].Target
+					if realAddr[len(realAddr)-1:] == "." {
+						realAddr = realAddr[0 : len(realAddr)-1]
+					}
+					url := fmt.Sprintf("https://%s", net.JoinHostPort(realAddr, strconv.Itoa(int(addrs[0].Port))))
+					server := cachedServer{url, wkHost}
+					apiUrlCacheInstance.Set(hostname, server, cache.DefaultExpiration)
+					logrus.Debug("Server API URL for " + hostname + " is " + url + " (WK; SRV-Deprecated)")
+					return url, wkHost, nil
+				}
+
 				// Step 3d: use the delegated host as-is
 				logrus.Debug("Using .well-known as-is for ", wkHost)
 				url := fmt.Sprintf("https://%s", net.JoinHostPort(wkHost, wkPort))
@@ -172,7 +190,7 @@ func GetServerApiUrl(hostname string) (string, string, error) {
 	// Step 4: try resolving a hostname using SRV records and use it
 	// Note: we ignore errors here because the hostname will fail elsewhere.
 	logrus.Debug("Doing SRV for host ", hostname)
-	_, addrs, _ := net.LookupSRV("matrix", "tcp", hostname)
+	_, addrs, _ := net.LookupSRV("matrix-fed", "tcp", hostname)
 	if len(addrs) > 0 {
 		// Trim off the trailing period if there is one (golang doesn't like this)
 		realAddr := addrs[0].Target
@@ -186,7 +204,24 @@ func GetServerApiUrl(hostname string) (string, string, error) {
 		return url, h, nil
 	}
 
-	// Step 5: use the target host as-is
+	// Step 5: try resolving a hostname using DEPRECATED SRV records and use it
+	// Note: we ignore errors here because the hostname will fail elsewhere.
+	logrus.Debug("Doing SRV for host ", hostname)
+	_, addrs, _ = net.LookupSRV("matrix", "tcp", hostname)
+	if len(addrs) > 0 {
+		// Trim off the trailing period if there is one (golang doesn't like this)
+		realAddr := addrs[0].Target
+		if realAddr[len(realAddr)-1:] == "." {
+			realAddr = realAddr[0 : len(realAddr)-1]
+		}
+		url := fmt.Sprintf("https://%s", net.JoinHostPort(realAddr, strconv.Itoa(int(addrs[0].Port))))
+		server := cachedServer{url, h}
+		apiUrlCacheInstance.Set(hostname, server, cache.DefaultExpiration)
+		logrus.Debug("Server API URL for " + hostname + " is " + url + " (SRV-Deprecated)")
+		return url, h, nil
+	}
+
+	// Step 6: use the target host as-is
 	logrus.Debug("Using host as-is: ", hostname)
 	url := fmt.Sprintf("https://%s", net.JoinHostPort(h, p))
 	server := cachedServer{url, h}
-- 
GitLab