From eb625dd968a34579559faf51a50b4f0f90b21ecf Mon Sep 17 00:00:00 2001
From: Travis Ralston <travpc@gmail.com>
Date: Sun, 2 Aug 2020 00:12:11 -0600
Subject: [PATCH] Add support for quotas

Fixes https://github.com/turt2live/matrix-media-repo/issues/100
---
 api/r0/upload.go                            | 12 +++++++
 api/responses.go                            |  4 +++
 api/webserver/route_handler.go              |  3 ++
 common/config/conf_min_shared.go            |  4 +++
 common/config/models_domain.go              | 17 +++++++--
 common/errorcodes.go                        |  2 ++
 config.sample.yaml                          | 19 ++++++++++
 migrations/17_add_user_stats_table_down.sql |  3 ++
 migrations/17_add_user_stats_table_up.sql   | 40 +++++++++++++++++++++
 quota/quota.go                              | 35 ++++++++++++++++++
 storage/stores/metadata_store.go            | 19 ++++++++++
 types/stats.go                              |  6 ++++
 12 files changed, 161 insertions(+), 3 deletions(-)
 create mode 100644 migrations/17_add_user_stats_table_down.sql
 create mode 100644 migrations/17_add_user_stats_table_up.sql
 create mode 100644 quota/quota.go
 create mode 100644 types/stats.go

diff --git a/api/r0/upload.go b/api/r0/upload.go
index 74d7eef8..1f7b803c 100644
--- a/api/r0/upload.go
+++ b/api/r0/upload.go
@@ -13,6 +13,7 @@ import (
 	"github.com/turt2live/matrix-media-repo/common/rcontext"
 	"github.com/turt2live/matrix-media-repo/controllers/info_controller"
 	"github.com/turt2live/matrix-media-repo/controllers/upload_controller"
+	"github.com/turt2live/matrix-media-repo/quota"
 	"github.com/turt2live/matrix-media-repo/util/cleanup"
 )
 
@@ -44,6 +45,17 @@ func UploadMedia(r *http.Request, rctx rcontext.RequestContext, user api.UserInf
 		return api.RequestTooSmall()
 	}
 
+	inQuota, err := quota.IsUserWithinQuota(rctx, user.UserId)
+	if err != nil {
+		io.Copy(ioutil.Discard, r.Body) // Ditch the entire request
+		rctx.Log.Error("Unexpected error checking quota: " + err.Error())
+		return api.InternalServerError("Unexpected Error")
+	}
+	if !inQuota {
+		io.Copy(ioutil.Discard, r.Body) // Ditch the entire request
+		return api.QuotaExceeded()
+	}
+
 	contentLength := upload_controller.EstimateContentLength(r.ContentLength, r.Header.Get("Content-Length"))
 
 	media, err := upload_controller.UploadMedia(r.Body, contentLength, contentType, filename, user.UserId, r.Host, rctx)
diff --git a/api/responses.go b/api/responses.go
index f439191e..d032ce80 100644
--- a/api/responses.go
+++ b/api/responses.go
@@ -49,3 +49,7 @@ func AuthFailed() *ErrorResponse {
 func BadRequest(message string) *ErrorResponse {
 	return &ErrorResponse{common.ErrCodeUnknown, message, common.ErrCodeBadRequest}
 }
+
+func QuotaExceeded() *ErrorResponse {
+	return &ErrorResponse{common.ErrCodeForbidden, "Quota Exceeded", common.ErrCodeQuotaExceeded}
+}
diff --git a/api/webserver/route_handler.go b/api/webserver/route_handler.go
index cbf09359..eff47562 100644
--- a/api/webserver/route_handler.go
+++ b/api/webserver/route_handler.go
@@ -150,6 +150,9 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 		case common.ErrCodeMethodNotAllowed:
 			statusCode = http.StatusMethodNotAllowed
 			break
+		case common.ErrCodeForbidden:
+			statusCode = http.StatusForbidden
+			break
 		default: // Treat as unknown (a generic server error)
 			statusCode = http.StatusInternalServerError
 			break
diff --git a/common/config/conf_min_shared.go b/common/config/conf_min_shared.go
index fb212855..00408488 100644
--- a/common/config/conf_min_shared.go
+++ b/common/config/conf_min_shared.go
@@ -23,6 +23,10 @@ func NewDefaultMinimumRepoConfig() MinimumRepoConfig {
 			MaxSizeBytes:         104857600, // 100mb
 			MinSizeBytes:         100,
 			ReportedMaxSizeBytes: 0,
+			Quota: QuotasConfig{
+				Enabled:    false,
+				UserQuotas: []QuotaUserConfig{},
+			},
 		},
 		Identicons: IdenticonsConfig{
 			Enabled: true,
diff --git a/common/config/models_domain.go b/common/config/models_domain.go
index 8ede29d9..caaf7286 100644
--- a/common/config/models_domain.go
+++ b/common/config/models_domain.go
@@ -6,10 +6,21 @@ type ArchivingConfig struct {
 	TargetBytesPerPart int64 `yaml:"targetBytesPerPart"`
 }
 
+type QuotaUserConfig struct {
+	Glob     string `yaml:"glob"`
+	MaxBytes int64  `yaml:"maxBytes"`
+}
+
+type QuotasConfig struct {
+	Enabled    bool              `yaml:"enabled"`
+	UserQuotas []QuotaUserConfig `yaml:"users,flow"`
+}
+
 type UploadsConfig struct {
-	MaxSizeBytes         int64 `yaml:"maxBytes"`
-	MinSizeBytes         int64 `yaml:"minBytes"`
-	ReportedMaxSizeBytes int64 `yaml:"reportedMaxBytes"`
+	MaxSizeBytes         int64        `yaml:"maxBytes"`
+	MinSizeBytes         int64        `yaml:"minBytes"`
+	ReportedMaxSizeBytes int64        `yaml:"reportedMaxBytes"`
+	Quota                QuotasConfig `yaml:"quotas"`
 }
 
 type DatastoreConfig struct {
diff --git a/common/errorcodes.go b/common/errorcodes.go
index a8cc23ee..8deb6494 100644
--- a/common/errorcodes.go
+++ b/common/errorcodes.go
@@ -13,3 +13,5 @@ const ErrCodeMethodNotAllowed = "M_METHOD_NOT_ALLOWED"
 const ErrCodeBadRequest = "M_BAD_REQUEST"
 const ErrCodeRateLimitExceeded = "M_LIMIT_EXCEEDED"
 const ErrCodeUnknown = "M_UNKNOWN"
+const ErrCodeForbidden = "M_FORBIDDEN"
+const ErrCodeQuotaExceeded = "M_QUOTA_EXCEEDED"
diff --git a/config.sample.yaml b/config.sample.yaml
index 57ec20a4..04b35845 100644
--- a/config.sample.yaml
+++ b/config.sample.yaml
@@ -195,6 +195,7 @@ archiving:
 
 # The file upload settings for the media repository
 uploads:
+  # The maximum individual file size a user can upload.
   maxBytes: 104857600 # 100MB default, 0 to disable
 
   # The minimum number of bytes to let people upload. This is recommended to be non-zero to
@@ -210,6 +211,24 @@ uploads:
   # Set this to -1 to indicate that there is no limit. Zero will force the use of maxBytes.
   #reportedMaxBytes: 104857600
 
+  # Options for limiting how much content a user can upload. Quotas are applied to content
+  # associated with a user regardless of de-duplication. Quotas which affect remote servers
+  # or users will not take effect. When a user exceeds their quota they will be unable to
+  # upload any more media.
+  quotas:
+    # Whether or not quotas are enabled/enforced. Note that even when disabled the media repo
+    # will track how much media a user has uploaded. This is disabled by default.
+    enabled: false
+
+    # The quota rules that affect users. The first rule to match the uploader will take effect.
+    # An implied rule which matches all users and has no quota is always last in this list,
+    # meaning that if no rules are supplied then users will be able to upload anything. Similarly,
+    # if no rules match a user then the implied rule will match, allowing the user to have no
+    # quota. The quota will let the user upload to 1 media past their quota, meaning that from
+    # a statistics perspective the user might exceed their quota however only by a small amount.
+    users:
+      - glob: "@*:*"  # Affect all users. Use asterisks (*) to match any character.
+        maxBytes: 53687063712 # 50GB default, 0 to disable
 
 # Settings related to downloading files from the media repository
 downloads:
diff --git a/migrations/17_add_user_stats_table_down.sql b/migrations/17_add_user_stats_table_down.sql
new file mode 100644
index 00000000..287f0eb0
--- /dev/null
+++ b/migrations/17_add_user_stats_table_down.sql
@@ -0,0 +1,3 @@
+DROP TRIGGER media_change_for_user;
+DELETE FUNCTION track_update_user_media();
+DROP TABLE user_stats;
diff --git a/migrations/17_add_user_stats_table_up.sql b/migrations/17_add_user_stats_table_up.sql
new file mode 100644
index 00000000..67f1b85c
--- /dev/null
+++ b/migrations/17_add_user_stats_table_up.sql
@@ -0,0 +1,40 @@
+CREATE TABLE IF NOT EXISTS user_stats (
+	user_id TEXT PRIMARY KEY NOT NULL,
+	uploaded_bytes BIGINT NOT NULL
+);
+CREATE OR REPLACE FUNCTION track_update_user_media()
+    RETURNS TRIGGER
+    LANGUAGE PLPGSQL
+    AS
+$$
+BEGIN
+    IF TG_OP = 'UPDATE' THEN
+        INSERT INTO user_stats (user_id, uploaded_bytes) VALUES (NEW.user_id, 0) ON CONFLICT (user_id) DO NOTHING;
+        INSERT INTO user_stats (user_id, uploaded_bytes) VALUES (OLD.user_id, 0) ON CONFLICT (user_id) DO NOTHING;
+
+        IF NEW.user_id <> OLD.user_id THEN
+            UPDATE user_stats SET uploaded_bytes = user_stats.uploaded_bytes - OLD.size_bytes WHERE user_stats.user_id = OLD.user_id;
+            UPDATE user_stats SET uploaded_bytes = user_stats.uploaded_bytes + NEW.size_bytes WHERE user_stats.user_id = NEW.user_id;
+        ELSIF NEW.size_bytes <> OLD.size_bytes THEN
+            UPDATE user_stats SET uploaded_bytes = user_stats.uploaded_bytes - OLD.size_bytes + NEW.size_bytes WHERE user_stats.user_id = NEW.user_id;
+        END IF;
+        RETURN NEW;
+    ELSIF TG_OP = 'DELETE' THEN
+        UPDATE user_stats SET uploaded_bytes = user_stats.uploaded_bytes - OLD.size_bytes WHERE user_stats.user_id = OLD.user_id;
+        RETURN OLD;
+    ELSIF TG_OP = 'INSERT' THEN
+        INSERT INTO user_stats (user_id, uploaded_bytes) VALUES (NEW.user_id, NEW.size_bytes) ON CONFLICT (user_id) DO UPDATE SET uploaded_bytes = user_stats.uploaded_bytes + NEW.size_bytes;
+        RETURN NEW;
+    END IF;
+END;
+$$;
+DROP TRIGGER IF EXISTS media_change_for_user ON media;
+CREATE TRIGGER media_change_for_user AFTER INSERT OR UPDATE OR DELETE ON media FOR EACH ROW EXECUTE PROCEDURE track_update_user_media();
+
+-- Populate the new table
+DO $$
+BEGIN
+    IF ((SELECT COUNT(*) FROM user_stats)) = 0 THEN
+        INSERT INTO user_stats SELECT user_id, SUM(size_bytes) FROM media GROUP BY user_id;
+    END IF;
+END $$;
diff --git a/quota/quota.go b/quota/quota.go
new file mode 100644
index 00000000..1a8649ff
--- /dev/null
+++ b/quota/quota.go
@@ -0,0 +1,35 @@
+package quota
+
+import (
+	"database/sql"
+
+	"github.com/ryanuber/go-glob"
+	"github.com/turt2live/matrix-media-repo/common/rcontext"
+	"github.com/turt2live/matrix-media-repo/storage"
+)
+
+func IsUserWithinQuota(ctx rcontext.RequestContext, userId string) (bool, error) {
+	if !ctx.Config.Uploads.Quota.Enabled {
+		return true, nil
+	}
+
+	db := storage.GetDatabase().GetMetadataStore(ctx)
+	stat, err := db.GetUserStats(userId)
+	if err == sql.ErrNoRows {
+		return true, nil // no stats == within quota
+	}
+	if err != nil {
+		return false, err
+	}
+
+	for _, q := range ctx.Config.Uploads.Quota.UserQuotas {
+		if glob.Glob(q.Glob, userId) {
+			if q.MaxBytes == 0 {
+				return true, nil // infinite quota
+			}
+			return stat.UploadedBytes < q.MaxBytes, nil
+		}
+	}
+
+	return true, nil // no rules == no quota
+}
diff --git a/storage/stores/metadata_store.go b/storage/stores/metadata_store.go
index 28c5f712..82a4c22f 100644
--- a/storage/stores/metadata_store.go
+++ b/storage/stores/metadata_store.go
@@ -31,6 +31,7 @@ const selectReservation = "SELECT origin, media_id, reason FROM reserved_media W
 const selectMediaLastAccessed = "SELECT m.sha256_hash, m.size_bytes, m.datastore_id, m.location, m.creation_ts, a.last_access_ts FROM media AS m JOIN last_access AS a ON m.sha256_hash = a.sha256_hash WHERE a.last_access_ts < $1;"
 const insertBlurhash = "INSERT INTO blurhashes (sha256_hash, blurhash) VALUES ($1, $2);"
 const selectBlurhash = "SELECT blurhash FROM blurhashes WHERE sha256_hash = $1;"
+const selectUserStats = "SELECT user_id, uploaded_bytes FROM user_stats WHERE user_id = $1;"
 
 type metadataStoreStatements struct {
 	upsertLastAccessed                            *sql.Stmt
@@ -51,6 +52,7 @@ type metadataStoreStatements struct {
 	selectMediaLastAccessed                       *sql.Stmt
 	insertBlurhash                                *sql.Stmt
 	selectBlurhash                                *sql.Stmt
+	selectUserStats                               *sql.Stmt
 }
 
 type MetadataStoreFactory struct {
@@ -124,6 +126,9 @@ func InitMetadataStore(sqlDb *sql.DB) (*MetadataStoreFactory, error) {
 	if store.stmts.selectBlurhash, err = store.sqlDb.Prepare(selectBlurhash); err != nil {
 		return nil, err
 	}
+	if store.stmts.selectUserStats, err = store.sqlDb.Prepare(selectUserStats); err != nil {
+		return nil, err
+	}
 
 	return &store, nil
 }
@@ -408,3 +413,17 @@ func (s *MetadataStore) GetBlurhash(sha256Hash string) (string, error) {
 	}
 	return blurhash, nil
 }
+
+func (s *MetadataStore) GetUserStats(userId string) (*types.UserStats, error) {
+	r := s.statements.selectUserStats.QueryRowContext(s.ctx, userId)
+
+	stat := &types.UserStats{}
+	err := r.Scan(
+		&stat.UserId,
+		&stat.UploadedBytes,
+	)
+	if err != nil {
+		return nil, err
+	}
+	return stat, nil
+}
diff --git a/types/stats.go b/types/stats.go
new file mode 100644
index 00000000..01188c9d
--- /dev/null
+++ b/types/stats.go
@@ -0,0 +1,6 @@
+package types
+
+type UserStats struct {
+	UserId        string
+	UploadedBytes int64
+}
-- 
GitLab