diff --git a/api/auth.go b/api/auth.go index 5d271b8e54b5ecd2c3f14dcd6edd89dbabf9b318..3d37b9767c97c43ebeb0eab0946adf303ef4f3da 100644 --- a/api/auth.go +++ b/api/auth.go @@ -37,8 +37,8 @@ func AccessTokenRequiredRoute(next func(r *http.Request, rctx rcontext.RequestCo appserviceUserId := util.GetAppserviceUserIdFromRequest(r) userId, err := matrix.GetUserIdFromToken(rctx, r.Host, accessToken, appserviceUserId, r.RemoteAddr) if err != nil || userId == "" { - if err != nil && err != matrix.ErrNoToken { - rctx.Log.Error("Error verifying token: ", err) + if err != nil && err != matrix.ErrInvalidToken { + rctx.Log.Error("Error verifying token (fatal): ", err) return InternalServerError("Unexpected Error") } @@ -65,7 +65,7 @@ func AccessTokenOptionalRoute(next func(r *http.Request, rctx rcontext.RequestCo appserviceUserId := util.GetAppserviceUserIdFromRequest(r) userId, err := matrix.GetUserIdFromToken(rctx, r.Host, accessToken, appserviceUserId, r.RemoteAddr) if err != nil { - if err != matrix.ErrNoToken { + if err != matrix.ErrInvalidToken { rctx.Log.Error("Error verifying token: ", err) return InternalServerError("Unexpected Error") } diff --git a/matrix/auth.go b/matrix/auth.go index df0ab62cafe3c9b945c7119e8db533a95681c598..689147a9c12d58e2bd0f71c8b6d3fc3c5345d2ba 100644 --- a/matrix/auth.go +++ b/matrix/auth.go @@ -8,17 +8,18 @@ import ( "github.com/turt2live/matrix-media-repo/common/rcontext" ) -var ErrNoToken = errors.New("Missing access token") +var ErrInvalidToken = errors.New("Missing or invalid access token") func GetUserIdFromToken(ctx rcontext.RequestContext, serverName string, accessToken string, appserviceUserId string, ipAddr string) (string, error) { if accessToken == "" { - return "", ErrNoToken + return "", ErrInvalidToken } hs, cb := getBreakerAndConfig(serverName) userId := "" var replyError error + var authError error replyError = cb.CallContext(ctx, func() error { query := map[string]string{} if appserviceUserId != "" { @@ -34,7 +35,8 @@ func GetUserIdFromToken(ctx rcontext.RequestContext, serverName string, accessTo target.RawQuery = q.Encode() err := doRequest(ctx, "GET", target.String(), nil, response, accessToken, ipAddr) if err != nil { - err, replyError = filterError(err) + ctx.Log.Warn("Error from homeserver: ", err) + err, authError = filterError(err) return err } @@ -42,8 +44,8 @@ func GetUserIdFromToken(ctx rcontext.RequestContext, serverName string, accessTo return nil }, 1*time.Minute) - if replyError == nil { - return userId, nil + if authError != nil { + return userId, authError } return userId, replyError } diff --git a/matrix/matrix.go b/matrix/matrix.go index 70417b6cec4a6734940c511b580fc0b4d0125e45..7a278afdbb6eb43940c9e1ef358341fc288253b6 100644 --- a/matrix/matrix.go +++ b/matrix/matrix.go @@ -35,10 +35,10 @@ func filterError(err error) (error, error) { } // Unknown token errors should be filtered out explicitly to ensure we don't break on bad requests - if httpErr, ok := err.(errorResponse); ok { + if httpErr, ok := err.(*errorResponse); ok { if httpErr.ErrorCode == common.ErrCodeUnknownToken { // We send back our own version of 'unknown token' to ensure we can filter it out elsewhere - return nil, ErrNoToken + return nil, ErrInvalidToken } }