diff --git a/api/r0/download.go b/api/r0/download.go index 7855f73e7ce1804fc118a5578159114a6a571769..ed81ad4a0cef77fcaf05f16c2c5e82dce99b26f7 100644 --- a/api/r0/download.go +++ b/api/r0/download.go @@ -3,7 +3,6 @@ package r0 import ( "net/http" "strconv" - "time" "github.com/getsentry/sentry-go" "github.com/turt2live/matrix-media-repo/api/_apimeta" @@ -48,19 +47,9 @@ func DownloadMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta. downloadRemote = parsedFlag } - blockFor := 20 * time.Second - if timeoutMs != "" { - parsed, err := strconv.Atoi(timeoutMs) - if err != nil { - return _responses.BadRequest("timeout_ms does not appear to be an integer") - } - if parsed > 0 { - // Limit to 60 seconds - if parsed > 60000 { - parsed = 60000 - } - blockFor = time.Duration(parsed) * time.Millisecond - } + blockFor, err := util.CalcBlockForDuration(timeoutMs) + if err != nil { + return _responses.BadRequest("timeout_ms does not appear to be an integer") } rctx = rctx.LogWithFields(logrus.Fields{ diff --git a/api/r0/thumbnail.go b/api/r0/thumbnail.go index 86694315dad38f2ac079191c71d52808081b62e2..552fe0b8ddb9155cc763ef6265fbad7ab3d68af0 100644 --- a/api/r0/thumbnail.go +++ b/api/r0/thumbnail.go @@ -1,25 +1,27 @@ package r0 import ( + "net/http" + "strconv" + "github.com/getsentry/sentry-go" "github.com/turt2live/matrix-media-repo/api/_apimeta" "github.com/turt2live/matrix-media-repo/api/_responses" "github.com/turt2live/matrix-media-repo/api/_routers" + "github.com/turt2live/matrix-media-repo/pipelines/pipeline_download" + "github.com/turt2live/matrix-media-repo/pipelines/pipeline_thumbnail" "github.com/turt2live/matrix-media-repo/util" - "net/http" - "strconv" - "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/common" "github.com/turt2live/matrix-media-repo/common/rcontext" - "github.com/turt2live/matrix-media-repo/controllers/thumbnail_controller" ) func ThumbnailMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} { server := _routers.GetParam("server", r) mediaId := _routers.GetParam("mediaId", r) allowRemote := r.URL.Query().Get("allow_remote") + timeoutMs := r.URL.Query().Get("timeout_ms") if !_routers.ServerNameRegex.MatchString(server) { return _responses.BadRequest("invalid server ID") @@ -34,6 +36,11 @@ func ThumbnailMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta downloadRemote = parsedFlag } + blockFor, err := util.CalcBlockForDuration(timeoutMs) + if err != nil { + return _responses.BadRequest("timeout_ms does not appear to be an integer") + } + rctx = rctx.LogWithFields(logrus.Fields{ "mediaId": mediaId, "server": server, @@ -97,12 +104,26 @@ func ThumbnailMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta return _responses.BadRequest("Width and height must be greater than zero") } - streamedThumbnail, err := thumbnail_controller.GetThumbnail(server, mediaId, width, height, animated, method, downloadRemote, rctx) + thumbnail, stream, err := pipeline_thumbnail.Execute(rctx, server, mediaId, pipeline_thumbnail.ThumbnailOpts{ + DownloadOpts: pipeline_download.DownloadOpts{ + FetchRemoteIfNeeded: downloadRemote, + StartByte: -1, + EndByte: -1, + BlockForReadUntil: blockFor, + RecordOnly: false, // overridden + }, + Width: width, + Height: height, + Method: method, + Animated: animated, + }) if err != nil { if err == common.ErrMediaNotFound { return _responses.NotFoundError() } else if err == common.ErrMediaTooLarge { return _responses.RequestTooLarge() + } else if err == common.ErrMediaQuarantined { + return _responses.NotFoundError() // We lie for security } rctx.Log.Error("Unexpected error locating media: " + err.Error()) sentry.CaptureException(err) @@ -110,9 +131,10 @@ func ThumbnailMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta } return &DownloadMediaResponse{ - ContentType: streamedThumbnail.Thumbnail.ContentType, - SizeBytes: streamedThumbnail.Thumbnail.SizeBytes, - Data: streamedThumbnail.Stream, - Filename: "thumbnail.png", + ContentType: thumbnail.ContentType, + Filename: "thumbnail" + util.ExtensionForContentType(thumbnail.ContentType), + SizeBytes: thumbnail.SizeBytes, + Data: stream, + TargetDisposition: "infer", } } diff --git a/database/db.go b/database/db.go index 86eceae276c2f88d1e09f19a345c7cba6d9c9c33..006242abe0e4d441627c31dc10b9b364f3535a35 100644 --- a/database/db.go +++ b/database/db.go @@ -22,6 +22,7 @@ type Database struct { MetadataView *metadataVirtualTableStatements Blurhashes *blurhashesTableStatements HeldMedia *heldMediaTableStatements + Thumbnails *thumbnailsTableStatements } var instance *Database @@ -96,6 +97,9 @@ func openDatabase(connectionString string, maxConns int, maxIdleConns int) error if d.HeldMedia, err = prepareHeldMediaTables(d.conn); err != nil { return errors.New("failed to create held media table accessor: " + err.Error()) } + if d.Thumbnails, err = prepareThumbnailsTables(d.conn); err != nil { + return errors.New("failed to create thumbnails table accessor: " + err.Error()) + } instance = d return nil diff --git a/database/table_media.go b/database/table_media.go index 2b24b27b6a7399b0fba47bf6fb25f3c4b0f7e04a..fde55e210ff0311f9e2521181b15d1c272f4932f 100644 --- a/database/table_media.go +++ b/database/table_media.go @@ -7,18 +7,25 @@ import ( "github.com/turt2live/matrix-media-repo/common/rcontext" ) +type Locatable struct { + Sha256Hash string + DatastoreId string + Location string +} + type DbMedia struct { + *Locatable Origin string MediaId string UploadName string ContentType string UserId string - Sha256Hash string + //Sha256Hash string SizeBytes int64 CreationTs int64 Quarantined bool - DatastoreId string - Location string + //DatastoreId string + //Location string } const selectDistinctMediaDatastoreIds = "SELECT DISTINCT datastore_id FROM media;" @@ -118,7 +125,7 @@ func (s *mediaTableWithContext) GetByHash(sha256hash string) ([]*DbMedia, error) return nil, err } for rows.Next() { - val := &DbMedia{} + val := &DbMedia{Locatable: &Locatable{}} if err = rows.Scan(&val.Origin, &val.MediaId, &val.UploadName, &val.ContentType, &val.UserId, &val.Sha256Hash, &val.SizeBytes, &val.CreationTs, &val.Quarantined, &val.DatastoreId, &val.Location); err != nil { return nil, err } @@ -130,7 +137,7 @@ func (s *mediaTableWithContext) GetByHash(sha256hash string) ([]*DbMedia, error) func (s *mediaTableWithContext) GetById(origin string, mediaId string) (*DbMedia, error) { row := s.statements.selectMediaById.QueryRowContext(s.ctx, origin, mediaId) - val := &DbMedia{} + val := &DbMedia{Locatable: &Locatable{}} err := row.Scan(&val.Origin, &val.MediaId, &val.UploadName, &val.ContentType, &val.UserId, &val.Sha256Hash, &val.SizeBytes, &val.CreationTs, &val.Quarantined, &val.DatastoreId, &val.Location) if err == sql.ErrNoRows { err = nil diff --git a/database/table_thumbnails.go b/database/table_thumbnails.go new file mode 100644 index 0000000000000000000000000000000000000000..1a537b91ba864163b3dd66ef6435a4d238854949 --- /dev/null +++ b/database/table_thumbnails.go @@ -0,0 +1,76 @@ +package database + +import ( + "database/sql" + "errors" + + "github.com/turt2live/matrix-media-repo/common/rcontext" +) + +type DbMethod string + +type DbThumbnail struct { + *Locatable + Origin string + MediaId string + ContentType string + Width int + Height int + Method string + Animated bool + //Sha256Hash string + SizeBytes int64 + CreationTs int64 + //DatastoreId string + //Location string +} + +const selectThumbnailByParams = "SELECT origin, media_id, content_type, width, height, method, animated, sha256_hash, size_bytes, creation_ts, datastore_id, location FROM thumbnails WHERE origin = $1 AND media_id = $2 AND width = $3 AND height = $4 AND method = $5 AND animated = $6;" +const insertThumbnail = "INSERT INTO thumbnails (origin, media_id, content_type, width, height, method, animated, sha256_hash, size_bytes, creation_ts, datastore_id, location) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12);" + +type thumbnailsTableStatements struct { + selectThumbnailByParams *sql.Stmt + insertThumbnail *sql.Stmt +} + +type thumbnailsTableWithContext struct { + statements *thumbnailsTableStatements + ctx rcontext.RequestContext +} + +func prepareThumbnailsTables(db *sql.DB) (*thumbnailsTableStatements, error) { + var err error + var stmts = &thumbnailsTableStatements{} + + if stmts.selectThumbnailByParams, err = db.Prepare(selectThumbnailByParams); err != nil { + return nil, errors.New("error preparing selectThumbnailByParams: " + err.Error()) + } + if stmts.insertThumbnail, err = db.Prepare(insertThumbnail); err != nil { + return nil, errors.New("error preparing insertThumbnail: " + err.Error()) + } + + return stmts, nil +} + +func (s *thumbnailsTableStatements) Prepare(ctx rcontext.RequestContext) *thumbnailsTableWithContext { + return &thumbnailsTableWithContext{ + statements: s, + ctx: ctx, + } +} + +func (s *thumbnailsTableWithContext) GetByParams(origin string, mediaId string, width int, height int, method string, animated bool) (*DbThumbnail, error) { + row := s.statements.selectThumbnailByParams.QueryRowContext(s.ctx, origin, mediaId, width, height, method, animated) + val := &DbThumbnail{Locatable: &Locatable{}} + err := row.Scan(&val.Origin, &val.MediaId, &val.ContentType, &val.Width, &val.Height, &val.Method, &val.Animated, &val.Sha256Hash, &val.SizeBytes, &val.CreationTs, &val.DatastoreId, &val.Location) + if err == sql.ErrNoRows { + err = nil + val = nil + } + return val, err +} + +func (s *thumbnailsTableWithContext) Insert(record *DbThumbnail) error { + _, err := s.statements.insertThumbnail.ExecContext(s.ctx, record.Origin, record.MediaId, record.ContentType, record.Width, record.Height, record.Method, record.Animated, record.Sha256Hash, record.SizeBytes, record.CreationTs, record.DatastoreId, record.Location) + return err +} diff --git a/go.mod b/go.mod index 3af2dad1b68b41ce7df111394f981e202790b1a4..d8f38acf4c7cebb182124dad02821056ddc3f06d 100644 --- a/go.mod +++ b/go.mod @@ -54,7 +54,7 @@ require ( github.com/minio/minio-go/v7 v7.0.55 github.com/panjf2000/ants/v2 v2.7.4 github.com/redis/go-redis/v9 v9.0.4 - github.com/t2bot/go-singleflight-streams v0.0.1 + github.com/t2bot/go-singleflight-streams v0.0.2 ) require ( diff --git a/go.sum b/go.sum index 01e2d1210e2a8b0a5067dd07684c6f527ee19b64..f7ab76e56b128d016fcbdcb92169e10194cef8f7 100644 --- a/go.sum +++ b/go.sum @@ -335,8 +335,8 @@ github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKs github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stvp/tempredis v0.0.0-20181119212430-b82af8480203 h1:QVqDTf3h2WHt08YuiTGPZLls0Wq99X9bWd0Q5ZSBesM= github.com/stvp/tempredis v0.0.0-20181119212430-b82af8480203/go.mod h1:oqN97ltKNihBbwlX8dLpwxCl3+HnXKV/R0e+sRLd9C8= -github.com/t2bot/go-singleflight-streams v0.0.1 h1:UWGe3ud5thKP/oqCcUM8eyH1YIC/5ESpF5JvFDoTfH4= -github.com/t2bot/go-singleflight-streams v0.0.1/go.mod h1:oiZ5Zj2o6p3SuMWTPXAmye7Ou86WCrQ2oGcr32bEEq8= +github.com/t2bot/go-singleflight-streams v0.0.2 h1:N60e6rZvuf5CybfuGREgMDxKbFtvHqr3FzawuPLaGOI= +github.com/t2bot/go-singleflight-streams v0.0.2/go.mod h1:oiZ5Zj2o6p3SuMWTPXAmye7Ou86WCrQ2oGcr32bEEq8= github.com/tebeka/strftime v0.1.3 h1:5HQXOqWKYRFfNyBMNVc9z5+QzuBtIXy03psIhtdJYto= github.com/tebeka/strftime v0.1.3/go.mod h1:7wJm3dZlpr4l/oVK0t1HYIc4rMzQ2XJlOMIUJUJH6XQ= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= diff --git a/pipelines/_steps/datastore_op/put_and_return_stream.go b/pipelines/_steps/datastore_op/put_and_return_stream.go new file mode 100644 index 0000000000000000000000000000000000000000..975c1d8075368af7ffa52ea55b614afb043b92cb --- /dev/null +++ b/pipelines/_steps/datastore_op/put_and_return_stream.go @@ -0,0 +1,80 @@ +package datastore_op + +import ( + "errors" + "io" + "sync" + + "github.com/turt2live/matrix-media-repo/common/config" + "github.com/turt2live/matrix-media-repo/common/rcontext" + "github.com/turt2live/matrix-media-repo/database" + "github.com/turt2live/matrix-media-repo/datastores" + "github.com/turt2live/matrix-media-repo/pipelines/pipeline_upload" +) + +type downloadResult struct { + r io.ReadCloser + filename string + contentType string + err error +} + +type uploadResult struct { + m *database.DbMedia + err error +} + +func PutAndReturnStream(ctx rcontext.RequestContext, origin string, mediaId string, input io.ReadCloser, contentType string, fileName string, kind datastores.Kind) (*database.DbMedia, io.ReadCloser, error) { + dsConf, err := datastores.Pick(ctx, kind) + if err != nil { + return nil, nil, err + } + + pr, pw := io.Pipe() + tee := io.TeeReader(input, pw) + defer pw.CloseWithError(errors.New("failed to finish write")) + + wg := new(sync.WaitGroup) + wg.Add(2) + + bufferCh := make(chan downloadResult) + uploadCh := make(chan uploadResult) + defer close(bufferCh) + defer close(uploadCh) + + upstreamClose := func() error { return pw.Close() } + + go func(dsConf config.DatastoreConfig, pr io.ReadCloser, bufferCh chan downloadResult) { + _, _, retReader, err2 := datastores.BufferTemp(dsConf, pr) + // async the channel update to avoid deadlocks + go func(bufferCh chan downloadResult, err2 error, retReader io.ReadCloser) { + bufferCh <- downloadResult{err: err2, r: retReader} + }(bufferCh, err2, retReader) + wg.Done() + }(dsConf, pr, bufferCh) + + go func(ctx rcontext.RequestContext, origin string, mediaId string, r io.ReadCloser, upstreamClose func() error, contentType string, fileName string, kind datastores.Kind, uploadCh chan uploadResult) { + m, err2 := pipeline_upload.Execute(ctx, origin, mediaId, r, contentType, fileName, "", kind) + // async the channel update to avoid deadlocks + go func(uploadCh chan uploadResult, err2 error, m *database.DbMedia) { + uploadCh <- uploadResult{err: err2, m: m} + }(uploadCh, err2, m) + if err3 := upstreamClose(); err3 != nil { + ctx.Log.Warn("Failed to close non-tee writer during remote download: ", err3) + } + wg.Done() + }(ctx, origin, mediaId, io.NopCloser(tee), upstreamClose, contentType, fileName, kind, uploadCh) + + wg.Wait() + bufferRes := <-bufferCh + uploadRes := <-uploadCh + if bufferRes.err != nil { + return nil, nil, bufferRes.err + } + if uploadRes.err != nil { + defer bufferRes.r.Close() + return nil, nil, uploadRes.err + } + + return uploadRes.m, bufferRes.r, nil +} diff --git a/pipelines/_steps/download/open_stream.go b/pipelines/_steps/download/open_stream.go index 67efe03b1f85e76028633e8fd5cde18b4f8692fb..7012b91a8328efa33220fd4d000fc5d82b74745c 100644 --- a/pipelines/_steps/download/open_stream.go +++ b/pipelines/_steps/download/open_stream.go @@ -25,7 +25,7 @@ func (r limitedCloser) Close() error { return r.rs.Close() } -func OpenStream(ctx rcontext.RequestContext, media *database.DbMedia, startByte int64, endByte int64) (io.ReadCloser, error) { +func OpenStream(ctx rcontext.RequestContext, media *database.Locatable, startByte int64, endByte int64) (io.ReadCloser, error) { reader, err := redislib.TryGetMedia(ctx, media.Sha256Hash, startByte, endByte) if err != nil || reader != nil { return io.NopCloser(reader), err diff --git a/pipelines/_steps/download/try_download.go b/pipelines/_steps/download/try_download.go index a892add75c8490c1ed55cbfbf19a6e2ccd0bcec2..b6de06d9cccdd2ca1fb3f9c2e9ef804739e8aa1b 100644 --- a/pipelines/_steps/download/try_download.go +++ b/pipelines/_steps/download/try_download.go @@ -8,18 +8,16 @@ import ( "net/http" "net/url" "strconv" - "sync" "github.com/prometheus/client_golang/prometheus" "github.com/turt2live/matrix-media-repo/common" - "github.com/turt2live/matrix-media-repo/common/config" "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/database" "github.com/turt2live/matrix-media-repo/datastores" "github.com/turt2live/matrix-media-repo/errcache" "github.com/turt2live/matrix-media-repo/matrix" "github.com/turt2live/matrix-media-repo/metrics" - "github.com/turt2live/matrix-media-repo/pipelines/pipeline_upload" + "github.com/turt2live/matrix-media-repo/pipelines/_steps/datastore_op" "github.com/turt2live/matrix-media-repo/pool" "github.com/turt2live/matrix-media-repo/util" ) @@ -117,57 +115,7 @@ func TryDownload(ctx rcontext.RequestContext, origin string, mediaId string) (*d return nil, nil, res.err } - // At this point, res.r is our http response body. We'll first cache it (getting a temporary stream we'll return - // later), then upload it to persist the record. + // At this point, res.r is our http response body. - dsConf, err := datastores.Pick(ctx, datastores.RemoteMediaKind) - if err != nil { - return nil, nil, err - } - - pr, pw := io.Pipe() - tee := io.TeeReader(res.r, pw) - defer pw.CloseWithError(errors.New("failed to finish write")) - wg := new(sync.WaitGroup) - wg.Add(2) - bufferCh := make(chan downloadResult) - uploadCh := make(chan uploadResult) - defer close(bufferCh) - defer close(uploadCh) - - upstreamClose := func() error { return pw.Close() } - - go func(dsConf config.DatastoreConfig, pr io.ReadCloser, bufferCh chan downloadResult) { - _, _, retReader, err2 := datastores.BufferTemp(dsConf, pr) - // async the channel update to avoid deadlocks - go func(bufferCh chan downloadResult, err2 error, retReader io.ReadCloser) { - bufferCh <- downloadResult{err: err2, r: retReader} - }(bufferCh, err2, retReader) - wg.Done() - }(dsConf, pr, bufferCh) - - go func(ctx rcontext.RequestContext, origin string, mediaId string, r io.ReadCloser, upstreamClose func() error, contentType string, fileName string, uploadCh chan uploadResult) { - m, err2 := pipeline_upload.Execute(ctx, origin, mediaId, r, contentType, fileName, "", datastores.RemoteMediaKind) - // async the channel update to avoid deadlocks - go func(uploadCh chan uploadResult, err2 error, m *database.DbMedia) { - uploadCh <- uploadResult{err: err2, m: m} - }(uploadCh, err2, m) - if err3 := upstreamClose(); err3 != nil { - ctx.Log.Warn("Failed to close non-tee writer during remote download: ", err3) - } - wg.Done() - }(ctx, origin, mediaId, io.NopCloser(tee), upstreamClose, res.contentType, res.filename, uploadCh) - - wg.Wait() - bufferRes := <-bufferCh - uploadRes := <-uploadCh - if bufferRes.err != nil { - return nil, nil, bufferRes.err - } - if uploadRes.err != nil { - defer bufferRes.r.Close() - return nil, nil, uploadRes.err - } - - return uploadRes.m, bufferRes.r, nil + return datastore_op.PutAndReturnStream(ctx, origin, mediaId, res.r, res.contentType, res.filename, datastores.RemoteMediaKind) } diff --git a/pipelines/_steps/thumbnails/generate.go b/pipelines/_steps/thumbnails/generate.go new file mode 100644 index 0000000000000000000000000000000000000000..4fe71b0dac6c38c971c467cda74febcef4bd3802 --- /dev/null +++ b/pipelines/_steps/thumbnails/generate.go @@ -0,0 +1,86 @@ +package thumbnails + +import ( + "io" + + "github.com/turt2live/matrix-media-repo/common" + "github.com/turt2live/matrix-media-repo/common/rcontext" + "github.com/turt2live/matrix-media-repo/database" + "github.com/turt2live/matrix-media-repo/datastores" + "github.com/turt2live/matrix-media-repo/pipelines/_steps/datastore_op" + "github.com/turt2live/matrix-media-repo/pipelines/_steps/download" + "github.com/turt2live/matrix-media-repo/pool" + "github.com/turt2live/matrix-media-repo/thumbnailing" + "github.com/turt2live/matrix-media-repo/thumbnailing/m" + "github.com/turt2live/matrix-media-repo/util" +) + +type generateResult struct { + i *m.Thumbnail + err error +} + +func Generate(ctx rcontext.RequestContext, mediaRecord *database.DbMedia, width int, height int, method string, animated bool) (*database.DbThumbnail, io.ReadCloser, error) { + ch := make(chan generateResult) + defer close(ch) + fn := func() { + mediaStream, err := download.OpenStream(ctx, mediaRecord.Locatable, -1, -1) + if err != nil { + ch <- generateResult{err: err} + return + } + fixedContentType := util.FixContentType(mediaRecord.ContentType) + + i, err := thumbnailing.GenerateThumbnail(mediaStream, fixedContentType, width, height, method, animated, ctx) + if err != nil { + ch <- generateResult{err: err} + return + } + + ch <- generateResult{i: i} + } + + if err := pool.ThumbnailQueue.Schedule(fn); err != nil { + return nil, nil, err + } + res := <-ch + if res.err != nil { + return nil, nil, res.err + } + if res.i == nil { + // Couldn't generate a thumbnail + return nil, nil, common.ErrMediaNotFound + } + + // At this point, res.i is our thumbnail + + thumbMediaRecord, thumbStream, err := datastore_op.PutAndReturnStream(ctx, ctx.Request.Host, "", res.i.Reader, res.i.ContentType, "", datastores.ThumbnailsKind) + if err != nil { + return nil, nil, err + } + + // Create a DbThumbnail + newRecord := &database.DbThumbnail{ + Origin: thumbMediaRecord.Origin, + MediaId: thumbMediaRecord.MediaId, + ContentType: thumbMediaRecord.ContentType, + Width: width, + Height: height, + Method: method, + Animated: res.i.Animated, + SizeBytes: thumbMediaRecord.SizeBytes, + CreationTs: thumbMediaRecord.CreationTs, + Locatable: &database.Locatable{ + Sha256Hash: thumbMediaRecord.Sha256Hash, + DatastoreId: thumbMediaRecord.DatastoreId, + Location: thumbMediaRecord.Location, + }, + } + err = database.GetInstance().Thumbnails.Prepare(ctx).Insert(newRecord) + if err != nil { + defer thumbStream.Close() + return nil, nil, err + } + + return newRecord, thumbStream, nil +} diff --git a/pipelines/_steps/thumbnails/pick_dimensions.go b/pipelines/_steps/thumbnails/pick_dimensions.go new file mode 100644 index 0000000000000000000000000000000000000000..d811ff9ce8fddfb9b362a79a1e147b441c4b8443 --- /dev/null +++ b/pipelines/_steps/thumbnails/pick_dimensions.go @@ -0,0 +1,69 @@ +package thumbnails + +import ( + "errors" + + "github.com/turt2live/matrix-media-repo/common/rcontext" + "github.com/turt2live/matrix-media-repo/util" +) + +func PickNewDimensions(ctx rcontext.RequestContext, desiredWidth int, desiredHeight int, desiredMethod string) (int, int, string, error) { + if desiredWidth <= 0 { + return 0, 0, "", errors.New("width must be positive") + } + if desiredHeight <= 0 { + return 0, 0, "", errors.New("height must be positive") + } + if desiredMethod != "crop" && desiredMethod != "scale" { + return 0, 0, "", errors.New("method must be crop or scale") + } + + foundSize := false + targetWidth := 0 + targetHeight := 0 + largestWidth := 0 + largestHeight := 0 + desiredAspectRatio := float32(desiredWidth) / float32(desiredHeight) + + for _, size := range ctx.Config.Thumbnails.Sizes { + largestWidth = util.MaxInt(largestWidth, size.Width) + largestHeight = util.MaxInt(largestHeight, size.Height) + + // Unlikely, but if we get the exact dimensions then just use that + if desiredWidth == size.Width && desiredHeight == size.Height { + return size.Width, size.Height, desiredMethod, nil + } + + // If we come across a size that's larger than requested, try and use that + if desiredWidth <= size.Width && desiredHeight <= size.Height { + // Only use our new found size if it's smaller than one we've already picked + if !foundSize || (targetWidth > size.Width && targetHeight > size.Height) { + targetWidth = size.Width + targetHeight = size.Height + foundSize = true + } + } + } + + if ctx.Config.Thumbnails.DynamicSizing { + return util.MinInt(largestWidth, desiredWidth), util.MinInt(largestHeight, desiredHeight), desiredMethod, nil + } + + // Use the largest dimensions available if we didn't find anything + if !foundSize { + targetWidth = largestWidth + targetHeight = largestHeight + } + + if desiredMethod == "crop" { + // We need to maintain the aspect ratio of the request + sizeAspect := float32(targetWidth) / float32(targetHeight) + if sizeAspect != desiredAspectRatio { // it's unlikely to match, but we can dream + ratio := util.MinFloat32(float32(targetWidth)/float32(desiredWidth), float32(targetHeight)/float32(desiredHeight)) + targetWidth = int(float32(desiredWidth) * ratio) + targetHeight = int(float32(desiredHeight) * ratio) + } + } + + return targetWidth, targetHeight, desiredMethod, nil +} diff --git a/pipelines/pipeline_download/pipeline.go b/pipelines/pipeline_download/pipeline.go index f678b7488f4f002c6bdd2d90a874efefbaf22064..e4eecf300abf0c8af8eb5482c1afe8ad7abc06d7 100644 --- a/pipelines/pipeline_download/pipeline.go +++ b/pipelines/pipeline_download/pipeline.go @@ -2,15 +2,18 @@ package pipeline_download import ( "context" + "errors" "fmt" "io" "time" + "github.com/getsentry/sentry-go" "github.com/t2bot/go-singleflight-streams" "github.com/turt2live/matrix-media-repo/common" "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/database" "github.com/turt2live/matrix-media-repo/pipelines/_steps/download" + "github.com/turt2live/matrix-media-repo/util" ) var sf = new(sfstreams.Group) @@ -20,27 +23,13 @@ type DownloadOpts struct { StartByte int64 EndByte int64 BlockForReadUntil time.Duration + RecordOnly bool } func (o DownloadOpts) String() string { return fmt.Sprintf("f=%t,s=%d,e=%d,b=%s", o.FetchRemoteIfNeeded, o.StartByte, o.EndByte, o.BlockForReadUntil.String()) } -type cancelCloser struct { - io.ReadCloser - r io.ReadCloser - cancel func() -} - -func (c *cancelCloser) Read(p []byte) (int, error) { - return c.r.Read(p) -} - -func (c *cancelCloser) Close() error { - c.cancel() - return c.r.Close() -} - func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts DownloadOpts) (*database.DbMedia, io.ReadCloser, error) { // Step 1: Make our context a timeout context var cancel context.CancelFunc @@ -63,7 +52,10 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Do } if record != nil { go serveRecord(recordCh, record) // async function to prevent deadlock - return download.OpenStream(ctx, record, opts.StartByte, opts.EndByte) + if opts.RecordOnly { + return nil, nil + } + return download.OpenStream(ctx, record.Locatable, opts.StartByte, opts.EndByte) } // Step 4: Media record unknown - download it (if possible) @@ -75,6 +67,10 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Do return nil, err } go serveRecord(recordCh, record) // async function to prevent deadlock + if opts.RecordOnly { + r.Close() + return nil, nil + } // Step 5: Limit the stream if needed r, err = download.CreateLimitedStream(ctx, r, opts.StartByte, opts.EndByte) @@ -85,11 +81,19 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Do return r, nil }) if err != nil { + cancel() return nil, nil, err } record := <-recordCh - return record, &cancelCloser{ - r: r, - cancel: cancel, - }, nil + if opts.RecordOnly { + if r != nil { + devErr := errors.New("expected no download stream, but got one anyways") + ctx.Log.Warn(devErr) + sentry.CaptureException(devErr) + r.Close() + } + cancel() + return record, nil, nil + } + return record, util.NewCancelCloser(r, cancel), nil } diff --git a/pipelines/pipeline_thumbnail/pipeline.go b/pipelines/pipeline_thumbnail/pipeline.go new file mode 100644 index 0000000000000000000000000000000000000000..4d845f81940191732c3c2cd15b0039c4e4df2e2e --- /dev/null +++ b/pipelines/pipeline_thumbnail/pipeline.go @@ -0,0 +1,126 @@ +package pipeline_thumbnail + +import ( + "context" + "errors" + "fmt" + "io" + + "github.com/getsentry/sentry-go" + sfstreams "github.com/t2bot/go-singleflight-streams" + "github.com/turt2live/matrix-media-repo/common" + "github.com/turt2live/matrix-media-repo/common/rcontext" + "github.com/turt2live/matrix-media-repo/database" + "github.com/turt2live/matrix-media-repo/pipelines/_steps/download" + "github.com/turt2live/matrix-media-repo/pipelines/_steps/thumbnails" + "github.com/turt2live/matrix-media-repo/pipelines/pipeline_download" + "github.com/turt2live/matrix-media-repo/util" +) + +var sf = new(sfstreams.Group) + +// ThumbnailOpts are options for generating a thumbnail +type ThumbnailOpts struct { + pipeline_download.DownloadOpts + Width int + Height int + Method string + Animated bool +} + +func (o ThumbnailOpts) String() string { + return fmt.Sprintf("%s,w=%d,h=%d,m=%s,a=%t", o.DownloadOpts.String(), o.Width, o.Height, o.Method, o.Animated) +} + +func (o ThumbnailOpts) ImpliedDownloadOpts() pipeline_download.DownloadOpts { + return pipeline_download.DownloadOpts{ + FetchRemoteIfNeeded: o.FetchRemoteIfNeeded, + BlockForReadUntil: o.BlockForReadUntil, + RecordOnly: true, + + // We remove the range parameters to ensure we get a useful download stream + StartByte: -1, + EndByte: -1, + } +} + +func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts ThumbnailOpts) (*database.DbThumbnail, io.ReadCloser, error) { + // Step 1: Fix the request parameters + w, h, method, err1 := thumbnails.PickNewDimensions(ctx, opts.Width, opts.Height, opts.Method) + if err1 != nil { + return nil, nil, err1 + } + opts.Width = w + opts.Height = h + opts.Method = method + + // Step 2: Make our context a timeout context + var cancel context.CancelFunc + //goland:noinspection GoVetLostCancel - we handle the function in our custom cancelCloser struct + ctx.Context, cancel = context.WithTimeout(ctx.Context, opts.BlockForReadUntil) + + // Step 3: Join the singleflight queue + recordCh := make(chan *database.DbThumbnail) + defer close(recordCh) + r, err, _ := sf.Do(fmt.Sprintf("%s/%s?%s", origin, mediaId, opts.String()), func() (io.ReadCloser, error) { + serveRecord := func(recordCh chan *database.DbThumbnail, record *database.DbThumbnail) { + recordCh <- record + } + + // Step 4: Get the associated media record (without stream) + mediaRecord, _, err := pipeline_download.Execute(ctx, origin, mediaId, opts.ImpliedDownloadOpts()) + if err != nil { + return nil, err + } + if mediaRecord == nil { + return nil, common.ErrMediaNotFound + } + + // Step 5: Check for quarantine + // TODO: Quarantine + + // Step 6: See if we're lucky enough to already have this thumbnail + thumbDb := database.GetInstance().Thumbnails.Prepare(ctx) + record, err := thumbDb.GetByParams(origin, mediaId, opts.Width, opts.Height, opts.Method, opts.Animated) + if err != nil { + return nil, err + } + if record != nil { + go serveRecord(recordCh, record) // async function to prevent deadlock + if opts.RecordOnly { + return nil, nil + } + return download.OpenStream(ctx, record.Locatable, opts.StartByte, opts.EndByte) + } + + // Step 7: Generate the thumbnail and return that + record, r, err := thumbnails.Generate(ctx, mediaRecord, opts.Width, opts.Height, opts.Method, opts.Animated) + if err != nil { + return nil, err + } + go serveRecord(recordCh, record) + if opts.RecordOnly { + defer r.Close() + return nil, nil + } + + // Step 8: Create a limited stream + return download.CreateLimitedStream(ctx, r, opts.StartByte, opts.EndByte) + }) + if err != nil { + cancel() + return nil, nil, err + } + record := <-recordCh + if opts.RecordOnly { + if r != nil { + devErr := errors.New("expected no thumbnail stream, but got one anyways") + ctx.Log.Warn(devErr) + sentry.CaptureException(devErr) + r.Close() + } + cancel() + return record, nil, nil + } + return record, util.NewCancelCloser(r, cancel), nil +} diff --git a/pipelines/pipeline_upload/pipeline.go b/pipelines/pipeline_upload/pipeline.go index 15a924561155b1146c83ffc75c210dbe1cb6477f..1f095a67e89da836e786e147bd7583d216f3d00c 100644 --- a/pipelines/pipeline_upload/pipeline.go +++ b/pipelines/pipeline_upload/pipeline.go @@ -79,12 +79,14 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, r io.Re UploadName: fileName, ContentType: contentType, UserId: userId, - Sha256Hash: sha256hash, SizeBytes: sizeBytes, CreationTs: util.NowMillis(), Quarantined: false, - DatastoreId: "", // Populated later - Location: "", // Populated later + Locatable: &database.Locatable{ + Sha256Hash: sha256hash, + DatastoreId: "", // Populated later + Location: "", // Populated later + }, } record, perfect, err := upload.FindRecord(ctx, sha256hash, userId, contentType, fileName) if record != nil { diff --git a/pool/init.go b/pool/init.go index 4b683249f4720f21178cb2e7ca40ecdd83ada568..96a5a5c822268d2e27ef6b7bbebcd0a47554f72f 100644 --- a/pool/init.go +++ b/pool/init.go @@ -7,6 +7,7 @@ import ( ) var DownloadQueue *Queue +var ThumbnailQueue *Queue func Init() { var err error @@ -15,12 +16,19 @@ func Init() { logrus.Error("Error setting up downloads queue") logrus.Fatal(err) } + if ThumbnailQueue, err = NewQueue(config.Get().Thumbnails.NumWorkers, "thumbnails"); err != nil { + sentry.CaptureException(err) + logrus.Error("Error setting up thumbnails queue") + logrus.Fatal(err) + } } func AdjustSize() { DownloadQueue.pool.Tune(config.Get().Downloads.NumWorkers) + ThumbnailQueue.pool.Tune(config.Get().Thumbnails.NumWorkers) } func Drain() { DownloadQueue.pool.Release() + ThumbnailQueue.pool.Release() } diff --git a/util/cancel_closer.go b/util/cancel_closer.go new file mode 100644 index 0000000000000000000000000000000000000000..11021243f2f67ab53222465e62093b6a11c208f9 --- /dev/null +++ b/util/cancel_closer.go @@ -0,0 +1,25 @@ +package util + +import "io" + +type CancelCloser struct { + io.ReadCloser + r io.ReadCloser + cancel func() +} + +func NewCancelCloser(r io.ReadCloser, cancel func()) *CancelCloser { + return &CancelCloser{ + r: r, + cancel: cancel, + } +} + +func (c *CancelCloser) Read(p []byte) (int, error) { + return c.r.Read(p) +} + +func (c *CancelCloser) Close() error { + c.cancel() + return c.r.Close() +} diff --git a/util/mime.go b/util/mime.go index 46bc24b4f9af3ff3b9f2268b4ef7767a8a0d7f92..6005858da191c9a3761e430b255d6fe930b1ce0f 100644 --- a/util/mime.go +++ b/util/mime.go @@ -1,9 +1,18 @@ package util import ( + "mime" "strings" ) func FixContentType(ct string) string { return strings.Split(ct, ";")[0] } + +func ExtensionForContentType(ct string) string { + exts, _ := mime.ExtensionsByType(ct) + if exts != nil && len(exts) > 0 { + return exts[0] + } + return ".bin" +} diff --git a/util/time.go b/util/time.go index f9b496da2e221ee87e8d3578540b3b0e1f2cbfe1..d972bafb11012998e7d0bde5f0ffd8bdd3b31cfc 100644 --- a/util/time.go +++ b/util/time.go @@ -1,6 +1,9 @@ package util -import "time" +import ( + "strconv" + "time" +) func NowMillis() int64 { return time.Now().UnixNano() / 1000000 @@ -9,3 +12,21 @@ func NowMillis() int64 { func FromMillis(m int64) time.Time { return time.Unix(0, m*int64(time.Millisecond)) } + +func CalcBlockForDuration(timeoutMs string) (time.Duration, error) { + blockFor := 20 * time.Second + if timeoutMs != "" { + parsed, err := strconv.Atoi(timeoutMs) + if err != nil { + return 0, err + } + if parsed > 0 { + // Limit to 60 seconds + if parsed > 60000 { + parsed = 60000 + } + blockFor = time.Duration(parsed) * time.Millisecond + } + } + return blockFor, nil +}