diff --git a/src/github.com/turt2live/matrix-media-repo/api/auth.go b/src/github.com/turt2live/matrix-media-repo/api/auth.go index 0bca41719dc880abe876c473eed78bc7f884b5eb..57fb5701caa9736d809e9ec6d8e3085e60f0c6fa 100644 --- a/src/github.com/turt2live/matrix-media-repo/api/auth.go +++ b/src/github.com/turt2live/matrix-media-repo/api/auth.go @@ -17,7 +17,7 @@ func AccessTokenRequiredRoute(next func(r *http.Request, log *logrus.Entry, user return func(r *http.Request, log *logrus.Entry) interface{} { accessToken := util.GetAccessTokenFromRequest(r) appserviceUserId := util.GetAppserviceUserIdFromRequest(r) - userId, err := matrix.GetUserIdFromToken(r.Context(), r.Host, accessToken, appserviceUserId) + userId, err := matrix.GetUserIdFromToken(r.Context(), r.Host, accessToken, appserviceUserId, r.RemoteAddr) if err != nil || userId == "" { if err != nil && err != matrix.ErrNoToken { log.Error("Error verifying token: ", err) @@ -37,7 +37,7 @@ func AccessTokenOptionalRoute(next func(r *http.Request, log *logrus.Entry, user return func(r *http.Request, log *logrus.Entry) interface{} { accessToken := util.GetAccessTokenFromRequest(r) appserviceUserId := util.GetAppserviceUserIdFromRequest(r) - userId, err := matrix.GetUserIdFromToken(r.Context(), r.Host, accessToken, appserviceUserId) + userId, err := matrix.GetUserIdFromToken(r.Context(), r.Host, accessToken, appserviceUserId, r.RemoteAddr) if err != nil { if err != matrix.ErrNoToken { log.Error("Error verifying token: ", err) diff --git a/src/github.com/turt2live/matrix-media-repo/api/custom/quarantine.go b/src/github.com/turt2live/matrix-media-repo/api/custom/quarantine.go index 627a1c121fa503a967fcd089abbaf478e2c52b88..a1cfff12d7df5ec06d56c7a449ca7d6218efe70d 100644 --- a/src/github.com/turt2live/matrix-media-repo/api/custom/quarantine.go +++ b/src/github.com/turt2live/matrix-media-repo/api/custom/quarantine.go @@ -38,7 +38,7 @@ func QuarantineRoomMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) "localAdmin": isLocalAdmin, }) - allMedia, err := matrix.ListMedia(r.Context(), r.Host, user.AccessToken, roomId) + allMedia, err := matrix.ListMedia(r.Context(), r.Host, user.AccessToken, roomId, r.RemoteAddr) if err != nil { log.Error("Error while listing media in the room: " + err.Error()) return api.InternalServerError("error retrieving media in room") @@ -158,7 +158,7 @@ func getQuarantineRequestInfo(r *http.Request, log *logrus.Entry, user api.UserI var err error if !isGlobalAdmin { if config.Get().Quarantine.AllowLocalAdmins { - isLocalAdmin, err = matrix.IsUserAdmin(r.Context(), r.Host, user.AccessToken) + isLocalAdmin, err = matrix.IsUserAdmin(r.Context(), r.Host, user.AccessToken, r.RemoteAddr) if err != nil { log.Error("Error verifying local admin: " + err.Error()) canQuarantine = false diff --git a/src/github.com/turt2live/matrix-media-repo/matrix/admin.go b/src/github.com/turt2live/matrix-media-repo/matrix/admin.go index f5f2104c1f7f268674030222efd70d20f3bb3d8e..d58dc1608fe655b872ce85f3187cf786e9393410 100644 --- a/src/github.com/turt2live/matrix-media-repo/matrix/admin.go +++ b/src/github.com/turt2live/matrix-media-repo/matrix/admin.go @@ -5,7 +5,7 @@ import ( "time" ) -func IsUserAdmin(ctx context.Context, serverName string, accessToken string) (bool, error) { +func IsUserAdmin(ctx context.Context, serverName string, accessToken string, ipAddr string) (bool, error) { fakeUser := "@media.repo.admin.check:" + serverName hs, cb := getBreakerAndConfig(serverName) @@ -15,7 +15,7 @@ func IsUserAdmin(ctx context.Context, serverName string, accessToken string) (bo response := &whoisResponse{} url := makeUrl(hs.ClientServerApi, "/_matrix/client/r0/admin/whois/", fakeUser) - err := doRequest("GET", url, nil, response, accessToken) + err := doRequest("GET", url, nil, response, accessToken, ipAddr) if err != nil { err, replyError = filterError(err) return err @@ -28,14 +28,14 @@ func IsUserAdmin(ctx context.Context, serverName string, accessToken string) (bo return isAdmin, replyError } -func ListMedia(ctx context.Context, serverName string, accessToken string, roomId string) (*mediaListResponse, error) { +func ListMedia(ctx context.Context, serverName string, accessToken string, roomId string, ipAddr string) (*mediaListResponse, error) { hs, cb := getBreakerAndConfig(serverName) response := &mediaListResponse{} var replyError error replyError = cb.CallContext(ctx, func() error { url := makeUrl(hs.ClientServerApi, "/_matrix/client/r0/admin/room/", roomId, "/media") - err := doRequest("GET", url, nil, response, accessToken) + err := doRequest("GET", url, nil, response, accessToken, ipAddr) if err != nil { err, replyError = filterError(err) return err diff --git a/src/github.com/turt2live/matrix-media-repo/matrix/auth.go b/src/github.com/turt2live/matrix-media-repo/matrix/auth.go index 74b370d9c3ede3c3511e065a407d0263783cc553..2355be2393c8f799e22a8f3ef7814183651aad37 100644 --- a/src/github.com/turt2live/matrix-media-repo/matrix/auth.go +++ b/src/github.com/turt2live/matrix-media-repo/matrix/auth.go @@ -10,7 +10,7 @@ import ( var ErrNoToken = errors.New("Missing access token") -func GetUserIdFromToken(ctx context.Context, serverName string, accessToken string, appserviceUserId string) (string, error) { +func GetUserIdFromToken(ctx context.Context, serverName string, accessToken string, appserviceUserId string, ipAddr string) (string, error) { if accessToken == "" { return "", ErrNoToken } @@ -32,7 +32,7 @@ func GetUserIdFromToken(ctx context.Context, serverName string, accessToken stri q.Set(k, v) } target.RawQuery = q.Encode() - err := doRequest("GET", target.String(), nil, response, accessToken) + err := doRequest("GET", target.String(), nil, response, accessToken, ipAddr) if err != nil { err, replyError = filterError(err) return err diff --git a/src/github.com/turt2live/matrix-media-repo/matrix/client_server.go b/src/github.com/turt2live/matrix-media-repo/matrix/client_server.go index c8b9a8ccf7938dd148c33bdfea083a19072ebbba..85d859ea4900c37086c6b0ddf334136bde866ae4 100644 --- a/src/github.com/turt2live/matrix-media-repo/matrix/client_server.go +++ b/src/github.com/turt2live/matrix-media-repo/matrix/client_server.go @@ -15,7 +15,7 @@ var matrixHttpClient = &http.Client{ } // Based in part on https://github.com/matrix-org/gomatrix/blob/072b39f7fa6b40257b4eead8c958d71985c28bdd/client.go#L180-L243 -func doRequest(method string, urlStr string, body interface{}, result interface{}, accessToken string) (error) { +func doRequest(method string, urlStr string, body interface{}, result interface{}, accessToken string, ipAddr string) (error) { var bodyBytes []byte if body != nil { jsonStr, err := json.Marshal(body) @@ -35,6 +35,10 @@ func doRequest(method string, urlStr string, body interface{}, result interface{ if accessToken != "" { req.Header.Set("Authorization", "Bearer "+accessToken) } + if ipAddr != "" { + req.Header.Set("X-Forwarded-For", ipAddr) + req.Header.Set("X-Real-IP", ipAddr) + } res, err := matrixHttpClient.Do(req) if res != nil {