From eaf7415b0a4d209885aa41c64bd57e952b91a001 Mon Sep 17 00:00:00 2001
From: Travis Ralston <travpc@gmail.com>
Date: Sun, 16 Jul 2023 20:38:35 -0600
Subject: [PATCH] Support MSC4034

https://github.com/matrix-org/matrix-spec-proposals/pull/4034
---
 CHANGELOG.md                    |  3 +-
 api/r0/public_config.go         | 31 +++++++++++++++++++--
 api/routes.go                   |  4 ++-
 api/unstable/public_usage.go    | 49 +++++++++++++++++++++++++++++++++
 common/config/models_domain.go  |  1 +
 config.sample.yaml              |  5 ++++
 database/table_media.go         | 16 +++++++++++
 pipelines/_steps/quota/check.go | 39 ++++++++++++++++++++------
 8 files changed, 135 insertions(+), 13 deletions(-)
 create mode 100644 api/unstable/public_usage.go

diff --git a/CHANGELOG.md b/CHANGELOG.md
index ed9e246a..d48e093f 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -83,7 +83,8 @@ path/server, for example, then you can simply update the path in the config for
 ### Added
 
 * Added a `federation.ignoredHosts` config option to block media from individual homeservers.
-* Support for MSC2246 (async uploads) is added, with per-user quota limiting options.
+* Support for [MSC2246](https://github.com/matrix-org/matrix-spec-proposals/pull/2246) (async uploads) is added, with per-user quota limiting options.
+* Support for [MSC4034](https://github.com/matrix-org/matrix-spec-proposals/pull/4034) (self-serve usage information) is added, alongside a new "maximum file count" quota limit.
 
 ### Removed
 
diff --git a/api/r0/public_config.go b/api/r0/public_config.go
index fc9bf3d6..148c5086 100644
--- a/api/r0/public_config.go
+++ b/api/r0/public_config.go
@@ -3,12 +3,16 @@ package r0
 import (
 	"net/http"
 
+	"github.com/getsentry/sentry-go"
 	"github.com/turt2live/matrix-media-repo/api/_apimeta"
 	"github.com/turt2live/matrix-media-repo/common/rcontext"
+	"github.com/turt2live/matrix-media-repo/pipelines/_steps/quota"
 )
 
 type PublicConfigResponse struct {
-	UploadMaxSize int64 `json:"m.upload.size,omitempty"`
+	UploadMaxSize   int64 `json:"m.upload.size,omitempty"`
+	StorageMaxSize  int64 `json:"org.matrix.msc4034.storage.size,omitempty"`
+	StorageMaxFiles int64 `json:"org.matrix.msc4034.storage.max_files,omitempty"`
 }
 
 func PublicConfig(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} {
@@ -21,7 +25,30 @@ func PublicConfig(r *http.Request, rctx rcontext.RequestContext, user _apimeta.U
 		uploadSize = 0 // invokes the omitEmpty
 	}
 
+	storageSize := int64(0)
+	limit, err := quota.Limit(rctx, user.UserId, quota.MaxBytes)
+	if err != nil {
+		rctx.Log.Warn("Non-fatal error getting per-user quota limit (max bytes): ", err)
+		sentry.CaptureException(err)
+	} else {
+		storageSize = limit
+	}
+	if storageSize < 0 {
+		storageSize = 0 // invokes the omitEmpty
+	}
+
+	maxFiles := int64(0)
+	limit, err = quota.Limit(rctx, user.UserId, quota.MaxCount)
+	if err != nil {
+		rctx.Log.Warn("Non-fatal error getting per-user quota limit (max files count): ", err)
+		sentry.CaptureException(err)
+	} else {
+		maxFiles = limit
+	}
+
 	return &PublicConfigResponse{
-		UploadMaxSize: uploadSize,
+		UploadMaxSize:   uploadSize,
+		StorageMaxSize:  storageSize,
+		StorageMaxFiles: maxFiles,
 	}
 }
diff --git a/api/routes.go b/api/routes.go
index 60b6e8a2..5f879ee1 100644
--- a/api/routes.go
+++ b/api/routes.go
@@ -47,6 +47,7 @@ func buildRoutes() http.Handler {
 	register([]string{"GET"}, PrefixMedia, "info/:server/:mediaId", mxUnstable, router, makeRoute(_routers.RequireAccessToken(unstable.MediaInfo), "info", counter))
 	purgeOneRoute := makeRoute(_routers.RequireAccessToken(custom.PurgeIndividualRecord), "purge_individual_media", counter)
 	register([]string{"DELETE"}, PrefixMedia, "download/:server/:mediaId", mxUnstable, router, purgeOneRoute)
+	register([]string{"GET"}, PrefixMedia, "usage", msc4034, router, makeRoute(_routers.RequireAccessToken(unstable.PublicUsage), "usage", counter))
 
 	// Custom and top-level features
 	router.Handler("GET", fmt.Sprintf("%s/version", PrefixMedia), makeRoute(_routers.OptionalAccessToken(custom.GetVersion), "get_version", counter))
@@ -111,8 +112,9 @@ func makeRoute(generator _routers.GeneratorFn, name string, counter *_routers.Re
 type matrixVersions []string
 
 var (
-	//mxAllSpec            matrixVersions = []string{"r0", "v1", "v3", "unstable", "unstable/io.t2bot.media"}
+	//mxAllSpec            matrixVersions = []string{"r0", "v1", "v3", "unstable", "unstable/io.t2bot.media" /* and MSC routes */}
 	mxUnstable           matrixVersions = []string{"unstable", "unstable/io.t2bot.media"}
+	msc4034              matrixVersions = []string{"unstable/org.matrix.msc4034"}
 	mxSpecV3Transition   matrixVersions = []string{"r0", "v1", "v3"}
 	mxSpecV3TransitionCS matrixVersions = []string{"r0", "v3"}
 	mxR0                 matrixVersions = []string{"r0"}
diff --git a/api/unstable/public_usage.go b/api/unstable/public_usage.go
new file mode 100644
index 00000000..1225b248
--- /dev/null
+++ b/api/unstable/public_usage.go
@@ -0,0 +1,49 @@
+package unstable
+
+import (
+	"net/http"
+
+	"github.com/getsentry/sentry-go"
+	"github.com/turt2live/matrix-media-repo/api/_apimeta"
+	"github.com/turt2live/matrix-media-repo/common/rcontext"
+	"github.com/turt2live/matrix-media-repo/pipelines/_steps/quota"
+)
+
+type PublicUsageResponse struct {
+	StorageFree  int64 `json:"org.matrix.msc4034.storage.free,omitempty"`
+	StorageFiles int64 `json:"org.matrix.msc4034.storage.files,omitempty"`
+}
+
+func PublicUsage(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} {
+	storageUsed := int64(0)
+	storageLimit := int64(0)
+	limit, err := quota.Limit(rctx, user.UserId, quota.MaxBytes)
+	if err != nil {
+		rctx.Log.Warn("Non-fatal error getting per-user quota limit (max bytes): ", err)
+		sentry.CaptureException(err)
+	} else if limit > 0 {
+		storageLimit = limit
+	}
+	if storageLimit > 0 {
+		current, err := quota.Current(rctx, user.UserId, quota.MaxBytes)
+		if err != nil {
+			rctx.Log.Warn("Non-fatal error getting per-user quota usage (max bytes @ now): ", err)
+			sentry.CaptureException(err)
+		} else {
+			storageUsed = current
+		}
+	} else {
+		storageLimit = 0
+	}
+
+	fileCount, err := quota.Current(rctx, user.UserId, quota.MaxCount)
+	if err != nil {
+		rctx.Log.Warn("Non-fatal error getting per-user quota usage (files count @ now): ", err)
+		sentry.CaptureException(err)
+	}
+
+	return &PublicUsageResponse{
+		StorageFree:  storageLimit - storageUsed,
+		StorageFiles: fileCount,
+	}
+}
diff --git a/common/config/models_domain.go b/common/config/models_domain.go
index 466d17e0..fb2b9975 100644
--- a/common/config/models_domain.go
+++ b/common/config/models_domain.go
@@ -10,6 +10,7 @@ type QuotaUserConfig struct {
 	Glob       string `yaml:"glob"`
 	MaxBytes   int64  `yaml:"maxBytes"`
 	MaxPending int64  `yaml:"maxPending"`
+	MaxFiles   int64  `yaml:"maxFiles"`
 }
 
 type QuotasConfig struct {
diff --git a/config.sample.yaml b/config.sample.yaml
index be88b11a..89f9aac2 100644
--- a/config.sample.yaml
+++ b/config.sample.yaml
@@ -259,6 +259,11 @@ uploads:
         # complete before starting another one. Defaults to maxPending above. Set to 0 to
         # disable.
         maxPending: 5
+        # The maximum number of uploaded files a user can have. Defaults to zero (no limit).
+        # If both maxBytes and maxFiles are in use then the first condition a user triggers
+        # will prevent upload. Note that a user can still have uploads contributing to maxPending,
+        # but will not be able to complete them if they are at maxFiles.
+        maxFiles: 0
 
 # Settings related to downloading files from the media repository
 downloads:
diff --git a/database/table_media.go b/database/table_media.go
index f1523b6d..060ff0a5 100644
--- a/database/table_media.go
+++ b/database/table_media.go
@@ -37,6 +37,7 @@ const selectMediaById = "SELECT origin, media_id, upload_name, content_type, use
 const selectMediaByUserId = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, creation_ts, quarantined, datastore_id, location FROM media WHERE user_id = $1;"
 const selectMediaByOrigin = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, creation_ts, quarantined, datastore_id, location FROM media WHERE origin = $1;"
 const selectMediaByLocationExists = "SELECT TRUE FROM media WHERE datastore_id = $1 AND location = $2 LIMIT 1;"
+const selectMediaByUserCount = "SELECT COUNT(*) FROM media WHERE user_id = $1;"
 
 type mediaTableStatements struct {
 	selectDistinctMediaDatastoreIds *sql.Stmt
@@ -48,6 +49,7 @@ type mediaTableStatements struct {
 	selectMediaByUserId             *sql.Stmt
 	selectMediaByOrigin             *sql.Stmt
 	selectMediaByLocationExists     *sql.Stmt
+	selectMediaByUserCount          *sql.Stmt
 }
 
 type mediaTableWithContext struct {
@@ -86,6 +88,9 @@ func prepareMediaTables(db *sql.DB) (*mediaTableStatements, error) {
 	if stmts.selectMediaByLocationExists, err = db.Prepare(selectMediaByLocationExists); err != nil {
 		return nil, errors.New("error preparing selectMediaByLocationExists: " + err.Error())
 	}
+	if stmts.selectMediaByUserCount, err = db.Prepare(selectMediaByUserCount); err != nil {
+		return nil, errors.New("error preparing selectMediaByUserCount: " + err.Error())
+	}
 
 	return stmts, nil
 }
@@ -172,6 +177,17 @@ func (s *mediaTableWithContext) GetById(origin string, mediaId string) (*DbMedia
 	return val, err
 }
 
+func (s *mediaTableWithContext) ByUserCount(userId string) (int64, error) {
+	row := s.statements.selectMediaByUserCount.QueryRowContext(s.ctx, userId)
+	val := int64(0)
+	err := row.Scan(&val)
+	if err == sql.ErrNoRows {
+		err = nil
+		val = 0
+	}
+	return val, err
+}
+
 func (s *mediaTableWithContext) IdExists(origin string, mediaId string) (bool, error) {
 	row := s.statements.selectMediaExists.QueryRowContext(s.ctx, origin, mediaId)
 	val := false
diff --git a/pipelines/_steps/quota/check.go b/pipelines/_steps/quota/check.go
index 878f2850..8bba9b90 100644
--- a/pipelines/_steps/quota/check.go
+++ b/pipelines/_steps/quota/check.go
@@ -14,6 +14,7 @@ type Type int64
 const (
 	MaxBytes   Type = 0
 	MaxPending Type = 1
+	MaxCount   Type = 2
 )
 
 func Check(ctx rcontext.RequestContext, userId string, quotaType Type) error {
@@ -22,18 +23,13 @@ func Check(ctx rcontext.RequestContext, userId string, quotaType Type) error {
 		return err
 	}
 
-	var count int64
-	if quotaType == MaxBytes {
-		if limit < 0 {
+	if quotaType == MaxBytes || quotaType == MaxCount {
+		if limit <= 0 {
 			return nil
 		}
-		count, err = database.GetInstance().UserStats.Prepare(ctx).UserUploadedBytes(userId)
-	} else if quotaType == MaxPending {
-		count, err = database.GetInstance().ExpiringMedia.Prepare(ctx).ByUserCount(userId)
-	} else {
-		return errors.New("missing check for quota type - contact developer")
 	}
 
+	count, err := Current(ctx, userId, quotaType)
 	if err != nil {
 		return err
 	}
@@ -44,7 +40,24 @@ func Check(ctx rcontext.RequestContext, userId string, quotaType Type) error {
 	}
 }
 
+func Current(ctx rcontext.RequestContext, userId string, quotaType Type) (int64, error) {
+	var count int64
+	var err error
+	if quotaType == MaxBytes {
+		count, err = database.GetInstance().UserStats.Prepare(ctx).UserUploadedBytes(userId)
+	} else if quotaType == MaxPending {
+		count, err = database.GetInstance().ExpiringMedia.Prepare(ctx).ByUserCount(userId)
+	} else if quotaType == MaxCount {
+		count, err = database.GetInstance().Media.Prepare(ctx).ByUserCount(userId)
+	} else {
+		return 0, errors.New("missing current count for quota type - contact developer")
+	}
+
+	return count, err
+}
+
 func CanUpload(ctx rcontext.RequestContext, userId string, bytes int64) error {
+	// We can't use Check() for MaxBytes because we're testing limit+to_be_uploaded_size
 	limit, err := Limit(ctx, userId, MaxBytes)
 	if err != nil {
 		return err
@@ -53,7 +66,7 @@ func CanUpload(ctx rcontext.RequestContext, userId string, bytes int64) error {
 		return nil
 	}
 
-	count, err := database.GetInstance().UserStats.Prepare(ctx).UserUploadedBytes(userId)
+	count, err := Current(ctx, userId, MaxBytes)
 	if err != nil {
 		return err
 	}
@@ -62,6 +75,10 @@ func CanUpload(ctx rcontext.RequestContext, userId string, bytes int64) error {
 		return common.ErrQuotaExceeded
 	}
 
+	if err = Check(ctx, userId, MaxCount); err != nil {
+		return err
+	}
+
 	return nil
 }
 
@@ -76,6 +93,8 @@ func Limit(ctx rcontext.RequestContext, userId string, quotaType Type) (int64, e
 				return q.MaxBytes, nil
 			} else if quotaType == MaxPending {
 				return q.MaxPending, nil
+			} else if quotaType == MaxCount {
+				return q.MaxFiles, nil
 			} else {
 				return 0, errors.New("missing glob switch for quota type - contact developer")
 			}
@@ -90,6 +109,8 @@ func defaultLimit(ctx rcontext.RequestContext, quotaType Type) (int64, error) {
 		return -1, nil
 	} else if quotaType == MaxPending {
 		return ctx.Config.Uploads.MaxPending, nil
+	} else if quotaType == MaxCount {
+		return 0, nil
 	}
 	return 0, errors.New("no default for quota type - contact developer")
 }
-- 
GitLab