From 89ec270432e6f615d8be16baf59df6ea74e1f790 Mon Sep 17 00:00:00 2001 From: mattc <buffless-matt@users.noreply.github.com> Date: Thu, 10 Mar 2022 16:26:24 +1100 Subject: [PATCH] Add - media store parameterized media store function to fetch users' usage statistics for a server. --- storage/stores/media_store.go | 112 ++++++++++++++++++++++++++++++++++ types/media.go | 6 ++ 2 files changed, 118 insertions(+) diff --git a/storage/stores/media_store.go b/storage/stores/media_store.go index cf00064c..af94fe5b 100644 --- a/storage/stores/media_store.go +++ b/storage/stores/media_store.go @@ -2,6 +2,10 @@ package stores import ( "database/sql" + "errors" + "fmt" + "github.com/turt2live/matrix-media-repo/util" + "strings" "sync" "github.com/lib/pq" @@ -33,6 +37,8 @@ const selectMediaByDomainBefore = "SELECT origin, media_id, upload_name, content const selectMediaByLocation = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, datastore_id, location, creation_ts, quarantined FROM media WHERE datastore_id = $1 AND location = $2" const selectIfQuarantined = "SELECT 1 FROM media WHERE sha256_hash = $1 AND quarantined = $2 LIMIT 1;" +var UsersUsageStatsSorts = []string{"media_count", "media_length", "user_id"} + var dsCacheByPath = sync.Map{} // [string] => Datastore var dsCacheById = sync.Map{} // [string] => Datastore @@ -460,6 +466,112 @@ func (s *MediaStore) GetAllMediaForServer(serverName string) ([]*types.Media, er return results, nil } +func (s *MediaStore) GetUsersUsageStatsForServer( + serverName string, + orderBy string, + start int64, + limit int64, + fromTS int64, + untilTS int64, + searchTerm string, + isAscendingOrder bool, +) ([]*types.UserUsageStats, int64, error) { + if !util.ArrayContains(UsersUsageStatsSorts, orderBy) { + return nil, 0, errors.New("invalid orderBy") + } + if start < 0 { + return nil, 0, errors.New("invalid start") + } + if limit < 0 { + return nil, 0, errors.New("invalid limit") + } + + orderDirection := "DESC" + if isAscendingOrder { + orderDirection = "ASC" + } + + var queryParamIdx int64 = 1 + var commonQueryParams = []interface{}{serverName} + + var filters = []string{fmt.Sprintf("origin = $%d", queryParamIdx)} + queryParamIdx++ + if fromTS >= 0 { + filters = append(filters, fmt.Sprintf("creation_ts >= $%d", queryParamIdx)) + queryParamIdx++ + commonQueryParams = append(commonQueryParams, fromTS) + } + if untilTS >= 0 { + filters = append(filters, fmt.Sprintf("creation_ts <= $%d", queryParamIdx)) + queryParamIdx++ + commonQueryParams = append(commonQueryParams, untilTS) + } + if searchTerm != "" { + filters = append(filters, fmt.Sprintf("user_id LIKE $%d", queryParamIdx)) + queryParamIdx++ + commonQueryParams = append(commonQueryParams, fmt.Sprintf("@%%%s%%:%%", searchTerm)) + } + + var otherPaginationParams []interface{} + limitClause := fmt.Sprintf("LIMIT $%d", queryParamIdx) + queryParamIdx++ + otherPaginationParams = append(otherPaginationParams, limit) + + offsetClause := fmt.Sprintf("OFFSET $%d", queryParamIdx) + queryParamIdx++ + otherPaginationParams = append(otherPaginationParams, start) + + commonQueryPortion := fmt.Sprintf( + "FROM media "+ + "WHERE %s "+ + "GROUP BY user_id", strings.Join(filters, " AND ")) + + paginationQuery := fmt.Sprintf( + "SELECT COUNT(user_id) AS media_count, SUM(size_bytes) AS media_length, user_id "+ + "%s "+ + "ORDER BY %s %s "+ + "%s "+ + "%s;", + commonQueryPortion, + orderBy, + orderDirection, + limitClause, + offsetClause) + rows, err := s.factory.sqlDb.Query(paginationQuery, append(commonQueryParams, otherPaginationParams...)...) + if err != nil { + return nil, 0, err + } + + var results []*types.UserUsageStats + for rows.Next() { + obj := &types.UserUsageStats{} + err = rows.Scan( + &obj.MediaCount, + &obj.MediaLength, + &obj.UserId, + ) + if err != nil { + return nil, 0, err + } + results = append(results, obj) + } + + totalQuery := fmt.Sprintf( + "SELECT COUNT(*) "+ + "FROM ("+ + " SELECT user_id "+ + " %s "+ + ") as count_user_ids; ", + commonQueryPortion) + var totalNumRows int64 = 0 + err = s.factory.sqlDb.QueryRow(totalQuery, commonQueryParams...).Scan(&totalNumRows) + if err != nil { + return nil, 0, err + } + + return results, totalNumRows, nil +} + func (s *MediaStore) GetAllMediaForServerUsers(serverName string, userIds []string) ([]*types.Media, error) { rows, err := s.statements.selectAllMediaForServerUsers.QueryContext(s.ctx, serverName, pq.Array(userIds)) if err != nil { diff --git a/types/media.go b/types/media.go index ce8ea0a1..873da958 100644 --- a/types/media.go +++ b/types/media.go @@ -35,6 +35,12 @@ type MinimalMediaMetadata struct { DatastoreId string } +type UserUsageStats struct { + MediaCount int64 + MediaLength int64 + UserId string +} + func (m *Media) MxcUri() string { return "mxc://" + m.Origin + "/" + m.MediaId } -- GitLab