diff --git a/api/r0/thumbnail.go b/api/r0/thumbnail.go index 39ee66f3a6820430f86f669de5b40510a62bc92e..afae5e4fe23320498e6515e7d7241e2eaaa92924 100644 --- a/api/r0/thumbnail.go +++ b/api/r0/thumbnail.go @@ -1,6 +1,7 @@ package r0 import ( + "fmt" "net/http" "strconv" @@ -10,8 +11,13 @@ import ( "github.com/turt2live/matrix-media-repo/common" "github.com/turt2live/matrix-media-repo/common/config" "github.com/turt2live/matrix-media-repo/controllers/thumbnail_controller" + "github.com/turt2live/matrix-media-repo/types" + "github.com/turt2live/matrix-media-repo/util" + "github.com/turt2live/matrix-media-repo/util/singleflight-counter" ) +var thumbnailRequestGroup singleflight_counter.Group + func ThumbnailMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { params := mux.Vars(r) @@ -75,21 +81,55 @@ func ThumbnailMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) inter "requestedAnimated": animated, }) - streamedThumbnail, err := thumbnail_controller.GetThumbnail(server, mediaId, width, height, animated, method, downloadRemote, r.Context(), log) - if err != nil { - if err == common.ErrMediaNotFound { - return api.NotFoundError() - } else if err == common.ErrMediaTooLarge { - return api.RequestTooLarge() + // TODO: Move this to a lower layer (somewhere where the thumbnail dimensions are known, before media is downloaded) + requestKey := fmt.Sprintf("thumbnail_%s_%s_%d_%d_%s_%t", server, mediaId, width, height, method, animated) + v, count, err := thumbnailRequestGroup.Do(requestKey, func() (interface{}, error) { + streamedThumbnail, err := thumbnail_controller.GetThumbnail(server, mediaId, width, height, animated, method, downloadRemote, r.Context(), log) + if err != nil { + if err == common.ErrMediaNotFound { + return api.NotFoundError(), nil + } else if err == common.ErrMediaTooLarge { + return api.RequestTooLarge(), nil + } + log.Error("Unexpected error locating media: " + err.Error()) + return api.InternalServerError("Unexpected Error"), nil + } + + return streamedThumbnail, nil + }, func(v interface{}, count int, err error) []interface{} { + if err != nil { + return nil + } + + rv := v.(*types.StreamedThumbnail) + vals := make([]interface{}, 0) + streams := util.CloneReader(rv.Stream, count) + + for i := 0; i < count; i++ { + vals = append(vals, &types.StreamedThumbnail{ + Thumbnail: rv.Thumbnail, + Stream: streams[i], + }) } - log.Error("Unexpected error locating media: " + err.Error()) + + return vals + }) + + if err != nil { + log.Error("Unexpected error handling request: " + err.Error()) return api.InternalServerError("Unexpected Error") } + rv := v.(*types.StreamedThumbnail) + + if count > 0 { + log.Info("Request response was shared ", count, " times") + } + return &DownloadMediaResponse{ - ContentType: streamedThumbnail.Thumbnail.ContentType, - SizeBytes: streamedThumbnail.Thumbnail.SizeBytes, - Data: streamedThumbnail.Stream, + ContentType: rv.Thumbnail.ContentType, + SizeBytes: rv.Thumbnail.SizeBytes, + Data: rv.Stream, Filename: "thumbnail", } } diff --git a/go.mod b/go.mod index 546fdb2734cc8f09b81a09272f56616058b584b7..2898b48527ac20c72a378b70de4ea5ded55a5be0 100644 --- a/go.mod +++ b/go.mod @@ -46,6 +46,7 @@ require ( golang.org/x/crypto v0.0.0-20190131182504-b8fe1690c613 golang.org/x/image v0.0.0-20171214225156-12117c17ca67 golang.org/x/net v0.0.0-20190125091013-d26f9f9a57f3 + golang.org/x/sync v0.0.0-20181108010431-42b317875d0f golang.org/x/sys v0.0.0-20190203050204-7ae0202eb74c golang.org/x/text v0.3.0 golang.org/x/time v0.0.0-20170927054726-6dc17368e09b diff --git a/go.sum b/go.sum index 4255c4939d7fa8c206b0e0648b254ded9caf828f..d0e06302830e50262c0c95b9fab018d6fad5fc10 100644 --- a/go.sum +++ b/go.sum @@ -93,6 +93,7 @@ golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73r golang.org/x/net v0.0.0-20190125091013-d26f9f9a57f3 h1:ulvT7fqt0yHWzpJwI57MezWnYDVpCAYBVuYst/L+fAY= golang.org/x/net v0.0.0-20190125091013-d26f9f9a57f3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f h1:Bl/8QSvNqXvPGPGXa2z5xUTmV7VDcZyvRZ+QQXkXTZQ= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190124100055-b90733256f2e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190203050204-7ae0202eb74c h1:YeMXU0KQqExdpG959DFhAhfpY8myIsnfqj8lhNFRzzE= diff --git a/types/thumbnail.go b/types/thumbnail.go index 16b171c22a439b75f0d99368106411647ee2054a..de71be3445d6be63af0bc25d4db791307cff8031 100644 --- a/types/thumbnail.go +++ b/types/thumbnail.go @@ -1,6 +1,8 @@ package types -import "io" +import ( + "io" +) type Thumbnail struct { Origin string diff --git a/util/singleflight-counter/singleflight.go b/util/singleflight-counter/singleflight.go new file mode 100644 index 0000000000000000000000000000000000000000..cf4e548535cfae3e417e34d736be703a50b812b3 --- /dev/null +++ b/util/singleflight-counter/singleflight.go @@ -0,0 +1,72 @@ +package singleflight_counter + +import ( + "sync" +) + +// Largely inspired by Go's singleflight package. +// https://github.com/golang/sync/blob/112230192c580c3556b8cee6403af37a4fc5f28c/singleflight/singleflight.go + +type call struct { + wg sync.WaitGroup + + valsMu sync.Mutex + nextIndex int + vals []interface{} + + val interface{} + err error + count int +} + +type Group struct { + mu sync.Mutex + m map[string]*call +} + +func (c *call) NextVal() interface{} { + c.valsMu.Lock() + val := c.val + if c.vals != nil && len(c.vals) >= c.count { + val = c.vals[c.nextIndex] + c.nextIndex++ + } + c.valsMu.Unlock() + return val +} + +func (g *Group) Do(key string, fn func() (interface{}, error), postprocess func(v interface{}, total int, e error) []interface{}) (interface{}, int, error) { + g.mu.Lock() + if g.m == nil { + g.m = make(map[string]*call) + } + if c, ok := g.m[key]; ok { + c.count++ + g.mu.Unlock() + c.wg.Wait() + + return c.NextVal(), c.count, c.err + } + + c := new(call) + c.count = 1 // Always start at 1 (for us) + c.nextIndex = 0 + c.wg.Add(1) + g.m[key] = c + g.mu.Unlock() + + c.val, c.err = fn() + + g.mu.Lock() + delete(g.m, key) + g.mu.Unlock() + + c.vals = nil + if postprocess != nil { + c.vals = postprocess(c.val, c.count, c.err) + } + + c.wg.Done() + + return c.NextVal(), c.count, c.err +}