From 93a18a19829e7a324c9ab6473eda0e6f38e6fa86 Mon Sep 17 00:00:00 2001
From: Travis Ralston <travpc@gmail.com>
Date: Wed, 11 Apr 2018 21:15:42 -0600
Subject: [PATCH] Move authentication handler out of the endpoints

---
 .../matrix-media-repo/api/{r0 => }/auth.go    | 37 +++++++++----------
 .../matrix-media-repo/api/r0/download.go      |  4 +-
 .../matrix-media-repo/api/r0/identicon.go     |  4 +-
 .../matrix-media-repo/api/r0/info.go          |  2 +-
 .../matrix-media-repo/api/r0/local_copy.go    |  4 +-
 .../matrix-media-repo/api/r0/preview_url.go   |  4 +-
 .../matrix-media-repo/api/r0/purge.go         |  2 +-
 .../matrix-media-repo/api/r0/quarantine.go    | 16 ++++----
 .../matrix-media-repo/api/r0/thumbnail.go     |  4 +-
 .../matrix-media-repo/api/r0/upload.go        |  4 +-
 .../matrix-media-repo/cmd/media_repo/main.go  | 20 +++++-----
 11 files changed, 50 insertions(+), 51 deletions(-)
 rename src/github.com/turt2live/matrix-media-repo/api/{r0 => }/auth.go (68%)

diff --git a/src/github.com/turt2live/matrix-media-repo/api/r0/auth.go b/src/github.com/turt2live/matrix-media-repo/api/auth.go
similarity index 68%
rename from src/github.com/turt2live/matrix-media-repo/api/r0/auth.go
rename to src/github.com/turt2live/matrix-media-repo/api/auth.go
index fc8b18b2..d8fe4a61 100644
--- a/src/github.com/turt2live/matrix-media-repo/api/r0/auth.go
+++ b/src/github.com/turt2live/matrix-media-repo/api/auth.go
@@ -1,20 +1,19 @@
-package r0
+package api
 
 import (
 	"net/http"
 
 	"github.com/sirupsen/logrus"
-	"github.com/turt2live/matrix-media-repo/api"
 	"github.com/turt2live/matrix-media-repo/matrix"
 	"github.com/turt2live/matrix-media-repo/util"
 )
 
-type userInfo struct {
-	userId      string
-	accessToken string
+type UserInfo struct {
+	UserId      string
+	AccessToken string
 }
 
-func AccessTokenRequiredRoute(next func(r *http.Request, log *logrus.Entry, user userInfo) interface{}) func(*http.Request, *logrus.Entry) interface{} {
+func AccessTokenRequiredRoute(next func(r *http.Request, log *logrus.Entry, user UserInfo) interface{}) func(*http.Request, *logrus.Entry) interface{} {
 	return func(r *http.Request, log *logrus.Entry) interface{} {
 		accessToken := util.GetAccessTokenFromRequest(r)
 		appserviceUserId := util.GetAppserviceUserIdFromRequest(r)
@@ -23,19 +22,19 @@ func AccessTokenRequiredRoute(next func(r *http.Request, log *logrus.Entry, user
 			log.Error(err)
 			if err != nil && err != matrix.ErrNoToken {
 				log.Error("Error verifying token: ", err)
-				return api.InternalServerError("Unexpected Error")
+				return InternalServerError("Unexpected Error")
 			}
 
 			log.Warn("Failed to verify token (fatal)")
-			return api.AuthFailed()
+			return AuthFailed()
 		}
 
 		log = log.WithFields(logrus.Fields{"authUserId": userId})
-		return next(r, log, userInfo{userId, accessToken})
+		return next(r, log, UserInfo{userId, accessToken})
 	}
 }
 
-func AccessTokenOptionalRoute(next func(r *http.Request, log *logrus.Entry, user userInfo) interface{}) func(*http.Request, *logrus.Entry) interface{} {
+func AccessTokenOptionalRoute(next func(r *http.Request, log *logrus.Entry, user UserInfo) interface{}) func(*http.Request, *logrus.Entry) interface{} {
 	return func(r *http.Request, log *logrus.Entry) interface{} {
 		accessToken := util.GetAccessTokenFromRequest(r)
 		appserviceUserId := util.GetAppserviceUserIdFromRequest(r)
@@ -43,7 +42,7 @@ func AccessTokenOptionalRoute(next func(r *http.Request, log *logrus.Entry, user
 		if err != nil {
 			if err != matrix.ErrNoToken {
 				log.Error("Error verifying token: ", err)
-				return api.InternalServerError("Unexpected Error")
+				return InternalServerError("Unexpected Error")
 			}
 
 			log.Warn("Failed to verify token (non-fatal)")
@@ -51,19 +50,19 @@ func AccessTokenOptionalRoute(next func(r *http.Request, log *logrus.Entry, user
 		}
 
 		log = log.WithFields(logrus.Fields{"authUserId": userId})
-		return next(r, log, userInfo{userId, accessToken})
+		return next(r, log, UserInfo{userId, accessToken})
 	}
 }
 
-func RepoAdminRoute(next func(r *http.Request, log *logrus.Entry, user userInfo) interface{}) func(*http.Request, *logrus.Entry) interface{} {
-	return AccessTokenRequiredRoute(func(r *http.Request, log *logrus.Entry, user userInfo) interface{} {
-		if user.userId == "" {
+func RepoAdminRoute(next func(r *http.Request, log *logrus.Entry, user UserInfo) interface{}) func(*http.Request, *logrus.Entry) interface{} {
+	return AccessTokenRequiredRoute(func(r *http.Request, log *logrus.Entry, user UserInfo) interface{} {
+		if user.UserId == "" {
 			log.Warn("Could not identify user for this admin route")
-			return api.AuthFailed()
+			return AuthFailed()
 		}
-		if !util.IsGlobalAdmin(user.userId) {
-			log.Warn("User " + user.userId + " is not a repository administrator")
-			return api.AuthFailed()
+		if !util.IsGlobalAdmin(user.UserId) {
+			log.Warn("User " + user.UserId + " is not a repository administrator")
+			return AuthFailed()
 		}
 
 		log = log.WithFields(logrus.Fields{"isRepoAdmin": true})
diff --git a/src/github.com/turt2live/matrix-media-repo/api/r0/download.go b/src/github.com/turt2live/matrix-media-repo/api/r0/download.go
index 2609928a..a6e5460d 100644
--- a/src/github.com/turt2live/matrix-media-repo/api/r0/download.go
+++ b/src/github.com/turt2live/matrix-media-repo/api/r0/download.go
@@ -19,9 +19,9 @@ type DownloadMediaResponse struct {
 	Data        io.ReadCloser
 }
 
-func DownloadMedia(r *http.Request, log *logrus.Entry, user userInfo) interface{} {
+func DownloadMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} {
 	hs := util.GetHomeserverConfig(r.Host)
-	if hs.DownloadRequiresAuth && user.userId == "" {
+	if hs.DownloadRequiresAuth && user.UserId == "" {
 		log.Warn("Homeserver requires authenticated downloads - denying request")
 		return api.AuthFailed()
 	}
diff --git a/src/github.com/turt2live/matrix-media-repo/api/r0/identicon.go b/src/github.com/turt2live/matrix-media-repo/api/r0/identicon.go
index 9a39a49a..82f2ab76 100644
--- a/src/github.com/turt2live/matrix-media-repo/api/r0/identicon.go
+++ b/src/github.com/turt2live/matrix-media-repo/api/r0/identicon.go
@@ -21,9 +21,9 @@ type IdenticonResponse struct {
 	Avatar io.Reader
 }
 
-func Identicon(r *http.Request, log *logrus.Entry, user userInfo) interface{} {
+func Identicon(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} {
 	hs := util.GetHomeserverConfig(r.Host)
-	if hs.DownloadRequiresAuth && user.userId == "" {
+	if hs.DownloadRequiresAuth && user.UserId == "" {
 		log.Warn("Homeserver requires authenticated downloads - denying request")
 		return api.AuthFailed()
 	}
diff --git a/src/github.com/turt2live/matrix-media-repo/api/r0/info.go b/src/github.com/turt2live/matrix-media-repo/api/r0/info.go
index a241c69f..6624241b 100644
--- a/src/github.com/turt2live/matrix-media-repo/api/r0/info.go
+++ b/src/github.com/turt2live/matrix-media-repo/api/r0/info.go
@@ -19,7 +19,7 @@ type MediaInfoResponse struct {
 	Size        int64  `json:"size"`
 }
 
-func MediaInfo(r *http.Request, log *logrus.Entry, user userInfo) interface{} {
+func MediaInfo(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} {
 	params := mux.Vars(r)
 
 	server := params["server"]
diff --git a/src/github.com/turt2live/matrix-media-repo/api/r0/local_copy.go b/src/github.com/turt2live/matrix-media-repo/api/r0/local_copy.go
index e00da83e..a3148ed1 100644
--- a/src/github.com/turt2live/matrix-media-repo/api/r0/local_copy.go
+++ b/src/github.com/turt2live/matrix-media-repo/api/r0/local_copy.go
@@ -11,7 +11,7 @@ import (
 	"github.com/turt2live/matrix-media-repo/util/errs"
 )
 
-func LocalCopy(r *http.Request, log *logrus.Entry, user userInfo) interface{} {
+func LocalCopy(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} {
 	params := mux.Vars(r)
 
 	server := params["server"]
@@ -46,7 +46,7 @@ func LocalCopy(r *http.Request, log *logrus.Entry, user userInfo) interface{} {
 		return &MediaUploadedResponse{streamedMedia.Media.MxcUri()}
 	}
 
-	newMedia, err := svc.StoreMedia(streamedMedia.Stream, streamedMedia.Media.ContentType, streamedMedia.Media.UploadName, user.userId, r.Host, "")
+	newMedia, err := svc.StoreMedia(streamedMedia.Stream, streamedMedia.Media.ContentType, streamedMedia.Media.UploadName, user.UserId, r.Host, "")
 	if err != nil {
 		if err == errs.ErrMediaNotAllowed {
 			return api.BadRequest("Media content type not allowed on this server")
diff --git a/src/github.com/turt2live/matrix-media-repo/api/r0/preview_url.go b/src/github.com/turt2live/matrix-media-repo/api/r0/preview_url.go
index e52d7ad3..fdc80bae 100644
--- a/src/github.com/turt2live/matrix-media-repo/api/r0/preview_url.go
+++ b/src/github.com/turt2live/matrix-media-repo/api/r0/preview_url.go
@@ -26,7 +26,7 @@ type MatrixOpenGraph struct {
 	ImageHeight int    `json:"og:image:height,omitempty"`
 }
 
-func PreviewUrl(r *http.Request, log *logrus.Entry, user userInfo) interface{} {
+func PreviewUrl(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} {
 	if !config.Get().UrlPreviews.Enabled {
 		return api.NotFoundError()
 	}
@@ -55,7 +55,7 @@ func PreviewUrl(r *http.Request, log *logrus.Entry, user userInfo) interface{} {
 	}
 
 	svc := url_service.New(r.Context(), log)
-	preview, err := svc.GetPreview(urlStr, r.Host, user.userId, ts)
+	preview, err := svc.GetPreview(urlStr, r.Host, user.UserId, ts)
 	if err != nil {
 		if err == errs.ErrMediaNotFound || err == errs.ErrHostNotFound {
 			return api.NotFoundError()
diff --git a/src/github.com/turt2live/matrix-media-repo/api/r0/purge.go b/src/github.com/turt2live/matrix-media-repo/api/r0/purge.go
index 2fcc2ab2..18049839 100644
--- a/src/github.com/turt2live/matrix-media-repo/api/r0/purge.go
+++ b/src/github.com/turt2live/matrix-media-repo/api/r0/purge.go
@@ -13,7 +13,7 @@ type MediaPurgedResponse struct {
 	NumRemoved int `json:"total_removed"`
 }
 
-func PurgeRemoteMedia(r *http.Request, log *logrus.Entry, user userInfo) interface{} {
+func PurgeRemoteMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} {
 	beforeTsStr := r.URL.Query().Get("before_ts")
 	if beforeTsStr == "" {
 		return api.BadRequest("Missing before_ts argument")
diff --git a/src/github.com/turt2live/matrix-media-repo/api/r0/quarantine.go b/src/github.com/turt2live/matrix-media-repo/api/r0/quarantine.go
index 79da321c..7c21027c 100644
--- a/src/github.com/turt2live/matrix-media-repo/api/r0/quarantine.go
+++ b/src/github.com/turt2live/matrix-media-repo/api/r0/quarantine.go
@@ -18,7 +18,7 @@ type MediaQuarantinedResponse struct {
 	NumQuarantined int `json:"num_quarantined"`
 }
 
-func QuarantineRoomMedia(r *http.Request, log *logrus.Entry, user userInfo) interface{} {
+func QuarantineRoomMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} {
 	canQuarantine, allowOtherHosts, isLocalAdmin := getQuarantineRequestInfo(r, log, user)
 	if !canQuarantine {
 		return api.AuthFailed()
@@ -33,7 +33,7 @@ func QuarantineRoomMedia(r *http.Request, log *logrus.Entry, user userInfo) inte
 		"localAdmin": isLocalAdmin,
 	})
 
-	allMedia, err := matrix.ListMedia(r.Context(), r.Host, user.accessToken, roomId)
+	allMedia, err := matrix.ListMedia(r.Context(), r.Host, user.AccessToken, roomId)
 	if err != nil {
 		log.Error("Error while listing media in the room: " + err.Error())
 		return api.InternalServerError("error retrieving media in room")
@@ -67,7 +67,7 @@ func QuarantineRoomMedia(r *http.Request, log *logrus.Entry, user userInfo) inte
 	return &MediaQuarantinedResponse{NumQuarantined: total}
 }
 
-func QuarantineMedia(r *http.Request, log *logrus.Entry, user userInfo) interface{} {
+func QuarantineMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} {
 	canQuarantine, allowOtherHosts, isLocalAdmin := getQuarantineRequestInfo(r, log, user)
 	if !canQuarantine {
 		return api.AuthFailed()
@@ -115,15 +115,15 @@ func doQuarantine(ctx context.Context, log *logrus.Entry, server string, mediaId
 	return &MediaQuarantinedResponse{NumQuarantined: num}, true
 }
 
-func getQuarantineRequestInfo(r *http.Request, log *logrus.Entry, user userInfo) (bool, bool, bool) {
-	isGlobalAdmin := util.IsGlobalAdmin(user.userId)
+func getQuarantineRequestInfo(r *http.Request, log *logrus.Entry, user api.UserInfo) (bool, bool, bool) {
+	isGlobalAdmin := util.IsGlobalAdmin(user.UserId)
 	canQuarantine := isGlobalAdmin
 	allowOtherHosts := isGlobalAdmin
 	isLocalAdmin := false
 	var err error
 	if !isGlobalAdmin {
 		if config.Get().Quarantine.AllowLocalAdmins {
-			isLocalAdmin, err = matrix.IsUserAdmin(r.Context(), r.Host, user.accessToken)
+			isLocalAdmin, err = matrix.IsUserAdmin(r.Context(), r.Host, user.AccessToken)
 			if err != nil {
 				log.Error("Error verifying local admin: " + err.Error())
 				canQuarantine = false
@@ -131,7 +131,7 @@ func getQuarantineRequestInfo(r *http.Request, log *logrus.Entry, user userInfo)
 			}
 
 			if !isLocalAdmin {
-				log.Warn(user.userId + " tried to quarantine media on another server")
+				log.Warn(user.UserId + " tried to quarantine media on another server")
 				canQuarantine = false
 				return canQuarantine, allowOtherHosts, isLocalAdmin
 			}
@@ -142,7 +142,7 @@ func getQuarantineRequestInfo(r *http.Request, log *logrus.Entry, user userInfo)
 	}
 
 	if !canQuarantine {
-		log.Warn(user.userId + " tried to quarantine media")
+		log.Warn(user.UserId + " tried to quarantine media")
 	}
 
 	return canQuarantine, allowOtherHosts, isLocalAdmin
diff --git a/src/github.com/turt2live/matrix-media-repo/api/r0/thumbnail.go b/src/github.com/turt2live/matrix-media-repo/api/r0/thumbnail.go
index e7f65eb9..8b68fd44 100644
--- a/src/github.com/turt2live/matrix-media-repo/api/r0/thumbnail.go
+++ b/src/github.com/turt2live/matrix-media-repo/api/r0/thumbnail.go
@@ -13,9 +13,9 @@ import (
 	"github.com/turt2live/matrix-media-repo/util/errs"
 )
 
-func ThumbnailMedia(r *http.Request, log *logrus.Entry, user userInfo) interface{} {
+func ThumbnailMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} {
 	hs := util.GetHomeserverConfig(r.Host)
-	if hs.DownloadRequiresAuth && user.userId == "" {
+	if hs.DownloadRequiresAuth && user.UserId == "" {
 		log.Warn("Homeserver requires authenticated downloads - denying request")
 		return api.AuthFailed()
 	}
diff --git a/src/github.com/turt2live/matrix-media-repo/api/r0/upload.go b/src/github.com/turt2live/matrix-media-repo/api/r0/upload.go
index d731b7df..1ae677ed 100644
--- a/src/github.com/turt2live/matrix-media-repo/api/r0/upload.go
+++ b/src/github.com/turt2live/matrix-media-repo/api/r0/upload.go
@@ -15,7 +15,7 @@ type MediaUploadedResponse struct {
 	ContentUri string `json:"content_uri"`
 }
 
-func UploadMedia(r *http.Request, log *logrus.Entry, user userInfo) interface{} {
+func UploadMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} {
 	filename := r.URL.Query().Get("filename")
 	if filename == "" {
 		filename = "upload.bin"
@@ -38,7 +38,7 @@ func UploadMedia(r *http.Request, log *logrus.Entry, user userInfo) interface{}
 		return api.RequestTooLarge()
 	}
 
-	media, err := svc.UploadMedia(r.Body, contentType, filename, user.userId, r.Host)
+	media, err := svc.UploadMedia(r.Body, contentType, filename, user.UserId, r.Host)
 	if err != nil {
 		io.Copy(ioutil.Discard, r.Body) // Ditch the entire request
 		defer r.Body.Close()
diff --git a/src/github.com/turt2live/matrix-media-repo/cmd/media_repo/main.go b/src/github.com/turt2live/matrix-media-repo/cmd/media_repo/main.go
index c14c35f6..9bf3ec6f 100644
--- a/src/github.com/turt2live/matrix-media-repo/cmd/media_repo/main.go
+++ b/src/github.com/turt2live/matrix-media-repo/cmd/media_repo/main.go
@@ -66,16 +66,16 @@ func main() {
 	hOpts := HandlerOpts{&counter}
 
 	optionsHandler := Handler{optionsRequest, hOpts}
-	uploadHandler := Handler{r0.AccessTokenRequiredRoute(r0.UploadMedia), hOpts}
-	downloadHandler := Handler{r0.AccessTokenOptionalRoute(r0.DownloadMedia), hOpts}
-	thumbnailHandler := Handler{r0.AccessTokenOptionalRoute(r0.ThumbnailMedia), hOpts}
-	previewUrlHandler := Handler{r0.AccessTokenRequiredRoute(r0.PreviewUrl), hOpts}
-	identiconHandler := Handler{r0.AccessTokenOptionalRoute(r0.Identicon), hOpts}
-	purgeHandler := Handler{r0.RepoAdminRoute(r0.PurgeRemoteMedia), hOpts}
-	quarantineHandler := Handler{r0.AccessTokenRequiredRoute(r0.QuarantineMedia), hOpts}
-	quarantineRoomHandler := Handler{r0.AccessTokenRequiredRoute(r0.QuarantineRoomMedia), hOpts}
-	localCopyHandler := Handler{r0.AccessTokenRequiredRoute(r0.LocalCopy), hOpts}
-	infoHandler := Handler{r0.AccessTokenRequiredRoute(r0.MediaInfo), hOpts}
+	uploadHandler := Handler{api.AccessTokenRequiredRoute(r0.UploadMedia), hOpts}
+	downloadHandler := Handler{api.AccessTokenOptionalRoute(r0.DownloadMedia), hOpts}
+	thumbnailHandler := Handler{api.AccessTokenOptionalRoute(r0.ThumbnailMedia), hOpts}
+	previewUrlHandler := Handler{api.AccessTokenRequiredRoute(r0.PreviewUrl), hOpts}
+	identiconHandler := Handler{api.AccessTokenOptionalRoute(r0.Identicon), hOpts}
+	purgeHandler := Handler{api.RepoAdminRoute(r0.PurgeRemoteMedia), hOpts}
+	quarantineHandler := Handler{api.AccessTokenRequiredRoute(r0.QuarantineMedia), hOpts}
+	quarantineRoomHandler := Handler{api.AccessTokenRequiredRoute(r0.QuarantineRoomMedia), hOpts}
+	localCopyHandler := Handler{api.AccessTokenRequiredRoute(r0.LocalCopy), hOpts}
+	infoHandler := Handler{api.AccessTokenRequiredRoute(r0.MediaInfo), hOpts}
 
 	routes := make(map[string]*ApiRoute)
 	versions := []string{"r0", "v1"} // r0 is typically clients and v1 is typically servers
-- 
GitLab