From 89ec270432e6f615d8be16baf59df6ea74e1f790 Mon Sep 17 00:00:00 2001
From: mattc <buffless-matt@users.noreply.github.com>
Date: Thu, 10 Mar 2022 16:26:24 +1100
Subject: [PATCH] Add - media store parameterized media store function to fetch
 users' usage statistics for a server.

---
 storage/stores/media_store.go | 112 ++++++++++++++++++++++++++++++++++
 types/media.go                |   6 ++
 2 files changed, 118 insertions(+)

diff --git a/storage/stores/media_store.go b/storage/stores/media_store.go
index cf00064c..af94fe5b 100644
--- a/storage/stores/media_store.go
+++ b/storage/stores/media_store.go
@@ -2,6 +2,10 @@ package stores
 
 import (
 	"database/sql"
+	"errors"
+	"fmt"
+	"github.com/turt2live/matrix-media-repo/util"
+	"strings"
 	"sync"
 
 	"github.com/lib/pq"
@@ -33,6 +37,8 @@ const selectMediaByDomainBefore = "SELECT origin, media_id, upload_name, content
 const selectMediaByLocation = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, datastore_id, location, creation_ts, quarantined FROM media WHERE datastore_id = $1 AND location = $2"
 const selectIfQuarantined = "SELECT 1 FROM media WHERE sha256_hash = $1 AND quarantined = $2 LIMIT 1;"
 
+var UsersUsageStatsSorts = []string{"media_count", "media_length", "user_id"}
+
 var dsCacheByPath = sync.Map{} // [string] => Datastore
 var dsCacheById = sync.Map{}   // [string] => Datastore
 
@@ -460,6 +466,112 @@ func (s *MediaStore) GetAllMediaForServer(serverName string) ([]*types.Media, er
 	return results, nil
 }
 
+func (s *MediaStore) GetUsersUsageStatsForServer(
+	serverName string,
+	orderBy string,
+	start int64,
+	limit int64,
+	fromTS int64,
+	untilTS int64,
+	searchTerm string,
+	isAscendingOrder bool,
+) ([]*types.UserUsageStats, int64, error) {
+	if !util.ArrayContains(UsersUsageStatsSorts, orderBy) {
+		return nil, 0, errors.New("invalid orderBy")
+	}
+	if start < 0 {
+		return nil, 0, errors.New("invalid start")
+	}
+	if limit < 0 {
+		return nil, 0, errors.New("invalid limit")
+	}
+
+	orderDirection := "DESC"
+	if isAscendingOrder {
+		orderDirection = "ASC"
+	}
+
+	var queryParamIdx int64 = 1
+	var commonQueryParams = []interface{}{serverName}
+
+	var filters = []string{fmt.Sprintf("origin = $%d", queryParamIdx)}
+	queryParamIdx++
+	if fromTS >= 0 {
+		filters = append(filters, fmt.Sprintf("creation_ts >= $%d", queryParamIdx))
+		queryParamIdx++
+		commonQueryParams = append(commonQueryParams, fromTS)
+	}
+	if untilTS >= 0 {
+		filters = append(filters, fmt.Sprintf("creation_ts <= $%d", queryParamIdx))
+		queryParamIdx++
+		commonQueryParams = append(commonQueryParams, untilTS)
+	}
+	if searchTerm != "" {
+		filters = append(filters, fmt.Sprintf("user_id LIKE $%d", queryParamIdx))
+		queryParamIdx++
+		commonQueryParams = append(commonQueryParams, fmt.Sprintf("@%%%s%%:%%", searchTerm))
+	}
+
+	var otherPaginationParams []interface{}
+	limitClause := fmt.Sprintf("LIMIT $%d", queryParamIdx)
+	queryParamIdx++
+	otherPaginationParams = append(otherPaginationParams, limit)
+
+	offsetClause := fmt.Sprintf("OFFSET $%d", queryParamIdx)
+	queryParamIdx++
+	otherPaginationParams = append(otherPaginationParams, start)
+
+	commonQueryPortion := fmt.Sprintf(
+		"FROM media "+
+			"WHERE %s "+
+			"GROUP BY user_id", strings.Join(filters, " AND "))
+
+	paginationQuery := fmt.Sprintf(
+		"SELECT COUNT(user_id) AS media_count, SUM(size_bytes) AS media_length, user_id "+
+			"%s "+
+			"ORDER BY %s %s "+
+			"%s "+
+			"%s;",
+		commonQueryPortion,
+		orderBy,
+		orderDirection,
+		limitClause,
+		offsetClause)
+	rows, err := s.factory.sqlDb.Query(paginationQuery, append(commonQueryParams, otherPaginationParams...)...)
+	if err != nil {
+		return nil, 0, err
+	}
+
+	var results []*types.UserUsageStats
+	for rows.Next() {
+		obj := &types.UserUsageStats{}
+		err = rows.Scan(
+			&obj.MediaCount,
+			&obj.MediaLength,
+			&obj.UserId,
+		)
+		if err != nil {
+			return nil, 0, err
+		}
+		results = append(results, obj)
+	}
+
+	totalQuery := fmt.Sprintf(
+		"SELECT COUNT(*) "+
+			"FROM ("+
+			"  SELECT user_id "+
+			"  %s "+
+			") as count_user_ids; ",
+		commonQueryPortion)
+	var totalNumRows int64 = 0
+	err = s.factory.sqlDb.QueryRow(totalQuery, commonQueryParams...).Scan(&totalNumRows)
+	if err != nil {
+		return nil, 0, err
+	}
+
+	return results, totalNumRows, nil
+}
+
 func (s *MediaStore) GetAllMediaForServerUsers(serverName string, userIds []string) ([]*types.Media, error) {
 	rows, err := s.statements.selectAllMediaForServerUsers.QueryContext(s.ctx, serverName, pq.Array(userIds))
 	if err != nil {
diff --git a/types/media.go b/types/media.go
index ce8ea0a1..873da958 100644
--- a/types/media.go
+++ b/types/media.go
@@ -35,6 +35,12 @@ type MinimalMediaMetadata struct {
 	DatastoreId  string
 }
 
+type UserUsageStats struct {
+	MediaCount  int64
+	MediaLength int64
+	UserId      string
+}
+
 func (m *Media) MxcUri() string {
 	return "mxc://" + m.Origin + "/" + m.MediaId
 }
-- 
GitLab