From 0205efa9e3ac672575f60b64f6a725a3b15a80bb Mon Sep 17 00:00:00 2001
From: Travis Ralston <travpc@gmail.com>
Date: Tue, 25 Jun 2019 20:01:24 -0600
Subject: [PATCH] Use custom singleflight everywhere

---
 api/r0/thumbnail.go                           |  60 +----
 common/globals/singleflight_groups.go         |   7 +
 .../download_controller.go                    | 225 ++++++++++--------
 .../preview_controller/preview_controller.go  |  68 +++---
 .../thumbnail_controller.go                   | 116 +++++----
 storage/stores/url_store.go                   |   6 +-
 util/singleflight-counter/singleflight.go     |   6 +
 7 files changed, 265 insertions(+), 223 deletions(-)
 create mode 100644 common/globals/singleflight_groups.go

diff --git a/api/r0/thumbnail.go b/api/r0/thumbnail.go
index afae5e4f..39ee66f3 100644
--- a/api/r0/thumbnail.go
+++ b/api/r0/thumbnail.go
@@ -1,7 +1,6 @@
 package r0
 
 import (
-	"fmt"
 	"net/http"
 	"strconv"
 
@@ -11,13 +10,8 @@ import (
 	"github.com/turt2live/matrix-media-repo/common"
 	"github.com/turt2live/matrix-media-repo/common/config"
 	"github.com/turt2live/matrix-media-repo/controllers/thumbnail_controller"
-	"github.com/turt2live/matrix-media-repo/types"
-	"github.com/turt2live/matrix-media-repo/util"
-	"github.com/turt2live/matrix-media-repo/util/singleflight-counter"
 )
 
-var thumbnailRequestGroup singleflight_counter.Group
-
 func ThumbnailMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} {
 	params := mux.Vars(r)
 
@@ -81,55 +75,21 @@ func ThumbnailMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) inter
 		"requestedAnimated": animated,
 	})
 
-	// TODO: Move this to a lower layer (somewhere where the thumbnail dimensions are known, before media is downloaded)
-	requestKey := fmt.Sprintf("thumbnail_%s_%s_%d_%d_%s_%t", server, mediaId, width, height, method, animated)
-	v, count, err := thumbnailRequestGroup.Do(requestKey, func() (interface{}, error) {
-		streamedThumbnail, err := thumbnail_controller.GetThumbnail(server, mediaId, width, height, animated, method, downloadRemote, r.Context(), log)
-		if err != nil {
-			if err == common.ErrMediaNotFound {
-				return api.NotFoundError(), nil
-			} else if err == common.ErrMediaTooLarge {
-				return api.RequestTooLarge(), nil
-			}
-			log.Error("Unexpected error locating media: " + err.Error())
-			return api.InternalServerError("Unexpected Error"), nil
-		}
-
-		return streamedThumbnail, nil
-	}, func(v interface{}, count int, err error) []interface{} {
-		if err != nil {
-			return nil
-		}
-
-		rv := v.(*types.StreamedThumbnail)
-		vals := make([]interface{}, 0)
-		streams := util.CloneReader(rv.Stream, count)
-
-		for i := 0; i < count; i++ {
-			vals = append(vals, &types.StreamedThumbnail{
-				Thumbnail: rv.Thumbnail,
-				Stream:    streams[i],
-			})
-		}
-
-		return vals
-	})
-
+	streamedThumbnail, err := thumbnail_controller.GetThumbnail(server, mediaId, width, height, animated, method, downloadRemote, r.Context(), log)
 	if err != nil {
-		log.Error("Unexpected error handling request: " + err.Error())
+		if err == common.ErrMediaNotFound {
+			return api.NotFoundError()
+		} else if err == common.ErrMediaTooLarge {
+			return api.RequestTooLarge()
+		}
+		log.Error("Unexpected error locating media: " + err.Error())
 		return api.InternalServerError("Unexpected Error")
 	}
 
-	rv := v.(*types.StreamedThumbnail)
-
-	if count > 0 {
-		log.Info("Request response was shared ", count, " times")
-	}
-
 	return &DownloadMediaResponse{
-		ContentType: rv.Thumbnail.ContentType,
-		SizeBytes:   rv.Thumbnail.SizeBytes,
-		Data:        rv.Stream,
+		ContentType: streamedThumbnail.Thumbnail.ContentType,
+		SizeBytes:   streamedThumbnail.Thumbnail.SizeBytes,
+		Data:        streamedThumbnail.Stream,
 		Filename:    "thumbnail",
 	}
 }
diff --git a/common/globals/singleflight_groups.go b/common/globals/singleflight_groups.go
new file mode 100644
index 00000000..4f5fd266
--- /dev/null
+++ b/common/globals/singleflight_groups.go
@@ -0,0 +1,7 @@
+package globals
+
+import (
+	"github.com/turt2live/matrix-media-repo/util/singleflight-counter"
+)
+
+var DefaultRequestGroup singleflight_counter.Group
diff --git a/controllers/download_controller/download_controller.go b/controllers/download_controller/download_controller.go
index 61ce3d2c..7dae57d7 100644
--- a/controllers/download_controller/download_controller.go
+++ b/controllers/download_controller/download_controller.go
@@ -4,11 +4,13 @@ import (
 	"context"
 	"database/sql"
 	"errors"
+	"fmt"
 	"time"
 
 	"github.com/patrickmn/go-cache"
 	"github.com/sirupsen/logrus"
 	"github.com/turt2live/matrix-media-repo/common"
+	"github.com/turt2live/matrix-media-repo/common/globals"
 	"github.com/turt2live/matrix-media-repo/internal_cache"
 	"github.com/turt2live/matrix-media-repo/storage"
 	"github.com/turt2live/matrix-media-repo/storage/datastore"
@@ -19,86 +21,116 @@ import (
 var localCache = cache.New(30*time.Second, 60*time.Second)
 
 func GetMedia(origin string, mediaId string, downloadRemote bool, blockForMedia bool, ctx context.Context, log *logrus.Entry) (*types.MinimalMedia, error) {
-	var media *types.Media
-	var minMedia *types.MinimalMedia
-	var err error
-	if blockForMedia {
-		media, err = FindMediaRecord(origin, mediaId, downloadRemote, ctx, log)
-		if media != nil {
-			minMedia = &types.MinimalMedia{
-				Origin:      media.Origin,
-				MediaId:     media.MediaId,
-				ContentType: media.ContentType,
-				UploadName:  media.UploadName,
-				SizeBytes:   media.SizeBytes,
-				Stream:      nil, // we'll populate this later if we need to
-				KnownMedia:  media,
+	cacheKey := fmt.Sprintf("%s/%s?r=%t&b=%t", origin, mediaId, downloadRemote, blockForMedia)
+	v, _, err := globals.DefaultRequestGroup.Do(cacheKey, func() (interface{}, error) {
+		var media *types.Media
+		var minMedia *types.MinimalMedia
+		var err error
+		if blockForMedia {
+			media, err = FindMediaRecord(origin, mediaId, downloadRemote, ctx, log)
+			if media != nil {
+				minMedia = &types.MinimalMedia{
+					Origin:      media.Origin,
+					MediaId:     media.MediaId,
+					ContentType: media.ContentType,
+					UploadName:  media.UploadName,
+					SizeBytes:   media.SizeBytes,
+					Stream:      nil, // we'll populate this later if we need to
+					KnownMedia:  media,
+				}
+			}
+		} else {
+			minMedia, err = FindMinimalMediaRecord(origin, mediaId, downloadRemote, ctx, log)
+			if minMedia != nil {
+				media = minMedia.KnownMedia
 			}
 		}
-	} else {
-		minMedia, err = FindMinimalMediaRecord(origin, mediaId, downloadRemote, ctx, log)
-		if minMedia != nil {
-			media = minMedia.KnownMedia
+		if err != nil {
+			return nil, err
 		}
-	}
-	if err != nil {
-		return nil, err
-	}
-	if minMedia == nil {
-		log.Warn("Unexpected error while fetching media: no minimal media record")
-		return nil, common.ErrMediaNotFound
-	}
-	if media == nil && blockForMedia {
-		log.Warn("Unexpected error while fetching media: no regular media record (block for media in place)")
-		return nil, common.ErrMediaNotFound
-	}
+		if minMedia == nil {
+			log.Warn("Unexpected error while fetching media: no minimal media record")
+			return nil, common.ErrMediaNotFound
+		}
+		if media == nil && blockForMedia {
+			log.Warn("Unexpected error while fetching media: no regular media record (block for media in place)")
+			return nil, common.ErrMediaNotFound
+		}
+
+		// if we have a known media record, we might as well set it
+		// if we don't, this won't do anything different
+		minMedia.KnownMedia = media
+
+		if media != nil {
+			if media.Quarantined {
+				log.Warn("Quarantined media accessed")
+				return nil, common.ErrMediaQuarantined
+			}
 
-	// if we have a known media record, we might as well set it
-	// if we don't, this won't do anything different
-	minMedia.KnownMedia = media
+			err = storage.GetDatabase().GetMetadataStore(ctx, log).UpsertLastAccess(media.Sha256Hash, util.NowMillis())
+			if err != nil {
+				logrus.Warn("Failed to upsert the last access time: ", err)
+			}
 
-	if media != nil {
-		if media.Quarantined {
-			log.Warn("Quarantined media accessed")
-			return nil, common.ErrMediaQuarantined
+			localCache.Set(origin+"/"+mediaId, media, cache.DefaultExpiration)
+			internal_cache.Get().IncrementDownloads(media.Sha256Hash)
+
+			cached, err := internal_cache.Get().GetMedia(media, log)
+			if err != nil {
+				return nil, err
+			}
+			if cached != nil && cached.Contents != nil {
+				minMedia.Stream = util.BufferToStream(cached.Contents)
+				return minMedia, nil
+			}
 		}
 
-		err = storage.GetDatabase().GetMetadataStore(ctx, log).UpsertLastAccess(media.Sha256Hash, util.NowMillis())
-		if err != nil {
-			logrus.Warn("Failed to upsert the last access time: ", err)
+		if minMedia.Stream != nil {
+			log.Info("Returning minimal media record with a viable stream")
+			return minMedia, nil
 		}
 
-		localCache.Set(origin+"/"+mediaId, media, cache.DefaultExpiration)
-		internal_cache.Get().IncrementDownloads(media.Sha256Hash)
+		if media == nil {
+			log.Error("Failed to locate media")
+			return nil, errors.New("failed to locate media")
+		}
 
-		cached, err := internal_cache.Get().GetMedia(media, log)
+		log.Info("Reading media from disk")
+		mediaStream, err := datastore.DownloadStream(ctx, log, media.DatastoreId, media.Location)
 		if err != nil {
 			return nil, err
 		}
-		if cached != nil && cached.Contents != nil {
-			minMedia.Stream = util.BufferToStream(cached.Contents)
-			return minMedia, nil
-		}
-	}
 
-	if minMedia.Stream != nil {
-		log.Info("Returning minimal media record with a viable stream")
+		minMedia.Stream = mediaStream
 		return minMedia, nil
-	}
+	}, func(v interface{}, count int, err error) []interface{} {
+		if err != nil {
+			return nil
+		}
 
-	if media == nil {
-		log.Error("Failed to locate media")
-		return nil, errors.New("failed to locate media")
-	}
+		rv := v.(*types.MinimalMedia)
+		vals := make([]interface{}, 0)
+		streams := util.CloneReader(rv.Stream, count)
 
-	log.Info("Reading media from disk")
-	mediaStream, err := datastore.DownloadStream(ctx, log, media.DatastoreId, media.Location)
-	if err != nil {
-		return nil, err
-	}
+		for i := 0; i < count; i++ {
+			if rv.KnownMedia != nil {
+				internal_cache.Get().IncrementDownloads(rv.KnownMedia.Sha256Hash)
+			}
+			vals = append(vals, &types.MinimalMedia{
+				Origin:      rv.Origin,
+				MediaId:     rv.MediaId,
+				UploadName:  rv.UploadName,
+				ContentType: rv.ContentType,
+				SizeBytes:   rv.SizeBytes,
+				KnownMedia:  rv.KnownMedia,
+				Stream:      streams[i],
+			})
+		}
 
-	minMedia.Stream = mediaStream
-	return minMedia, nil
+		return vals
+	})
+
+	return v.(*types.MinimalMedia), err
 }
 
 func FindMinimalMediaRecord(origin string, mediaId string, downloadRemote bool, ctx context.Context, log *logrus.Entry) (*types.MinimalMedia, error) {
@@ -181,45 +213,50 @@ func FindMinimalMediaRecord(origin string, mediaId string, downloadRemote bool,
 }
 
 func FindMediaRecord(origin string, mediaId string, downloadRemote bool, ctx context.Context, log *logrus.Entry) (*types.Media, error) {
-	db := storage.GetDatabase().GetMediaStore(ctx, log)
+	cacheKey := origin + "/" + mediaId
+	v, _, err := globals.DefaultRequestGroup.DoWithoutPost(cacheKey, func() (interface{}, error) {
+		db := storage.GetDatabase().GetMediaStore(ctx, log)
 
-	var media *types.Media
-	item, found := localCache.Get(origin + "/" + mediaId)
-	if found {
-		media = item.(*types.Media)
-	} else {
-		log.Info("Getting media record from database")
-		dbMedia, err := db.Get(origin, mediaId)
-		if err != nil {
-			if err == sql.ErrNoRows {
-				if util.IsServerOurs(origin) {
-					log.Warn("Media not found")
-					return nil, common.ErrMediaNotFound
+		var media *types.Media
+		item, found := localCache.Get(cacheKey)
+		if found {
+			media = item.(*types.Media)
+		} else {
+			log.Info("Getting media record from database")
+			dbMedia, err := db.Get(origin, mediaId)
+			if err != nil {
+				if err == sql.ErrNoRows {
+					if util.IsServerOurs(origin) {
+						log.Warn("Media not found")
+						return nil, common.ErrMediaNotFound
+					}
 				}
-			}
 
-			if !downloadRemote {
-				log.Warn("Remote media not being downloaded")
-				return nil, common.ErrMediaNotFound
-			}
+				if !downloadRemote {
+					log.Warn("Remote media not being downloaded")
+					return nil, common.ErrMediaNotFound
+				}
 
-			mediaChan := getResourceHandler().DownloadRemoteMedia(origin, mediaId, true)
-			defer close(mediaChan)
+				mediaChan := getResourceHandler().DownloadRemoteMedia(origin, mediaId, true)
+				defer close(mediaChan)
 
-			result := <-mediaChan
-			if result.err != nil {
-				return nil, result.err
+				result := <-mediaChan
+				if result.err != nil {
+					return nil, result.err
+				}
+				media = result.media
+			} else {
+				media = dbMedia
 			}
-			media = result.media
-		} else {
-			media = dbMedia
 		}
-	}
 
-	if media == nil {
-		log.Warn("Despite all efforts, a media record could not be found")
-		return nil, common.ErrMediaNotFound
-	}
+		if media == nil {
+			log.Warn("Despite all efforts, a media record could not be found")
+			return nil, common.ErrMediaNotFound
+		}
+
+		return media, nil
+	})
 
-	return media, nil
+	return v.(*types.Media), err
 }
diff --git a/controllers/preview_controller/preview_controller.go b/controllers/preview_controller/preview_controller.go
index cb720fdf..82c7b600 100644
--- a/controllers/preview_controller/preview_controller.go
+++ b/controllers/preview_controller/preview_controller.go
@@ -4,52 +4,62 @@ import (
 	"context"
 	"database/sql"
 	"errors"
+	"fmt"
 
 	"github.com/sirupsen/logrus"
 	"github.com/turt2live/matrix-media-repo/common"
+	"github.com/turt2live/matrix-media-repo/common/globals"
 	"github.com/turt2live/matrix-media-repo/controllers/preview_controller/acl"
 	"github.com/turt2live/matrix-media-repo/storage"
+	"github.com/turt2live/matrix-media-repo/storage/stores"
 	"github.com/turt2live/matrix-media-repo/types"
 	"github.com/turt2live/matrix-media-repo/util"
 )
 
 func GetPreview(urlStr string, onHost string, forUserId string, atTs int64, ctx context.Context, log *logrus.Entry) (*types.UrlPreview, error) {
-	log = log.WithFields(logrus.Fields{
-		"preview_controller_at_ts": atTs,
-	})
+	atTs = stores.GetBucketTs(atTs)
+	cacheKey := fmt.Sprintf("%d_%s/%s", atTs, onHost, urlStr)
+	v, _, err := globals.DefaultRequestGroup.DoWithoutPost(cacheKey, func() (interface{}, error) {
 
-	db := storage.GetDatabase().GetUrlStore(ctx, log)
+		log = log.WithFields(logrus.Fields{
+			"preview_controller_at_ts": atTs,
+		})
 
-	cached, err := db.GetPreview(urlStr, atTs)
-	if err != nil && err != sql.ErrNoRows {
-		log.Error("Error getting cached URL preview: ", err.Error())
-		return nil, err
-	}
-	if err != sql.ErrNoRows {
-		log.Info("Returning cached URL preview")
-		return cachedPreviewToReal(cached)
-	}
+		db := storage.GetDatabase().GetUrlStore(ctx, log)
 
-	now := util.NowMillis()
-	if (now - atTs) > 60000 {
-		// Because we don't have a cached preview, we'll use the current time as the preview time.
-		// We also give a 60 second buffer so we don't cause an infinite loop (considering we're
-		// calling ourselves), and to give a lenient opportunity for slow execution.
-		return GetPreview(urlStr, onHost, forUserId, now, ctx, log)
-	}
+		cached, err := db.GetPreview(urlStr, atTs)
+		if err != nil && err != sql.ErrNoRows {
+			log.Error("Error getting cached URL preview: ", err.Error())
+			return nil, err
+		}
+		if err != sql.ErrNoRows {
+			log.Info("Returning cached URL preview")
+			return cachedPreviewToReal(cached)
+		}
 
-	log.Info("Preview not cached - fetching resource")
+		now := util.NowMillis()
+		if (now - atTs) > 60000 {
+			// Because we don't have a cached preview, we'll use the current time as the preview time.
+			// We also give a 60 second buffer so we don't cause an infinite loop (considering we're
+			// calling ourselves), and to give a lenient opportunity for slow execution.
+			return GetPreview(urlStr, onHost, forUserId, now, ctx, log)
+		}
 
-	urlToPreview, err := acl.ValidateUrlForPreview(urlStr, ctx, log)
-	if err != nil {
-		return nil, err
-	}
+		log.Info("Preview not cached - fetching resource")
+
+		urlToPreview, err := acl.ValidateUrlForPreview(urlStr, ctx, log)
+		if err != nil {
+			return nil, err
+		}
 
-	previewChan := getResourceHandler().GeneratePreview(urlToPreview, forUserId, onHost)
-	defer close(previewChan)
+		previewChan := getResourceHandler().GeneratePreview(urlToPreview, forUserId, onHost)
+		defer close(previewChan)
+
+		result := <-previewChan
+		return result.preview, result.err
+	})
 
-	result := <-previewChan
-	return result.preview, result.err
+	return v.(*types.UrlPreview), err
 }
 
 func cachedPreviewToReal(cached *types.CachedUrlPreview) (*types.UrlPreview, error) {
diff --git a/controllers/thumbnail_controller/thumbnail_controller.go b/controllers/thumbnail_controller/thumbnail_controller.go
index 616f6637..3c9a3219 100644
--- a/controllers/thumbnail_controller/thumbnail_controller.go
+++ b/controllers/thumbnail_controller/thumbnail_controller.go
@@ -18,6 +18,7 @@ import (
 	"github.com/sirupsen/logrus"
 	"github.com/turt2live/matrix-media-repo/common"
 	"github.com/turt2live/matrix-media-repo/common/config"
+	"github.com/turt2live/matrix-media-repo/common/globals"
 	"github.com/turt2live/matrix-media-repo/controllers/download_controller"
 	"github.com/turt2live/matrix-media-repo/internal_cache"
 	"github.com/turt2live/matrix-media-repo/storage"
@@ -108,8 +109,6 @@ func GetThumbnail(origin string, mediaId string, desiredWidth int, desiredHeight
 		return nil, common.ErrMediaTooLarge
 	}
 
-	db := storage.GetDatabase().GetThumbnailStore(ctx, log)
-
 	width, height, method, err := pickThumbnailDimensions(desiredWidth, desiredHeight, method)
 	if err != nil {
 		return nil, err
@@ -117,61 +116,84 @@ func GetThumbnail(origin string, mediaId string, desiredWidth int, desiredHeight
 
 	cacheKey := fmt.Sprintf("%s/%s?w=%d&h=%d&m=%s&a=%t", media.Origin, media.MediaId, width, height, method, animated)
 
-	var thumbnail *types.Thumbnail
-	item, found := localCache.Get(cacheKey)
-	if found {
-		thumbnail = item.(*types.Thumbnail)
-	} else {
-		log.Info("Getting thumbnail record from database")
-		dbThumb, err := db.Get(media.Origin, media.MediaId, width, height, method, animated)
-		if err != nil {
-			if err == sql.ErrNoRows {
-				log.Info("Thumbnail does not exist, attempting to generate it")
-				genThumb, err2 := GetOrGenerateThumbnail(media, width, height, animated, method, ctx, log)
-				if err2 != nil {
-					return nil, err2
-				}
+	v, _, err := globals.DefaultRequestGroup.Do(cacheKey, func() (interface{}, error) {
+		db := storage.GetDatabase().GetThumbnailStore(ctx, log)
 
-				thumbnail = genThumb
+		var thumbnail *types.Thumbnail
+		item, found := localCache.Get(cacheKey)
+		if found {
+			thumbnail = item.(*types.Thumbnail)
+		} else {
+			log.Info("Getting thumbnail record from database")
+			dbThumb, err := db.Get(media.Origin, media.MediaId, width, height, method, animated)
+			if err != nil {
+				if err == sql.ErrNoRows {
+					log.Info("Thumbnail does not exist, attempting to generate it")
+					genThumb, err2 := GetOrGenerateThumbnail(media, width, height, animated, method, ctx, log)
+					if err2 != nil {
+						return nil, err2
+					}
+
+					thumbnail = genThumb
+				} else {
+					return nil, err
+				}
 			} else {
-				return nil, err
+				thumbnail = dbThumb
 			}
-		} else {
-			thumbnail = dbThumb
 		}
-	}
 
-	if thumbnail == nil {
-		log.Warn("Despite all efforts, a thumbnail record could not be found or generated")
-		return nil, common.ErrMediaNotFound
-	}
+		if thumbnail == nil {
+			log.Warn("Despite all efforts, a thumbnail record could not be found or generated")
+			return nil, common.ErrMediaNotFound
+		}
 
-	err = storage.GetDatabase().GetMetadataStore(ctx, log).UpsertLastAccess(thumbnail.Sha256Hash, util.NowMillis())
-	if err != nil {
-		logrus.Warn("Failed to upsert the last access time: ", err)
-	}
+		err = storage.GetDatabase().GetMetadataStore(ctx, log).UpsertLastAccess(thumbnail.Sha256Hash, util.NowMillis())
+		if err != nil {
+			logrus.Warn("Failed to upsert the last access time: ", err)
+		}
 
-	localCache.Set(cacheKey, thumbnail, cache.DefaultExpiration)
-	internal_cache.Get().IncrementDownloads(thumbnail.Sha256Hash)
+		localCache.Set(cacheKey, thumbnail, cache.DefaultExpiration)
 
-	cached, err := internal_cache.Get().GetThumbnail(thumbnail, log)
-	if err != nil {
-		return nil, err
-	}
-	if cached != nil && cached.Contents != nil {
-		return &types.StreamedThumbnail{
-			Thumbnail: thumbnail,
-			Stream:    util.BufferToStream(cached.Contents),
-		}, nil
-	}
+		cached, err := internal_cache.Get().GetThumbnail(thumbnail, log)
+		if err != nil {
+			return nil, err
+		}
+		if cached != nil && cached.Contents != nil {
+			return &types.StreamedThumbnail{
+				Thumbnail: thumbnail,
+				Stream:    util.BufferToStream(cached.Contents),
+			}, nil
+		}
 
-	log.Info("Reading thumbnail from disk")
-	mediaStream, err := datastore.DownloadStream(ctx, log, thumbnail.DatastoreId, thumbnail.Location)
-	if err != nil {
-		return nil, err
-	}
+		log.Info("Reading thumbnail from disk")
+		mediaStream, err := datastore.DownloadStream(ctx, log, thumbnail.DatastoreId, thumbnail.Location)
+		if err != nil {
+			return nil, err
+		}
+
+		return &types.StreamedThumbnail{Thumbnail: thumbnail, Stream: mediaStream}, nil
+	}, func(v interface{}, count int, err error) []interface{} {
+		if err != nil {
+			return nil
+		}
+
+		rv := v.(*types.StreamedThumbnail)
+		vals := make([]interface{}, 0)
+		streams := util.CloneReader(rv.Stream, count)
+
+		for i := 0; i < count; i++ {
+			internal_cache.Get().IncrementDownloads(rv.Thumbnail.Sha256Hash)
+			vals = append(vals, &types.StreamedThumbnail{
+				Thumbnail: rv.Thumbnail,
+				Stream:    streams[i],
+			})
+		}
+
+		return vals
+	})
 
-	return &types.StreamedThumbnail{Thumbnail: thumbnail, Stream: mediaStream}, nil
+	return v.(*types.StreamedThumbnail), err
 }
 
 func GetOrGenerateThumbnail(media *types.Media, width int, height int, animated bool, method string, ctx context.Context, log *logrus.Entry) (*types.Thumbnail, error) {
diff --git a/storage/stores/url_store.go b/storage/stores/url_store.go
index cf65bfdb..4e321ac0 100644
--- a/storage/stores/url_store.go
+++ b/storage/stores/url_store.go
@@ -58,7 +58,7 @@ func (s *UrlStore) GetPreview(url string, ts int64) (*types.CachedUrlPreview, er
 	r := &types.CachedUrlPreview{
 		Preview: &types.UrlPreview{},
 	}
-	err := s.statements.selectUrlPreview.QueryRowContext(s.ctx, url, getBucketTs(ts)).Scan(
+	err := s.statements.selectUrlPreview.QueryRowContext(s.ctx, url, GetBucketTs(ts)).Scan(
 		&r.SearchUrl,
 		&r.ErrorCode,
 		&r.FetchedTs,
@@ -82,7 +82,7 @@ func (s *UrlStore) InsertPreview(record *types.CachedUrlPreview) error {
 		s.ctx,
 		record.SearchUrl,
 		record.ErrorCode,
-		getBucketTs(record.FetchedTs),
+		GetBucketTs(record.FetchedTs),
 		record.Preview.Url,
 		record.Preview.SiteName,
 		record.Preview.Type,
@@ -107,7 +107,7 @@ func (s *UrlStore) InsertPreviewError(url string, errorCode string) error {
 	})
 }
 
-func getBucketTs(ts int64) int64 {
+func GetBucketTs(ts int64) int64 {
 	// 1 hour buckets
 	return (ts / 3600000) * 3600000
 }
diff --git a/util/singleflight-counter/singleflight.go b/util/singleflight-counter/singleflight.go
index cf4e5485..b32095c9 100644
--- a/util/singleflight-counter/singleflight.go
+++ b/util/singleflight-counter/singleflight.go
@@ -35,6 +35,12 @@ func (c *call) NextVal() interface{} {
 	return val
 }
 
+func (g *Group) DoWithoutPost(key string, fn func() (interface{}, error)) (interface{}, int, error) {
+	return g.Do(key, fn, func(v interface{}, total int, e error) []interface{} {
+		return nil
+	})
+}
+
 func (g *Group) Do(key string, fn func() (interface{}, error), postprocess func(v interface{}, total int, e error) []interface{}) (interface{}, int, error) {
 	g.mu.Lock()
 	if g.m == nil {
-- 
GitLab