From dad1f2d37f92bd66805918c562cdaf3c481f9d1e Mon Sep 17 00:00:00 2001
From: Travis Ralston <travpc@gmail.com>
Date: Mon, 17 Jul 2023 00:24:20 -0600
Subject: [PATCH] Update media attributes API

---
 api/custom/media_attributes.go     | 61 +++++++++++-------------
 database/db.go                     | 26 ++++++-----
 database/table_media_attributes.go | 75 ++++++++++++++++++++++++++++++
 3 files changed, 116 insertions(+), 46 deletions(-)
 create mode 100644 database/table_media_attributes.go

diff --git a/api/custom/media_attributes.go b/api/custom/media_attributes.go
index 5ed12724..d69a6049 100644
--- a/api/custom/media_attributes.go
+++ b/api/custom/media_attributes.go
@@ -1,27 +1,22 @@
 package custom
 
 import (
-	"database/sql"
 	"encoding/json"
-	"io"
 	"net/http"
 
 	"github.com/getsentry/sentry-go"
+	"github.com/sirupsen/logrus"
 	"github.com/turt2live/matrix-media-repo/api/_apimeta"
 	"github.com/turt2live/matrix-media-repo/api/_responses"
 	"github.com/turt2live/matrix-media-repo/api/_routers"
-	"github.com/turt2live/matrix-media-repo/util/stream_util"
-
-	"github.com/sirupsen/logrus"
 	"github.com/turt2live/matrix-media-repo/common/rcontext"
+	"github.com/turt2live/matrix-media-repo/database"
 	"github.com/turt2live/matrix-media-repo/matrix"
-	"github.com/turt2live/matrix-media-repo/storage"
-	"github.com/turt2live/matrix-media-repo/types"
 	"github.com/turt2live/matrix-media-repo/util"
 )
 
 type Attributes struct {
-	Purpose string `json:"purpose"`
+	Purpose database.Purpose `json:"purpose"`
 }
 
 func canChangeAttributes(rctx rcontext.RequestContext, r *http.Request, origin string, user _apimeta.UserInfo) bool {
@@ -55,29 +50,32 @@ func GetAttributes(r *http.Request, rctx rcontext.RequestContext, user _apimeta.
 	}
 
 	// Check to see if the media exists
-	mediaDb := storage.GetDatabase().GetMediaStore(rctx)
-	media, err := mediaDb.Get(origin, mediaId)
-	if err != nil && err != sql.ErrNoRows {
+	mediaDb := database.GetInstance().Media.Prepare(rctx)
+	media, err := mediaDb.GetById(origin, mediaId)
+	if err != nil {
 		rctx.Log.Error(err)
 		sentry.CaptureException(err)
 		return _responses.InternalServerError("failed to get media record")
 	}
-	if media == nil || err == sql.ErrNoRows {
+	if media == nil {
 		return _responses.NotFoundError()
 	}
 
-	db := storage.GetDatabase().GetMediaAttributesStore(rctx)
-
-	attrs, err := db.GetAttributesDefaulted(origin, mediaId)
+	attrDb := database.GetInstance().MediaAttributes.Prepare(rctx)
+	attrs, err := attrDb.Get(origin, mediaId)
 	if err != nil {
 		rctx.Log.Error(err)
 		sentry.CaptureException(err)
-		return _responses.InternalServerError("failed to get attributes")
+		return _responses.InternalServerError("failed to get attributes record")
+	}
+	retAttrs := &Attributes{
+		Purpose: database.PurposeNone,
+	}
+	if attrs != nil {
+		retAttrs.Purpose = attrs.Purpose
 	}
 
-	return &_responses.DoNotCacheResponse{Payload: &Attributes{
-		Purpose: attrs.Purpose,
-	}}
+	return &_responses.DoNotCacheResponse{Payload: retAttrs}
 }
 
 func SetAttributes(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} {
@@ -97,36 +95,29 @@ func SetAttributes(r *http.Request, rctx rcontext.RequestContext, user _apimeta.
 		return _responses.AuthFailed()
 	}
 
-	defer stream_util.DumpAndCloseStream(r.Body)
-	b, err := io.ReadAll(r.Body)
-	if err != nil {
-		rctx.Log.Error(err)
-		sentry.CaptureException(err)
-		return _responses.InternalServerError("failed to read attributes")
-	}
-
+	defer r.Body.Close()
 	newAttrs := &Attributes{}
-	err = json.Unmarshal(b, &newAttrs)
+	decoder := json.NewDecoder(r.Body)
+	err := decoder.Decode(&newAttrs)
 	if err != nil {
 		rctx.Log.Error(err)
 		sentry.CaptureException(err)
-		return _responses.InternalServerError("failed to parse attributes")
+		return _responses.InternalServerError("failed to read attributes")
 	}
 
-	db := storage.GetDatabase().GetMediaAttributesStore(rctx)
-
-	attrs, err := db.GetAttributesDefaulted(origin, mediaId)
+	attrDb := database.GetInstance().MediaAttributes.Prepare(rctx)
+	attrs, err := attrDb.Get(origin, mediaId)
 	if err != nil {
 		rctx.Log.Error(err)
 		sentry.CaptureException(err)
 		return _responses.InternalServerError("failed to get attributes")
 	}
 
-	if attrs.Purpose != newAttrs.Purpose {
-		if !util.ArrayContains(types.AllPurposes, newAttrs.Purpose) {
+	if attrs == nil || attrs.Purpose != newAttrs.Purpose {
+		if !database.IsPurpose(newAttrs.Purpose) {
 			return _responses.BadRequest("unknown purpose")
 		}
-		err = db.UpsertPurpose(origin, mediaId, newAttrs.Purpose)
+		err = attrDb.UpsertPurpose(origin, mediaId, newAttrs.Purpose)
 		if err != nil {
 			rctx.Log.Error(err)
 			sentry.CaptureException(err)
diff --git a/database/db.go b/database/db.go
index e738f566..5eaee18c 100644
--- a/database/db.go
+++ b/database/db.go
@@ -14,17 +14,18 @@ import (
 )
 
 type Database struct {
-	conn          *sql.DB
-	Media         *mediaTableStatements
-	ExpiringMedia *expiringMediaTableStatements
-	UserStats     *userStatsTableStatements
-	ReservedMedia *reservedMediaTableStatements
-	MetadataView  *metadataVirtualTableStatements
-	Blurhashes    *blurhashesTableStatements
-	HeldMedia     *heldMediaTableStatements
-	Thumbnails    *thumbnailsTableStatements
-	LastAccess    *lastAccessTableStatements
-	UrlPreviews   *urlPreviewsTableStatements
+	conn            *sql.DB
+	Media           *mediaTableStatements
+	ExpiringMedia   *expiringMediaTableStatements
+	UserStats       *userStatsTableStatements
+	ReservedMedia   *reservedMediaTableStatements
+	MetadataView    *metadataVirtualTableStatements
+	Blurhashes      *blurhashesTableStatements
+	HeldMedia       *heldMediaTableStatements
+	Thumbnails      *thumbnailsTableStatements
+	LastAccess      *lastAccessTableStatements
+	UrlPreviews     *urlPreviewsTableStatements
+	MediaAttributes *mediaAttributesTableStatements
 }
 
 var instance *Database
@@ -108,6 +109,9 @@ func openDatabase(connectionString string, maxConns int, maxIdleConns int) error
 	if d.UrlPreviews, err = prepareUrlPreviewsTables(d.conn); err != nil {
 		return errors.New("failed to create url previews table accessor: " + err.Error())
 	}
+	if d.MediaAttributes, err = prepareMediaAttributesTables(d.conn); err != nil {
+		return errors.New("failed to create media attributes table accessor: " + err.Error())
+	}
 
 	instance = d
 	return nil
diff --git a/database/table_media_attributes.go b/database/table_media_attributes.go
new file mode 100644
index 00000000..2513fc55
--- /dev/null
+++ b/database/table_media_attributes.go
@@ -0,0 +1,75 @@
+package database
+
+import (
+	"database/sql"
+	"errors"
+
+	"github.com/turt2live/matrix-media-repo/common/rcontext"
+)
+
+type DbMediaAttributes struct {
+	Origin  string
+	MediaId string
+	Purpose Purpose
+}
+
+type Purpose string
+
+const (
+	PurposeNone   Purpose = "none"
+	PurposePinned Purpose = "pinned"
+)
+
+func IsPurpose(purpose Purpose) bool {
+	return purpose == PurposeNone || purpose == PurposePinned
+}
+
+const selectMediaAttributes = "SELECT origin, media_id, purpose FROM media_attributes WHERE origin = $1 AND media_id = $2;"
+const upsertMediaPurpose = "INSERT INTO media_attributes (origin, media_id, purpose) VALUES ($1, $2, $3) ON CONFLICT (origin, media_id) DO UPDATE SET purpose = $3;"
+
+type mediaAttributesTableStatements struct {
+	selectMediaAttributes *sql.Stmt
+	upsertMediaPurpose    *sql.Stmt
+}
+
+type mediaAttributesTableWithContext struct {
+	statements *mediaAttributesTableStatements
+	ctx        rcontext.RequestContext
+}
+
+func prepareMediaAttributesTables(db *sql.DB) (*mediaAttributesTableStatements, error) {
+	var err error
+	var stmts = &mediaAttributesTableStatements{}
+
+	if stmts.selectMediaAttributes, err = db.Prepare(selectMediaAttributes); err != nil {
+		return nil, errors.New("error preparing selectMediaAttributes: " + err.Error())
+	}
+	if stmts.upsertMediaPurpose, err = db.Prepare(upsertMediaPurpose); err != nil {
+		return nil, errors.New("error preparing upsertMediaPurpose: " + err.Error())
+	}
+
+	return stmts, nil
+}
+
+func (s *mediaAttributesTableStatements) Prepare(ctx rcontext.RequestContext) *mediaAttributesTableWithContext {
+	return &mediaAttributesTableWithContext{
+		statements: s,
+		ctx:        ctx,
+	}
+}
+
+func (s *mediaAttributesTableWithContext) Get(origin string, mediaId string) (*DbMediaAttributes, error) {
+	row := s.statements.selectMediaAttributes.QueryRowContext(s.ctx, origin, mediaId)
+	val := &DbMediaAttributes{}
+	err := row.Scan(&val.Origin, &val.MediaId, &val.Purpose)
+	if err == sql.ErrNoRows {
+		err = nil
+		val = nil
+	}
+	return val, err
+}
+
+func (s *mediaAttributesTableWithContext) UpsertPurpose(origin string, mediaId string, purpose Purpose) error {
+	_, err := s.statements.upsertMediaPurpose.ExecContext(s.ctx, origin, mediaId, purpose)
+	return err
+}
-- 
GitLab