From 4a22c6c5eb49e43686aeec177d8f5f7b52d53344 Mon Sep 17 00:00:00 2001
From: Travis Ralston <travpc@gmail.com>
Date: Fri, 3 Jan 2020 12:59:21 -0700
Subject: [PATCH] Fix access token error handling

We'd return a 500 when the token isn't found on the homeserver, but we should be returned a proper auth error. This just fixes the logic around the area to make that actually work.
---
 api/auth.go      |  6 +++---
 matrix/auth.go   | 12 +++++++-----
 matrix/matrix.go |  4 ++--
 3 files changed, 12 insertions(+), 10 deletions(-)

diff --git a/api/auth.go b/api/auth.go
index 5d271b8e..3d37b976 100644
--- a/api/auth.go
+++ b/api/auth.go
@@ -37,8 +37,8 @@ func AccessTokenRequiredRoute(next func(r *http.Request, rctx rcontext.RequestCo
 		appserviceUserId := util.GetAppserviceUserIdFromRequest(r)
 		userId, err := matrix.GetUserIdFromToken(rctx, r.Host, accessToken, appserviceUserId, r.RemoteAddr)
 		if err != nil || userId == "" {
-			if err != nil && err != matrix.ErrNoToken {
-				rctx.Log.Error("Error verifying token: ", err)
+			if err != nil && err != matrix.ErrInvalidToken {
+				rctx.Log.Error("Error verifying token (fatal): ", err)
 				return InternalServerError("Unexpected Error")
 			}
 
@@ -65,7 +65,7 @@ func AccessTokenOptionalRoute(next func(r *http.Request, rctx rcontext.RequestCo
 		appserviceUserId := util.GetAppserviceUserIdFromRequest(r)
 		userId, err := matrix.GetUserIdFromToken(rctx, r.Host, accessToken, appserviceUserId, r.RemoteAddr)
 		if err != nil {
-			if err != matrix.ErrNoToken {
+			if err != matrix.ErrInvalidToken {
 				rctx.Log.Error("Error verifying token: ", err)
 				return InternalServerError("Unexpected Error")
 			}
diff --git a/matrix/auth.go b/matrix/auth.go
index df0ab62c..689147a9 100644
--- a/matrix/auth.go
+++ b/matrix/auth.go
@@ -8,17 +8,18 @@ import (
 	"github.com/turt2live/matrix-media-repo/common/rcontext"
 )
 
-var ErrNoToken = errors.New("Missing access token")
+var ErrInvalidToken = errors.New("Missing or invalid access token")
 
 func GetUserIdFromToken(ctx rcontext.RequestContext, serverName string, accessToken string, appserviceUserId string, ipAddr string) (string, error) {
 	if accessToken == "" {
-		return "", ErrNoToken
+		return "", ErrInvalidToken
 	}
 
 	hs, cb := getBreakerAndConfig(serverName)
 
 	userId := ""
 	var replyError error
+	var authError error
 	replyError = cb.CallContext(ctx, func() error {
 		query := map[string]string{}
 		if appserviceUserId != "" {
@@ -34,7 +35,8 @@ func GetUserIdFromToken(ctx rcontext.RequestContext, serverName string, accessTo
 		target.RawQuery = q.Encode()
 		err := doRequest(ctx, "GET", target.String(), nil, response, accessToken, ipAddr)
 		if err != nil {
-			err, replyError = filterError(err)
+			ctx.Log.Warn("Error from homeserver: ", err)
+			err, authError = filterError(err)
 			return err
 		}
 
@@ -42,8 +44,8 @@ func GetUserIdFromToken(ctx rcontext.RequestContext, serverName string, accessTo
 		return nil
 	}, 1*time.Minute)
 
-	if replyError == nil {
-		return userId, nil
+	if authError != nil {
+		return userId, authError
 	}
 	return userId, replyError
 }
diff --git a/matrix/matrix.go b/matrix/matrix.go
index 70417b6c..7a278afd 100644
--- a/matrix/matrix.go
+++ b/matrix/matrix.go
@@ -35,10 +35,10 @@ func filterError(err error) (error, error) {
 	}
 
 	// Unknown token errors should be filtered out explicitly to ensure we don't break on bad requests
-	if httpErr, ok := err.(errorResponse); ok {
+	if httpErr, ok := err.(*errorResponse); ok {
 		if httpErr.ErrorCode == common.ErrCodeUnknownToken {
 			// We send back our own version of 'unknown token' to ensure we can filter it out elsewhere
-			return nil, ErrNoToken
+			return nil, ErrInvalidToken
 		}
 	}
 
-- 
GitLab