From cfd20e3ddb21d6f06aa277a49fdc694bd9f0d5d3 Mon Sep 17 00:00:00 2001
From: Travis Ralston <travpc@gmail.com>
Date: Wed, 31 Jul 2019 15:32:46 -0600
Subject: [PATCH] Add early stats for the per-domain level

Fixes https://github.com/turt2live/matrix-media-repo/issues/169
---
 api/custom/usage.go                  | 70 ++++++++++++++++++++++++++++
 api/webserver/webserver.go           |  2 +
 migrations/9_origin_indexes_down.sql |  3 ++
 migrations/9_origin_indexes_up.sql   |  3 ++
 storage/stores/metadata_store.go     | 60 ++++++++++++++++++++++++
 5 files changed, 138 insertions(+)
 create mode 100644 api/custom/usage.go
 create mode 100644 migrations/9_origin_indexes_down.sql
 create mode 100644 migrations/9_origin_indexes_up.sql

diff --git a/api/custom/usage.go b/api/custom/usage.go
new file mode 100644
index 00000000..e7bed29f
--- /dev/null
+++ b/api/custom/usage.go
@@ -0,0 +1,70 @@
+package custom
+
+import (
+	"net/http"
+
+	"github.com/gorilla/mux"
+	"github.com/sirupsen/logrus"
+	"github.com/turt2live/matrix-media-repo/api"
+	"github.com/turt2live/matrix-media-repo/storage"
+)
+
+type UsageInfo struct {
+	Total      int64 `json:"total"`
+	Media      int64 `json:"media"`
+	Thumbnails int64 `json:"thumbnails"`
+}
+
+type DomainUsageResponse struct {
+	NumUsers  int       `json:"user_count"`
+	UserIDs   []string  `json:"known_user_ids"`
+	RawBytes  UsageInfo `json:"raw_bytes"`
+	RawCounts UsageInfo `json:"raw_counts"`
+}
+
+func GetDomainUsage(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} {
+	// TODO: Auth check on host (allow local admin to query their own domain)
+
+	params := mux.Vars(r)
+
+	serverName := params["serverName"]
+
+	log = log.WithFields(logrus.Fields{
+		"serverName": serverName,
+	})
+
+	db := storage.GetDatabase().GetMetadataStore(r.Context(), log)
+
+	userIds, err := db.GetUsersForServer(serverName)
+	if err != nil {
+		log.Error(err)
+		return api.InternalServerError("Failed to get users belonging to server")
+	}
+
+	mediaBytes, thumbBytes, err := db.GetByteUsageForServer(serverName)
+	if err != nil {
+		log.Error(err)
+		return api.InternalServerError("Failed to get byte usage for server")
+	}
+
+	mediaCount, thumbCount, err := db.GetCountUsageForServer(serverName)
+	if err != nil {
+		log.Error(err)
+		return api.InternalServerError("Failed to get count usage for server")
+	}
+
+	return &DomainUsageResponse{
+		NumUsers: len(userIds),
+		UserIDs:  userIds,
+		RawBytes: UsageInfo{
+			Total:      mediaBytes + thumbBytes,
+			Media:      mediaBytes,
+			Thumbnails: thumbBytes,
+		},
+		RawCounts: UsageInfo{
+			Total:      mediaCount + thumbCount,
+			Media:      mediaCount,
+			Thumbnails: thumbCount,
+		},
+	}
+}
diff --git a/api/webserver/webserver.go b/api/webserver/webserver.go
index 8890bac7..f653f757 100644
--- a/api/webserver/webserver.go
+++ b/api/webserver/webserver.go
@@ -42,6 +42,7 @@ func Init() {
 	dsTransferHandler := handler{api.RepoAdminRoute(custom.MigrateBetweenDatastores), "datastore_transfer", counter, false}
 	fedTestHandler := handler{api.RepoAdminRoute(custom.GetFederationInfo), "federation_test", counter, false}
 	healthzHandler := handler{api.AccessTokenOptionalRoute(custom.GetHealthz), "healthz", counter, true}
+	domainUsageHandler := handler{api.AccessTokenOptionalRoute(custom.GetDomainUsage), "domain_usage", counter, false}
 
 	routes := make(map[string]route)
 	versions := []string{"r0", "v1", "unstable"} // r0 is typically clients and v1 is typically servers. v1 is deprecated.
@@ -64,6 +65,7 @@ func Init() {
 		routes["/_matrix/media/"+version+"/admin/datastores"] = route{"GET", datastoreListHandler}
 		routes["/_matrix/media/"+version+"/admin/datastores/{sourceDsId:[^/]+}/transfer_to/{targetDsId:[^/]+}"] = route{"POST", dsTransferHandler}
 		routes["/_matrix/media/"+version+"/admin/federation/test/{serverName:[a-zA-Z0-9.:\\-_]+}"] = route{"GET", fedTestHandler}
+		routes["/_matrix/media/"+version+"/admin/usage/{serverName:[a-zA-Z0-9.:\\-_]+}"] = route{"GET", domainUsageHandler}
 
 		// Routes that we should handle but aren't in the media namespace (synapse compat)
 		routes["/_matrix/client/"+version+"/admin/purge_media_cache"] = route{"POST", purgeHandler}
diff --git a/migrations/9_origin_indexes_down.sql b/migrations/9_origin_indexes_down.sql
new file mode 100644
index 00000000..fee15b35
--- /dev/null
+++ b/migrations/9_origin_indexes_down.sql
@@ -0,0 +1,3 @@
+DROP INDEX IF EXISTS idx_origin_media;
+DROP INDEX IF EXISTS idx_origin_thumbnails;
+DROP INDEX IF EXISTS idx_origin_user_id_media;
diff --git a/migrations/9_origin_indexes_up.sql b/migrations/9_origin_indexes_up.sql
new file mode 100644
index 00000000..326ef709
--- /dev/null
+++ b/migrations/9_origin_indexes_up.sql
@@ -0,0 +1,3 @@
+CREATE INDEX IF NOT EXISTS idx_origin_media ON media(origin);
+CREATE INDEX IF NOT EXISTS idx_origin_thumbnails ON thumbnails(origin);
+CREATE INDEX IF NOT EXISTS idx_origin_user_id_media ON media(origin, user_id);
diff --git a/storage/stores/metadata_store.go b/storage/stores/metadata_store.go
index 61347caa..12a3cf32 100644
--- a/storage/stores/metadata_store.go
+++ b/storage/stores/metadata_store.go
@@ -18,6 +18,9 @@ const selectMediaLastAccessedBeforeInDatastore = "SELECT m.sha256_hash, m.size_b
 const selectThumbnailsLastAccessedBeforeInDatastore = "SELECT m.sha256_hash, m.size_bytes, m.datastore_id, m.location, m.creation_ts, a.last_access_ts FROM thumbnails AS m JOIN last_access AS a ON m.sha256_hash = a.sha256_hash WHERE a.last_access_ts < $1 AND m.datastore_id = $2"
 const changeDatastoreOfMediaHash = "UPDATE media SET datastore_id = $1, location = $2 WHERE sha256_hash = $3"
 const changeDatastoreOfThumbnailHash = "UPDATE thumbnails SET datastore_id = $1, location = $2 WHERE sha256_hash = $3"
+const selectUploadCountsForServer = "SELECT COALESCE((SELECT COUNT(origin) FROM media WHERE origin = $1), 0) AS media, COALESCE((SELECT COUNT(origin) FROM thumbnails WHERE origin = $1), 0) AS thumbnails"
+const selectUploadSizesForServer = "SELECT COALESCE((SELECT SUM(size_bytes) FROM media WHERE origin = $1), 0) AS media, COALESCE((SELECT SUM(size_bytes) FROM thumbnails WHERE origin = $1), 0) AS thumbnails"
+const selectUsersForServer = "SELECT DISTINCT user_id FROM media WHERE origin = $1 AND user_id IS NOT NULL AND LENGTH(user_id) > 0"
 
 type metadataStoreStatements struct {
 	upsertLastAccessed                            *sql.Stmt
@@ -26,6 +29,9 @@ type metadataStoreStatements struct {
 	selectThumbnailsLastAccessedBeforeInDatastore *sql.Stmt
 	changeDatastoreOfMediaHash                    *sql.Stmt
 	changeDatastoreOfThumbnailHash                *sql.Stmt
+	selectUploadCountsForServer                   *sql.Stmt
+	selectUploadSizesForServer                    *sql.Stmt
+	selectUsersForServer                          *sql.Stmt
 }
 
 type MetadataStoreFactory struct {
@@ -64,6 +70,15 @@ func InitMetadataStore(sqlDb *sql.DB) (*MetadataStoreFactory, error) {
 	if store.stmts.changeDatastoreOfThumbnailHash, err = store.sqlDb.Prepare(changeDatastoreOfThumbnailHash); err != nil {
 		return nil, err
 	}
+	if store.stmts.selectUsersForServer, err = store.sqlDb.Prepare(selectUsersForServer); err != nil {
+		return nil, err
+	}
+	if store.stmts.selectUploadSizesForServer, err = store.sqlDb.Prepare(selectUploadSizesForServer); err != nil {
+		return nil, err
+	}
+	if store.stmts.selectUploadCountsForServer, err = store.sqlDb.Prepare(selectUploadCountsForServer); err != nil {
+		return nil, err
+	}
 
 	return &store, nil
 }
@@ -151,3 +166,48 @@ func (s *MetadataStore) GetOldThumbnailsInDatastore(datastoreId string, beforeTs
 
 	return results, nil
 }
+
+func (s *MetadataStore) GetUsersForServer(serverName string) ([]string, error) {
+	rows, err := s.statements.selectUsersForServer.QueryContext(s.ctx, serverName)
+	if err != nil {
+		return nil, err
+	}
+
+	results := make([]string, 0)
+	for rows.Next() {
+		v := ""
+		err = rows.Scan(&v)
+		if err != nil {
+			return nil, err
+		}
+		results = append(results, v)
+	}
+
+	return results, nil
+}
+
+func (s *MetadataStore) GetByteUsageForServer(serverName string) (int64, int64, error) {
+	row := s.statements.selectUploadSizesForServer.QueryRowContext(s.ctx, serverName)
+
+	media := int64(0)
+	thumbs := int64(0)
+	err := row.Scan(&media, &thumbs)
+	if err != nil {
+		return 0, 0, err
+	}
+
+	return media, thumbs, nil
+}
+
+func (s *MetadataStore) GetCountUsageForServer(serverName string) (int64, int64, error) {
+	row := s.statements.selectUploadCountsForServer.QueryRowContext(s.ctx, serverName)
+
+	media := int64(0)
+	thumbs := int64(0)
+	err := row.Scan(&media, &thumbs)
+	if err != nil {
+		return 0, 0, err
+	}
+
+	return media, thumbs, nil
+}
-- 
GitLab