Skip to content
Snippets Groups Projects
Commit 86e7e6e2 authored by Travis Ralston's avatar Travis Ralston
Browse files
parent 35642e3f
No related branches found
No related tags found
No related merge requests found
......@@ -169,6 +169,9 @@ beforeParseDownload:
case common.ErrCodeForbidden:
proposedStatusCode = http.StatusForbidden
break
case common.ErrCodeCannotOverwrite:
proposedStatusCode = http.StatusConflict
break
default: // Treat as unknown (a generic server error)
proposedStatusCode = http.StatusInternalServerError
break
......
......@@ -18,7 +18,7 @@ import (
)
type MediaUploadedResponse struct {
ContentUri string `json:"content_uri"`
ContentUri string `json:"content_uri,omitempty"`
Blurhash string `json:"xyz.amorgan.blurhash,omitempty"`
}
......@@ -35,6 +35,34 @@ func UploadMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.Us
}
// Early sizing constraints (reject requests which claim to be too large/small)
if sizeRes := uploadRequestSizeCheck(rctx, r); sizeRes != nil {
return sizeRes
}
// Actually upload
media, err := pipeline_upload.Execute(rctx, r.Host, "", r.Body, contentType, filename, user.UserId, datastores.LocalMediaKind)
if err != nil {
if err == common.ErrQuotaExceeded {
return _responses.QuotaExceeded()
}
rctx.Log.Error("Unexpected error uploading media: " + err.Error())
sentry.CaptureException(err)
return _responses.InternalServerError("Unexpected Error")
}
blurhash, err := database.GetInstance().Blurhashes.Prepare(rctx).Get(media.Sha256Hash)
if err != nil {
rctx.Log.Warn("Unexpected error getting media's blurhash from DB: " + err.Error())
sentry.CaptureException(err)
}
return &MediaUploadedResponse{
ContentUri: util.MxcUri(media.Origin, media.MediaId),
Blurhash: blurhash,
}
}
func uploadRequestSizeCheck(rctx rcontext.RequestContext, r *http.Request) *_responses.ErrorResponse {
maxSize := rctx.Config.Uploads.MaxSizeBytes
minSize := rctx.Config.Uploads.MinSizeBytes
if maxSize > 0 || minSize > 0 {
......@@ -58,26 +86,5 @@ func UploadMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.Us
}
}
}
// Actually upload
media, err := pipeline_upload.Execute(rctx, r.Host, "", r.Body, contentType, filename, user.UserId, datastores.LocalMediaKind)
if err != nil {
if err == common.ErrQuotaExceeded {
return _responses.QuotaExceeded()
}
rctx.Log.Error("Unexpected error uploading media: " + err.Error())
sentry.CaptureException(err)
return _responses.InternalServerError("Unexpected Error")
}
blurhash, err := database.GetInstance().Blurhashes.Prepare(rctx).Get(media.Sha256Hash)
if err != nil {
rctx.Log.Error("Unexpected error getting media's blurhash from DB: " + err.Error())
sentry.CaptureException(err)
}
return &MediaUploadedResponse{
ContentUri: util.MxcUri(media.Origin, media.MediaId),
Blurhash: blurhash,
}
return nil
}
package r0
import (
"net/http"
"path/filepath"
"github.com/getsentry/sentry-go"
"github.com/sirupsen/logrus"
"github.com/turt2live/matrix-media-repo/api/_apimeta"
"github.com/turt2live/matrix-media-repo/api/_responses"
"github.com/turt2live/matrix-media-repo/api/_routers"
"github.com/turt2live/matrix-media-repo/common"
"github.com/turt2live/matrix-media-repo/common/rcontext"
"github.com/turt2live/matrix-media-repo/database"
"github.com/turt2live/matrix-media-repo/pipelines/pipeline_upload"
)
func UploadMediaAsync(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} {
server := _routers.GetParam("server", r)
mediaId := _routers.GetParam("mediaId", r)
filename := filepath.Base(r.URL.Query().Get("filename"))
rctx = rctx.LogWithFields(logrus.Fields{
"mediaId": mediaId,
"server": server,
"filename": filename,
})
if r.Host != server {
return _responses.ErrorResponse{
Code: common.ErrCodeNotFound,
Message: "Upload request is for another domain.",
InternalCode: common.ErrCodeForbidden,
}
}
contentType := r.Header.Get("Content-Type")
if contentType == "" {
contentType = "application/octet-stream" // binary
}
// Early sizing constraints (reject requests which claim to be too large/small)
if sizeRes := uploadRequestSizeCheck(rctx, r); sizeRes != nil {
return sizeRes
}
// Actually upload
media, err := pipeline_upload.ExecutePut(rctx, server, mediaId, r.Body, contentType, filename, user.UserId)
if err != nil {
if err == common.ErrQuotaExceeded {
return _responses.QuotaExceeded()
} else if err == common.ErrAlreadyUploaded {
return _responses.ErrorResponse{
Code: common.ErrCodeCannotOverwrite,
Message: "This media has already been uploaded.",
InternalCode: common.ErrCodeCannotOverwrite,
}
} else if err == common.ErrWrongUser {
return _responses.ErrorResponse{
Code: common.ErrCodeForbidden,
Message: "You do not have permission to upload this media.",
InternalCode: common.ErrCodeForbidden,
}
} else if err == common.ErrExpired {
return _responses.ErrorResponse{
Code: common.ErrCodeNotFound,
Message: "Media expired or not found.",
InternalCode: common.ErrCodeNotFound,
}
}
rctx.Log.Error("Unexpected error uploading media: " + err.Error())
sentry.CaptureException(err)
return _responses.InternalServerError("Unexpected Error")
}
blurhash, err := database.GetInstance().Blurhashes.Prepare(rctx).Get(media.Sha256Hash)
if err != nil {
rctx.Log.Warn("Unexpected error getting media's blurhash from DB: " + err.Error())
sentry.CaptureException(err)
}
return &MediaUploadedResponse{
//ContentUri: util.MxcUri(media.Origin, media.MediaId), // This endpoint doesn't return a URI
Blurhash: blurhash,
}
}
......@@ -31,6 +31,7 @@ func buildRoutes() http.Handler {
// Standard (spec) features
register([]string{"POST"}, PrefixMedia, "upload", false, router, makeRoute(_routers.RequireAccessToken(r0.UploadMedia), "upload", false, counter))
register([]string{"PUT"}, PrefixMedia, "upload/:server/:mediaId", false, router, makeRoute(_routers.RequireAccessToken(r0.UploadMediaAsync), "upload_async", false, counter))
downloadRoute := makeRoute(_routers.OptionalAccessToken(r0.DownloadMedia), "download", false, counter)
register([]string{"GET"}, PrefixMedia, "download/:server/:mediaId/:filename", false, router, downloadRoute)
register([]string{"GET"}, PrefixMedia, "download/:server/:mediaId", false, router, downloadRoute)
......
......@@ -16,3 +16,4 @@ const ErrCodeRateLimitExceeded = "M_LIMIT_EXCEEDED"
const ErrCodeUnknown = "M_UNKNOWN"
const ErrCodeForbidden = "M_FORBIDDEN"
const ErrCodeQuotaExceeded = "M_QUOTA_EXCEEDED"
const ErrCodeCannotOverwrite = "M_CANNOT_OVERWRITE_MEDIA"
......@@ -11,3 +11,6 @@ var ErrHostNotFound = errors.New("host not found")
var ErrHostBlacklisted = errors.New("host not allowed")
var ErrMediaQuarantined = errors.New("media quarantined")
var ErrQuotaExceeded = errors.New("quota exceeded")
var ErrWrongUser = errors.New("wrong user")
var ErrExpired = errors.New("expired")
var ErrAlreadyUploaded = errors.New("already uploaded")
......@@ -17,10 +17,14 @@ type DbExpiringMedia struct {
const insertExpiringMedia = "INSERT INTO expiring_media (origin, media_id, user_id, expires_ts) VALUES ($1, $2, $3, $4);"
const selectExpiringMediaByUserCount = "SELECT COUNT(*) FROM expiring_media WHERE user_id = $1 AND expires_ts >= $2;"
const selectExpiringMediaById = "SELECT origin, media_id, user_id, expires_ts FROM expiring_media WHERE origin = $1 AND media_id = $2;"
const deleteExpiringMediaById = "DELETE FROM expiring_media WHERE origin = $1 AND media_id = $2;"
type expiringMediaTableStatements struct {
insertExpiringMedia *sql.Stmt
selectExpiringMediaByUserCount *sql.Stmt
selectExpiringMediaById *sql.Stmt
deleteExpiringMediaById *sql.Stmt
}
type expiringMediaTableWithContext struct {
......@@ -38,6 +42,12 @@ func prepareExpiringMediaTables(db *sql.DB) (*expiringMediaTableStatements, erro
if stmts.selectExpiringMediaByUserCount, err = db.Prepare(selectExpiringMediaByUserCount); err != nil {
return nil, errors.New("error preparing selectExpiringMediaByUserCount: " + err.Error())
}
if stmts.selectExpiringMediaById, err = db.Prepare(selectExpiringMediaById); err != nil {
return nil, errors.New("error preparing selectExpiringMediaById: " + err.Error())
}
if stmts.deleteExpiringMediaById, err = db.Prepare(deleteExpiringMediaById); err != nil {
return nil, errors.New("error preparing deleteExpiringMediaById: " + err.Error())
}
return stmts, nil
}
......@@ -64,3 +74,19 @@ func (s *expiringMediaTableWithContext) ByUserCount(userId string) (int64, error
}
return val, err
}
func (s *expiringMediaTableWithContext) Get(origin string, mediaId string) (*DbExpiringMedia, error) {
row := s.statements.selectExpiringMediaById.QueryRowContext(s.ctx, origin, mediaId)
val := &DbExpiringMedia{}
err := row.Scan(&val.Origin, &val.MediaId, &val.UserId, &val.ExpiresTs)
if err == sql.ErrNoRows {
err = nil
val = nil
}
return val, err
}
func (s *expiringMediaTableWithContext) Delete(origin string, mediaId string) error {
_, err := s.statements.deleteExpiringMediaById.ExecContext(s.ctx, origin, mediaId)
return err
}
package pipeline_upload
import (
"io"
"github.com/getsentry/sentry-go"
"github.com/turt2live/matrix-media-repo/common"
"github.com/turt2live/matrix-media-repo/common/rcontext"
"github.com/turt2live/matrix-media-repo/database"
"github.com/turt2live/matrix-media-repo/datastores"
"github.com/turt2live/matrix-media-repo/util"
)
func ExecutePut(ctx rcontext.RequestContext, origin string, mediaId string, r io.ReadCloser, contentType string, fileName string, userId string) (*database.DbMedia, error) {
// Step 1: Do we already have a media record for this?
mediaDb := database.GetInstance().Media.Prepare(ctx)
mediaRecord, err := mediaDb.GetById(origin, mediaId)
if err != nil {
return nil, err
}
if mediaRecord != nil {
return nil, common.ErrAlreadyUploaded
}
// Step 2: Try to find the holding record
expiringDb := database.GetInstance().ExpiringMedia.Prepare(ctx)
record, err := expiringDb.Get(origin, mediaId)
if err != nil {
return nil, err
}
// Step 3: Is the record expired?
if record == nil || record.ExpiresTs < util.NowMillis() {
return nil, common.ErrExpired
}
// Step 4: Is the correct user uploading this media?
if record.UserId != userId {
return nil, common.ErrWrongUser
}
// Step 5: Do the upload
newRecord, err := Execute(ctx, origin, mediaId, r, contentType, fileName, userId, datastores.LocalMediaKind)
if err != nil {
return nil, err
}
// Step 6: Delete the holding record
if err2 := expiringDb.Delete(origin, mediaId); err2 != nil {
ctx.Log.Warn("Non-fatal error while deleting expiring media record: " + err2.Error())
sentry.CaptureException(err2)
}
return newRecord, err
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment