From 1e49d6b69f125bbdf95a965168346201c5a81f3b Mon Sep 17 00:00:00 2001
From: Travis Ralston <travpc@gmail.com>
Date: Thu, 10 Feb 2022 22:13:17 -0700
Subject: [PATCH] Initial support for content ranges

Fixes https://github.com/turt2live/matrix-media-repo/issues/73
---
 CHANGELOG.md                   |   1 +
 api/webserver/route_handler.go | 125 ++++++++++++++++++++++++++++++++-
 common/config/models_domain.go |   5 +-
 config.sample.yaml             |   6 ++
 4 files changed, 132 insertions(+), 5 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 173e1801..e992d1ac 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
 
 * New config option to set user agent when requesting URL previews.
 * Added support for `image/jxl` thumbnailing.
+* Built-in early support for content ranges (being able to skip around in audio and video).
 
 ### Removed
 
diff --git a/api/webserver/route_handler.go b/api/webserver/route_handler.go
index f3227348..3d04714b 100644
--- a/api/webserver/route_handler.go
+++ b/api/webserver/route_handler.go
@@ -8,6 +8,8 @@ import (
 	"fmt"
 	"github.com/getsentry/sentry-go"
 	"io"
+	"io/ioutil"
+	"math"
 	"mime"
 	"net"
 	"net/http"
@@ -85,6 +87,7 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 
 	// Process response
 	var res interface{} = api.AuthFailed()
+	var rctx rcontext.RequestContext
 	if util.IsServerOurs(r.Host) || h.ignoreHost {
 		contextLog.Info("Host is valid - processing request")
 		cfg := config.GetDomain(r.Host)
@@ -100,7 +103,7 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 		ctx = context.WithValue(ctx, "mr.logger", contextLog)
 		ctx = context.WithValue(ctx, "mr.serverConfig", cfg)
 		ctx = context.WithValue(ctx, "mr.request", r)
-		rctx := rcontext.RequestContext{Context: ctx, Log: contextLog, Config: *cfg, Request: r}
+		rctx = rcontext.RequestContext{Context: ctx, Log: contextLog, Config: *cfg, Request: r}
 		r = r.WithContext(rctx)
 
 		metrics.HttpRequests.With(prometheus.Labels{
@@ -164,6 +167,91 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 		}
 		break
 	case *r0.DownloadMediaResponse:
+		// XXX: This range parsing isn't perfect, but works fine enough for now
+		rangeStart := int64(0)
+		rangeEnd := int64(0)
+		grabBytes := int64(0)
+		doRange := false
+		if r.Header.Get("Range") != "" && result.SizeBytes > 0 && rctx.Request != nil && config.Get().Redis.Enabled {
+			rnge := r.Header.Get("Range")
+			if !strings.HasPrefix(rnge, "bytes=") {
+				statusCode = http.StatusRequestedRangeNotSatisfiable
+				res = api.BadRequest("Improper range units")
+				break
+			}
+			if !strings.Contains(rnge, ",") && !strings.HasPrefix(rnge, "bytes=-") {
+				parts := strings.Split(rnge[len("bytes="):], "-")
+				if len(parts) <= 2 {
+					rstart, err := strconv.ParseInt(parts[0], 10, 64)
+					if err != nil {
+						statusCode = http.StatusRequestedRangeNotSatisfiable
+						res = api.BadRequest("Improper start of range")
+						break
+					}
+
+					if rstart < 0 {
+						statusCode = http.StatusRequestedRangeNotSatisfiable
+						res = api.BadRequest("Improper start of range: negative")
+						break
+					}
+
+					rend := int64(-1)
+					if len(parts) > 1 && parts[1] != "" {
+						rend, err = strconv.ParseInt(parts[1], 10, 64)
+						if err != nil {
+							statusCode = http.StatusRequestedRangeNotSatisfiable
+							res = api.BadRequest("Improper end of range")
+							break
+						}
+
+						if rend < 1 {
+							statusCode = http.StatusRequestedRangeNotSatisfiable
+							res = api.BadRequest("Improper end of range: negative")
+							break
+						}
+
+						if rend >= result.SizeBytes {
+							statusCode = http.StatusRequestedRangeNotSatisfiable
+							res = api.BadRequest("Improper end of range: out of bounds")
+							break
+						}
+
+						if rend <= rstart {
+							statusCode = http.StatusRequestedRangeNotSatisfiable
+							res = api.BadRequest("Start must be before end")
+							break
+						}
+
+						if (rstart + rend) >= result.SizeBytes {
+							statusCode = http.StatusRequestedRangeNotSatisfiable
+							res = api.BadRequest("Range too large")
+							break
+						}
+
+						grabBytes = rend - rstart
+					} else {
+						add := int64(10485760) // 10mb default
+						if rctx.Config.Downloads.DefaultRangeChunkSizeBytes > 0 {
+							add = rctx.Config.Downloads.DefaultRangeChunkSizeBytes
+						}
+						rend = int64(math.Min(float64(rstart+add), float64(result.SizeBytes-1)))
+						grabBytes = (rend - rstart) + 1
+					}
+
+					rangeStart = rstart
+					rangeEnd = rend
+
+					if (rangeEnd-rangeStart) <= 0 || grabBytes <= 0 {
+						statusCode = http.StatusRequestedRangeNotSatisfiable
+						res = api.BadRequest("Range invalid at last pass")
+						break
+					}
+
+					doRange = true
+				}
+			}
+		}
+
 		metrics.HttpResponses.With(prometheus.Labels{
 			"host":       r.Host,
 			"action":     h.action,
@@ -187,6 +275,9 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 		w.Header().Set("Cache-Control", "private, max-age=259200") // 3 days
 		w.Header().Set("Content-Type", contentType)
 		if result.SizeBytes > 0 {
+			if config.Get().Redis.Enabled {
+				w.Header().Set("Accept-Ranges", "bytes")
+			}
 			w.Header().Set("Content-Length", fmt.Sprint(result.SizeBytes))
 		}
 		disposition := result.TargetDisposition
@@ -222,8 +313,36 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 		} else {
 			w.Header().Set("Content-Disposition", disposition+"; filename*=utf-8''"+url.QueryEscape(fname))
 		}
+
 		defer result.Data.Close()
-		writeResponseData(w, result.Data, result.SizeBytes)
+
+		if doRange {
+			_, err = io.CopyN(ioutil.Discard, result.Data, rangeStart)
+			if err != nil {
+				// Should only blow up this request
+				panic(err)
+			}
+
+			expectedBytes := grabBytes
+			w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", rangeStart, rangeEnd, result.SizeBytes))
+			w.Header().Set("Content-Length", fmt.Sprint(expectedBytes))
+			w.WriteHeader(http.StatusPartialContent)
+			b, err := io.CopyN(w, result.Data, expectedBytes)
+			if err != nil {
+				// Should only blow up this request
+				panic(err)
+			}
+
+			// Discard anything that remains
+			_, _ = io.Copy(ioutil.Discard, result.Data)
+
+			if expectedBytes > 0 && b != expectedBytes {
+				// Should only blow up this request
+				panic(errors.New("mismatch transfer size"))
+			}
+		} else {
+			writeResponseData(w, result.Data, result.SizeBytes)
+		}
 		return // Prevent sending conflicting responses
 	case *r0.IdenticonResponse:
 		metrics.HttpResponses.With(prometheus.Labels{
@@ -245,7 +364,7 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 		}).Inc()
 		w.Header().Set("Cache-Control", "private, max-age=259200") // 3 days
 		w.Header().Set("Content-Type", "text/html; charset=UTF-8")
-		w.Header().Set("Content-Security-Policy", "") // We're serving HTML, so take away the CSP
+		w.Header().Set("Content-Security-Policy", "")   // We're serving HTML, so take away the CSP
 		w.Header().Set("X-Content-Security-Policy", "") // We're serving HTML, so take away the CSP
 		io.Copy(w, bytes.NewBuffer([]byte(result.HTML)))
 		return
diff --git a/common/config/models_domain.go b/common/config/models_domain.go
index c58ef8b0..32427247 100644
--- a/common/config/models_domain.go
+++ b/common/config/models_domain.go
@@ -31,8 +31,9 @@ type DatastoreConfig struct {
 }
 
 type DownloadsConfig struct {
-	MaxSizeBytes        int64 `yaml:"maxBytes"`
-	FailureCacheMinutes int   `yaml:"failureCacheMinutes"`
+	MaxSizeBytes               int64 `yaml:"maxBytes"`
+	FailureCacheMinutes        int   `yaml:"failureCacheMinutes"`
+	DefaultRangeChunkSizeBytes int64 `yaml:"defaultRangeChunkSizeBytes"`
 }
 
 type ThumbnailsConfig struct {
diff --git a/config.sample.yaml b/config.sample.yaml
index 8fc5c62c..670aaf84 100644
--- a/config.sample.yaml
+++ b/config.sample.yaml
@@ -262,6 +262,12 @@ downloads:
   # negative to disable. Defaults to disabled.
   expireAfterDays: 0
 
+  # The default size, in bytes, to return for range requests on media. Range requests are used
+  # by clients when they only need part of a file, such as a video or audio element. Note that
+  # the entire file will still be cached (if enabled), but only part of it will be returned.
+  # If the client requests a larger or smaller range, that will be honoured.
+  defaultRangeChunkSizeBytes: 10485760 # 10MB default
+
 # URL Preview settings
 urlPreviews:
   enabled: true # If enabled, the preview_url routes will be accessible
-- 
GitLab