From 86e7e6e245241655ca97b031211cf4ba0e4ded8a Mon Sep 17 00:00:00 2001 From: Travis Ralston <travpc@gmail.com> Date: Sat, 10 Jun 2023 18:32:42 -0600 Subject: [PATCH] Support MSC2246 PUT /upload requests Closes https://github.com/turt2live/matrix-media-repo/issues/411 For https://github.com/turt2live/matrix-media-repo/issues/407 --- api/_routers/98-use-rcontext.go | 3 + api/r0/upload.go | 53 +++++++++------- api/r0/upload_async.go | 86 ++++++++++++++++++++++++++ api/routes.go | 1 + common/errorcodes.go | 1 + common/errors.go | 3 + database/table_expiring_media.go | 26 ++++++++ pipelines/pipeline_upload/pipeline2.go | 55 ++++++++++++++++ 8 files changed, 205 insertions(+), 23 deletions(-) create mode 100644 api/r0/upload_async.go create mode 100644 pipelines/pipeline_upload/pipeline2.go diff --git a/api/_routers/98-use-rcontext.go b/api/_routers/98-use-rcontext.go index fcbb025d..ba774e81 100644 --- a/api/_routers/98-use-rcontext.go +++ b/api/_routers/98-use-rcontext.go @@ -169,6 +169,9 @@ beforeParseDownload: case common.ErrCodeForbidden: proposedStatusCode = http.StatusForbidden break + case common.ErrCodeCannotOverwrite: + proposedStatusCode = http.StatusConflict + break default: // Treat as unknown (a generic server error) proposedStatusCode = http.StatusInternalServerError break diff --git a/api/r0/upload.go b/api/r0/upload.go index 4734ba0a..51418721 100644 --- a/api/r0/upload.go +++ b/api/r0/upload.go @@ -18,7 +18,7 @@ import ( ) type MediaUploadedResponse struct { - ContentUri string `json:"content_uri"` + ContentUri string `json:"content_uri,omitempty"` Blurhash string `json:"xyz.amorgan.blurhash,omitempty"` } @@ -35,6 +35,34 @@ func UploadMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.Us } // Early sizing constraints (reject requests which claim to be too large/small) + if sizeRes := uploadRequestSizeCheck(rctx, r); sizeRes != nil { + return sizeRes + } + + // Actually upload + media, err := pipeline_upload.Execute(rctx, r.Host, "", r.Body, contentType, filename, user.UserId, datastores.LocalMediaKind) + if err != nil { + if err == common.ErrQuotaExceeded { + return _responses.QuotaExceeded() + } + rctx.Log.Error("Unexpected error uploading media: " + err.Error()) + sentry.CaptureException(err) + return _responses.InternalServerError("Unexpected Error") + } + + blurhash, err := database.GetInstance().Blurhashes.Prepare(rctx).Get(media.Sha256Hash) + if err != nil { + rctx.Log.Warn("Unexpected error getting media's blurhash from DB: " + err.Error()) + sentry.CaptureException(err) + } + + return &MediaUploadedResponse{ + ContentUri: util.MxcUri(media.Origin, media.MediaId), + Blurhash: blurhash, + } +} + +func uploadRequestSizeCheck(rctx rcontext.RequestContext, r *http.Request) *_responses.ErrorResponse { maxSize := rctx.Config.Uploads.MaxSizeBytes minSize := rctx.Config.Uploads.MinSizeBytes if maxSize > 0 || minSize > 0 { @@ -58,26 +86,5 @@ func UploadMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.Us } } } - - // Actually upload - media, err := pipeline_upload.Execute(rctx, r.Host, "", r.Body, contentType, filename, user.UserId, datastores.LocalMediaKind) - if err != nil { - if err == common.ErrQuotaExceeded { - return _responses.QuotaExceeded() - } - rctx.Log.Error("Unexpected error uploading media: " + err.Error()) - sentry.CaptureException(err) - return _responses.InternalServerError("Unexpected Error") - } - - blurhash, err := database.GetInstance().Blurhashes.Prepare(rctx).Get(media.Sha256Hash) - if err != nil { - rctx.Log.Error("Unexpected error getting media's blurhash from DB: " + err.Error()) - sentry.CaptureException(err) - } - - return &MediaUploadedResponse{ - ContentUri: util.MxcUri(media.Origin, media.MediaId), - Blurhash: blurhash, - } + return nil } diff --git a/api/r0/upload_async.go b/api/r0/upload_async.go new file mode 100644 index 00000000..05f90584 --- /dev/null +++ b/api/r0/upload_async.go @@ -0,0 +1,86 @@ +package r0 + +import ( + "net/http" + "path/filepath" + + "github.com/getsentry/sentry-go" + "github.com/sirupsen/logrus" + "github.com/turt2live/matrix-media-repo/api/_apimeta" + "github.com/turt2live/matrix-media-repo/api/_responses" + "github.com/turt2live/matrix-media-repo/api/_routers" + "github.com/turt2live/matrix-media-repo/common" + "github.com/turt2live/matrix-media-repo/common/rcontext" + "github.com/turt2live/matrix-media-repo/database" + "github.com/turt2live/matrix-media-repo/pipelines/pipeline_upload" +) + +func UploadMediaAsync(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} { + server := _routers.GetParam("server", r) + mediaId := _routers.GetParam("mediaId", r) + filename := filepath.Base(r.URL.Query().Get("filename")) + + rctx = rctx.LogWithFields(logrus.Fields{ + "mediaId": mediaId, + "server": server, + "filename": filename, + }) + + if r.Host != server { + return _responses.ErrorResponse{ + Code: common.ErrCodeNotFound, + Message: "Upload request is for another domain.", + InternalCode: common.ErrCodeForbidden, + } + } + + contentType := r.Header.Get("Content-Type") + if contentType == "" { + contentType = "application/octet-stream" // binary + } + + // Early sizing constraints (reject requests which claim to be too large/small) + if sizeRes := uploadRequestSizeCheck(rctx, r); sizeRes != nil { + return sizeRes + } + + // Actually upload + media, err := pipeline_upload.ExecutePut(rctx, server, mediaId, r.Body, contentType, filename, user.UserId) + if err != nil { + if err == common.ErrQuotaExceeded { + return _responses.QuotaExceeded() + } else if err == common.ErrAlreadyUploaded { + return _responses.ErrorResponse{ + Code: common.ErrCodeCannotOverwrite, + Message: "This media has already been uploaded.", + InternalCode: common.ErrCodeCannotOverwrite, + } + } else if err == common.ErrWrongUser { + return _responses.ErrorResponse{ + Code: common.ErrCodeForbidden, + Message: "You do not have permission to upload this media.", + InternalCode: common.ErrCodeForbidden, + } + } else if err == common.ErrExpired { + return _responses.ErrorResponse{ + Code: common.ErrCodeNotFound, + Message: "Media expired or not found.", + InternalCode: common.ErrCodeNotFound, + } + } + rctx.Log.Error("Unexpected error uploading media: " + err.Error()) + sentry.CaptureException(err) + return _responses.InternalServerError("Unexpected Error") + } + + blurhash, err := database.GetInstance().Blurhashes.Prepare(rctx).Get(media.Sha256Hash) + if err != nil { + rctx.Log.Warn("Unexpected error getting media's blurhash from DB: " + err.Error()) + sentry.CaptureException(err) + } + + return &MediaUploadedResponse{ + //ContentUri: util.MxcUri(media.Origin, media.MediaId), // This endpoint doesn't return a URI + Blurhash: blurhash, + } +} diff --git a/api/routes.go b/api/routes.go index 766cfc30..e16dbae2 100644 --- a/api/routes.go +++ b/api/routes.go @@ -31,6 +31,7 @@ func buildRoutes() http.Handler { // Standard (spec) features register([]string{"POST"}, PrefixMedia, "upload", false, router, makeRoute(_routers.RequireAccessToken(r0.UploadMedia), "upload", false, counter)) + register([]string{"PUT"}, PrefixMedia, "upload/:server/:mediaId", false, router, makeRoute(_routers.RequireAccessToken(r0.UploadMediaAsync), "upload_async", false, counter)) downloadRoute := makeRoute(_routers.OptionalAccessToken(r0.DownloadMedia), "download", false, counter) register([]string{"GET"}, PrefixMedia, "download/:server/:mediaId/:filename", false, router, downloadRoute) register([]string{"GET"}, PrefixMedia, "download/:server/:mediaId", false, router, downloadRoute) diff --git a/common/errorcodes.go b/common/errorcodes.go index 599b32c9..0be63484 100644 --- a/common/errorcodes.go +++ b/common/errorcodes.go @@ -16,3 +16,4 @@ const ErrCodeRateLimitExceeded = "M_LIMIT_EXCEEDED" const ErrCodeUnknown = "M_UNKNOWN" const ErrCodeForbidden = "M_FORBIDDEN" const ErrCodeQuotaExceeded = "M_QUOTA_EXCEEDED" +const ErrCodeCannotOverwrite = "M_CANNOT_OVERWRITE_MEDIA" diff --git a/common/errors.go b/common/errors.go index c7712e84..3da519c9 100644 --- a/common/errors.go +++ b/common/errors.go @@ -11,3 +11,6 @@ var ErrHostNotFound = errors.New("host not found") var ErrHostBlacklisted = errors.New("host not allowed") var ErrMediaQuarantined = errors.New("media quarantined") var ErrQuotaExceeded = errors.New("quota exceeded") +var ErrWrongUser = errors.New("wrong user") +var ErrExpired = errors.New("expired") +var ErrAlreadyUploaded = errors.New("already uploaded") diff --git a/database/table_expiring_media.go b/database/table_expiring_media.go index daac75e9..dd1f6546 100644 --- a/database/table_expiring_media.go +++ b/database/table_expiring_media.go @@ -17,10 +17,14 @@ type DbExpiringMedia struct { const insertExpiringMedia = "INSERT INTO expiring_media (origin, media_id, user_id, expires_ts) VALUES ($1, $2, $3, $4);" const selectExpiringMediaByUserCount = "SELECT COUNT(*) FROM expiring_media WHERE user_id = $1 AND expires_ts >= $2;" +const selectExpiringMediaById = "SELECT origin, media_id, user_id, expires_ts FROM expiring_media WHERE origin = $1 AND media_id = $2;" +const deleteExpiringMediaById = "DELETE FROM expiring_media WHERE origin = $1 AND media_id = $2;" type expiringMediaTableStatements struct { insertExpiringMedia *sql.Stmt selectExpiringMediaByUserCount *sql.Stmt + selectExpiringMediaById *sql.Stmt + deleteExpiringMediaById *sql.Stmt } type expiringMediaTableWithContext struct { @@ -38,6 +42,12 @@ func prepareExpiringMediaTables(db *sql.DB) (*expiringMediaTableStatements, erro if stmts.selectExpiringMediaByUserCount, err = db.Prepare(selectExpiringMediaByUserCount); err != nil { return nil, errors.New("error preparing selectExpiringMediaByUserCount: " + err.Error()) } + if stmts.selectExpiringMediaById, err = db.Prepare(selectExpiringMediaById); err != nil { + return nil, errors.New("error preparing selectExpiringMediaById: " + err.Error()) + } + if stmts.deleteExpiringMediaById, err = db.Prepare(deleteExpiringMediaById); err != nil { + return nil, errors.New("error preparing deleteExpiringMediaById: " + err.Error()) + } return stmts, nil } @@ -64,3 +74,19 @@ func (s *expiringMediaTableWithContext) ByUserCount(userId string) (int64, error } return val, err } + +func (s *expiringMediaTableWithContext) Get(origin string, mediaId string) (*DbExpiringMedia, error) { + row := s.statements.selectExpiringMediaById.QueryRowContext(s.ctx, origin, mediaId) + val := &DbExpiringMedia{} + err := row.Scan(&val.Origin, &val.MediaId, &val.UserId, &val.ExpiresTs) + if err == sql.ErrNoRows { + err = nil + val = nil + } + return val, err +} + +func (s *expiringMediaTableWithContext) Delete(origin string, mediaId string) error { + _, err := s.statements.deleteExpiringMediaById.ExecContext(s.ctx, origin, mediaId) + return err +} diff --git a/pipelines/pipeline_upload/pipeline2.go b/pipelines/pipeline_upload/pipeline2.go new file mode 100644 index 00000000..6d89c1b4 --- /dev/null +++ b/pipelines/pipeline_upload/pipeline2.go @@ -0,0 +1,55 @@ +package pipeline_upload + +import ( + "io" + + "github.com/getsentry/sentry-go" + "github.com/turt2live/matrix-media-repo/common" + "github.com/turt2live/matrix-media-repo/common/rcontext" + "github.com/turt2live/matrix-media-repo/database" + "github.com/turt2live/matrix-media-repo/datastores" + "github.com/turt2live/matrix-media-repo/util" +) + +func ExecutePut(ctx rcontext.RequestContext, origin string, mediaId string, r io.ReadCloser, contentType string, fileName string, userId string) (*database.DbMedia, error) { + // Step 1: Do we already have a media record for this? + mediaDb := database.GetInstance().Media.Prepare(ctx) + mediaRecord, err := mediaDb.GetById(origin, mediaId) + if err != nil { + return nil, err + } + if mediaRecord != nil { + return nil, common.ErrAlreadyUploaded + } + + // Step 2: Try to find the holding record + expiringDb := database.GetInstance().ExpiringMedia.Prepare(ctx) + record, err := expiringDb.Get(origin, mediaId) + if err != nil { + return nil, err + } + + // Step 3: Is the record expired? + if record == nil || record.ExpiresTs < util.NowMillis() { + return nil, common.ErrExpired + } + + // Step 4: Is the correct user uploading this media? + if record.UserId != userId { + return nil, common.ErrWrongUser + } + + // Step 5: Do the upload + newRecord, err := Execute(ctx, origin, mediaId, r, contentType, fileName, userId, datastores.LocalMediaKind) + if err != nil { + return nil, err + } + + // Step 6: Delete the holding record + if err2 := expiringDb.Delete(origin, mediaId); err2 != nil { + ctx.Log.Warn("Non-fatal error while deleting expiring media record: " + err2.Error()) + sentry.CaptureException(err2) + } + + return newRecord, err +} -- GitLab