From 0205efa9e3ac672575f60b64f6a725a3b15a80bb Mon Sep 17 00:00:00 2001 From: Travis Ralston <travpc@gmail.com> Date: Tue, 25 Jun 2019 20:01:24 -0600 Subject: [PATCH] Use custom singleflight everywhere --- api/r0/thumbnail.go | 60 +---- common/globals/singleflight_groups.go | 7 + .../download_controller.go | 225 ++++++++++-------- .../preview_controller/preview_controller.go | 68 +++--- .../thumbnail_controller.go | 116 +++++---- storage/stores/url_store.go | 6 +- util/singleflight-counter/singleflight.go | 6 + 7 files changed, 265 insertions(+), 223 deletions(-) create mode 100644 common/globals/singleflight_groups.go diff --git a/api/r0/thumbnail.go b/api/r0/thumbnail.go index afae5e4f..39ee66f3 100644 --- a/api/r0/thumbnail.go +++ b/api/r0/thumbnail.go @@ -1,7 +1,6 @@ package r0 import ( - "fmt" "net/http" "strconv" @@ -11,13 +10,8 @@ import ( "github.com/turt2live/matrix-media-repo/common" "github.com/turt2live/matrix-media-repo/common/config" "github.com/turt2live/matrix-media-repo/controllers/thumbnail_controller" - "github.com/turt2live/matrix-media-repo/types" - "github.com/turt2live/matrix-media-repo/util" - "github.com/turt2live/matrix-media-repo/util/singleflight-counter" ) -var thumbnailRequestGroup singleflight_counter.Group - func ThumbnailMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { params := mux.Vars(r) @@ -81,55 +75,21 @@ func ThumbnailMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) inter "requestedAnimated": animated, }) - // TODO: Move this to a lower layer (somewhere where the thumbnail dimensions are known, before media is downloaded) - requestKey := fmt.Sprintf("thumbnail_%s_%s_%d_%d_%s_%t", server, mediaId, width, height, method, animated) - v, count, err := thumbnailRequestGroup.Do(requestKey, func() (interface{}, error) { - streamedThumbnail, err := thumbnail_controller.GetThumbnail(server, mediaId, width, height, animated, method, downloadRemote, r.Context(), log) - if err != nil { - if err == common.ErrMediaNotFound { - return api.NotFoundError(), nil - } else if err == common.ErrMediaTooLarge { - return api.RequestTooLarge(), nil - } - log.Error("Unexpected error locating media: " + err.Error()) - return api.InternalServerError("Unexpected Error"), nil - } - - return streamedThumbnail, nil - }, func(v interface{}, count int, err error) []interface{} { - if err != nil { - return nil - } - - rv := v.(*types.StreamedThumbnail) - vals := make([]interface{}, 0) - streams := util.CloneReader(rv.Stream, count) - - for i := 0; i < count; i++ { - vals = append(vals, &types.StreamedThumbnail{ - Thumbnail: rv.Thumbnail, - Stream: streams[i], - }) - } - - return vals - }) - + streamedThumbnail, err := thumbnail_controller.GetThumbnail(server, mediaId, width, height, animated, method, downloadRemote, r.Context(), log) if err != nil { - log.Error("Unexpected error handling request: " + err.Error()) + if err == common.ErrMediaNotFound { + return api.NotFoundError() + } else if err == common.ErrMediaTooLarge { + return api.RequestTooLarge() + } + log.Error("Unexpected error locating media: " + err.Error()) return api.InternalServerError("Unexpected Error") } - rv := v.(*types.StreamedThumbnail) - - if count > 0 { - log.Info("Request response was shared ", count, " times") - } - return &DownloadMediaResponse{ - ContentType: rv.Thumbnail.ContentType, - SizeBytes: rv.Thumbnail.SizeBytes, - Data: rv.Stream, + ContentType: streamedThumbnail.Thumbnail.ContentType, + SizeBytes: streamedThumbnail.Thumbnail.SizeBytes, + Data: streamedThumbnail.Stream, Filename: "thumbnail", } } diff --git a/common/globals/singleflight_groups.go b/common/globals/singleflight_groups.go new file mode 100644 index 00000000..4f5fd266 --- /dev/null +++ b/common/globals/singleflight_groups.go @@ -0,0 +1,7 @@ +package globals + +import ( + "github.com/turt2live/matrix-media-repo/util/singleflight-counter" +) + +var DefaultRequestGroup singleflight_counter.Group diff --git a/controllers/download_controller/download_controller.go b/controllers/download_controller/download_controller.go index 61ce3d2c..7dae57d7 100644 --- a/controllers/download_controller/download_controller.go +++ b/controllers/download_controller/download_controller.go @@ -4,11 +4,13 @@ import ( "context" "database/sql" "errors" + "fmt" "time" "github.com/patrickmn/go-cache" "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/common" + "github.com/turt2live/matrix-media-repo/common/globals" "github.com/turt2live/matrix-media-repo/internal_cache" "github.com/turt2live/matrix-media-repo/storage" "github.com/turt2live/matrix-media-repo/storage/datastore" @@ -19,86 +21,116 @@ import ( var localCache = cache.New(30*time.Second, 60*time.Second) func GetMedia(origin string, mediaId string, downloadRemote bool, blockForMedia bool, ctx context.Context, log *logrus.Entry) (*types.MinimalMedia, error) { - var media *types.Media - var minMedia *types.MinimalMedia - var err error - if blockForMedia { - media, err = FindMediaRecord(origin, mediaId, downloadRemote, ctx, log) - if media != nil { - minMedia = &types.MinimalMedia{ - Origin: media.Origin, - MediaId: media.MediaId, - ContentType: media.ContentType, - UploadName: media.UploadName, - SizeBytes: media.SizeBytes, - Stream: nil, // we'll populate this later if we need to - KnownMedia: media, + cacheKey := fmt.Sprintf("%s/%s?r=%t&b=%t", origin, mediaId, downloadRemote, blockForMedia) + v, _, err := globals.DefaultRequestGroup.Do(cacheKey, func() (interface{}, error) { + var media *types.Media + var minMedia *types.MinimalMedia + var err error + if blockForMedia { + media, err = FindMediaRecord(origin, mediaId, downloadRemote, ctx, log) + if media != nil { + minMedia = &types.MinimalMedia{ + Origin: media.Origin, + MediaId: media.MediaId, + ContentType: media.ContentType, + UploadName: media.UploadName, + SizeBytes: media.SizeBytes, + Stream: nil, // we'll populate this later if we need to + KnownMedia: media, + } + } + } else { + minMedia, err = FindMinimalMediaRecord(origin, mediaId, downloadRemote, ctx, log) + if minMedia != nil { + media = minMedia.KnownMedia } } - } else { - minMedia, err = FindMinimalMediaRecord(origin, mediaId, downloadRemote, ctx, log) - if minMedia != nil { - media = minMedia.KnownMedia + if err != nil { + return nil, err } - } - if err != nil { - return nil, err - } - if minMedia == nil { - log.Warn("Unexpected error while fetching media: no minimal media record") - return nil, common.ErrMediaNotFound - } - if media == nil && blockForMedia { - log.Warn("Unexpected error while fetching media: no regular media record (block for media in place)") - return nil, common.ErrMediaNotFound - } + if minMedia == nil { + log.Warn("Unexpected error while fetching media: no minimal media record") + return nil, common.ErrMediaNotFound + } + if media == nil && blockForMedia { + log.Warn("Unexpected error while fetching media: no regular media record (block for media in place)") + return nil, common.ErrMediaNotFound + } + + // if we have a known media record, we might as well set it + // if we don't, this won't do anything different + minMedia.KnownMedia = media + + if media != nil { + if media.Quarantined { + log.Warn("Quarantined media accessed") + return nil, common.ErrMediaQuarantined + } - // if we have a known media record, we might as well set it - // if we don't, this won't do anything different - minMedia.KnownMedia = media + err = storage.GetDatabase().GetMetadataStore(ctx, log).UpsertLastAccess(media.Sha256Hash, util.NowMillis()) + if err != nil { + logrus.Warn("Failed to upsert the last access time: ", err) + } - if media != nil { - if media.Quarantined { - log.Warn("Quarantined media accessed") - return nil, common.ErrMediaQuarantined + localCache.Set(origin+"/"+mediaId, media, cache.DefaultExpiration) + internal_cache.Get().IncrementDownloads(media.Sha256Hash) + + cached, err := internal_cache.Get().GetMedia(media, log) + if err != nil { + return nil, err + } + if cached != nil && cached.Contents != nil { + minMedia.Stream = util.BufferToStream(cached.Contents) + return minMedia, nil + } } - err = storage.GetDatabase().GetMetadataStore(ctx, log).UpsertLastAccess(media.Sha256Hash, util.NowMillis()) - if err != nil { - logrus.Warn("Failed to upsert the last access time: ", err) + if minMedia.Stream != nil { + log.Info("Returning minimal media record with a viable stream") + return minMedia, nil } - localCache.Set(origin+"/"+mediaId, media, cache.DefaultExpiration) - internal_cache.Get().IncrementDownloads(media.Sha256Hash) + if media == nil { + log.Error("Failed to locate media") + return nil, errors.New("failed to locate media") + } - cached, err := internal_cache.Get().GetMedia(media, log) + log.Info("Reading media from disk") + mediaStream, err := datastore.DownloadStream(ctx, log, media.DatastoreId, media.Location) if err != nil { return nil, err } - if cached != nil && cached.Contents != nil { - minMedia.Stream = util.BufferToStream(cached.Contents) - return minMedia, nil - } - } - if minMedia.Stream != nil { - log.Info("Returning minimal media record with a viable stream") + minMedia.Stream = mediaStream return minMedia, nil - } + }, func(v interface{}, count int, err error) []interface{} { + if err != nil { + return nil + } - if media == nil { - log.Error("Failed to locate media") - return nil, errors.New("failed to locate media") - } + rv := v.(*types.MinimalMedia) + vals := make([]interface{}, 0) + streams := util.CloneReader(rv.Stream, count) - log.Info("Reading media from disk") - mediaStream, err := datastore.DownloadStream(ctx, log, media.DatastoreId, media.Location) - if err != nil { - return nil, err - } + for i := 0; i < count; i++ { + if rv.KnownMedia != nil { + internal_cache.Get().IncrementDownloads(rv.KnownMedia.Sha256Hash) + } + vals = append(vals, &types.MinimalMedia{ + Origin: rv.Origin, + MediaId: rv.MediaId, + UploadName: rv.UploadName, + ContentType: rv.ContentType, + SizeBytes: rv.SizeBytes, + KnownMedia: rv.KnownMedia, + Stream: streams[i], + }) + } - minMedia.Stream = mediaStream - return minMedia, nil + return vals + }) + + return v.(*types.MinimalMedia), err } func FindMinimalMediaRecord(origin string, mediaId string, downloadRemote bool, ctx context.Context, log *logrus.Entry) (*types.MinimalMedia, error) { @@ -181,45 +213,50 @@ func FindMinimalMediaRecord(origin string, mediaId string, downloadRemote bool, } func FindMediaRecord(origin string, mediaId string, downloadRemote bool, ctx context.Context, log *logrus.Entry) (*types.Media, error) { - db := storage.GetDatabase().GetMediaStore(ctx, log) + cacheKey := origin + "/" + mediaId + v, _, err := globals.DefaultRequestGroup.DoWithoutPost(cacheKey, func() (interface{}, error) { + db := storage.GetDatabase().GetMediaStore(ctx, log) - var media *types.Media - item, found := localCache.Get(origin + "/" + mediaId) - if found { - media = item.(*types.Media) - } else { - log.Info("Getting media record from database") - dbMedia, err := db.Get(origin, mediaId) - if err != nil { - if err == sql.ErrNoRows { - if util.IsServerOurs(origin) { - log.Warn("Media not found") - return nil, common.ErrMediaNotFound + var media *types.Media + item, found := localCache.Get(cacheKey) + if found { + media = item.(*types.Media) + } else { + log.Info("Getting media record from database") + dbMedia, err := db.Get(origin, mediaId) + if err != nil { + if err == sql.ErrNoRows { + if util.IsServerOurs(origin) { + log.Warn("Media not found") + return nil, common.ErrMediaNotFound + } } - } - if !downloadRemote { - log.Warn("Remote media not being downloaded") - return nil, common.ErrMediaNotFound - } + if !downloadRemote { + log.Warn("Remote media not being downloaded") + return nil, common.ErrMediaNotFound + } - mediaChan := getResourceHandler().DownloadRemoteMedia(origin, mediaId, true) - defer close(mediaChan) + mediaChan := getResourceHandler().DownloadRemoteMedia(origin, mediaId, true) + defer close(mediaChan) - result := <-mediaChan - if result.err != nil { - return nil, result.err + result := <-mediaChan + if result.err != nil { + return nil, result.err + } + media = result.media + } else { + media = dbMedia } - media = result.media - } else { - media = dbMedia } - } - if media == nil { - log.Warn("Despite all efforts, a media record could not be found") - return nil, common.ErrMediaNotFound - } + if media == nil { + log.Warn("Despite all efforts, a media record could not be found") + return nil, common.ErrMediaNotFound + } + + return media, nil + }) - return media, nil + return v.(*types.Media), err } diff --git a/controllers/preview_controller/preview_controller.go b/controllers/preview_controller/preview_controller.go index cb720fdf..82c7b600 100644 --- a/controllers/preview_controller/preview_controller.go +++ b/controllers/preview_controller/preview_controller.go @@ -4,52 +4,62 @@ import ( "context" "database/sql" "errors" + "fmt" "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/common" + "github.com/turt2live/matrix-media-repo/common/globals" "github.com/turt2live/matrix-media-repo/controllers/preview_controller/acl" "github.com/turt2live/matrix-media-repo/storage" + "github.com/turt2live/matrix-media-repo/storage/stores" "github.com/turt2live/matrix-media-repo/types" "github.com/turt2live/matrix-media-repo/util" ) func GetPreview(urlStr string, onHost string, forUserId string, atTs int64, ctx context.Context, log *logrus.Entry) (*types.UrlPreview, error) { - log = log.WithFields(logrus.Fields{ - "preview_controller_at_ts": atTs, - }) + atTs = stores.GetBucketTs(atTs) + cacheKey := fmt.Sprintf("%d_%s/%s", atTs, onHost, urlStr) + v, _, err := globals.DefaultRequestGroup.DoWithoutPost(cacheKey, func() (interface{}, error) { - db := storage.GetDatabase().GetUrlStore(ctx, log) + log = log.WithFields(logrus.Fields{ + "preview_controller_at_ts": atTs, + }) - cached, err := db.GetPreview(urlStr, atTs) - if err != nil && err != sql.ErrNoRows { - log.Error("Error getting cached URL preview: ", err.Error()) - return nil, err - } - if err != sql.ErrNoRows { - log.Info("Returning cached URL preview") - return cachedPreviewToReal(cached) - } + db := storage.GetDatabase().GetUrlStore(ctx, log) - now := util.NowMillis() - if (now - atTs) > 60000 { - // Because we don't have a cached preview, we'll use the current time as the preview time. - // We also give a 60 second buffer so we don't cause an infinite loop (considering we're - // calling ourselves), and to give a lenient opportunity for slow execution. - return GetPreview(urlStr, onHost, forUserId, now, ctx, log) - } + cached, err := db.GetPreview(urlStr, atTs) + if err != nil && err != sql.ErrNoRows { + log.Error("Error getting cached URL preview: ", err.Error()) + return nil, err + } + if err != sql.ErrNoRows { + log.Info("Returning cached URL preview") + return cachedPreviewToReal(cached) + } - log.Info("Preview not cached - fetching resource") + now := util.NowMillis() + if (now - atTs) > 60000 { + // Because we don't have a cached preview, we'll use the current time as the preview time. + // We also give a 60 second buffer so we don't cause an infinite loop (considering we're + // calling ourselves), and to give a lenient opportunity for slow execution. + return GetPreview(urlStr, onHost, forUserId, now, ctx, log) + } - urlToPreview, err := acl.ValidateUrlForPreview(urlStr, ctx, log) - if err != nil { - return nil, err - } + log.Info("Preview not cached - fetching resource") + + urlToPreview, err := acl.ValidateUrlForPreview(urlStr, ctx, log) + if err != nil { + return nil, err + } - previewChan := getResourceHandler().GeneratePreview(urlToPreview, forUserId, onHost) - defer close(previewChan) + previewChan := getResourceHandler().GeneratePreview(urlToPreview, forUserId, onHost) + defer close(previewChan) + + result := <-previewChan + return result.preview, result.err + }) - result := <-previewChan - return result.preview, result.err + return v.(*types.UrlPreview), err } func cachedPreviewToReal(cached *types.CachedUrlPreview) (*types.UrlPreview, error) { diff --git a/controllers/thumbnail_controller/thumbnail_controller.go b/controllers/thumbnail_controller/thumbnail_controller.go index 616f6637..3c9a3219 100644 --- a/controllers/thumbnail_controller/thumbnail_controller.go +++ b/controllers/thumbnail_controller/thumbnail_controller.go @@ -18,6 +18,7 @@ import ( "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/common" "github.com/turt2live/matrix-media-repo/common/config" + "github.com/turt2live/matrix-media-repo/common/globals" "github.com/turt2live/matrix-media-repo/controllers/download_controller" "github.com/turt2live/matrix-media-repo/internal_cache" "github.com/turt2live/matrix-media-repo/storage" @@ -108,8 +109,6 @@ func GetThumbnail(origin string, mediaId string, desiredWidth int, desiredHeight return nil, common.ErrMediaTooLarge } - db := storage.GetDatabase().GetThumbnailStore(ctx, log) - width, height, method, err := pickThumbnailDimensions(desiredWidth, desiredHeight, method) if err != nil { return nil, err @@ -117,61 +116,84 @@ func GetThumbnail(origin string, mediaId string, desiredWidth int, desiredHeight cacheKey := fmt.Sprintf("%s/%s?w=%d&h=%d&m=%s&a=%t", media.Origin, media.MediaId, width, height, method, animated) - var thumbnail *types.Thumbnail - item, found := localCache.Get(cacheKey) - if found { - thumbnail = item.(*types.Thumbnail) - } else { - log.Info("Getting thumbnail record from database") - dbThumb, err := db.Get(media.Origin, media.MediaId, width, height, method, animated) - if err != nil { - if err == sql.ErrNoRows { - log.Info("Thumbnail does not exist, attempting to generate it") - genThumb, err2 := GetOrGenerateThumbnail(media, width, height, animated, method, ctx, log) - if err2 != nil { - return nil, err2 - } + v, _, err := globals.DefaultRequestGroup.Do(cacheKey, func() (interface{}, error) { + db := storage.GetDatabase().GetThumbnailStore(ctx, log) - thumbnail = genThumb + var thumbnail *types.Thumbnail + item, found := localCache.Get(cacheKey) + if found { + thumbnail = item.(*types.Thumbnail) + } else { + log.Info("Getting thumbnail record from database") + dbThumb, err := db.Get(media.Origin, media.MediaId, width, height, method, animated) + if err != nil { + if err == sql.ErrNoRows { + log.Info("Thumbnail does not exist, attempting to generate it") + genThumb, err2 := GetOrGenerateThumbnail(media, width, height, animated, method, ctx, log) + if err2 != nil { + return nil, err2 + } + + thumbnail = genThumb + } else { + return nil, err + } } else { - return nil, err + thumbnail = dbThumb } - } else { - thumbnail = dbThumb } - } - if thumbnail == nil { - log.Warn("Despite all efforts, a thumbnail record could not be found or generated") - return nil, common.ErrMediaNotFound - } + if thumbnail == nil { + log.Warn("Despite all efforts, a thumbnail record could not be found or generated") + return nil, common.ErrMediaNotFound + } - err = storage.GetDatabase().GetMetadataStore(ctx, log).UpsertLastAccess(thumbnail.Sha256Hash, util.NowMillis()) - if err != nil { - logrus.Warn("Failed to upsert the last access time: ", err) - } + err = storage.GetDatabase().GetMetadataStore(ctx, log).UpsertLastAccess(thumbnail.Sha256Hash, util.NowMillis()) + if err != nil { + logrus.Warn("Failed to upsert the last access time: ", err) + } - localCache.Set(cacheKey, thumbnail, cache.DefaultExpiration) - internal_cache.Get().IncrementDownloads(thumbnail.Sha256Hash) + localCache.Set(cacheKey, thumbnail, cache.DefaultExpiration) - cached, err := internal_cache.Get().GetThumbnail(thumbnail, log) - if err != nil { - return nil, err - } - if cached != nil && cached.Contents != nil { - return &types.StreamedThumbnail{ - Thumbnail: thumbnail, - Stream: util.BufferToStream(cached.Contents), - }, nil - } + cached, err := internal_cache.Get().GetThumbnail(thumbnail, log) + if err != nil { + return nil, err + } + if cached != nil && cached.Contents != nil { + return &types.StreamedThumbnail{ + Thumbnail: thumbnail, + Stream: util.BufferToStream(cached.Contents), + }, nil + } - log.Info("Reading thumbnail from disk") - mediaStream, err := datastore.DownloadStream(ctx, log, thumbnail.DatastoreId, thumbnail.Location) - if err != nil { - return nil, err - } + log.Info("Reading thumbnail from disk") + mediaStream, err := datastore.DownloadStream(ctx, log, thumbnail.DatastoreId, thumbnail.Location) + if err != nil { + return nil, err + } + + return &types.StreamedThumbnail{Thumbnail: thumbnail, Stream: mediaStream}, nil + }, func(v interface{}, count int, err error) []interface{} { + if err != nil { + return nil + } + + rv := v.(*types.StreamedThumbnail) + vals := make([]interface{}, 0) + streams := util.CloneReader(rv.Stream, count) + + for i := 0; i < count; i++ { + internal_cache.Get().IncrementDownloads(rv.Thumbnail.Sha256Hash) + vals = append(vals, &types.StreamedThumbnail{ + Thumbnail: rv.Thumbnail, + Stream: streams[i], + }) + } + + return vals + }) - return &types.StreamedThumbnail{Thumbnail: thumbnail, Stream: mediaStream}, nil + return v.(*types.StreamedThumbnail), err } func GetOrGenerateThumbnail(media *types.Media, width int, height int, animated bool, method string, ctx context.Context, log *logrus.Entry) (*types.Thumbnail, error) { diff --git a/storage/stores/url_store.go b/storage/stores/url_store.go index cf65bfdb..4e321ac0 100644 --- a/storage/stores/url_store.go +++ b/storage/stores/url_store.go @@ -58,7 +58,7 @@ func (s *UrlStore) GetPreview(url string, ts int64) (*types.CachedUrlPreview, er r := &types.CachedUrlPreview{ Preview: &types.UrlPreview{}, } - err := s.statements.selectUrlPreview.QueryRowContext(s.ctx, url, getBucketTs(ts)).Scan( + err := s.statements.selectUrlPreview.QueryRowContext(s.ctx, url, GetBucketTs(ts)).Scan( &r.SearchUrl, &r.ErrorCode, &r.FetchedTs, @@ -82,7 +82,7 @@ func (s *UrlStore) InsertPreview(record *types.CachedUrlPreview) error { s.ctx, record.SearchUrl, record.ErrorCode, - getBucketTs(record.FetchedTs), + GetBucketTs(record.FetchedTs), record.Preview.Url, record.Preview.SiteName, record.Preview.Type, @@ -107,7 +107,7 @@ func (s *UrlStore) InsertPreviewError(url string, errorCode string) error { }) } -func getBucketTs(ts int64) int64 { +func GetBucketTs(ts int64) int64 { // 1 hour buckets return (ts / 3600000) * 3600000 } diff --git a/util/singleflight-counter/singleflight.go b/util/singleflight-counter/singleflight.go index cf4e5485..b32095c9 100644 --- a/util/singleflight-counter/singleflight.go +++ b/util/singleflight-counter/singleflight.go @@ -35,6 +35,12 @@ func (c *call) NextVal() interface{} { return val } +func (g *Group) DoWithoutPost(key string, fn func() (interface{}, error)) (interface{}, int, error) { + return g.Do(key, fn, func(v interface{}, total int, e error) []interface{} { + return nil + }) +} + func (g *Group) Do(key string, fn func() (interface{}, error), postprocess func(v interface{}, total int, e error) []interface{}) (interface{}, int, error) { g.mu.Lock() if g.m == nil { -- GitLab