diff --git a/database/db.go b/database/db.go index 02affb0463ca3d61b53136191a7a189339d95de3..86eceae276c2f88d1e09f19a345c7cba6d9c9c33 100644 --- a/database/db.go +++ b/database/db.go @@ -21,6 +21,7 @@ type Database struct { ReservedMedia *reservedMediaTableStatements MetadataView *metadataVirtualTableStatements Blurhashes *blurhashesTableStatements + HeldMedia *heldMediaTableStatements } var instance *Database @@ -92,6 +93,9 @@ func openDatabase(connectionString string, maxConns int, maxIdleConns int) error if d.Blurhashes, err = prepareBlurhashesTables(d.conn); err != nil { return errors.New("failed to create blurhashes table accessor: " + err.Error()) } + if d.HeldMedia, err = prepareHeldMediaTables(d.conn); err != nil { + return errors.New("failed to create held media table accessor: " + err.Error()) + } instance = d return nil diff --git a/database/table_media.go b/database/table_media.go index 1c2da8acdfe368d44035550daa9938c0a062a398..3f49bad0707ca3ac8f0ece20fe5149e9a462496e 100644 --- a/database/table_media.go +++ b/database/table_media.go @@ -25,12 +25,14 @@ const selectDistinctMediaDatastoreIds = "SELECT DISTINCT datastore_id FROM media const selectMediaIsQuarantinedByHash = "SELECT quarantined FROM media WHERE quarantined = TRUE AND sha256_hash = $1;" const selectMediaByHash = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, creation_ts, quarantined, datastore_id, location FROM media WHERE sha256_hash = $1;" const insertMedia = "INSERT INTO media (origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, creation_ts, quarantined, datastore_id, location) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11);" +const selectMediaExists = "SELECT TRUE FROM media WHERE origin = $1 AND media_id = $2 LIMIT 1;" type mediaTableStatements struct { selectDistinctMediaDatastoreIds *sql.Stmt selectMediaIsQuarantinedByHash *sql.Stmt selectMediaByHash *sql.Stmt insertMedia *sql.Stmt + selectMediaExists *sql.Stmt } type mediaTableWithContext struct { @@ -54,6 +56,9 @@ func prepareMediaTables(db *sql.DB) (*mediaTableStatements, error) { if stmts.insertMedia, err = db.Prepare(insertMedia); err != nil { return nil, errors.New("error preparing insertMedia: " + err.Error()) } + if stmts.selectMediaExists, err = db.Prepare(selectMediaExists); err != nil { + return nil, errors.New("error preparing selectMediaExists: " + err.Error()) + } return stmts, nil } @@ -118,6 +123,17 @@ func (s *mediaTableWithContext) GetByHash(sha256hash string) ([]*DbMedia, error) return results, nil } +func (s *mediaTableWithContext) IdExists(origin string, mediaId string) (bool, error) { + row := s.statements.selectMediaExists.QueryRowContext(s.ctx, origin, mediaId) + val := false + err := row.Scan(&val) + if err == sql.ErrNoRows { + err = nil + val = false + } + return val, err +} + func (s *mediaTableWithContext) Insert(record *DbMedia) error { _, err := s.statements.insertMedia.ExecContext(s.ctx, record.Origin, record.MediaId, record.UploadName, record.ContentType, record.UserId, record.Sha256Hash, record.SizeBytes, record.CreationTs, record.Quarantined, record.DatastoreId, record.Location) return err diff --git a/database/table_media_hold.go b/database/table_media_hold.go new file mode 100644 index 0000000000000000000000000000000000000000..0911f9572c1c1f720b92edf37da0ee2bff13f7f0 --- /dev/null +++ b/database/table_media_hold.go @@ -0,0 +1,54 @@ +package database + +import ( + "database/sql" + "errors" + + "github.com/turt2live/matrix-media-repo/common/rcontext" +) + +type DbHeldMedia struct { + Origin string + MediaId string + Reason string +} + +type HeldReason string + +const ( + ForCreateHeldReason HeldReason = "media_create" +) + +const insertHeldMedia = "INSERT INTO media_id_hold (origin, media_id, reason) VALUES ($1, $2, $3);" + +type heldMediaTableStatements struct { + insertHeldMedia *sql.Stmt +} + +type heldMediaTableWithContext struct { + statements *heldMediaTableStatements + ctx rcontext.RequestContext +} + +func prepareHeldMediaTables(db *sql.DB) (*heldMediaTableStatements, error) { + var err error + var stmts = &heldMediaTableStatements{} + + if stmts.insertHeldMedia, err = db.Prepare(insertHeldMedia); err != nil { + return nil, errors.New("error preparing insertHeldMedia: " + err.Error()) + } + + return stmts, nil +} + +func (s *heldMediaTableStatements) Prepare(ctx rcontext.RequestContext) *heldMediaTableWithContext { + return &heldMediaTableWithContext{ + statements: s, + ctx: ctx, + } +} + +func (s *heldMediaTableWithContext) TryInsert(origin string, mediaId string, reason HeldReason) error { + _, err := s.statements.insertHeldMedia.ExecContext(s.ctx, origin, mediaId, reason) + return err +} diff --git a/database/table_reserved_media.go b/database/table_reserved_media.go index 807ccdfd0ba5394f48de11ec0ca0b319c242f4fe..239b0976c8377b10f6a4582c38340c1bdda19da1 100644 --- a/database/table_reserved_media.go +++ b/database/table_reserved_media.go @@ -13,12 +13,6 @@ type DbReservedMedia struct { Reason string } -type ReserveReason string - -const ( - ForCreateReserveReason ReserveReason = "media_create" -) - const insertReservedMedia = "INSERT INTO reserved_media (origin, media_id, reason) VALUES ($1, $2, $3);" type reservedMediaTableStatements struct { @@ -48,7 +42,7 @@ func (s *reservedMediaTableStatements) Prepare(ctx rcontext.RequestContext) *res } } -func (s *reservedMediaTableWithContext) TryInsert(origin string, mediaId string, reason ReserveReason) error { +func (s *reservedMediaTableWithContext) TryInsert(origin string, mediaId string, reason string) error { _, err := s.statements.insertReservedMedia.ExecContext(s.ctx, origin, mediaId, reason) return err } diff --git a/migrations/20_create_id_hold_table_down.sql b/migrations/20_create_id_hold_table_down.sql new file mode 100644 index 0000000000000000000000000000000000000000..ad21bf1ce5d4a14b57282f2e03dfcb16dc356bee --- /dev/null +++ b/migrations/20_create_id_hold_table_down.sql @@ -0,0 +1,2 @@ +DROP INDEX idx_media_id_hold; +DROP TABLE media_id_hold; \ No newline at end of file diff --git a/migrations/20_create_id_hold_table_up.sql b/migrations/20_create_id_hold_table_up.sql new file mode 100644 index 0000000000000000000000000000000000000000..a8692400b06327ab1e3c2b286601604048292688 --- /dev/null +++ b/migrations/20_create_id_hold_table_up.sql @@ -0,0 +1,6 @@ +CREATE TABLE IF NOT EXISTS media_id_hold ( + origin TEXT NOT NULL, + media_id TEXT NOT NULL, + reason TEXT NOT NULL +); +CREATE UNIQUE INDEX IF NOT EXISTS idx_media_id_hold ON media_id_hold (media_id, origin); \ No newline at end of file diff --git a/pipline/_steps/upload/generate_media_id.go b/pipline/_steps/upload/generate_media_id.go index 6e4782f02d5ae8c5646339ff1ac6b9fa52947f5f..330aa9b00daa1b154d6e3497fdd1cc3347373063 100644 --- a/pipline/_steps/upload/generate_media_id.go +++ b/pipline/_steps/upload/generate_media_id.go @@ -9,9 +9,11 @@ import ( ) func GenerateMediaId(ctx rcontext.RequestContext, origin string) (string, error) { - db := database.GetInstance().ReservedMedia.Prepare(ctx) + heldDb := database.GetInstance().HeldMedia.Prepare(ctx) + mediaDb := database.GetInstance().Media.Prepare(ctx) var mediaId string var err error + var exists bool attempts := 0 for true { attempts += 1 @@ -21,13 +23,21 @@ func GenerateMediaId(ctx rcontext.RequestContext, origin string) (string, error) mediaId, err = ids.NewUniqueId() - err = db.TryInsert(origin, mediaId, database.ForCreateReserveReason) + err = heldDb.TryInsert(origin, mediaId, database.ForCreateHeldReason) if err != nil { return "", err } // Check if there's a media table record for this media as well (there shouldn't be) - return mediaId, nil // TODO: @@TR - This + exists, err = mediaDb.IdExists(origin, mediaId) + if err != nil { + return "", err + } + if exists { + continue + } + + return mediaId, nil } return "", errors.New("internal limit reached: fell out of media ID generation loop") }