From 86e7e6e245241655ca97b031211cf4ba0e4ded8a Mon Sep 17 00:00:00 2001
From: Travis Ralston <travpc@gmail.com>
Date: Sat, 10 Jun 2023 18:32:42 -0600
Subject: [PATCH] Support MSC2246 PUT /upload requests

Closes https://github.com/turt2live/matrix-media-repo/issues/411
For https://github.com/turt2live/matrix-media-repo/issues/407
---
 api/_routers/98-use-rcontext.go        |  3 +
 api/r0/upload.go                       | 53 +++++++++-------
 api/r0/upload_async.go                 | 86 ++++++++++++++++++++++++++
 api/routes.go                          |  1 +
 common/errorcodes.go                   |  1 +
 common/errors.go                       |  3 +
 database/table_expiring_media.go       | 26 ++++++++
 pipelines/pipeline_upload/pipeline2.go | 55 ++++++++++++++++
 8 files changed, 205 insertions(+), 23 deletions(-)
 create mode 100644 api/r0/upload_async.go
 create mode 100644 pipelines/pipeline_upload/pipeline2.go

diff --git a/api/_routers/98-use-rcontext.go b/api/_routers/98-use-rcontext.go
index fcbb025d..ba774e81 100644
--- a/api/_routers/98-use-rcontext.go
+++ b/api/_routers/98-use-rcontext.go
@@ -169,6 +169,9 @@ beforeParseDownload:
 		case common.ErrCodeForbidden:
 			proposedStatusCode = http.StatusForbidden
 			break
+		case common.ErrCodeCannotOverwrite:
+			proposedStatusCode = http.StatusConflict
+			break
 		default: // Treat as unknown (a generic server error)
 			proposedStatusCode = http.StatusInternalServerError
 			break
diff --git a/api/r0/upload.go b/api/r0/upload.go
index 4734ba0a..51418721 100644
--- a/api/r0/upload.go
+++ b/api/r0/upload.go
@@ -18,7 +18,7 @@ import (
 )
 
 type MediaUploadedResponse struct {
-	ContentUri string `json:"content_uri"`
+	ContentUri string `json:"content_uri,omitempty"`
 	Blurhash   string `json:"xyz.amorgan.blurhash,omitempty"`
 }
 
@@ -35,6 +35,34 @@ func UploadMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.Us
 	}
 
 	// Early sizing constraints (reject requests which claim to be too large/small)
+	if sizeRes := uploadRequestSizeCheck(rctx, r); sizeRes != nil {
+		return sizeRes
+	}
+
+	// Actually upload
+	media, err := pipeline_upload.Execute(rctx, r.Host, "", r.Body, contentType, filename, user.UserId, datastores.LocalMediaKind)
+	if err != nil {
+		if err == common.ErrQuotaExceeded {
+			return _responses.QuotaExceeded()
+		}
+		rctx.Log.Error("Unexpected error uploading media: " + err.Error())
+		sentry.CaptureException(err)
+		return _responses.InternalServerError("Unexpected Error")
+	}
+
+	blurhash, err := database.GetInstance().Blurhashes.Prepare(rctx).Get(media.Sha256Hash)
+	if err != nil {
+		rctx.Log.Warn("Unexpected error getting media's blurhash from DB: " + err.Error())
+		sentry.CaptureException(err)
+	}
+
+	return &MediaUploadedResponse{
+		ContentUri: util.MxcUri(media.Origin, media.MediaId),
+		Blurhash:   blurhash,
+	}
+}
+
+func uploadRequestSizeCheck(rctx rcontext.RequestContext, r *http.Request) *_responses.ErrorResponse {
 	maxSize := rctx.Config.Uploads.MaxSizeBytes
 	minSize := rctx.Config.Uploads.MinSizeBytes
 	if maxSize > 0 || minSize > 0 {
@@ -58,26 +86,5 @@ func UploadMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.Us
 			}
 		}
 	}
-
-	// Actually upload
-	media, err := pipeline_upload.Execute(rctx, r.Host, "", r.Body, contentType, filename, user.UserId, datastores.LocalMediaKind)
-	if err != nil {
-		if err == common.ErrQuotaExceeded {
-			return _responses.QuotaExceeded()
-		}
-		rctx.Log.Error("Unexpected error uploading media: " + err.Error())
-		sentry.CaptureException(err)
-		return _responses.InternalServerError("Unexpected Error")
-	}
-
-	blurhash, err := database.GetInstance().Blurhashes.Prepare(rctx).Get(media.Sha256Hash)
-	if err != nil {
-		rctx.Log.Error("Unexpected error getting media's blurhash from DB: " + err.Error())
-		sentry.CaptureException(err)
-	}
-
-	return &MediaUploadedResponse{
-		ContentUri: util.MxcUri(media.Origin, media.MediaId),
-		Blurhash:   blurhash,
-	}
+	return nil
 }
diff --git a/api/r0/upload_async.go b/api/r0/upload_async.go
new file mode 100644
index 00000000..05f90584
--- /dev/null
+++ b/api/r0/upload_async.go
@@ -0,0 +1,86 @@
+package r0
+
+import (
+	"net/http"
+	"path/filepath"
+
+	"github.com/getsentry/sentry-go"
+	"github.com/sirupsen/logrus"
+	"github.com/turt2live/matrix-media-repo/api/_apimeta"
+	"github.com/turt2live/matrix-media-repo/api/_responses"
+	"github.com/turt2live/matrix-media-repo/api/_routers"
+	"github.com/turt2live/matrix-media-repo/common"
+	"github.com/turt2live/matrix-media-repo/common/rcontext"
+	"github.com/turt2live/matrix-media-repo/database"
+	"github.com/turt2live/matrix-media-repo/pipelines/pipeline_upload"
+)
+
+func UploadMediaAsync(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} {
+	server := _routers.GetParam("server", r)
+	mediaId := _routers.GetParam("mediaId", r)
+	filename := filepath.Base(r.URL.Query().Get("filename"))
+
+	rctx = rctx.LogWithFields(logrus.Fields{
+		"mediaId":  mediaId,
+		"server":   server,
+		"filename": filename,
+	})
+
+	if r.Host != server {
+		return _responses.ErrorResponse{
+			Code:         common.ErrCodeNotFound,
+			Message:      "Upload request is for another domain.",
+			InternalCode: common.ErrCodeForbidden,
+		}
+	}
+
+	contentType := r.Header.Get("Content-Type")
+	if contentType == "" {
+		contentType = "application/octet-stream" // binary
+	}
+
+	// Early sizing constraints (reject requests which claim to be too large/small)
+	if sizeRes := uploadRequestSizeCheck(rctx, r); sizeRes != nil {
+		return sizeRes
+	}
+
+	// Actually upload
+	media, err := pipeline_upload.ExecutePut(rctx, server, mediaId, r.Body, contentType, filename, user.UserId)
+	if err != nil {
+		if err == common.ErrQuotaExceeded {
+			return _responses.QuotaExceeded()
+		} else if err == common.ErrAlreadyUploaded {
+			return _responses.ErrorResponse{
+				Code:         common.ErrCodeCannotOverwrite,
+				Message:      "This media has already been uploaded.",
+				InternalCode: common.ErrCodeCannotOverwrite,
+			}
+		} else if err == common.ErrWrongUser {
+			return _responses.ErrorResponse{
+				Code:         common.ErrCodeForbidden,
+				Message:      "You do not have permission to upload this media.",
+				InternalCode: common.ErrCodeForbidden,
+			}
+		} else if err == common.ErrExpired {
+			return _responses.ErrorResponse{
+				Code:         common.ErrCodeNotFound,
+				Message:      "Media expired or not found.",
+				InternalCode: common.ErrCodeNotFound,
+			}
+		}
+		rctx.Log.Error("Unexpected error uploading media: " + err.Error())
+		sentry.CaptureException(err)
+		return _responses.InternalServerError("Unexpected Error")
+	}
+
+	blurhash, err := database.GetInstance().Blurhashes.Prepare(rctx).Get(media.Sha256Hash)
+	if err != nil {
+		rctx.Log.Warn("Unexpected error getting media's blurhash from DB: " + err.Error())
+		sentry.CaptureException(err)
+	}
+
+	return &MediaUploadedResponse{
+		//ContentUri: util.MxcUri(media.Origin, media.MediaId), // This endpoint doesn't return a URI
+		Blurhash: blurhash,
+	}
+}
diff --git a/api/routes.go b/api/routes.go
index 766cfc30..e16dbae2 100644
--- a/api/routes.go
+++ b/api/routes.go
@@ -31,6 +31,7 @@ func buildRoutes() http.Handler {
 
 	// Standard (spec) features
 	register([]string{"POST"}, PrefixMedia, "upload", false, router, makeRoute(_routers.RequireAccessToken(r0.UploadMedia), "upload", false, counter))
+	register([]string{"PUT"}, PrefixMedia, "upload/:server/:mediaId", false, router, makeRoute(_routers.RequireAccessToken(r0.UploadMediaAsync), "upload_async", false, counter))
 	downloadRoute := makeRoute(_routers.OptionalAccessToken(r0.DownloadMedia), "download", false, counter)
 	register([]string{"GET"}, PrefixMedia, "download/:server/:mediaId/:filename", false, router, downloadRoute)
 	register([]string{"GET"}, PrefixMedia, "download/:server/:mediaId", false, router, downloadRoute)
diff --git a/common/errorcodes.go b/common/errorcodes.go
index 599b32c9..0be63484 100644
--- a/common/errorcodes.go
+++ b/common/errorcodes.go
@@ -16,3 +16,4 @@ const ErrCodeRateLimitExceeded = "M_LIMIT_EXCEEDED"
 const ErrCodeUnknown = "M_UNKNOWN"
 const ErrCodeForbidden = "M_FORBIDDEN"
 const ErrCodeQuotaExceeded = "M_QUOTA_EXCEEDED"
+const ErrCodeCannotOverwrite = "M_CANNOT_OVERWRITE_MEDIA"
diff --git a/common/errors.go b/common/errors.go
index c7712e84..3da519c9 100644
--- a/common/errors.go
+++ b/common/errors.go
@@ -11,3 +11,6 @@ var ErrHostNotFound = errors.New("host not found")
 var ErrHostBlacklisted = errors.New("host not allowed")
 var ErrMediaQuarantined = errors.New("media quarantined")
 var ErrQuotaExceeded = errors.New("quota exceeded")
+var ErrWrongUser = errors.New("wrong user")
+var ErrExpired = errors.New("expired")
+var ErrAlreadyUploaded = errors.New("already uploaded")
diff --git a/database/table_expiring_media.go b/database/table_expiring_media.go
index daac75e9..dd1f6546 100644
--- a/database/table_expiring_media.go
+++ b/database/table_expiring_media.go
@@ -17,10 +17,14 @@ type DbExpiringMedia struct {
 
 const insertExpiringMedia = "INSERT INTO expiring_media (origin, media_id, user_id, expires_ts) VALUES ($1, $2, $3, $4);"
 const selectExpiringMediaByUserCount = "SELECT COUNT(*) FROM expiring_media WHERE user_id = $1 AND expires_ts >= $2;"
+const selectExpiringMediaById = "SELECT origin, media_id, user_id, expires_ts FROM expiring_media WHERE origin = $1 AND media_id = $2;"
+const deleteExpiringMediaById = "DELETE FROM expiring_media WHERE origin = $1 AND media_id = $2;"
 
 type expiringMediaTableStatements struct {
 	insertExpiringMedia            *sql.Stmt
 	selectExpiringMediaByUserCount *sql.Stmt
+	selectExpiringMediaById        *sql.Stmt
+	deleteExpiringMediaById        *sql.Stmt
 }
 
 type expiringMediaTableWithContext struct {
@@ -38,6 +42,12 @@ func prepareExpiringMediaTables(db *sql.DB) (*expiringMediaTableStatements, erro
 	if stmts.selectExpiringMediaByUserCount, err = db.Prepare(selectExpiringMediaByUserCount); err != nil {
 		return nil, errors.New("error preparing selectExpiringMediaByUserCount: " + err.Error())
 	}
+	if stmts.selectExpiringMediaById, err = db.Prepare(selectExpiringMediaById); err != nil {
+		return nil, errors.New("error preparing selectExpiringMediaById: " + err.Error())
+	}
+	if stmts.deleteExpiringMediaById, err = db.Prepare(deleteExpiringMediaById); err != nil {
+		return nil, errors.New("error preparing deleteExpiringMediaById: " + err.Error())
+	}
 
 	return stmts, nil
 }
@@ -64,3 +74,19 @@ func (s *expiringMediaTableWithContext) ByUserCount(userId string) (int64, error
 	}
 	return val, err
 }
+
+func (s *expiringMediaTableWithContext) Get(origin string, mediaId string) (*DbExpiringMedia, error) {
+	row := s.statements.selectExpiringMediaById.QueryRowContext(s.ctx, origin, mediaId)
+	val := &DbExpiringMedia{}
+	err := row.Scan(&val.Origin, &val.MediaId, &val.UserId, &val.ExpiresTs)
+	if err == sql.ErrNoRows {
+		err = nil
+		val = nil
+	}
+	return val, err
+}
+
+func (s *expiringMediaTableWithContext) Delete(origin string, mediaId string) error {
+	_, err := s.statements.deleteExpiringMediaById.ExecContext(s.ctx, origin, mediaId)
+	return err
+}
diff --git a/pipelines/pipeline_upload/pipeline2.go b/pipelines/pipeline_upload/pipeline2.go
new file mode 100644
index 00000000..6d89c1b4
--- /dev/null
+++ b/pipelines/pipeline_upload/pipeline2.go
@@ -0,0 +1,55 @@
+package pipeline_upload
+
+import (
+	"io"
+
+	"github.com/getsentry/sentry-go"
+	"github.com/turt2live/matrix-media-repo/common"
+	"github.com/turt2live/matrix-media-repo/common/rcontext"
+	"github.com/turt2live/matrix-media-repo/database"
+	"github.com/turt2live/matrix-media-repo/datastores"
+	"github.com/turt2live/matrix-media-repo/util"
+)
+
+func ExecutePut(ctx rcontext.RequestContext, origin string, mediaId string, r io.ReadCloser, contentType string, fileName string, userId string) (*database.DbMedia, error) {
+	// Step 1: Do we already have a media record for this?
+	mediaDb := database.GetInstance().Media.Prepare(ctx)
+	mediaRecord, err := mediaDb.GetById(origin, mediaId)
+	if err != nil {
+		return nil, err
+	}
+	if mediaRecord != nil {
+		return nil, common.ErrAlreadyUploaded
+	}
+
+	// Step 2: Try to find the holding record
+	expiringDb := database.GetInstance().ExpiringMedia.Prepare(ctx)
+	record, err := expiringDb.Get(origin, mediaId)
+	if err != nil {
+		return nil, err
+	}
+
+	// Step 3: Is the record expired?
+	if record == nil || record.ExpiresTs < util.NowMillis() {
+		return nil, common.ErrExpired
+	}
+
+	// Step 4: Is the correct user uploading this media?
+	if record.UserId != userId {
+		return nil, common.ErrWrongUser
+	}
+
+	// Step 5: Do the upload
+	newRecord, err := Execute(ctx, origin, mediaId, r, contentType, fileName, userId, datastores.LocalMediaKind)
+	if err != nil {
+		return nil, err
+	}
+
+	// Step 6: Delete the holding record
+	if err2 := expiringDb.Delete(origin, mediaId); err2 != nil {
+		ctx.Log.Warn("Non-fatal error while deleting expiring media record: " + err2.Error())
+		sentry.CaptureException(err2)
+	}
+
+	return newRecord, err
+}
-- 
GitLab