Newer
Older

Travis Ralston
committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
package database
import (
"database/sql"
"errors"
"github.com/turt2live/matrix-media-repo/common/rcontext"
"github.com/turt2live/matrix-media-repo/util"
)
type DbExpiringMedia struct {
Origin string
MediaId string
UserId string
ExpiresTs int64
}
const insertExpiringMedia = "INSERT INTO expiring_media (origin, media_id, user_id, expires_ts) VALUES ($1, $2, $3, $4);"
const selectExpiringMediaByUserCount = "SELECT COUNT(*) FROM expiring_media WHERE user_id = $1 AND expires_ts >= $2;"
type expiringMediaTableStatements struct {
insertExpiringMedia *sql.Stmt
selectExpiringMediaByUserCount *sql.Stmt
}
type expiringMediaTableWithContext struct {
statements *expiringMediaTableStatements
ctx rcontext.RequestContext
}
func prepareExpiringMediaTables(db *sql.DB) (*expiringMediaTableStatements, error) {
var err error
var stmts = &expiringMediaTableStatements{}
if stmts.insertExpiringMedia, err = db.Prepare(insertExpiringMedia); err != nil {
return nil, errors.New("error preparing insertExpiringMedia: " + err.Error())
}
if stmts.selectExpiringMediaByUserCount, err = db.Prepare(selectExpiringMediaByUserCount); err != nil {
return nil, errors.New("error preparing selectExpiringMediaByUserCount: " + err.Error())
}
return stmts, nil
}
func (s *expiringMediaTableStatements) Prepare(ctx rcontext.RequestContext) *expiringMediaTableWithContext {
return &expiringMediaTableWithContext{
statements: s,
ctx: ctx,
}
}
func (s *expiringMediaTableWithContext) Insert(origin string, mediaId string, userId string, expiresTs int64) error {
_, err := s.statements.insertExpiringMedia.ExecContext(s.ctx, origin, mediaId, userId, expiresTs)
return err
}
func (s *expiringMediaTableWithContext) ByUserCount(userId string) (int64, error) {
row := s.statements.selectExpiringMediaByUserCount.QueryRowContext(s.ctx, userId, util.NowMillis())
val := int64(0)
err := row.Scan(&val)
if err == sql.ErrNoRows {
err = nil
val = 0
}
return val, err
}