-
Travis Ralston authoredTravis Ralston authored
pipeline.go 5.81 KiB
package pipeline_upload
import (
"errors"
"io"
"github.com/getsentry/sentry-go"
"github.com/turt2live/matrix-media-repo/common"
"github.com/turt2live/matrix-media-repo/common/config"
"github.com/turt2live/matrix-media-repo/common/rcontext"
"github.com/turt2live/matrix-media-repo/database"
"github.com/turt2live/matrix-media-repo/datastores"
"github.com/turt2live/matrix-media-repo/notifier"
"github.com/turt2live/matrix-media-repo/pipelines/_steps/meta"
"github.com/turt2live/matrix-media-repo/pipelines/_steps/quota"
"github.com/turt2live/matrix-media-repo/pipelines/_steps/upload"
"github.com/turt2live/matrix-media-repo/util"
"github.com/turt2live/matrix-media-repo/util/readers"
)
// Execute Media upload. If mediaId is an empty string, one will be generated.
func Execute(ctx rcontext.RequestContext, origin string, mediaId string, r io.ReadCloser, contentType string, fileName string, userId string, kind datastores.Kind) (*database.DbMedia, error) {
uploadDone := func(record *database.DbMedia) {
meta.FlagAccess(ctx, record.Sha256Hash)
if err := notifier.UploadDone(ctx, record); err != nil {
ctx.Log.Warn("Non-fatal error notifying about completed upload: ", err)
sentry.CaptureException(err)
}
}
// Step 1: Limit the stream's length
r = upload.LimitStream(ctx, r)
// Step 2: Create a media ID (if needed)
mustUseMediaId := true
if mediaId == "" {
var err error
mediaId, err = upload.GenerateMediaId(ctx, origin)
if err != nil {
return nil, err
}
mustUseMediaId = false
}
// Step 3: Pick a datastore
dsConf, err := datastores.Pick(ctx, kind)
if err != nil {
return nil, err
}
// Step 4: Buffer to the datastore's temporary path, and check for spam
spamR, spamW := io.Pipe()
spamTee := io.TeeReader(r, spamW)
spamChan := upload.CheckSpamAsync(ctx, spamR, upload.FileMetadata{
Name: fileName,
ContentType: contentType,
UserId: userId,
Origin: origin,
MediaId: mediaId,
})
sha256hash, sizeBytes, reader, err := datastores.BufferTemp(dsConf, readers.NewCancelCloser(io.NopCloser(spamTee), func() {
r.Close()
}))
if err != nil {
return nil, err
}
if err = spamW.Close(); err != nil {
ctx.Log.Warn("Failed to close writer for spam checker: ", err)
spamChan <- upload.SpamResponse{Err: errors.New("failed to close")}
}
defer reader.Close()
spam := <-spamChan
if spam.Err != nil {
return nil, err
}
if spam.IsSpam {
return nil, common.ErrMediaQuarantined
}
// Step 5: Split the buffer to calculate a blurhash & populate cache later
bhR, bhW := io.Pipe()
cacheR, cacheW := io.Pipe()
allWriters := io.MultiWriter(cacheW, bhW)
tee := io.TeeReader(reader, allWriters)
defer func(bhW *io.PipeWriter, err error) {
_ = bhW.CloseWithError(err)
}(bhW, errors.New("failed to finish write"))
defer func(cacheW *io.PipeWriter, err error) {
_ = cacheW.CloseWithError(err)
}(cacheW, errors.New("failed to finish write"))
// Step 6: Check quarantine
if err = upload.CheckQuarantineStatus(ctx, sha256hash); err != nil {
return nil, err
}
// Step 7: Ensure user can upload within quota
if userId != "" && !config.Runtime.IsImportProcess {
err = quota.CanUpload(ctx, userId, sizeBytes)
if err != nil {
return nil, err
}
}
// Step 8: Acquire a lock on the media hash for uploading
unlockFn, err := upload.LockForUpload(ctx, sha256hash)
if err != nil {
return nil, err
}
//goland:noinspection GoUnhandledErrorResult
defer unlockFn()
// Step 9: Pull all upload records (to check if an upload has already happened)
newRecord := &database.DbMedia{
Origin: origin,
MediaId: mediaId,
UploadName: fileName,
ContentType: contentType,
UserId: userId,
SizeBytes: sizeBytes,
CreationTs: util.NowMillis(),
Quarantined: false,
Locatable: &database.Locatable{
Sha256Hash: sha256hash,
DatastoreId: "", // Populated later
Location: "", // Populated later
},
}
record, perfect, err := upload.FindRecord(ctx, sha256hash, userId, contentType, fileName)
if err != nil {
return nil, err
}
if record != nil {
// We already had this record in some capacity
if perfect && !mustUseMediaId {
// Exact match - deduplicate, skip upload to datastore
return record, nil
} else {
// We already uploaded it somewhere else - use the datastore ID and location
newRecord.Quarantined = record.Quarantined // just in case (shouldn't be a different value by here)
newRecord.DatastoreId = record.DatastoreId
newRecord.Location = record.Location
if err = database.GetInstance().Media.Prepare(ctx).Insert(newRecord); err != nil {
return nil, err
}
uploadDone(newRecord)
return newRecord, nil
}
}
// Step 10: Asynchronously calculate blurhash
bhChan := upload.CalculateBlurhashAsync(ctx, bhR, sizeBytes, sha256hash)
// Step 11: Asynchronously upload to cache
cacheChan := upload.PopulateCacheAsync(ctx, cacheR, sizeBytes, sha256hash)
// Step 12: Since we didn't find a duplicate, upload it to the datastore
dsLocation, err := datastores.Upload(ctx, dsConf, io.NopCloser(tee), sizeBytes, contentType, sha256hash)
if err != nil {
return nil, err
}
if err = bhW.Close(); err != nil {
ctx.Log.Warn("Failed to close writer for blurhash: ", err)
close(bhChan)
}
if err = cacheW.Close(); err != nil {
ctx.Log.Warn("Failed to close writer for cache layer: ", err)
close(cacheChan)
}
// Step 13: Wait for channels
<-bhChan
<-cacheChan
// Step 14: Everything finally looks good - return some stuff
newRecord.DatastoreId = dsConf.Id
newRecord.Location = dsLocation
if err = database.GetInstance().Media.Prepare(ctx).Insert(newRecord); err != nil {
if err2 := datastores.Remove(ctx, dsConf, dsLocation); err2 != nil {
sentry.CaptureException(err2)
ctx.Log.Warn("Error deleting upload (delete attempted due to persistence error): ", err2)
}
return nil, err
}
uploadDone(newRecord)
return newRecord, nil
}