diff --git a/src/github.com/turt2live/matrix-media-repo/media_handler/media_handler.go b/src/github.com/turt2live/matrix-media-repo/media_handler/media_handler.go index 5f89d9e7cd03b9efe680f7fd5e527e7dc5a776f6..8ca48a53ff486f530c4879829d1b735b3a08b26a 100644 --- a/src/github.com/turt2live/matrix-media-repo/media_handler/media_handler.go +++ b/src/github.com/turt2live/matrix-media-repo/media_handler/media_handler.go @@ -24,7 +24,42 @@ func (r MediaUploadRequest) StoreMedia(ctx context.Context, db storage.Database) return "", err } - // TODO: Dedupe media here + records, err := db.GetMediaByHash(ctx, hash) + if err != nil { + return "", err + } + if len(records) > 0 { + var media types.Media + + // Try and find an already-existing media item for this host + for i := 0; i < len(records); i++ { + media = records[i] + + // If the media is exactly the same, just return it + if IsMediaSame(media, r) { + return util.MediaToMxc(&media), nil + } + + if media.Origin == r.Host { + // Generate a new ID for this upload + media.MediaId = GenerateMediaId() + break + } + } + + media.Origin = r.Host + media.UserId = r.UploadedBy + media.UploadName = r.DesiredFilename + media.ContentType = r.ContentType + media.CreationTs = time.Now().UnixNano() / 1000000 + + err = db.InsertMedia(ctx, &media) + if err != nil { + return "", err + } + + return util.MediaToMxc(&media), nil + } destination, err := storage.PersistTempFile(r.TempLocation) if err != nil { @@ -71,4 +106,13 @@ func GenerateMediaId() string { } return str +} + +func IsMediaSame(media types.Media, r MediaUploadRequest) bool { + originSame := media.Origin == r.Host + nameSame := media.UploadName == r.DesiredFilename + userSame := media.UserId == r.UploadedBy + typeSame := media.ContentType == r.ContentType + + return originSame && nameSame && userSame && typeSame } \ No newline at end of file diff --git a/src/github.com/turt2live/matrix-media-repo/storage/storage.go b/src/github.com/turt2live/matrix-media-repo/storage/storage.go index 5090030a3324aad8c80366eacdf4badfad4021ed..1389a4132684a123fd1c33b8a12b4db0dc1688c3 100644 --- a/src/github.com/turt2live/matrix-media-repo/storage/storage.go +++ b/src/github.com/turt2live/matrix-media-repo/storage/storage.go @@ -22,8 +22,8 @@ type statements struct { insertOrigin *sql.Stmt } -const selectMedia = "SELECT * FROM media WHERE origin = $1 and media_id = $2;" -const selectMediaByHash = "SELECT origin, media_id FROM media WHERE sha256_hash = $1;" +const selectMedia = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, location, creation_ts FROM media WHERE origin = $1 and media_id = $2;" +const selectMediaByHash = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, location, creation_ts FROM media WHERE sha256_hash = $1;" const insertMedia = "INSERT INTO media (origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, location, creation_ts) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9);" func OpenDatabase(connectionString string) (*Database, error) { @@ -59,4 +59,33 @@ func (d *Database) InsertMedia(ctx context.Context, media *types.Media) error { ) return err +} + +func (d *Database) GetMediaByHash(ctx context.Context, hash string) ([]types.Media, error) { + rows, err := d.statements.selectMediaByHash.QueryContext(ctx, hash) + if err != nil { + return nil, err + } + + var results []types.Media + for rows.Next() { + obj := types.Media{} + err = rows.Scan( + &obj.Origin, + &obj.MediaId, + &obj.UploadName, + &obj.ContentType, + &obj.UserId, + &obj.Sha256Hash, + &obj.SizeBytes, + &obj.Location, + &obj.CreationTs, + ) + if err != nil { + return nil, err + } + results = append(results, obj) + } + + return results, nil } \ No newline at end of file