diff --git a/pipline/upload_pipeline/pipeline.go b/pipline/upload_pipeline/pipeline.go new file mode 100644 index 0000000000000000000000000000000000000000..a2c46d02e804eb17137227bcffce4ffebac247fb --- /dev/null +++ b/pipline/upload_pipeline/pipeline.go @@ -0,0 +1,66 @@ +package upload_pipeline + +import ( + "bytes" + "errors" + "io" + "io/ioutil" + + "github.com/turt2live/matrix-media-repo/common/rcontext" + "github.com/turt2live/matrix-media-repo/types" + "github.com/turt2live/matrix-media-repo/util/cleanup" +) + +func UploadMedia(ctx rcontext.RequestContext, origin string, mediaId string, r io.ReadCloser, contentType string, fileName string, userId string) (*types.Media, error) { + defer cleanup.DumpAndCloseStream(r) + + // Step 1: Limit the stream's length + r = limitStreamLength(ctx, r) + + // Step 2: Buffer the stream + b, err := bufferStream(ctx, r) + if err != nil { + return nil, err + } + + // Create a utility function for getting at the buffer easily + stream := func() io.ReadCloser { + return ioutil.NopCloser(bytes.NewBuffer(b)) + } + + // Step 3: Get a hash of the file + hash, err := hashFile(ctx, stream()) + if err != nil { + return nil, err + } + + // Step 4: Check if the media is quarantined + err = checkQuarantineStatus(ctx, hash) + if err != nil { + return nil, err + } + + // Step 5: Generate a media ID if we need to + if mediaId == "" { + mediaId, err = generateMediaID(ctx, origin) + if err != nil { + return nil, err + } + } + + // Step 6: De-duplicate the media + // TODO: Implementation. Check to see if uploading is required, also if the user has already uploaded a copy. + + // Step 7: Cache the file before uploading + // TODO + + // Step 8: Prepare an async job to persist the media + // TODO: Implementation. Limit the number of concurrent jobs on this to avoid queue flooding. + // TODO: Should this be configurable? + // TODO: Handle partial uploads/incomplete uploads. + + // Step 9: Return the media while it gets persisted + // TODO + + return nil, errors.New("not yet implemented") +} \ No newline at end of file diff --git a/pipline/upload_pipeline/step_buffer.go b/pipline/upload_pipeline/step_buffer.go new file mode 100644 index 0000000000000000000000000000000000000000..5af2ea3c45c5dc93ef4ca99c51accb812ce4b988 --- /dev/null +++ b/pipline/upload_pipeline/step_buffer.go @@ -0,0 +1,12 @@ +package upload_pipeline + +import ( + "io" + "io/ioutil" + + "github.com/turt2live/matrix-media-repo/common/rcontext" +) + +func bufferStream(ctx rcontext.RequestContext, r io.ReadCloser) ([]byte, error) { + return ioutil.ReadAll(r) +} diff --git a/pipline/upload_pipeline/step_check_quarantine.go b/pipline/upload_pipeline/step_check_quarantine.go new file mode 100644 index 0000000000000000000000000000000000000000..a840ade027cc7b234da9832cd0cb6f52a5cd12dd --- /dev/null +++ b/pipline/upload_pipeline/step_check_quarantine.go @@ -0,0 +1,19 @@ +package upload_pipeline + +import ( + "github.com/turt2live/matrix-media-repo/common" + "github.com/turt2live/matrix-media-repo/common/rcontext" + "github.com/turt2live/matrix-media-repo/storage" +) + +func checkQuarantineStatus(ctx rcontext.RequestContext, hash string) error { + db := storage.GetDatabase().GetMediaStore(ctx) + q, err := db.IsQuarantined(hash) + if err != nil { + return err + } + if q { + return common.ErrMediaQuarantined + } + return nil +} diff --git a/pipline/upload_pipeline/step_gen_media_id.go b/pipline/upload_pipeline/step_gen_media_id.go new file mode 100644 index 0000000000000000000000000000000000000000..a2ce11acb280c640074ab09ee912dec82f133631 --- /dev/null +++ b/pipline/upload_pipeline/step_gen_media_id.go @@ -0,0 +1,53 @@ +package upload_pipeline + +import ( + "errors" + "strconv" + "time" + + "github.com/patrickmn/go-cache" + "github.com/turt2live/matrix-media-repo/common/rcontext" + "github.com/turt2live/matrix-media-repo/storage" + "github.com/turt2live/matrix-media-repo/util" +) + +var recentMediaIds = cache.New(30*time.Second, 60*time.Second) + +func generateMediaID(ctx rcontext.RequestContext, origin string) (string, error) { + metadataDb := storage.GetDatabase().GetMetadataStore(ctx) + mediaTaken := true + var mediaId string + var err error + attempts := 0 + for mediaTaken { + attempts += 1 + if attempts > 10 { + return "", errors.New("failed to generate a media ID after 10 rounds") + } + + mediaId, err = util.GenerateRandomString(64) + if err != nil { + return "", err + } + mediaId, err = util.GetSha1OfString(mediaId + strconv.FormatInt(util.NowMillis(), 10)) + if err != nil { + return "", err + } + + // Because we use the current time in the media ID, we don't need to worry about + // collisions from the database. + if _, present := recentMediaIds.Get(mediaId); present { + mediaTaken = true + continue + } + + mediaTaken, err = metadataDb.IsReserved(origin, mediaId) + if err != nil { + return "", err + } + } + + _ = recentMediaIds.Add(mediaId, true, cache.DefaultExpiration) + + return mediaId, nil +} diff --git a/pipline/upload_pipeline/step_hash.go b/pipline/upload_pipeline/step_hash.go new file mode 100644 index 0000000000000000000000000000000000000000..66072ee7ce7a5763508309e35e60479f4571e1f1 --- /dev/null +++ b/pipline/upload_pipeline/step_hash.go @@ -0,0 +1,12 @@ +package upload_pipeline + +import ( + "io" + + "github.com/turt2live/matrix-media-repo/common/rcontext" + "github.com/turt2live/matrix-media-repo/util" +) + +func hashFile(ctx rcontext.RequestContext, r io.ReadCloser) (string, error) { + return util.GetSha256HashOfStream(r) +} diff --git a/pipline/upload_pipeline/step_limited.go b/pipline/upload_pipeline/step_limited.go new file mode 100644 index 0000000000000000000000000000000000000000..4c0be6c9cb1a973a60bc36f84e9982cbbaaa1dc9 --- /dev/null +++ b/pipline/upload_pipeline/step_limited.go @@ -0,0 +1,16 @@ +package upload_pipeline + +import ( + "io" + "io/ioutil" + + "github.com/turt2live/matrix-media-repo/common/rcontext" +) + +func limitStreamLength(ctx rcontext.RequestContext, r io.ReadCloser) io.ReadCloser { + if ctx.Config.Uploads.MaxSizeBytes > 0 { + return ioutil.NopCloser(io.LimitReader(r, ctx.Config.Uploads.MaxSizeBytes)) + } else { + return r + } +} diff --git a/storage/stores/media_store.go b/storage/stores/media_store.go index b1eb5b3b1b7678903954d1042cf9d65ff2e572b8..cf00064cf6216497bf1149b20b511a5d386c26d4 100644 --- a/storage/stores/media_store.go +++ b/storage/stores/media_store.go @@ -31,6 +31,7 @@ const selectMediaByUser = "SELECT origin, media_id, upload_name, content_type, u const selectMediaByUserBefore = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, datastore_id, location, creation_ts, quarantined FROM media WHERE user_id = $1 AND creation_ts <= $2" const selectMediaByDomainBefore = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, datastore_id, location, creation_ts, quarantined FROM media WHERE origin = $1 AND creation_ts <= $2" const selectMediaByLocation = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, datastore_id, location, creation_ts, quarantined FROM media WHERE datastore_id = $1 AND location = $2" +const selectIfQuarantined = "SELECT 1 FROM media WHERE sha256_hash = $1 AND quarantined = $2 LIMIT 1;" var dsCacheByPath = sync.Map{} // [string] => Datastore var dsCacheById = sync.Map{} // [string] => Datastore @@ -59,6 +60,7 @@ type mediaStoreStatements struct { selectMediaByUserBefore *sql.Stmt selectMediaByDomainBefore *sql.Stmt selectMediaByLocation *sql.Stmt + selectIfQuarantined *sql.Stmt } type MediaStoreFactory struct { @@ -144,6 +146,9 @@ func InitMediaStore(sqlDb *sql.DB) (*MediaStoreFactory, error) { if store.stmts.selectMediaByLocation, err = store.sqlDb.Prepare(selectMediaByLocation); err != nil { return nil, err } + if store.stmts.selectIfQuarantined, err = store.sqlDb.Prepare(selectIfQuarantined); err != nil { + return nil, err + } return &store, nil } @@ -702,3 +707,15 @@ func (s *MediaStore) GetMediaByLocation(datastoreId string, location string) ([] return results, nil } + +func (s *MediaStore) IsQuarantined(sha256hash string) (bool, error) { + r := s.statements.selectIfQuarantined.QueryRow(sha256hash, true) + var i int + err := r.Scan(&i) + if err == sql.ErrNoRows { + return false, nil + } else if err != nil { + return false, err + } + return true, nil +}