diff --git a/api/auth.go b/api/auth.go index 9b2798175cb9833e2b3adfd39c5f8b3e8aee85a9..5d271b8e54b5ecd2c3f14dcd6edd89dbabf9b318 100644 --- a/api/auth.go +++ b/api/auth.go @@ -6,6 +6,7 @@ import ( "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/common" "github.com/turt2live/matrix-media-repo/common/config" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/matrix" "github.com/turt2live/matrix-media-repo/util" ) @@ -16,88 +17,93 @@ type UserInfo struct { IsShared bool } -func AccessTokenRequiredRoute(next func(r *http.Request, log *logrus.Entry, user UserInfo) interface{}) func(*http.Request, *logrus.Entry) interface{} { - return func(r *http.Request, log *logrus.Entry) interface{} { +func callUserNext(next func(r *http.Request, rctx rcontext.RequestContext, user UserInfo) interface{}, r *http.Request, rctx rcontext.RequestContext, user UserInfo) interface{} { + r.WithContext(rctx) + return next(r, rctx, user) +} + +func AccessTokenRequiredRoute(next func(r *http.Request, rctx rcontext.RequestContext, user UserInfo) interface{}) func(*http.Request, rcontext.RequestContext) interface{} { + return func(r *http.Request, rctx rcontext.RequestContext) interface{} { accessToken := util.GetAccessTokenFromRequest(r) if accessToken == "" { - log.Error("Error: no token provided (required)") + rctx.Log.Error("Error: no token provided (required)") return &ErrorResponse{common.ErrCodeMissingToken, "no token provided (required)", common.ErrCodeUnknownToken} } if config.Get().SharedSecret.Enabled && accessToken == config.Get().SharedSecret.Token { - log = log.WithFields(logrus.Fields{"isRepoAdmin": true}) + log := rctx.Log.WithFields(logrus.Fields{"isRepoAdmin": true}) log.Info("User authed using shared secret") - return next(r, log, UserInfo{UserId: "@sharedsecret", AccessToken: accessToken, IsShared: true}) + return callUserNext(next, r, rctx, UserInfo{UserId: "@sharedsecret", AccessToken: accessToken, IsShared: true}) } appserviceUserId := util.GetAppserviceUserIdFromRequest(r) - userId, err := matrix.GetUserIdFromToken(r.Context(), r.Host, accessToken, appserviceUserId, r.RemoteAddr) + userId, err := matrix.GetUserIdFromToken(rctx, r.Host, accessToken, appserviceUserId, r.RemoteAddr) if err != nil || userId == "" { if err != nil && err != matrix.ErrNoToken { - log.Error("Error verifying token: ", err) + rctx.Log.Error("Error verifying token: ", err) return InternalServerError("Unexpected Error") } - log.Warn("Failed to verify token (fatal): ", err) + rctx.Log.Warn("Failed to verify token (fatal): ", err) return AuthFailed() } - log = log.WithFields(logrus.Fields{"authUserId": userId}) - return next(r, log, UserInfo{userId, accessToken, false}) + rctx = rctx.LogWithFields(logrus.Fields{"authUserId": userId}) + return callUserNext(next, r, rctx, UserInfo{userId, accessToken, false}) } } -func AccessTokenOptionalRoute(next func(r *http.Request, log *logrus.Entry, user UserInfo) interface{}) func(*http.Request, *logrus.Entry) interface{} { - return func(r *http.Request, log *logrus.Entry) interface{} { +func AccessTokenOptionalRoute(next func(r *http.Request, rctx rcontext.RequestContext, user UserInfo) interface{}) func(*http.Request, rcontext.RequestContext) interface{} { + return func(r *http.Request, rctx rcontext.RequestContext) interface{} { accessToken := util.GetAccessTokenFromRequest(r) if accessToken == "" { - return next(r, log, UserInfo{"", "", false}) + return callUserNext(next, r, rctx, UserInfo{"", "", false}) } if config.Get().SharedSecret.Enabled && accessToken == config.Get().SharedSecret.Token { - log = log.WithFields(logrus.Fields{"isRepoAdmin": true}) - log.Info("User authed using shared secret") - return next(r, log, UserInfo{UserId: "@sharedsecret", AccessToken: accessToken, IsShared: true}) + rctx = rctx.LogWithFields(logrus.Fields{"isRepoAdmin": true}) + rctx.Log.Info("User authed using shared secret") + return callUserNext(next, r, rctx, UserInfo{UserId: "@sharedsecret", AccessToken: accessToken, IsShared: true}) } appserviceUserId := util.GetAppserviceUserIdFromRequest(r) - userId, err := matrix.GetUserIdFromToken(r.Context(), r.Host, accessToken, appserviceUserId, r.RemoteAddr) + userId, err := matrix.GetUserIdFromToken(rctx, r.Host, accessToken, appserviceUserId, r.RemoteAddr) if err != nil { if err != matrix.ErrNoToken { - log.Error("Error verifying token: ", err) + rctx.Log.Error("Error verifying token: ", err) return InternalServerError("Unexpected Error") } - log.Warn("Failed to verify token (non-fatal): ", err) + rctx.Log.Warn("Failed to verify token (non-fatal): ", err) userId = "" } - log = log.WithFields(logrus.Fields{"authUserId": userId}) - return next(r, log, UserInfo{userId, accessToken, false}) + rctx = rctx.LogWithFields(logrus.Fields{"authUserId": userId}) + return callUserNext(next, r, rctx, UserInfo{userId, accessToken, false}) } } -func RepoAdminRoute(next func(r *http.Request, log *logrus.Entry, user UserInfo) interface{}) func(*http.Request, *logrus.Entry) interface{} { - regularFunc := AccessTokenRequiredRoute(func(r *http.Request, log *logrus.Entry, user UserInfo) interface{} { +func RepoAdminRoute(next func(r *http.Request, rctx rcontext.RequestContext, user UserInfo) interface{}) func(*http.Request, rcontext.RequestContext) interface{} { + regularFunc := AccessTokenRequiredRoute(func(r *http.Request, rctx rcontext.RequestContext, user UserInfo) interface{} { if user.UserId == "" { - log.Warn("Could not identify user for this admin route") + rctx.Log.Warn("Could not identify user for this admin route") return AuthFailed() } if !util.IsGlobalAdmin(user.UserId) { - log.Warn("User " + user.UserId + " is not a repository administrator") + rctx.Log.Warn("User " + user.UserId + " is not a repository administrator") return AuthFailed() } - log = log.WithFields(logrus.Fields{"isRepoAdmin": true}) - return next(r, log, user) + rctx = rctx.LogWithFields(logrus.Fields{"isRepoAdmin": true}) + return callUserNext(next, r, rctx, user) }) - return func(r *http.Request, log *logrus.Entry) interface{} { + return func(r *http.Request, rctx rcontext.RequestContext) interface{} { if config.Get().SharedSecret.Enabled { accessToken := util.GetAccessTokenFromRequest(r) if accessToken == config.Get().SharedSecret.Token { - log = log.WithFields(logrus.Fields{"isRepoAdmin": true}) - log.Info("User authed using shared secret") - return next(r, log, UserInfo{UserId: "@sharedsecret", AccessToken: accessToken, IsShared: true}) + rctx = rctx.LogWithFields(logrus.Fields{"isRepoAdmin": true}) + rctx.Log.Info("User authed using shared secret") + return callUserNext(next, r, rctx, UserInfo{UserId: "@sharedsecret", AccessToken: accessToken, IsShared: true}) } } - return regularFunc(r, log) + return regularFunc(r, rctx) } } diff --git a/api/custom/datastores.go b/api/custom/datastores.go index 3285a8547834aabf68f31f1af29a3f524effd847..5368c4d55c26bbe21051070d089957d4f60f2177 100644 --- a/api/custom/datastores.go +++ b/api/custom/datastores.go @@ -7,6 +7,7 @@ import ( "github.com/gorilla/mux" "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/api" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/controllers/maintenance_controller" "github.com/turt2live/matrix-media-repo/storage" "github.com/turt2live/matrix-media-repo/storage/datastore" @@ -19,10 +20,10 @@ type DatastoreMigration struct { TaskID int `json:"task_id"` } -func GetDatastores(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { - datastores, err := storage.GetDatabase().GetMediaStore(r.Context(), log).GetAllDatastores() +func GetDatastores(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) interface{} { + datastores, err := storage.GetDatabase().GetMediaStore(rctx).GetAllDatastores() if err != nil { - log.Error(err) + rctx.Log.Error(err) return api.InternalServerError("Error getting datastores") } @@ -38,7 +39,7 @@ func GetDatastores(r *http.Request, log *logrus.Entry, user api.UserInfo) interf return &api.DoNotCacheResponse{Payload: response} } -func MigrateBetweenDatastores(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { +func MigrateBetweenDatastores(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) interface{} { beforeTsStr := r.URL.Query().Get("before_ts") beforeTs := util.NowMillis() var err error @@ -54,7 +55,7 @@ func MigrateBetweenDatastores(r *http.Request, log *logrus.Entry, user api.UserI sourceDsId := params["sourceDsId"] targetDsId := params["targetDsId"] - log = log.WithFields(logrus.Fields{ + rctx = rctx.LogWithFields(logrus.Fields{ "beforeTs": beforeTs, "sourceDsId": sourceDsId, "targetDsId": targetDsId, @@ -64,28 +65,28 @@ func MigrateBetweenDatastores(r *http.Request, log *logrus.Entry, user api.UserI return api.BadRequest("Source and target datastore cannot be the same") } - sourceDatastore, err := datastore.LocateDatastore(r.Context(), log, sourceDsId) + sourceDatastore, err := datastore.LocateDatastore(rctx, sourceDsId) if err != nil { - log.Error(err) + rctx.Log.Error(err) return api.BadRequest("Error getting source datastore. Does it exist?") } - targetDatastore, err := datastore.LocateDatastore(r.Context(), log, targetDsId) + targetDatastore, err := datastore.LocateDatastore(rctx, targetDsId) if err != nil { - log.Error(err) + rctx.Log.Error(err) return api.BadRequest("Error getting target datastore. Does it exist?") } - log.Info("User ", user.UserId, " has started a datastore media transfer") - task, err := maintenance_controller.StartStorageMigration(sourceDatastore, targetDatastore, beforeTs, log) + rctx.Log.Info("User ", user.UserId, " has started a datastore media transfer") + task, err := maintenance_controller.StartStorageMigration(sourceDatastore, targetDatastore, beforeTs, rctx) if err != nil { - log.Error(err) + rctx.Log.Error(err) return api.InternalServerError("Unexpected error starting migration") } - estimate, err := maintenance_controller.EstimateDatastoreSizeWithAge(beforeTs, sourceDsId, r.Context(), log) + estimate, err := maintenance_controller.EstimateDatastoreSizeWithAge(beforeTs, sourceDsId, rctx) if err != nil { - log.Error(err) + rctx.Log.Error(err) return api.InternalServerError("Unexpected error getting storage estimate") } @@ -97,7 +98,7 @@ func MigrateBetweenDatastores(r *http.Request, log *logrus.Entry, user api.UserI return &api.DoNotCacheResponse{Payload: migration} } -func GetDatastoreStorageEstimate(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { +func GetDatastoreStorageEstimate(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) interface{} { beforeTsStr := r.URL.Query().Get("before_ts") beforeTs := util.NowMillis() var err error @@ -112,14 +113,14 @@ func GetDatastoreStorageEstimate(r *http.Request, log *logrus.Entry, user api.Us datastoreId := params["datastoreId"] - log = log.WithFields(logrus.Fields{ + rctx = rctx.LogWithFields(logrus.Fields{ "beforeTs": beforeTs, "datastoreId": datastoreId, }) - result, err := maintenance_controller.EstimateDatastoreSizeWithAge(beforeTs, datastoreId, r.Context(), log) + result, err := maintenance_controller.EstimateDatastoreSizeWithAge(beforeTs, datastoreId, rctx) if err != nil { - log.Error(err) + rctx.Log.Error(err) return api.InternalServerError("Unexpected error getting storage estimate") } return &api.DoNotCacheResponse{Payload: result} diff --git a/api/custom/exports.go b/api/custom/exports.go index 9d1c92384ef14a64019a5f9dc930ca7ddb90fb9e..8a4b3e38b2b8422eab61a753536b82ef66a02a33 100644 --- a/api/custom/exports.go +++ b/api/custom/exports.go @@ -11,6 +11,7 @@ import ( "github.com/turt2live/matrix-media-repo/api" "github.com/turt2live/matrix-media-repo/api/r0" "github.com/turt2live/matrix-media-repo/common/config" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/controllers/data_controller" "github.com/turt2live/matrix-media-repo/matrix" "github.com/turt2live/matrix-media-repo/storage" @@ -35,7 +36,7 @@ type ExportMetadata struct { Parts []*ExportPartMetadata `json:"parts"` } -func ExportUserData(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { +func ExportUserData(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) interface{} { if !config.Get().Archiving.Enabled { return api.BadRequest("archiving is not enabled") } @@ -52,18 +53,18 @@ func ExportUserData(r *http.Request, log *logrus.Entry, user api.UserInfo) inter userId := params["userId"] - if !isAdmin && user.UserId != userId { + if !isAdmin && user.UserId != userId { return api.BadRequest("cannot export data for another user") } - log = log.WithFields(logrus.Fields{ + rctx = rctx.LogWithFields(logrus.Fields{ "exportUserId": userId, "includeData": includeData, "s3urls": s3urls, }) - task, exportId, err := data_controller.StartUserExport(userId, s3urls, includeData, log) + task, exportId, err := data_controller.StartUserExport(userId, s3urls, includeData, rctx) if err != nil { - log.Error(err) + rctx.Log.Error(err) return api.InternalServerError("fatal error starting export") } @@ -73,7 +74,7 @@ func ExportUserData(r *http.Request, log *logrus.Entry, user api.UserInfo) inter }} } -func ExportServerData(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { +func ExportServerData(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) interface{} { if !config.Get().Archiving.Enabled { return api.BadRequest("archiving is not enabled") } @@ -98,9 +99,9 @@ func ExportServerData(r *http.Request, log *logrus.Entry, user api.UserInfo) int return api.BadRequest("cannot export data for another server") } - isLocalAdmin, err := matrix.IsUserAdmin(r.Context(), serverName, user.AccessToken, r.RemoteAddr) + isLocalAdmin, err := matrix.IsUserAdmin(rctx, serverName, user.AccessToken, r.RemoteAddr) if err != nil { - log.Error("Error verifying local admin: " + err.Error()) + rctx.Log.Error("Error verifying local admin: " + err.Error()) isLocalAdmin = false } if !isLocalAdmin { @@ -108,14 +109,14 @@ func ExportServerData(r *http.Request, log *logrus.Entry, user api.UserInfo) int } } - log = log.WithFields(logrus.Fields{ + rctx = rctx.LogWithFields(logrus.Fields{ "exportServerName": serverName, "includeData": includeData, "s3urls": s3urls, }) - task, exportId, err := data_controller.StartServerExport(serverName, s3urls, includeData, log) + task, exportId, err := data_controller.StartServerExport(serverName, s3urls, includeData, rctx) if err != nil { - log.Error(err) + rctx.Log.Error(err) return api.InternalServerError("fatal error starting export") } @@ -125,7 +126,7 @@ func ExportServerData(r *http.Request, log *logrus.Entry, user api.UserInfo) int }} } -func ViewExport(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { +func ViewExport(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) interface{} { if !config.Get().Archiving.Enabled { return api.BadRequest("archiving is not enabled") } @@ -133,27 +134,27 @@ func ViewExport(r *http.Request, log *logrus.Entry, user api.UserInfo) interface params := mux.Vars(r) exportId := params["exportId"] - log = log.WithFields(logrus.Fields{ + rctx = rctx.LogWithFields(logrus.Fields{ "exportId": exportId, }) - exportDb := storage.GetDatabase().GetExportStore(r.Context(), log) + exportDb := storage.GetDatabase().GetExportStore(rctx) exportInfo, err := exportDb.GetExportMetadata(exportId) if err != nil { - log.Error(err) + rctx.Log.Error(err) return api.InternalServerError("failed to get metadata") } parts, err := exportDb.GetExportParts(exportId) if err != nil { - log.Error(err) + rctx.Log.Error(err) return api.InternalServerError("failed to get export parts") } template, err := templating.GetTemplate("view_export") if err != nil { - log.Error(err) + rctx.Log.Error(err) return api.InternalServerError("failed to get template") } @@ -175,14 +176,14 @@ func ViewExport(r *http.Request, log *logrus.Entry, user api.UserInfo) interface html := bytes.Buffer{} err = template.Execute(&html, model) if err != nil { - log.Error(err) + rctx.Log.Error(err) return api.InternalServerError("failed to render template") } return &api.HtmlResponse{HTML: string(html.Bytes())} } -func GetExportMetadata(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { +func GetExportMetadata(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) interface{} { if !config.Get().Archiving.Enabled { return api.BadRequest("archiving is not enabled") } @@ -190,21 +191,21 @@ func GetExportMetadata(r *http.Request, log *logrus.Entry, user api.UserInfo) in params := mux.Vars(r) exportId := params["exportId"] - log = log.WithFields(logrus.Fields{ + rctx = rctx.LogWithFields(logrus.Fields{ "exportId": exportId, }) - exportDb := storage.GetDatabase().GetExportStore(r.Context(), log) + exportDb := storage.GetDatabase().GetExportStore(rctx) exportInfo, err := exportDb.GetExportMetadata(exportId) if err != nil { - log.Error(err) + rctx.Log.Error(err) return api.InternalServerError("failed to get metadata") } parts, err := exportDb.GetExportParts(exportId) if err != nil { - log.Error(err) + rctx.Log.Error(err) return api.InternalServerError("failed to get export parts") } @@ -223,7 +224,7 @@ func GetExportMetadata(r *http.Request, log *logrus.Entry, user api.UserInfo) in return &api.DoNotCacheResponse{Payload: metadata} } -func DownloadExportPart(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { +func DownloadExportPart(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) interface{} { if !config.Get().Archiving.Enabled { return api.BadRequest("archiving is not enabled") } @@ -233,25 +234,25 @@ func DownloadExportPart(r *http.Request, log *logrus.Entry, user api.UserInfo) i exportId := params["exportId"] partId, err := strconv.ParseInt(params["partId"], 10, 64) if err != nil { - log.Error(err) + rctx.Log.Error(err) return api.BadRequest("invalid part index") } - log = log.WithFields(logrus.Fields{ + rctx = rctx.LogWithFields(logrus.Fields{ "exportId": exportId, "partId": partId, }) - db := storage.GetDatabase().GetExportStore(r.Context(), log) + db := storage.GetDatabase().GetExportStore(rctx) part, err := db.GetExportPart(exportId, int(partId)) if err != nil { - log.Error(err) + rctx.Log.Error(err) return api.InternalServerError("failed to get part") } - s, err := datastore.DownloadStream(r.Context(), log, part.DatastoreID, part.Location) + s, err := datastore.DownloadStream(rctx, part.DatastoreID, part.Location) if err != nil { - log.Error(err) + rctx.Log.Error(err) return api.InternalServerError("failed to start download") } @@ -263,7 +264,7 @@ func DownloadExportPart(r *http.Request, log *logrus.Entry, user api.UserInfo) i } } -func DeleteExport(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { +func DeleteExport(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) interface{} { if !config.Get().Archiving.Enabled { return api.BadRequest("archiving is not enabled") } @@ -272,38 +273,38 @@ func DeleteExport(r *http.Request, log *logrus.Entry, user api.UserInfo) interfa exportId := params["exportId"] - log = log.WithFields(logrus.Fields{ + rctx = rctx.LogWithFields(logrus.Fields{ "exportId": exportId, }) - db := storage.GetDatabase().GetExportStore(r.Context(), log) + db := storage.GetDatabase().GetExportStore(rctx) - log.Info("Getting information on which parts to delete") + rctx.Log.Info("Getting information on which parts to delete") parts, err := db.GetExportParts(exportId) if err != nil { - log.Error(err) + rctx.Log.Error(err) return api.InternalServerError("failed to delete export") } for _, part := range parts { - log.Info("Locating datastore: " + part.DatastoreID) - ds, err := datastore.LocateDatastore(r.Context(), log, part.DatastoreID) + rctx.Log.Info("Locating datastore: " + part.DatastoreID) + ds, err := datastore.LocateDatastore(rctx, part.DatastoreID) if err != nil { - log.Error(err) + rctx.Log.Error(err) return api.InternalServerError("failed to delete export") } - log.Info("Deleting object: " + part.Location) + rctx.Log.Info("Deleting object: " + part.Location) err = ds.DeleteObject(part.Location) if err != nil { - log.Warn(err) + rctx.Log.Warn(err) } } - log.Info("Purging export from database") + rctx.Log.Info("Purging export from database") err = db.DeleteExportAndParts(exportId) if err != nil { - log.Error(err) + rctx.Log.Error(err) return api.InternalServerError("failed to delete export") } diff --git a/api/custom/federation.go b/api/custom/federation.go index 728b3423efaf8f60b0b9b5eee0c6b0c4d9e77a9f..bdd4f829d964a3e71c61408bd403d44ab2b3a680 100644 --- a/api/custom/federation.go +++ b/api/custom/federation.go @@ -8,41 +8,42 @@ import ( "github.com/gorilla/mux" "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/api" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/matrix" ) -func GetFederationInfo(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { +func GetFederationInfo(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) interface{} { params := mux.Vars(r) serverName := params["serverName"] - log = log.WithFields(logrus.Fields{ + rctx = rctx.LogWithFields(logrus.Fields{ "serverName": serverName, }) url, hostname, err := matrix.GetServerApiUrl(serverName) if err != nil { - log.Error(err) + rctx.Log.Error(err) return api.InternalServerError(err.Error()) } versionUrl := url + "/_matrix/federation/v1/version" versionResponse, err := matrix.FederatedGet(versionUrl, hostname) if err != nil { - log.Error(err) + rctx.Log.Error(err) return api.InternalServerError(err.Error()) } c, err := ioutil.ReadAll(versionResponse.Body) if err != nil { - log.Error(err) + rctx.Log.Error(err) return api.InternalServerError(err.Error()) } out := make(map[string]interface{}) err = json.Unmarshal(c, &out) if err != nil { - log.Error(err) + rctx.Log.Error(err) return api.InternalServerError(err.Error()) } diff --git a/api/custom/health.go b/api/custom/health.go index d9d067db8decb8f0ba055343dd72553b01e6577f..cab5bf9996372d36f74606209034d38c17904bc0 100644 --- a/api/custom/health.go +++ b/api/custom/health.go @@ -3,8 +3,8 @@ package custom import ( "net/http" - "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/api" + "github.com/turt2live/matrix-media-repo/common/rcontext" ) type HealthzResponse struct { @@ -12,7 +12,7 @@ type HealthzResponse struct { Status string `json:"status"` } -func GetHealthz(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { +func GetHealthz(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) interface{} { return &api.DoNotCacheResponse{ Payload: &HealthzResponse{ OK: true, diff --git a/api/custom/imports.go b/api/custom/imports.go index 5592d15de4d31d1a96020c3745fe76f758844a2b..9dfdffcd96c6d19e37fa96f95f09aa186cce7d07 100644 --- a/api/custom/imports.go +++ b/api/custom/imports.go @@ -4,9 +4,9 @@ import ( "net/http" "github.com/gorilla/mux" - "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/api" "github.com/turt2live/matrix-media-repo/common/config" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/controllers/data_controller" ) @@ -15,15 +15,15 @@ type ImportStarted struct { TaskID int `json:"task_id"` } -func StartImport(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { +func StartImport(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) interface{} { if !config.Get().Archiving.Enabled { return api.BadRequest("archiving is not enabled") } defer r.Body.Close() - task, importId, err := data_controller.StartImport(r.Body, log) + task, importId, err := data_controller.StartImport(r.Body, rctx) if err != nil { - log.Error(err) + rctx.Log.Error(err) return api.InternalServerError("fatal error starting import") } @@ -33,7 +33,7 @@ func StartImport(r *http.Request, log *logrus.Entry, user api.UserInfo) interfac }} } -func AppendToImport(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { +func AppendToImport(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) interface{} { if !config.Get().Archiving.Enabled { return api.BadRequest("archiving is not enabled") } @@ -45,14 +45,14 @@ func AppendToImport(r *http.Request, log *logrus.Entry, user api.UserInfo) inter defer r.Body.Close() err := data_controller.AppendToImport(importId, r.Body) if err != nil { - log.Error(err) + rctx.Log.Error(err) return api.InternalServerError("fatal error appending to import") } return &api.DoNotCacheResponse{Payload: &api.EmptyResponse{}} } -func StopImport(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { +func StopImport(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) interface{} { if !config.Get().Archiving.Enabled { return api.BadRequest("archiving is not enabled") } @@ -63,7 +63,7 @@ func StopImport(r *http.Request, log *logrus.Entry, user api.UserInfo) interface err := data_controller.StopImport(importId) if err != nil { - log.Error(err) + rctx.Log.Error(err) return api.InternalServerError("fatal error stopping import") } diff --git a/api/custom/purge.go b/api/custom/purge.go index aa0b9dc072c23c1c302360c2ad7fe8472ee7c93d..407f276c6dc689cce918abf69f592d42a2a14c4f 100644 --- a/api/custom/purge.go +++ b/api/custom/purge.go @@ -9,6 +9,7 @@ import ( "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/api" "github.com/turt2live/matrix-media-repo/common" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/controllers/maintenance_controller" "github.com/turt2live/matrix-media-repo/matrix" "github.com/turt2live/matrix-media-repo/storage" @@ -20,7 +21,7 @@ type MediaPurgedResponse struct { NumRemoved int `json:"total_removed"` } -func PurgeRemoteMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { +func PurgeRemoteMedia(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) interface{} { beforeTsStr := r.URL.Query().Get("before_ts") if beforeTsStr == "" { return api.BadRequest("Missing before_ts argument") @@ -30,22 +31,22 @@ func PurgeRemoteMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) int return api.BadRequest("Error parsing before_ts: " + err.Error()) } - log = log.WithFields(logrus.Fields{ + rctx = rctx.LogWithFields(logrus.Fields{ "beforeTs": beforeTs, }) // We don't bother clearing the cache because it's still probably useful there - removed, err := maintenance_controller.PurgeRemoteMediaBefore(beforeTs, r.Context(), log) + removed, err := maintenance_controller.PurgeRemoteMediaBefore(beforeTs, rctx) if err != nil { - log.Error("Error purging remote media: " + err.Error()) + rctx.Log.Error("Error purging remote media: " + err.Error()) return api.InternalServerError("Error purging remote media") } return &api.DoNotCacheResponse{Payload: &MediaPurgedResponse{NumRemoved: removed}} } -func PurgeIndividualRecord(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { - isGlobalAdmin, isLocalAdmin := getPurgeRequestInfo(r, log, user) +func PurgeIndividualRecord(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) interface{} { + isGlobalAdmin, isLocalAdmin := getPurgeRequestInfo(r, rctx, user) localServerName := r.Host params := mux.Vars(r) @@ -53,7 +54,7 @@ func PurgeIndividualRecord(r *http.Request, log *logrus.Entry, user api.UserInfo server := params["server"] mediaId := params["mediaId"] - log = log.WithFields(logrus.Fields{ + rctx = rctx.LogWithFields(logrus.Fields{ "server": server, "mediaId": mediaId, }) @@ -65,13 +66,13 @@ func PurgeIndividualRecord(r *http.Request, log *logrus.Entry, user api.UserInfo } // If the user is NOT a local admin, ensure they uploaded the content in the first place if !isLocalAdmin { - db := storage.GetDatabase().GetMediaStore(r.Context(), log) + db := storage.GetDatabase().GetMediaStore(rctx) m, err := db.Get(server, mediaId) if err == sql.ErrNoRows { return api.NotFoundError() } if err != nil { - log.Error("Error checking ownership of media: " + err.Error()) + rctx.Log.Error("Error checking ownership of media: " + err.Error()) return api.InternalServerError("error checking media ownership") } if m.UserId != user.UserId { @@ -80,35 +81,35 @@ func PurgeIndividualRecord(r *http.Request, log *logrus.Entry, user api.UserInfo } } - err := maintenance_controller.PurgeMedia(server, mediaId, r.Context(), log) + err := maintenance_controller.PurgeMedia(server, mediaId, rctx) if err == sql.ErrNoRows || err == common.ErrMediaNotFound { return api.NotFoundError() } if err != nil { - log.Error("Error purging media: " + err.Error()) + rctx.Log.Error("Error purging media: " + err.Error()) return api.InternalServerError("error purging media") } return &api.DoNotCacheResponse{Payload: map[string]interface{}{"purged": true}} } -func PurgeQuarantined(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { - isGlobalAdmin, isLocalAdmin := getPurgeRequestInfo(r, log, user) +func PurgeQuarantined(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) interface{} { + isGlobalAdmin, isLocalAdmin := getPurgeRequestInfo(r, rctx, user) localServerName := r.Host var affected []*types.Media var err error if isGlobalAdmin { - affected, err = maintenance_controller.PurgeQuarantined(r.Context(), log) + affected, err = maintenance_controller.PurgeQuarantined(rctx) } else if isLocalAdmin { - affected, err = maintenance_controller.PurgeQuarantinedFor(localServerName, r.Context(), log) + affected, err = maintenance_controller.PurgeQuarantinedFor(localServerName, rctx) } else { return api.AuthFailed() } if err != nil { - log.Error("Error purging media: " + err.Error()) + rctx.Log.Error("Error purging media: " + err.Error()) return api.InternalServerError("error purging media") } @@ -120,7 +121,7 @@ func PurgeQuarantined(r *http.Request, log *logrus.Entry, user api.UserInfo) int return &api.DoNotCacheResponse{Payload: map[string]interface{}{"purged": true, "affected": mxcs}} } -func PurgeOldMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { +func PurgeOldMedia(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) interface{} { var err error beforeTs := util.NowMillis() beforeTsStr := r.URL.Query().Get("before_ts") @@ -140,15 +141,15 @@ func PurgeOldMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) interf } } - log = log.WithFields(logrus.Fields{ + rctx = rctx.LogWithFields(logrus.Fields{ "before_ts": beforeTs, "include_local": includeLocal, }) - affected, err := maintenance_controller.PurgeOldMedia(beforeTs, includeLocal, r.Context(), log) + affected, err := maintenance_controller.PurgeOldMedia(beforeTs, includeLocal, rctx) if err != nil { - log.Error("Error purging media: " + err.Error()) + rctx.Log.Error("Error purging media: " + err.Error()) return api.InternalServerError("error purging media") } @@ -160,8 +161,8 @@ func PurgeOldMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) interf return &api.DoNotCacheResponse{Payload: map[string]interface{}{"purged": true, "affected": mxcs}} } -func PurgeUserMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { - isGlobalAdmin, isLocalAdmin := getPurgeRequestInfo(r, log, user) +func PurgeUserMedia(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) interface{} { + isGlobalAdmin, isLocalAdmin := getPurgeRequestInfo(r, rctx, user) if !isGlobalAdmin && !isLocalAdmin { return api.AuthFailed() } @@ -180,14 +181,14 @@ func PurgeUserMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) inter userId := params["userId"] - log = log.WithFields(logrus.Fields{ + rctx = rctx.LogWithFields(logrus.Fields{ "userId": userId, "beforeTs": beforeTs, }) _, userDomain, err := util.SplitUserId(userId) if err != nil { - log.Error("Error parsing user ID (" + userId + "): " + err.Error()) + rctx.Log.Error("Error parsing user ID (" + userId + "): " + err.Error()) return api.InternalServerError("error parsing user ID") } @@ -195,10 +196,10 @@ func PurgeUserMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) inter return api.AuthFailed() } - affected, err := maintenance_controller.PurgeUserMedia(userId, beforeTs, r.Context(), log) + affected, err := maintenance_controller.PurgeUserMedia(userId, beforeTs, rctx) if err != nil { - log.Error("Error purging media: " + err.Error()) + rctx.Log.Error("Error purging media: " + err.Error()) return api.InternalServerError("error purging media") } @@ -210,8 +211,8 @@ func PurgeUserMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) inter return &api.DoNotCacheResponse{Payload: map[string]interface{}{"purged": true, "affected": mxcs}} } -func PurgeRoomMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { - isGlobalAdmin, isLocalAdmin := getPurgeRequestInfo(r, log, user) +func PurgeRoomMedia(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) interface{} { + isGlobalAdmin, isLocalAdmin := getPurgeRequestInfo(r, rctx, user) if !isGlobalAdmin && !isLocalAdmin { return api.AuthFailed() } @@ -230,14 +231,14 @@ func PurgeRoomMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) inter roomId := params["roomId"] - log = log.WithFields(logrus.Fields{ + rctx = rctx.LogWithFields(logrus.Fields{ "roomId": roomId, "beforeTs": beforeTs, }) - allMedia, err := matrix.ListMedia(r.Context(), r.Host, user.AccessToken, roomId, r.RemoteAddr) + allMedia, err := matrix.ListMedia(rctx, r.Host, user.AccessToken, roomId, r.RemoteAddr) if err != nil { - log.Error("Error while listing media in the room: " + err.Error()) + rctx.Log.Error("Error while listing media in the room: " + err.Error()) return api.InternalServerError("error retrieving media in room") } @@ -273,10 +274,10 @@ func PurgeRoomMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) inter } } - affected, err := maintenance_controller.PurgeRoomMedia(mxcs, beforeTs, r.Context(), log) + affected, err := maintenance_controller.PurgeRoomMedia(mxcs, beforeTs, rctx) if err != nil { - log.Error("Error purging media: " + err.Error()) + rctx.Log.Error("Error purging media: " + err.Error()) return api.InternalServerError("error purging media") } @@ -288,8 +289,8 @@ func PurgeRoomMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) inter return &api.DoNotCacheResponse{Payload: map[string]interface{}{"purged": true, "affected": mxcs}} } -func PurgeDomainMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { - isGlobalAdmin, isLocalAdmin := getPurgeRequestInfo(r, log, user) +func PurgeDomainMedia(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) interface{} { + isGlobalAdmin, isLocalAdmin := getPurgeRequestInfo(r, rctx, user) if !isGlobalAdmin && !isLocalAdmin { return api.AuthFailed() } @@ -308,7 +309,7 @@ func PurgeDomainMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) int serverName := params["serverName"] - log = log.WithFields(logrus.Fields{ + rctx = rctx.LogWithFields(logrus.Fields{ "serverName": serverName, "beforeTs": beforeTs, }) @@ -317,10 +318,10 @@ func PurgeDomainMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) int return api.AuthFailed() } - affected, err := maintenance_controller.PurgeDomainMedia(serverName, beforeTs, r.Context(), log) + affected, err := maintenance_controller.PurgeDomainMedia(serverName, beforeTs, rctx) if err != nil { - log.Error("Error purging media: " + err.Error()) + rctx.Log.Error("Error purging media: " + err.Error()) return api.InternalServerError("error purging media") } @@ -332,11 +333,11 @@ func PurgeDomainMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) int return &api.DoNotCacheResponse{Payload: map[string]interface{}{"purged": true, "affected": mxcs}} } -func getPurgeRequestInfo(r *http.Request, log *logrus.Entry, user api.UserInfo) (bool, bool) { +func getPurgeRequestInfo(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) (bool, bool) { isGlobalAdmin := util.IsGlobalAdmin(user.UserId) || user.IsShared - isLocalAdmin, err := matrix.IsUserAdmin(r.Context(), r.Host, user.AccessToken, r.RemoteAddr) + isLocalAdmin, err := matrix.IsUserAdmin(rctx, r.Host, user.AccessToken, r.RemoteAddr) if err != nil { - log.Error("Error verifying local admin: " + err.Error()) + rctx.Log.Error("Error verifying local admin: " + err.Error()) return isGlobalAdmin, false } diff --git a/api/custom/quarantine.go b/api/custom/quarantine.go index 4922cf9304592d77b65c7d65d14fad53284c21ee..7f3ce133250aaf86de61cda10daf50a0a0f2b2ef 100644 --- a/api/custom/quarantine.go +++ b/api/custom/quarantine.go @@ -1,7 +1,6 @@ package custom import ( - "context" "database/sql" "net/http" @@ -9,6 +8,7 @@ import ( "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/api" "github.com/turt2live/matrix-media-repo/common/config" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/internal_cache" "github.com/turt2live/matrix-media-repo/matrix" "github.com/turt2live/matrix-media-repo/storage" @@ -23,8 +23,8 @@ type MediaQuarantinedResponse struct { // Developer note: This isn't broken out into a dedicated controller class because the logic is slightly // too complex to do so. If anything, the logic should be improved and moved. -func QuarantineRoomMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { - canQuarantine, allowOtherHosts, isLocalAdmin := getQuarantineRequestInfo(r, log, user) +func QuarantineRoomMedia(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) interface{} { + canQuarantine, allowOtherHosts, isLocalAdmin := getQuarantineRequestInfo(r, rctx, user) if !canQuarantine { return api.AuthFailed() } @@ -33,14 +33,14 @@ func QuarantineRoomMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) roomId := params["roomId"] - log = log.WithFields(logrus.Fields{ + rctx = rctx.LogWithFields(logrus.Fields{ "roomId": roomId, "localAdmin": isLocalAdmin, }) - allMedia, err := matrix.ListMedia(r.Context(), r.Host, user.AccessToken, roomId, r.RemoteAddr) + allMedia, err := matrix.ListMedia(rctx, r.Host, user.AccessToken, roomId, r.RemoteAddr) if err != nil { - log.Error("Error while listing media in the room: " + err.Error()) + rctx.Log.Error("Error while listing media in the room: " + err.Error()) return api.InternalServerError("error retrieving media in room") } @@ -52,16 +52,16 @@ func QuarantineRoomMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) for _, mxc := range mxcs { server, mediaId, err := util.SplitMxc(mxc) if err != nil { - log.Error("Error parsing MXC URI (" + mxc + "): " + err.Error()) + rctx.Log.Error("Error parsing MXC URI (" + mxc + "): " + err.Error()) return api.InternalServerError("error parsing mxc uri") } if !allowOtherHosts && r.Host != server { - log.Warn("Skipping media " + mxc + " because it is on a different host") + rctx.Log.Warn("Skipping media " + mxc + " because it is on a different host") continue } - resp, ok := doQuarantine(r.Context(), log, server, mediaId, allowOtherHosts) + resp, ok := doQuarantine(rctx, server, mediaId, allowOtherHosts) if !ok { return resp } @@ -72,8 +72,8 @@ func QuarantineRoomMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) return &api.DoNotCacheResponse{Payload: &MediaQuarantinedResponse{NumQuarantined: total}} } -func QuarantineUserMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { - canQuarantine, allowOtherHosts, isLocalAdmin := getQuarantineRequestInfo(r, log, user) +func QuarantineUserMedia(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) interface{} { + canQuarantine, allowOtherHosts, isLocalAdmin := getQuarantineRequestInfo(r, rctx, user) if !canQuarantine { return api.AuthFailed() } @@ -82,14 +82,14 @@ func QuarantineUserMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) userId := params["userId"] - log = log.WithFields(logrus.Fields{ + rctx = rctx.LogWithFields(logrus.Fields{ "userId": userId, "localAdmin": isLocalAdmin, }) _, userDomain, err := util.SplitUserId(userId) if err != nil { - log.Error("Error parsing user ID (" + userId + "): " + err.Error()) + rctx.Log.Error("Error parsing user ID (" + userId + "): " + err.Error()) return api.InternalServerError("error parsing user ID") } @@ -97,16 +97,16 @@ func QuarantineUserMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) return api.AuthFailed() } - db := storage.GetDatabase().GetMediaStore(r.Context(), log) + db := storage.GetDatabase().GetMediaStore(rctx) userMedia, err := db.GetMediaByUser(userId) if err != nil { - log.Error("Error while listing media for the user: " + err.Error()) + rctx.Log.Error("Error while listing media for the user: " + err.Error()) return api.InternalServerError("error retrieving media for user") } total := 0 for _, media := range userMedia { - resp, ok := doQuarantineOn(media, allowOtherHosts, log, r.Context()) + resp, ok := doQuarantineOn(media, allowOtherHosts, rctx) if !ok { return resp } @@ -117,8 +117,8 @@ func QuarantineUserMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) return &api.DoNotCacheResponse{Payload: &MediaQuarantinedResponse{NumQuarantined: total}} } -func QuarantineDomainMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { - canQuarantine, allowOtherHosts, isLocalAdmin := getQuarantineRequestInfo(r, log, user) +func QuarantineDomainMedia(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) interface{} { + canQuarantine, allowOtherHosts, isLocalAdmin := getQuarantineRequestInfo(r, rctx, user) if !canQuarantine { return api.AuthFailed() } @@ -127,7 +127,7 @@ func QuarantineDomainMedia(r *http.Request, log *logrus.Entry, user api.UserInfo serverName := params["serverName"] - log = log.WithFields(logrus.Fields{ + rctx = rctx.LogWithFields(logrus.Fields{ "serverName": serverName, "localAdmin": isLocalAdmin, }) @@ -136,16 +136,16 @@ func QuarantineDomainMedia(r *http.Request, log *logrus.Entry, user api.UserInfo return api.AuthFailed() } - db := storage.GetDatabase().GetMediaStore(r.Context(), log) + db := storage.GetDatabase().GetMediaStore(rctx) userMedia, err := db.GetAllMediaForServer(serverName) if err != nil { - log.Error("Error while listing media for the server: " + err.Error()) + rctx.Log.Error("Error while listing media for the server: " + err.Error()) return api.InternalServerError("error retrieving media for server") } total := 0 for _, media := range userMedia { - resp, ok := doQuarantineOn(media, allowOtherHosts, log, r.Context()) + resp, ok := doQuarantineOn(media, allowOtherHosts, rctx) if !ok { return resp } @@ -156,8 +156,8 @@ func QuarantineDomainMedia(r *http.Request, log *logrus.Entry, user api.UserInfo return &api.DoNotCacheResponse{Payload: &MediaQuarantinedResponse{NumQuarantined: total}} } -func QuarantineMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { - canQuarantine, allowOtherHosts, isLocalAdmin := getQuarantineRequestInfo(r, log, user) +func QuarantineMedia(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) interface{} { + canQuarantine, allowOtherHosts, isLocalAdmin := getQuarantineRequestInfo(r, rctx, user) if !canQuarantine { return api.AuthFailed() } @@ -167,7 +167,7 @@ func QuarantineMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) inte server := params["server"] mediaId := params["mediaId"] - log = log.WithFields(logrus.Fields{ + rctx = rctx.LogWithFields(logrus.Fields{ "server": server, "mediaId": mediaId, "localAdmin": isLocalAdmin, @@ -177,42 +177,42 @@ func QuarantineMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) inte return api.BadRequest("unable to quarantine media on other homeservers") } - resp, _ := doQuarantine(r.Context(), log, server, mediaId, allowOtherHosts) + resp, _ := doQuarantine(rctx, server, mediaId, allowOtherHosts) return &api.DoNotCacheResponse{Payload: resp} } -func doQuarantine(ctx context.Context, log *logrus.Entry, origin string, mediaId string, allowOtherHosts bool) (interface{}, bool) { - db := storage.GetDatabase().GetMediaStore(ctx, log) +func doQuarantine(ctx rcontext.RequestContext, origin string, mediaId string, allowOtherHosts bool) (interface{}, bool) { + db := storage.GetDatabase().GetMediaStore(ctx) media, err := db.Get(origin, mediaId) if err != nil { if err == sql.ErrNoRows { - log.Warn("Media not found, could not quarantine: " + origin + "/" + mediaId) + ctx.Log.Warn("Media not found, could not quarantine: " + origin + "/" + mediaId) return &MediaQuarantinedResponse{0}, true } - log.Error("Error fetching media: " + err.Error()) + ctx.Log.Error("Error fetching media: " + err.Error()) return api.InternalServerError("error quarantining media"), false } - return doQuarantineOn(media, allowOtherHosts, log, ctx) + return doQuarantineOn(media, allowOtherHosts, ctx) } -func doQuarantineOn(media *types.Media, allowOtherHosts bool, log *logrus.Entry, ctx context.Context) (interface{}, bool) { +func doQuarantineOn(media *types.Media, allowOtherHosts bool, ctx rcontext.RequestContext) (interface{}, bool) { // We reset the entire cache to avoid any lingering links floating around, such as thumbnails or other media. // The reset is done before actually quarantining the media because that could fail for some reason internal_cache.Get().Reset() - num, err := setMediaQuarantined(media, true, allowOtherHosts, ctx, log) + num, err := setMediaQuarantined(media, true, allowOtherHosts, ctx) if err != nil { - log.Error("Error quarantining media: " + err.Error()) + ctx.Log.Error("Error quarantining media: " + err.Error()) return api.InternalServerError("Error quarantining media"), false } return &MediaQuarantinedResponse{NumQuarantined: num}, true } -func setMediaQuarantined(media *types.Media, isQuarantined bool, allowOtherHosts bool, ctx context.Context, log *logrus.Entry) (int, error) { - db := storage.GetDatabase().GetMediaStore(ctx, log) +func setMediaQuarantined(media *types.Media, isQuarantined bool, allowOtherHosts bool, ctx rcontext.RequestContext) (int, error) { + db := storage.GetDatabase().GetMediaStore(ctx) numQuarantined := 0 // Quarantine all media with the same hash, including the one requested @@ -222,7 +222,7 @@ func setMediaQuarantined(media *types.Media, isQuarantined bool, allowOtherHosts } for _, m := range otherMedia { if m.Origin != media.Origin && !allowOtherHosts { - log.Warn("Skipping quarantine on " + m.Origin + "/" + m.MediaId + " because it is on a different host from " + media.Origin + "/" + media.MediaId) + ctx.Log.Warn("Skipping quarantine on " + m.Origin + "/" + m.MediaId + " because it is on a different host from " + media.Origin + "/" + media.MediaId) continue } @@ -232,13 +232,13 @@ func setMediaQuarantined(media *types.Media, isQuarantined bool, allowOtherHosts } numQuarantined++ - log.Warn("Media has been quarantined: " + m.Origin + "/" + m.MediaId) + ctx.Log.Warn("Media has been quarantined: " + m.Origin + "/" + m.MediaId) } return numQuarantined, nil } -func getQuarantineRequestInfo(r *http.Request, log *logrus.Entry, user api.UserInfo) (bool, bool, bool) { +func getQuarantineRequestInfo(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) (bool, bool, bool) { isGlobalAdmin := util.IsGlobalAdmin(user.UserId) || user.IsShared canQuarantine := isGlobalAdmin allowOtherHosts := isGlobalAdmin @@ -246,15 +246,15 @@ func getQuarantineRequestInfo(r *http.Request, log *logrus.Entry, user api.UserI var err error if !isGlobalAdmin { if config.Get().Quarantine.AllowLocalAdmins { - isLocalAdmin, err = matrix.IsUserAdmin(r.Context(), r.Host, user.AccessToken, r.RemoteAddr) + isLocalAdmin, err = matrix.IsUserAdmin(rctx, r.Host, user.AccessToken, r.RemoteAddr) if err != nil { - log.Error("Error verifying local admin: " + err.Error()) + rctx.Log.Error("Error verifying local admin: " + err.Error()) canQuarantine = false return canQuarantine, allowOtherHosts, isLocalAdmin } if !isLocalAdmin { - log.Warn(user.UserId + " tried to quarantine media on another server") + rctx.Log.Warn(user.UserId + " tried to quarantine media on another server") canQuarantine = false return canQuarantine, allowOtherHosts, isLocalAdmin } @@ -265,7 +265,7 @@ func getQuarantineRequestInfo(r *http.Request, log *logrus.Entry, user api.UserI } if !canQuarantine { - log.Warn(user.UserId + " tried to quarantine media") + rctx.Log.Warn(user.UserId + " tried to quarantine media") } return canQuarantine, allowOtherHosts, isLocalAdmin diff --git a/api/custom/tasks.go b/api/custom/tasks.go index 63d7a9845d9bbfb15d284e84ae05d49caa031787..bfb84c933723ebf34951cbddadeb2534afba9e23 100644 --- a/api/custom/tasks.go +++ b/api/custom/tasks.go @@ -7,6 +7,7 @@ import ( "github.com/gorilla/mux" "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/api" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/storage" ) @@ -19,25 +20,25 @@ type TaskStatus struct { IsFinished bool `json:"is_finished"` } -func GetTask(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { +func GetTask(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) interface{} { params := mux.Vars(r) taskIdStr := params["taskId"] taskId, err := strconv.Atoi(taskIdStr) if err != nil { - log.Error(err) + rctx.Log.Error(err) return api.BadRequest("invalid task ID") } - log = log.WithFields(logrus.Fields{ + rctx = rctx.LogWithFields(logrus.Fields{ "taskId": taskId, }) - db := storage.GetDatabase().GetMetadataStore(r.Context(), log) + db := storage.GetDatabase().GetMetadataStore(rctx) task, err := db.GetBackgroundTask(taskId) if err != nil { - log.Error(err) + rctx.Log.Error(err) return api.InternalServerError("failed to get task information") } @@ -51,8 +52,8 @@ func GetTask(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} }} } -func ListAllTasks(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { - db := storage.GetDatabase().GetMetadataStore(r.Context(), log) +func ListAllTasks(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) interface{} { + db := storage.GetDatabase().GetMetadataStore(rctx) tasks, err := db.GetAllBackgroundTasks() if err != nil { @@ -75,8 +76,8 @@ func ListAllTasks(r *http.Request, log *logrus.Entry, user api.UserInfo) interfa return &api.DoNotCacheResponse{Payload: statusObjs} } -func ListUnfinishedTasks(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { - db := storage.GetDatabase().GetMetadataStore(r.Context(), log) +func ListUnfinishedTasks(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) interface{} { + db := storage.GetDatabase().GetMetadataStore(rctx) tasks, err := db.GetAllBackgroundTasks() if err != nil { diff --git a/api/custom/usage.go b/api/custom/usage.go index 379893bee8434ec90ccd8f81fb399baff0da2806..1e0fef4e4e4965a7cb80882dcfce0442e82fa6c2 100644 --- a/api/custom/usage.go +++ b/api/custom/usage.go @@ -6,6 +6,7 @@ import ( "github.com/gorilla/mux" "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/api" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/storage" "github.com/turt2live/matrix-media-repo/types" "github.com/turt2live/matrix-media-repo/util" @@ -44,26 +45,26 @@ type MediaUsageEntry struct { CreatedTs int64 `json:"created_ts"` } -func GetDomainUsage(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { +func GetDomainUsage(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) interface{} { params := mux.Vars(r) serverName := params["serverName"] - log = log.WithFields(logrus.Fields{ + rctx = rctx.LogWithFields(logrus.Fields{ "serverName": serverName, }) - db := storage.GetDatabase().GetMetadataStore(r.Context(), log) + db := storage.GetDatabase().GetMetadataStore(rctx) mediaBytes, thumbBytes, err := db.GetByteUsageForServer(serverName) if err != nil { - log.Error(err) + rctx.Log.Error(err) return api.InternalServerError("Failed to get byte usage for server") } mediaCount, thumbCount, err := db.GetCountUsageForServer(serverName) if err != nil { - log.Error(err) + rctx.Log.Error(err) return api.InternalServerError("Failed to get count usage for server") } @@ -87,17 +88,17 @@ func GetDomainUsage(r *http.Request, log *logrus.Entry, user api.UserInfo) inter } } -func GetUserUsage(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { +func GetUserUsage(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) interface{} { params := mux.Vars(r) serverName := params["serverName"] userIds := r.URL.Query()["user_id"] - log = log.WithFields(logrus.Fields{ + rctx = rctx.LogWithFields(logrus.Fields{ "serverName": serverName, }) - db := storage.GetDatabase().GetMediaStore(r.Context(), log) + db := storage.GetDatabase().GetMediaStore(rctx) var records []*types.Media var err error @@ -108,7 +109,7 @@ func GetUserUsage(r *http.Request, log *logrus.Entry, user api.UserInfo) interfa } if err != nil { - log.Error(err) + rctx.Log.Error(err) return api.InternalServerError("Failed to get media records for users") } @@ -143,17 +144,17 @@ func GetUserUsage(r *http.Request, log *logrus.Entry, user api.UserInfo) interfa return &api.DoNotCacheResponse{Payload: parsed} } -func GetUploadsUsage(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { +func GetUploadsUsage(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) interface{} { params := mux.Vars(r) serverName := params["serverName"] mxcs := r.URL.Query()["mxc"] - log = log.WithFields(logrus.Fields{ + rctx = rctx.LogWithFields(logrus.Fields{ "serverName": serverName, }) - db := storage.GetDatabase().GetMediaStore(r.Context(), log) + db := storage.GetDatabase().GetMediaStore(rctx) var records []*types.Media var err error @@ -164,7 +165,7 @@ func GetUploadsUsage(r *http.Request, log *logrus.Entry, user api.UserInfo) inte for _, mxc := range mxcs { o, i, err := util.SplitMxc(mxc) if err != nil { - log.Error(err) + rctx.Log.Error(err) return api.InternalServerError("Error parsing MXC " + mxc) } @@ -178,7 +179,7 @@ func GetUploadsUsage(r *http.Request, log *logrus.Entry, user api.UserInfo) inte } if err != nil { - log.Error(err) + rctx.Log.Error(err) return api.InternalServerError("Failed to get media records for users") } diff --git a/api/general_handlers.go b/api/general_handlers.go index a1b0b2d6c3592eeaf12084a936fda58357f4da65..1c20046008ff2b91b0d3b44c37fd8d5c6b2d5da4 100644 --- a/api/general_handlers.go +++ b/api/general_handlers.go @@ -3,17 +3,17 @@ package api import ( "net/http" - "github.com/sirupsen/logrus" + "github.com/turt2live/matrix-media-repo/common/rcontext" ) -func NotFoundHandler(r *http.Request, log *logrus.Entry) interface{} { +func NotFoundHandler(r *http.Request, rctx rcontext.RequestContext) interface{} { return NotFoundError() } -func MethodNotAllowedHandler(r *http.Request, log *logrus.Entry) interface{} { +func MethodNotAllowedHandler(r *http.Request, rctx rcontext.RequestContext) interface{} { return MethodNotAllowed() } -func EmptyResponseHandler(r *http.Request, log *logrus.Entry) interface{} { +func EmptyResponseHandler(r *http.Request, rctx rcontext.RequestContext) interface{} { return &EmptyResponse{} } diff --git a/api/r0/download.go b/api/r0/download.go index 14ab9bd48ece141309f57ebbc638ef0ca9e15d8d..0122238f7cb393bfd31757b4b04628218c727db1 100644 --- a/api/r0/download.go +++ b/api/r0/download.go @@ -9,6 +9,7 @@ import ( "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/api" "github.com/turt2live/matrix-media-repo/common" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/controllers/download_controller" ) @@ -19,7 +20,7 @@ type DownloadMediaResponse struct { Data io.ReadCloser } -func DownloadMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { +func DownloadMedia(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) interface{} { params := mux.Vars(r) server := params["server"] @@ -36,14 +37,14 @@ func DownloadMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) interf downloadRemote = parsedFlag } - log = log.WithFields(logrus.Fields{ + rctx = rctx.LogWithFields(logrus.Fields{ "mediaId": mediaId, "server": server, "filename": filename, "allowRemote": downloadRemote, }) - streamedMedia, err := download_controller.GetMedia(server, mediaId, downloadRemote, false, r.Context(), log) + streamedMedia, err := download_controller.GetMedia(server, mediaId, downloadRemote, false, rctx) if err != nil { if err == common.ErrMediaNotFound { return api.NotFoundError() @@ -52,7 +53,7 @@ func DownloadMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) interf } else if err == common.ErrMediaQuarantined { return api.NotFoundError() // We lie for security } - log.Error("Unexpected error locating media: " + err.Error()) + rctx.Log.Error("Unexpected error locating media: " + err.Error()) return api.InternalServerError("Unexpected Error") } diff --git a/api/r0/identicon.go b/api/r0/identicon.go index 13d6569c0d2793a5367369b39d2c58048fc323d7..f7caf5c0a381f7fa773240cff14e59a368486193 100644 --- a/api/r0/identicon.go +++ b/api/r0/identicon.go @@ -14,13 +14,14 @@ import ( "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/api" "github.com/turt2live/matrix-media-repo/common/config" + "github.com/turt2live/matrix-media-repo/common/rcontext" ) type IdenticonResponse struct { Avatar io.Reader } -func Identicon(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { +func Identicon(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) interface{} { if !config.Get().Identicons.Enabled { return api.NotFoundError() } @@ -48,7 +49,7 @@ func Identicon(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{ } } - log = log.WithFields(logrus.Fields{ + rctx = rctx.LogWithFields(logrus.Fields{ "identiconWidth": width, "identiconHeight": height, "identiconSeed": seed, @@ -72,18 +73,18 @@ func Identicon(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{ }, } - log.Info("Generating identicon") + rctx.Log.Info("Generating identicon") img := sig.Make(width, false, []byte(hashed)) if width != height { // Resize to the desired height - log.Info("Resizing image to fit height") + rctx.Log.Info("Resizing image to fit height") img = imaging.Resize(img, width, height, imaging.Lanczos) } imgData := &bytes.Buffer{} err = imaging.Encode(imgData, img, imaging.PNG) if err != nil { - log.Error("Error generating image:" + err.Error()) + rctx.Log.Error("Error generating image:" + err.Error()) return api.InternalServerError("error generating identicon") } diff --git a/api/r0/preview_url.go b/api/r0/preview_url.go index a2d6d87970fb3e71e5da89d1d7727727fbbbd3a4..0107277fb4519fab365783046db50ea87492a684 100644 --- a/api/r0/preview_url.go +++ b/api/r0/preview_url.go @@ -5,10 +5,10 @@ import ( "strconv" "strings" - "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/api" "github.com/turt2live/matrix-media-repo/common" "github.com/turt2live/matrix-media-repo/common/config" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/controllers/preview_controller" "github.com/turt2live/matrix-media-repo/util" ) @@ -26,7 +26,7 @@ type MatrixOpenGraph struct { ImageHeight int `json:"og:image:height,omitempty"` } -func PreviewUrl(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { +func PreviewUrl(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) interface{} { if !config.Get().UrlPreviews.Enabled { return api.NotFoundError() } @@ -41,7 +41,7 @@ func PreviewUrl(r *http.Request, log *logrus.Entry, user api.UserInfo) interface if tsStr != "" { ts, err = strconv.ParseInt(tsStr, 10, 64) if err != nil { - log.Error("Error parsing ts: " + err.Error()) + rctx.Log.Error("Error parsing ts: " + err.Error()) return api.BadRequest(err.Error()) } } @@ -54,7 +54,7 @@ func PreviewUrl(r *http.Request, log *logrus.Entry, user api.UserInfo) interface return api.BadRequest("Scheme not accepted") } - preview, err := preview_controller.GetPreview(urlStr, r.Host, user.UserId, ts, r.Context(), log) + preview, err := preview_controller.GetPreview(urlStr, r.Host, user.UserId, ts, rctx) if err != nil { if err == common.ErrMediaNotFound || err == common.ErrHostNotFound { return api.NotFoundError() diff --git a/api/r0/public_config.go b/api/r0/public_config.go index 2dc6f6cc044eac7a69daac975fd483649d0421c7..696f22d71aa6abff2796cd800c75c430f6243158 100644 --- a/api/r0/public_config.go +++ b/api/r0/public_config.go @@ -3,16 +3,16 @@ package r0 import ( "net/http" - "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/api" "github.com/turt2live/matrix-media-repo/common/config" + "github.com/turt2live/matrix-media-repo/common/rcontext" ) type PublicConfigResponse struct { UploadMaxSize int64 `json:"m.upload.size,omitempty"` } -func PublicConfig(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { +func PublicConfig(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) interface{} { uploadSize := config.Get().Uploads.ReportedMaxSizeBytes if uploadSize == 0 { uploadSize = config.Get().Uploads.MaxSizeBytes diff --git a/api/r0/thumbnail.go b/api/r0/thumbnail.go index 39ee66f3a6820430f86f669de5b40510a62bc92e..5715d0bd363752931db263a774857e336a6293f6 100644 --- a/api/r0/thumbnail.go +++ b/api/r0/thumbnail.go @@ -9,10 +9,11 @@ import ( "github.com/turt2live/matrix-media-repo/api" "github.com/turt2live/matrix-media-repo/common" "github.com/turt2live/matrix-media-repo/common/config" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/controllers/thumbnail_controller" ) -func ThumbnailMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { +func ThumbnailMedia(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) interface{} { params := mux.Vars(r) server := params["server"] @@ -28,7 +29,7 @@ func ThumbnailMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) inter downloadRemote = parsedFlag } - log = log.WithFields(logrus.Fields{ + rctx = rctx.LogWithFields(logrus.Fields{ "mediaId": mediaId, "server": server, "allowRemote": downloadRemote, @@ -68,21 +69,21 @@ func ThumbnailMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) inter method = "scale" } - log = log.WithFields(logrus.Fields{ + rctx = rctx.LogWithFields(logrus.Fields{ "requestedWidth": width, "requestedHeight": height, "requestedMethod": method, "requestedAnimated": animated, }) - streamedThumbnail, err := thumbnail_controller.GetThumbnail(server, mediaId, width, height, animated, method, downloadRemote, r.Context(), log) + streamedThumbnail, err := thumbnail_controller.GetThumbnail(server, mediaId, width, height, animated, method, downloadRemote, rctx) if err != nil { if err == common.ErrMediaNotFound { return api.NotFoundError() } else if err == common.ErrMediaTooLarge { return api.RequestTooLarge() } - log.Error("Unexpected error locating media: " + err.Error()) + rctx.Log.Error("Unexpected error locating media: " + err.Error()) return api.InternalServerError("Unexpected Error") } diff --git a/api/r0/upload.go b/api/r0/upload.go index fe42d2051d99446df749aceaa949528e46bce8c9..2f26fa32c86e898a16c6e189263a25ce914ba1ac 100644 --- a/api/r0/upload.go +++ b/api/r0/upload.go @@ -9,6 +9,7 @@ import ( "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/api" "github.com/turt2live/matrix-media-repo/common" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/controllers/upload_controller" ) @@ -16,11 +17,11 @@ type MediaUploadedResponse struct { ContentUri string `json:"content_uri"` } -func UploadMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { +func UploadMedia(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) interface{} { filename := filepath.Base(r.URL.Query().Get("filename")) defer r.Body.Close() - log = log.WithFields(logrus.Fields{ + rctx = rctx.LogWithFields(logrus.Fields{ "filename": filename, }) @@ -41,7 +42,7 @@ func UploadMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) interfac contentLength := upload_controller.EstimateContentLength(r.ContentLength, r.Header.Get("Content-Length")) - media, err := upload_controller.UploadMedia(r.Body, contentLength, contentType, filename, user.UserId, r.Host, r.Context(), log) + media, err := upload_controller.UploadMedia(r.Body, contentLength, contentType, filename, user.UserId, r.Host, rctx) if err != nil { io.Copy(ioutil.Discard, r.Body) // Ditch the entire request @@ -51,7 +52,7 @@ func UploadMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) interfac return api.BadRequest("This file is not permitted on this server") } - log.Error("Unexpected error storing media: " + err.Error()) + rctx.Log.Error("Unexpected error storing media: " + err.Error()) return api.InternalServerError("Unexpected Error") } diff --git a/api/unstable/info.go b/api/unstable/info.go index 67d381803be8b0b2110174d30ea42e4a33ed9153..d203d549579cbabb60a0f566584d6e70b57d1bca 100644 --- a/api/unstable/info.go +++ b/api/unstable/info.go @@ -10,6 +10,7 @@ import ( "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/api" "github.com/turt2live/matrix-media-repo/common" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/controllers/download_controller" "github.com/turt2live/matrix-media-repo/storage" ) @@ -34,7 +35,7 @@ type MediaInfoResponse struct { Thumbnails []*mediaInfoThumbnail `json:"thumbnails,omitempty"` } -func MediaInfo(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { +func MediaInfo(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) interface{} { params := mux.Vars(r) server := params["server"] @@ -50,13 +51,13 @@ func MediaInfo(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{ downloadRemote = parsedFlag } - log = log.WithFields(logrus.Fields{ + rctx = rctx.LogWithFields(logrus.Fields{ "mediaId": mediaId, "server": server, "allowRemote": downloadRemote, }) - streamedMedia, err := download_controller.GetMedia(server, mediaId, downloadRemote, true, r.Context(), log) + streamedMedia, err := download_controller.GetMedia(server, mediaId, downloadRemote, true, rctx) if err != nil { if err == common.ErrMediaNotFound { return api.NotFoundError() @@ -65,7 +66,7 @@ func MediaInfo(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{ } else if err == common.ErrMediaQuarantined { return api.NotFoundError() // We lie for security } - log.Error("Unexpected error locating media: " + err.Error()) + rctx.Log.Error("Unexpected error locating media: " + err.Error()) return api.InternalServerError("Unexpected Error") } defer streamedMedia.Stream.Close() @@ -85,10 +86,10 @@ func MediaInfo(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{ response.Height = img.Bounds().Max.Y } - thumbsDb := storage.GetDatabase().GetThumbnailStore(r.Context(), log) + thumbsDb := storage.GetDatabase().GetThumbnailStore(rctx) thumbs, err := thumbsDb.GetAllForMedia(streamedMedia.KnownMedia.Origin, streamedMedia.KnownMedia.MediaId) if err != nil && err != sql.ErrNoRows { - log.Error("Unexpected error locating media: " + err.Error()) + rctx.Log.Error("Unexpected error locating media: " + err.Error()) return api.InternalServerError("Unexpected Error") } diff --git a/api/unstable/local_copy.go b/api/unstable/local_copy.go index 6d17fbc3679725a96c318ea985bb87f2602165e1..963a7526052eea99ad5363d4325ba5d8c30c1f75 100644 --- a/api/unstable/local_copy.go +++ b/api/unstable/local_copy.go @@ -9,11 +9,12 @@ import ( "github.com/turt2live/matrix-media-repo/api" "github.com/turt2live/matrix-media-repo/api/r0" "github.com/turt2live/matrix-media-repo/common" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/controllers/download_controller" "github.com/turt2live/matrix-media-repo/controllers/upload_controller" ) -func LocalCopy(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { +func LocalCopy(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) interface{} { params := mux.Vars(r) server := params["server"] @@ -29,7 +30,7 @@ func LocalCopy(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{ downloadRemote = parsedFlag } - log = log.WithFields(logrus.Fields{ + rctx = rctx.LogWithFields(logrus.Fields{ "mediaId": mediaId, "server": server, "allowRemote": downloadRemote, @@ -37,7 +38,7 @@ func LocalCopy(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{ // TODO: There's a lot of room for improvement here. Instead of re-uploading media, we should just update the DB. - streamedMedia, err := download_controller.GetMedia(server, mediaId, downloadRemote, true, r.Context(), log) + streamedMedia, err := download_controller.GetMedia(server, mediaId, downloadRemote, true, rctx) if err != nil { if err == common.ErrMediaNotFound { return api.NotFoundError() @@ -46,7 +47,7 @@ func LocalCopy(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{ } else if err == common.ErrMediaQuarantined { return api.NotFoundError() // We lie for security } - log.Error("Unexpected error locating media: " + err.Error()) + rctx.Log.Error("Unexpected error locating media: " + err.Error()) return api.InternalServerError("Unexpected Error") } defer streamedMedia.Stream.Close() @@ -56,13 +57,13 @@ func LocalCopy(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{ return &r0.MediaUploadedResponse{ContentUri: streamedMedia.KnownMedia.MxcUri()} } - newMedia, err := upload_controller.UploadMedia(streamedMedia.Stream, streamedMedia.KnownMedia.SizeBytes, streamedMedia.KnownMedia.ContentType, streamedMedia.KnownMedia.UploadName, user.UserId, r.Host, r.Context(), log) + newMedia, err := upload_controller.UploadMedia(streamedMedia.Stream, streamedMedia.KnownMedia.SizeBytes, streamedMedia.KnownMedia.ContentType, streamedMedia.KnownMedia.UploadName, user.UserId, r.Host, rctx) if err != nil { if err == common.ErrMediaNotAllowed { return api.BadRequest("Media content type not allowed on this server") } - log.Error("Unexpected error storing media: " + err.Error()) + rctx.Log.Error("Unexpected error storing media: " + err.Error()) return api.InternalServerError("Unexpected Error") } diff --git a/api/webserver/route_handler.go b/api/webserver/route_handler.go index 218905f7ce7ea37b2f26c4be8c66bf90668e70ca..20e2fe278a9d134ad77b1dbfe96df9f46f688007 100644 --- a/api/webserver/route_handler.go +++ b/api/webserver/route_handler.go @@ -2,6 +2,7 @@ package webserver import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -19,12 +20,13 @@ import ( "github.com/turt2live/matrix-media-repo/api/r0" "github.com/turt2live/matrix-media-repo/common" "github.com/turt2live/matrix-media-repo/common/config" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/metrics" "github.com/turt2live/matrix-media-repo/util" ) type handler struct { - h func(r *http.Request, entry *logrus.Entry) interface{} + h func(r *http.Request, ctx rcontext.RequestContext) interface{} action string reqCounter *requestCounter ignoreHost bool @@ -78,13 +80,23 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Process response var res interface{} = api.AuthFailed() if util.IsServerOurs(r.Host) || h.ignoreHost { - contextLog.Info("Server is owned by us, processing request") + contextLog.Info("Host is valid - processing request") + + // Build a context that can be used throughout the remainder of the app + // This is kinda annoying, but it's better than trying to pass our own + // thing throughout the layers. + ctx := r.Context() + ctx = context.WithValue(ctx, "mr.logger", contextLog) + ctx = context.WithValue(ctx, "mr.serverConfig", config.Get()) + rctx := rcontext.RequestContext{Context: ctx, Log: contextLog} + r = r.WithContext(rctx) + metrics.HttpRequests.With(prometheus.Labels{ "host": r.Host, "action": h.action, "method": r.Method, }).Inc() - res = h.h(r, contextLog) + res = h.h(r, rctx) if res == nil { res = &api.EmptyResponse{} } diff --git a/api/webserver/webserver.go b/api/webserver/webserver.go index cd82b6d19ce924ab2de035b72e83d06cc22fee72..b54cb2a8b18ff40fa5e1b1b753d53140e2a50322 100644 --- a/api/webserver/webserver.go +++ b/api/webserver/webserver.go @@ -205,10 +205,10 @@ func Reload() { func Stop() { if srv != nil { - ctx, cancel := context.WithTimeout(context.Background(), 5 * time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := srv.Shutdown(ctx); err != nil { panic(err) } } -} \ No newline at end of file +} diff --git a/cmd/media_repo/inits.go b/cmd/media_repo/inits.go index cf35545210a41b2331ee8810e69cbcf379eb2469..8a9ee1e888d236a340c7b0e19b83728550d3e7f8 100644 --- a/cmd/media_repo/inits.go +++ b/cmd/media_repo/inits.go @@ -1,12 +1,12 @@ package main import ( - "context" "fmt" "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/common" "github.com/turt2live/matrix-media-repo/common/config" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/controllers/maintenance_controller" "github.com/turt2live/matrix-media-repo/storage" "github.com/turt2live/matrix-media-repo/storage/datastore" @@ -14,9 +14,8 @@ import ( ) func scanAndStartUnfinishedTasks() error { - ctx := context.Background() - log := logrus.WithFields(logrus.Fields{"stage": "startup"}) - db := storage.GetDatabase().GetMetadataStore(ctx, log) + ctx := rcontext.Initial().LogWithFields(logrus.Fields{"stage": "startup"}) + db := storage.GetDatabase().GetMetadataStore(ctx) tasks, err := db.GetAllBackgroundTasks() if err != nil { return err @@ -26,7 +25,7 @@ func scanAndStartUnfinishedTasks() error { continue } - taskLog := log.WithFields(logrus.Fields{ + taskCtx := ctx.LogWithFields(logrus.Fields{ "prev_task_id": task.ID, "prev_task_name": task.Name, }) @@ -36,16 +35,16 @@ func scanAndStartUnfinishedTasks() error { sourceDsId := task.Params["source_datastore_id"].(string) targetDsId := task.Params["target_datastore_id"].(string) - sourceDs, err := datastore.LocateDatastore(ctx, taskLog, sourceDsId) + sourceDs, err := datastore.LocateDatastore(taskCtx, sourceDsId) if err != nil { return err } - targetDs, err := datastore.LocateDatastore(ctx, taskLog, targetDsId) + targetDs, err := datastore.LocateDatastore(taskCtx, targetDsId) if err != nil { return err } - newTask, err := maintenance_controller.StartStorageMigration(sourceDs, targetDs, beforeTs, taskLog) + newTask, err := maintenance_controller.StartStorageMigration(sourceDs, targetDs, beforeTs, taskCtx) if err != nil { return err } @@ -55,9 +54,9 @@ func scanAndStartUnfinishedTasks() error { return err } - taskLog.Infof("Started replacement task ID %d for unfinished task %d (%s)", newTask.ID, task.ID, task.Name) + logrus.Infof("Started replacement task ID %d for unfinished task %d (%s)", newTask.ID, task.ID, task.Name) } else { - taskLog.Warn(fmt.Sprintf("Unknown task %s at ID %d - ignoring", task.Name, task.ID)) + logrus.Warn(fmt.Sprintf("Unknown task %s at ID %d - ignoring", task.Name, task.ID)) } } @@ -73,7 +72,8 @@ func loadDatastores() { if len(config.Get().Uploads.StoragePaths) > 0 { logrus.Warn("storagePaths usage is deprecated - please use datastores instead") for _, p := range config.Get().Uploads.StoragePaths { - ds, err := storage.GetOrCreateDatastoreOfType(context.Background(), logrus.WithFields(logrus.Fields{"path": p}), "file", p) + ctx := rcontext.Initial().LogWithFields(logrus.Fields{"path": p}) + ds, err := storage.GetOrCreateDatastoreOfType(ctx, "file", p) if err != nil { logrus.Fatal(err) } @@ -88,7 +88,7 @@ func loadDatastores() { } } - mediaStore := storage.GetDatabase().GetMediaStore(context.TODO(), &logrus.Entry{}) + mediaStore := storage.GetDatabase().GetMediaStore(rcontext.Initial()) logrus.Info("Initializing datastores...") for _, ds := range config.Get().DataStores { @@ -98,7 +98,7 @@ func loadDatastores() { uri := datastore.GetUriForDatastore(ds) - _, err := storage.GetOrCreateDatastoreOfType(context.TODO(), &logrus.Entry{}, ds.Type, uri) + _, err := storage.GetOrCreateDatastoreOfType(rcontext.Initial(), ds.Type, uri) if err != nil { logrus.Fatal(err) } diff --git a/common/config/config.go b/common/config/config.go index 76271be7f0b69e6bc5158fc380eeebe57cc598dc..c5e333c4ee91687f415e73a344f142426c04239d 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -19,154 +19,6 @@ type runtimeConfig struct { var Runtime = &runtimeConfig{} -type HomeserverConfig struct { - Name string `yaml:"name"` - ClientServerApi string `yaml:"csApi"` - BackoffAt int `yaml:"backoffAt"` - AdminApiKind string `yaml:"adminApiKind"` -} - -type GeneralConfig struct { - BindAddress string `yaml:"bindAddress"` - Port int `yaml:"port"` - LogDirectory string `yaml:"logDirectory"` - TrustAnyForward bool `yaml:"trustAnyForwardedAddress"` - UseForwardedHost bool `yaml:"useForwardedHost"` -} - -type DbPoolConfig struct { - MaxConnections int `yaml:"maxConnections"` - MaxIdle int `yaml:"maxIdleConnections"` -} - -type DatabaseConfig struct { - Postgres string `yaml:"postgres"` - Pool *DbPoolConfig `yaml:"pool"` -} - -type ArchivingConfig struct { - Enabled bool `yaml:"enabled"` - SelfService bool `yaml:"selfService"` - TargetBytesPerPart int64 `yaml:"targetBytesPerPart"` -} - -type UploadsConfig struct { - StoragePaths []string `yaml:"storagePaths,flow"` // deprecated - MaxSizeBytes int64 `yaml:"maxBytes"` - MinSizeBytes int64 `yaml:"minBytes"` - AllowedTypes []string `yaml:"allowedTypes,flow"` - PerUserExclusions map[string][]string `yaml:"exclusions,flow"` - ReportedMaxSizeBytes int64 `yaml:"reportedMaxBytes"` -} - -type DatastoreConfig struct { - Type string `yaml:"type"` - Enabled bool `yaml:"enabled"` - ForUploads bool `yaml:"forUploads"` // deprecated - MediaKinds []string `yaml:"forKinds,flow"` - Options map[string]string `yaml:"opts,flow"` -} - -type DownloadsConfig struct { - MaxSizeBytes int64 `yaml:"maxBytes"` - NumWorkers int `yaml:"numWorkers"` - FailureCacheMinutes int `yaml:"failureCacheMinutes"` - Cache *CacheConfig `yaml:"cache"` -} - -type ThumbnailsConfig struct { - MaxSourceBytes int64 `yaml:"maxSourceBytes"` - NumWorkers int `yaml:"numWorkers"` - Types []string `yaml:"types,flow"` - MaxAnimateSizeBytes int64 `yaml:"maxAnimateSizeBytes"` - Sizes []*ThumbnailSize `yaml:"sizes,flow"` - AllowAnimated bool `yaml:"allowAnimated"` - DefaultAnimated bool `yaml:"defaultAnimated"` - StillFrame float32 `yaml:"stillFrame"` -} - -type ThumbnailSize struct { - Width int `yaml:"width"` - Height int `yaml:"height"` -} - -type UrlPreviewsConfig struct { - Enabled bool `yaml:"enabled"` - NumWords int `yaml:"numWords"` - NumTitleWords int `yaml:"numTitleWords"` - MaxLength int `yaml:"maxLength"` - MaxTitleLength int `yaml:"maxTitleLength"` - MaxPageSizeBytes int64 `yaml:"maxPageSizeBytes"` - NumWorkers int `yaml:"numWorkers"` - FilePreviewTypes []string `yaml:"filePreviewTypes,flow"` - DisallowedNetworks []string `yaml:"disallowedNetworks,flow"` - AllowedNetworks []string `yaml:"allowedNetworks,flow"` - UnsafeCertificates bool `yaml:"previewUnsafeCertificates"` -} - -type RateLimitConfig struct { - RequestsPerSecond float64 `yaml:"requestsPerSecond"` - Enabled bool `yaml:"enabled"` - BurstCount int `yaml:"burst"` -} - -type IdenticonsConfig struct { - Enabled bool `yaml:"enabled"` -} - -type CacheConfig struct { - Enabled bool `yaml:"enabled"` - MaxSizeBytes int64 `yaml:"maxSizeBytes"` - MaxFileSizeBytes int64 `yaml:"maxFileSizeBytes"` - TrackedMinutes int `yaml:"trackedMinutes"` - MinCacheTimeSeconds int `yaml:"minCacheTimeSeconds"` - MinEvictedTimeSeconds int `yaml:"minEvictedTimeSeconds"` - MinDownloads int `yaml:"minDownloads"` -} - -type QuarantineConfig struct { - ReplaceThumbnails bool `yaml:"replaceThumbnails"` - ReplaceDownloads bool `yaml:"replaceDownloads"` - ThumbnailPath string `yaml:"thumbnailPath"` - AllowLocalAdmins bool `yaml:"allowLocalAdmins"` -} - -type TimeoutsConfig struct { - UrlPreviews int `yaml:"urlPreviewTimeoutSeconds"` - Federation int `yaml:"federationTimeoutSeconds"` - ClientServer int `yaml:"clientServerTimeoutSeconds"` -} - -type MetricsConfig struct { - Enabled bool `yaml:"enabled"` - BindAddress string `yaml:"bindAddress"` - Port int `yaml:"port"` -} - -type SharedSecretConfig struct { - Enabled bool `yaml:"enabled"` - Token string `yaml:"token"` -} - -type MediaRepoConfig struct { - General *GeneralConfig `yaml:"repo"` - Homeservers []*HomeserverConfig `yaml:"homeservers,flow"` - Admins []string `yaml:"admins,flow"` - Database *DatabaseConfig `yaml:"database"` - DataStores []DatastoreConfig `yaml:"datastores"` - Archiving *ArchivingConfig `yaml:"archiving"` - Uploads *UploadsConfig `yaml:"uploads"` - Downloads *DownloadsConfig `yaml:"downloads"` - Thumbnails *ThumbnailsConfig `yaml:"thumbnails"` - UrlPreviews *UrlPreviewsConfig `yaml:"urlPreviews"` - RateLimit *RateLimitConfig `yaml:"rateLimit"` - Identicons *IdenticonsConfig `yaml:"identicons"` - Quarantine *QuarantineConfig `yaml:"quarantine"` - TimeoutSeconds *TimeoutsConfig `yaml:"timeouts"` - Metrics *MetricsConfig `yaml:"metrics"` - SharedSecret *SharedSecretConfig `yaml:"sharedSecretAuth"` -} - var instance *MediaRepoConfig var singletonLock = &sync.Once{} var Path = "media-repo.yaml" @@ -256,126 +108,3 @@ func Get() *MediaRepoConfig { } return instance } - -func NewDefaultConfig() *MediaRepoConfig { - return &MediaRepoConfig{ - General: &GeneralConfig{ - BindAddress: "127.0.0.1", - Port: 8000, - LogDirectory: "logs", - TrustAnyForward: false, - UseForwardedHost: true, - }, - Database: &DatabaseConfig{ - Postgres: "postgres://your_username:your_password@localhost/database_name?sslmode=disable", - Pool: &DbPoolConfig{ - MaxConnections: 25, - MaxIdle: 5, - }, - }, - Homeservers: []*HomeserverConfig{}, - Admins: []string{}, - DataStores: []DatastoreConfig{}, - Archiving: &ArchivingConfig{ - Enabled: true, - SelfService: false, - TargetBytesPerPart: 209715200, // 200mb - }, - Uploads: &UploadsConfig{ - MaxSizeBytes: 104857600, // 100mb - MinSizeBytes: 100, - ReportedMaxSizeBytes: 0, - StoragePaths: []string{}, - AllowedTypes: []string{"*/*"}, - }, - Downloads: &DownloadsConfig{ - MaxSizeBytes: 104857600, // 100mb - NumWorkers: 10, - FailureCacheMinutes: 15, - Cache: &CacheConfig{ - Enabled: true, - MaxSizeBytes: 1048576000, // 1gb - MaxFileSizeBytes: 104857600, // 100mb - TrackedMinutes: 30, - MinDownloads: 5, - MinCacheTimeSeconds: 300, // 5min - MinEvictedTimeSeconds: 60, - }, - }, - UrlPreviews: &UrlPreviewsConfig{ - Enabled: true, - NumWords: 50, - NumTitleWords: 30, - MaxLength: 200, - MaxTitleLength: 150, - MaxPageSizeBytes: 10485760, // 10mb - NumWorkers: 10, - FilePreviewTypes: []string{ - "image/*", - }, - DisallowedNetworks: []string{ - "127.0.0.1/8", - "10.0.0.0/8", - "172.16.0.0/12", - "192.168.0.0/16", - "100.64.0.0/10", - "169.254.0.0/16", - "::1/128", - "fe80::/64", - "fc00::/7", - }, - AllowedNetworks: []string{ - "0.0.0.0/0", // "Everything" - }, - }, - Thumbnails: &ThumbnailsConfig{ - MaxSourceBytes: 10485760, // 10mb - MaxAnimateSizeBytes: 10485760, // 10mb - NumWorkers: 10, - AllowAnimated: true, - DefaultAnimated: false, - StillFrame: 0.5, - Sizes: []*ThumbnailSize{ - {32, 32}, - {96, 96}, - {320, 240}, - {640, 480}, - {800, 600}, - }, - Types: []string{ - "image/jpeg", - "image/jpg", - "image/png", - "image/gif", - }, - }, - RateLimit: &RateLimitConfig{ - Enabled: true, - RequestsPerSecond: 5, - BurstCount: 10, - }, - Identicons: &IdenticonsConfig{ - Enabled: true, - }, - Quarantine: &QuarantineConfig{ - ReplaceThumbnails: true, - ReplaceDownloads: false, - ThumbnailPath: "", - AllowLocalAdmins: true, - }, - TimeoutSeconds: &TimeoutsConfig{ - UrlPreviews: 10, - ClientServer: 30, - Federation: 120, - }, - Metrics: &MetricsConfig{ - Enabled: false, - BindAddress: "localhost", - Port: 9000, - }, - SharedSecret: &SharedSecretConfig{ - Enabled: false, - Token: "ReplaceMe", - }, - } -} diff --git a/common/config/defaults.go b/common/config/defaults.go new file mode 100644 index 0000000000000000000000000000000000000000..c19345d140647dbc6f63fdd53f4430239b050602 --- /dev/null +++ b/common/config/defaults.go @@ -0,0 +1,124 @@ +package config + +func NewDefaultConfig() *MediaRepoConfig { + return &MediaRepoConfig{ + General: &GeneralConfig{ + BindAddress: "127.0.0.1", + Port: 8000, + LogDirectory: "logs", + TrustAnyForward: false, + UseForwardedHost: true, + }, + Database: &DatabaseConfig{ + Postgres: "postgres://your_username:your_password@localhost/database_name?sslmode=disable", + Pool: &DbPoolConfig{ + MaxConnections: 25, + MaxIdle: 5, + }, + }, + Homeservers: []*HomeserverConfig{}, + Admins: []string{}, + DataStores: []DatastoreConfig{}, + Archiving: &ArchivingConfig{ + Enabled: true, + SelfService: false, + TargetBytesPerPart: 209715200, // 200mb + }, + Uploads: &UploadsConfig{ + MaxSizeBytes: 104857600, // 100mb + MinSizeBytes: 100, + ReportedMaxSizeBytes: 0, + StoragePaths: []string{}, + AllowedTypes: []string{"*/*"}, + }, + Downloads: &DownloadsConfig{ + MaxSizeBytes: 104857600, // 100mb + NumWorkers: 10, + FailureCacheMinutes: 15, + Cache: &CacheConfig{ + Enabled: true, + MaxSizeBytes: 1048576000, // 1gb + MaxFileSizeBytes: 104857600, // 100mb + TrackedMinutes: 30, + MinDownloads: 5, + MinCacheTimeSeconds: 300, // 5min + MinEvictedTimeSeconds: 60, + }, + }, + UrlPreviews: &UrlPreviewsConfig{ + Enabled: true, + NumWords: 50, + NumTitleWords: 30, + MaxLength: 200, + MaxTitleLength: 150, + MaxPageSizeBytes: 10485760, // 10mb + NumWorkers: 10, + FilePreviewTypes: []string{ + "image/*", + }, + DisallowedNetworks: []string{ + "127.0.0.1/8", + "10.0.0.0/8", + "172.16.0.0/12", + "192.168.0.0/16", + "100.64.0.0/10", + "169.254.0.0/16", + "::1/128", + "fe80::/64", + "fc00::/7", + }, + AllowedNetworks: []string{ + "0.0.0.0/0", // "Everything" + }, + }, + Thumbnails: &ThumbnailsConfig{ + MaxSourceBytes: 10485760, // 10mb + MaxAnimateSizeBytes: 10485760, // 10mb + NumWorkers: 10, + AllowAnimated: true, + DefaultAnimated: false, + StillFrame: 0.5, + Sizes: []*ThumbnailSize{ + {32, 32}, + {96, 96}, + {320, 240}, + {640, 480}, + {800, 600}, + }, + Types: []string{ + "image/jpeg", + "image/jpg", + "image/png", + "image/gif", + }, + }, + RateLimit: &RateLimitConfig{ + Enabled: true, + RequestsPerSecond: 5, + BurstCount: 10, + }, + Identicons: &IdenticonsConfig{ + Enabled: true, + }, + Quarantine: &QuarantineConfig{ + ReplaceThumbnails: true, + ReplaceDownloads: false, + ThumbnailPath: "", + AllowLocalAdmins: true, + }, + TimeoutSeconds: &TimeoutsConfig{ + UrlPreviews: 10, + ClientServer: 30, + Federation: 120, + }, + Metrics: &MetricsConfig{ + Enabled: false, + BindAddress: "localhost", + Port: 9000, + }, + SharedSecret: &SharedSecretConfig{ + Enabled: false, + Token: "ReplaceMe", + }, + } +} diff --git a/common/config/models.go b/common/config/models.go new file mode 100644 index 0000000000000000000000000000000000000000..1ed2defa7c84bd5bfbc782877917928aeda6e2ca --- /dev/null +++ b/common/config/models.go @@ -0,0 +1,149 @@ +package config + +type HomeserverConfig struct { + Name string `yaml:"name"` + ClientServerApi string `yaml:"csApi"` + BackoffAt int `yaml:"backoffAt"` + AdminApiKind string `yaml:"adminApiKind"` +} + +type GeneralConfig struct { + BindAddress string `yaml:"bindAddress"` + Port int `yaml:"port"` + LogDirectory string `yaml:"logDirectory"` + TrustAnyForward bool `yaml:"trustAnyForwardedAddress"` + UseForwardedHost bool `yaml:"useForwardedHost"` +} + +type DbPoolConfig struct { + MaxConnections int `yaml:"maxConnections"` + MaxIdle int `yaml:"maxIdleConnections"` +} + +type DatabaseConfig struct { + Postgres string `yaml:"postgres"` + Pool *DbPoolConfig `yaml:"pool"` +} + +type ArchivingConfig struct { + Enabled bool `yaml:"enabled"` + SelfService bool `yaml:"selfService"` + TargetBytesPerPart int64 `yaml:"targetBytesPerPart"` +} + +type UploadsConfig struct { + StoragePaths []string `yaml:"storagePaths,flow"` // deprecated + MaxSizeBytes int64 `yaml:"maxBytes"` + MinSizeBytes int64 `yaml:"minBytes"` + AllowedTypes []string `yaml:"allowedTypes,flow"` + PerUserExclusions map[string][]string `yaml:"exclusions,flow"` + ReportedMaxSizeBytes int64 `yaml:"reportedMaxBytes"` +} + +type DatastoreConfig struct { + Type string `yaml:"type"` + Enabled bool `yaml:"enabled"` + ForUploads bool `yaml:"forUploads"` // deprecated + MediaKinds []string `yaml:"forKinds,flow"` + Options map[string]string `yaml:"opts,flow"` +} + +type DownloadsConfig struct { + MaxSizeBytes int64 `yaml:"maxBytes"` + NumWorkers int `yaml:"numWorkers"` + FailureCacheMinutes int `yaml:"failureCacheMinutes"` + Cache *CacheConfig `yaml:"cache"` +} + +type ThumbnailsConfig struct { + MaxSourceBytes int64 `yaml:"maxSourceBytes"` + NumWorkers int `yaml:"numWorkers"` + Types []string `yaml:"types,flow"` + MaxAnimateSizeBytes int64 `yaml:"maxAnimateSizeBytes"` + Sizes []*ThumbnailSize `yaml:"sizes,flow"` + AllowAnimated bool `yaml:"allowAnimated"` + DefaultAnimated bool `yaml:"defaultAnimated"` + StillFrame float32 `yaml:"stillFrame"` +} + +type ThumbnailSize struct { + Width int `yaml:"width"` + Height int `yaml:"height"` +} + +type UrlPreviewsConfig struct { + Enabled bool `yaml:"enabled"` + NumWords int `yaml:"numWords"` + NumTitleWords int `yaml:"numTitleWords"` + MaxLength int `yaml:"maxLength"` + MaxTitleLength int `yaml:"maxTitleLength"` + MaxPageSizeBytes int64 `yaml:"maxPageSizeBytes"` + NumWorkers int `yaml:"numWorkers"` + FilePreviewTypes []string `yaml:"filePreviewTypes,flow"` + DisallowedNetworks []string `yaml:"disallowedNetworks,flow"` + AllowedNetworks []string `yaml:"allowedNetworks,flow"` + UnsafeCertificates bool `yaml:"previewUnsafeCertificates"` +} + +type RateLimitConfig struct { + RequestsPerSecond float64 `yaml:"requestsPerSecond"` + Enabled bool `yaml:"enabled"` + BurstCount int `yaml:"burst"` +} + +type IdenticonsConfig struct { + Enabled bool `yaml:"enabled"` +} + +type CacheConfig struct { + Enabled bool `yaml:"enabled"` + MaxSizeBytes int64 `yaml:"maxSizeBytes"` + MaxFileSizeBytes int64 `yaml:"maxFileSizeBytes"` + TrackedMinutes int `yaml:"trackedMinutes"` + MinCacheTimeSeconds int `yaml:"minCacheTimeSeconds"` + MinEvictedTimeSeconds int `yaml:"minEvictedTimeSeconds"` + MinDownloads int `yaml:"minDownloads"` +} + +type QuarantineConfig struct { + ReplaceThumbnails bool `yaml:"replaceThumbnails"` + ReplaceDownloads bool `yaml:"replaceDownloads"` + ThumbnailPath string `yaml:"thumbnailPath"` + AllowLocalAdmins bool `yaml:"allowLocalAdmins"` +} + +type TimeoutsConfig struct { + UrlPreviews int `yaml:"urlPreviewTimeoutSeconds"` + Federation int `yaml:"federationTimeoutSeconds"` + ClientServer int `yaml:"clientServerTimeoutSeconds"` +} + +type MetricsConfig struct { + Enabled bool `yaml:"enabled"` + BindAddress string `yaml:"bindAddress"` + Port int `yaml:"port"` +} + +type SharedSecretConfig struct { + Enabled bool `yaml:"enabled"` + Token string `yaml:"token"` +} + +type MediaRepoConfig struct { + General *GeneralConfig `yaml:"repo"` + Homeservers []*HomeserverConfig `yaml:"homeservers,flow"` + Admins []string `yaml:"admins,flow"` + Database *DatabaseConfig `yaml:"database"` + DataStores []DatastoreConfig `yaml:"datastores"` + Archiving *ArchivingConfig `yaml:"archiving"` + Uploads *UploadsConfig `yaml:"uploads"` + Downloads *DownloadsConfig `yaml:"downloads"` + Thumbnails *ThumbnailsConfig `yaml:"thumbnails"` + UrlPreviews *UrlPreviewsConfig `yaml:"urlPreviews"` + RateLimit *RateLimitConfig `yaml:"rateLimit"` + Identicons *IdenticonsConfig `yaml:"identicons"` + Quarantine *QuarantineConfig `yaml:"quarantine"` + TimeoutSeconds *TimeoutsConfig `yaml:"timeouts"` + Metrics *MetricsConfig `yaml:"metrics"` + SharedSecret *SharedSecretConfig `yaml:"sharedSecretAuth"` +} diff --git a/common/globals/singleflight_groups.go b/common/globals/singleflight_groups.go index f13df437bd8d9061dfce384384b0e4f5021943b2..4f5fd266f599d11dd9a17cc8bf715572b01bc9be 100644 --- a/common/globals/singleflight_groups.go +++ b/common/globals/singleflight_groups.go @@ -4,4 +4,4 @@ import ( "github.com/turt2live/matrix-media-repo/util/singleflight-counter" ) -var DefaultRequestGroup singleflight_counter.Group \ No newline at end of file +var DefaultRequestGroup singleflight_counter.Group diff --git a/common/rcontext/request_context.go b/common/rcontext/request_context.go new file mode 100644 index 0000000000000000000000000000000000000000..2f89fb244991b6aa2766fe391df2752b20940b83 --- /dev/null +++ b/common/rcontext/request_context.go @@ -0,0 +1,43 @@ +package rcontext + +import ( + "context" + + "github.com/sirupsen/logrus" + "github.com/turt2live/matrix-media-repo/common/config" +) + +func Initial() RequestContext { + return RequestContext{ + Context: context.Background(), + Log: &logrus.Entry{}, + Config: config.Get(), + }.populate() +} + +type RequestContext struct { + context.Context + + // These are also stored on the context object itself + Log *logrus.Entry // mr.logger + Config *config.MediaRepoConfig // mr.serverConfig +} + +func (c RequestContext) populate() RequestContext { + c.Context = context.WithValue(c.Context, "mr.logger", c.Log) + //c.Context = context.WithValue(c.Context, "mr.serverConfig", c.Config) + return c +} + +func (c RequestContext) ReplaceLogger(log *logrus.Entry) RequestContext { + ctx := context.WithValue(c.Context, "mr.logger", log) + return RequestContext{ + Context: ctx, + Log: log, + Config: c.Config, + } +} + +func (c RequestContext) LogWithFields(fields logrus.Fields) RequestContext { + return c.ReplaceLogger(c.Log.WithFields(fields)) +} diff --git a/controllers/data_controller/export_controller.go b/controllers/data_controller/export_controller.go index 22639b562c81961cbe2213fef8fbb25ef45f0c4d..d54d3f4ca4e52fee027a1127b2732fcf6d164512 100644 --- a/controllers/data_controller/export_controller.go +++ b/controllers/data_controller/export_controller.go @@ -4,16 +4,15 @@ import ( "archive/tar" "bytes" "compress/gzip" - "context" "encoding/json" "fmt" "io" "time" "github.com/dustin/go-humanize" - "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/common" "github.com/turt2live/matrix-media-repo/common/config" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/storage" "github.com/turt2live/matrix-media-repo/storage/datastore" "github.com/turt2live/matrix-media-repo/storage/datastore/ds_s3" @@ -45,15 +44,13 @@ type manifest struct { UserId string `json:"user_id,omitempty"` } -func StartServerExport(serverName string, s3urls bool, includeData bool, log *logrus.Entry) (*types.BackgroundTask, string, error) { - ctx := context.Background() - +func StartServerExport(serverName string, s3urls bool, includeData bool, ctx rcontext.RequestContext) (*types.BackgroundTask, string, error) { exportId, err := util.GenerateRandomString(128) if err != nil { return nil, "", err } - db := storage.GetDatabase().GetMetadataStore(ctx, log) + db := storage.GetDatabase().GetMetadataStore(ctx) task, err := db.CreateBackgroundTask("export_data", map[string]interface{}{ "server_name": serverName, "include_s3_urls": s3urls, @@ -66,42 +63,40 @@ func StartServerExport(serverName string, s3urls bool, includeData bool, log *lo } go func() { - ds, err := datastore.PickDatastore(common.KindArchives, ctx, log) + ds, err := datastore.PickDatastore(common.KindArchives, ctx) if err != nil { - log.Error(err) + ctx.Log.Error(err) return } - mediaDb := storage.GetDatabase().GetMediaStore(ctx, log) + mediaDb := storage.GetDatabase().GetMediaStore(ctx) media, err := mediaDb.GetAllMediaForServer(serverName) if err != nil { - log.Error(err) + ctx.Log.Error(err) return } - compileArchive(exportId, serverName, ds, media, s3urls, includeData, ctx, log) + compileArchive(exportId, serverName, ds, media, s3urls, includeData, ctx) - log.Info("Finishing export task") + ctx.Log.Info("Finishing export task") err = db.FinishedBackgroundTask(task.ID) if err != nil { - log.Error(err) - log.Error("Failed to flag task as finished") + ctx.Log.Error(err) + ctx.Log.Error("Failed to flag task as finished") } - log.Info("Finished export") + ctx.Log.Info("Finished export") }() return task, exportId, nil } -func StartUserExport(userId string, s3urls bool, includeData bool, log *logrus.Entry) (*types.BackgroundTask, string, error) { - ctx := context.Background() - +func StartUserExport(userId string, s3urls bool, includeData bool, ctx rcontext.RequestContext) (*types.BackgroundTask, string, error) { exportId, err := util.GenerateRandomString(128) if err != nil { return nil, "", err } - db := storage.GetDatabase().GetMetadataStore(ctx, log) + db := storage.GetDatabase().GetMetadataStore(ctx) task, err := db.CreateBackgroundTask("export_data", map[string]interface{}{ "user_id": userId, "include_s3_urls": s3urls, @@ -114,38 +109,38 @@ func StartUserExport(userId string, s3urls bool, includeData bool, log *logrus.E } go func() { - ds, err := datastore.PickDatastore(common.KindArchives, ctx, log) + ds, err := datastore.PickDatastore(common.KindArchives, ctx, ) if err != nil { - log.Error(err) + ctx.Log.Error(err) return } - mediaDb := storage.GetDatabase().GetMediaStore(ctx, log) + mediaDb := storage.GetDatabase().GetMediaStore(ctx) media, err := mediaDb.GetMediaByUser(userId) if err != nil { - log.Error(err) + ctx.Log.Error(err) return } - compileArchive(exportId, userId, ds, media, s3urls, includeData, ctx, log) + compileArchive(exportId, userId, ds, media, s3urls, includeData, ctx) - log.Info("Finishing export task") + ctx.Log.Info("Finishing export task") err = db.FinishedBackgroundTask(task.ID) if err != nil { - log.Error(err) - log.Error("Failed to flag task as finished") + ctx.Log.Error(err) + ctx.Log.Error("Failed to flag task as finished") } - log.Info("Finished export") + ctx.Log.Info("Finished export") }() return task, exportId, nil } -func compileArchive(exportId string, entityId string, archiveDs *datastore.DatastoreRef, media []*types.Media, s3urls bool, includeData bool, ctx context.Context, log *logrus.Entry) { - exportDb := storage.GetDatabase().GetExportStore(ctx, log) +func compileArchive(exportId string, entityId string, archiveDs *datastore.DatastoreRef, media []*types.Media, s3urls bool, includeData bool, ctx rcontext.RequestContext) { + exportDb := storage.GetDatabase().GetExportStore(ctx) err := exportDb.InsertExport(exportId, entityId) if err != nil { - log.Error(err) + ctx.Log.Error(err) return } @@ -159,7 +154,7 @@ func compileArchive(exportId string, entityId string, archiveDs *datastore.Datas currentTar.Close() // compress - log.Info("Compressing tar file") + ctx.Log.Info("Compressing tar file") gzipBytes := bytes.Buffer{} archiver := gzip.NewWriter(&gzipBytes) archiver.Name = fmt.Sprintf("export-part-%d.tar", part) @@ -169,10 +164,10 @@ func compileArchive(exportId string, entityId string, archiveDs *datastore.Datas } archiver.Close() - log.Info("Uploading compressed tar file") + ctx.Log.Info("Uploading compressed tar file") buf := bytes.NewBuffer(gzipBytes.Bytes()) size := int64(buf.Len()) - obj, err := archiveDs.UploadFile(util.BufferToStream(buf), size, ctx, log) + obj, err := archiveDs.UploadFile(util.BufferToStream(buf), size, ctx) if err != nil { return err } @@ -189,14 +184,14 @@ func compileArchive(exportId string, entityId string, archiveDs *datastore.Datas newTar := func() error { if part > 0 { - log.Info("Persisting complete tar file") + ctx.Log.Info("Persisting complete tar file") err := persistTar() if err != nil { return err } } - log.Info("Starting new tar file") + ctx.Log.Info("Starting new tar file") currentTarBytes = bytes.Buffer{} currentTar = tar.NewWriter(¤tTarBytes) part = part + 1 @@ -206,10 +201,10 @@ func compileArchive(exportId string, entityId string, archiveDs *datastore.Datas } // Start the first tar file - log.Info("Creating first tar file") + ctx.Log.Info("Creating first tar file") err = newTar() if err != nil { - log.Error(err) + ctx.Log.Error(err) return } @@ -222,13 +217,13 @@ func compileArchive(exportId string, entityId string, archiveDs *datastore.Datas } err := currentTar.WriteHeader(header) if err != nil { - log.Error("error writing header") + ctx.Log.Error("error writing header") return err } i, err := io.Copy(currentTar, file) if err != nil { - log.Error("error writing file") + ctx.Log.Error("error writing file") return err } @@ -243,7 +238,7 @@ func compileArchive(exportId string, entityId string, archiveDs *datastore.Datas } // Build a manifest first (JSON) - log.Info("Building manifest") + ctx.Log.Info("Building manifest") indexModel := &templating.ExportIndexModel{ Entity: entityId, ExportID: exportId, @@ -255,7 +250,7 @@ func compileArchive(exportId string, entityId string, archiveDs *datastore.Datas if s3urls { s3url, err = ds_s3.GetS3URL(m.DatastoreId, m.Location) if err != nil { - log.Warn(err) + ctx.Log.Warn(err) } } mediaManifest[m.MxcUri()] = &manifestRecord{ @@ -293,67 +288,67 @@ func compileArchive(exportId string, entityId string, archiveDs *datastore.Datas } b, err := json.Marshal(manifest) if err != nil { - log.Error(err) + ctx.Log.Error(err) return } - log.Info("Writing manifest") + ctx.Log.Info("Writing manifest") err = putFile("manifest.json", int64(len(b)), time.Now(), util.BufferToStream(bytes.NewBuffer(b))) if err != nil { - log.Error(err) + ctx.Log.Error(err) return } if includeData { - log.Info("Building and writing index") + ctx.Log.Info("Building and writing index") t, err := templating.GetTemplate("export_index") if err != nil { - log.Error(err) + ctx.Log.Error(err) return } html := bytes.Buffer{} err = t.Execute(&html, indexModel) if err != nil { - log.Error(err) + ctx.Log.Error(err) return } err = putFile("index.html", int64(html.Len()), time.Now(), util.BufferToStream(bytes.NewBuffer(html.Bytes()))) if err != nil { - log.Error(err) + ctx.Log.Error(err) return } - log.Info("Including data in the archive") + ctx.Log.Info("Including data in the archive") for _, m := range media { - log.Info("Downloading ", m.MxcUri()) - s, err := datastore.DownloadStream(ctx, log, m.DatastoreId, m.Location) + ctx.Log.Info("Downloading ", m.MxcUri()) + s, err := datastore.DownloadStream(ctx, m.DatastoreId, m.Location) if err != nil { - log.Error(err) + ctx.Log.Error(err) continue } - log.Infof("Copying %s to memory", m.MxcUri()) + ctx.Log.Infof("Copying %s to memory", m.MxcUri()) b := bytes.Buffer{} _, err = io.Copy(&b, s) if err != nil { - log.Error(err) + ctx.Log.Error(err) continue } s.Close() s = util.BufferToStream(bytes.NewBuffer(b.Bytes())) - log.Info("Archiving ", m.MxcUri()) + ctx.Log.Info("Archiving ", m.MxcUri()) err = putFile(archivedName(m), m.SizeBytes, time.Unix(0, m.CreationTs*int64(time.Millisecond)), s) if err != nil { - log.Error(err) + ctx.Log.Error(err) return } if currentSize >= config.Get().Archiving.TargetBytesPerPart { - log.Info("Rotating tar") + ctx.Log.Info("Rotating tar") err = newTar() if err != nil { - log.Error(err) + ctx.Log.Error(err) return } } @@ -361,10 +356,10 @@ func compileArchive(exportId string, entityId string, archiveDs *datastore.Datas } if currentSize > 0 { - log.Info("Persisting last tar") + ctx.Log.Info("Persisting last tar") err = persistTar() if err != nil { - log.Error(err) + ctx.Log.Error(err) return } } diff --git a/controllers/data_controller/import_controller.go b/controllers/data_controller/import_controller.go index 51a4029a8ba9a902d0b30bc7d33b988927f68ce1..ff619c01fd612aafc41d5fae2b7de15795638602 100644 --- a/controllers/data_controller/import_controller.go +++ b/controllers/data_controller/import_controller.go @@ -4,7 +4,6 @@ import ( "archive/tar" "bytes" "compress/gzip" - "context" "database/sql" "encoding/json" "errors" @@ -12,8 +11,9 @@ import ( "net/http" "sync" - "github.com/sirupsen/logrus" + "github.com/prometheus/common/log" "github.com/turt2live/matrix-media-repo/common" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/controllers/upload_controller" "github.com/turt2live/matrix-media-repo/storage" "github.com/turt2live/matrix-media-repo/storage/datastore" @@ -29,9 +29,7 @@ type importUpdate struct { var openImports = &sync.Map{} // importId => updateChan -func StartImport(data io.Reader, log *logrus.Entry) (*types.BackgroundTask, string, error) { - ctx := context.Background() - +func StartImport(data io.Reader, ctx rcontext.RequestContext) (*types.BackgroundTask, string, error) { // Prepare the first update for the import (sync, so we can error) // We do this before anything else because if the archive is invalid then we shouldn't // even bother with an import. @@ -45,7 +43,7 @@ func StartImport(data io.Reader, log *logrus.Entry) (*types.BackgroundTask, stri return nil, "", err } - db := storage.GetDatabase().GetMetadataStore(ctx, log) + db := storage.GetDatabase().GetMetadataStore(ctx) task, err := db.CreateBackgroundTask("import_data", map[string]interface{}{ "import_id": importId, }) @@ -56,7 +54,7 @@ func StartImport(data io.Reader, log *logrus.Entry) (*types.BackgroundTask, stri // Start the import and send it its first update updateChan := make(chan *importUpdate) - go doImport(updateChan, task.ID, importId, ctx, log) + go doImport(updateChan, task.ID, importId, ctx) openImports.Store(importId, updateChan) updateChan <- &importUpdate{stop: false, fileMap: results} @@ -131,29 +129,29 @@ func processArchive(data io.Reader) (map[string]*bytes.Buffer, error) { return index, nil } -func doImport(updateChannel chan *importUpdate, taskId int, importId string, ctx context.Context, log *logrus.Entry) { - log.Info("Preparing for import...") +func doImport(updateChannel chan *importUpdate, taskId int, importId string, ctx rcontext.RequestContext) { + ctx.Log.Info("Preparing for import...") fileMap := make(map[string]*bytes.Buffer) stopImport := false archiveManifest := &manifest{} haveManifest := false imported := make(map[string]bool) - db := storage.GetDatabase().GetMediaStore(ctx, log) + db := storage.GetDatabase().GetMediaStore(ctx) for !stopImport { update := <-updateChannel if update.stop { - log.Info("Close requested") + ctx.Log.Info("Close requested") stopImport = true } // Populate files for name, fileBytes := range update.fileMap { if _, ok := fileMap[name]; ok { - log.Warnf("Duplicate file name, skipping: %s", name) + ctx.Log.Warnf("Duplicate file name, skipping: %s", name) continue // file already known to us } - log.Infof("Tracking file: %s", name) + ctx.Log.Infof("Tracking file: %s", name) fileMap[name] = fileBytes } @@ -161,7 +159,7 @@ func doImport(updateChannel chan *importUpdate, taskId int, importId string, ctx var manifestBuf *bytes.Buffer var ok bool if manifestBuf, ok = fileMap["manifest.json"]; !ok { - log.Info("No manifest found - waiting for more files") + ctx.Log.Info("No manifest found - waiting for more files") continue } @@ -169,26 +167,26 @@ func doImport(updateChannel chan *importUpdate, taskId int, importId string, ctx haveManifest = true err := json.Unmarshal(manifestBuf.Bytes(), archiveManifest) if err != nil { - log.Error("Failed to parse manifest - giving up on import") - log.Error(err) + ctx.Log.Error("Failed to parse manifest - giving up on import") + ctx.Log.Error(err) break } if archiveManifest.Version != 1 && archiveManifest.Version != 2 { - log.Error("Unsupported archive version") + ctx.Log.Error("Unsupported archive version") break } if archiveManifest.Version == 1 { archiveManifest.EntityId = archiveManifest.UserId } if archiveManifest.EntityId == "" { - log.Error("Invalid manifest: no entity") + ctx.Log.Error("Invalid manifest: no entity") break } if archiveManifest.Media == nil { - log.Error("Invalid manifest: no media") + ctx.Log.Error("Invalid manifest: no media") break } - log.Infof("Using manifest for %s (v%d) created %d", archiveManifest.EntityId, archiveManifest.Version, archiveManifest.CreatedTs) + ctx.Log.Infof("Using manifest for %s (v%d) created %d", archiveManifest.EntityId, archiveManifest.Version, archiveManifest.CreatedTs) } if !haveManifest { @@ -212,7 +210,7 @@ func doImport(updateChannel chan *importUpdate, taskId int, importId string, ctx if userId != "" { _, s, err := util.SplitUserId(userId) if err != nil { - log.Errorf("Invalid user ID: %s", userId) + ctx.Log.Errorf("Invalid user ID: %s", userId) serverName = "" } else { serverName = s @@ -222,25 +220,25 @@ func doImport(updateChannel chan *importUpdate, taskId int, importId string, ctx kind = common.KindRemoteMedia } - log.Infof("Attempting to import %s for %s", mxc, archiveManifest.EntityId) + ctx.Log.Infof("Attempting to import %s for %s", mxc, archiveManifest.EntityId) buf, found := fileMap[record.ArchivedName] if found { - log.Info("Using file from memory") + ctx.Log.Info("Using file from memory") closer := util.BufferToStream(buf) - _, err := upload_controller.StoreDirect(closer, record.SizeBytes, record.ContentType, record.FileName, userId, record.Origin, record.MediaId, kind, ctx, log) + _, err := upload_controller.StoreDirect(closer, record.SizeBytes, record.ContentType, record.FileName, userId, record.Origin, record.MediaId, kind, ctx) if err != nil { - log.Errorf("Error importing file: %s", err.Error()) + ctx.Log.Errorf("Error importing file: %s", err.Error()) continue } } else if record.S3Url != "" { - log.Info("Using S3 URL") + ctx.Log.Info("Using S3 URL") endpoint, bucket, location, err := ds_s3.ParseS3URL(record.S3Url) if err != nil { - log.Errorf("Error importing file: %s", err.Error()) + ctx.Log.Errorf("Error importing file: %s", err.Error()) continue } - log.Infof("Seeing if a datastore for %s/%s exists", endpoint, bucket) + ctx.Log.Infof("Seeing if a datastore for %s/%s exists", endpoint, bucket) datastores, err := datastore.GetAvailableDatastores() if err != nil { log.Errorf("Error locating datastore: %s", err.Error()) @@ -254,19 +252,19 @@ func doImport(updateChannel chan *importUpdate, taskId int, importId string, ctx tmplUrl, err := ds_s3.GetS3URL(ds.DatastoreId, location) if err != nil { - log.Errorf("Error investigating s3 datastore: %s", err.Error()) + ctx.Log.Errorf("Error investigating s3 datastore: %s", err.Error()) continue } if tmplUrl == record.S3Url { - log.Infof("File matches! Assuming the file has been uploaded already") + ctx.Log.Infof("File matches! Assuming the file has been uploaded already") existingRecord, err := db.Get(record.Origin, record.MediaId) if err != nil && err != sql.ErrNoRows { - log.Errorf("Error testing file in database: %s", err.Error()) + ctx.Log.Errorf("Error testing file in database: %s", err.Error()) break } if err != sql.ErrNoRows && existingRecord != nil { - log.Warnf("Media %s already exists - skipping without altering record", existingRecord.MxcUri()) + ctx.Log.Warnf("Media %s already exists - skipping without altering record", existingRecord.MxcUri()) imported = true break } @@ -292,32 +290,32 @@ func doImport(updateChannel chan *importUpdate, taskId int, importId string, ctx err = db.Insert(media) if err != nil { - log.Errorf("Error creating media record: %s", err.Error()) + ctx.Log.Errorf("Error creating media record: %s", err.Error()) break } - log.Infof("Media %s has been imported", media.MxcUri()) + ctx.Log.Infof("Media %s has been imported", media.MxcUri()) imported = true break } } if !imported { - log.Info("No datastore found - trying to upload by downloading first") + ctx.Log.Info("No datastore found - trying to upload by downloading first") r, err := http.DefaultClient.Get(record.S3Url) if err != nil { - log.Errorf("Error trying to download file from S3 via HTTP: ", err.Error()) + ctx.Log.Errorf("Error trying to download file from S3 via HTTP: ", err.Error()) continue } - _, err = upload_controller.StoreDirect(r.Body, r.ContentLength, record.ContentType, record.FileName, userId, record.Origin, record.MediaId, kind, ctx, log) + _, err = upload_controller.StoreDirect(r.Body, r.ContentLength, record.ContentType, record.FileName, userId, record.Origin, record.MediaId, kind, ctx) if err != nil { - log.Errorf("Error importing file: %s", err.Error()) + ctx.Log.Errorf("Error importing file: %s", err.Error()) continue } } } else { - log.Warn("Missing usable file for import - assuming it will show up in a future upload") + ctx.Log.Warn("Missing usable file for import - assuming it will show up in a future upload") continue } @@ -336,19 +334,19 @@ func doImport(updateChannel chan *importUpdate, taskId int, importId string, ctx } if !missingAny { - log.Info("No more files to import - closing import") + ctx.Log.Info("No more files to import - closing import") stopImport = true } } openImports.Delete(importId) - log.Info("Finishing import task") - dbMeta := storage.GetDatabase().GetMetadataStore(ctx, log) + ctx.Log.Info("Finishing import task") + dbMeta := storage.GetDatabase().GetMetadataStore(ctx) err := dbMeta.FinishedBackgroundTask(taskId) if err != nil { - log.Error(err) - log.Error("Failed to flag task as finished") + ctx.Log.Error(err) + ctx.Log.Error("Failed to flag task as finished") } - log.Info("Finished import") + ctx.Log.Info("Finished import") } diff --git a/controllers/download_controller/download_controller.go b/controllers/download_controller/download_controller.go index 7ea4acb32868a2c52b51d1a5f966ab35a412648c..59fd6208a399add74fc6d8148235923df8306bd6 100644 --- a/controllers/download_controller/download_controller.go +++ b/controllers/download_controller/download_controller.go @@ -2,7 +2,6 @@ package download_controller import ( "bytes" - "context" "database/sql" "errors" "fmt" @@ -10,10 +9,10 @@ import ( "github.com/disintegration/imaging" "github.com/patrickmn/go-cache" - "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/common" "github.com/turt2live/matrix-media-repo/common/config" "github.com/turt2live/matrix-media-repo/common/globals" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/controllers/quarantine_controller" "github.com/turt2live/matrix-media-repo/internal_cache" "github.com/turt2live/matrix-media-repo/storage" @@ -24,14 +23,14 @@ import ( var localCache = cache.New(30*time.Second, 60*time.Second) -func GetMedia(origin string, mediaId string, downloadRemote bool, blockForMedia bool, ctx context.Context, log *logrus.Entry) (*types.MinimalMedia, error) { +func GetMedia(origin string, mediaId string, downloadRemote bool, blockForMedia bool, ctx rcontext.RequestContext) (*types.MinimalMedia, error) { cacheKey := fmt.Sprintf("%s/%s?r=%t&b=%t", origin, mediaId, downloadRemote, blockForMedia) v, _, err := globals.DefaultRequestGroup.Do(cacheKey, func() (interface{}, error) { var media *types.Media var minMedia *types.MinimalMedia var err error if blockForMedia { - media, err = FindMediaRecord(origin, mediaId, downloadRemote, ctx, log) + media, err = FindMediaRecord(origin, mediaId, downloadRemote, ctx) if media != nil { minMedia = &types.MinimalMedia{ Origin: media.Origin, @@ -44,7 +43,7 @@ func GetMedia(origin string, mediaId string, downloadRemote bool, blockForMedia } } } else { - minMedia, err = FindMinimalMediaRecord(origin, mediaId, downloadRemote, ctx, log) + minMedia, err = FindMinimalMediaRecord(origin, mediaId, downloadRemote, ctx) if minMedia != nil { media = minMedia.KnownMedia } @@ -53,11 +52,11 @@ func GetMedia(origin string, mediaId string, downloadRemote bool, blockForMedia return nil, err } if minMedia == nil { - log.Warn("Unexpected error while fetching media: no minimal media record") + ctx.Log.Warn("Unexpected error while fetching media: no minimal media record") return nil, common.ErrMediaNotFound } if media == nil && blockForMedia { - log.Warn("Unexpected error while fetching media: no regular media record (block for media in place)") + ctx.Log.Warn("Unexpected error while fetching media: no regular media record (block for media in place)") return nil, common.ErrMediaNotFound } @@ -67,10 +66,10 @@ func GetMedia(origin string, mediaId string, downloadRemote bool, blockForMedia if media != nil { if media.Quarantined { - log.Warn("Quarantined media accessed") + ctx.Log.Warn("Quarantined media accessed") if config.Get().Quarantine.ReplaceDownloads { - log.Info("Replacing thumbnail with a quarantined one") + ctx.Log.Info("Replacing thumbnail with a quarantined one") img, err := quarantine_controller.GenerateQuarantineThumbnail(512, 512) if err != nil { @@ -94,15 +93,15 @@ func GetMedia(origin string, mediaId string, downloadRemote bool, blockForMedia return nil, common.ErrMediaQuarantined } - err = storage.GetDatabase().GetMetadataStore(ctx, log).UpsertLastAccess(media.Sha256Hash, util.NowMillis()) + err = storage.GetDatabase().GetMetadataStore(ctx).UpsertLastAccess(media.Sha256Hash, util.NowMillis()) if err != nil { - logrus.Warn("Failed to upsert the last access time: ", err) + ctx.Log.Warn("Failed to upsert the last access time: ", err) } localCache.Set(origin+"/"+mediaId, media, cache.DefaultExpiration) internal_cache.Get().IncrementDownloads(media.Sha256Hash) - cached, err := internal_cache.Get().GetMedia(media, log) + cached, err := internal_cache.Get().GetMedia(media, ctx) if err != nil { return nil, err } @@ -113,17 +112,17 @@ func GetMedia(origin string, mediaId string, downloadRemote bool, blockForMedia } if minMedia.Stream != nil { - log.Info("Returning minimal media record with a viable stream") + ctx.Log.Info("Returning minimal media record with a viable stream") return minMedia, nil } if media == nil { - log.Error("Failed to locate media") + ctx.Log.Error("Failed to locate media") return nil, errors.New("failed to locate media") } - log.Info("Reading media from disk") - mediaStream, err := datastore.DownloadStream(ctx, log, media.DatastoreId, media.Location) + ctx.Log.Info("Reading media from disk") + mediaStream, err := datastore.DownloadStream(ctx, media.DatastoreId, media.Location) if err != nil { return nil, err } @@ -165,26 +164,26 @@ func GetMedia(origin string, mediaId string, downloadRemote bool, blockForMedia return value, err } -func FindMinimalMediaRecord(origin string, mediaId string, downloadRemote bool, ctx context.Context, log *logrus.Entry) (*types.MinimalMedia, error) { - db := storage.GetDatabase().GetMediaStore(ctx, log) +func FindMinimalMediaRecord(origin string, mediaId string, downloadRemote bool, ctx rcontext.RequestContext) (*types.MinimalMedia, error) { + db := storage.GetDatabase().GetMediaStore(ctx) var media *types.Media item, found := localCache.Get(origin + "/" + mediaId) if found { media = item.(*types.Media) } else { - log.Info("Getting media record from database") + ctx.Log.Info("Getting media record from database") dbMedia, err := db.Get(origin, mediaId) if err != nil { if err == sql.ErrNoRows { if util.IsServerOurs(origin) { - log.Warn("Media not found") + ctx.Log.Warn("Media not found") return nil, common.ErrMediaNotFound } } if !downloadRemote { - log.Warn("Remote media not being downloaded") + ctx.Log.Warn("Remote media not being downloaded") return nil, common.ErrMediaNotFound } @@ -196,13 +195,13 @@ func FindMinimalMediaRecord(origin string, mediaId string, downloadRemote bool, return nil, result.err } if result.stream == nil { - log.Info("No stream returned from remote download - attempting to create one") + ctx.Log.Info("No stream returned from remote download - attempting to create one") if result.media == nil { - log.Error("Fatal error: No stream and no media. Cannot acquire a stream for media") + ctx.Log.Error("Fatal error: No stream and no media. Cannot acquire a stream for media") return nil, errors.New("no stream available") } - stream, err := datastore.DownloadStream(ctx, log, result.media.DatastoreId, result.media.Location) + stream, err := datastore.DownloadStream(ctx, result.media.DatastoreId, result.media.Location) if err != nil { return nil, err } @@ -224,11 +223,11 @@ func FindMinimalMediaRecord(origin string, mediaId string, downloadRemote bool, } if media == nil { - log.Warn("Despite all efforts, a media record could not be found") + ctx.Log.Warn("Despite all efforts, a media record could not be found") return nil, common.ErrMediaNotFound } - mediaStream, err := datastore.DownloadStream(ctx, log, media.DatastoreId, media.Location) + mediaStream, err := datastore.DownloadStream(ctx, media.DatastoreId, media.Location) if err != nil { return nil, err } @@ -244,28 +243,28 @@ func FindMinimalMediaRecord(origin string, mediaId string, downloadRemote bool, }, nil } -func FindMediaRecord(origin string, mediaId string, downloadRemote bool, ctx context.Context, log *logrus.Entry) (*types.Media, error) { +func FindMediaRecord(origin string, mediaId string, downloadRemote bool, ctx rcontext.RequestContext) (*types.Media, error) { cacheKey := origin + "/" + mediaId v, _, err := globals.DefaultRequestGroup.DoWithoutPost(cacheKey, func() (interface{}, error) { - db := storage.GetDatabase().GetMediaStore(ctx, log) + db := storage.GetDatabase().GetMediaStore(ctx) var media *types.Media item, found := localCache.Get(cacheKey) if found { media = item.(*types.Media) } else { - log.Info("Getting media record from database") + ctx.Log.Info("Getting media record from database") dbMedia, err := db.Get(origin, mediaId) if err != nil { if err == sql.ErrNoRows { if util.IsServerOurs(origin) { - log.Warn("Media not found") + ctx.Log.Warn("Media not found") return nil, common.ErrMediaNotFound } } if !downloadRemote { - log.Warn("Remote media not being downloaded") + ctx.Log.Warn("Remote media not being downloaded") return nil, common.ErrMediaNotFound } @@ -283,7 +282,7 @@ func FindMediaRecord(origin string, mediaId string, downloadRemote bool, ctx con } if media == nil { - log.Warn("Despite all efforts, a media record could not be found") + ctx.Log.Warn("Despite all efforts, a media record could not be found") return nil, common.ErrMediaNotFound } diff --git a/controllers/download_controller/download_resource_handler.go b/controllers/download_controller/download_resource_handler.go index 1bff8adc25d7567de1a7b4bdc6e23111eca0e5c8..a113bda1a1c893cf6dfb583fb10f14b229b462a9 100644 --- a/controllers/download_controller/download_resource_handler.go +++ b/controllers/download_controller/download_resource_handler.go @@ -1,7 +1,6 @@ package download_controller import ( - "context" "errors" "io" "io/ioutil" @@ -16,6 +15,7 @@ import ( "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/common" "github.com/turt2live/matrix-media-repo/common/config" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/controllers/upload_controller" "github.com/turt2live/matrix-media-repo/matrix" "github.com/turt2live/matrix-media-repo/metrics" @@ -115,24 +115,22 @@ func (h *mediaResourceHandler) DownloadRemoteMedia(origin string, mediaId string func downloadResourceWorkFn(request *resource_handler.WorkRequest) interface{} { info := request.Metadata.(*downloadRequest) - log := logrus.WithFields(logrus.Fields{ + ctx := rcontext.Initial().LogWithFields(logrus.Fields{ "worker_requestId": request.Id, "worker_requestOrigin": info.origin, "worker_requestMediaId": info.mediaId, "worker_blockForMedia": info.blockForMedia, }) - log.Info("Downloading remote media") + ctx.Log.Info("Downloading remote media") - ctx := context.TODO() // TODO: Should we use a real context? - - downloaded, err := DownloadRemoteMediaDirect(info.origin, info.mediaId, log) + downloaded, err := DownloadRemoteMediaDirect(info.origin, info.mediaId, ctx) if err != nil { return &workerDownloadResponse{err: err} } - log.Info("Checking to ensure the reported content type is allowed...") - if downloaded.ContentType != "" && !upload_controller.IsAllowed(downloaded.ContentType, downloaded.ContentType, upload_controller.NoApplicableUploadUser, log) { - log.Error("Remote media failed the preliminary IsAllowed check based on content type (reported as " + downloaded.ContentType + ")") + ctx.Log.Info("Checking to ensure the reported content type is allowed...") + if downloaded.ContentType != "" && !upload_controller.IsAllowed(downloaded.ContentType, downloaded.ContentType, upload_controller.NoApplicableUploadUser, ctx) { + ctx.Log.Error("Remote media failed the preliminary IsAllowed check based on content type (reported as " + downloaded.ContentType + ")") return &workerDownloadResponse{err: common.ErrMediaNotAllowed} } @@ -140,22 +138,22 @@ func downloadResourceWorkFn(request *resource_handler.WorkRequest) interface{} { defer fileStream.Close() userId := upload_controller.NoApplicableUploadUser - media, err := upload_controller.StoreDirect(fileStream, downloaded.ContentLength, downloaded.ContentType, downloaded.DesiredFilename, userId, info.origin, info.mediaId, common.KindRemoteMedia, ctx, log) + media, err := upload_controller.StoreDirect(fileStream, downloaded.ContentLength, downloaded.ContentType, downloaded.DesiredFilename, userId, info.origin, info.mediaId, common.KindRemoteMedia, ctx) if err != nil { - log.Error("Error persisting file: ", err) + ctx.Log.Error("Error persisting file: ", err) return &workerDownloadResponse{err: err} } - log.Info("Remote media persisted under datastore ", media.DatastoreId, " at ", media.Location) + ctx.Log.Info("Remote media persisted under datastore ", media.DatastoreId, " at ", media.Location) return &workerDownloadResponse{media: media} } if info.blockForMedia { - log.Warn("Not streaming remote media download request due to request for a block") + ctx.Log.Warn("Not streaming remote media download request due to request for a block") return persistFile(downloaded.Contents) } - log.Info("Streaming remote media to filesystem and requesting party at the same time") + ctx.Log.Info("Streaming remote media to filesystem and requesting party at the same time") reader, writer := io.Pipe() tr := io.TeeReader(downloaded.Contents, writer) @@ -174,7 +172,7 @@ func downloadResourceWorkFn(request *resource_handler.WorkRequest) interface{} { } } -func DownloadRemoteMediaDirect(server string, mediaId string, log *logrus.Entry) (*downloadedMedia, error) { +func DownloadRemoteMediaDirect(server string, mediaId string, ctx rcontext.RequestContext) (*downloadedMedia, error) { if downloadErrorsCache == nil { downloadErrorCacheSingletonLock.Do(func() { cacheTime := time.Duration(config.Get().Downloads.FailureCacheMinutes) * time.Minute @@ -185,7 +183,7 @@ func DownloadRemoteMediaDirect(server string, mediaId string, log *logrus.Entry) cacheKey := server + "/" + mediaId item, found := downloadErrorsCache.Get(cacheKey) if found { - log.Warn("Returning cached error for remote media download failure") + ctx.Log.Warn("Returning cached error for remote media download failure") return nil, item.(error) } @@ -203,13 +201,13 @@ func DownloadRemoteMediaDirect(server string, mediaId string, log *logrus.Entry) } if resp.StatusCode == 404 { - log.Info("Remote media not found") + ctx.Log.Info("Remote media not found") err = common.ErrMediaNotFound downloadErrorsCache.Set(cacheKey, err, cache.DefaultExpiration) return nil, err } else if resp.StatusCode != 200 { - log.Info("Unknown error fetching remote media; received status code " + strconv.Itoa(resp.StatusCode)) + ctx.Log.Info("Unknown error fetching remote media; received status code " + strconv.Itoa(resp.StatusCode)) err = errors.New("could not fetch remote media") downloadErrorsCache.Set(cacheKey, err, cache.DefaultExpiration) @@ -223,11 +221,11 @@ func DownloadRemoteMediaDirect(server string, mediaId string, log *logrus.Entry) return nil, err } } else { - log.Warn("Missing Content-Length header on response - continuing anyway") + ctx.Log.Warn("Missing Content-Length header on response - continuing anyway") } if contentLength > 0 && config.Get().Downloads.MaxSizeBytes > 0 && contentLength > config.Get().Downloads.MaxSizeBytes { - log.Warn("Attempted to download media that was too large") + ctx.Log.Warn("Attempted to download media that was too large") err = common.ErrMediaTooLarge downloadErrorsCache.Set(cacheKey, err, cache.DefaultExpiration) @@ -236,7 +234,7 @@ func DownloadRemoteMediaDirect(server string, mediaId string, log *logrus.Entry) contentType := resp.Header.Get("Content-Type") if contentType == "" { - log.Warn("Remote media has no content type; Assuming application/octet-stream") + ctx.Log.Warn("Remote media has no content type; Assuming application/octet-stream") contentType = "application/octet-stream" // binary } @@ -252,7 +250,7 @@ func DownloadRemoteMediaDirect(server string, mediaId string, log *logrus.Entry) request.DesiredFilename = params["filename"] } - log.Info("Persisting downloaded media") + ctx.Log.Info("Persisting downloaded media") metrics.MediaDownloaded.With(prometheus.Labels{"origin": server}).Inc() return request, nil } diff --git a/controllers/maintenance_controller/maintainance_controller.go b/controllers/maintenance_controller/maintainance_controller.go index 0fc364d5f6260be4c56e60eef7b0e3116cf49d5b..3890dae6e0f4afc1291fdcdc943fc3bdf0c23dc3 100644 --- a/controllers/maintenance_controller/maintainance_controller.go +++ b/controllers/maintenance_controller/maintainance_controller.go @@ -1,12 +1,12 @@ package maintenance_controller import ( - "context" "database/sql" "fmt" "os" "github.com/sirupsen/logrus" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/controllers/download_controller" "github.com/turt2live/matrix-media-repo/storage" "github.com/turt2live/matrix-media-repo/storage/datastore" @@ -15,10 +15,8 @@ import ( ) // Returns an error only if starting up the background task failed. -func StartStorageMigration(sourceDs *datastore.DatastoreRef, targetDs *datastore.DatastoreRef, beforeTs int64, log *logrus.Entry) (*types.BackgroundTask, error) { - ctx := context.Background() - - db := storage.GetDatabase().GetMetadataStore(ctx, log) +func StartStorageMigration(sourceDs *datastore.DatastoreRef, targetDs *datastore.DatastoreRef, beforeTs int64, ctx rcontext.RequestContext) (*types.BackgroundTask, error) { + db := storage.GetDatabase().GetMetadataStore(ctx) task, err := db.CreateBackgroundTask("storage_migration", map[string]interface{}{ "source_datastore_id": sourceDs.DatastoreId, "target_datastore_id": targetDs.DatastoreId, @@ -29,82 +27,81 @@ func StartStorageMigration(sourceDs *datastore.DatastoreRef, targetDs *datastore } go func() { - log.Info("Starting transfer") + ctx.Log.Info("Starting transfer") - db := storage.GetDatabase().GetMetadataStore(ctx, log) + db := storage.GetDatabase().GetMetadataStore(ctx) - origLog := log doUpdate := func(records []*types.MinimalMediaMetadata) { for _, record := range records { - log := origLog.WithFields(logrus.Fields{"mediaSha256": record.Sha256Hash}) + rctx := ctx.LogWithFields(logrus.Fields{"mediaSha256": record.Sha256Hash}) - log.Info("Starting transfer of media") + rctx.Log.Info("Starting transfer of media") sourceStream, err := sourceDs.DownloadFile(record.Location) if err != nil { - log.Error(err) - log.Error("Failed to start download from source datastore") + rctx.Log.Error(err) + rctx.Log.Error("Failed to start download from source datastore") continue } - newLocation, err := targetDs.UploadFile(sourceStream, record.SizeBytes, ctx, log) + newLocation, err := targetDs.UploadFile(sourceStream, record.SizeBytes, rctx) if err != nil { - log.Error(err) - log.Error("Failed to upload file to target datastore") + rctx.Log.Error(err) + rctx.Log.Error("Failed to upload file to target datastore") continue } - log.Info("Updating media records...") + rctx.Log.Info("Updating media records...") err = db.ChangeDatastoreOfHash(targetDs.DatastoreId, newLocation.Location, record.Sha256Hash) if err != nil { - log.Error(err) - log.Error("Failed to update database records") + rctx.Log.Error(err) + rctx.Log.Error("Failed to update database records") continue } - log.Info("Deleting media from old datastore") + rctx.Log.Info("Deleting media from old datastore") err = sourceDs.DeleteObject(record.Location) if err != nil { - log.Error(err) - log.Error("Failed to delete old media") + rctx.Log.Error(err) + rctx.Log.Error("Failed to delete old media") continue } - log.Info("Media updated!") + rctx.Log.Info("Media updated!") } } media, err := db.GetOldMediaInDatastore(sourceDs.DatastoreId, beforeTs) if err != nil { - log.Error(err) + ctx.Log.Error(err) return } doUpdate(media) thumbs, err := db.GetOldThumbnailsInDatastore(sourceDs.DatastoreId, beforeTs) if err != nil { - log.Error(err) + ctx.Log.Error(err) return } doUpdate(thumbs) err = db.FinishedBackgroundTask(task.ID) if err != nil { - log.Error(err) - log.Error("Failed to flag task as finished") + ctx.Log.Error(err) + ctx.Log.Error("Failed to flag task as finished") } - log.Info("Finished transfer") + ctx.Log.Info("Finished transfer") }() return task, nil } -func EstimateDatastoreSizeWithAge(beforeTs int64, datastoreId string, ctx context.Context, log *logrus.Entry) (*types.DatastoreMigrationEstimate, error) { +func EstimateDatastoreSizeWithAge(beforeTs int64, datastoreId string, ctx rcontext.RequestContext) (*types.DatastoreMigrationEstimate, error) { estimates := &types.DatastoreMigrationEstimate{} seenHashes := make(map[string]bool) seenMediaHashes := make(map[string]bool) seenThumbnailHashes := make(map[string]bool) - db := storage.GetDatabase().GetMetadataStore(ctx, log) + db := storage.GetDatabase().GetMetadataStore(ctx) media, err := db.GetOldMediaInDatastore(datastoreId, beforeTs) if err != nil { return nil, err @@ -150,9 +147,9 @@ func EstimateDatastoreSizeWithAge(beforeTs int64, datastoreId string, ctx contex return estimates, nil } -func PurgeRemoteMediaBefore(beforeTs int64, ctx context.Context, log *logrus.Entry) (int, error) { - db := storage.GetDatabase().GetMediaStore(ctx, log) - thumbsDb := storage.GetDatabase().GetThumbnailStore(ctx, log) +func PurgeRemoteMediaBefore(beforeTs int64, ctx rcontext.RequestContext) (int, error) { + db := storage.GetDatabase().GetMediaStore(ctx) + thumbsDb := storage.GetDatabase().GetThumbnailStore(ctx) origins, err := db.GetOrigins() if err != nil { @@ -171,74 +168,74 @@ func PurgeRemoteMediaBefore(beforeTs int64, ctx context.Context, log *logrus.Ent return 0, err } - log.Info(fmt.Sprintf("Starting removal of %d remote media files (db records will be kept)", len(oldMedia))) + ctx.Log.Info(fmt.Sprintf("Starting removal of %d remote media files (db records will be kept)", len(oldMedia))) removed := 0 for _, media := range oldMedia { if media.Quarantined { - log.Warn("Not removing quarantined media to maintain quarantined status: " + media.Origin + "/" + media.MediaId) + ctx.Log.Warn("Not removing quarantined media to maintain quarantined status: " + media.Origin + "/" + media.MediaId) continue } - ds, err := datastore.LocateDatastore(context.TODO(), &logrus.Entry{}, media.DatastoreId) + ds, err := datastore.LocateDatastore(ctx, media.DatastoreId) if err != nil { - log.Error("Error finding datastore for media " + media.Origin + "/" + media.MediaId + " because: " + err.Error()) + ctx.Log.Error("Error finding datastore for media " + media.Origin + "/" + media.MediaId + " because: " + err.Error()) continue } // Delete the file first err = ds.DeleteObject(media.Location) if err != nil { - log.Warn("Cannot remove media " + media.Origin + "/" + media.MediaId + " because: " + err.Error()) + ctx.Log.Warn("Cannot remove media " + media.Origin + "/" + media.MediaId + " because: " + err.Error()) } else { removed++ - log.Info("Removed remote media file: " + media.Origin + "/" + media.MediaId) + ctx.Log.Info("Removed remote media file: " + media.Origin + "/" + media.MediaId) } // Try to remove the record from the database now err = db.Delete(media.Origin, media.MediaId) if err != nil { - log.Warn("Error removing media " + media.Origin + "/" + media.MediaId + " from database: " + err.Error()) + ctx.Log.Warn("Error removing media " + media.Origin + "/" + media.MediaId + " from database: " + err.Error()) } // Delete the thumbnails too thumbs, err := thumbsDb.GetAllForMedia(media.Origin, media.MediaId) if err != nil { - log.Warn("Error getting thumbnails for media " + media.Origin + "/" + media.MediaId + " from database: " + err.Error()) + ctx.Log.Warn("Error getting thumbnails for media " + media.Origin + "/" + media.MediaId + " from database: " + err.Error()) continue } for _, thumb := range thumbs { - log.Info("Deleting thumbnail with hash: ", thumb.Sha256Hash) - ds, err := datastore.LocateDatastore(ctx, log, thumb.DatastoreId) + ctx.Log.Info("Deleting thumbnail with hash: ", thumb.Sha256Hash) + ds, err := datastore.LocateDatastore(ctx, thumb.DatastoreId) if err != nil { - log.Warn("Error removing thumbnail for media " + media.Origin + "/" + media.MediaId + " from database: " + err.Error()) + ctx.Log.Warn("Error removing thumbnail for media " + media.Origin + "/" + media.MediaId + " from database: " + err.Error()) continue } err = ds.DeleteObject(thumb.Location) if err != nil { - log.Warn("Error removing thumbnail for media " + media.Origin + "/" + media.MediaId + " from database: " + err.Error()) + ctx.Log.Warn("Error removing thumbnail for media " + media.Origin + "/" + media.MediaId + " from database: " + err.Error()) continue } } err = thumbsDb.DeleteAllForMedia(media.Origin, media.MediaId) if err != nil { - log.Warn("Error removing thumbnails for media " + media.Origin + "/" + media.MediaId + " from database: " + err.Error()) + ctx.Log.Warn("Error removing thumbnails for media " + media.Origin + "/" + media.MediaId + " from database: " + err.Error()) } } return removed, nil } -func PurgeQuarantined(ctx context.Context, log *logrus.Entry) ([]*types.Media, error) { - mediaDb := storage.GetDatabase().GetMediaStore(ctx, log) +func PurgeQuarantined(ctx rcontext.RequestContext) ([]*types.Media, error) { + mediaDb := storage.GetDatabase().GetMediaStore(ctx) records, err := mediaDb.GetAllQuarantinedMedia() if err != nil { return nil, err } for _, r := range records { - err = doPurge(r, ctx, log) + err = doPurge(r, ctx) if err != nil { return nil, err } @@ -247,15 +244,15 @@ func PurgeQuarantined(ctx context.Context, log *logrus.Entry) ([]*types.Media, e return records, nil } -func PurgeQuarantinedFor(serverName string, ctx context.Context, log *logrus.Entry) ([]*types.Media, error) { - mediaDb := storage.GetDatabase().GetMediaStore(ctx, log) +func PurgeQuarantinedFor(serverName string, ctx rcontext.RequestContext) ([]*types.Media, error) { + mediaDb := storage.GetDatabase().GetMediaStore(ctx) records, err := mediaDb.GetQuarantinedMediaFor(serverName) if err != nil { return nil, err } for _, r := range records { - err = doPurge(r, ctx, log) + err = doPurge(r, ctx) if err != nil { return nil, err } @@ -264,15 +261,15 @@ func PurgeQuarantinedFor(serverName string, ctx context.Context, log *logrus.Ent return records, nil } -func PurgeUserMedia(userId string, beforeTs int64, ctx context.Context, log *logrus.Entry) ([]*types.Media, error) { - mediaDb := storage.GetDatabase().GetMediaStore(ctx, log) +func PurgeUserMedia(userId string, beforeTs int64, ctx rcontext.RequestContext) ([]*types.Media, error) { + mediaDb := storage.GetDatabase().GetMediaStore(ctx) records, err := mediaDb.GetMediaByUserBefore(userId, beforeTs) if err != nil { return nil, err } for _, r := range records { - err = doPurge(r, ctx, log) + err = doPurge(r, ctx) if err != nil { return nil, err } @@ -281,9 +278,9 @@ func PurgeUserMedia(userId string, beforeTs int64, ctx context.Context, log *log return records, nil } -func PurgeOldMedia(beforeTs int64, includeLocal bool, ctx context.Context, log *logrus.Entry) ([]*types.Media, error) { - metadataDb := storage.GetDatabase().GetMetadataStore(ctx, log) - mediaDb := storage.GetDatabase().GetMediaStore(ctx, log) +func PurgeOldMedia(beforeTs int64, includeLocal bool, ctx rcontext.RequestContext) ([]*types.Media, error) { + metadataDb := storage.GetDatabase().GetMetadataStore(ctx) + mediaDb := storage.GetDatabase().GetMediaStore(ctx) oldHashes, err := metadataDb.GetOldMedia(beforeTs) if err != nil { @@ -303,7 +300,7 @@ func PurgeOldMedia(beforeTs int64, includeLocal bool, ctx context.Context, log * continue } - err = doPurge(m, ctx, log) + err = doPurge(m, ctx) if err != nil { return nil, err } @@ -315,8 +312,8 @@ func PurgeOldMedia(beforeTs int64, includeLocal bool, ctx context.Context, log * return purged, nil } -func PurgeRoomMedia(mxcs []string, beforeTs int64, ctx context.Context, log *logrus.Entry) ([]*types.Media, error) { - mediaDb := storage.GetDatabase().GetMediaStore(ctx, log) +func PurgeRoomMedia(mxcs []string, beforeTs int64, ctx rcontext.RequestContext) ([]*types.Media, error) { + mediaDb := storage.GetDatabase().GetMediaStore(ctx) purged := make([]*types.Media, 0) @@ -339,7 +336,7 @@ func PurgeRoomMedia(mxcs []string, beforeTs int64, ctx context.Context, log *log continue } - err = doPurge(record, ctx, log) + err = doPurge(record, ctx) if err != nil { return nil, err } @@ -350,15 +347,15 @@ func PurgeRoomMedia(mxcs []string, beforeTs int64, ctx context.Context, log *log return purged, nil } -func PurgeDomainMedia(serverName string, beforeTs int64, ctx context.Context, log *logrus.Entry) ([]*types.Media, error) { - mediaDb := storage.GetDatabase().GetMediaStore(ctx, log) +func PurgeDomainMedia(serverName string, beforeTs int64, ctx rcontext.RequestContext) ([]*types.Media, error) { + mediaDb := storage.GetDatabase().GetMediaStore(ctx) records, err := mediaDb.GetMediaByDomainBefore(serverName, beforeTs) if err != nil { return nil, err } for _, r := range records { - err = doPurge(r, ctx, log) + err = doPurge(r, ctx) if err != nil { return nil, err } @@ -367,25 +364,25 @@ func PurgeDomainMedia(serverName string, beforeTs int64, ctx context.Context, lo return records, nil } -func PurgeMedia(origin string, mediaId string, ctx context.Context, log *logrus.Entry) error { - media, err := download_controller.FindMediaRecord(origin, mediaId, false, ctx, log) +func PurgeMedia(origin string, mediaId string, ctx rcontext.RequestContext) error { + media, err := download_controller.FindMediaRecord(origin, mediaId, false, ctx) if err != nil { return err } - return doPurge(media, ctx, log) + return doPurge(media, ctx) } -func doPurge(media *types.Media, ctx context.Context, log *logrus.Entry) error { +func doPurge(media *types.Media, ctx rcontext.RequestContext) error { // Delete all the thumbnails first - thumbsDb := storage.GetDatabase().GetThumbnailStore(ctx, log) + thumbsDb := storage.GetDatabase().GetThumbnailStore(ctx) thumbs, err := thumbsDb.GetAllForMedia(media.Origin, media.MediaId) if err != nil { return err } for _, thumb := range thumbs { - log.Info("Deleting thumbnail with hash: ", thumb.Sha256Hash) - ds, err := datastore.LocateDatastore(ctx, log, thumb.DatastoreId) + ctx.Log.Info("Deleting thumbnail with hash: ", thumb.Sha256Hash) + ds, err := datastore.LocateDatastore(ctx, thumb.DatastoreId) if err != nil { return err } @@ -400,12 +397,12 @@ func doPurge(media *types.Media, ctx context.Context, log *logrus.Entry) error { return err } - ds, err := datastore.LocateDatastore(ctx, log, media.DatastoreId) + ds, err := datastore.LocateDatastore(ctx, media.DatastoreId) if err != nil { return err } - mediaDb := storage.GetDatabase().GetMediaStore(ctx, log) + mediaDb := storage.GetDatabase().GetMediaStore(ctx) similarMedia, err := mediaDb.GetByHash(media.Sha256Hash) if err != nil { return err @@ -424,10 +421,10 @@ func doPurge(media *types.Media, ctx context.Context, log *logrus.Entry) error { return err } } else { - log.Warnf("Not deleting media from datastore: media is shared over %d objects", len(similarMedia)) + ctx.Log.Warnf("Not deleting media from datastore: media is shared over %d objects", len(similarMedia)) } - metadataDb := storage.GetDatabase().GetMetadataStore(ctx, log) + metadataDb := storage.GetDatabase().GetMetadataStore(ctx) reserved, err := metadataDb.IsReserved(media.Origin, media.MediaId) if err != nil { diff --git a/controllers/preview_controller/acl/acl.go b/controllers/preview_controller/acl/acl.go index 45e3b19315a00b985d9ff361dbc9af0df166f7ec..d46908d31f5f40f02488a9da4ae2bda8d1538231 100644 --- a/controllers/preview_controller/acl/acl.go +++ b/controllers/preview_controller/acl/acl.go @@ -1,7 +1,6 @@ package acl import ( - "context" "fmt" "net" "net/url" @@ -9,22 +8,23 @@ import ( "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/common" "github.com/turt2live/matrix-media-repo/common/config" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/controllers/preview_controller/preview_types" "github.com/turt2live/matrix-media-repo/storage" ) -func ValidateUrlForPreview(urlStr string, ctx context.Context, log *logrus.Entry) (*preview_types.UrlPayload, error) { - db := storage.GetDatabase().GetUrlStore(ctx, log) +func ValidateUrlForPreview(urlStr string, ctx rcontext.RequestContext) (*preview_types.UrlPayload, error) { + db := storage.GetDatabase().GetUrlStore(ctx) parsedUrl, err := url.ParseRequestURI(urlStr) if err != nil { - log.Error("Error parsing URL: ", err.Error()) + ctx.Log.Error("Error parsing URL: ", err.Error()) db.InsertPreviewError(urlStr, common.ErrCodeInvalidHost) return nil, common.ErrInvalidHost } realHost, _, err := net.SplitHostPort(parsedUrl.Host) if err != nil { - log.Error("Error parsing host and port: ", err.Error()) + ctx.Log.Error("Error parsing host and port: ", err.Error()) realHost = parsedUrl.Host } @@ -32,7 +32,7 @@ func ValidateUrlForPreview(urlStr string, ctx context.Context, log *logrus.Entry if realHost != "localhost" { addrs, err := net.LookupIP(realHost) if err != nil { - log.Error("Error getting host info: ", err.Error()) + ctx.Log.Error("Error getting host info: ", err.Error()) db.InsertPreviewError(urlStr, common.ErrCodeInvalidHost) return nil, common.ErrInvalidHost } @@ -56,7 +56,7 @@ func ValidateUrlForPreview(urlStr string, ctx context.Context, log *logrus.Entry deniedCidrs = append(deniedCidrs, "0.0.0.0/32") deniedCidrs = append(deniedCidrs, "::/128") - if !isAllowed(addr, allowedCidrs, deniedCidrs, log) { + if !isAllowed(addr, allowedCidrs, deniedCidrs, ctx) { db.InsertPreviewError(urlStr, common.ErrCodeHostBlacklisted) return nil, common.ErrHostBlacklisted } @@ -69,38 +69,38 @@ func ValidateUrlForPreview(urlStr string, ctx context.Context, log *logrus.Entry return urlToPreview, nil } -func isAllowed(ip net.IP, allowed []string, disallowed []string, log *logrus.Entry) bool { - log = log.WithFields(logrus.Fields{ +func isAllowed(ip net.IP, allowed []string, disallowed []string, ctx rcontext.RequestContext) bool { + ctx = ctx.LogWithFields(logrus.Fields{ "checkHost": ip, "allowedHosts": fmt.Sprintf("%v", allowed), "disallowedHosts": fmt.Sprintf("%v", allowed), }) - log.Info("Validating host") + ctx.Log.Info("Validating host") // First check if the IP fits the blacklist. This should be a much shorter list, and therefore // much faster to check. - log.Info("Checking blacklist for host...") - if inRange(ip, disallowed, log) { - log.Warn("Host found on blacklist - rejecting") + ctx.Log.Info("Checking blacklist for host...") + if inRange(ip, disallowed, ctx) { + ctx.Log.Warn("Host found on blacklist - rejecting") return false } // Now check the allowed list just to make sure the IP is actually allowed - if inRange(ip, allowed, log) { - log.Info("Host allowed due to whitelist") + if inRange(ip, allowed, ctx) { + ctx.Log.Info("Host allowed due to whitelist") return true } - log.Warn("Host is not on either whitelist or blacklist, considering blacklisted") + ctx.Log.Warn("Host is not on either whitelist or blacklist, considering blacklisted") return false } -func inRange(ip net.IP, cidrs []string, log *logrus.Entry) bool { +func inRange(ip net.IP, cidrs []string, ctx rcontext.RequestContext) bool { for i := 0; i < len(cidrs); i++ { cidr := cidrs[i] _, network, err := net.ParseCIDR(cidr) if err != nil { - log.Error("Error checking host: " + err.Error()) + ctx.Log.Error("Error checking host: " + err.Error()) return false } if network.Contains(ip) { diff --git a/controllers/preview_controller/preview_controller.go b/controllers/preview_controller/preview_controller.go index e0acd5afdc3151b08bccd64505b795b1b041e1bd..bd5ef113fd790b26a7d44d790285b0aff9720a8c 100644 --- a/controllers/preview_controller/preview_controller.go +++ b/controllers/preview_controller/preview_controller.go @@ -1,7 +1,6 @@ package preview_controller import ( - "context" "database/sql" "errors" "fmt" @@ -9,6 +8,7 @@ import ( "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/common" "github.com/turt2live/matrix-media-repo/common/globals" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/controllers/preview_controller/acl" "github.com/turt2live/matrix-media-repo/storage" "github.com/turt2live/matrix-media-repo/storage/stores" @@ -16,24 +16,24 @@ import ( "github.com/turt2live/matrix-media-repo/util" ) -func GetPreview(urlStr string, onHost string, forUserId string, atTs int64, ctx context.Context, log *logrus.Entry) (*types.UrlPreview, error) { +func GetPreview(urlStr string, onHost string, forUserId string, atTs int64, ctx rcontext.RequestContext) (*types.UrlPreview, error) { atTs = stores.GetBucketTs(atTs) cacheKey := fmt.Sprintf("%d_%s/%s", atTs, onHost, urlStr) v, _, err := globals.DefaultRequestGroup.DoWithoutPost(cacheKey, func() (interface{}, error) { - log = log.WithFields(logrus.Fields{ + ctx := ctx.LogWithFields(logrus.Fields{ "preview_controller_at_ts": atTs, }) - db := storage.GetDatabase().GetUrlStore(ctx, log) + db := storage.GetDatabase().GetUrlStore(ctx) cached, err := db.GetPreview(urlStr, atTs) if err != nil && err != sql.ErrNoRows { - log.Error("Error getting cached URL preview: ", err.Error()) + ctx.Log.Error("Error getting cached URL preview: ", err.Error()) return nil, err } if err != sql.ErrNoRows { - log.Info("Returning cached URL preview") + ctx.Log.Info("Returning cached URL preview") return cachedPreviewToReal(cached) } @@ -44,12 +44,12 @@ func GetPreview(urlStr string, onHost string, forUserId string, atTs int64, ctx // Because we don't have a cached preview, we'll use the current time as the preview time. // We also give a 60 second buffer so we don't cause an infinite loop (considering we're // calling ourselves), and to give a lenient opportunity for slow execution. - return GetPreview(urlStr, onHost, forUserId, now, ctx, log) + return GetPreview(urlStr, onHost, forUserId, now, ctx) } - log.Info("Preview not cached - fetching resource") + ctx.Log.Info("Preview not cached - fetching resource") - urlToPreview, err := acl.ValidateUrlForPreview(urlStr, ctx, log) + urlToPreview, err := acl.ValidateUrlForPreview(urlStr, ctx) if err != nil { return nil, err } diff --git a/controllers/preview_controller/preview_resource_handler.go b/controllers/preview_controller/preview_resource_handler.go index 6e231b74159a4150a0956995fa0c00d3b1472d4a..a241cd8ae35ae18ecc09cf83936fa8c740aa3219 100644 --- a/controllers/preview_controller/preview_resource_handler.go +++ b/controllers/preview_controller/preview_resource_handler.go @@ -1,7 +1,6 @@ package preview_controller import ( - "context" "fmt" "sync" @@ -9,6 +8,7 @@ import ( "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/common" "github.com/turt2live/matrix-media-repo/common/config" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/controllers/preview_controller/preview_types" "github.com/turt2live/matrix-media-repo/controllers/preview_controller/previewers" "github.com/turt2live/matrix-media-repo/controllers/upload_controller" @@ -54,22 +54,21 @@ func getResourceHandler() *urlResourceHandler { func urlPreviewWorkFn(request *resource_handler.WorkRequest) interface{} { info := request.Metadata.(*urlPreviewRequest) - log := logrus.WithFields(logrus.Fields{ + ctx := rcontext.Initial().LogWithFields(logrus.Fields{ "worker_requestId": request.Id, "worker_url": info.urlPayload.UrlString, "worker_previewer": "OpenGraph", }) - log.Info("Processing url preview request") + ctx.Log.Info("Processing url preview request") - ctx := context.TODO() // TODO: Should we use a real context? - db := storage.GetDatabase().GetUrlStore(ctx, log) + db := storage.GetDatabase().GetUrlStore(ctx) - preview, err := previewers.GenerateOpenGraphPreview(info.urlPayload, log) + preview, err := previewers.GenerateOpenGraphPreview(info.urlPayload, ctx) if err == preview_types.ErrPreviewUnsupported { - log.Info("OpenGraph preview for this URL is unsupported - treating it as a file") - log = log.WithFields(logrus.Fields{"worker_previewer": "File"}) + ctx.Log.Info("OpenGraph preview for this URL is unsupported - treating it as a file") + ctx = ctx.LogWithFields(logrus.Fields{"worker_previewer": "File"}) - preview, err = previewers.GenerateCalculatedPreview(info.urlPayload, log) + preview, err = previewers.GenerateCalculatedPreview(info.urlPayload, ctx) } if err != nil { // Transparently convert "unsupported" to "not found" for processing @@ -98,17 +97,17 @@ func urlPreviewWorkFn(request *resource_handler.WorkRequest) interface{} { contentLength := upload_controller.EstimateContentLength(preview.Image.ContentLength, preview.Image.ContentLengthHeader) // UploadMedia will close the read stream for the thumbnail and dedupe the image - media, err := upload_controller.UploadMedia(preview.Image.Data, contentLength, preview.Image.ContentType, preview.Image.Filename, info.forUserId, info.onHost, ctx, log) + media, err := upload_controller.UploadMedia(preview.Image.Data, contentLength, preview.Image.ContentType, preview.Image.Filename, info.forUserId, info.onHost, ctx) if err != nil { - log.Warn("Non-fatal error storing preview thumbnail: " + err.Error()) + ctx.Log.Warn("Non-fatal error storing preview thumbnail: " + err.Error()) } else { - mediaStream, err := datastore.DownloadStream(ctx, log, media.DatastoreId, media.Location) + mediaStream, err := datastore.DownloadStream(ctx, media.DatastoreId, media.Location) if err != nil { - log.Warn("Non-fatal error streaming datastore file: " + err.Error()) + ctx.Log.Warn("Non-fatal error streaming datastore file: " + err.Error()) } else { img, err := imaging.Decode(mediaStream) if err != nil { - log.Warn("Non-fatal error getting thumbnail dimensions: " + err.Error()) + ctx.Log.Warn("Non-fatal error getting thumbnail dimensions: " + err.Error()) } else { result.ImageMxc = media.MxcUri() result.ImageType = media.ContentType @@ -128,7 +127,7 @@ func urlPreviewWorkFn(request *resource_handler.WorkRequest) interface{} { } err = db.InsertPreview(dbRecord) if err != nil { - log.Warn("Error caching URL preview: " + err.Error()) + ctx.Log.Warn("Error caching URL preview: " + err.Error()) // Non-fatal: Just report it and move on. The worst that happens is we re-cache it. } diff --git a/controllers/preview_controller/previewers/calculated_previewer.go b/controllers/preview_controller/previewers/calculated_previewer.go index 0bae659f1178a7cd20866e75280d325c4945ac74..7f933a2cb19d753e9df94dc743590b9db8173519 100644 --- a/controllers/preview_controller/previewers/calculated_previewer.go +++ b/controllers/preview_controller/previewers/calculated_previewer.go @@ -5,18 +5,18 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/ryanuber/go-glob" - "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/common" "github.com/turt2live/matrix-media-repo/common/config" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/controllers/preview_controller/preview_types" "github.com/turt2live/matrix-media-repo/metrics" "github.com/turt2live/matrix-media-repo/util" ) -func GenerateCalculatedPreview(urlPayload *preview_types.UrlPayload, log *logrus.Entry) (preview_types.PreviewResult, error) { - bytes, filename, contentType, contentLength, err := downloadRawContent(urlPayload, config.Get().UrlPreviews.FilePreviewTypes, log) +func GenerateCalculatedPreview(urlPayload *preview_types.UrlPayload, ctx rcontext.RequestContext) (preview_types.PreviewResult, error) { + bytes, filename, contentType, contentLength, err := downloadRawContent(urlPayload, config.Get().UrlPreviews.FilePreviewTypes, ctx) if err != nil { - log.Error("Error downloading content: " + err.Error()) + ctx.Log.Error("Error downloading content: " + err.Error()) // Make sure the unsupported error gets passed through if err == preview_types.ErrPreviewUnsupported { diff --git a/controllers/preview_controller/previewers/http.go b/controllers/preview_controller/previewers/http.go index 642a24da942ac55d21841c68842a92b15dc94c2e..e66fdd125ffff3187f2e4df9d69decdc23b94003 100644 --- a/controllers/preview_controller/previewers/http.go +++ b/controllers/preview_controller/previewers/http.go @@ -14,13 +14,13 @@ import ( "time" "github.com/ryanuber/go-glob" - "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/common" "github.com/turt2live/matrix-media-repo/common/config" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/controllers/preview_controller/preview_types" ) -func doHttpGet(urlPayload *preview_types.UrlPayload, log *logrus.Entry) (*http.Response, error) { +func doHttpGet(urlPayload *preview_types.UrlPayload, ctx rcontext.RequestContext) (*http.Response, error) { var client *http.Client dialer := &net.Dialer{ @@ -72,7 +72,7 @@ func doHttpGet(urlPayload *preview_types.UrlPayload, log *logrus.Entry) (*http.R } if config.Get().UrlPreviews.UnsafeCertificates { - log.Warn("Ignoring any certificate errors while making request") + ctx.Log.Warn("Ignoring any certificate errors while making request") tr := &http.Transport{ DisableKeepAlives: true, DialContext: dialContext, @@ -116,14 +116,14 @@ func doHttpGet(urlPayload *preview_types.UrlPayload, log *logrus.Entry) (*http.R return client.Do(req) } -func downloadRawContent(urlPayload *preview_types.UrlPayload, supportedTypes []string, log *logrus.Entry) ([]byte, string, string, string, error) { - log.Info("Fetching remote content...") - resp, err := doHttpGet(urlPayload, log) +func downloadRawContent(urlPayload *preview_types.UrlPayload, supportedTypes []string, ctx rcontext.RequestContext) ([]byte, string, string, string, error) { + ctx.Log.Info("Fetching remote content...") + resp, err := doHttpGet(urlPayload, ctx) if err != nil { return nil, "", "", "", err } if resp.StatusCode != http.StatusOK { - log.Warn("Received status code " + strconv.Itoa(resp.StatusCode)) + ctx.Log.Warn("Received status code " + strconv.Itoa(resp.StatusCode)) return nil, "", "", "", errors.New("error during transfer") } @@ -161,8 +161,8 @@ func downloadRawContent(urlPayload *preview_types.UrlPayload, supportedTypes []s return bytes, filename, contentType, resp.Header.Get("Content-Length"), nil } -func downloadHtmlContent(urlPayload *preview_types.UrlPayload, supportedTypes []string, log *logrus.Entry) (string, error) { - raw, _, _, _, err := downloadRawContent(urlPayload, supportedTypes, log) +func downloadHtmlContent(urlPayload *preview_types.UrlPayload, supportedTypes []string, ctx rcontext.RequestContext) (string, error) { + raw, _, _, _, err := downloadRawContent(urlPayload, supportedTypes, ctx) html := "" if raw != nil { html = string(raw) @@ -170,14 +170,14 @@ func downloadHtmlContent(urlPayload *preview_types.UrlPayload, supportedTypes [] return html, err } -func downloadImage(urlPayload *preview_types.UrlPayload, log *logrus.Entry) (*preview_types.PreviewImage, error) { - log.Info("Getting image from " + urlPayload.ParsedUrl.String()) - resp, err := doHttpGet(urlPayload, log) +func downloadImage(urlPayload *preview_types.UrlPayload, ctx rcontext.RequestContext) (*preview_types.PreviewImage, error) { + ctx.Log.Info("Getting image from " + urlPayload.ParsedUrl.String()) + resp, err := doHttpGet(urlPayload, ctx) if err != nil { return nil, err } if resp.StatusCode != http.StatusOK { - log.Warn("Received status code " + strconv.Itoa(resp.StatusCode)) + ctx.Log.Warn("Received status code " + strconv.Itoa(resp.StatusCode)) return nil, errors.New("error during transfer") } diff --git a/controllers/preview_controller/previewers/opengraph_previewer.go b/controllers/preview_controller/previewers/opengraph_previewer.go index c60b3f05ddf8d2be3bf8ecc6c4f640947a77c67f..19bfcf7fa3958f84c8aa459c52b2545fe1e8b4e7 100644 --- a/controllers/preview_controller/previewers/opengraph_previewer.go +++ b/controllers/preview_controller/previewers/opengraph_previewer.go @@ -1,7 +1,6 @@ package previewers import ( - "context" "fmt" "net/url" "strconv" @@ -10,9 +9,9 @@ import ( "github.com/PuerkitoBio/goquery" "github.com/dyatlov/go-opengraph/opengraph" "github.com/prometheus/client_golang/prometheus" - "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/common" "github.com/turt2live/matrix-media-repo/common/config" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/controllers/preview_controller/acl" "github.com/turt2live/matrix-media-repo/controllers/preview_controller/preview_types" "github.com/turt2live/matrix-media-repo/metrics" @@ -20,10 +19,10 @@ import ( var ogSupportedTypes = []string{"text/*"} -func GenerateOpenGraphPreview(urlPayload *preview_types.UrlPayload, log *logrus.Entry) (preview_types.PreviewResult, error) { - html, err := downloadHtmlContent(urlPayload, ogSupportedTypes, log) +func GenerateOpenGraphPreview(urlPayload *preview_types.UrlPayload, ctx rcontext.RequestContext) (preview_types.PreviewResult, error) { + html, err := downloadHtmlContent(urlPayload, ogSupportedTypes, ctx) if err != nil { - log.Error("Error downloading content: " + err.Error()) + ctx.Log.Error("Error downloading content: " + err.Error()) // Make sure the unsupported error gets passed through if err == preview_types.ErrPreviewUnsupported { @@ -37,7 +36,7 @@ func GenerateOpenGraphPreview(urlPayload *preview_types.UrlPayload, log *logrus. og := opengraph.NewOpenGraph() err = og.ProcessHTML(strings.NewReader(html)) if err != nil { - log.Error("Error getting OpenGraph: " + err.Error()) + ctx.Log.Error("Error getting OpenGraph: " + err.Error()) return preview_types.PreviewResult{}, err } @@ -67,27 +66,27 @@ func GenerateOpenGraphPreview(urlPayload *preview_types.UrlPayload, log *logrus. baseUrlS := fmt.Sprintf("%s://%s", urlPayload.ParsedUrl.Scheme, urlPayload.Address.String()) baseUrl, err := url.Parse(baseUrlS) if err != nil { - log.Error("Non-fatal error getting thumbnail (parsing base url): " + err.Error()) + ctx.Log.Error("Non-fatal error getting thumbnail (parsing base url): " + err.Error()) return *graph, nil } imgUrl, err := url.Parse(og.Images[0].URL) if err != nil { - log.Error("Non-fatal error getting thumbnail (parsing image url): " + err.Error()) + ctx.Log.Error("Non-fatal error getting thumbnail (parsing image url): " + err.Error()) return *graph, nil } // Ensure images pass through the same validation check imgAbsUrl := baseUrl.ResolveReference(imgUrl) - imgUrlPayload, err := acl.ValidateUrlForPreview(imgAbsUrl.String(), context.TODO(), log) + imgUrlPayload, err := acl.ValidateUrlForPreview(imgAbsUrl.String(), ctx) if err != nil { - log.Error("Non-fatal error getting thumbnail (URL validation): " + err.Error()) + ctx.Log.Error("Non-fatal error getting thumbnail (URL validation): " + err.Error()) return *graph, nil } - img, err := downloadImage(imgUrlPayload, log) + img, err := downloadImage(imgUrlPayload, ctx) if err != nil { - log.Error("Non-fatal error getting thumbnail (downloading image): " + err.Error()) + ctx.Log.Error("Non-fatal error getting thumbnail (downloading image): " + err.Error()) return *graph, nil } diff --git a/controllers/thumbnail_controller/thumbnail_controller.go b/controllers/thumbnail_controller/thumbnail_controller.go index 72c65b3d12dd8b8b340273eabde0c4b5cc980877..b08ae5b2143967b21ba21a167c9d618559467bed 100644 --- a/controllers/thumbnail_controller/thumbnail_controller.go +++ b/controllers/thumbnail_controller/thumbnail_controller.go @@ -2,7 +2,6 @@ package thumbnail_controller import ( "bytes" - "context" "database/sql" "fmt" "time" @@ -10,10 +9,10 @@ import ( "github.com/disintegration/imaging" "github.com/patrickmn/go-cache" "github.com/pkg/errors" - "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/common" "github.com/turt2live/matrix-media-repo/common/config" "github.com/turt2live/matrix-media-repo/common/globals" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/controllers/download_controller" "github.com/turt2live/matrix-media-repo/controllers/quarantine_controller" "github.com/turt2live/matrix-media-repo/internal_cache" @@ -39,27 +38,27 @@ var animatedTypes = []string{"image/gif"} var localCache = cache.New(30*time.Second, 60*time.Second) -func GetThumbnail(origin string, mediaId string, desiredWidth int, desiredHeight int, animated bool, method string, downloadRemote bool, ctx context.Context, log *logrus.Entry) (*types.StreamedThumbnail, error) { - media, err := download_controller.FindMediaRecord(origin, mediaId, downloadRemote, ctx, log) +func GetThumbnail(origin string, mediaId string, desiredWidth int, desiredHeight int, animated bool, method string, downloadRemote bool, ctx rcontext.RequestContext) (*types.StreamedThumbnail, error) { + media, err := download_controller.FindMediaRecord(origin, mediaId, downloadRemote, ctx) if err != nil { return nil, err } if !util.ArrayContains(supportedThumbnailTypes, media.ContentType) { - log.Warn("Cannot generate thumbnail for " + media.ContentType + " because it is not supported") + ctx.Log.Warn("Cannot generate thumbnail for " + media.ContentType + " because it is not supported") return nil, errors.New("cannot generate thumbnail for this media's content type") } if !util.ArrayContains(config.Get().Thumbnails.Types, media.ContentType) { - log.Warn("Cannot generate thumbnail for " + media.ContentType + " because it is not listed in the config") + ctx.Log.Warn("Cannot generate thumbnail for " + media.ContentType + " because it is not listed in the config") return nil, errors.New("cannot generate thumbnail for this media's content type") } if media.Quarantined { - log.Warn("Quarantined media accessed") + ctx.Log.Warn("Quarantined media accessed") if config.Get().Quarantine.ReplaceThumbnails { - log.Info("Replacing thumbnail with a quarantined one") + ctx.Log.Info("Replacing thumbnail with a quarantined one") img, err := quarantine_controller.GenerateQuarantineThumbnail(desiredWidth, desiredHeight) if err != nil { @@ -90,17 +89,17 @@ func GetThumbnail(origin string, mediaId string, desiredWidth int, desiredHeight } if animated && config.Get().Thumbnails.MaxAnimateSizeBytes > 0 && config.Get().Thumbnails.MaxAnimateSizeBytes < media.SizeBytes { - log.Warn("Attempted to animate a media record that is too large. Assuming animated=false") + ctx.Log.Warn("Attempted to animate a media record that is too large. Assuming animated=false") animated = false } if animated && !util.ArrayContains(animatedTypes, media.ContentType) { - log.Warn("Attempted to animate a media record that isn't an animated type. Assuming animated=false") + ctx.Log.Warn("Attempted to animate a media record that isn't an animated type. Assuming animated=false") animated = false } if media.SizeBytes > config.Get().Thumbnails.MaxSourceBytes { - log.Warn("Media too large to thumbnail") + ctx.Log.Warn("Media too large to thumbnail") return nil, common.ErrMediaTooLarge } @@ -112,19 +111,19 @@ func GetThumbnail(origin string, mediaId string, desiredWidth int, desiredHeight cacheKey := fmt.Sprintf("%s/%s?w=%d&h=%d&m=%s&a=%t", media.Origin, media.MediaId, width, height, method, animated) v, _, err := globals.DefaultRequestGroup.Do(cacheKey, func() (interface{}, error) { - db := storage.GetDatabase().GetThumbnailStore(ctx, log) + db := storage.GetDatabase().GetThumbnailStore(ctx) var thumbnail *types.Thumbnail item, found := localCache.Get(cacheKey) if found { thumbnail = item.(*types.Thumbnail) } else { - log.Info("Getting thumbnail record from database") + ctx.Log.Info("Getting thumbnail record from database") dbThumb, err := db.Get(media.Origin, media.MediaId, width, height, method, animated) if err != nil { if err == sql.ErrNoRows { - log.Info("Thumbnail does not exist, attempting to generate it") - genThumb, err2 := GetOrGenerateThumbnail(media, width, height, animated, method, ctx, log) + ctx.Log.Info("Thumbnail does not exist, attempting to generate it") + genThumb, err2 := GetOrGenerateThumbnail(media, width, height, animated, method, ctx) if err2 != nil { return nil, err2 } @@ -139,18 +138,18 @@ func GetThumbnail(origin string, mediaId string, desiredWidth int, desiredHeight } if thumbnail == nil { - log.Warn("Despite all efforts, a thumbnail record could not be found or generated") + ctx.Log.Warn("Despite all efforts, a thumbnail record could not be found or generated") return nil, common.ErrMediaNotFound } - err = storage.GetDatabase().GetMetadataStore(ctx, log).UpsertLastAccess(thumbnail.Sha256Hash, util.NowMillis()) + err = storage.GetDatabase().GetMetadataStore(ctx).UpsertLastAccess(thumbnail.Sha256Hash, util.NowMillis()) if err != nil { - logrus.Warn("Failed to upsert the last access time: ", err) + ctx.Log.Warn("Failed to upsert the last access time: ", err) } localCache.Set(cacheKey, thumbnail, cache.DefaultExpiration) - cached, err := internal_cache.Get().GetThumbnail(thumbnail, log) + cached, err := internal_cache.Get().GetThumbnail(thumbnail, ctx) if err != nil { return nil, err } @@ -161,8 +160,8 @@ func GetThumbnail(origin string, mediaId string, desiredWidth int, desiredHeight }, nil } - log.Info("Reading thumbnail from disk") - mediaStream, err := datastore.DownloadStream(ctx, log, thumbnail.DatastoreId, thumbnail.Location) + ctx.Log.Info("Reading thumbnail from disk") + mediaStream, err := datastore.DownloadStream(ctx, thumbnail.DatastoreId, thumbnail.Location) if err != nil { return nil, err } @@ -196,18 +195,18 @@ func GetThumbnail(origin string, mediaId string, desiredWidth int, desiredHeight return value, err } -func GetOrGenerateThumbnail(media *types.Media, width int, height int, animated bool, method string, ctx context.Context, log *logrus.Entry) (*types.Thumbnail, error) { - db := storage.GetDatabase().GetThumbnailStore(ctx, log) +func GetOrGenerateThumbnail(media *types.Media, width int, height int, animated bool, method string, ctx rcontext.RequestContext) (*types.Thumbnail, error) { + db := storage.GetDatabase().GetThumbnailStore(ctx) thumbnail, err := db.Get(media.Origin, media.MediaId, width, height, method, animated) if err != nil && err != sql.ErrNoRows { return nil, err } if err != sql.ErrNoRows { - log.Info("Using thumbnail from database") + ctx.Log.Info("Using thumbnail from database") return thumbnail, nil } - log.Info("Generating thumbnail") + ctx.Log.Info("Generating thumbnail") thumbnailChan := getResourceHandler().GenerateThumbnail(media, width, height, method, animated) defer close(thumbnailChan) diff --git a/controllers/thumbnail_controller/thumbnail_resource_handler.go b/controllers/thumbnail_controller/thumbnail_resource_handler.go index ffcdb64c59764fa64eee3c44f3175dcc96b4887c..8c93373ab4ff65c99e85e1b39ad6b2c235c6e93b 100644 --- a/controllers/thumbnail_controller/thumbnail_resource_handler.go +++ b/controllers/thumbnail_controller/thumbnail_resource_handler.go @@ -2,7 +2,6 @@ package thumbnail_controller import ( "bytes" - "context" "errors" "fmt" "image" @@ -22,6 +21,7 @@ import ( "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/common" "github.com/turt2live/matrix-media-repo/common/config" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/metrics" "github.com/turt2live/matrix-media-repo/storage" "github.com/turt2live/matrix-media-repo/storage/datastore" @@ -77,7 +77,7 @@ func getResourceHandler() *thumbnailResourceHandler { func thumbnailWorkFn(request *resource_handler.WorkRequest) interface{} { info := request.Metadata.(*thumbnailRequest) - log := logrus.WithFields(logrus.Fields{ + ctx := rcontext.Initial().LogWithFields(logrus.Fields{ "worker_requestId": request.Id, "worker_media": info.media.Origin + "/" + info.media.MediaId, "worker_width": info.width, @@ -85,11 +85,9 @@ func thumbnailWorkFn(request *resource_handler.WorkRequest) interface{} { "worker_method": info.method, "worker_animated": info.animated, }) - log.Info("Processing thumbnail request") + ctx.Log.Info("Processing thumbnail request") - ctx := context.TODO() // TODO: Should we use a real context? - - generated, err := GenerateThumbnail(info.media, info.width, info.height, info.method, info.animated, ctx, log) + generated, err := GenerateThumbnail(info.media, info.width, info.height, info.method, info.animated, ctx) if err != nil { return &thumbnailResponse{err: err} } @@ -109,10 +107,10 @@ func thumbnailWorkFn(request *resource_handler.WorkRequest) interface{} { Sha256Hash: generated.Sha256Hash, } - db := storage.GetDatabase().GetThumbnailStore(ctx, log) + db := storage.GetDatabase().GetThumbnailStore(ctx) err = db.Insert(newThumb) if err != nil { - log.Error("Unexpected error caching thumbnail: " + err.Error()) + ctx.Log.Error("Unexpected error caching thumbnail: " + err.Error()) return &thumbnailResponse{err: err} } @@ -135,7 +133,7 @@ func (h *thumbnailResourceHandler) GenerateThumbnail(media *types.Media, width i return resultChan } -func GenerateThumbnail(media *types.Media, width int, height int, method string, animated bool, ctx context.Context, log *logrus.Entry) (*GeneratedThumbnail, error) { +func GenerateThumbnail(media *types.Media, width int, height int, method string, animated bool, ctx rcontext.RequestContext) (*GeneratedThumbnail, error) { var src image.Image var err error @@ -143,13 +141,13 @@ func GenerateThumbnail(media *types.Media, width int, height int, method string, allowAnimated := config.Get().Thumbnails.AllowAnimated if media.ContentType == "image/svg+xml" { - src, err = svgToImage(media, ctx, log) + src, err = svgToImage(media, ctx) } else if canAnimate && !animated { - src, err = pickImageFrame(media, ctx, log) + src, err = pickImageFrame(media, ctx) } else { - mediaStream, err2 := datastore.DownloadStream(ctx, log, media.DatastoreId, media.Location) + mediaStream, err2 := datastore.DownloadStream(ctx, media.DatastoreId, media.Location) if err2 != nil { - log.Error("Error getting file: ", err2) + ctx.Log.Error("Error getting file: ", err2) return nil, err2 } src, err = imaging.Decode(mediaStream) @@ -167,7 +165,7 @@ func GenerateThumbnail(media *types.Media, width int, height int, method string, if aspectRatio == targetAspectRatio { // Highly unlikely, but if the aspect ratios match then just resize method = "scale" - log.Info("Aspect ratio is the same, converting method to 'scale'") + ctx.Log.Info("Aspect ratio is the same, converting method to 'scale'") } metric := metrics.ThumbnailsGenerated.With(prometheus.Labels{ @@ -184,11 +182,11 @@ func GenerateThumbnail(media *types.Media, width int, height int, method string, if srcWidth <= width && srcHeight <= height { if animated { - log.Warn("Image is too small but the image should be animated. Adjusting dimensions to fit image exactly.") + ctx.Log.Warn("Image is too small but the image should be animated. Adjusting dimensions to fit image exactly.") width = srcWidth height = srcHeight } else if canAnimate && !animated { - log.Warn("Image is too small, but the request calls for a static image. Adjusting dimensions to fit image exactly.") + ctx.Log.Warn("Image is too small, but the request calls for a static image. Adjusting dimensions to fit image exactly.") width = srcWidth height = srcHeight } else { @@ -198,7 +196,7 @@ func GenerateThumbnail(media *types.Media, width int, height int, method string, thumb.DatastoreLocation = media.Location thumb.SizeBytes = media.SizeBytes thumb.Sha256Hash = media.Sha256Hash - log.Warn("Image too small, returning raw image") + ctx.Log.Warn("Image too small, returning raw image") metric.Inc() return thumb, nil } @@ -208,7 +206,7 @@ func GenerateThumbnail(media *types.Media, width int, height int, method string, if media.ContentType == "image/jpeg" || media.ContentType == "image/jpg" { orientation, err = util_exif.GetExifOrientation(media) if err != nil { - log.Warn("Non-fatal error getting EXIF orientation: " + err.Error()) + ctx.Log.Warn("Non-fatal error getting EXIF orientation: " + err.Error()) orientation = nil // just in case } } @@ -216,21 +214,21 @@ func GenerateThumbnail(media *types.Media, width int, height int, method string, contentType := "image/png" imgData := &bytes.Buffer{} if allowAnimated && animated { - log.Info("Generating animated thumbnail") + ctx.Log.Info("Generating animated thumbnail") contentType = "image/gif" // Animated GIFs are a bit more special because we need to do it frame by frame. // This is fairly resource intensive. The calling code is responsible for limiting this case. - mediaStream, err := datastore.DownloadStream(ctx, log, media.DatastoreId, media.Location) + mediaStream, err := datastore.DownloadStream(ctx, media.DatastoreId, media.Location) if err != nil { - log.Error("Error resolving datastore path: ", err) + ctx.Log.Error("Error resolving datastore path: ", err) return nil, err } g, err := gif.DecodeAll(mediaStream) if err != nil { - log.Error("Error generating animated thumbnail: " + err.Error()) + ctx.Log.Error("Error generating animated thumbnail: " + err.Error()) return nil, err } @@ -249,7 +247,7 @@ func GenerateThumbnail(media *types.Media, width int, height int, method string, // Do the thumbnailing on the copied frame frameThumb, err := thumbnailFrame(frameImg, method, width, height, imaging.Linear, nil) if err != nil { - log.Error("Error generating animated thumbnail frame: " + err.Error()) + ctx.Log.Error("Error generating animated thumbnail frame: " + err.Error()) return nil, err } @@ -266,32 +264,32 @@ func GenerateThumbnail(media *types.Media, width int, height int, method string, err = gif.EncodeAll(imgData, g) if err != nil { - log.Error("Error generating animated thumbnail: " + err.Error()) + ctx.Log.Error("Error generating animated thumbnail: " + err.Error()) return nil, err } } else { src, err = thumbnailFrame(src, method, width, height, imaging.Lanczos, orientation) if err != nil { - log.Error("Error generating thumbnail: " + err.Error()) + ctx.Log.Error("Error generating thumbnail: " + err.Error()) return nil, err } // Put the image bytes into a memory buffer err = imaging.Encode(imgData, src, imaging.PNG) if err != nil { - log.Error("Unexpected error encoding thumbnail: " + err.Error()) + ctx.Log.Error("Unexpected error encoding thumbnail: " + err.Error()) return nil, err } } // Reset the buffer pointer and store the file - ds, err := datastore.PickDatastore(common.KindThumbnails, ctx, log) + ds, err := datastore.PickDatastore(common.KindThumbnails, ctx) if err != nil { return nil, err } - info, err := ds.UploadFile(util.BufferToStream(imgData), int64(len(imgData.Bytes())), ctx, log) + info, err := ds.UploadFile(util.BufferToStream(imgData), int64(len(imgData.Bytes())), ctx) if err != nil { - log.Error("Unexpected error saving thumbnail: " + err.Error()) + ctx.Log.Error("Unexpected error saving thumbnail: " + err.Error()) return nil, err } @@ -337,7 +335,7 @@ func thumbnailFrame(src image.Image, method string, width int, height int, filte return result, nil } -func svgToImage(media *types.Media, ctx context.Context, log *logrus.Entry) (image.Image, error) { +func svgToImage(media *types.Media, ctx rcontext.RequestContext) (image.Image, error) { tempFile1 := path.Join(os.TempDir(), "media_repo."+media.Origin+"."+media.MediaId+".1.png") tempFile2 := path.Join(os.TempDir(), "media_repo."+media.Origin+"."+media.MediaId+".2.png") @@ -345,9 +343,9 @@ func svgToImage(media *types.Media, ctx context.Context, log *logrus.Entry) (ima defer os.Remove(tempFile2) // requires imagemagick - mediaStream, err := datastore.DownloadStream(ctx, log, media.DatastoreId, media.Location) + mediaStream, err := datastore.DownloadStream(ctx, media.DatastoreId, media.Location) if err != nil { - log.Error("Error streaming file: ", err) + ctx.Log.Error("Error streaming file: ", err) return nil, err } @@ -372,22 +370,22 @@ func svgToImage(media *types.Media, ctx context.Context, log *logrus.Entry) (ima return imaging.Decode(imgData) } -func pickImageFrame(media *types.Media, ctx context.Context, log *logrus.Entry) (image.Image, error) { - mediaStream, err := datastore.DownloadStream(ctx, log, media.DatastoreId, media.Location) +func pickImageFrame(media *types.Media, ctx rcontext.RequestContext) (image.Image, error) { + mediaStream, err := datastore.DownloadStream(ctx, media.DatastoreId, media.Location) if err != nil { - log.Error("Error resolving datastore path: ", err) + ctx.Log.Error("Error resolving datastore path: ", err) return nil, err } g, err := gif.DecodeAll(mediaStream) if err != nil { - log.Error("Error picking frame: " + err.Error()) + ctx.Log.Error("Error picking frame: " + err.Error()) return nil, err } stillFrameRatio := float64(config.Get().Thumbnails.StillFrame) frameIndex := int(math.Floor(math.Min(1, math.Max(0, stillFrameRatio)) * float64(len(g.Image)))) - log.Info("Picking frame ", frameIndex, " for animated file") + ctx.Log.Info("Picking frame ", frameIndex, " for animated file") return g.Image[frameIndex], nil } diff --git a/controllers/upload_controller/upload_controller.go b/controllers/upload_controller/upload_controller.go index 9c1bbd196e886a82acfb6bd9afb3ea3ed17a54d9..2b4b233310ecc152b21884220196b7d6e03fe24b 100644 --- a/controllers/upload_controller/upload_controller.go +++ b/controllers/upload_controller/upload_controller.go @@ -1,7 +1,6 @@ package upload_controller import ( - "context" "database/sql" "io" "io/ioutil" @@ -12,6 +11,7 @@ import ( "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/common" "github.com/turt2live/matrix-media-repo/common/config" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/storage" "github.com/turt2live/matrix-media-repo/storage/datastore" "github.com/turt2live/matrix-media-repo/types" @@ -77,7 +77,7 @@ func EstimateContentLength(contentLength int64, contentLengthHeader string) int6 return -1 // unknown } -func UploadMedia(contents io.ReadCloser, contentLength int64, contentType string, filename string, userId string, origin string, ctx context.Context, log *logrus.Entry) (*types.Media, error) { +func UploadMedia(contents io.ReadCloser, contentLength int64, contentType string, filename string, userId string, origin string, ctx rcontext.RequestContext) (*types.Media, error) { defer contents.Close() var data io.ReadCloser @@ -87,8 +87,8 @@ func UploadMedia(contents io.ReadCloser, contentLength int64, contentType string data = contents } - metadataDb := storage.GetDatabase().GetMetadataStore(ctx, log) - mediaDb := storage.GetDatabase().GetMediaStore(ctx, log) + metadataDb := storage.GetDatabase().GetMetadataStore(ctx) + mediaDb := storage.GetDatabase().GetMediaStore(ctx) mediaTaken := true var mediaId string @@ -125,17 +125,17 @@ func UploadMedia(contents io.ReadCloser, contentLength int64, contentType string } } - return StoreDirect(data, contentLength, contentType, filename, userId, origin, mediaId, common.KindLocalMedia, ctx, log) + return StoreDirect(data, contentLength, contentType, filename, userId, origin, mediaId, common.KindLocalMedia, ctx) } -func trackUploadAsLastAccess(ctx context.Context, log *logrus.Entry, media *types.Media) { - err := storage.GetDatabase().GetMetadataStore(ctx, log).UpsertLastAccess(media.Sha256Hash, util.NowMillis()) +func trackUploadAsLastAccess(ctx rcontext.RequestContext, media *types.Media) { + err := storage.GetDatabase().GetMetadataStore(ctx).UpsertLastAccess(media.Sha256Hash, util.NowMillis()) if err != nil { logrus.Warn("Failed to upsert the last access time: ", err) } } -func IsAllowed(contentType string, reportedContentType string, userId string, log *logrus.Entry) bool { +func IsAllowed(contentType string, reportedContentType string, userId string, ctx rcontext.RequestContext) bool { allowed := false userMatched := false @@ -143,13 +143,13 @@ func IsAllowed(contentType string, reportedContentType string, userId string, lo for user, userExcl := range config.Get().Uploads.PerUserExclusions { if glob.Glob(user, userId) { if !userMatched { - log.Info("Per-user allowed types policy found for " + userId) + ctx.Log.Info("Per-user allowed types policy found for " + userId) userMatched = true } for _, exclType := range userExcl { if glob.Glob(exclType, contentType) { allowed = true - log.Info("Content type " + contentType + " (reported as " + reportedContentType + ") is allowed due to a per-user policy for " + userId) + ctx.Log.Info("Content type " + contentType + " (reported as " + reportedContentType + ") is allowed due to a per-user policy for " + userId) break } } @@ -162,7 +162,7 @@ func IsAllowed(contentType string, reportedContentType string, userId string, lo } if !userMatched && !allowed { - log.Info("Checking general allowed types due to no matching per-user policy") + ctx.Log.Info("Checking general allowed types due to no matching per-user policy") for _, allowedType := range config.Get().Uploads.AllowedTypes { if glob.Glob(allowedType, contentType) { allowed = true @@ -178,12 +178,12 @@ func IsAllowed(contentType string, reportedContentType string, userId string, lo return allowed } -func StoreDirect(contents io.ReadCloser, expectedSize int64, contentType string, filename string, userId string, origin string, mediaId string, kind string, ctx context.Context, log *logrus.Entry) (*types.Media, error) { - ds, err := datastore.PickDatastore(kind, ctx, log) +func StoreDirect(contents io.ReadCloser, expectedSize int64, contentType string, filename string, userId string, origin string, mediaId string, kind string, ctx rcontext.RequestContext) (*types.Media, error) { + ds, err := datastore.PickDatastore(kind, ctx) if err != nil { return nil, err } - info, err := ds.UploadFile(contents, expectedSize, ctx, log) + info, err := ds.UploadFile(contents, expectedSize, ctx) if err != nil { return nil, err } @@ -195,20 +195,20 @@ func StoreDirect(contents io.ReadCloser, expectedSize int64, contentType string, fileMime, err := util.GetMimeType(stream) if err != nil { - log.Error("Error while checking content type of file: ", err.Error()) + ctx.Log.Error("Error while checking content type of file: ", err.Error()) ds.DeleteObject(info.Location) // delete temp object return nil, err } - allowed := IsAllowed(fileMime, contentType, userId, log) + allowed := IsAllowed(fileMime, contentType, userId, ctx) if !allowed { - log.Warn("Content type " + fileMime + " (reported as " + contentType + ") is not allowed to be uploaded") + ctx.Log.Warn("Content type " + fileMime + " (reported as " + contentType + ") is not allowed to be uploaded") ds.DeleteObject(info.Location) // delete temp object return nil, common.ErrMediaNotAllowed } - db := storage.GetDatabase().GetMediaStore(ctx, log) + db := storage.GetDatabase().GetMediaStore(ctx) records, err := db.GetByHash(info.Sha256Hash) if err != nil { ds.DeleteObject(info.Location) // delete temp object @@ -216,7 +216,7 @@ func StoreDirect(contents io.ReadCloser, expectedSize int64, contentType string, } if len(records) > 0 { - log.Info("Duplicate media for hash ", info.Sha256Hash) + ctx.Log.Info("Duplicate media for hash ", info.Sha256Hash) // If the user is a real user (ie: actually uploaded media), then we'll see if there's // an exact duplicate that we can return. Otherwise we'll just pick the first record and @@ -224,13 +224,13 @@ func StoreDirect(contents io.ReadCloser, expectedSize int64, contentType string, if userId != NoApplicableUploadUser { for _, record := range records { if record.Quarantined { - log.Warn("User attempted to upload quarantined content - rejecting") + ctx.Log.Warn("User attempted to upload quarantined content - rejecting") return nil, common.ErrMediaQuarantined } if record.UserId == userId && record.Origin == origin && record.ContentType == contentType { - log.Info("User has already uploaded this media before - returning unaltered media record") + ctx.Log.Info("User has already uploaded this media before - returning unaltered media record") ds.DeleteObject(info.Location) // delete temp object - trackUploadAsLastAccess(ctx, log, record) + trackUploadAsLastAccess(ctx, record) return record, nil } } @@ -239,7 +239,7 @@ func StoreDirect(contents io.ReadCloser, expectedSize int64, contentType string, // We'll use the location from the first record record := records[0] if record.Quarantined { - log.Warn("User attempted to upload quarantined content - rejecting") + ctx.Log.Warn("User attempted to upload quarantined content - rejecting") return nil, common.ErrMediaQuarantined } @@ -260,7 +260,7 @@ func StoreDirect(contents io.ReadCloser, expectedSize int64, contentType string, // If the media's file exists, we'll delete the temp file // If the media's file doesn't exist, we'll move the temp file to where the media expects it to be if media.DatastoreId != ds.DatastoreId && media.Location != info.Location { - ds2, err := datastore.LocateDatastore(ctx, log, media.DatastoreId) + ds2, err := datastore.LocateDatastore(ctx, media.DatastoreId) if err != nil { ds.DeleteObject(info.Location) // delete temp object return nil, err @@ -271,14 +271,14 @@ func StoreDirect(contents io.ReadCloser, expectedSize int64, contentType string, return nil, err } - ds2.OverwriteObject(media.Location, stream, ctx, log) + ds2.OverwriteObject(media.Location, stream, ctx) ds.DeleteObject(info.Location) } else { ds.DeleteObject(info.Location) } } - trackUploadAsLastAccess(ctx, log, media) + trackUploadAsLastAccess(ctx, media) return media, nil } @@ -289,7 +289,7 @@ func StoreDirect(contents io.ReadCloser, expectedSize int64, contentType string, return nil, errors.New("file has no contents") } - log.Info("Persisting new media record") + ctx.Log.Info("Persisting new media record") media := &types.Media{ Origin: origin, @@ -310,6 +310,6 @@ func StoreDirect(contents io.ReadCloser, expectedSize int64, contentType string, return nil, err } - trackUploadAsLastAccess(ctx, log, media) + trackUploadAsLastAccess(ctx, media) return media, nil } diff --git a/docs/config.md b/docs/config.md index b6ffd2cc4e4ec7e8abf884074857b0a200206008..3bb110186390d0d47319f65dbd8ca311c3e1f443 100644 --- a/docs/config.md +++ b/docs/config.md @@ -54,4 +54,5 @@ identicons: enabled: false ``` -Per-domain configs can also be layered - just ensure that each layer has the `homeserver` property in it. +Per-domain configs can also be layered - just ensure that each layer has the `homeserver` property in it. They inherit +from the main config for options not defined in their layers. diff --git a/go.mod b/go.mod index 2faf4e22df897e9e6e6c4ea829933c3d347b99d5..34aa48f18371408c20ba57dbbc00b8c628e3fdd8 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,8 @@ require ( github.com/DavidHuie/gomigrate v0.0.0-20160809001028-4004e6142040 github.com/PuerkitoBio/goquery v0.0.0-20171206121606-bc4e06eb0792 github.com/ajstarks/svgo v0.0.0-20171111115224-f9be02f22f2c // indirect + github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 // indirect + github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d // indirect github.com/alioygur/is v0.0.0-20170213121024-204f48747743 github.com/andybalholm/cascadia v0.0.0-20161224141413-349dd0209470 // indirect github.com/bep/debounce v1.2.0 @@ -43,21 +45,21 @@ require ( github.com/peterbourgon/g2s v0.0.0-20170223122336-d4e7ad98afea // indirect github.com/pkg/errors v0.0.0-20171018195549-f15c970de5b7 github.com/prometheus/client_golang v0.0.0-20181116151817-3fb53dff765f - github.com/prometheus/common v0.0.0-20181116084131-1f2c4f3cd6db // indirect + github.com/prometheus/common v0.0.0-20181116084131-1f2c4f3cd6db github.com/rifflock/lfshook v0.0.0-20170910022531-3bcf86f879c7 github.com/rubyist/circuitbreaker v2.2.1+incompatible github.com/rwcarlsen/goexif v0.0.0-20180110181140-17202558c8d9 github.com/ryanuber/go-glob v0.0.0-20170128012129-256dc444b735 github.com/sebest/xff v0.0.0-20160910043805-6c115e0ffa35 github.com/sirupsen/logrus v0.0.0-20170822132746-89742aefa4b2 - github.com/stretchr/testify v1.3.0 // indirect github.com/tebeka/strftime v0.0.0-20140926081919-3f9c7761e312 // indirect golang.org/x/image v0.0.0-20171214225156-12117c17ca67 golang.org/x/time v0.0.0-20170927054726-6dc17368e09b // indirect google.golang.org/appengine v1.6.1 // indirect gopkg.in/airbrake/gobrake.v2 v2.0.9 // indirect + gopkg.in/alecthomas/kingpin.v2 v2.2.6 // indirect gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect gopkg.in/gemnasium/logrus-airbrake-hook.v2 v2.1.2 // indirect gopkg.in/h2non/filetype.v1 v1.0.5 // indirect - gopkg.in/yaml.v2 v2.2.1 + gopkg.in/yaml.v2 v2.2.2 ) diff --git a/go.sum b/go.sum index f44fccf4481f887190246ff4181d666ce9fa296d..cd8ded2240496dee1c6eaa6faf58b4422e02525f 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,10 @@ github.com/PuerkitoBio/goquery v0.0.0-20171206121606-bc4e06eb0792 h1:z1EfQR3dcEw github.com/PuerkitoBio/goquery v0.0.0-20171206121606-bc4e06eb0792/go.mod h1:T9ezsOHcCrDCgA8aF1Cqr3sSYbO/xgdy8/R/XiIMAhA= github.com/ajstarks/svgo v0.0.0-20171111115224-f9be02f22f2c h1:EkbaKSiLkX5LJZp5gWk3DPP6dflpRDpqTii1fBAw/G4= github.com/ajstarks/svgo v0.0.0-20171111115224-f9be02f22f2c/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw= +github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 h1:JYp7IbQjafoB+tBA3gMyHYHrpOtNuDiK/uB5uXxq5wM= +github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= +github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d h1:UQZhZ2O0vMHr2cI+DC1Mbh0TJxzA3RcLoMsFw+aXw7E= +github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho= github.com/alioygur/is v0.0.0-20170213121024-204f48747743 h1:Ou8l+Rf3eEEM4mukAbO7uvy4+5qgx6HBasKKLl8rsFk= github.com/alioygur/is v0.0.0-20170213121024-204f48747743/go.mod h1:fmXi78K26iMaOs0fINRVLl1TIPCYcLfOopoZ5+mc8AE= github.com/andybalholm/cascadia v0.0.0-20161224141413-349dd0209470 h1:4jHLmof+Hba81591gfH5xYA8QXzuvgksxwPNrmjR2BA= @@ -126,8 +130,8 @@ github.com/smartystreets/assertions v0.0.0-20190116191733-b6c0e53d7304/go.mod h1 github.com/smartystreets/goconvey v0.0.0-20181108003508-044398e4856c h1:Ho+uVpkel/udgjbwB5Lktg9BtvJSh2DT0Hi6LPSyI2w= github.com/smartystreets/goconvey v0.0.0-20181108003508-044398e4856c/go.mod h1:XDJAKZRPZ1CvBcN2aX5YOUTYGHki24fSF0Iv48Ibg0s= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= -github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/tebeka/strftime v0.0.0-20140926081919-3f9c7761e312 h1:frNEkk4P8mq+47LAMvj9LvhDq01kFDUhpJZzzei8IuM= github.com/tebeka/strftime v0.0.0-20140926081919-3f9c7761e312/go.mod h1:o6CrSUtupq/A5hylbvAsdydn0d5yokJExs8VVdx4wwI= golang.org/x/crypto v0.0.0-20190128193316-c7b33c32a30b/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= @@ -166,6 +170,8 @@ google.golang.org/appengine v1.6.1 h1:QzqyMA1tlu6CgqCDUtU9V+ZKhLFT2dkJuANu5QaxI3 google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0= gopkg.in/airbrake/gobrake.v2 v2.0.9 h1:7z2uVWwn7oVeeugY1DtlPAy5H+KYgB1KeKTnqjNatLo= gopkg.in/airbrake/gobrake.v2 v2.0.9/go.mod h1:/h5ZAUhDkGaJfjzjKLSjv6zCL6O0LLBxU4K+aSYdM/U= +gopkg.in/alecthomas/kingpin.v2 v2.2.6 h1:jMFz6MfLP0/4fUyZle81rXUoxOBFi19VUFKVDOQfozc= +gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= @@ -181,3 +187,5 @@ gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkep gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/yaml.v2 v2.2.1 h1:mUhvW9EsL+naU5Q3cakzfE91YhliOondGd6ZrsDBHQE= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/internal_cache/media_cache.go b/internal_cache/media_cache.go index 9f66c245e1e6e4a58f5a0ef826465600a2539e94..f0d99f38a71c36e7d4510ea9e11c072d8c70c2ad 100644 --- a/internal_cache/media_cache.go +++ b/internal_cache/media_cache.go @@ -3,7 +3,6 @@ package internal_cache import ( "bytes" "container/list" - "context" "fmt" "io/ioutil" "sync" @@ -11,8 +10,10 @@ import ( "github.com/patrickmn/go-cache" "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/common/log" "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/common/config" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/metrics" "github.com/turt2live/matrix-media-repo/storage/datastore" "github.com/turt2live/matrix-media-repo/types" @@ -90,14 +91,14 @@ func (c *MediaCache) IncrementDownloads(fileHash string) { c.tracker.Increment(fileHash) } -func (c *MediaCache) GetMedia(media *types.Media, log *logrus.Entry) (*cachedFile, error) { +func (c *MediaCache) GetMedia(media *types.Media, ctx rcontext.RequestContext) (*cachedFile, error) { if !c.enabled { metrics.CacheMisses.With(prometheus.Labels{"cache": "media"}).Inc() return nil, nil } cacheFn := func() (*cachedFile, error) { - mediaStream, err := datastore.DownloadStream(context.TODO(), log, media.DatastoreId, media.Location) + mediaStream, err := datastore.DownloadStream(ctx, media.DatastoreId, media.Location) if err != nil { return nil, err } @@ -110,17 +111,17 @@ func (c *MediaCache) GetMedia(media *types.Media, log *logrus.Entry) (*cachedFil return &cachedFile{media: media, Contents: bytes.NewBuffer(data)}, nil } - return c.updateItemInCache(media.Sha256Hash, media.SizeBytes, cacheFn, log) + return c.updateItemInCache(media.Sha256Hash, media.SizeBytes, cacheFn, ctx) } -func (c *MediaCache) GetThumbnail(thumbnail *types.Thumbnail, log *logrus.Entry) (*cachedFile, error) { +func (c *MediaCache) GetThumbnail(thumbnail *types.Thumbnail, ctx rcontext.RequestContext) (*cachedFile, error) { if !c.enabled { metrics.CacheMisses.With(prometheus.Labels{"cache": "media"}).Inc() return nil, nil } cacheFn := func() (*cachedFile, error) { - mediaStream, err := datastore.DownloadStream(context.TODO(), log, thumbnail.DatastoreId, thumbnail.Location) + mediaStream, err := datastore.DownloadStream(ctx, thumbnail.DatastoreId, thumbnail.Location) if err != nil { return nil, err } @@ -133,10 +134,10 @@ func (c *MediaCache) GetThumbnail(thumbnail *types.Thumbnail, log *logrus.Entry) return &cachedFile{thumbnail: thumbnail, Contents: bytes.NewBuffer(data)}, nil } - return c.updateItemInCache(thumbnail.Sha256Hash, thumbnail.SizeBytes, cacheFn, log) + return c.updateItemInCache(thumbnail.Sha256Hash, thumbnail.SizeBytes, cacheFn, ctx) } -func (c *MediaCache) updateItemInCache(recordId string, mediaSize int64, cacheFn func() (*cachedFile, error), log *logrus.Entry) (*cachedFile, error) { +func (c *MediaCache) updateItemInCache(recordId string, mediaSize int64, cacheFn func() (*cachedFile, error), ctx rcontext.RequestContext) (*cachedFile, error) { downloads := c.tracker.NumDownloads(recordId) enoughDownloads := downloads >= config.Get().Downloads.Cache.MinDownloads canCache := c.canJoinCache(recordId) @@ -145,7 +146,7 @@ func (c *MediaCache) updateItemInCache(recordId string, mediaSize int64, cacheFn // No longer eligible for the cache - delete item // The cached bytes will leave memory over time if found && !enoughDownloads { - log.Info("Removing media from cache because it does not have enough downloads") + ctx.Log.Info("Removing media from cache because it does not have enough downloads") metrics.CacheMisses.With(prometheus.Labels{"cache": "media"}).Inc() metrics.CacheEvictions.With(prometheus.Labels{"cache": "media", "reason": "not_enough_downloads"}).Inc() c.cache.Delete(recordId) @@ -162,14 +163,14 @@ func (c *MediaCache) updateItemInCache(recordId string, mediaSize int64, cacheFn // Don't bother checking for space if it won't fit anyways if mediaSize > maxSpace { - log.Warn("Media too large to cache") + ctx.Log.Warn("Media too large to cache") metrics.CacheMisses.With(prometheus.Labels{"cache": "media"}).Inc() return nil, nil } if freeSpace >= mediaSize { // Perfect! It'll fit - just cache it - log.Info("Caching file in memory") + ctx.Log.Info("Caching file in memory") c.size = usedSpace + mediaSize c.flagCached(recordId) @@ -184,9 +185,9 @@ func (c *MediaCache) updateItemInCache(recordId string, mediaSize int64, cacheFn // We need to clean up some space neededSize := (usedSpace + mediaSize) - maxSpace - log.Info(fmt.Sprintf("Attempting to clear %d bytes from media cache", neededSize)) + ctx.Log.Info(fmt.Sprintf("Attempting to clear %d bytes from media cache", neededSize)) clearedSpace := c.clearSpace(neededSize, downloads, mediaSize) - log.Info(fmt.Sprintf("Cleared %d bytes from media cache", clearedSpace)) + ctx.Log.Info(fmt.Sprintf("Cleared %d bytes from media cache", clearedSpace)) freeSpace += clearedSpace if freeSpace >= mediaSize { // Now it'll fit - cache it @@ -205,7 +206,7 @@ func (c *MediaCache) updateItemInCache(recordId string, mediaSize int64, cacheFn return cachedItem, nil } - log.Warn("Unable to clear enough space for file to be cached") + ctx.Log.Warn("Unable to clear enough space for file to be cached") return nil, nil } diff --git a/matrix/admin.go b/matrix/admin.go index 633fe2fefae635c26041b98bf8056bda249f79f2..ed70e0ab9063dfe897e605897a3f5d887ae7296b 100644 --- a/matrix/admin.go +++ b/matrix/admin.go @@ -1,11 +1,12 @@ package matrix import ( - "context" "time" + + "github.com/turt2live/matrix-media-repo/common/rcontext" ) -func IsUserAdmin(ctx context.Context, serverName string, accessToken string, ipAddr string) (bool, error) { +func IsUserAdmin(ctx rcontext.RequestContext, serverName string, accessToken string, ipAddr string) (bool, error) { fakeUser := "@media.repo.admin.check:" + serverName hs, cb := getBreakerAndConfig(serverName) @@ -32,7 +33,7 @@ func IsUserAdmin(ctx context.Context, serverName string, accessToken string, ipA return isAdmin, replyError } -func ListMedia(ctx context.Context, serverName string, accessToken string, roomId string, ipAddr string) (*mediaListResponse, error) { +func ListMedia(ctx rcontext.RequestContext, serverName string, accessToken string, roomId string, ipAddr string) (*mediaListResponse, error) { hs, cb := getBreakerAndConfig(serverName) response := &mediaListResponse{} diff --git a/matrix/auth.go b/matrix/auth.go index 2355be2393c8f799e22a8f3ef7814183651aad37..98b9315472c68f2d862bd83749b1927dfbf34d92 100644 --- a/matrix/auth.go +++ b/matrix/auth.go @@ -1,16 +1,16 @@ package matrix import ( - "context" "net/url" "time" "github.com/pkg/errors" + "github.com/turt2live/matrix-media-repo/common/rcontext" ) var ErrNoToken = errors.New("Missing access token") -func GetUserIdFromToken(ctx context.Context, serverName string, accessToken string, appserviceUserId string, ipAddr string) (string, error) { +func GetUserIdFromToken(ctx rcontext.RequestContext, serverName string, accessToken string, appserviceUserId string, ipAddr string) (string, error) { if accessToken == "" { return "", ErrNoToken } diff --git a/storage/datastore/datastore.go b/storage/datastore/datastore.go index 9fe6ba7d0cbe773eb385cff30b0eabd4782866c9..63355e835c954d12b4bd7eff4c4633fd9f61a7a7 100644 --- a/storage/datastore/datastore.go +++ b/storage/datastore/datastore.go @@ -1,7 +1,6 @@ package datastore import ( - "context" "fmt" "io" @@ -9,6 +8,7 @@ import ( "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/common" "github.com/turt2live/matrix-media-repo/common/config" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/storage" "github.com/turt2live/matrix-media-repo/types" ) @@ -22,7 +22,7 @@ func GetAvailableDatastores() ([]*types.Datastore, error) { uri := GetUriForDatastore(ds) - dsInstance, err := storage.GetOrCreateDatastoreOfType(context.TODO(), &logrus.Entry{}, ds.Type, uri) + dsInstance, err := storage.GetOrCreateDatastoreOfType(rcontext.Initial(), ds.Type, uri) if err != nil { return nil, err } @@ -33,8 +33,8 @@ func GetAvailableDatastores() ([]*types.Datastore, error) { return datastores, nil } -func LocateDatastore(ctx context.Context, log *logrus.Entry, datastoreId string) (*DatastoreRef, error) { - ds, err := storage.GetDatabase().GetMediaStore(ctx, log).GetDatastore(datastoreId) +func LocateDatastore(ctx rcontext.RequestContext, datastoreId string) (*DatastoreRef, error) { + ds, err := storage.GetDatabase().GetMediaStore(ctx).GetDatastore(datastoreId) if err != nil { return nil, err } @@ -47,8 +47,8 @@ func LocateDatastore(ctx context.Context, log *logrus.Entry, datastoreId string) return newDatastoreRef(ds, conf), nil } -func DownloadStream(ctx context.Context, log *logrus.Entry, datastoreId string, location string) (io.ReadCloser, error) { - ref, err := LocateDatastore(ctx, log, datastoreId) +func DownloadStream(ctx rcontext.RequestContext, datastoreId string, location string) (io.ReadCloser, error) { + ref, err := LocateDatastore(ctx, datastoreId) if err != nil { return nil, err } @@ -86,11 +86,11 @@ func GetUriForDatastore(dsConf config.DatastoreConfig) string { return "" } -func PickDatastore(forKind string, ctx context.Context, log *logrus.Entry) (*DatastoreRef, error) { +func PickDatastore(forKind string, ctx rcontext.RequestContext) (*DatastoreRef, error) { // If we haven't found a legacy option, pick a datastore - log.Info("Finding a suitable datastore to pick for uploads") + ctx.Log.Info("Finding a suitable datastore to pick for uploads") confDatastores := config.Get().DataStores - mediaStore := storage.GetDatabase().GetMediaStore(ctx, log) + mediaStore := storage.GetDatabase().GetMediaStore(ctx) var targetDs *types.Datastore var targetDsConf config.DatastoreConfig @@ -101,7 +101,7 @@ func PickDatastore(forKind string, ctx context.Context, log *logrus.Entry) (*Dat } if len(dsConf.MediaKinds) == 0 && dsConf.ForUploads { - log.Warnf("Datastore of type %s is using a deprecated flag (forUploads) - please use forKinds instead", dsConf.Type) + ctx.Log.Warnf("Datastore of type %s is using a deprecated flag (forUploads) - please use forKinds instead", dsConf.Type) dsConf.MediaKinds = common.AllKinds } @@ -121,7 +121,7 @@ func PickDatastore(forKind string, ctx context.Context, log *logrus.Entry) (*Dat continue } - size, err := estimatedDatastoreSize(ds, ctx, log) + size, err := estimatedDatastoreSize(ds, ctx) if err != nil { continue } @@ -134,13 +134,13 @@ func PickDatastore(forKind string, ctx context.Context, log *logrus.Entry) (*Dat } if targetDs != nil { - logrus.Info("Using ", targetDs.Uri) + ctx.Log.Info("Using ", targetDs.Uri) return newDatastoreRef(targetDs, targetDsConf), nil } return nil, errors.New("failed to pick a datastore: none available") } -func estimatedDatastoreSize(ds *types.Datastore, ctx context.Context, log *logrus.Entry) (int64, error) { - return storage.GetDatabase().GetMetadataStore(ctx, log).GetEstimatedSizeOfDatastore(ds.DatastoreId) +func estimatedDatastoreSize(ds *types.Datastore, ctx rcontext.RequestContext) (int64, error) { + return storage.GetDatabase().GetMetadataStore(ctx).GetEstimatedSizeOfDatastore(ds.DatastoreId) } diff --git a/storage/datastore/datastore_ref.go b/storage/datastore/datastore_ref.go index d1d8c7576409bdbeed0c8709ad229da784909a2d..919d1386b9e79bd43e8ff8f22e7ba593f2289bd5 100644 --- a/storage/datastore/datastore_ref.go +++ b/storage/datastore/datastore_ref.go @@ -1,14 +1,14 @@ package datastore import ( - "context" "errors" "io" "os" "path" "github.com/sirupsen/logrus" - "github.com/turt2live/matrix-media-repo/common/config" + config2 "github.com/turt2live/matrix-media-repo/common/config" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/storage/datastore/ds_file" "github.com/turt2live/matrix-media-repo/storage/datastore/ds_s3" "github.com/turt2live/matrix-media-repo/types" @@ -22,10 +22,10 @@ type DatastoreRef struct { Uri string datastore *types.Datastore - config config.DatastoreConfig + config config2.DatastoreConfig } -func newDatastoreRef(ds *types.Datastore, config config.DatastoreConfig) *DatastoreRef { +func newDatastoreRef(ds *types.Datastore, config config2.DatastoreConfig) *DatastoreRef { return &DatastoreRef{ DatastoreId: ds.DatastoreId, Type: ds.Type, @@ -35,17 +35,17 @@ func newDatastoreRef(ds *types.Datastore, config config.DatastoreConfig) *Datast } } -func (d *DatastoreRef) UploadFile(file io.ReadCloser, expectedLength int64, ctx context.Context, log *logrus.Entry) (*types.ObjectInfo, error) { - log = log.WithFields(logrus.Fields{"datastoreId": d.DatastoreId, "datastoreUri": d.Uri}) +func (d *DatastoreRef) UploadFile(file io.ReadCloser, expectedLength int64, ctx rcontext.RequestContext) (*types.ObjectInfo, error) { + ctx = ctx.LogWithFields(logrus.Fields{"datastoreId": d.DatastoreId, "datastoreUri": d.Uri}) if d.Type == "file" { - return ds_file.PersistFile(d.Uri, file, ctx, log) + return ds_file.PersistFile(d.Uri, file, ctx) } else if d.Type == "s3" { s3, err := ds_s3.GetOrCreateS3Datastore(d.DatastoreId, d.config) if err != nil { return nil, err } - return s3.UploadFile(file, expectedLength, ctx, log) + return s3.UploadFile(file, expectedLength, ctx) } else { return nil, errors.New("unknown datastore type") } @@ -97,9 +97,9 @@ func (d *DatastoreRef) ObjectExists(location string) bool { } } -func (d *DatastoreRef) OverwriteObject(location string, stream io.ReadCloser, ctx context.Context, log *logrus.Entry) error { +func (d *DatastoreRef) OverwriteObject(location string, stream io.ReadCloser, ctx rcontext.RequestContext) error { if d.Type == "file" { - _, _, err := ds_file.PersistFileAtLocation(path.Join(d.Uri, location), stream, ctx, log) + _, _, err := ds_file.PersistFileAtLocation(path.Join(d.Uri, location), stream, ctx) return err } else if d.Type == "s3" { s3, err := ds_s3.GetOrCreateS3Datastore(d.DatastoreId, d.config) diff --git a/storage/datastore/ds_file/file_store.go b/storage/datastore/ds_file/file_store.go index b2a91eb47d16105fef24b201c7491129d9ab084a..327bd842420844cfc5e0009363b43d090e4529e0 100644 --- a/storage/datastore/ds_file/file_store.go +++ b/storage/datastore/ds_file/file_store.go @@ -1,19 +1,18 @@ package ds_file import ( - "context" "errors" "io" "io/ioutil" "os" "path" - "github.com/sirupsen/logrus" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/types" "github.com/turt2live/matrix-media-repo/util" ) -func PersistFile(basePath string, file io.ReadCloser, ctx context.Context, log *logrus.Entry) (*types.ObjectInfo, error) { +func PersistFile(basePath string, file io.ReadCloser, ctx rcontext.RequestContext) (*types.ObjectInfo, error) { defer file.Close() exists := true @@ -35,13 +34,13 @@ func PersistFile(basePath string, file io.ReadCloser, ctx context.Context, log * targetDir = path.Join(basePath, primaryContainer, secondaryContainer) targetFile = path.Join(targetDir, fileName) - log.Info("Checking if file exists: " + targetFile) + ctx.Log.Info("Checking if file exists: " + targetFile) exists, err = util.FileExists(targetFile) attempts++ if err != nil { - log.Error("Error checking if the file exists: " + err.Error()) + ctx.Log.Error("Error checking if the file exists: " + err.Error()) } // Infinite loop protection @@ -55,7 +54,7 @@ func PersistFile(basePath string, file io.ReadCloser, ctx context.Context, log * return nil, err } - sizeBytes, hash, err := PersistFileAtLocation(targetFile, file, ctx, log) + sizeBytes, hash, err := PersistFileAtLocation(targetFile, file, ctx) if err != nil { return nil, err } @@ -68,7 +67,7 @@ func PersistFile(basePath string, file io.ReadCloser, ctx context.Context, log * }, nil } -func PersistFileAtLocation(targetFile string, file io.ReadCloser, ctx context.Context, log *logrus.Entry) (int64, string, error) { +func PersistFileAtLocation(targetFile string, file io.ReadCloser, ctx rcontext.RequestContext) (int64, string, error) { defer file.Close() f, err := os.OpenFile(targetFile, os.O_WRONLY|os.O_CREATE, 0644) @@ -90,16 +89,16 @@ func PersistFileAtLocation(targetFile string, file io.ReadCloser, ctx context.Co go func() { defer wfile.Close() - log.Info("Calculating hash of stream...") + ctx.Log.Info("Calculating hash of stream...") hash, hashErr = util.GetSha256HashOfStream(ioutil.NopCloser(tr)) - log.Info("Hash of file is ", hash) + ctx.Log.Info("Hash of file is ", hash) done <- true }() go func() { - log.Info("Writing file...") + ctx.Log.Info("Writing file...") sizeBytes, writeErr = io.Copy(f, rfile) - log.Info("Wrote ", sizeBytes, " bytes to file") + ctx.Log.Info("Wrote ", sizeBytes, " bytes to file") done <- true }() diff --git a/storage/datastore/ds_s3/s3_store.go b/storage/datastore/ds_s3/s3_store.go index e29c833d6f357e23c2b76e09568e31bbf6f5c4f5..389a756393251a1d6d8161c723ca45db51090289 100644 --- a/storage/datastore/ds_s3/s3_store.go +++ b/storage/datastore/ds_s3/s3_store.go @@ -1,7 +1,6 @@ package ds_s3 import ( - "context" "fmt" "io" "io/ioutil" @@ -13,6 +12,7 @@ import ( "github.com/pkg/errors" "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/common/config" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/types" "github.com/turt2live/matrix-media-repo/util" ) @@ -110,7 +110,7 @@ func (s *s3Datastore) EnsureTempPathExists() error { return nil } -func (s *s3Datastore) UploadFile(file io.ReadCloser, expectedLength int64, ctx context.Context, log *logrus.Entry) (*types.ObjectInfo, error) { +func (s *s3Datastore) UploadFile(file io.ReadCloser, expectedLength int64, ctx rcontext.RequestContext) (*types.ObjectInfo, error) { defer file.Close() objectName, err := util.GenerateRandomString(512) @@ -133,16 +133,16 @@ func (s *s3Datastore) UploadFile(file io.ReadCloser, expectedLength int64, ctx c go func() { defer ws3.Close() - log.Info("Calculating hash of stream...") + ctx.Log.Info("Calculating hash of stream...") hash, hashErr = util.GetSha256HashOfStream(ioutil.NopCloser(tr)) - log.Info("Hash of file is ", hash) + ctx.Log.Info("Hash of file is ", hash) done <- true }() go func() { if expectedLength <= 0 { if s.tempPath != "" { - log.Info("Buffering file to temp path due to unknown file size") + ctx.Log.Info("Buffering file to temp path due to unknown file size") var f *os.File f, uploadErr = ioutil.TempFile(s.tempPath, "mr*") if uploadErr != nil { @@ -165,13 +165,13 @@ func (s *s3Datastore) UploadFile(file io.ReadCloser, expectedLength int64, ctx c rs3 = f defer f.Close() } else { - log.Warn("Uploading content of unknown length to s3 - this could result in high memory usage") + ctx.Log.Warn("Uploading content of unknown length to s3 - this could result in high memory usage") expectedLength = -1 } } - log.Info("Uploading file...") + ctx.Log.Info("Uploading file...") sizeBytes, uploadErr = s.client.PutObjectWithContext(ctx, s.bucket, objectName, rs3, expectedLength, minio.PutObjectOptions{}) - log.Info("Uploaded ", sizeBytes, " bytes to s3") + ctx.Log.Info("Uploaded ", sizeBytes, " bytes to s3") done <- true }() diff --git a/storage/ds_utils.go b/storage/ds_utils.go index a9635da98b68d764aad00677d6b751fd83551a94..6c8a553fea6d488c72d715be9e7e49eb8380f2e7 100644 --- a/storage/ds_utils.go +++ b/storage/ds_utils.go @@ -1,22 +1,22 @@ package storage import ( - "context" "database/sql" "github.com/sirupsen/logrus" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/storage/stores" "github.com/turt2live/matrix-media-repo/types" "github.com/turt2live/matrix-media-repo/util" ) -func GetOrCreateDatastoreOfType(ctx context.Context, log *logrus.Entry, dsType string, dsUri string) (*types.Datastore, error) { - mediaService := GetDatabase().GetMediaStore(ctx, log) +func GetOrCreateDatastoreOfType(ctx rcontext.RequestContext, dsType string, dsUri string) (*types.Datastore, error) { + mediaService := GetDatabase().GetMediaStore(ctx) datastore, err := mediaService.GetDatastoreByUri(dsUri) if err != nil && err == sql.ErrNoRows { id, err2 := util.GenerateRandomString(32) if err2 != nil { - logrus.Error("Error generating datastore ID for URI ", dsUri, ": ", err) + ctx.Log.Error("Error generating datastore ID for URI ", dsUri, ": ", err) return nil, err2 } datastore = &types.Datastore{ @@ -26,7 +26,7 @@ func GetOrCreateDatastoreOfType(ctx context.Context, log *logrus.Entry, dsType s } err2 = mediaService.InsertDatastore(datastore) if err2 != nil { - logrus.Error("Error creating datastore for URI ", dsUri, ": ", err) + ctx.Log.Error("Error creating datastore for URI ", dsUri, ": ", err) return nil, err2 } } diff --git a/storage/startup_migrations.go b/storage/startup_migrations.go index c6ed4d7448b771832ebb5f289f363beb000a97f4..e903132bef38ef5844924d2939122e8b7bc2165d 100644 --- a/storage/startup_migrations.go +++ b/storage/startup_migrations.go @@ -1,16 +1,16 @@ package storage import ( - "context" "path" "github.com/sirupsen/logrus" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/util" ) func populateThumbnailHashes(db *Database) error { - svc := db.GetThumbnailStore(context.TODO(), &logrus.Entry{}) - mediaSvc := db.GetMediaStore(context.TODO(), &logrus.Entry{}) + svc := db.GetThumbnailStore(rcontext.Initial()) + mediaSvc := db.GetMediaStore(rcontext.Initial()) thumbs, err := svc.GetAllWithoutHash() if err != nil { @@ -52,8 +52,8 @@ func populateThumbnailHashes(db *Database) error { func populateDatastores(db *Database) error { logrus.Info("Starting to populate datastores...") - thumbService := db.GetThumbnailStore(context.TODO(), &logrus.Entry{}) - mediaService := db.GetMediaStore(context.TODO(), &logrus.Entry{}) + thumbService := db.GetThumbnailStore(rcontext.Initial()) + mediaService := db.GetMediaStore(rcontext.Initial()) logrus.Info("Fetching thumbnails...") thumbs, err := thumbService.GetAllWithoutDatastore() diff --git a/storage/storage.go b/storage/storage.go index cc28826a35c882b613d4281ab14b944b4d9c7877..6fb7ba94509ec05d2b51e2a9f3450c72f6a00931 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -1,7 +1,6 @@ package storage import ( - "context" "database/sql" "sync" @@ -9,6 +8,7 @@ import ( _ "github.com/lib/pq" // postgres driver "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/common/config" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/storage/stores" ) @@ -109,22 +109,22 @@ func OpenDatabase(connectionString string, maxConns int, maxIdleConns int) error return nil } -func (d *Database) GetMediaStore(ctx context.Context, log *logrus.Entry) *stores.MediaStore { - return d.repos.mediaStore.Create(ctx, log) +func (d *Database) GetMediaStore(ctx rcontext.RequestContext) *stores.MediaStore { + return d.repos.mediaStore.Create(ctx) } -func (d *Database) GetThumbnailStore(ctx context.Context, log *logrus.Entry) *stores.ThumbnailStore { - return d.repos.thumbnailStore.New(ctx, log) +func (d *Database) GetThumbnailStore(ctx rcontext.RequestContext) *stores.ThumbnailStore { + return d.repos.thumbnailStore.New(ctx) } -func (d *Database) GetUrlStore(ctx context.Context, log *logrus.Entry) *stores.UrlStore { - return d.repos.urlStore.Create(ctx, log) +func (d *Database) GetUrlStore(ctx rcontext.RequestContext) *stores.UrlStore { + return d.repos.urlStore.Create(ctx) } -func (d *Database) GetMetadataStore(ctx context.Context, log *logrus.Entry) *stores.MetadataStore { - return d.repos.metadataStore.Create(ctx, log) +func (d *Database) GetMetadataStore(ctx rcontext.RequestContext) *stores.MetadataStore { + return d.repos.metadataStore.Create(ctx) } -func (d *Database) GetExportStore(ctx context.Context, log *logrus.Entry) *stores.ExportStore { - return d.repos.exportStore.Create(ctx, log) +func (d *Database) GetExportStore(ctx rcontext.RequestContext) *stores.ExportStore { + return d.repos.exportStore.Create(ctx) } diff --git a/storage/stores/export_store.go b/storage/stores/export_store.go index 87d094d6b9592190f03babf1ab508b9f2f705449..9fc233ae954049de41ea55cf9acc04f6bcfc4907 100644 --- a/storage/stores/export_store.go +++ b/storage/stores/export_store.go @@ -1,10 +1,9 @@ package stores import ( - "context" "database/sql" - "github.com/sirupsen/logrus" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/types" ) @@ -33,8 +32,7 @@ type ExportStoreFactory struct { type ExportStore struct { factory *ExportStoreFactory // just for reference - ctx context.Context - log *logrus.Entry + ctx rcontext.RequestContext statements *exportStoreStatements // copied from factory } @@ -69,11 +67,10 @@ func InitExportStore(sqlDb *sql.DB) (*ExportStoreFactory, error) { return &store, nil } -func (f *ExportStoreFactory) Create(ctx context.Context, entry *logrus.Entry) *ExportStore { +func (f *ExportStoreFactory) Create(ctx rcontext.RequestContext) *ExportStore { return &ExportStore{ factory: f, ctx: ctx, - log: entry, statements: f.stmts, // we copy this intentionally } } diff --git a/storage/stores/media_store.go b/storage/stores/media_store.go index 65a15c27882a48bb0df2901c2ec7eaa04c64e1a0..b838eb27f5f090b737d81143dda2490b108c2ba1 100644 --- a/storage/stores/media_store.go +++ b/storage/stores/media_store.go @@ -1,12 +1,11 @@ package stores import ( - "context" "database/sql" "sync" "github.com/lib/pq" - "github.com/sirupsen/logrus" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/types" ) @@ -67,8 +66,7 @@ type MediaStoreFactory struct { type MediaStore struct { factory *MediaStoreFactory // just for reference - ctx context.Context - log *logrus.Entry + ctx rcontext.RequestContext statements *mediaStoreStatements // copied from factory } @@ -142,11 +140,10 @@ func InitMediaStore(sqlDb *sql.DB) (*MediaStoreFactory, error) { return &store, nil } -func (f *MediaStoreFactory) Create(ctx context.Context, entry *logrus.Entry) *MediaStore { +func (f *MediaStoreFactory) Create(ctx rcontext.RequestContext) *MediaStore { return &MediaStore{ factory: f, ctx: ctx, - log: entry, statements: f.stmts, // we copy this intentionally } } diff --git a/storage/stores/metadata_store.go b/storage/stores/metadata_store.go index ab63822969275de2766a94153e9875b78196de25..a47d30dc8b28ef642a3cf8fbe37a2bb36ae6dd6f 100644 --- a/storage/stores/metadata_store.go +++ b/storage/stores/metadata_store.go @@ -1,11 +1,10 @@ package stores import ( - "context" "database/sql" "encoding/json" - "github.com/sirupsen/logrus" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/types" "github.com/turt2live/matrix-media-repo/util" ) @@ -57,8 +56,7 @@ type MetadataStoreFactory struct { type MetadataStore struct { factory *MetadataStoreFactory // just for reference - ctx context.Context - log *logrus.Entry + ctx rcontext.RequestContext statements *metadataStoreStatements // copied from factory } @@ -120,11 +118,10 @@ func InitMetadataStore(sqlDb *sql.DB) (*MetadataStoreFactory, error) { return &store, nil } -func (f *MetadataStoreFactory) Create(ctx context.Context, entry *logrus.Entry) *MetadataStore { +func (f *MetadataStoreFactory) Create(ctx rcontext.RequestContext) *MetadataStore { return &MetadataStore{ factory: f, ctx: ctx, - log: entry, statements: f.stmts, // we copy this intentionally } } diff --git a/storage/stores/thumbnail_store.go b/storage/stores/thumbnail_store.go index 4502147ffcf9eb06258ab3853a6938386540f447..f5fa028bf1744f50e9adcb7b744cb6a395f04dd3 100644 --- a/storage/stores/thumbnail_store.go +++ b/storage/stores/thumbnail_store.go @@ -1,10 +1,9 @@ package stores import ( - "context" "database/sql" - "github.com/sirupsen/logrus" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/types" ) @@ -35,8 +34,7 @@ type ThumbnailStoreFactory struct { type ThumbnailStore struct { factory *ThumbnailStoreFactory // just for reference - ctx context.Context - log *logrus.Entry + ctx rcontext.RequestContext statements *thumbnailStatements // copied from factory } @@ -74,11 +72,10 @@ func InitThumbnailStore(sqlDb *sql.DB) (*ThumbnailStoreFactory, error) { return &store, nil } -func (f *ThumbnailStoreFactory) New(ctx context.Context, entry *logrus.Entry) *ThumbnailStore { +func (f *ThumbnailStoreFactory) New(ctx rcontext.RequestContext) *ThumbnailStore { return &ThumbnailStore{ factory: f, ctx: ctx, - log: entry, statements: f.stmts, // we copy this intentionally } } diff --git a/storage/stores/url_store.go b/storage/stores/url_store.go index 4e321ac0a2300b073cea22a48d8e208940c38b5b..53163dac30a7a7e6ad369d86e57c979456ed641e 100644 --- a/storage/stores/url_store.go +++ b/storage/stores/url_store.go @@ -1,10 +1,9 @@ package stores import ( - "context" "database/sql" - "github.com/sirupsen/logrus" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/types" "github.com/turt2live/matrix-media-repo/util" ) @@ -24,8 +23,7 @@ type UrlStoreFactory struct { type UrlStore struct { factory *UrlStoreFactory // just for reference - ctx context.Context - log *logrus.Entry + ctx rcontext.RequestContext statements *urlStatements // copied from factory } @@ -45,11 +43,10 @@ func InitUrlStore(sqlDb *sql.DB) (*UrlStoreFactory, error) { return &store, nil } -func (f *UrlStoreFactory) Create(ctx context.Context, entry *logrus.Entry) *UrlStore { +func (f *UrlStoreFactory) Create(ctx rcontext.RequestContext) *UrlStore { return &UrlStore{ factory: f, ctx: ctx, - log: entry, statements: f.stmts, // we copy this intentionally } } diff --git a/util/config.go b/util/config.go index 05c762cfd69ab4d14bd8727e8099be81c13cbce0..3bf820c56773ac23cf50615528ccdc6378b1286a 100644 --- a/util/config.go +++ b/util/config.go @@ -1,6 +1,8 @@ package util -import "github.com/turt2live/matrix-media-repo/common/config" +import ( + "github.com/turt2live/matrix-media-repo/common/config" +) func IsServerOurs(server string) bool { hs := GetHomeserverConfig(server) diff --git a/util/util_exif/exif.go b/util/util_exif/exif.go index eef4352c80760228c08adb210ad17082284bd4dd..7403bb59e7d95547e203871be0fa3fcd9155424c 100644 --- a/util/util_exif/exif.go +++ b/util/util_exif/exif.go @@ -1,12 +1,11 @@ package util_exif import ( - "context" "fmt" "github.com/pkg/errors" "github.com/rwcarlsen/goexif/exif" - "github.com/sirupsen/logrus" + "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/storage/datastore" "github.com/turt2live/matrix-media-repo/types" ) @@ -22,7 +21,7 @@ func GetExifOrientation(media *types.Media) (*ExifOrientation, error) { return nil, errors.New("image is not a jpeg") } - mediaStream, err := datastore.DownloadStream(context.TODO(), &logrus.Entry{}, media.DatastoreId, media.Location) + mediaStream, err := datastore.DownloadStream(rcontext.Initial(), media.DatastoreId, media.Location) if err != nil { return nil, err }