diff --git a/api/r0/thumbnail.go b/api/r0/thumbnail.go index 2495ee29b60732ff4309e1914847601667c13438..7b89fb1668cf1e879c03c8a274938e18ff43653f 100644 --- a/api/r0/thumbnail.go +++ b/api/r0/thumbnail.go @@ -2,6 +2,7 @@ package r0 import ( "fmt" + "io/ioutil" "net/http" "strconv" @@ -11,6 +12,7 @@ import ( "github.com/turt2live/matrix-media-repo/common" "github.com/turt2live/matrix-media-repo/common/config" "github.com/turt2live/matrix-media-repo/controllers/thumbnail_controller" + "github.com/turt2live/matrix-media-repo/types" "golang.org/x/sync/singleflight" ) @@ -93,12 +95,7 @@ func ThumbnailMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) inter return api.InternalServerError("Unexpected Error"), nil } - return &DownloadMediaResponse{ - ContentType: streamedThumbnail.Thumbnail.ContentType, - SizeBytes: streamedThumbnail.Thumbnail.SizeBytes, - Data: streamedThumbnail.Stream, - Filename: "thumbnail", - }, nil + return streamedThumbnail, nil }) if err != nil { @@ -106,9 +103,16 @@ func ThumbnailMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) inter return api.InternalServerError("Unexpected Error") } + rv := v.(*types.StreamedThumbnail) + if shared { log.Info("Request response was shared") } - return v + return &DownloadMediaResponse{ + ContentType: rv.Thumbnail.ContentType, + SizeBytes: rv.Thumbnail.SizeBytes, + Data: ioutil.NopCloser(rv.Stream.GetReader()), + Filename: "thumbnail", + } } diff --git a/controllers/thumbnail_controller/thumbnail_controller.go b/controllers/thumbnail_controller/thumbnail_controller.go index 616f66377cce208dc367efc5e3934149ac7460a9..b5a96578c719812bf26cc9fbad9fbbe61f1eb61f 100644 --- a/controllers/thumbnail_controller/thumbnail_controller.go +++ b/controllers/thumbnail_controller/thumbnail_controller.go @@ -73,7 +73,7 @@ func GetThumbnail(origin string, mediaId string, desiredWidth int, desiredHeight data := &bytes.Buffer{} imaging.Encode(data, img, imaging.PNG) return &types.StreamedThumbnail{ - Stream: util.BufferToStream(data), + Stream: util.NewManyReader(util.BufferToStream(data)), Thumbnail: &types.Thumbnail{ // We lie about the details to ensure we keep our contract Width: img.Bounds().Max.X, @@ -161,7 +161,7 @@ func GetThumbnail(origin string, mediaId string, desiredWidth int, desiredHeight if cached != nil && cached.Contents != nil { return &types.StreamedThumbnail{ Thumbnail: thumbnail, - Stream: util.BufferToStream(cached.Contents), + Stream: util.NewManyReader(util.BufferToStream(cached.Contents)), }, nil } @@ -171,7 +171,7 @@ func GetThumbnail(origin string, mediaId string, desiredWidth int, desiredHeight return nil, err } - return &types.StreamedThumbnail{Thumbnail: thumbnail, Stream: mediaStream}, nil + return &types.StreamedThumbnail{Thumbnail: thumbnail, Stream: util.NewManyReader(mediaStream)}, nil } func GetOrGenerateThumbnail(media *types.Media, width int, height int, animated bool, method string, ctx context.Context, log *logrus.Entry) (*types.Thumbnail, error) { diff --git a/types/thumbnail.go b/types/thumbnail.go index 16b171c22a439b75f0d99368106411647ee2054a..e769059b9bc7c985da5c26c485eab4a4a5752e0b 100644 --- a/types/thumbnail.go +++ b/types/thumbnail.go @@ -1,6 +1,8 @@ package types -import "io" +import ( + "github.com/turt2live/matrix-media-repo/util" +) type Thumbnail struct { Origin string @@ -19,5 +21,5 @@ type Thumbnail struct { type StreamedThumbnail struct { Thumbnail *Thumbnail - Stream io.ReadCloser + Stream *util.ManyReader } diff --git a/util/many_reader.go b/util/many_reader.go new file mode 100644 index 0000000000000000000000000000000000000000..cb0935d31af71784668bc46743bac4d92cc852d1 --- /dev/null +++ b/util/many_reader.go @@ -0,0 +1,74 @@ +package util + +import ( + "bytes" + "io" + "io/ioutil" + + "github.com/sirupsen/logrus" +) + +type ManyReader struct { + buf *bytes.Buffer + eof bool +} + +type ManyReaderReader struct { + manyReader *ManyReader + pos int +} + +func NewManyReader(input io.Reader) *ManyReader { + var buf bytes.Buffer + tr := io.TeeReader(input, &buf) + + mr := &ManyReader{&buf, false} + go func() { + ioutil.ReadAll(tr) + mr.eof = true + }() + + return mr +} + +func (r *ManyReader) GetReader() *ManyReaderReader { + return &ManyReaderReader{r, 0} +} + +func (r *ManyReaderReader) Read(p []byte) (int, error) { + b := r.manyReader.buf.Bytes() + available := len(b) + if r.pos >= available - 1 && r.manyReader.eof { + return 0, io.EOF + } + + limit := len(p) + end := r.pos + limit + if end > available { + end = available + } + + if end == r.pos || end <= 0 { + return 0, nil + } + + limit = end - r.pos + if limit <= 0 { + return 0, nil + } + + logrus.Info("Available: ", available) + logrus.Info("Position: ", r.pos) + logrus.Info("End: ", end) + logrus.Info("Limit: ", limit) + + for i := 0; i < limit; i++ { + p[i] = b[r.pos + i] + } + r.pos += limit + + logrus.Info("Read: ", limit) + logrus.Info("Final Position: ", r.pos) + + return limit, nil +}