From a745540cd94f1c9815482b8afb2c8974b38e14ba Mon Sep 17 00:00:00 2001
From: Travis Ralston <travpc@gmail.com>
Date: Wed, 11 Apr 2018 21:07:27 -0600
Subject: [PATCH] Move authentication to a dedicated layer

Fixes #87

The layer will return a 403 if the media repo explicitly receives an unknown token error from the homeserver. Otherwise it'll return a 500 upon any other homeserver failure (such as being offline).
---
 .../client/error_handlers.go                  |  4 +-
 .../matrix-media-repo/client/r0/auth.go       | 72 +++++++++++++++++++
 .../matrix-media-repo/client/r0/download.go   | 22 ++----
 .../matrix-media-repo/client/r0/identicon.go  | 12 ++--
 .../matrix-media-repo/client/r0/info.go       | 14 +---
 .../matrix-media-repo/client/r0/local_copy.go | 16 +----
 .../client/r0/preview_url.go                  | 16 +----
 .../matrix-media-repo/client/r0/purge.go      | 20 +-----
 .../matrix-media-repo/client/r0/quarantine.go | 60 +++++-----------
 .../matrix-media-repo/client/r0/thumbnail.go  |  7 +-
 .../matrix-media-repo/client/r0/upload.go     | 17 +----
 .../matrix-media-repo/cmd/media_repo/main.go  | 26 +++----
 .../matrix-media-repo/matrix/admin.go         | 12 ++--
 .../matrix-media-repo/matrix/auth.go          | 16 ++++-
 .../matrix-media-repo/matrix/matrix.go        | 12 ++--
 15 files changed, 158 insertions(+), 168 deletions(-)
 create mode 100644 src/github.com/turt2live/matrix-media-repo/client/r0/auth.go

diff --git a/src/github.com/turt2live/matrix-media-repo/client/error_handlers.go b/src/github.com/turt2live/matrix-media-repo/client/error_handlers.go
index 506a02ae..7bbede01 100644
--- a/src/github.com/turt2live/matrix-media-repo/client/error_handlers.go
+++ b/src/github.com/turt2live/matrix-media-repo/client/error_handlers.go
@@ -6,10 +6,10 @@ import (
 	"github.com/sirupsen/logrus"
 )
 
