Skip to content
Snippets Groups Projects
Commit b5181187 authored by Travis Ralston's avatar Travis Ralston
Browse files

Fix handling of guest accounts (MSC3069)

parent e0ddaf9c
No related branches found
No related tags found
No related merge requests found
...@@ -7,6 +7,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ...@@ -7,6 +7,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased] ## [Unreleased]
### Fixed
* Handle guest accounts properly. Previously they were still declined, though by coincidence.
## [1.2.5] - March 17th, 2021 ## [1.2.5] - March 17th, 2021
### Added ### Added
......
...@@ -39,6 +39,9 @@ func AccessTokenRequiredRoute(next func(r *http.Request, rctx rcontext.RequestCo ...@@ -39,6 +39,9 @@ func AccessTokenRequiredRoute(next func(r *http.Request, rctx rcontext.RequestCo
appserviceUserId := util.GetAppserviceUserIdFromRequest(r) appserviceUserId := util.GetAppserviceUserIdFromRequest(r)
userId, err := auth_cache.GetUserId(rctx, accessToken, appserviceUserId) userId, err := auth_cache.GetUserId(rctx, accessToken, appserviceUserId)
if err != nil || userId == "" { if err != nil || userId == "" {
if err == matrix.ErrGuestToken {
return GuestAuthFailed()
}
if err != nil && err != matrix.ErrInvalidToken { if err != nil && err != matrix.ErrInvalidToken {
sentry.CaptureException(err) sentry.CaptureException(err)
rctx.Log.Error("Error verifying token (fatal): ", err) rctx.Log.Error("Error verifying token (fatal): ", err)
......
...@@ -46,6 +46,10 @@ func AuthFailed() *ErrorResponse { ...@@ -46,6 +46,10 @@ func AuthFailed() *ErrorResponse {
return &ErrorResponse{common.ErrCodeUnknownToken, "Authentication Failed", common.ErrCodeUnknownToken} return &ErrorResponse{common.ErrCodeUnknownToken, "Authentication Failed", common.ErrCodeUnknownToken}
} }
func GuestAuthFailed() *ErrorResponse {
return &ErrorResponse{common.ErrCodeNoGuests, "Guests cannot use this endpoint", common.ErrCodeNoGuests}
}
func BadRequest(message string) *ErrorResponse { func BadRequest(message string) *ErrorResponse {
return &ErrorResponse{common.ErrCodeUnknown, message, common.ErrCodeBadRequest} return &ErrorResponse{common.ErrCodeUnknown, message, common.ErrCodeBadRequest}
} }
......
...@@ -5,6 +5,7 @@ const ErrCodeHostNotFound = "M_HOST_NOT_FOUND" ...@@ -5,6 +5,7 @@ const ErrCodeHostNotFound = "M_HOST_NOT_FOUND"
const ErrCodeHostBlacklisted = "M_HOST_BLACKLISTED" const ErrCodeHostBlacklisted = "M_HOST_BLACKLISTED"
const ErrCodeNotFound = "M_NOT_FOUND" const ErrCodeNotFound = "M_NOT_FOUND"
const ErrCodeUnknownToken = "M_UNKNOWN_TOKEN" const ErrCodeUnknownToken = "M_UNKNOWN_TOKEN"
const ErrCodeNoGuests = "M_GUEST_ACCESS_FORBIDDEN"
const ErrCodeMissingToken = "M_MISSING_TOKEN" const ErrCodeMissingToken = "M_MISSING_TOKEN"
const ErrCodeMediaTooLarge = "M_MEDIA_TOO_LARGE" const ErrCodeMediaTooLarge = "M_MEDIA_TOO_LARGE"
const ErrCodeMediaTooSmall = "M_MEDIA_TOO_SMALL" const ErrCodeMediaTooSmall = "M_MEDIA_TOO_SMALL"
......
...@@ -10,6 +10,7 @@ import ( ...@@ -10,6 +10,7 @@ import (
) )
var ErrInvalidToken = errors.New("Missing or invalid access token") var ErrInvalidToken = errors.New("Missing or invalid access token")
var ErrGuestToken = errors.New("Token belongs to a guest")
func doBreakerRequest(ctx rcontext.RequestContext, serverName string, accessToken string, appserviceUserId string, ipAddr string, method string, path string, resp interface{}) error { func doBreakerRequest(ctx rcontext.RequestContext, serverName string, accessToken string, appserviceUserId string, ipAddr string, method string, path string, resp interface{}) error {
if accessToken == "" { if accessToken == "" {
...@@ -53,6 +54,9 @@ func GetUserIdFromToken(ctx rcontext.RequestContext, serverName string, accessTo ...@@ -53,6 +54,9 @@ func GetUserIdFromToken(ctx rcontext.RequestContext, serverName string, accessTo
if err != nil { if err != nil {
return "", err return "", err
} }
if response.IsGuest {
return "", ErrGuestToken
}
return response.UserId, nil return response.UserId, nil
} }
......
...@@ -36,9 +36,11 @@ func filterError(err error) (error, error) { ...@@ -36,9 +36,11 @@ func filterError(err error) (error, error) {
// Unknown token errors should be filtered out explicitly to ensure we don't break on bad requests // Unknown token errors should be filtered out explicitly to ensure we don't break on bad requests
if httpErr, ok := err.(*errorResponse); ok { if httpErr, ok := err.(*errorResponse); ok {
// We send back our own version of errors to ensure we can filter them out elsewhere
if httpErr.ErrorCode == common.ErrCodeUnknownToken { if httpErr.ErrorCode == common.ErrCodeUnknownToken {
// We send back our own version of 'unknown token' to ensure we can filter it out elsewhere
return nil, ErrInvalidToken return nil, ErrInvalidToken
} else if httpErr.ErrorCode == common.ErrCodeNoGuests {
return nil, ErrGuestToken
} }
} }
......
...@@ -8,7 +8,8 @@ type emptyResponse struct { ...@@ -8,7 +8,8 @@ type emptyResponse struct {
} }
type userIdResponse struct { type userIdResponse struct {
UserId string `json:"user_id"` UserId string `json:"user_id"`
IsGuest bool `json:"org.matrix.msc3069.is_guest"`
} }
type whoisResponse struct { type whoisResponse struct {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment