diff --git a/api/custom/usage.go b/api/custom/usage.go index 387c27c778e872e7911f8385d0ad8ae8f26f0ec6..be92ed2916a8acf3243abf60f7cf9a37b830bc16 100644 --- a/api/custom/usage.go +++ b/api/custom/usage.go @@ -1,7 +1,6 @@ package custom import ( - "fmt" "net/http" "github.com/gorilla/mux" @@ -9,6 +8,7 @@ import ( "github.com/turt2live/matrix-media-repo/api" "github.com/turt2live/matrix-media-repo/storage" "github.com/turt2live/matrix-media-repo/types" + "github.com/turt2live/matrix-media-repo/util" ) type MinimalUsageInfo struct { @@ -32,6 +32,18 @@ type UserUsageEntry struct { UploadedMxcs []string `json:"uploaded,flow"` } +type MediaUsageEntry struct { + SizeBytes int64 `json:"size_bytes"` + UploadedBy string `json:"uploaded_by"` + DatastoreId string `json:"datastore_id"` + DatastoreLocation string `json:"datastore_location"` + Sha256Hash string `json:"sha256_hash"` + Quarantined bool `json:"quarantined"` + UploadName string `json:"upload_name"` + ContentType string `json:"content_type"` + CreatedTs int64 `json:"created_ts"` +} + func GetDomainUsage(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { params := mux.Vars(r) @@ -123,7 +135,65 @@ func GetUserUsage(r *http.Request, log *logrus.Entry, user api.UserInfo) interfa entry.RawCounts.Total += 1 entry.RawCounts.Media += 1 - entry.UploadedMxcs = append(entry.UploadedMxcs, fmt.Sprintf("mxc://%s/%s", media.Origin, media.MediaId)) + entry.UploadedMxcs = append(entry.UploadedMxcs, media.MxcUri()) + } + + return parsed +} + +func GetUploadsUsage(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { + params := mux.Vars(r) + + serverName := params["serverName"] + mxcs := r.URL.Query()["mxc"] + + log = log.WithFields(logrus.Fields{ + "serverName": serverName, + }) + + db := storage.GetDatabase().GetMediaStore(r.Context(), log) + + var records []*types.Media + var err error + if mxcs == nil || len(mxcs) == 0 { + records, err = db.GetAllMediaForServer(serverName) + } else { + split := make([]string, 0) + for _, mxc := range mxcs { + o, i, err := util.SplitMxc(mxc) + if err != nil { + log.Error(err) + return api.InternalServerError("Error parsing MXC " + mxc) + } + + if o != serverName { + return api.BadRequest("MXC URIs must match the requested server") + } + + split = append(split, i) + } + records, err = db.GetAllMediaInIds(serverName, split) + } + + if err != nil { + log.Error(err) + return api.InternalServerError("Failed to get media records for users") + } + + parsed := make(map[string]*MediaUsageEntry) + + for _, media := range records { + parsed[media.MxcUri()] = &MediaUsageEntry{ + SizeBytes: media.SizeBytes, + UploadName: media.UploadName, + ContentType: media.ContentType, + CreatedTs: media.CreationTs, + DatastoreId: media.DatastoreId, + DatastoreLocation: media.Location, + Quarantined: media.Quarantined, + Sha256Hash: media.Sha256Hash, + UploadedBy: media.UserId, + } } return parsed diff --git a/api/webserver/webserver.go b/api/webserver/webserver.go index 0b5cdd39261e8b22bf412169717ba7790fc94048..e259cfd189ae97642448670783f163db05e54d5d 100644 --- a/api/webserver/webserver.go +++ b/api/webserver/webserver.go @@ -44,6 +44,7 @@ func Init() { healthzHandler := handler{api.AccessTokenOptionalRoute(custom.GetHealthz), "healthz", counter, true} domainUsageHandler := handler{api.RepoAdminRoute(custom.GetDomainUsage), "domain_usage", counter, false} userUsageHandler := handler{api.RepoAdminRoute(custom.GetUserUsage), "user_usage", counter, false} + uploadsUsageHandler := handler{api.RepoAdminRoute(custom.GetUploadsUsage), "uploads_usage", counter, false} routes := make(map[string]route) versions := []string{"r0", "v1", "unstable"} // r0 is typically clients and v1 is typically servers. v1 is deprecated. @@ -68,6 +69,7 @@ func Init() { routes["/_matrix/media/"+version+"/admin/federation/test/{serverName:[a-zA-Z0-9.:\\-_]+}"] = route{"GET", fedTestHandler} routes["/_matrix/media/"+version+"/admin/usage/{serverName:[a-zA-Z0-9.:\\-_]+}"] = route{"GET", domainUsageHandler} routes["/_matrix/media/"+version+"/admin/usage/{serverName:[a-zA-Z0-9.:\\-_]+}/users"] = route{"GET", userUsageHandler} + routes["/_matrix/media/"+version+"/admin/usage/{serverName:[a-zA-Z0-9.:\\-_]+}/uploads"] = route{"GET", uploadsUsageHandler} // Routes that we should handle but aren't in the media namespace (synapse compat) routes["/_matrix/client/"+version+"/admin/purge_media_cache"] = route{"POST", purgeHandler} diff --git a/storage/stores/media_store.go b/storage/stores/media_store.go index a42a2aa864d80a9474ef14e77e0d6c281a34b7f8..d2e86a2fbe0c25ba839bedcb039da74c6af5e58b 100644 --- a/storage/stores/media_store.go +++ b/storage/stores/media_store.go @@ -25,6 +25,7 @@ const updateMediaDatastoreAndLocation = "UPDATE media SET location = $4, datasto const selectAllDatastores = "SELECT datastore_id, ds_type, uri FROM datastores;" const selectAllMediaForServer = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, datastore_id, location, creation_ts, quarantined FROM media WHERE origin = $1" const selectAllMediaForServerUsers = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, datastore_id, location, creation_ts, quarantined FROM media WHERE origin = $1 AND user_id = ANY($2)" +const selectAllMediaForServerIds = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, datastore_id, location, creation_ts, quarantined FROM media WHERE origin = $1 AND media_id = ANY($2)" var dsCacheByPath = sync.Map{} // [string] => Datastore var dsCacheById = sync.Map{} // [string] => Datastore @@ -46,6 +47,7 @@ type mediaStoreStatements struct { selectMediaInDatastoreOlderThan *sql.Stmt selectAllMediaForServer *sql.Stmt selectAllMediaForServerUsers *sql.Stmt + selectAllMediaForServerIds *sql.Stmt } type MediaStoreFactory struct { @@ -111,6 +113,9 @@ func InitMediaStore(sqlDb *sql.DB) (*MediaStoreFactory, error) { if store.stmts.selectAllMediaForServerUsers, err = store.sqlDb.Prepare(selectAllMediaForServerUsers); err != nil { return nil, err } + if store.stmts.selectAllMediaForServerIds, err = store.sqlDb.Prepare(selectAllMediaForServerIds); err != nil { + return nil, err + } return &store, nil } @@ -453,3 +458,34 @@ func (s *MediaStore) GetAllMediaForServerUsers(serverName string, userIds []stri return results, nil } + +func (s *MediaStore) GetAllMediaInIds(serverName string, mediaIds []string) ([]*types.Media, error) { + rows, err := s.statements.selectAllMediaForServerIds.QueryContext(s.ctx, serverName, pq.Array(mediaIds)) + if err != nil { + return nil, err + } + + var results []*types.Media + for rows.Next() { + obj := &types.Media{} + err = rows.Scan( + &obj.Origin, + &obj.MediaId, + &obj.UploadName, + &obj.ContentType, + &obj.UserId, + &obj.Sha256Hash, + &obj.SizeBytes, + &obj.DatastoreId, + &obj.Location, + &obj.CreationTs, + &obj.Quarantined, + ) + if err != nil { + return nil, err + } + results = append(results, obj) + } + + return results, nil +}