-func NotFoundHandler(w http.ResponseWriter, r *http.Request, log *logrus.Entry) interface{} {
+func NotFoundHandler(r *http.Request, log *logrus.Entry) interface{} {
 	return NotFoundError()
 }
 
-func MethodNotAllowedHandler(w http.ResponseWriter, r *http.Request, log *logrus.Entry) interface{} {
+func MethodNotAllowedHandler(r *http.Request, log *logrus.Entry) interface{} {
 	return MethodNotAllowed()
 }
diff --git a/src/github.com/turt2live/matrix-media-repo/client/r0/auth.go b/src/github.com/turt2live/matrix-media-repo/client/r0/auth.go
new file mode 100644
index 00000000..acf16368
--- /dev/null
+++ b/src/github.com/turt2live/matrix-media-repo/client/r0/auth.go
@@ -0,0 +1,72 @@
+package r0
+
+import (
+	"net/http"
+
+	"github.com/sirupsen/logrus"
+	"github.com/turt2live/matrix-media-repo/client"
+	"github.com/turt2live/matrix-media-repo/matrix"
+	"github.com/turt2live/matrix-media-repo/util"
+)
+
+type userInfo struct {
+	userId      string
+	accessToken string
+}
+
+func AccessTokenRequiredRoute(next func(r *http.Request, log *logrus.Entry, user userInfo) interface{}) func(*http.Request, *logrus.Entry) interface{} {
+	return func(r *http.Request, log *logrus.Entry) interface{} {
+		accessToken := util.GetAccessTokenFromRequest(r)
+		appserviceUserId := util.GetAppserviceUserIdFromRequest(r)
+		userId, err := matrix.GetUserIdFromToken(r.Context(), r.Host, accessToken, appserviceUserId)
+		if err != nil || userId == "" {
+			log.Error(err)
+			if err != nil && err != matrix.ErrNoToken {
+				log.Error("Error verifying token: ", err)
+				return client.InternalServerError("Unexpected Error")
+			}
+
+			log.Warn("Failed to verify token (fatal)")
+			return client.AuthFailed()
+		}
+
+		log = log.WithFields(logrus.Fields{"authUserId": userId})
+		return next(r, log, userInfo{userId, accessToken})
+	}
+}
+
+func AccessTokenOptionalRoute(next func(r *http.Request, log *logrus.Entry, user userInfo) interface{}) func(*http.Request, *logrus.Entry) interface{} {
+	return func(r *http.Request, log *logrus.Entry) interface{} {
+		accessToken := util.GetAccessTokenFromRequest(r)
+		appserviceUserId := util.GetAppserviceUserIdFromRequest(r)
+		userId, err := matrix.GetUserIdFromToken(r.Context(), r.Host, accessToken, appserviceUserId)
+		if err != nil {
+			if err != matrix.ErrNoToken {
+				log.Error("Error verifying token: ", err)
+				return client.InternalServerError("Unexpected Error")
+			}
+
+			log.Warn("Failed to verify token (non-fatal)")
+			userId = ""
+		}
+
+		log = log.WithFields(logrus.Fields{"authUserId": userId})
+		return next(r, log, userInfo{userId, accessToken})
+	}
+}
+
+func RepoAdminRoute(next func(r *http.Request, log *logrus.Entry, user userInfo) interface{}) func(*http.Request, *logrus.Entry) interface{} {
+	return AccessTokenRequiredRoute(func(r *http.Request, log *logrus.Entry, user userInfo) interface{} {
+		if user.userId == "" {
+			log.Warn("Could not identify user for this admin route")
+			return client.AuthFailed()
+		}
+		if !util.IsGlobalAdmin(user.userId) {
+			log.Warn("User " + user.userId + " is not a repository administrator")
+			return client.AuthFailed()
+		}
+
+		log = log.WithFields(logrus.Fields{"isRepoAdmin": true})
+		return next(r, log, user)
+	})
+}
diff --git a/src/github.com/turt2live/matrix-media-repo/client/r0/download.go b/src/github.com/turt2live/matrix-media-repo/client/r0/download.go
index 317e9286..005eeeb0 100644
--- a/src/github.com/turt2live/matrix-media-repo/client/r0/download.go
+++ b/src/github.com/turt2live/matrix-media-repo/client/r0/download.go
@@ -7,7 +7,6 @@ import (
 	"github.com/gorilla/mux"
 	"github.com/sirupsen/logrus"
 	"github.com/turt2live/matrix-media-repo/client"
-	"github.com/turt2live/matrix-media-repo/matrix"
 	"github.com/turt2live/matrix-media-repo/media_cache"
 	"github.com/turt2live/matrix-media-repo/util"
 	"github.com/turt2live/matrix-media-repo/util/errs"
@@ -20,8 +19,10 @@ type DownloadMediaResponse struct {
 	Data        io.ReadCloser
 }
 
-func DownloadMedia(w http.ResponseWriter, r *http.Request, log *logrus.Entry) interface{} {
-	if !ValidateUserCanDownload(r, log) {
+func DownloadMedia(r *http.Request, log *logrus.Entry, user userInfo) interface{} {
+	hs := util.GetHomeserverConfig(r.Host)
+	if hs.DownloadRequiresAuth && user.userId == "" {
+		log.Warn("Homeserver requires authenticated downloads - denying request")
 		return client.AuthFailed()
 	}
 
@@ -63,18 +64,3 @@ func DownloadMedia(w http.ResponseWriter, r *http.Request, log *logrus.Entry) in
 		Data:        streamedMedia.Stream,
 	}
 }
-
-func ValidateUserCanDownload(r *http.Request, log *logrus.Entry) (bool) {
-	hs := util.GetHomeserverConfig(r.Host)
-	if !hs.DownloadRequiresAuth {
-		return true // no auth required == can access
-	}
-
-	accessToken := util.GetAccessTokenFromRequest(r)
-	appserviceUserId := util.GetAppserviceUserIdFromRequest(r)
-	userId, err := matrix.GetUserIdFromToken(r.Context(), r.Host, accessToken, appserviceUserId)
-	if err != nil {
-		log.Error("Error verifying token: " + err.Error())
-	}
-	return userId != "" && err != nil
-}
diff --git a/src/github.com/turt2live/matrix-media-repo/client/r0/identicon.go b/src/github.com/turt2live/matrix-media-repo/client/r0/identicon.go
index 6af3f27e..4c62e178 100644
--- a/src/github.com/turt2live/matrix-media-repo/client/r0/identicon.go
+++ b/src/github.com/turt2live/matrix-media-repo/client/r0/identicon.go
@@ -14,19 +14,23 @@ import (
 	"github.com/sirupsen/logrus"
 	"github.com/turt2live/matrix-media-repo/client"
 	"github.com/turt2live/matrix-media-repo/config"
+	"github.com/turt2live/matrix-media-repo/util"
 )
 
 type IdenticonResponse struct {
 	Avatar io.Reader
 }
 
-func Identicon(w http.ResponseWriter, r *http.Request, log *logrus.Entry) interface{} {
+func Identicon(r *http.Request, log *logrus.Entry, user userInfo) interface{} {
+	hs := util.GetHomeserverConfig(r.Host)
+	if hs.DownloadRequiresAuth && user.userId == "" {
+		log.Warn("Homeserver requires authenticated downloads - denying request")
+		return client.AuthFailed()
+	}
+
 	if !config.Get().Identicons.Enabled {
 		return client.NotFoundError()
 	}
-	if !ValidateUserCanDownload(r, log) {
-		return client.AuthFailed()
-	}
 
 	params := mux.Vars(r)
 	seed := params["seed"]
diff --git a/src/github.com/turt2live/matrix-media-repo/client/r0/info.go b/src/github.com/turt2live/matrix-media-repo/client/r0/info.go
index 1ff497c7..1af76348 100644
--- a/src/github.com/turt2live/matrix-media-repo/client/r0/info.go
+++ b/src/github.com/turt2live/matrix-media-repo/client/r0/info.go
@@ -7,9 +7,7 @@ import (
 	"github.com/gorilla/mux"
 	"github.com/sirupsen/logrus"
 	"github.com/turt2live/matrix-media-repo/client"
-	"github.com/turt2live/matrix-media-repo/matrix"
 	"github.com/turt2live/matrix-media-repo/media_cache"
-	"github.com/turt2live/matrix-media-repo/util"
 	"github.com/turt2live/matrix-media-repo/util/errs"
 )
 
@@ -21,17 +19,7 @@ type MediaInfoResponse struct {
 	Size        int64  `json:"size"`
 }
 
-func MediaInfo(w http.ResponseWriter, r *http.Request, log *logrus.Entry) interface{} {
-	accessToken := util.GetAccessTokenFromRequest(r)
-	appserviceUserId := util.GetAppserviceUserIdFromRequest(r)
-	userId, err := matrix.GetUserIdFromToken(r.Context(), r.Host, accessToken, appserviceUserId)
-	if err != nil || userId == "" {
-		if err != nil {
-			log.Error("Error verifying token: " + err.Error())
-		}
-		return client.AuthFailed()
-	}
-
+func MediaInfo(r *http.Request, log *logrus.Entry, user userInfo) interface{} {
 	params := mux.Vars(r)
 
 	server := params["server"]
diff --git a/src/github.com/turt2live/matrix-media-repo/client/r0/local_copy.go b/src/github.com/turt2live/matrix-media-repo/client/r0/local_copy.go
index f20e5aff..1e7ffad5 100644
--- a/src/github.com/turt2live/matrix-media-repo/client/r0/local_copy.go
+++ b/src/github.com/turt2live/matrix-media-repo/client/r0/local_copy.go
@@ -6,24 +6,12 @@ import (
 	"github.com/gorilla/mux"
 	"github.com/sirupsen/logrus"
 	"github.com/turt2live/matrix-media-repo/client"
-	"github.com/turt2live/matrix-media-repo/matrix"
 	"github.com/turt2live/matrix-media-repo/media_cache"
 	"github.com/turt2live/matrix-media-repo/services/media_service"
-	"github.com/turt2live/matrix-media-repo/util"
 	"github.com/turt2live/matrix-media-repo/util/errs"
 )
 
-func LocalCopy(w http.ResponseWriter, r *http.Request, log *logrus.Entry) interface{} {
-	accessToken := util.GetAccessTokenFromRequest(r)
-	appserviceUserId := util.GetAppserviceUserIdFromRequest(r)
-	userId, err := matrix.GetUserIdFromToken(r.Context(), r.Host, accessToken, appserviceUserId)
-	if err != nil || userId == "" {
-		if err != nil {
-			log.Error("Error verifying token: " + err.Error())
-		}
-		return client.AuthFailed()
-	}
-
+func LocalCopy(r *http.Request, log *logrus.Entry, user userInfo) interface{} {
 	params := mux.Vars(r)
 
 	server := params["server"]
@@ -58,7 +46,7 @@ func LocalCopy(w http.ResponseWriter, r *http.Request, log *logrus.Entry) interf
 		return &MediaUploadedResponse{streamedMedia.Media.MxcUri()}
 	}
 
-	newMedia, err := svc.StoreMedia(streamedMedia.Stream, streamedMedia.Media.ContentType, streamedMedia.Media.UploadName, userId, r.Host, "")
+	newMedia, err := svc.StoreMedia(streamedMedia.Stream, streamedMedia.Media.ContentType, streamedMedia.Media.UploadName, user.userId, r.Host, "")
 	if err != nil {
 		if err == errs.ErrMediaNotAllowed {
 			return client.BadRequest("Media content type not allowed on this server")
diff --git a/src/github.com/turt2live/matrix-media-repo/client/r0/preview_url.go b/src/github.com/turt2live/matrix-media-repo/client/r0/preview_url.go
index 251474a2..3fc78f1e 100644
--- a/src/github.com/turt2live/matrix-media-repo/client/r0/preview_url.go
+++ b/src/github.com/turt2live/matrix-media-repo/client/r0/preview_url.go
@@ -8,7 +8,6 @@ import (
 	"github.com/sirupsen/logrus"
 	"github.com/turt2live/matrix-media-repo/client"
 	"github.com/turt2live/matrix-media-repo/config"
-	"github.com/turt2live/matrix-media-repo/matrix"
 	"github.com/turt2live/matrix-media-repo/services/url_service"
 	"github.com/turt2live/matrix-media-repo/util"
 	"github.com/turt2live/matrix-media-repo/util/errs"
@@ -27,27 +26,18 @@ type MatrixOpenGraph struct {
 	ImageHeight int    `json:"og:image:height,omitempty"`
 }
 
-func PreviewUrl(w http.ResponseWriter, r *http.Request, log *logrus.Entry) interface{} {
+func PreviewUrl(r *http.Request, log *logrus.Entry, user userInfo) interface{} {
 	if !config.Get().UrlPreviews.Enabled {
 		return client.NotFoundError()
 	}
 
-	accessToken := util.GetAccessTokenFromRequest(r)
-	appserviceUserId := util.GetAppserviceUserIdFromRequest(r)
-	userId, err := matrix.GetUserIdFromToken(r.Context(), r.Host, accessToken, appserviceUserId)
-	if err != nil || userId == "" {
-		if err != nil {
-			log.Error("Error verifying token: " + err.Error())
-		}
-		return client.AuthFailed()
-	}
-
 	params := r.URL.Query()
 
 	// Parse the parameters
 	urlStr := params.Get("url")
 	tsStr := params.Get("ts")
 	ts := util.NowMillis()
+	var err error
 	if tsStr != "" {
 		ts, err = strconv.ParseInt(tsStr, 10, 64)
 		if err != nil {
@@ -65,7 +55,7 @@ func PreviewUrl(w http.ResponseWriter, r *http.Request, log *logrus.Entry) inter
 	}
 
 	svc := url_service.New(r.Context(), log)
-	preview, err := svc.GetPreview(urlStr, r.Host, userId, ts)
+	preview, err := svc.GetPreview(urlStr, r.Host, user.userId, ts)
 	if err != nil {
 		if err == errs.ErrMediaNotFound || err == errs.ErrHostNotFound {
 			return client.NotFoundError()
diff --git a/src/github.com/turt2live/matrix-media-repo/client/r0/purge.go b/src/github.com/turt2live/matrix-media-repo/client/r0/purge.go
index 5fcdbe38..328cd3ca 100644
--- a/src/github.com/turt2live/matrix-media-repo/client/r0/purge.go
+++ b/src/github.com/turt2live/matrix-media-repo/client/r0/purge.go
@@ -6,31 +6,14 @@ import (
 
 	"github.com/sirupsen/logrus"
 	"github.com/turt2live/matrix-media-repo/client"
-	"github.com/turt2live/matrix-media-repo/matrix"
 	"github.com/turt2live/matrix-media-repo/services/media_service"
-	"github.com/turt2live/matrix-media-repo/util"
 )
 
 type MediaPurgedResponse struct {
 	NumRemoved int `json:"total_removed"`
 }
 
-func PurgeRemoteMedia(w http.ResponseWriter, r *http.Request, log *logrus.Entry) interface{} {
-	accessToken := util.GetAccessTokenFromRequest(r)
-	appserviceUserId := util.GetAppserviceUserIdFromRequest(r)
-	userId, err := matrix.GetUserIdFromToken(r.Context(), r.Host, accessToken, appserviceUserId)
-	if err != nil || userId == "" {
-		if err != nil {
-			log.Error("Error verifying token: " + err.Error())
-		}
-		return client.AuthFailed()
-	}
-	isAdmin := util.IsGlobalAdmin(userId)
-	if !isAdmin {
-		log.Warn("User " + userId + " is not a repository administrator")
-		return client.AuthFailed()
-	}
-
+func PurgeRemoteMedia(r *http.Request, log *logrus.Entry, user userInfo) interface{} {
 	beforeTsStr := r.URL.Query().Get("before_ts")
 	if beforeTsStr == "" {
 		return client.BadRequest("Missing before_ts argument")
@@ -42,7 +25,6 @@ func PurgeRemoteMedia(w http.ResponseWriter, r *http.Request, log *logrus.Entry)
 
 	log = log.WithFields(logrus.Fields{
 		"beforeTs": beforeTs,
-		"userId":   userId,
 	})
 
 	// We don't bother clearing the cache because it's still probably useful there
diff --git a/src/github.com/turt2live/matrix-media-repo/client/r0/quarantine.go b/src/github.com/turt2live/matrix-media-repo/client/r0/quarantine.go
index 4be6a495..7dde9e70 100644
--- a/src/github.com/turt2live/matrix-media-repo/client/r0/quarantine.go
+++ b/src/github.com/turt2live/matrix-media-repo/client/r0/quarantine.go
@@ -18,18 +18,8 @@ type MediaQuarantinedResponse struct {
 	NumQuarantined int `json:"num_quarantined"`
 }
 
-func QuarantineRoomMedia(w http.ResponseWriter, r *http.Request, log *logrus.Entry) interface{} {
-	accessToken := util.GetAccessTokenFromRequest(r)
-	appserviceUserId := util.GetAppserviceUserIdFromRequest(r)
-	userId, err := matrix.GetUserIdFromToken(r.Context(), r.Host, accessToken, appserviceUserId)
-	if err != nil || userId == "" {
-		if err != nil {
-			log.Error("Error verifying token: " + err.Error())
-		}
-		return client.AuthFailed()
-	}
-
-	canQuarantine, allowOtherHosts, isLocalAdmin, isGlobalAdmin := getQuarantineRequestInfo(r, log, userId, accessToken)
+func QuarantineRoomMedia(r *http.Request, log *logrus.Entry, user userInfo) interface{} {
+	canQuarantine, allowOtherHosts, isLocalAdmin := getQuarantineRequestInfo(r, log, user)
 	if !canQuarantine {
 		return client.AuthFailed()
 	}
@@ -39,13 +29,11 @@ func QuarantineRoomMedia(w http.ResponseWriter, r *http.Request, log *logrus.Ent
 	roomId := params["roomId"]
 
 	log = log.WithFields(logrus.Fields{
-		"roomId":      roomId,
-		"userId":      userId,
-		"localAdmin":  isLocalAdmin,
-		"globalAdmin": isGlobalAdmin,
+		"roomId":     roomId,
+		"localAdmin": isLocalAdmin,
 	})
 
-	allMedia, err := matrix.ListMedia(r.Context(), r.Host, accessToken, roomId)
+	allMedia, err := matrix.ListMedia(r.Context(), r.Host, user.accessToken, roomId)
 	if err != nil {
 		log.Error("Error while listing media in the room: " + err.Error())
 		return client.InternalServerError("error retrieving media in room")
@@ -79,18 +67,8 @@ func QuarantineRoomMedia(w http.ResponseWriter, r *http.Request, log *logrus.Ent
 	return &MediaQuarantinedResponse{NumQuarantined: total}
 }
 
-func QuarantineMedia(w http.ResponseWriter, r *http.Request, log *logrus.Entry) interface{} {
-	accessToken := util.GetAccessTokenFromRequest(r)
-	appserviceUserId := util.GetAppserviceUserIdFromRequest(r)
-	userId, err := matrix.GetUserIdFromToken(r.Context(), r.Host, accessToken, appserviceUserId)
-	if err != nil || userId == "" {
-		if err != nil {
-			log.Error("Error verifying token: " + err.Error())
-		}
-		return client.AuthFailed()
-	}
-
-	canQuarantine, allowOtherHosts, isLocalAdmin, isGlobalAdmin := getQuarantineRequestInfo(r, log, userId, accessToken)
+func QuarantineMedia(r *http.Request, log *logrus.Entry, user userInfo) interface{} {
+	canQuarantine, allowOtherHosts, isLocalAdmin := getQuarantineRequestInfo(r, log, user)
 	if !canQuarantine {
 		return client.AuthFailed()
 	}
@@ -101,11 +79,9 @@ func QuarantineMedia(w http.ResponseWriter, r *http.Request, log *logrus.Entry)
 	mediaId := params["mediaId"]
 
 	log = log.WithFields(logrus.Fields{
-		"server":      server,
-		"mediaId":     mediaId,
-		"userId":      userId,
-		"localAdmin":  isLocalAdmin,
-		"globalAdmin": isGlobalAdmin,
+		"server":     server,
+		"mediaId":    mediaId,
+		"localAdmin": isLocalAdmin,
 	})
 
 	if !allowOtherHosts && r.Host != server {
@@ -139,25 +115,25 @@ func doQuarantine(ctx context.Context, log *logrus.Entry, server string, mediaId
 	return &MediaQuarantinedResponse{NumQuarantined: num}, true
 }
 
-func getQuarantineRequestInfo(r *http.Request, log *logrus.Entry, userId string, accessToken string) (bool, bool, bool, bool) {
-	isGlobalAdmin := util.IsGlobalAdmin(userId)
+func getQuarantineRequestInfo(r *http.Request, log *logrus.Entry, user userInfo) (bool, bool, bool) {
+	isGlobalAdmin := util.IsGlobalAdmin(user.userId)
 	canQuarantine := isGlobalAdmin
 	allowOtherHosts := isGlobalAdmin
 	isLocalAdmin := false
 	var err error
 	if !isGlobalAdmin {
 		if config.Get().Quarantine.AllowLocalAdmins {
-			isLocalAdmin, err = matrix.IsUserAdmin(r.Context(), r.Host, accessToken)
+			isLocalAdmin, err = matrix.IsUserAdmin(r.Context(), r.Host, user.accessToken)
 			if err != nil {
 				log.Error("Error verifying local admin: " + err.Error())
 				canQuarantine = false
-				return canQuarantine, allowOtherHosts, isLocalAdmin, isGlobalAdmin
+				return canQuarantine, allowOtherHosts, isLocalAdmin
 			}
 
 			if !isLocalAdmin {
-				log.Warn(userId + " tried to quarantine media on another server")
+				log.Warn(user.userId + " tried to quarantine media on another server")
 				canQuarantine = false
-				return canQuarantine, allowOtherHosts, isLocalAdmin, isGlobalAdmin
+				return canQuarantine, allowOtherHosts, isLocalAdmin
 			}
 
 			// They have local admin status and we allow local admins to quarantine
@@ -166,8 +142,8 @@ func getQuarantineRequestInfo(r *http.Request, log *logrus.Entry, userId string,
 	}
 
 	if !canQuarantine {
-		log.Warn(userId + " tried to quarantine media")
+		log.Warn(user.userId + " tried to quarantine media")
 	}
 
-	return canQuarantine, allowOtherHosts, isLocalAdmin, isGlobalAdmin
+	return canQuarantine, allowOtherHosts, isLocalAdmin
 }
diff --git a/src/github.com/turt2live/matrix-media-repo/client/r0/thumbnail.go b/src/github.com/turt2live/matrix-media-repo/client/r0/thumbnail.go
index 3373e381..6cf01359 100644
--- a/src/github.com/turt2live/matrix-media-repo/client/r0/thumbnail.go
+++ b/src/github.com/turt2live/matrix-media-repo/client/r0/thumbnail.go
@@ -9,11 +9,14 @@ import (
 	"github.com/turt2live/matrix-media-repo/client"
 	"github.com/turt2live/matrix-media-repo/config"
 	"github.com/turt2live/matrix-media-repo/media_cache"
+	"github.com/turt2live/matrix-media-repo/util"
 	"github.com/turt2live/matrix-media-repo/util/errs"
 )
 
-func ThumbnailMedia(w http.ResponseWriter, r *http.Request, log *logrus.Entry) interface{} {
-	if !ValidateUserCanDownload(r, log) {
+func ThumbnailMedia(r *http.Request, log *logrus.Entry, user userInfo) interface{} {
+	hs := util.GetHomeserverConfig(r.Host)
+	if hs.DownloadRequiresAuth && user.userId == "" {
+		log.Warn("Homeserver requires authenticated downloads - denying request")
 		return client.AuthFailed()
 	}
 
diff --git a/src/github.com/turt2live/matrix-media-repo/client/r0/upload.go b/src/github.com/turt2live/matrix-media-repo/client/r0/upload.go
index 712ba325..ada56891 100644
--- a/src/github.com/turt2live/matrix-media-repo/client/r0/upload.go
+++ b/src/github.com/turt2live/matrix-media-repo/client/r0/upload.go
@@ -7,9 +7,7 @@ import (
 
 	"github.com/sirupsen/logrus"
 	"github.com/turt2live/matrix-media-repo/client"
-	"github.com/turt2live/matrix-media-repo/matrix"
 	"github.com/turt2live/matrix-media-repo/services/media_service"
-	"github.com/turt2live/matrix-media-repo/util"
 	"github.com/turt2live/matrix-media-repo/util/errs"
 )
 
@@ -17,17 +15,7 @@ type MediaUploadedResponse struct {
 	ContentUri string `json:"content_uri"`
 }
 
-func UploadMedia(w http.ResponseWriter, r *http.Request, log *logrus.Entry) interface{} {
-	accessToken := util.GetAccessTokenFromRequest(r)
-	appserviceUserId := util.GetAppserviceUserIdFromRequest(r)
-	userId, err := matrix.GetUserIdFromToken(r.Context(), r.Host, accessToken, appserviceUserId)
-	if err != nil || userId == "" {
-		if err != nil {
-			log.Error("Error verifying token: " + err.Error())
-		}
-		return client.AuthFailed()
-	}
-
+func UploadMedia(r *http.Request, log *logrus.Entry, user userInfo) interface{} {
 	filename := r.URL.Query().Get("filename")
 	if filename == "" {
 		filename = "upload.bin"
@@ -35,7 +23,6 @@ func UploadMedia(w http.ResponseWriter, r *http.Request, log *logrus.Entry) inte
 
 	log = log.WithFields(logrus.Fields{
 		"filename": filename,
-		"userId":   userId,
 	})
 
 	contentType := r.Header.Get("Content-Type")
@@ -51,7 +38,7 @@ func UploadMedia(w http.ResponseWriter, r *http.Request, log *logrus.Entry) inte
 		return client.RequestTooLarge()
 	}
 
-	media, err := svc.UploadMedia(r.Body, contentType, filename, userId, r.Host)
+	media, err := svc.UploadMedia(r.Body, contentType, filename, user.userId, r.Host)
 	if err != nil {
 		io.Copy(ioutil.Discard, r.Body) // Ditch the entire request
 		defer r.Body.Close()
diff --git a/src/github.com/turt2live/matrix-media-repo/cmd/media_repo/main.go b/src/github.com/turt2live/matrix-media-repo/cmd/media_repo/main.go
index 6c7f5db6..d9335263 100644
--- a/src/github.com/turt2live/matrix-media-repo/cmd/media_repo/main.go
+++ b/src/github.com/turt2live/matrix-media-repo/cmd/media_repo/main.go
@@ -30,7 +30,7 @@ type requestCounter struct {
 }
 
 type Handler struct {
-	h    func(http.ResponseWriter, *http.Request, *log.Entry) interface{}
+	h    func(*http.Request, *log.Entry) interface{}
 	opts HandlerOpts
 }
 
@@ -66,16 +66,16 @@ func main() {
 	hOpts := HandlerOpts{&counter}
 
 	optionsHandler := Handler{optionsRequest, hOpts}
-	uploadHandler := Handler{r0.UploadMedia, hOpts}
-	downloadHandler := Handler{r0.DownloadMedia, hOpts}
-	thumbnailHandler := Handler{r0.ThumbnailMedia, hOpts}
-	previewUrlHandler := Handler{r0.PreviewUrl, hOpts}
-	identiconHandler := Handler{r0.Identicon, hOpts}
-	purgeHandler := Handler{r0.PurgeRemoteMedia, hOpts}
-	quarantineHandler := Handler{r0.QuarantineMedia, hOpts}
-	quarantineRoomHandler := Handler{r0.QuarantineRoomMedia, hOpts}
-	localCopyHandler := Handler{r0.LocalCopy, hOpts}
-	infoHandler := Handler{r0.MediaInfo, hOpts}
+	uploadHandler := Handler{r0.AccessTokenRequiredRoute(r0.UploadMedia), hOpts}
+	downloadHandler := Handler{r0.AccessTokenOptionalRoute(r0.DownloadMedia), hOpts}
+	thumbnailHandler := Handler{r0.AccessTokenOptionalRoute(r0.ThumbnailMedia), hOpts}
+	previewUrlHandler := Handler{r0.AccessTokenRequiredRoute(r0.PreviewUrl), hOpts}
+	identiconHandler := Handler{r0.AccessTokenOptionalRoute(r0.Identicon), hOpts}
+	purgeHandler := Handler{r0.RepoAdminRoute(r0.PurgeRemoteMedia), hOpts}
+	quarantineHandler := Handler{r0.AccessTokenRequiredRoute(r0.QuarantineMedia), hOpts}
+	quarantineRoomHandler := Handler{r0.AccessTokenRequiredRoute(r0.QuarantineRoomMedia), hOpts}
+	localCopyHandler := Handler{r0.AccessTokenRequiredRoute(r0.LocalCopy), hOpts}
+	infoHandler := Handler{r0.AccessTokenRequiredRoute(r0.MediaInfo), hOpts}
 
 	routes := make(map[string]*ApiRoute)
 	versions := []string{"r0", "v1"} // r0 is typically clients and v1 is typically servers
@@ -175,7 +175,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 	var res interface{} = client.AuthFailed()
 	if util.IsServerOurs(r.Host) {
 		contextLog.Info("Server is owned by us, processing request")
-		res = h.h(w, r, contextLog)
+		res = h.h(r, contextLog)
 		if res == nil {
 			res = &EmptyResponse{}
 		}
@@ -241,6 +241,6 @@ func (c *requestCounter) GetNextId() string {
 	return "REQ-" + strId
 }
 
-func optionsRequest(w http.ResponseWriter, r *http.Request, log *log.Entry) interface{} {
+func optionsRequest(r *http.Request, log *log.Entry) interface{} {
 	return &EmptyResponse{}
 }
diff --git a/src/github.com/turt2live/matrix-media-repo/matrix/admin.go b/src/github.com/turt2live/matrix-media-repo/matrix/admin.go
index bfe7f43d..b35e063d 100644
--- a/src/github.com/turt2live/matrix-media-repo/matrix/admin.go
+++ b/src/github.com/turt2live/matrix-media-repo/matrix/admin.go
@@ -16,14 +16,16 @@ func IsUserAdmin(ctx context.Context, serverName string, accessToken string) (bo
 	cb.CallContext(ctx, func() error {
 		mtxClient, err := gomatrix.NewClient(hs.ClientServerApi, "", accessToken)
 		if err != nil {
-			return filterError(err, &replyError)
+			err, replyError = filterError(err)
+			return err
 		}
 
 		response := &whoisResponse{}
 		url := mtxClient.BuildURL("/admin/whois/", fakeUser)
 		_, err = mtxClient.MakeRequest("GET", url, nil, response)
 		if err != nil {
-			return filterError(err, &replyError)
+			err, replyError = filterError(err)
+			return err
 		}
 
 		isAdmin = true // if we made it this far, that is
@@ -41,13 +43,15 @@ func ListMedia(ctx context.Context, serverName string, accessToken string, roomI
 	cb.CallContext(ctx, func() error {
 		mtxClient, err := gomatrix.NewClient(hs.ClientServerApi, "", accessToken)
 		if err != nil {
-			return filterError(err, &replyError)
+			err, replyError = filterError(err)
+			return err
 		}
 
 		url := mtxClient.BuildURL("/admin/room/", roomId, "/media")
 		_, err = mtxClient.MakeRequest("GET", url, nil, response)
 		if err != nil {
-			return filterError(err, &replyError)
+			err, replyError = filterError(err)
+			return err
 		}
 
 		return nil
diff --git a/src/github.com/turt2live/matrix-media-repo/matrix/auth.go b/src/github.com/turt2live/matrix-media-repo/matrix/auth.go
index 989add5e..c8498b1c 100644
--- a/src/github.com/turt2live/matrix-media-repo/matrix/auth.go
+++ b/src/github.com/turt2live/matrix-media-repo/matrix/auth.go
@@ -5,9 +5,16 @@ import (
 	"time"
 
 	"github.com/matrix-org/gomatrix"
+	"github.com/pkg/errors"
 )
 
+var ErrNoToken = errors.New("Missing access token")
+
 func GetUserIdFromToken(ctx context.Context, serverName string, accessToken string, appserviceUserId string) (string, error) {
+	if accessToken == "" {
+		return "", ErrNoToken
+	}
+
 	hs, cb := getBreakerAndConfig(serverName)
 
 	userId := ""
@@ -15,7 +22,8 @@ func GetUserIdFromToken(ctx context.Context, serverName string, accessToken stri
 	cb.CallContext(ctx, func() error {
 		mtxClient, err := gomatrix.NewClient(hs.ClientServerApi, "", accessToken)
 		if err != nil {
-			return filterError(err, &replyError)
+			err, replyError = filterError(err)
+			return err
 		}
 
 		query := map[string]string{}
@@ -27,12 +35,16 @@ func GetUserIdFromToken(ctx context.Context, serverName string, accessToken stri
 		url := mtxClient.BuildURLWithQuery([]string{"/account/whoami"}, query)
 		_, err = mtxClient.MakeRequest("GET", url, nil, response)
 		if err != nil {
-			return filterError(err, &replyError)
+			err, replyError = filterError(err)
+			return err
 		}
 
 		userId = response.UserId
 		return nil
 	}, 1*time.Minute)
 
+	if replyError == nil {
+		return userId, nil
+	}
 	return userId, replyError
 }
diff --git a/src/github.com/turt2live/matrix-media-repo/matrix/matrix.go b/src/github.com/turt2live/matrix-media-repo/matrix/matrix.go
index 0187fa22..b4de5247 100644
--- a/src/github.com/turt2live/matrix-media-repo/matrix/matrix.go
+++ b/src/github.com/turt2live/matrix-media-repo/matrix/matrix.go
@@ -25,22 +25,20 @@ func getBreakerAndConfig(serverName string) (*config.HomeserverConfig, *circuit.
 	return hs, cb
 }
 
-func filterError(err error, replyError *error) error {
+func filterError(err error) (error, error) {
 	if err == nil {
-		replyError = nil
-		return nil
+		return nil, nil
 	}
 
 	// Unknown token errors should be filtered out explicitly to ensure we don't break on bad requests
 	if httpErr, ok := err.(gomatrix.HTTPError); ok {
 		if respErr, ok := httpErr.WrappedError.(gomatrix.RespError); ok {
 			if respErr.ErrCode == "M_UNKNOWN_TOKEN" {
-				replyError = &err // we still want to send the error to the caller though
-				return nil
+				// We send back our own version of UNKNOWN_TOKEN to ensure we can filter it out elsewhere
+				return nil, ErrNoToken
 			}
 		}
 	}
 
-	replyError = &err
-	return err
+	return err, err
 }
-- 
GitLab