From 28241dafd6684a93a069bc120b3637e5cac7c56e Mon Sep 17 00:00:00 2001
From: Travis Ralston <travpc@gmail.com>
Date: Mon, 2 Sep 2019 22:33:32 -0600
Subject: [PATCH] Reserve media IDs to avoid re-upload attacks

---
 .../maintainance_controller.go                |  6 +++
 .../upload_controller/upload_controller.go    | 40 +++++++++++++++++--
 migrations/11_add_reserved_ids_table_down.sql |  2 +
 migrations/11_add_reserved_ids_table_up.sql   |  6 +++
 storage/stores/metadata_store.go              | 34 ++++++++++++++++
 5 files changed, 85 insertions(+), 3 deletions(-)
 create mode 100644 migrations/11_add_reserved_ids_table_down.sql
 create mode 100644 migrations/11_add_reserved_ids_table_up.sql

diff --git a/controllers/maintenance_controller/maintainance_controller.go b/controllers/maintenance_controller/maintainance_controller.go
index 03f6cf60..cdc70349 100644
--- a/controllers/maintenance_controller/maintainance_controller.go
+++ b/controllers/maintenance_controller/maintainance_controller.go
@@ -288,6 +288,12 @@ func doPurge(media *types.Media, ctx context.Context, log *logrus.Entry) error {
 		return err
 	}
 
+	metadataDb := storage.GetDatabase().GetMetadataStore(ctx, log)
+	err = metadataDb.ReserveMediaId(media.Origin, media.MediaId, "purged / deleted")
+	if err != nil {
+		return err
+	}
+
 	mediaDb := storage.GetDatabase().GetMediaStore(ctx, log)
 	err = mediaDb.Delete(media.Origin, media.MediaId)
 	if err != nil {
diff --git a/controllers/upload_controller/upload_controller.go b/controllers/upload_controller/upload_controller.go
index acc01da4..5abb01d0 100644
--- a/controllers/upload_controller/upload_controller.go
+++ b/controllers/upload_controller/upload_controller.go
@@ -2,6 +2,7 @@ package upload_controller
 
 import (
 	"context"
+	"database/sql"
 	"io"
 	"io/ioutil"
 	"strconv"
@@ -86,9 +87,42 @@ func UploadMedia(contents io.ReadCloser, contentLength int64, contentType string
 		data = contents
 	}
 
-	mediaId, err := util.GenerateRandomString(64)
-	if err != nil {
-		return nil, err
+	metadataDb := storage.GetDatabase().GetMetadataStore(ctx, log)
+	mediaDb := storage.GetDatabase().GetMediaStore(ctx, log)
+
+	mediaTaken := true
+	var mediaId string
+	var err error
+	attempts := 0
+	for mediaTaken {
+		attempts += 1
+		if attempts > 10 {
+			return nil, errors.New("failed to generate a media ID after 10 rounds")
+		}
+
+		mediaId, err = util.GenerateRandomString(64)
+		if err != nil {
+			return nil, err
+		}
+
+		mediaTaken, err = metadataDb.IsReserved(origin, mediaId)
+		if err != nil {
+			return nil, err
+		}
+
+		if !mediaTaken {
+			// Double check it isn't already in use
+			var media *types.Media
+			media, err = mediaDb.Get(origin, mediaId)
+			if err == sql.ErrNoRows {
+				mediaTaken = false
+				continue
+			}
+			if err != nil {
+				return nil, err
+			}
+			mediaTaken = media != nil
+		}
 	}
 
 	return StoreDirect(data, contentLength, contentType, filename, userId, origin, mediaId, ctx, log)
diff --git a/migrations/11_add_reserved_ids_table_down.sql b/migrations/11_add_reserved_ids_table_down.sql
new file mode 100644
index 00000000..0f93b6d8
--- /dev/null
+++ b/migrations/11_add_reserved_ids_table_down.sql
@@ -0,0 +1,2 @@
+DROP INDEX reserved_media_index;
+DROP TABLE reserved_media;
diff --git a/migrations/11_add_reserved_ids_table_up.sql b/migrations/11_add_reserved_ids_table_up.sql
new file mode 100644
index 00000000..fae93417
--- /dev/null
+++ b/migrations/11_add_reserved_ids_table_up.sql
@@ -0,0 +1,6 @@
+CREATE TABLE IF NOT EXISTS reserved_media (
+	origin TEXT NOT NULL,
+	media_id TEXT NOT NULL,
+	reason TEXT NOT NULL
+);
+CREATE UNIQUE INDEX IF NOT EXISTS reserved_media_index ON reserved_media (media_id, origin);
diff --git a/storage/stores/metadata_store.go b/storage/stores/metadata_store.go
index 6dfd90e2..0179e6b8 100644
--- a/storage/stores/metadata_store.go
+++ b/storage/stores/metadata_store.go
@@ -27,6 +27,8 @@ const insertNewBackgroundTask = "INSERT INTO background_tasks (task, params, sta
 const selectBackgroundTask = "SELECT id, task, params, start_ts, end_ts FROM background_tasks WHERE id = $1"
 const updateBackgroundTask = "UPDATE background_tasks SET end_ts = $2 WHERE id = $1"
 const selectAllBackgroundTasks = "SELECT id, task, params, start_ts, end_ts FROM background_tasks"
+const insertReservation = "INSERT INTO reserved_media (origin, media_id, reason) VALUES ($1, $2, $3);"
+const selectReservation = "SELECT origin, media_id, reason FROM reserved_media WHERE origin = $1 AND media_id = $2;"
 
 type metadataStoreStatements struct {
 	upsertLastAccessed                            *sql.Stmt
@@ -42,6 +44,8 @@ type metadataStoreStatements struct {
 	selectBackgroundTask                          *sql.Stmt
 	updateBackgroundTask                          *sql.Stmt
 	selectAllBackgroundTasks                      *sql.Stmt
+	insertReservation                             *sql.Stmt
+	selectReservation                             *sql.Stmt
 }
 
 type MetadataStoreFactory struct {
@@ -101,6 +105,12 @@ func InitMetadataStore(sqlDb *sql.DB) (*MetadataStoreFactory, error) {
 	if store.stmts.selectAllBackgroundTasks, err = store.sqlDb.Prepare(selectAllBackgroundTasks); err != nil {
 		return nil, err
 	}
+	if store.stmts.insertReservation, err = store.sqlDb.Prepare(insertReservation); err != nil {
+		return nil, err
+	}
+	if store.stmts.selectReservation, err = store.sqlDb.Prepare(selectReservation); err != nil {
+		return nil, err
+	}
 
 	return &store, nil
 }
@@ -314,3 +324,27 @@ func (s *MetadataStore) GetAllBackgroundTasks() ([]*types.BackgroundTask, error)
 
 	return results, nil
 }
+
+func (s *MetadataStore) ReserveMediaId(origin string, mediaId string, reason string) error {
+	_, err := s.statements.insertReservation.ExecContext(s.ctx, origin, mediaId, reason)
+	if err != nil {
+		return err
+	}
+	return nil
+}
+
+func (s *MetadataStore) IsReserved(origin string, mediaId string) (bool, error) {
+	r := s.statements.selectReservation.QueryRowContext(s.ctx, origin, mediaId)
+	var dbOrigin string
+	var dbMediaId string
+	var dbReason string
+
+	err := r.Scan(&dbOrigin, &dbMediaId, &dbReason)
+	if err == sql.ErrNoRows {
+		return false, nil
+	}
+	if err != nil {
+		return true, err
+	}
+	return true, nil
+}
-- 
GitLab