Skip to content
Snippets Groups Projects
Commit 0205efa9 authored by Travis Ralston's avatar Travis Ralston
Browse files

Use custom singleflight everywhere

parent b94cfafc
No related branches found
No related tags found
No related merge requests found
package r0
import (
"fmt"
"net/http"
"strconv"
......@@ -11,13 +10,8 @@ import (
"github.com/turt2live/matrix-media-repo/common"
"github.com/turt2live/matrix-media-repo/common/config"
"github.com/turt2live/matrix-media-repo/controllers/thumbnail_controller"
"github.com/turt2live/matrix-media-repo/types"
"github.com/turt2live/matrix-media-repo/util"
"github.com/turt2live/matrix-media-repo/util/singleflight-counter"
)
var thumbnailRequestGroup singleflight_counter.Group
func ThumbnailMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} {
params := mux.Vars(r)
......@@ -81,55 +75,21 @@ func ThumbnailMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) inter
"requestedAnimated": animated,
})
// TODO: Move this to a lower layer (somewhere where the thumbnail dimensions are known, before media is downloaded)
requestKey := fmt.Sprintf("thumbnail_%s_%s_%d_%d_%s_%t", server, mediaId, width, height, method, animated)
v, count, err := thumbnailRequestGroup.Do(requestKey, func() (interface{}, error) {
streamedThumbnail, err := thumbnail_controller.GetThumbnail(server, mediaId, width, height, animated, method, downloadRemote, r.Context(), log)
if err != nil {
if err == common.ErrMediaNotFound {
return api.NotFoundError(), nil
} else if err == common.ErrMediaTooLarge {
return api.RequestTooLarge(), nil
}
log.Error("Unexpected error locating media: " + err.Error())
return api.InternalServerError("Unexpected Error"), nil
}
return streamedThumbnail, nil
}, func(v interface{}, count int, err error) []interface{} {
if err != nil {
return nil
}
rv := v.(*types.StreamedThumbnail)
vals := make([]interface{}, 0)
streams := util.CloneReader(rv.Stream, count)
for i := 0; i < count; i++ {
vals = append(vals, &types.StreamedThumbnail{
Thumbnail: rv.Thumbnail,
Stream: streams[i],
})
}
return vals
})
streamedThumbnail, err := thumbnail_controller.GetThumbnail(server, mediaId, width, height, animated, method, downloadRemote, r.Context(), log)
if err != nil {
log.Error("Unexpected error handling request: " + err.Error())
if err == common.ErrMediaNotFound {
return api.NotFoundError()
} else if err == common.ErrMediaTooLarge {
return api.RequestTooLarge()
}
log.Error("Unexpected error locating media: " + err.Error())
return api.InternalServerError("Unexpected Error")
}
rv := v.(*types.StreamedThumbnail)
if count > 0 {
log.Info("Request response was shared ", count, " times")
}
return &DownloadMediaResponse{
ContentType: rv.Thumbnail.ContentType,
SizeBytes: rv.Thumbnail.SizeBytes,
Data: rv.Stream,
ContentType: streamedThumbnail.Thumbnail.ContentType,
SizeBytes: streamedThumbnail.Thumbnail.SizeBytes,
Data: streamedThumbnail.Stream,
Filename: "thumbnail",
}
}
package globals
import (
"github.com/turt2live/matrix-media-repo/util/singleflight-counter"
)
var DefaultRequestGroup singleflight_counter.Group
......@@ -4,11 +4,13 @@ import (
"context"
"database/sql"
"errors"
"fmt"
"time"
"github.com/patrickmn/go-cache"
"github.com/sirupsen/logrus"
"github.com/turt2live/matrix-media-repo/common"
"github.com/turt2live/matrix-media-repo/common/globals"
"github.com/turt2live/matrix-media-repo/internal_cache"
"github.com/turt2live/matrix-media-repo/storage"
"github.com/turt2live/matrix-media-repo/storage/datastore"
......@@ -19,86 +21,116 @@ import (
var localCache = cache.New(30*time.Second, 60*time.Second)
func GetMedia(origin string, mediaId string, downloadRemote bool, blockForMedia bool, ctx context.Context, log *logrus.Entry) (*types.MinimalMedia, error) {
var media *types.Media
var minMedia *types.MinimalMedia
var err error
if blockForMedia {
media, err = FindMediaRecord(origin, mediaId, downloadRemote, ctx, log)
if media != nil {
minMedia = &types.MinimalMedia{
Origin: media.Origin,
MediaId: media.MediaId,
ContentType: media.ContentType,
UploadName: media.UploadName,
SizeBytes: media.SizeBytes,
Stream: nil, // we'll populate this later if we need to
KnownMedia: media,
cacheKey := fmt.Sprintf("%s/%s?r=%t&b=%t", origin, mediaId, downloadRemote, blockForMedia)
v, _, err := globals.DefaultRequestGroup.Do(cacheKey, func() (interface{}, error) {
var media *types.Media
var minMedia *types.MinimalMedia
var err error
if blockForMedia {
media, err = FindMediaRecord(origin, mediaId, downloadRemote, ctx, log)
if media != nil {
minMedia = &types.MinimalMedia{
Origin: media.Origin,
MediaId: media.MediaId,
ContentType: media.ContentType,
UploadName: media.UploadName,
SizeBytes: media.SizeBytes,
Stream: nil, // we'll populate this later if we need to
KnownMedia: media,
}
}
} else {
minMedia, err = FindMinimalMediaRecord(origin, mediaId, downloadRemote, ctx, log)
if minMedia != nil {
media = minMedia.KnownMedia
}
}
} else {
minMedia, err = FindMinimalMediaRecord(origin, mediaId, downloadRemote, ctx, log)
if minMedia != nil {
media = minMedia.KnownMedia
if err != nil {
return nil, err
}
}
if err != nil {
return nil, err
}
if minMedia == nil {
log.Warn("Unexpected error while fetching media: no minimal media record")
return nil, common.ErrMediaNotFound
}
if media == nil && blockForMedia {
log.Warn("Unexpected error while fetching media: no regular media record (block for media in place)")
return nil, common.ErrMediaNotFound
}
if minMedia == nil {
log.Warn("Unexpected error while fetching media: no minimal media record")
return nil, common.ErrMediaNotFound
}
if media == nil && blockForMedia {
log.Warn("Unexpected error while fetching media: no regular media record (block for media in place)")
return nil, common.ErrMediaNotFound
}
// if we have a known media record, we might as well set it
// if we don't, this won't do anything different
minMedia.KnownMedia = media
if media != nil {
if media.Quarantined {
log.Warn("Quarantined media accessed")
return nil, common.ErrMediaQuarantined
}
// if we have a known media record, we might as well set it
// if we don't, this won't do anything different
minMedia.KnownMedia = media
err = storage.GetDatabase().GetMetadataStore(ctx, log).UpsertLastAccess(media.Sha256Hash, util.NowMillis())
if err != nil {
logrus.Warn("Failed to upsert the last access time: ", err)
}
if media != nil {
if media.Quarantined {
log.Warn("Quarantined media accessed")
return nil, common.ErrMediaQuarantined
localCache.Set(origin+"/"+mediaId, media, cache.DefaultExpiration)
internal_cache.Get().IncrementDownloads(media.Sha256Hash)
cached, err := internal_cache.Get().GetMedia(media, log)
if err != nil {
return nil, err
}
if cached != nil && cached.Contents != nil {
minMedia.Stream = util.BufferToStream(cached.Contents)
return minMedia, nil
}
}
err = storage.GetDatabase().GetMetadataStore(ctx, log).UpsertLastAccess(media.Sha256Hash, util.NowMillis())
if err != nil {
logrus.Warn("Failed to upsert the last access time: ", err)
if minMedia.Stream != nil {
log.Info("Returning minimal media record with a viable stream")
return minMedia, nil
}
localCache.Set(origin+"/"+mediaId, media, cache.DefaultExpiration)
internal_cache.Get().IncrementDownloads(media.Sha256Hash)
if media == nil {
log.Error("Failed to locate media")
return nil, errors.New("failed to locate media")
}
cached, err := internal_cache.Get().GetMedia(media, log)
log.Info("Reading media from disk")
mediaStream, err := datastore.DownloadStream(ctx, log, media.DatastoreId, media.Location)
if err != nil {
return nil, err
}
if cached != nil && cached.Contents != nil {
minMedia.Stream = util.BufferToStream(cached.Contents)
return minMedia, nil
}
}
if minMedia.Stream != nil {
log.Info("Returning minimal media record with a viable stream")
minMedia.Stream = mediaStream
return minMedia, nil
}
}, func(v interface{}, count int, err error) []interface{} {
if err != nil {
return nil
}
if media == nil {
log.Error("Failed to locate media")
return nil, errors.New("failed to locate media")
}
rv := v.(*types.MinimalMedia)
vals := make([]interface{}, 0)
streams := util.CloneReader(rv.Stream, count)
log.Info("Reading media from disk")
mediaStream, err := datastore.DownloadStream(ctx, log, media.DatastoreId, media.Location)
if err != nil {
return nil, err
}
for i := 0; i < count; i++ {
if rv.KnownMedia != nil {
internal_cache.Get().IncrementDownloads(rv.KnownMedia.Sha256Hash)
}
vals = append(vals, &types.MinimalMedia{
Origin: rv.Origin,
MediaId: rv.MediaId,
UploadName: rv.UploadName,
ContentType: rv.ContentType,
SizeBytes: rv.SizeBytes,
KnownMedia: rv.KnownMedia,
Stream: streams[i],
})
}
minMedia.Stream = mediaStream
return minMedia, nil
return vals
})
return v.(*types.MinimalMedia), err
}
func FindMinimalMediaRecord(origin string, mediaId string, downloadRemote bool, ctx context.Context, log *logrus.Entry) (*types.MinimalMedia, error) {
......@@ -181,45 +213,50 @@ func FindMinimalMediaRecord(origin string, mediaId string, downloadRemote bool,
}
func FindMediaRecord(origin string, mediaId string, downloadRemote bool, ctx context.Context, log *logrus.Entry) (*types.Media, error) {
db := storage.GetDatabase().GetMediaStore(ctx, log)
cacheKey := origin + "/" + mediaId
v, _, err := globals.DefaultRequestGroup.DoWithoutPost(cacheKey, func() (interface{}, error) {
db := storage.GetDatabase().GetMediaStore(ctx, log)
var media *types.Media
item, found := localCache.Get(origin + "/" + mediaId)
if found {
media = item.(*types.Media)
} else {
log.Info("Getting media record from database")
dbMedia, err := db.Get(origin, mediaId)
if err != nil {
if err == sql.ErrNoRows {
if util.IsServerOurs(origin) {
log.Warn("Media not found")
return nil, common.ErrMediaNotFound
var media *types.Media
item, found := localCache.Get(cacheKey)
if found {
media = item.(*types.Media)
} else {
log.Info("Getting media record from database")
dbMedia, err := db.Get(origin, mediaId)
if err != nil {
if err == sql.ErrNoRows {
if util.IsServerOurs(origin) {
log.Warn("Media not found")
return nil, common.ErrMediaNotFound
}
}
}
if !downloadRemote {
log.Warn("Remote media not being downloaded")
return nil, common.ErrMediaNotFound
}
if !downloadRemote {
log.Warn("Remote media not being downloaded")
return nil, common.ErrMediaNotFound
}
mediaChan := getResourceHandler().DownloadRemoteMedia(origin, mediaId, true)
defer close(mediaChan)
mediaChan := getResourceHandler().DownloadRemoteMedia(origin, mediaId, true)
defer close(mediaChan)
result := <-mediaChan
if result.err != nil {
return nil, result.err
result := <-mediaChan
if result.err != nil {
return nil, result.err
}
media = result.media
} else {
media = dbMedia
}
media = result.media
} else {
media = dbMedia
}
}
if media == nil {
log.Warn("Despite all efforts, a media record could not be found")
return nil, common.ErrMediaNotFound
}
if media == nil {
log.Warn("Despite all efforts, a media record could not be found")
return nil, common.ErrMediaNotFound
}
return media, nil
})
return media, nil
return v.(*types.Media), err
}
......@@ -4,52 +4,62 @@ import (
"context"
"database/sql"
"errors"
"fmt"
"github.com/sirupsen/logrus"
"github.com/turt2live/matrix-media-repo/common"
"github.com/turt2live/matrix-media-repo/common/globals"
"github.com/turt2live/matrix-media-repo/controllers/preview_controller/acl"
"github.com/turt2live/matrix-media-repo/storage"
"github.com/turt2live/matrix-media-repo/storage/stores"
"github.com/turt2live/matrix-media-repo/types"
"github.com/turt2live/matrix-media-repo/util"
)
func GetPreview(urlStr string, onHost string, forUserId string, atTs int64, ctx context.Context, log *logrus.Entry) (*types.UrlPreview, error) {
log = log.WithFields(logrus.Fields{
"preview_controller_at_ts": atTs,
})
atTs = stores.GetBucketTs(atTs)
cacheKey := fmt.Sprintf("%d_%s/%s", atTs, onHost, urlStr)
v, _, err := globals.DefaultRequestGroup.DoWithoutPost(cacheKey, func() (interface{}, error) {
db := storage.GetDatabase().GetUrlStore(ctx, log)
log = log.WithFields(logrus.Fields{
"preview_controller_at_ts": atTs,
})
cached, err := db.GetPreview(urlStr, atTs)
if err != nil && err != sql.ErrNoRows {
log.Error("Error getting cached URL preview: ", err.Error())
return nil, err
}
if err != sql.ErrNoRows {
log.Info("Returning cached URL preview")
return cachedPreviewToReal(cached)
}
db := storage.GetDatabase().GetUrlStore(ctx, log)
now := util.NowMillis()
if (now - atTs) > 60000 {
// Because we don't have a cached preview, we'll use the current time as the preview time.
// We also give a 60 second buffer so we don't cause an infinite loop (considering we're
// calling ourselves), and to give a lenient opportunity for slow execution.
return GetPreview(urlStr, onHost, forUserId, now, ctx, log)
}
cached, err := db.GetPreview(urlStr, atTs)
if err != nil && err != sql.ErrNoRows {
log.Error("Error getting cached URL preview: ", err.Error())
return nil, err
}
if err != sql.ErrNoRows {
log.Info("Returning cached URL preview")
return cachedPreviewToReal(cached)
}
log.Info("Preview not cached - fetching resource")
now := util.NowMillis()
if (now - atTs) > 60000 {
// Because we don't have a cached preview, we'll use the current time as the preview time.
// We also give a 60 second buffer so we don't cause an infinite loop (considering we're
// calling ourselves), and to give a lenient opportunity for slow execution.
return GetPreview(urlStr, onHost, forUserId, now, ctx, log)
}
urlToPreview, err := acl.ValidateUrlForPreview(urlStr, ctx, log)
if err != nil {
return nil, err
}
log.Info("Preview not cached - fetching resource")
urlToPreview, err := acl.ValidateUrlForPreview(urlStr, ctx, log)
if err != nil {
return nil, err
}
previewChan := getResourceHandler().GeneratePreview(urlToPreview, forUserId, onHost)
defer close(previewChan)
previewChan := getResourceHandler().GeneratePreview(urlToPreview, forUserId, onHost)
defer close(previewChan)
result := <-previewChan
return result.preview, result.err
})
result := <-previewChan
return result.preview, result.err
return v.(*types.UrlPreview), err
}
func cachedPreviewToReal(cached *types.CachedUrlPreview) (*types.UrlPreview, error) {
......
......@@ -18,6 +18,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/turt2live/matrix-media-repo/common"
"github.com/turt2live/matrix-media-repo/common/config"
"github.com/turt2live/matrix-media-repo/common/globals"
"github.com/turt2live/matrix-media-repo/controllers/download_controller"
"github.com/turt2live/matrix-media-repo/internal_cache"
"github.com/turt2live/matrix-media-repo/storage"
......@@ -108,8 +109,6 @@ func GetThumbnail(origin string, mediaId string, desiredWidth int, desiredHeight
return nil, common.ErrMediaTooLarge
}
db := storage.GetDatabase().GetThumbnailStore(ctx, log)
width, height, method, err := pickThumbnailDimensions(desiredWidth, desiredHeight, method)
if err != nil {
return nil, err
......@@ -117,61 +116,84 @@ func GetThumbnail(origin string, mediaId string, desiredWidth int, desiredHeight
cacheKey := fmt.Sprintf("%s/%s?w=%d&h=%d&m=%s&a=%t", media.Origin, media.MediaId, width, height, method, animated)
var thumbnail *types.Thumbnail
item, found := localCache.Get(cacheKey)
if found {
thumbnail = item.(*types.Thumbnail)
} else {
log.Info("Getting thumbnail record from database")
dbThumb, err := db.Get(media.Origin, media.MediaId, width, height, method, animated)
if err != nil {
if err == sql.ErrNoRows {
log.Info("Thumbnail does not exist, attempting to generate it")
genThumb, err2 := GetOrGenerateThumbnail(media, width, height, animated, method, ctx, log)
if err2 != nil {
return nil, err2
}
v, _, err := globals.DefaultRequestGroup.Do(cacheKey, func() (interface{}, error) {
db := storage.GetDatabase().GetThumbnailStore(ctx, log)
thumbnail = genThumb
var thumbnail *types.Thumbnail
item, found := localCache.Get(cacheKey)
if found {
thumbnail = item.(*types.Thumbnail)
} else {
log.Info("Getting thumbnail record from database")
dbThumb, err := db.Get(media.Origin, media.MediaId, width, height, method, animated)
if err != nil {
if err == sql.ErrNoRows {
log.Info("Thumbnail does not exist, attempting to generate it")
genThumb, err2 := GetOrGenerateThumbnail(media, width, height, animated, method, ctx, log)
if err2 != nil {
return nil, err2
}
thumbnail = genThumb
} else {
return nil, err
}
} else {
return nil, err
thumbnail = dbThumb
}
} else {
thumbnail = dbThumb
}
}
if thumbnail == nil {
log.Warn("Despite all efforts, a thumbnail record could not be found or generated")
return nil, common.ErrMediaNotFound
}
if thumbnail == nil {
log.Warn("Despite all efforts, a thumbnail record could not be found or generated")
return nil, common.ErrMediaNotFound
}
err = storage.GetDatabase().GetMetadataStore(ctx, log).UpsertLastAccess(thumbnail.Sha256Hash, util.NowMillis())
if err != nil {
logrus.Warn("Failed to upsert the last access time: ", err)
}
err = storage.GetDatabase().GetMetadataStore(ctx, log).UpsertLastAccess(thumbnail.Sha256Hash, util.NowMillis())
if err != nil {
logrus.Warn("Failed to upsert the last access time: ", err)
}
localCache.Set(cacheKey, thumbnail, cache.DefaultExpiration)
internal_cache.Get().IncrementDownloads(thumbnail.Sha256Hash)
localCache.Set(cacheKey, thumbnail, cache.DefaultExpiration)
cached, err := internal_cache.Get().GetThumbnail(thumbnail, log)
if err != nil {
return nil, err
}
if cached != nil && cached.Contents != nil {
return &types.StreamedThumbnail{
Thumbnail: thumbnail,
Stream: util.BufferToStream(cached.Contents),
}, nil
}
cached, err := internal_cache.Get().GetThumbnail(thumbnail, log)
if err != nil {
return nil, err
}
if cached != nil && cached.Contents != nil {
return &types.StreamedThumbnail{
Thumbnail: thumbnail,
Stream: util.BufferToStream(cached.Contents),
}, nil
}
log.Info("Reading thumbnail from disk")
mediaStream, err := datastore.DownloadStream(ctx, log, thumbnail.DatastoreId, thumbnail.Location)
if err != nil {
return nil, err
}
log.Info("Reading thumbnail from disk")
mediaStream, err := datastore.DownloadStream(ctx, log, thumbnail.DatastoreId, thumbnail.Location)
if err != nil {
return nil, err
}
return &types.StreamedThumbnail{Thumbnail: thumbnail, Stream: mediaStream}, nil
}, func(v interface{}, count int, err error) []interface{} {
if err != nil {
return nil
}
rv := v.(*types.StreamedThumbnail)
vals := make([]interface{}, 0)
streams := util.CloneReader(rv.Stream, count)
for i := 0; i < count; i++ {
internal_cache.Get().IncrementDownloads(rv.Thumbnail.Sha256Hash)
vals = append(vals, &types.StreamedThumbnail{
Thumbnail: rv.Thumbnail,
Stream: streams[i],
})
}
return vals
})
return &types.StreamedThumbnail{Thumbnail: thumbnail, Stream: mediaStream}, nil
return v.(*types.StreamedThumbnail), err
}
func GetOrGenerateThumbnail(media *types.Media, width int, height int, animated bool, method string, ctx context.Context, log *logrus.Entry) (*types.Thumbnail, error) {
......
......@@ -58,7 +58,7 @@ func (s *UrlStore) GetPreview(url string, ts int64) (*types.CachedUrlPreview, er
r := &types.CachedUrlPreview{
Preview: &types.UrlPreview{},
}
err := s.statements.selectUrlPreview.QueryRowContext(s.ctx, url, getBucketTs(ts)).Scan(
err := s.statements.selectUrlPreview.QueryRowContext(s.ctx, url, GetBucketTs(ts)).Scan(
&r.SearchUrl,
&r.ErrorCode,
&r.FetchedTs,
......@@ -82,7 +82,7 @@ func (s *UrlStore) InsertPreview(record *types.CachedUrlPreview) error {
s.ctx,
record.SearchUrl,
record.ErrorCode,
getBucketTs(record.FetchedTs),
GetBucketTs(record.FetchedTs),
record.Preview.Url,
record.Preview.SiteName,
record.Preview.Type,
......@@ -107,7 +107,7 @@ func (s *UrlStore) InsertPreviewError(url string, errorCode string) error {
})
}
func getBucketTs(ts int64) int64 {
func GetBucketTs(ts int64) int64 {
// 1 hour buckets
return (ts / 3600000) * 3600000
}
......@@ -35,6 +35,12 @@ func (c *call) NextVal() interface{} {
return val
}
func (g *Group) DoWithoutPost(key string, fn func() (interface{}, error)) (interface{}, int, error) {
return g.Do(key, fn, func(v interface{}, total int, e error) []interface{} {
return nil
})
}
func (g *Group) Do(key string, fn func() (interface{}, error), postprocess func(v interface{}, total int, e error) []interface{}) (interface{}, int, error) {
g.mu.Lock()
if g.m == nil {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment