Skip to content
Snippets Groups Projects
Commit dad1f2d3 authored by Travis Ralston's avatar Travis Ralston
Browse files

Update media attributes API

parent 65308fa7
No related branches found
No related tags found
No related merge requests found
package custom
import (
"database/sql"
"encoding/json"
"io"
"net/http"
"github.com/getsentry/sentry-go"
"github.com/sirupsen/logrus"
"github.com/turt2live/matrix-media-repo/api/_apimeta"
"github.com/turt2live/matrix-media-repo/api/_responses"
"github.com/turt2live/matrix-media-repo/api/_routers"
"github.com/turt2live/matrix-media-repo/util/stream_util"
"github.com/sirupsen/logrus"
"github.com/turt2live/matrix-media-repo/common/rcontext"
"github.com/turt2live/matrix-media-repo/database"
"github.com/turt2live/matrix-media-repo/matrix"
"github.com/turt2live/matrix-media-repo/storage"
"github.com/turt2live/matrix-media-repo/types"
"github.com/turt2live/matrix-media-repo/util"
)
type Attributes struct {
Purpose string `json:"purpose"`
Purpose database.Purpose `json:"purpose"`
}
func canChangeAttributes(rctx rcontext.RequestContext, r *http.Request, origin string, user _apimeta.UserInfo) bool {
......@@ -55,29 +50,32 @@ func GetAttributes(r *http.Request, rctx rcontext.RequestContext, user _apimeta.
}
// Check to see if the media exists
mediaDb := storage.GetDatabase().GetMediaStore(rctx)
media, err := mediaDb.Get(origin, mediaId)
if err != nil && err != sql.ErrNoRows {
mediaDb := database.GetInstance().Media.Prepare(rctx)
media, err := mediaDb.GetById(origin, mediaId)
if err != nil {
rctx.Log.Error(err)
sentry.CaptureException(err)
return _responses.InternalServerError("failed to get media record")
}
if media == nil || err == sql.ErrNoRows {
if media == nil {
return _responses.NotFoundError()
}
db := storage.GetDatabase().GetMediaAttributesStore(rctx)
attrs, err := db.GetAttributesDefaulted(origin, mediaId)
attrDb := database.GetInstance().MediaAttributes.Prepare(rctx)
attrs, err := attrDb.Get(origin, mediaId)
if err != nil {
rctx.Log.Error(err)
sentry.CaptureException(err)
return _responses.InternalServerError("failed to get attributes")
return _responses.InternalServerError("failed to get attributes record")
}
retAttrs := &Attributes{
Purpose: database.PurposeNone,
}
if attrs != nil {
retAttrs.Purpose = attrs.Purpose
}
return &_responses.DoNotCacheResponse{Payload: &Attributes{
Purpose: attrs.Purpose,
}}
return &_responses.DoNotCacheResponse{Payload: retAttrs}
}
func SetAttributes(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} {
......@@ -97,36 +95,29 @@ func SetAttributes(r *http.Request, rctx rcontext.RequestContext, user _apimeta.
return _responses.AuthFailed()
}
defer stream_util.DumpAndCloseStream(r.Body)
b, err := io.ReadAll(r.Body)
if err != nil {
rctx.Log.Error(err)
sentry.CaptureException(err)
return _responses.InternalServerError("failed to read attributes")
}
defer r.Body.Close()
newAttrs := &Attributes{}
err = json.Unmarshal(b, &newAttrs)
decoder := json.NewDecoder(r.Body)
err := decoder.Decode(&newAttrs)
if err != nil {
rctx.Log.Error(err)
sentry.CaptureException(err)
return _responses.InternalServerError("failed to parse attributes")
return _responses.InternalServerError("failed to read attributes")
}
db := storage.GetDatabase().GetMediaAttributesStore(rctx)
attrs, err := db.GetAttributesDefaulted(origin, mediaId)
attrDb := database.GetInstance().MediaAttributes.Prepare(rctx)
attrs, err := attrDb.Get(origin, mediaId)
if err != nil {
rctx.Log.Error(err)
sentry.CaptureException(err)
return _responses.InternalServerError("failed to get attributes")
}
if attrs.Purpose != newAttrs.Purpose {
if !util.ArrayContains(types.AllPurposes, newAttrs.Purpose) {
if attrs == nil || attrs.Purpose != newAttrs.Purpose {
if !database.IsPurpose(newAttrs.Purpose) {
return _responses.BadRequest("unknown purpose")
}
err = db.UpsertPurpose(origin, mediaId, newAttrs.Purpose)
err = attrDb.UpsertPurpose(origin, mediaId, newAttrs.Purpose)
if err != nil {
rctx.Log.Error(err)
sentry.CaptureException(err)
......
......@@ -14,17 +14,18 @@ import (
)
type Database struct {
conn *sql.DB
Media *mediaTableStatements
ExpiringMedia *expiringMediaTableStatements
UserStats *userStatsTableStatements
ReservedMedia *reservedMediaTableStatements
MetadataView *metadataVirtualTableStatements
Blurhashes *blurhashesTableStatements
HeldMedia *heldMediaTableStatements
Thumbnails *thumbnailsTableStatements
LastAccess *lastAccessTableStatements
UrlPreviews *urlPreviewsTableStatements
conn *sql.DB
Media *mediaTableStatements
ExpiringMedia *expiringMediaTableStatements
UserStats *userStatsTableStatements
ReservedMedia *reservedMediaTableStatements
MetadataView *metadataVirtualTableStatements
Blurhashes *blurhashesTableStatements
HeldMedia *heldMediaTableStatements
Thumbnails *thumbnailsTableStatements
LastAccess *lastAccessTableStatements
UrlPreviews *urlPreviewsTableStatements
MediaAttributes *mediaAttributesTableStatements
}
var instance *Database
......@@ -108,6 +109,9 @@ func openDatabase(connectionString string, maxConns int, maxIdleConns int) error
if d.UrlPreviews, err = prepareUrlPreviewsTables(d.conn); err != nil {
return errors.New("failed to create url previews table accessor: " + err.Error())
}
if d.MediaAttributes, err = prepareMediaAttributesTables(d.conn); err != nil {
return errors.New("failed to create media attributes table accessor: " + err.Error())
}
instance = d
return nil
......
package database
import (
"database/sql"
"errors"
"github.com/turt2live/matrix-media-repo/common/rcontext"
)
type DbMediaAttributes struct {
Origin string
MediaId string
Purpose Purpose
}
type Purpose string
const (
PurposeNone Purpose = "none"
PurposePinned Purpose = "pinned"
)
func IsPurpose(purpose Purpose) bool {
return purpose == PurposeNone || purpose == PurposePinned
}
const selectMediaAttributes = "SELECT origin, media_id, purpose FROM media_attributes WHERE origin = $1 AND media_id = $2;"
const upsertMediaPurpose = "INSERT INTO media_attributes (origin, media_id, purpose) VALUES ($1, $2, $3) ON CONFLICT (origin, media_id) DO UPDATE SET purpose = $3;"
type mediaAttributesTableStatements struct {
selectMediaAttributes *sql.Stmt
upsertMediaPurpose *sql.Stmt
}
type mediaAttributesTableWithContext struct {
statements *mediaAttributesTableStatements
ctx rcontext.RequestContext
}
func prepareMediaAttributesTables(db *sql.DB) (*mediaAttributesTableStatements, error) {
var err error
var stmts = &mediaAttributesTableStatements{}
if stmts.selectMediaAttributes, err = db.Prepare(selectMediaAttributes); err != nil {
return nil, errors.New("error preparing selectMediaAttributes: " + err.Error())
}
if stmts.upsertMediaPurpose, err = db.Prepare(upsertMediaPurpose); err != nil {
return nil, errors.New("error preparing upsertMediaPurpose: " + err.Error())
}
return stmts, nil
}
func (s *mediaAttributesTableStatements) Prepare(ctx rcontext.RequestContext) *mediaAttributesTableWithContext {
return &mediaAttributesTableWithContext{
statements: s,
ctx: ctx,
}
}
func (s *mediaAttributesTableWithContext) Get(origin string, mediaId string) (*DbMediaAttributes, error) {
row := s.statements.selectMediaAttributes.QueryRowContext(s.ctx, origin, mediaId)
val := &DbMediaAttributes{}
err := row.Scan(&val.Origin, &val.MediaId, &val.Purpose)
if err == sql.ErrNoRows {
err = nil
val = nil
}
return val, err
}
func (s *mediaAttributesTableWithContext) UpsertPurpose(origin string, mediaId string, purpose Purpose) error {
_, err := s.statements.upsertMediaPurpose.ExecContext(s.ctx, origin, mediaId, purpose)
return err
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment