-
Travis Ralston authoredTravis Ralston authored
virtualtable_metadata.go 10.20 KiB
package database
import (
"database/sql"
"errors"
"fmt"
"strings"
"github.com/turt2live/matrix-media-repo/common/rcontext"
)
type VirtLastAccess struct {
*Locatable
SizeBytes int64
CreationTs int64
LastAccessTs int64
ContentType string
}
const selectEstimatedDatastoreSize = "SELECT COALESCE(SUM(m2.size_bytes), 0) + COALESCE((SELECT SUM(t2.size_bytes) FROM (SELECT DISTINCT t.sha256_hash, MAX(t.size_bytes) AS size_bytes FROM thumbnails AS t WHERE t.datastore_id = $1 GROUP BY t.sha256_hash) AS t2), 0) AS size_total FROM (SELECT DISTINCT m.sha256_hash, MAX(m.size_bytes) AS size_bytes FROM media AS m WHERE m.datastore_id = $1 GROUP BY m.sha256_hash) AS m2;"
const selectUploadSizesForServer = "SELECT COALESCE((SELECT SUM(size_bytes) FROM media WHERE origin = $1), 0) AS media, COALESCE((SELECT SUM(size_bytes) FROM thumbnails WHERE origin = $1), 0) AS thumbnails;"
const selectUploadCountsForServer = "SELECT COALESCE((SELECT COUNT(origin) FROM media WHERE origin = $1), 0) AS media, COALESCE((SELECT COUNT(origin) FROM thumbnails WHERE origin = $1), 0) AS thumbnails;"
const selectMediaForDatastoreWithLastAccess = "SELECT m.sha256_hash, m.size_bytes, m.datastore_id, m.location, m.creation_ts, a.last_access_ts, m.content_type FROM media AS m JOIN last_access AS a ON m.sha256_hash = a.sha256_hash WHERE a.last_access_ts < $1 AND m.datastore_id = $2;"
const selectThumbnailsForDatastoreWithLastAccess = "SELECT m.sha256_hash, m.size_bytes, m.datastore_id, m.location, m.creation_ts, a.last_access_ts, m.content_type FROM thumbnails AS m JOIN last_access AS a ON m.sha256_hash = a.sha256_hash WHERE a.last_access_ts < $1 AND m.datastore_id = $2;"
const updateQuarantineByHash = "WITH t AS (SELECT m.origin AS origin, m.media_id AS media_id, a.purpose AS purpose FROM media AS m LEFT JOIN media_attributes AS a ON m.origin = a.origin AND m.media_id = a.media_id WHERE m.sha256_hash = $1 AND (a.purpose IS NULL OR a.purpose <> $2) AND m.quarantined <> $3) UPDATE media AS m2 SET quarantined = $3 FROM t WHERE m2.origin = t.origin AND m2.media_id = t.media_id;"
const updateQuarantineByHashAndOrigin = "WITH t AS (SELECT m.origin AS origin, m.media_id AS media_id, a.purpose AS purpose FROM media AS m LEFT JOIN media_attributes AS a ON m.origin = a.origin AND m.media_id = a.media_id WHERE m.origin = $1 AND m.sha256_hash = $2 AND (a.purpose IS NULL OR a.purpose <> $3) AND m.quarantined <> $4) UPDATE media AS m2 SET quarantined = $4 FROM t WHERE m2.origin = t.origin AND m2.media_id = t.media_id;"
type SynStatUserOrderBy string
const (
SynStatUserOrderByMediaCount SynStatUserOrderBy = "media_count"
SynStatUserOrderByMediaLength SynStatUserOrderBy = "media_length"
SynStatUserOrderByUserId SynStatUserOrderBy = "user_id"
DefaultSynStatUserOrderBy = SynStatUserOrderByUserId
)
func IsSynStatUserOrderBy(orderBy SynStatUserOrderBy) bool {
return orderBy == SynStatUserOrderByMediaCount || orderBy == SynStatUserOrderByMediaLength || orderBy == SynStatUserOrderByUserId
}
type DbSynUserStat struct {
UserId string
MediaCount int64
MediaLength int64
}
type metadataVirtualTableStatements struct {
db *sql.DB
selectEstimatedDatastoreSize *sql.Stmt
selectUploadSizesForServer *sql.Stmt
selectUploadCountsForServer *sql.Stmt
selectMediaForDatastoreWithLastAccess *sql.Stmt
selectThumbnailsForDatastoreWithLastAccess *sql.Stmt
updateQuarantineByHash *sql.Stmt
updateQuarantineByHashAndOrigin *sql.Stmt
}
type metadataVirtualTableWithContext struct {
statements *metadataVirtualTableStatements
ctx rcontext.RequestContext
}
func prepareMetadataVirtualTables(db *sql.DB) (*metadataVirtualTableStatements, error) {
var err error
var stmts = &metadataVirtualTableStatements{
db: db,
}
if stmts.selectEstimatedDatastoreSize, err = db.Prepare(selectEstimatedDatastoreSize); err != nil {
return nil, errors.New("error preparing selectEstimatedDatastoreSize: " + err.Error())
}
if stmts.selectUploadSizesForServer, err = db.Prepare(selectUploadSizesForServer); err != nil {
return nil, errors.New("error preparing selectUploadSizesForServer: " + err.Error())
}
if stmts.selectUploadCountsForServer, err = db.Prepare(selectUploadCountsForServer); err != nil {
return nil, errors.New("error preparing selectUploadCountsForServer: " + err.Error())
}
if stmts.selectMediaForDatastoreWithLastAccess, err = db.Prepare(selectMediaForDatastoreWithLastAccess); err != nil {
return nil, errors.New("error preparing selectMediaForDatastoreWithLastAccess: " + err.Error())
}
if stmts.selectThumbnailsForDatastoreWithLastAccess, err = db.Prepare(selectThumbnailsForDatastoreWithLastAccess); err != nil {
return nil, errors.New("error preparing selectThumbnailsForDatastoreWithLastAccess: " + err.Error())
}
if stmts.updateQuarantineByHash, err = db.Prepare(updateQuarantineByHash); err != nil {
return nil, errors.New("error preparing updateQuarantineByHash: " + err.Error())
}
if stmts.updateQuarantineByHashAndOrigin, err = db.Prepare(updateQuarantineByHashAndOrigin); err != nil {
return nil, errors.New("error preparing updateQuarantineByHashAndOrigin: " + err.Error())
}
return stmts, nil
}
func (s *metadataVirtualTableStatements) Prepare(ctx rcontext.RequestContext) *metadataVirtualTableWithContext {
return &metadataVirtualTableWithContext{
statements: s,
ctx: ctx,
}
}
func (s *metadataVirtualTableWithContext) EstimateDatastoreSize(datastoreId string) (int64, error) {
row := s.statements.selectEstimatedDatastoreSize.QueryRowContext(s.ctx, datastoreId)
val := int64(0)
err := row.Scan(&val)
if err == sql.ErrNoRows {
err = nil
val = 0
}
return val, err
}
func (s *metadataVirtualTableWithContext) ByteUsageForServer(serverName string) (int64, int64, error) {
row := s.statements.selectUploadSizesForServer.QueryRowContext(s.ctx, serverName)
media := int64(0)
thumbs := int64(0)
err := row.Scan(&media, &thumbs)
if err == sql.ErrNoRows {
err = nil
media = int64(0)
thumbs = int64(0)
}
return media, thumbs, err
}
func (s *metadataVirtualTableWithContext) CountUsageForServer(serverName string) (int64, int64, error) {
row := s.statements.selectUploadCountsForServer.QueryRowContext(s.ctx, serverName)
media := int64(0)
thumbs := int64(0)
err := row.Scan(&media, &thumbs)
if err == sql.ErrNoRows {
err = nil
media = int64(0)
thumbs = int64(0)
}
return media, thumbs, err
}
func (s *metadataVirtualTableWithContext) UnoptimizedSynapseUserStatsPage(serverName string, orderBy SynStatUserOrderBy, startIdx int64, limit int64, fromTs int64, untilTs int64, search string, asc bool) ([]*DbSynUserStat, int64, error) {
sqlDir := "DESC"
if asc {
sqlDir = "ASC"
}
if !IsSynStatUserOrderBy(orderBy) {
return nil, 0, errors.New("sql injection prevented: orderBy must be recognized")
}
sqlParams := make([]interface{}, 0)
sqlWhere := make([]string, 0)
addParam := func(val interface{}) {
sqlParams = append(sqlParams, val)
}
addWhere := func(str string, val interface{}) {
addParam(val)
sqlWhere = append(sqlWhere, fmt.Sprintf(str, len(sqlParams)))
}
addWhere("origin = $%d", serverName)
if fromTs >= 0 {
addWhere("creation_ts >= $%d", fromTs)
}
if untilTs >= 0 {
addWhere("creation_ts <= $%d", untilTs)
}
if search != "" {
addWhere("user_id LIKE $%d", fmt.Sprintf("@%%%s%%:%%", search))
}
addWhere("user_id <> $%d", "")
addParam(limit)
sqlLimit := fmt.Sprintf("LIMIT $%d", len(sqlParams))
addParam(startIdx)
sqlOffset := fmt.Sprintf("OFFSET $%d", len(sqlParams))
sqlStart := fmt.Sprintf("FROM media WHERE %s GROUP BY user_id", strings.Join(sqlWhere, " AND "))
sqlPageQ := fmt.Sprintf("SELECT COUNT(user_id) AS media_count, SUM(size_bytes) AS media_length, user_id %s ORDER BY %s %s %s %s;", sqlStart, orderBy, sqlDir, sqlLimit, sqlOffset)
results := make([]*DbSynUserStat, 0)
rows, err := s.statements.db.QueryContext(s.ctx, sqlPageQ, sqlParams...)
if err != nil {
if err == sql.ErrNoRows {
return results, 0, nil
}
return nil, 0, err
}
for rows.Next() {
val := &DbSynUserStat{}
err = rows.Scan(&val.MediaCount, &val.MediaLength, &val.UserId)
if err != nil {
return nil, 0, err
}
results = append(results, val)
}
sqlTotalQ := fmt.Sprintf("SELECT COUNT(*) FROM (SELECT user_id %s) AS count_user_ids;", sqlStart)
sqlParams = sqlParams[:len(sqlParams)-2] // trim off LIMIT and OFFSET values
row := s.statements.db.QueryRowContext(s.ctx, sqlTotalQ, sqlParams...)
total := int64(0)
err = row.Scan(&total)
if err != nil {
if err == sql.ErrNoRows {
return make([]*DbSynUserStat, 0), 0, nil
}
return nil, 0, err
}
return results, total, nil
}
func (s *metadataVirtualTableWithContext) scanLastAccess(rows *sql.Rows, err error) ([]*VirtLastAccess, error) {
results := make([]*VirtLastAccess, 0)
if err != nil {
if err == sql.ErrNoRows {
return results, nil
}
return nil, err
}
for rows.Next() {
val := &VirtLastAccess{Locatable: &Locatable{}}
if err = rows.Scan(&val.Sha256Hash, &val.SizeBytes, &val.DatastoreId, &val.Location, &val.CreationTs, &val.LastAccessTs, &val.ContentType); err != nil {
return nil, err
}
results = append(results, val)
}
return results, nil
}
func (s *metadataVirtualTableWithContext) GetMediaForDatastoreByLastAccess(datastoreId string, lastAccessTs int64) ([]*VirtLastAccess, error) {
return s.scanLastAccess(s.statements.selectMediaForDatastoreWithLastAccess.QueryContext(s.ctx, lastAccessTs, datastoreId))
}
func (s *metadataVirtualTableWithContext) GetThumbnailsForDatastoreByLastAccess(datastoreId string, lastAccessTs int64) ([]*VirtLastAccess, error) {
return s.scanLastAccess(s.statements.selectThumbnailsForDatastoreWithLastAccess.QueryContext(s.ctx, lastAccessTs, datastoreId))
}
func (s *metadataVirtualTableWithContext) UpdateQuarantineByHash(hash string, quarantined bool) (int64, error) {
c, err := s.statements.updateQuarantineByHash.ExecContext(s.ctx, hash, PurposePinned, quarantined)
if err != nil {
return 0, err
}
return c.RowsAffected()
}
func (s *metadataVirtualTableWithContext) UpdateQuarantineByHashAndOrigin(origin string, hash string, quarantined bool) (int64, error) {
c, err := s.statements.updateQuarantineByHashAndOrigin.ExecContext(s.ctx, origin, hash, PurposePinned, quarantined)
if err != nil {
return 0, err
}
return c.RowsAffected()
}