From d6baaf8b97d282bcb96d39fc40a5ea2191703605 Mon Sep 17 00:00:00 2001
From: Travis Ralston <travpc@gmail.com>
Date: Fri, 30 Apr 2021 14:55:38 -0600
Subject: [PATCH] Fix crashes when workers panic

This is probably super ugly Go code, but it works.
---
 .../download_resource_handler.go              | 59 ++++++++++++-------
 .../preview_resource_handler.go               | 23 ++++++--
 .../thumbnail_resource_handler.go             | 23 ++++++--
 util/errors.go                                | 14 +++++
 4 files changed, 91 insertions(+), 28 deletions(-)
 create mode 100644 util/errors.go

diff --git a/controllers/download_controller/download_resource_handler.go b/controllers/download_controller/download_resource_handler.go
index 122ace06..35358ecd 100644
--- a/controllers/download_controller/download_resource_handler.go
+++ b/controllers/download_controller/download_resource_handler.go
@@ -3,6 +3,7 @@ package download_controller
 import (
 	"errors"
 	"github.com/getsentry/sentry-go"
+	"github.com/turt2live/matrix-media-repo/util"
 	"io"
 	"io/ioutil"
 	"mime"
@@ -74,7 +75,9 @@ var downloadErrorCacheSingletonLock = &sync.Once{}
 func getResourceHandler() *mediaResourceHandler {
 	if resHandler == nil {
 		resHandlerLock.Do(func() {
-			handler, err := resource_handler.New(config.Get().Downloads.NumWorkers, downloadResourceWorkFn)
+			handler, err := resource_handler.New(config.Get().Downloads.NumWorkers, func(r *resource_handler.WorkRequest) interface{} {
+				return downloadResourceWorkFn(r)
+			})
 			if err != nil {
 				sentry.CaptureException(err)
 				panic(err)
@@ -119,7 +122,7 @@ func (h *mediaResourceHandler) DownloadRemoteMedia(origin string, mediaId string
 	return resultChan
 }
 
-func downloadResourceWorkFn(request *resource_handler.WorkRequest) interface{} {
+func downloadResourceWorkFn(request *resource_handler.WorkRequest) (resp *workerDownloadResponse) {
 	info := request.Metadata.(*downloadRequest)
 	ctx := rcontext.Initial().LogWithFields(logrus.Fields{
 		"worker_requestId":      request.Id,
@@ -127,14 +130,29 @@ func downloadResourceWorkFn(request *resource_handler.WorkRequest) interface{} {
 		"worker_requestMediaId": info.mediaId,
 		"worker_blockForMedia":  info.blockForMedia,
 	})
+
+	resp = &workerDownloadResponse{}
+	defer func() {
+		if err := recover(); err != nil {
+			ctx.Log.Error("Caught panic: ", err)
+			sentry.CurrentHub().Recover(err)
+			resp.stream = nil
+			resp.filename = ""
+			resp.contentType = ""
+			resp.media = nil
+			resp.err = util.PanicToError(err)
+		}
+	}()
+
 	ctx.Log.Info("Downloading remote media")
 
 	downloaded, err := DownloadRemoteMediaDirect(info.origin, info.mediaId, ctx)
 	if err != nil {
-		return &workerDownloadResponse{err: err}
+		resp.err = err
+		return resp
 	}
 
-	persistFile := func(fileStream io.ReadCloser) *workerDownloadResponse {
+	persistFile := func(fileStream io.ReadCloser, r *workerDownloadResponse) *workerDownloadResponse {
 		defer cleanup.DumpAndCloseStream(fileStream)
 		userId := upload_controller.NoApplicableUploadUser
 
@@ -145,27 +163,29 @@ func downloadResourceWorkFn(request *resource_handler.WorkRequest) interface{} {
 		st, err := ms.NextReader()
 		if err != nil {
 			ctx.Log.Error("Unexpected error persisting file: ", err)
-			return &workerDownloadResponse{err: err}
+			r.err = err
+			return r
 		}
 
 		media, err := upload_controller.StoreDirect(nil, st, downloaded.ContentLength, downloaded.ContentType, downloaded.DesiredFilename, userId, info.origin, info.mediaId, common.KindRemoteMedia, ctx, true)
 		if err != nil {
 			ctx.Log.Error("Error persisting file: ", err)
-			return &workerDownloadResponse{err: err}
+			r.err = err
+			return r
 		}
 
 		ctx.Log.Info("Remote media persisted under datastore ", media.DatastoreId, " at ", media.Location)
-		return &workerDownloadResponse{
-			media: media,
-			contentType: media.ContentType,
-			filename: media.UploadName,
-			stream: ms,
-		}
+		r.media = media
+		r.contentType = media.ContentType
+		r.filename = media.UploadName
+		r.stream = ms
+		return r
 	}
 
 	if info.blockForMedia {
 		ctx.Log.Warn("Not streaming remote media download request due to request for a block")
-		return persistFile(downloaded.Contents)
+		persistFile(downloaded.Contents, resp)
+		return resp
 	}
 
 	ctx.Log.Info("Streaming remote media to filesystem and requesting party at the same time")
@@ -173,18 +193,17 @@ func downloadResourceWorkFn(request *resource_handler.WorkRequest) interface{} {
 	reader, writer := io.Pipe()
 	tr := io.TeeReader(downloaded.Contents, writer)
 
-	go persistFile(ioutil.NopCloser(tr))
+	go persistFile(ioutil.NopCloser(tr), &workerDownloadResponse{})
 
 	ms := stream.NewMemStream()
 	defer ms.Close()
 	io.Copy(ms, reader)
 
-	return &workerDownloadResponse{
-		err:         nil,
-		contentType: downloaded.ContentType,
-		filename:    downloaded.DesiredFilename,
-		stream:      ms,
-	}
+	resp.err = nil
+	resp.contentType = downloaded.ContentType
+	resp.filename = downloaded.DesiredFilename
+	resp.stream = ms
+	return resp
 }
 
 func DownloadRemoteMediaDirect(server string, mediaId string, ctx rcontext.RequestContext) (*downloadedMedia, error) {
diff --git a/controllers/preview_controller/preview_resource_handler.go b/controllers/preview_controller/preview_resource_handler.go
index c5517915..86c42b9d 100644
--- a/controllers/preview_controller/preview_resource_handler.go
+++ b/controllers/preview_controller/preview_resource_handler.go
@@ -44,7 +44,9 @@ var resHandlerSingletonLock = &sync.Once{}
 func getResourceHandler() *urlResourceHandler {
 	if resHandlerInstance == nil {
 		resHandlerSingletonLock.Do(func() {
-			handler, err := resource_handler.New(config.Get().UrlPreviews.NumWorkers, urlPreviewWorkFn)
+			handler, err := resource_handler.New(config.Get().UrlPreviews.NumWorkers, func(r *resource_handler.WorkRequest) interface{} {
+				return urlPreviewWorkFn(r)
+			})
 			if err != nil {
 				sentry.CaptureException(err)
 				panic(err)
@@ -57,12 +59,23 @@ func getResourceHandler() *urlResourceHandler {
 	return resHandlerInstance
 }
 
-func urlPreviewWorkFn(request *resource_handler.WorkRequest) interface{} {
+func urlPreviewWorkFn(request *resource_handler.WorkRequest) (resp *urlPreviewResponse) {
 	info := request.Metadata.(*urlPreviewRequest)
 	ctx := rcontext.Initial().LogWithFields(logrus.Fields{
 		"worker_requestId": request.Id,
 		"worker_url":       info.urlPayload.UrlString,
 	})
+
+	resp = &urlPreviewResponse{}
+	defer func() {
+		if err := recover(); err != nil {
+			ctx.Log.Error("Caught panic: ", err)
+			sentry.CurrentHub().Recover(err)
+			resp.preview = nil
+			resp.err = util.PanicToError(err)
+		}
+	}()
+
 	ctx.Log.Info("Processing url preview request")
 
 	db := storage.GetDatabase().GetUrlStore(ctx)
@@ -102,7 +115,8 @@ func urlPreviewWorkFn(request *resource_handler.WorkRequest) interface{} {
 		} else {
 			db.InsertPreviewError(info.urlPayload.UrlString, common.ErrCodeUnknown)
 		}
-		return &urlPreviewResponse{err: err}
+		resp.err = err
+		return resp
 	}
 
 	result := &types.UrlPreview{
@@ -158,7 +172,8 @@ func urlPreviewWorkFn(request *resource_handler.WorkRequest) interface{} {
 		// Non-fatal: Just report it and move on. The worst that happens is we re-cache it.
 	}
 
-	return &urlPreviewResponse{preview: result}
+	resp.preview = result
+	return resp
 }
 
 func (h *urlResourceHandler) GeneratePreview(urlPayload *preview_types.UrlPayload, forUserId string, onHost string, languageHeader string, allowOEmbed bool) chan *urlPreviewResponse {
diff --git a/controllers/thumbnail_controller/thumbnail_resource_handler.go b/controllers/thumbnail_controller/thumbnail_resource_handler.go
index b03b353c..7d084f9e 100644
--- a/controllers/thumbnail_controller/thumbnail_resource_handler.go
+++ b/controllers/thumbnail_controller/thumbnail_resource_handler.go
@@ -54,7 +54,9 @@ var resHandlerSingletonLock = &sync.Once{}
 func getResourceHandler() *thumbnailResourceHandler {
 	if resHandlerInstance == nil {
 		resHandlerSingletonLock.Do(func() {
-			handler, err := resource_handler.New(config.Get().Thumbnails.NumWorkers, thumbnailWorkFn)
+			handler, err := resource_handler.New(config.Get().Thumbnails.NumWorkers, func(r *resource_handler.WorkRequest) interface{} {
+				return thumbnailWorkFn(r)
+			})
 			if err != nil {
 				sentry.CaptureException(err)
 				panic(err)
@@ -67,7 +69,7 @@ func getResourceHandler() *thumbnailResourceHandler {
 	return resHandlerInstance
 }
 
-func thumbnailWorkFn(request *resource_handler.WorkRequest) interface{} {
+func thumbnailWorkFn(request *resource_handler.WorkRequest) (resp *thumbnailResponse) {
 	info := request.Metadata.(*thumbnailRequest)
 	ctx := rcontext.Initial().LogWithFields(logrus.Fields{
 		"worker_requestId": request.Id,
@@ -77,6 +79,17 @@ func thumbnailWorkFn(request *resource_handler.WorkRequest) interface{} {
 		"worker_method":    info.method,
 		"worker_animated":  info.animated,
 	})
+
+	resp = &thumbnailResponse{}
+	defer func() {
+		if err := recover(); err != nil {
+			ctx.Log.Error("Caught panic: ", err)
+			sentry.CurrentHub().Recover(err)
+			resp.thumbnail = nil
+			resp.err = util.PanicToError(err)
+		}
+	}()
+
 	ctx.Log.Info("Processing thumbnail request")
 
 	generated, err := GenerateThumbnail(info.media, info.width, info.height, info.method, info.animated, ctx)
@@ -111,10 +124,12 @@ func thumbnailWorkFn(request *resource_handler.WorkRequest) interface{} {
 	err = db.Insert(newThumb)
 	if err != nil {
 		ctx.Log.Error("Unexpected error caching thumbnail: " + err.Error())
-		return &thumbnailResponse{err: err}
+		resp.err = err
+	} else {
+		resp.thumbnail = newThumb
 	}
 
-	return &thumbnailResponse{thumbnail: newThumb}
+	return resp
 }
 
 func (h *thumbnailResourceHandler) GenerateThumbnail(media *types.Media, width int, height int, method string, animated bool) chan *thumbnailResponse {
diff --git a/util/errors.go b/util/errors.go
new file mode 100644
index 00000000..8393f82b
--- /dev/null
+++ b/util/errors.go
@@ -0,0 +1,14 @@
+package util
+
+import "errors"
+
+func PanicToError(err interface{}) error {
+	switch x := err.(type) {
+	case string:
+		return errors.New(x)
+	case error:
+		return x
+	default:
+		return errors.New("unknown panic")
+	}
+}
-- 
GitLab