Skip to content
Snippets Groups Projects
Commit fccd0df9 authored by Travis Ralston's avatar Travis Ralston
Browse files

Merge branch 'travis/new-pipelines'

parents 784db37f cc988987
No related branches found
No related tags found
No related merge requests found
package upload_pipeline
import (
"bytes"
"errors"
"io"
"io/ioutil"
"github.com/turt2live/matrix-media-repo/common/rcontext"
"github.com/turt2live/matrix-media-repo/types"
"github.com/turt2live/matrix-media-repo/util/cleanup"
)
func UploadMedia(ctx rcontext.RequestContext, origin string, mediaId string, r io.ReadCloser, contentType string, fileName string, userId string) (*types.Media, error) {
defer cleanup.DumpAndCloseStream(r)
// Step 1: Limit the stream's length
r = limitStreamLength(ctx, r)
// Step 2: Buffer the stream
b, err := bufferStream(ctx, r)
if err != nil {
return nil, err
}
// Create a utility function for getting at the buffer easily
stream := func() io.ReadCloser {
return ioutil.NopCloser(bytes.NewBuffer(b))
}
// Step 3: Get a hash of the file
hash, err := hashFile(ctx, stream())
if err != nil {
return nil, err
}
// Step 4: Check if the media is quarantined
err = checkQuarantineStatus(ctx, hash)
if err != nil {
return nil, err
}
// Step 5: Generate a media ID if we need to
if mediaId == "" {
mediaId, err = generateMediaID(ctx, origin)
if err != nil {
return nil, err
}
}
// Step 6: De-duplicate the media
// TODO: Implementation. Check to see if uploading is required, also if the user has already uploaded a copy.
// Step 7: Cache the file before uploading
// TODO
// Step 8: Prepare an async job to persist the media
// TODO: Implementation. Limit the number of concurrent jobs on this to avoid queue flooding.
// TODO: Should this be configurable?
// TODO: Handle partial uploads/incomplete uploads.
// Step 9: Return the media while it gets persisted
// TODO
return nil, errors.New("not yet implemented")
}
\ No newline at end of file
package upload_pipeline
import (
"io"
"io/ioutil"
"github.com/turt2live/matrix-media-repo/common/rcontext"
)
func bufferStream(ctx rcontext.RequestContext, r io.ReadCloser) ([]byte, error) {
return ioutil.ReadAll(r)
}
package upload_pipeline
import (
"github.com/turt2live/matrix-media-repo/common"
"github.com/turt2live/matrix-media-repo/common/rcontext"
"github.com/turt2live/matrix-media-repo/storage"
)
func checkQuarantineStatus(ctx rcontext.RequestContext, hash string) error {
db := storage.GetDatabase().GetMediaStore(ctx)
q, err := db.IsQuarantined(hash)
if err != nil {
return err
}
if q {
return common.ErrMediaQuarantined
}
return nil
}
package upload_pipeline
import (
"errors"
"strconv"
"time"
"github.com/patrickmn/go-cache"
"github.com/turt2live/matrix-media-repo/common/rcontext"
"github.com/turt2live/matrix-media-repo/storage"
"github.com/turt2live/matrix-media-repo/util"
)
var recentMediaIds = cache.New(30*time.Second, 60*time.Second)
func generateMediaID(ctx rcontext.RequestContext, origin string) (string, error) {
metadataDb := storage.GetDatabase().GetMetadataStore(ctx)
mediaTaken := true
var mediaId string
var err error
attempts := 0
for mediaTaken {
attempts += 1
if attempts > 10 {
return "", errors.New("failed to generate a media ID after 10 rounds")
}
mediaId, err = util.GenerateRandomString(64)
if err != nil {
return "", err
}
mediaId, err = util.GetSha1OfString(mediaId + strconv.FormatInt(util.NowMillis(), 10))
if err != nil {
return "", err
}
// Because we use the current time in the media ID, we don't need to worry about
// collisions from the database.
if _, present := recentMediaIds.Get(mediaId); present {
mediaTaken = true
continue
}
mediaTaken, err = metadataDb.IsReserved(origin, mediaId)
if err != nil {
return "", err
}
}
_ = recentMediaIds.Add(mediaId, true, cache.DefaultExpiration)
return mediaId, nil
}
package upload_pipeline
import (
"io"
"github.com/turt2live/matrix-media-repo/common/rcontext"
"github.com/turt2live/matrix-media-repo/util"
)
func hashFile(ctx rcontext.RequestContext, r io.ReadCloser) (string, error) {
return util.GetSha256HashOfStream(r)
}
package upload_pipeline
import (
"io"
"io/ioutil"
"github.com/turt2live/matrix-media-repo/common/rcontext"
)
func limitStreamLength(ctx rcontext.RequestContext, r io.ReadCloser) io.ReadCloser {
if ctx.Config.Uploads.MaxSizeBytes > 0 {
return ioutil.NopCloser(io.LimitReader(r, ctx.Config.Uploads.MaxSizeBytes))
} else {
return r
}
}
...@@ -31,6 +31,7 @@ const selectMediaByUser = "SELECT origin, media_id, upload_name, content_type, u ...@@ -31,6 +31,7 @@ const selectMediaByUser = "SELECT origin, media_id, upload_name, content_type, u
const selectMediaByUserBefore = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, datastore_id, location, creation_ts, quarantined FROM media WHERE user_id = $1 AND creation_ts <= $2" const selectMediaByUserBefore = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, datastore_id, location, creation_ts, quarantined FROM media WHERE user_id = $1 AND creation_ts <= $2"
const selectMediaByDomainBefore = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, datastore_id, location, creation_ts, quarantined FROM media WHERE origin = $1 AND creation_ts <= $2" const selectMediaByDomainBefore = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, datastore_id, location, creation_ts, quarantined FROM media WHERE origin = $1 AND creation_ts <= $2"
const selectMediaByLocation = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, datastore_id, location, creation_ts, quarantined FROM media WHERE datastore_id = $1 AND location = $2" const selectMediaByLocation = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, datastore_id, location, creation_ts, quarantined FROM media WHERE datastore_id = $1 AND location = $2"
const selectIfQuarantined = "SELECT 1 FROM media WHERE sha256_hash = $1 AND quarantined = $2 LIMIT 1;"
var dsCacheByPath = sync.Map{} // [string] => Datastore var dsCacheByPath = sync.Map{} // [string] => Datastore
var dsCacheById = sync.Map{} // [string] => Datastore var dsCacheById = sync.Map{} // [string] => Datastore
...@@ -59,6 +60,7 @@ type mediaStoreStatements struct { ...@@ -59,6 +60,7 @@ type mediaStoreStatements struct {
selectMediaByUserBefore *sql.Stmt selectMediaByUserBefore *sql.Stmt
selectMediaByDomainBefore *sql.Stmt selectMediaByDomainBefore *sql.Stmt
selectMediaByLocation *sql.Stmt selectMediaByLocation *sql.Stmt
selectIfQuarantined *sql.Stmt
} }
type MediaStoreFactory struct { type MediaStoreFactory struct {
...@@ -144,6 +146,9 @@ func InitMediaStore(sqlDb *sql.DB) (*MediaStoreFactory, error) { ...@@ -144,6 +146,9 @@ func InitMediaStore(sqlDb *sql.DB) (*MediaStoreFactory, error) {
if store.stmts.selectMediaByLocation, err = store.sqlDb.Prepare(selectMediaByLocation); err != nil { if store.stmts.selectMediaByLocation, err = store.sqlDb.Prepare(selectMediaByLocation); err != nil {
return nil, err return nil, err
} }
if store.stmts.selectIfQuarantined, err = store.sqlDb.Prepare(selectIfQuarantined); err != nil {
return nil, err
}
return &store, nil return &store, nil
} }
...@@ -702,3 +707,15 @@ func (s *MediaStore) GetMediaByLocation(datastoreId string, location string) ([] ...@@ -702,3 +707,15 @@ func (s *MediaStore) GetMediaByLocation(datastoreId string, location string) ([]
return results, nil return results, nil
} }
func (s *MediaStore) IsQuarantined(sha256hash string) (bool, error) {
r := s.statements.selectIfQuarantined.QueryRow(sha256hash, true)
var i int
err := r.Scan(&i)
if err == sql.ErrNoRows {
return false, nil
} else if err != nil {
return false, err
}
return true, nil
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment