Skip to content
Snippets Groups Projects
Commit cc988987 authored by Travis Ralston's avatar Travis Ralston
Browse files

WIP on new upload pipeline

parent 883c84ef
No related branches found
No related tags found
No related merge requests found
package upload_pipeline
import (
"bytes"
"errors"
"io"
"io/ioutil"
"github.com/turt2live/matrix-media-repo/common/rcontext"
"github.com/turt2live/matrix-media-repo/types"
"github.com/turt2live/matrix-media-repo/util/cleanup"
)
func UploadMedia(ctx rcontext.RequestContext, origin string, mediaId string, r io.ReadCloser, contentType string, fileName string, userId string) (*types.Media, error) {
defer cleanup.DumpAndCloseStream(r)
// Step 1: Limit the stream's length
r = limitStreamLength(ctx, r)
// Step 2: Buffer the stream
b, err := bufferStream(ctx, r)
if err != nil {
return nil, err
}
// Create a utility function for getting at the buffer easily
stream := func() io.ReadCloser {
return ioutil.NopCloser(bytes.NewBuffer(b))
}
// Step 3: Get a hash of the file
hash, err := hashFile(ctx, stream())
if err != nil {
return nil, err
}
// Step 4: Check if the media is quarantined
err = checkQuarantineStatus(ctx, hash)
if err != nil {
return nil, err
}
// Step 5: Generate a media ID if we need to
if mediaId == "" {
mediaId, err = generateMediaID(ctx, origin)
if err != nil {
return nil, err
}
}
// Step 6: De-duplicate the media
// TODO: Implementation. Check to see if uploading is required, also if the user has already uploaded a copy.
// Step 7: Cache the file before uploading
// TODO
// Step 8: Prepare an async job to persist the media
// TODO: Implementation. Limit the number of concurrent jobs on this to avoid queue flooding.
// TODO: Should this be configurable?
// TODO: Handle partial uploads/incomplete uploads.
// Step 9: Return the media while it gets persisted
// TODO
return nil, errors.New("not yet implemented")
}
\ No newline at end of file
package upload_pipeline
import (
"io"
"io/ioutil"
"github.com/turt2live/matrix-media-repo/common/rcontext"
)
func bufferStream(ctx rcontext.RequestContext, r io.ReadCloser) ([]byte, error) {
return ioutil.ReadAll(r)
}
package upload_pipeline
import (
"github.com/turt2live/matrix-media-repo/common"
"github.com/turt2live/matrix-media-repo/common/rcontext"
"github.com/turt2live/matrix-media-repo/storage"
)
func checkQuarantineStatus(ctx rcontext.RequestContext, hash string) error {
db := storage.GetDatabase().GetMediaStore(ctx)
q, err := db.IsQuarantined(hash)
if err != nil {
return err
}
if q {
return common.ErrMediaQuarantined
}
return nil
}
package upload_pipeline
import (
"errors"
"strconv"
"time"
"github.com/patrickmn/go-cache"
"github.com/turt2live/matrix-media-repo/common/rcontext"
"github.com/turt2live/matrix-media-repo/storage"
"github.com/turt2live/matrix-media-repo/util"
)
var recentMediaIds = cache.New(30*time.Second, 60*time.Second)
func generateMediaID(ctx rcontext.RequestContext, origin string) (string, error) {
metadataDb := storage.GetDatabase().GetMetadataStore(ctx)
mediaTaken := true
var mediaId string
var err error
attempts := 0
for mediaTaken {
attempts += 1
if attempts > 10 {
return "", errors.New("failed to generate a media ID after 10 rounds")
}
mediaId, err = util.GenerateRandomString(64)
if err != nil {
return "", err
}
mediaId, err = util.GetSha1OfString(mediaId + strconv.FormatInt(util.NowMillis(), 10))
if err != nil {
return "", err
}
// Because we use the current time in the media ID, we don't need to worry about
// collisions from the database.
if _, present := recentMediaIds.Get(mediaId); present {
mediaTaken = true
continue
}
mediaTaken, err = metadataDb.IsReserved(origin, mediaId)
if err != nil {
return "", err
}
}
_ = recentMediaIds.Add(mediaId, true, cache.DefaultExpiration)
return mediaId, nil
}
package upload_pipeline
import (
"io"
"github.com/turt2live/matrix-media-repo/common/rcontext"
"github.com/turt2live/matrix-media-repo/util"
)
func hashFile(ctx rcontext.RequestContext, r io.ReadCloser) (string, error) {
return util.GetSha256HashOfStream(r)
}
package upload_pipeline
import (
"io"
"io/ioutil"
"github.com/turt2live/matrix-media-repo/common/rcontext"
)
func limitStreamLength(ctx rcontext.RequestContext, r io.ReadCloser) io.ReadCloser {
if ctx.Config.Uploads.MaxSizeBytes > 0 {
return ioutil.NopCloser(io.LimitReader(r, ctx.Config.Uploads.MaxSizeBytes))
} else {
return r
}
}
......@@ -31,6 +31,7 @@ const selectMediaByUser = "SELECT origin, media_id, upload_name, content_type, u
const selectMediaByUserBefore = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, datastore_id, location, creation_ts, quarantined FROM media WHERE user_id = $1 AND creation_ts <= $2"
const selectMediaByDomainBefore = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, datastore_id, location, creation_ts, quarantined FROM media WHERE origin = $1 AND creation_ts <= $2"
const selectMediaByLocation = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, datastore_id, location, creation_ts, quarantined FROM media WHERE datastore_id = $1 AND location = $2"
const selectIfQuarantined = "SELECT 1 FROM media WHERE sha256_hash = $1 AND quarantined = $2 LIMIT 1;"
var dsCacheByPath = sync.Map{} // [string] => Datastore
var dsCacheById = sync.Map{} // [string] => Datastore
......@@ -59,6 +60,7 @@ type mediaStoreStatements struct {
selectMediaByUserBefore *sql.Stmt
selectMediaByDomainBefore *sql.Stmt
selectMediaByLocation *sql.Stmt
selectIfQuarantined *sql.Stmt
}
type MediaStoreFactory struct {
......@@ -144,6 +146,9 @@ func InitMediaStore(sqlDb *sql.DB) (*MediaStoreFactory, error) {
if store.stmts.selectMediaByLocation, err = store.sqlDb.Prepare(selectMediaByLocation); err != nil {
return nil, err
}
if store.stmts.selectIfQuarantined, err = store.sqlDb.Prepare(selectIfQuarantined); err != nil {
return nil, err
}
return &store, nil
}
......@@ -702,3 +707,15 @@ func (s *MediaStore) GetMediaByLocation(datastoreId string, location string) ([]
return results, nil
}
func (s *MediaStore) IsQuarantined(sha256hash string) (bool, error) {
r := s.statements.selectIfQuarantined.QueryRow(sha256hash, true)
var i int
err := r.Scan(&i)
if err == sql.ErrNoRows {
return false, nil
} else if err != nil {
return false, err
}
return true, nil
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment