Skip to content
Snippets Groups Projects
Commit 466f95f7 authored by Travis Ralston's avatar Travis Ralston
Browse files

Fix record + stream handling in download pipeline

parent 8ecdd0e1
No related branches found
No related tags found
No related merge requests found
...@@ -364,6 +364,8 @@ github.com/stvp/tempredis v0.0.0-20181119212430-b82af8480203 h1:QVqDTf3h2WHt08Yu ...@@ -364,6 +364,8 @@ github.com/stvp/tempredis v0.0.0-20181119212430-b82af8480203 h1:QVqDTf3h2WHt08Yu
github.com/stvp/tempredis v0.0.0-20181119212430-b82af8480203/go.mod h1:oqN97ltKNihBbwlX8dLpwxCl3+HnXKV/R0e+sRLd9C8= github.com/stvp/tempredis v0.0.0-20181119212430-b82af8480203/go.mod h1:oqN97ltKNihBbwlX8dLpwxCl3+HnXKV/R0e+sRLd9C8=
github.com/t2bot/go-singleflight-streams v0.0.5 h1:SAjmq+tLOn5K9hKCnrrgnXNdbRI56WqTT9gkUt/P1bI= github.com/t2bot/go-singleflight-streams v0.0.5 h1:SAjmq+tLOn5K9hKCnrrgnXNdbRI56WqTT9gkUt/P1bI=
github.com/t2bot/go-singleflight-streams v0.0.5/go.mod h1:pEIFm6l/utrXZBeP4tkIuMdYwBBN0TTw7feSEhozNzg= github.com/t2bot/go-singleflight-streams v0.0.5/go.mod h1:pEIFm6l/utrXZBeP4tkIuMdYwBBN0TTw7feSEhozNzg=
github.com/t2bot/go-typed-singleflight v0.0.3 h1:TAQyjhfa5BA63BwFTEVY1a4NF07ekX9JRgite5Cbq0A=
github.com/t2bot/go-typed-singleflight v0.0.3/go.mod h1:0SOeDgjEtLYEy1InNhFBCgDx0UjKAqsLzk5MXek/JNw=
github.com/tebeka/strftime v0.1.3 h1:5HQXOqWKYRFfNyBMNVc9z5+QzuBtIXy03psIhtdJYto= github.com/tebeka/strftime v0.1.3 h1:5HQXOqWKYRFfNyBMNVc9z5+QzuBtIXy03psIhtdJYto=
github.com/tebeka/strftime v0.1.3/go.mod h1:7wJm3dZlpr4l/oVK0t1HYIc4rMzQ2XJlOMIUJUJH6XQ= github.com/tebeka/strftime v0.1.3/go.mod h1:7wJm3dZlpr4l/oVK0t1HYIc4rMzQ2XJlOMIUJUJH6XQ=
github.com/testcontainers/testcontainers-go v0.23.0 h1:ERYTSikX01QczBLPZpqsETTBO7lInqEP349phDOVJVs= github.com/testcontainers/testcontainers-go v0.23.0 h1:ERYTSikX01QczBLPZpqsETTBO7lInqEP349phDOVJVs=
......
...@@ -16,9 +16,11 @@ import ( ...@@ -16,9 +16,11 @@ import (
"github.com/turt2live/matrix-media-repo/pipelines/_steps/meta" "github.com/turt2live/matrix-media-repo/pipelines/_steps/meta"
"github.com/turt2live/matrix-media-repo/pipelines/_steps/quarantine" "github.com/turt2live/matrix-media-repo/pipelines/_steps/quarantine"
"github.com/turt2live/matrix-media-repo/util/readers" "github.com/turt2live/matrix-media-repo/util/readers"
"github.com/turt2live/matrix-media-repo/util/sfcache"
) )
var sf = new(sfstreams.Group) var streamSf = new(sfstreams.Group)
var recordSf = sfcache.NewSingleflightCache[*database.DbMedia]()
type DownloadOpts struct { type DownloadOpts struct {
FetchRemoteIfNeeded bool FetchRemoteIfNeeded bool
...@@ -38,28 +40,28 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Do ...@@ -38,28 +40,28 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Do
//goland:noinspection GoVetLostCancel - we handle the function in our custom cancelCloser struct //goland:noinspection GoVetLostCancel - we handle the function in our custom cancelCloser struct
ctx.Context, cancel = context.WithTimeout(ctx.Context, opts.BlockForReadUntil) ctx.Context, cancel = context.WithTimeout(ctx.Context, opts.BlockForReadUntil)
// Step 2: Join the singleflight queue // Step 2: Join the singleflight queue for stream and DB record
recordCh := make(chan *database.DbMedia, 1) sfKey := fmt.Sprintf("%s/%s?%s", origin, mediaId, opts.String())
defer close(recordCh) fetchRecordFn := func() (*database.DbMedia, error) {
r, err, _ := sf.Do(fmt.Sprintf("%s/%s?%s", origin, mediaId, opts.String()), func() (io.ReadCloser, error) {
serveRecord := func(recordCh chan *database.DbMedia, record *database.DbMedia) {
defer func() {
// Don't crash when we send to a closed channel
recover()
}()
recordCh <- record
}
// Step 3: Do we already have the media? Serve it if yes.
mediaDb := database.GetInstance().Media.Prepare(ctx) mediaDb := database.GetInstance().Media.Prepare(ctx)
record, err := mediaDb.GetById(origin, mediaId) record, err := mediaDb.GetById(origin, mediaId)
didAsyncWait := false
handlePossibleRecord:
if err != nil { if err != nil {
return nil, err return nil, err
} }
if record == nil {
return download.WaitForAsyncMedia(ctx, origin, mediaId)
}
return record, nil
}
record, err := recordSf.Do(sfKey, fetchRecordFn)
defer recordSf.ForgetCacheKey(sfKey)
if err != nil {
cancel()
return nil, nil, err
}
r, err, _ := streamSf.Do(sfKey, func() (io.ReadCloser, error) {
// Step 3: Do we already have the media? Serve it if yes.
if record != nil { if record != nil {
go serveRecord(recordCh, record) // async function to prevent deadlock
if record.Quarantined { if record.Quarantined {
return quarantine.ReturnAppropriateThing(ctx, true, opts.RecordOnly, 512, 512, opts.StartByte, opts.EndByte) return quarantine.ReturnAppropriateThing(ctx, true, opts.RecordOnly, 512, 512, opts.StartByte, opts.EndByte)
} }
...@@ -70,14 +72,7 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Do ...@@ -70,14 +72,7 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Do
return download.OpenStream(ctx, record.Locatable, opts.StartByte, opts.EndByte) return download.OpenStream(ctx, record.Locatable, opts.StartByte, opts.EndByte)
} }
// Step 4: Wait for the media, if we can // Step 4: Media record unknown - download it (if possible)
if !didAsyncWait {
record, err = download.WaitForAsyncMedia(ctx, origin, mediaId)
didAsyncWait = true
goto handlePossibleRecord
}
// Step 5: Media record unknown - download it (if possible)
if !opts.FetchRemoteIfNeeded { if !opts.FetchRemoteIfNeeded {
return nil, common.ErrMediaNotFound return nil, common.ErrMediaNotFound
} }
...@@ -85,7 +80,7 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Do ...@@ -85,7 +80,7 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Do
if err != nil { if err != nil {
return nil, err return nil, err
} }
go serveRecord(recordCh, record) // async function to prevent deadlock recordSf.OverwriteCacheKey(sfKey, record)
if record.Quarantined { if record.Quarantined {
return quarantine.ReturnAppropriateThing(ctx, true, opts.RecordOnly, 512, 512, opts.StartByte, opts.EndByte) return quarantine.ReturnAppropriateThing(ctx, true, opts.RecordOnly, 512, 512, opts.StartByte, opts.EndByte)
} }
...@@ -95,7 +90,7 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Do ...@@ -95,7 +90,7 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Do
return nil, nil return nil, nil
} }
// Step 6: Limit the stream if needed // Step 5: Limit the stream if needed
r, err = download.CreateLimitedStream(ctx, r, opts.StartByte, opts.EndByte) r, err = download.CreateLimitedStream(ctx, r, opts.StartByte, opts.EndByte)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -111,7 +106,18 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Do ...@@ -111,7 +106,18 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Do
cancel() cancel()
return nil, nil, err return nil, nil, err
} }
record := <-recordCh if record == nil {
// Re-fetch, hopefully from cache
record, err = recordSf.Do(sfKey, fetchRecordFn)
if err != nil {
cancel()
return nil, nil, err
}
if record == nil {
cancel()
return nil, nil, errors.New("unexpected error: no viable record and no error condition")
}
}
if opts.RecordOnly { if opts.RecordOnly {
if r != nil { if r != nil {
devErr := errors.New("expected no download stream, but got one anyways") devErr := errors.New("expected no download stream, but got one anyways")
......
package sfcache
import (
"sync"
"github.com/t2bot/go-typed-singleflight"
)
type SingleflightCache[T comparable] struct {
sf *typedsf.Group[T]
cache *sync.Map
}
func NewSingleflightCache[T comparable]() *SingleflightCache[T] {
return &SingleflightCache[T]{
sf: new(typedsf.Group[T]),
cache: new(sync.Map),
}
}
func (c *SingleflightCache[T]) Do(key string, fn func() (T, error)) (T, error) {
if v, ok := c.cache.Load(key); ok {
// Safe cast because incorrect types are filtered out before storage
return v.(T), nil
}
var zero T
v, err, _ := c.sf.Do(key, fn)
if err != nil && v != zero {
c.cache.Store(key, v)
}
return v, err
}
func (c *SingleflightCache[T]) OverwriteCacheKey(key string, val T) {
c.cache.Store(key, val)
}
func (c *SingleflightCache[T]) ForgetCacheKey(key string) {
c.cache.Delete(key)
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment