From ff90a5c5ebbf8791a253e329c1702899a2136ce9 Mon Sep 17 00:00:00 2001
From: Travis Ralston <travpc@gmail.com>
Date: Sat, 28 Dec 2019 17:49:48 -0700
Subject: [PATCH] Move URL preview checks further down to avoid eager
 rejections

Fixes https://github.com/turt2live/matrix-media-repo/issues/173
---
 controllers/preview_controller/acl/acl.go     | 44 +++++--------------
 .../preview_controller/preview_controller.go  | 18 +++++---
 .../preview_controller/preview_types/types.go |  2 -
 .../preview_controller/previewers/http.go     | 26 ++++++-----
 .../previewers/opengraph_previewer.go         | 19 ++------
 5 files changed, 43 insertions(+), 66 deletions(-)

diff --git a/controllers/preview_controller/acl/acl.go b/controllers/preview_controller/acl/acl.go
index 7a6e11c2..80e91f91 100644
--- a/controllers/preview_controller/acl/acl.go
+++ b/controllers/preview_controller/acl/acl.go
@@ -3,45 +3,30 @@ package acl
 import (
 	"fmt"
 	"net"
-	"net/url"
 
 	"github.com/sirupsen/logrus"
 	"github.com/turt2live/matrix-media-repo/common"
 	"github.com/turt2live/matrix-media-repo/common/rcontext"
-	"github.com/turt2live/matrix-media-repo/controllers/preview_controller/preview_types"
-	"github.com/turt2live/matrix-media-repo/storage"
 )
 
-func ValidateUrlForPreview(urlStr string, ctx rcontext.RequestContext) (*preview_types.UrlPayload, error) {
-	db := storage.GetDatabase().GetUrlStore(ctx)
-
-	parsedUrl, err := url.Parse(urlStr)
-	if err != nil {
-		ctx.Log.Error("Error parsing URL: ", err.Error())
-		db.InsertPreviewError(urlStr, common.ErrCodeInvalidHost)
-		return nil, common.ErrInvalidHost
-	}
-	parsedUrl.Fragment = "" // Remove fragment because it's not important for servers
-
-	realHost, _, err := net.SplitHostPort(parsedUrl.Host)
+func GetSafeAddress(addr string, ctx rcontext.RequestContext) (net.IP, string, error) {
+	ctx.Log.Info("Checking address: " + addr)
+	realHost, p, err := net.SplitHostPort(addr)
 	if err != nil {
 		ctx.Log.Warn("Error parsing host and port: ", err.Error())
-		realHost = parsedUrl.Host
+		realHost = addr
 	}
 
-	addr := net.IPv4(127, 0, 0, 1)
+	ipAddr := net.IPv4(127, 0, 0, 1)
 	if realHost != "localhost" {
 		addrs, err := net.LookupIP(realHost)
 		if err != nil {
-			ctx.Log.Error("Error getting host info: ", err.Error())
-			db.InsertPreviewError(urlStr, common.ErrCodeInvalidHost)
-			return nil, common.ErrInvalidHost
+			return nil, "", common.ErrInvalidHost
 		}
 		if len(addrs) == 0 {
-			db.InsertPreviewError(urlStr, common.ErrCodeHostNotFound)
-			return nil, common.ErrHostNotFound
+			return nil, "", common.ErrHostNotFound
 		}
-		addr = addrs[0]
+		ipAddr = addrs[0]
 	}
 
 	allowedCidrs := ctx.Config.UrlPreviews.AllowedNetworks
@@ -57,17 +42,10 @@ func ValidateUrlForPreview(urlStr string, ctx rcontext.RequestContext) (*preview
 	deniedCidrs = append(deniedCidrs, "0.0.0.0/32")
 	deniedCidrs = append(deniedCidrs, "::/128")
 
-	if !isAllowed(addr, allowedCidrs, deniedCidrs, ctx) {
-		db.InsertPreviewError(urlStr, common.ErrCodeHostBlacklisted)
-		return nil, common.ErrHostBlacklisted
-	}
-
-	urlToPreview := &preview_types.UrlPayload{
-		UrlString: urlStr,
-		ParsedUrl: parsedUrl,
-		Address:   addr,
+	if !isAllowed(ipAddr, allowedCidrs, deniedCidrs, ctx) {
+		return nil, "", common.ErrHostBlacklisted
 	}
-	return urlToPreview, nil
+	return ipAddr, p, nil
 }
 
 func isAllowed(ip net.IP, allowed []string, disallowed []string, ctx rcontext.RequestContext) bool {
diff --git a/controllers/preview_controller/preview_controller.go b/controllers/preview_controller/preview_controller.go
index bd5ef113..743f5197 100644
--- a/controllers/preview_controller/preview_controller.go
+++ b/controllers/preview_controller/preview_controller.go
@@ -4,12 +4,13 @@ import (
 	"database/sql"
 	"errors"
 	"fmt"
+	"net/url"
 
 	"github.com/sirupsen/logrus"
 	"github.com/turt2live/matrix-media-repo/common"
 	"github.com/turt2live/matrix-media-repo/common/globals"
 	"github.com/turt2live/matrix-media-repo/common/rcontext"
-	"github.com/turt2live/matrix-media-repo/controllers/preview_controller/acl"
+	"github.com/turt2live/matrix-media-repo/controllers/preview_controller/preview_types"
 	"github.com/turt2live/matrix-media-repo/storage"
 	"github.com/turt2live/matrix-media-repo/storage/stores"
 	"github.com/turt2live/matrix-media-repo/types"
@@ -47,12 +48,19 @@ func GetPreview(urlStr string, onHost string, forUserId string, atTs int64, ctx
 			return GetPreview(urlStr, onHost, forUserId, now, ctx)
 		}
 
-		ctx.Log.Info("Preview not cached - fetching resource")
-
-		urlToPreview, err := acl.ValidateUrlForPreview(urlStr, ctx)
+		parsedUrl, err := url.Parse(urlStr)
 		if err != nil {
-			return nil, err
+			ctx.Log.Error("Error parsing URL: ", err.Error())
+			db.InsertPreviewError(urlStr, common.ErrCodeInvalidHost)
+			return nil, common.ErrInvalidHost
 		}
+		parsedUrl.Fragment = "" // Remove fragment because it's not important for servers
+		urlToPreview := &preview_types.UrlPayload{
+			UrlString: urlStr,
+			ParsedUrl: parsedUrl,
+		}
+
+		ctx.Log.Info("Preview not cached - fetching resource")
 
 		previewChan := getResourceHandler().GeneratePreview(urlToPreview, forUserId, onHost)
 		defer close(previewChan)
diff --git a/controllers/preview_controller/preview_types/types.go b/controllers/preview_controller/preview_types/types.go
index f26d6e94..0f33ea6b 100644
--- a/controllers/preview_controller/preview_types/types.go
+++ b/controllers/preview_controller/preview_types/types.go
@@ -3,7 +3,6 @@ package preview_types
 import (
 	"errors"
 	"io"
-	"net"
 	"net/url"
 )
 
@@ -27,7 +26,6 @@ type PreviewImage struct {
 type UrlPayload struct {
 	UrlString string
 	ParsedUrl *url.URL
-	Address   net.IP
 }
 
 var ErrPreviewUnsupported = errors.New("preview not supported by this previewer")
diff --git a/controllers/preview_controller/previewers/http.go b/controllers/preview_controller/previewers/http.go
index 08976463..dac0c4d9 100644
--- a/controllers/preview_controller/previewers/http.go
+++ b/controllers/preview_controller/previewers/http.go
@@ -16,6 +16,7 @@ import (
 	"github.com/ryanuber/go-glob"
 	"github.com/turt2live/matrix-media-repo/common"
 	"github.com/turt2live/matrix-media-repo/common/rcontext"
+	"github.com/turt2live/matrix-media-repo/controllers/preview_controller/acl"
 	"github.com/turt2live/matrix-media-repo/controllers/preview_controller/preview_types"
 )
 
@@ -28,10 +29,14 @@ func doHttpGet(urlPayload *preview_types.UrlPayload, ctx rcontext.RequestContext
 		DualStack: true,
 	}
 
-	dialContext := func(ctx context.Context, network, addr string) (conn net.Conn, e error) {
-		// If we aren't handling any address then return the default behaviour
-		if urlPayload.Address == nil {
-			return dialer.DialContext(ctx, network, addr)
+	dialContext := func(ctx2 context.Context, network, addr string) (conn net.Conn, e error) {
+		if network != "tcp" {
+			return nil, errors.New("invalid network: expected tcp")
+		}
+
+		safeIp, safePort, err := acl.GetSafeAddress(addr, ctx)
+		if err != nil {
+			return nil, err
 		}
 
 		// Try and determine which port we're expecting a request to come in on. Because the
@@ -39,28 +44,27 @@ func doHttpGet(urlPayload *preview_types.UrlPayload, ctx rcontext.RequestContext
 		// so that redirects don't fail previews. We only support the alternate port if the
 		// default port for the scheme is used, however.
 
-		expectedPort := urlPayload.ParsedUrl.Port()
 		altPort := ""
-		if expectedPort == "" {
+		if safePort == "" {
 			if urlPayload.ParsedUrl.Scheme == "http" {
-				expectedPort = "80"
+				safePort = "80"
 				altPort = "443"
 			} else if urlPayload.ParsedUrl.Scheme == "https" {
-				expectedPort = "443"
+				safePort = "443"
 				altPort = "80"
 			} else {
 				return nil, errors.New("unexpected scheme: cannot determine port")
 			}
 		}
 
-		expectedAddr := fmt.Sprintf("%s:%s", urlPayload.ParsedUrl.Host, expectedPort)
+		expectedAddr := fmt.Sprintf("%s:%s", urlPayload.ParsedUrl.Host, safePort)
 		altAddr := fmt.Sprintf("%s:%s", urlPayload.ParsedUrl.Host, altPort)
 
 		returnAddr := ""
 		if addr == expectedAddr {
-			returnAddr = fmt.Sprintf("%s:%s", urlPayload.Address.String(), expectedPort)
+			returnAddr = fmt.Sprintf("%s:%s", safeIp.String(), safePort)
 		} else if addr == altAddr && altPort != "" {
-			returnAddr = fmt.Sprintf("%s:%s", urlPayload.Address.String(), altPort)
+			returnAddr = fmt.Sprintf("%s:%s", safeIp.String(), altPort)
 		}
 
 		if returnAddr != "" {
diff --git a/controllers/preview_controller/previewers/opengraph_previewer.go b/controllers/preview_controller/previewers/opengraph_previewer.go
index f57891a6..fa21792b 100644
--- a/controllers/preview_controller/previewers/opengraph_previewer.go
+++ b/controllers/preview_controller/previewers/opengraph_previewer.go
@@ -1,7 +1,6 @@
 package previewers
 
 import (
-	"fmt"
 	"net/url"
 	"strconv"
 	"strings"
@@ -11,7 +10,6 @@ import (
 	"github.com/prometheus/client_golang/prometheus"
 	"github.com/turt2live/matrix-media-repo/common"
 	"github.com/turt2live/matrix-media-repo/common/rcontext"
-	"github.com/turt2live/matrix-media-repo/controllers/preview_controller/acl"
 	"github.com/turt2live/matrix-media-repo/controllers/preview_controller/preview_types"
 	"github.com/turt2live/matrix-media-repo/metrics"
 )
@@ -62,25 +60,16 @@ func GenerateOpenGraphPreview(urlPayload *preview_types.UrlPayload, ctx rcontext
 	}
 
 	if og.Images != nil && len(og.Images) > 0 {
-		baseUrlS := fmt.Sprintf("%s://%s", urlPayload.ParsedUrl.Scheme, urlPayload.Address.String())
-		baseUrl, err := url.Parse(baseUrlS)
-		if err != nil {
-			ctx.Log.Error("Non-fatal error getting thumbnail (parsing base url): " + err.Error())
-			return *graph, nil
-		}
-
 		imgUrl, err := url.Parse(og.Images[0].URL)
 		if err != nil {
 			ctx.Log.Error("Non-fatal error getting thumbnail (parsing image url): " + err.Error())
 			return *graph, nil
 		}
 
-		// Ensure images pass through the same validation check
-		imgAbsUrl := baseUrl.ResolveReference(imgUrl)
-		imgUrlPayload, err := acl.ValidateUrlForPreview(imgAbsUrl.String(), ctx)
-		if err != nil {
-			ctx.Log.Error("Non-fatal error getting thumbnail (URL validation): " + err.Error())
-			return *graph, nil
+		imgAbsUrl := urlPayload.ParsedUrl.ResolveReference(imgUrl)
+		imgUrlPayload := &preview_types.UrlPayload{
+			UrlString: imgAbsUrl.String(),
+			ParsedUrl: imgAbsUrl,
 		}
 
 		img, err := downloadImage(imgUrlPayload, ctx)
-- 
GitLab