From dad1f2d37f92bd66805918c562cdaf3c481f9d1e Mon Sep 17 00:00:00 2001 From: Travis Ralston <travpc@gmail.com> Date: Mon, 17 Jul 2023 00:24:20 -0600 Subject: [PATCH] Update media attributes API --- api/custom/media_attributes.go | 61 +++++++++++------------- database/db.go | 26 ++++++----- database/table_media_attributes.go | 75 ++++++++++++++++++++++++++++++ 3 files changed, 116 insertions(+), 46 deletions(-) create mode 100644 database/table_media_attributes.go diff --git a/api/custom/media_attributes.go b/api/custom/media_attributes.go index 5ed12724..d69a6049 100644 --- a/api/custom/media_attributes.go +++ b/api/custom/media_attributes.go @@ -1,27 +1,22 @@ package custom import ( - "database/sql" "encoding/json" - "io" "net/http" "github.com/getsentry/sentry-go" + "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/api/_apimeta" "github.com/turt2live/matrix-media-repo/api/_responses" "github.com/turt2live/matrix-media-repo/api/_routers" - "github.com/turt2live/matrix-media-repo/util/stream_util" - - "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/common/rcontext" + "github.com/turt2live/matrix-media-repo/database" "github.com/turt2live/matrix-media-repo/matrix" - "github.com/turt2live/matrix-media-repo/storage" - "github.com/turt2live/matrix-media-repo/types" "github.com/turt2live/matrix-media-repo/util" ) type Attributes struct { - Purpose string `json:"purpose"` + Purpose database.Purpose `json:"purpose"` } func canChangeAttributes(rctx rcontext.RequestContext, r *http.Request, origin string, user _apimeta.UserInfo) bool { @@ -55,29 +50,32 @@ func GetAttributes(r *http.Request, rctx rcontext.RequestContext, user _apimeta. } // Check to see if the media exists - mediaDb := storage.GetDatabase().GetMediaStore(rctx) - media, err := mediaDb.Get(origin, mediaId) - if err != nil && err != sql.ErrNoRows { + mediaDb := database.GetInstance().Media.Prepare(rctx) + media, err := mediaDb.GetById(origin, mediaId) + if err != nil { rctx.Log.Error(err) sentry.CaptureException(err) return _responses.InternalServerError("failed to get media record") } - if media == nil || err == sql.ErrNoRows { + if media == nil { return _responses.NotFoundError() } - db := storage.GetDatabase().GetMediaAttributesStore(rctx) - - attrs, err := db.GetAttributesDefaulted(origin, mediaId) + attrDb := database.GetInstance().MediaAttributes.Prepare(rctx) + attrs, err := attrDb.Get(origin, mediaId) if err != nil { rctx.Log.Error(err) sentry.CaptureException(err) - return _responses.InternalServerError("failed to get attributes") + return _responses.InternalServerError("failed to get attributes record") + } + retAttrs := &Attributes{ + Purpose: database.PurposeNone, + } + if attrs != nil { + retAttrs.Purpose = attrs.Purpose } - return &_responses.DoNotCacheResponse{Payload: &Attributes{ - Purpose: attrs.Purpose, - }} + return &_responses.DoNotCacheResponse{Payload: retAttrs} } func SetAttributes(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} { @@ -97,36 +95,29 @@ func SetAttributes(r *http.Request, rctx rcontext.RequestContext, user _apimeta. return _responses.AuthFailed() } - defer stream_util.DumpAndCloseStream(r.Body) - b, err := io.ReadAll(r.Body) - if err != nil { - rctx.Log.Error(err) - sentry.CaptureException(err) - return _responses.InternalServerError("failed to read attributes") - } - + defer r.Body.Close() newAttrs := &Attributes{} - err = json.Unmarshal(b, &newAttrs) + decoder := json.NewDecoder(r.Body) + err := decoder.Decode(&newAttrs) if err != nil { rctx.Log.Error(err) sentry.CaptureException(err) - return _responses.InternalServerError("failed to parse attributes") + return _responses.InternalServerError("failed to read attributes") } - db := storage.GetDatabase().GetMediaAttributesStore(rctx) - - attrs, err := db.GetAttributesDefaulted(origin, mediaId) + attrDb := database.GetInstance().MediaAttributes.Prepare(rctx) + attrs, err := attrDb.Get(origin, mediaId) if err != nil { rctx.Log.Error(err) sentry.CaptureException(err) return _responses.InternalServerError("failed to get attributes") } - if attrs.Purpose != newAttrs.Purpose { - if !util.ArrayContains(types.AllPurposes, newAttrs.Purpose) { + if attrs == nil || attrs.Purpose != newAttrs.Purpose { + if !database.IsPurpose(newAttrs.Purpose) { return _responses.BadRequest("unknown purpose") } - err = db.UpsertPurpose(origin, mediaId, newAttrs.Purpose) + err = attrDb.UpsertPurpose(origin, mediaId, newAttrs.Purpose) if err != nil { rctx.Log.Error(err) sentry.CaptureException(err) diff --git a/database/db.go b/database/db.go index e738f566..5eaee18c 100644 --- a/database/db.go +++ b/database/db.go @@ -14,17 +14,18 @@ import ( ) type Database struct { - conn *sql.DB - Media *mediaTableStatements - ExpiringMedia *expiringMediaTableStatements - UserStats *userStatsTableStatements - ReservedMedia *reservedMediaTableStatements - MetadataView *metadataVirtualTableStatements - Blurhashes *blurhashesTableStatements - HeldMedia *heldMediaTableStatements - Thumbnails *thumbnailsTableStatements - LastAccess *lastAccessTableStatements - UrlPreviews *urlPreviewsTableStatements + conn *sql.DB + Media *mediaTableStatements + ExpiringMedia *expiringMediaTableStatements + UserStats *userStatsTableStatements + ReservedMedia *reservedMediaTableStatements + MetadataView *metadataVirtualTableStatements + Blurhashes *blurhashesTableStatements + HeldMedia *heldMediaTableStatements + Thumbnails *thumbnailsTableStatements + LastAccess *lastAccessTableStatements + UrlPreviews *urlPreviewsTableStatements + MediaAttributes *mediaAttributesTableStatements } var instance *Database @@ -108,6 +109,9 @@ func openDatabase(connectionString string, maxConns int, maxIdleConns int) error if d.UrlPreviews, err = prepareUrlPreviewsTables(d.conn); err != nil { return errors.New("failed to create url previews table accessor: " + err.Error()) } + if d.MediaAttributes, err = prepareMediaAttributesTables(d.conn); err != nil { + return errors.New("failed to create media attributes table accessor: " + err.Error()) + } instance = d return nil diff --git a/database/table_media_attributes.go b/database/table_media_attributes.go new file mode 100644 index 00000000..2513fc55 --- /dev/null +++ b/database/table_media_attributes.go @@ -0,0 +1,75 @@ +package database + +import ( + "database/sql" + "errors" + + "github.com/turt2live/matrix-media-repo/common/rcontext" +) + +type DbMediaAttributes struct { + Origin string + MediaId string + Purpose Purpose +} + +type Purpose string + +const ( + PurposeNone Purpose = "none" + PurposePinned Purpose = "pinned" +) + +func IsPurpose(purpose Purpose) bool { + return purpose == PurposeNone || purpose == PurposePinned +} + +const selectMediaAttributes = "SELECT origin, media_id, purpose FROM media_attributes WHERE origin = $1 AND media_id = $2;" +const upsertMediaPurpose = "INSERT INTO media_attributes (origin, media_id, purpose) VALUES ($1, $2, $3) ON CONFLICT (origin, media_id) DO UPDATE SET purpose = $3;" + +type mediaAttributesTableStatements struct { + selectMediaAttributes *sql.Stmt + upsertMediaPurpose *sql.Stmt +} + +type mediaAttributesTableWithContext struct { + statements *mediaAttributesTableStatements + ctx rcontext.RequestContext +} + +func prepareMediaAttributesTables(db *sql.DB) (*mediaAttributesTableStatements, error) { + var err error + var stmts = &mediaAttributesTableStatements{} + + if stmts.selectMediaAttributes, err = db.Prepare(selectMediaAttributes); err != nil { + return nil, errors.New("error preparing selectMediaAttributes: " + err.Error()) + } + if stmts.upsertMediaPurpose, err = db.Prepare(upsertMediaPurpose); err != nil { + return nil, errors.New("error preparing upsertMediaPurpose: " + err.Error()) + } + + return stmts, nil +} + +func (s *mediaAttributesTableStatements) Prepare(ctx rcontext.RequestContext) *mediaAttributesTableWithContext { + return &mediaAttributesTableWithContext{ + statements: s, + ctx: ctx, + } +} + +func (s *mediaAttributesTableWithContext) Get(origin string, mediaId string) (*DbMediaAttributes, error) { + row := s.statements.selectMediaAttributes.QueryRowContext(s.ctx, origin, mediaId) + val := &DbMediaAttributes{} + err := row.Scan(&val.Origin, &val.MediaId, &val.Purpose) + if err == sql.ErrNoRows { + err = nil + val = nil + } + return val, err +} + +func (s *mediaAttributesTableWithContext) UpsertPurpose(origin string, mediaId string, purpose Purpose) error { + _, err := s.statements.upsertMediaPurpose.ExecContext(s.ctx, origin, mediaId, purpose) + return err +} -- GitLab