From eb5b86ffe20a552c6f9e6a3a523539561741da74 Mon Sep 17 00:00:00 2001
From: Philipp Heckel <pheckel@datto.com>
Date: Sun, 2 Jan 2022 23:56:12 +0100
Subject: [PATCH] WIP: Attachments

---
 .gitignore                       |   1 +
 cmd/serve.go                     |   5 +-
 docs/publish.md                  |   1 +
 server/config.go                 |  10 ++-
 server/message.go                |  23 ++++--
 server/server.go                 | 128 ++++++++++++++++++++++++++-----
 server/server_test.go            |  11 +--
 util/content_type_writer.go      |  41 ++++++++++
 util/content_type_writer_test.go |  50 ++++++++++++
 util/limit.go                    |  41 ++++++++++
 util/limit_test.go               |  63 ++++++++++++++-
 util/peak.go                     |  61 +++++++++++++++
 util/peak_test.go                |  55 +++++++++++++
 13 files changed, 444 insertions(+), 46 deletions(-)
 create mode 100644 util/content_type_writer.go
 create mode 100644 util/content_type_writer_test.go
 create mode 100644 util/peak.go
 create mode 100644 util/peak_test.go

diff --git a/.gitignore b/.gitignore
index 6dffcf5..6d12c73 100644
--- a/.gitignore
+++ b/.gitignore
@@ -3,4 +3,5 @@ build/
 .idea/
 server/docs/
 tools/fbsend/fbsend
+playground/
 *.iml
diff --git a/cmd/serve.go b/cmd/serve.go
index 5545206..f4161e1 100644
--- a/cmd/serve.go
+++ b/cmd/serve.go
@@ -20,6 +20,7 @@ var flagsServe = []cli.Flag{
 	altsrc.NewStringFlag(&cli.StringFlag{Name: "firebase-key-file", Aliases: []string{"F"}, EnvVars: []string{"NTFY_FIREBASE_KEY_FILE"}, Usage: "Firebase credentials file; if set additionally publish to FCM topic"}),
 	altsrc.NewStringFlag(&cli.StringFlag{Name: "cache-file", Aliases: []string{"C"}, EnvVars: []string{"NTFY_CACHE_FILE"}, Usage: "cache file used for message caching"}),
 	altsrc.NewDurationFlag(&cli.DurationFlag{Name: "cache-duration", Aliases: []string{"b"}, EnvVars: []string{"NTFY_CACHE_DURATION"}, Value: server.DefaultCacheDuration, Usage: "buffer messages for this time to allow `since` requests"}),
+	altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-cache-dir", EnvVars: []string{"NTFY_ATTACHMENT_CACHE_DIR"}, Usage: "cache directory for attached files"}),
 	altsrc.NewDurationFlag(&cli.DurationFlag{Name: "keepalive-interval", Aliases: []string{"k"}, EnvVars: []string{"NTFY_KEEPALIVE_INTERVAL"}, Value: server.DefaultKeepaliveInterval, Usage: "interval of keepalive messages"}),
 	altsrc.NewDurationFlag(&cli.DurationFlag{Name: "manager-interval", Aliases: []string{"m"}, EnvVars: []string{"NTFY_MANAGER_INTERVAL"}, Value: server.DefaultManagerInterval, Usage: "interval of for message pruning and stats printing"}),
 	altsrc.NewStringFlag(&cli.StringFlag{Name: "smtp-sender-addr", EnvVars: []string{"NTFY_SMTP_SENDER_ADDR"}, Usage: "SMTP server address (host:port) for outgoing emails"}),
@@ -69,6 +70,7 @@ func execServe(c *cli.Context) error {
 	firebaseKeyFile := c.String("firebase-key-file")
 	cacheFile := c.String("cache-file")
 	cacheDuration := c.Duration("cache-duration")
+	attachmentCacheDir := c.String("attachment-cache-dir")
 	keepaliveInterval := c.Duration("keepalive-interval")
 	managerInterval := c.Duration("manager-interval")
 	smtpSenderAddr := c.String("smtp-sender-addr")
@@ -117,6 +119,7 @@ func execServe(c *cli.Context) error {
 	conf.FirebaseKeyFile = firebaseKeyFile
 	conf.CacheFile = cacheFile
 	conf.CacheDuration = cacheDuration
+	conf.AttachmentCacheDir = attachmentCacheDir
 	conf.KeepaliveInterval = keepaliveInterval
 	conf.ManagerInterval = managerInterval
 	conf.SMTPSenderAddr = smtpSenderAddr
@@ -126,7 +129,7 @@ func execServe(c *cli.Context) error {
 	conf.SMTPServerListen = smtpServerListen
 	conf.SMTPServerDomain = smtpServerDomain
 	conf.SMTPServerAddrPrefix = smtpServerAddrPrefix
-	conf.GlobalTopicLimit = globalTopicLimit
+	conf.TotalTopicLimit = globalTopicLimit
 	conf.VisitorSubscriptionLimit = visitorSubscriptionLimit
 	conf.VisitorRequestLimitBurst = visitorRequestLimitBurst
 	conf.VisitorRequestLimitReplenish = visitorRequestLimitReplenish
diff --git a/docs/publish.md b/docs/publish.md
index ec017e0..46a5a33 100644
--- a/docs/publish.md
+++ b/docs/publish.md
@@ -886,3 +886,4 @@ and can be passed as **HTTP headers** or **query parameters in the URL**. They a
 | `X-Email` | `X-E-Mail`, `Email`, `E-Mail`, `mail`, `e` | E-mail address for [e-mail notifications](#e-mail-notifications) |
 | `X-Cache` | `Cache` | Allows disabling [message caching](#message-caching) |
 | `X-Firebase` | `Firebase` | Allows disabling [sending to Firebase](#disable-firebase) |
+| `X-UnifiedPush` | `UnifiedPush`, `up` | XXXXXXXXXXXXXXXX |
diff --git a/server/config.go b/server/config.go
index 30f937e..68a911f 100644
--- a/server/config.go
+++ b/server/config.go
@@ -14,6 +14,7 @@ const (
 	DefaultMinDelay                  = 10 * time.Second
 	DefaultMaxDelay                  = 3 * 24 * time.Hour
 	DefaultMessageLimit              = 4096
+	DefaultAttachmentSizeLimit       = 5 * 1024 * 1024
 	DefaultFirebaseKeepaliveInterval = 3 * time.Hour // Not too frequently to save battery
 )
 
@@ -41,6 +42,8 @@ type Config struct {
 	FirebaseKeyFile              string
 	CacheFile                    string
 	CacheDuration                time.Duration
+	AttachmentCacheDir           string
+	AttachmentSizeLimit          int64
 	KeepaliveInterval            time.Duration
 	ManagerInterval              time.Duration
 	AtSenderInterval             time.Duration
@@ -55,7 +58,8 @@ type Config struct {
 	MessageLimit                 int
 	MinDelay                     time.Duration
 	MaxDelay                     time.Duration
-	GlobalTopicLimit             int
+	TotalTopicLimit              int
+	TotalAttachmentSizeLimit     int64
 	VisitorRequestLimitBurst     int
 	VisitorRequestLimitReplenish time.Duration
 	VisitorEmailLimitBurst       int
@@ -75,6 +79,8 @@ func NewConfig() *Config {
 		FirebaseKeyFile:              "",
 		CacheFile:                    "",
 		CacheDuration:                DefaultCacheDuration,
+		AttachmentCacheDir:           "",
+		AttachmentSizeLimit:          DefaultAttachmentSizeLimit,
 		KeepaliveInterval:            DefaultKeepaliveInterval,
 		ManagerInterval:              DefaultManagerInterval,
 		MessageLimit:                 DefaultMessageLimit,
@@ -82,7 +88,7 @@ func NewConfig() *Config {
 		MaxDelay:                     DefaultMaxDelay,
 		AtSenderInterval:             DefaultAtSenderInterval,
 		FirebaseKeepaliveInterval:    DefaultFirebaseKeepaliveInterval,
-		GlobalTopicLimit:             DefaultGlobalTopicLimit,
+		TotalTopicLimit:              DefaultGlobalTopicLimit,
 		VisitorRequestLimitBurst:     DefaultVisitorRequestLimitBurst,
 		VisitorRequestLimitReplenish: DefaultVisitorRequestLimitReplenish,
 		VisitorEmailLimitBurst:       DefaultVisitorEmailLimitBurst,
diff --git a/server/message.go b/server/message.go
index ad870e0..2c3fb19 100644
--- a/server/message.go
+++ b/server/message.go
@@ -18,14 +18,21 @@ const (
 
 // message represents a message published to a topic
 type message struct {
-	ID       string   `json:"id"`    // Random message ID
-	Time     int64    `json:"time"`  // Unix time in seconds
-	Event    string   `json:"event"` // One of the above
-	Topic    string   `json:"topic"`
-	Priority int      `json:"priority,omitempty"`
-	Tags     []string `json:"tags,omitempty"`
-	Title    string   `json:"title,omitempty"`
-	Message  string   `json:"message,omitempty"`
+	ID         string      `json:"id"`    // Random message ID
+	Time       int64       `json:"time"`  // Unix time in seconds
+	Event      string      `json:"event"` // One of the above
+	Topic      string      `json:"topic"`
+	Priority   int         `json:"priority,omitempty"`
+	Tags       []string    `json:"tags,omitempty"`
+	Title      string      `json:"title,omitempty"`
+	Message    string      `json:"message,omitempty"`
+	Attachment *attachment `json:"attachment,omitempty"`
+}
+
+type attachment struct {
+	Name string `json:"name"`
+	Type string `json:"type"`
+	URL  string `json:"url"`
 }
 
 // messageEncoder is a function that knows how to encode a message
diff --git a/server/server.go b/server/server.go
index 9cf76de..b8ca70f 100644
--- a/server/server.go
+++ b/server/server.go
@@ -15,14 +15,18 @@ import (
 	"html/template"
 	"io"
 	"log"
+	"mime"
 	"net"
 	"net/http"
 	"net/http/httptest"
+	"os"
+	"path/filepath"
 	"regexp"
 	"strconv"
 	"strings"
 	"sync"
 	"time"
+	"unicode/utf8"
 )
 
 // TODO add "max messages in a topic" limit
@@ -96,7 +100,8 @@ var (
 
 	staticRegex      = regexp.MustCompile(`^/static/.+`)
 	docsRegex        = regexp.MustCompile(`^/docs(|/.*)$`)
-	disallowedTopics = []string{"docs", "static"}
+	fileRegex        = regexp.MustCompile(`^/file/([-_A-Za-z0-9]{1,64})(?:\.[A-Za-z0-9]{1,16})?$`)
+	disallowedTopics = []string{"docs", "static", "file"}
 
 	templateFnMap = template.FuncMap{
 		"durationToHuman": util.DurationToHuman,
@@ -117,22 +122,26 @@ var (
 	docsStaticFs     embed.FS
 	docsStaticCached = &util.CachingEmbedFS{ModTime: time.Now(), FS: docsStaticFs}
 
-	errHTTPNotFound                          = &errHTTP{40401, http.StatusNotFound, "page not found", ""}
-	errHTTPTooManyRequestsLimitRequests      = &errHTTP{42901, http.StatusTooManyRequests, "limit reached: too many requests, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
-	errHTTPTooManyRequestsLimitEmails        = &errHTTP{42902, http.StatusTooManyRequests, "limit reached: too many emails, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
-	errHTTPTooManyRequestsLimitSubscriptions = &errHTTP{42903, http.StatusTooManyRequests, "limit reached: too many active subscriptions, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
-	errHTTPTooManyRequestsLimitGlobalTopics  = &errHTTP{42904, http.StatusTooManyRequests, "limit reached: the total number of topics on the server has been reached, please contact the admin", "https://ntfy.sh/docs/publish/#limitations"}
-	errHTTPBadRequestEmailDisabled           = &errHTTP{40001, http.StatusBadRequest, "e-mail notifications are not enabled", "https://ntfy.sh/docs/config/#e-mail-notifications"}
-	errHTTPBadRequestDelayNoCache            = &errHTTP{40002, http.StatusBadRequest, "cannot disable cache for delayed message", ""}
-	errHTTPBadRequestDelayNoEmail            = &errHTTP{40003, http.StatusBadRequest, "delayed e-mail notifications are not supported", ""}
-	errHTTPBadRequestDelayCannotParse        = &errHTTP{40004, http.StatusBadRequest, "invalid delay parameter: unable to parse delay", "https://ntfy.sh/docs/publish/#scheduled-delivery"}
-	errHTTPBadRequestDelayTooSmall           = &errHTTP{40005, http.StatusBadRequest, "invalid delay parameter: too small, please refer to the docs", "https://ntfy.sh/docs/publish/#scheduled-delivery"}
-	errHTTPBadRequestDelayTooLarge           = &errHTTP{40006, http.StatusBadRequest, "invalid delay parameter: too large, please refer to the docs", "https://ntfy.sh/docs/publish/#scheduled-delivery"}
-	errHTTPBadRequestPriorityInvalid         = &errHTTP{40007, http.StatusBadRequest, "invalid priority parameter", "https://ntfy.sh/docs/publish/#message-priority"}
-	errHTTPBadRequestSinceInvalid            = &errHTTP{40008, http.StatusBadRequest, "invalid since parameter", "https://ntfy.sh/docs/subscribe/api/#fetch-cached-messages"}
-	errHTTPBadRequestTopicInvalid            = &errHTTP{40009, http.StatusBadRequest, "invalid topic: path invalid", ""}
-	errHTTPBadRequestTopicDisallowed         = &errHTTP{40010, http.StatusBadRequest, "invalid topic: topic name is disallowed", ""}
-	errHTTPInternalError                     = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""}
+	errHTTPNotFound                               = &errHTTP{40401, http.StatusNotFound, "page not found", ""}
+	errHTTPTooManyRequestsLimitRequests           = &errHTTP{42901, http.StatusTooManyRequests, "limit reached: too many requests, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
+	errHTTPTooManyRequestsLimitEmails             = &errHTTP{42902, http.StatusTooManyRequests, "limit reached: too many emails, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
+	errHTTPTooManyRequestsLimitSubscriptions      = &errHTTP{42903, http.StatusTooManyRequests, "limit reached: too many active subscriptions, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
+	errHTTPTooManyRequestsLimitGlobalTopics       = &errHTTP{42904, http.StatusTooManyRequests, "limit reached: the total number of topics on the server has been reached, please contact the admin", "https://ntfy.sh/docs/publish/#limitations"}
+	errHTTPBadRequestEmailDisabled                = &errHTTP{40001, http.StatusBadRequest, "e-mail notifications are not enabled", "https://ntfy.sh/docs/config/#e-mail-notifications"}
+	errHTTPBadRequestDelayNoCache                 = &errHTTP{40002, http.StatusBadRequest, "cannot disable cache for delayed message", ""}
+	errHTTPBadRequestDelayNoEmail                 = &errHTTP{40003, http.StatusBadRequest, "delayed e-mail notifications are not supported", ""}
+	errHTTPBadRequestDelayCannotParse             = &errHTTP{40004, http.StatusBadRequest, "invalid delay parameter: unable to parse delay", "https://ntfy.sh/docs/publish/#scheduled-delivery"}
+	errHTTPBadRequestDelayTooSmall                = &errHTTP{40005, http.StatusBadRequest, "invalid delay parameter: too small, please refer to the docs", "https://ntfy.sh/docs/publish/#scheduled-delivery"}
+	errHTTPBadRequestDelayTooLarge                = &errHTTP{40006, http.StatusBadRequest, "invalid delay parameter: too large, please refer to the docs", "https://ntfy.sh/docs/publish/#scheduled-delivery"}
+	errHTTPBadRequestPriorityInvalid              = &errHTTP{40007, http.StatusBadRequest, "invalid priority parameter", "https://ntfy.sh/docs/publish/#message-priority"}
+	errHTTPBadRequestSinceInvalid                 = &errHTTP{40008, http.StatusBadRequest, "invalid since parameter", "https://ntfy.sh/docs/subscribe/api/#fetch-cached-messages"}
+	errHTTPBadRequestTopicInvalid                 = &errHTTP{40009, http.StatusBadRequest, "invalid topic: path invalid", ""}
+	errHTTPBadRequestTopicDisallowed              = &errHTTP{40010, http.StatusBadRequest, "invalid topic: topic name is disallowed", ""}
+	errHTTPBadRequestAttachmentsDisallowed        = &errHTTP{40011, http.StatusBadRequest, "attachments disallowed", ""}
+	errHTTPBadRequestAttachmentsPublishDisallowed = &errHTTP{40011, http.StatusBadRequest, "invalid message: invalid encoding or too large, and attachments are not allowed", ""}
+	errHTTPBadRequestMessageTooLarge              = &errHTTP{40013, http.StatusBadRequest, "invalid message: too large", ""}
+	errHTTPInternalError                          = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""}
+	errHTTPInternalErrorInvalidFilePath           = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid file path", ""}
 )
 
 const (
@@ -163,6 +172,11 @@ func New(conf *Config) (*Server, error) {
 	if err != nil {
 		return nil, err
 	}
+	if conf.AttachmentCacheDir != "" {
+		if err := os.MkdirAll(conf.AttachmentCacheDir, 0700); err != nil {
+			return nil, err
+		}
+	}
 	return &Server{
 		config:   conf,
 		cache:    cache,
@@ -302,6 +316,8 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error {
 		return s.handleStatic(w, r)
 	} else if r.Method == http.MethodGet && docsRegex.MatchString(r.URL.Path) {
 		return s.handleDocs(w, r)
+	} else if r.Method == http.MethodGet && fileRegex.MatchString(r.URL.Path) {
+		return s.handleFile(w, r)
 	} else if r.Method == http.MethodOptions {
 		return s.handleOptions(w, r)
 	} else if r.Method == http.MethodGet && topicPathRegex.MatchString(r.URL.Path) {
@@ -357,17 +373,45 @@ func (s *Server) handleDocs(w http.ResponseWriter, r *http.Request) error {
 	return nil
 }
 
+func (s *Server) handleFile(w http.ResponseWriter, r *http.Request) error {
+	if s.config.AttachmentCacheDir == "" {
+		return errHTTPBadRequestAttachmentsDisallowed
+	}
+	matches := fileRegex.FindStringSubmatch(r.URL.Path)
+	if len(matches) != 2 {
+		return errHTTPInternalErrorInvalidFilePath
+	}
+	messageID := matches[1]
+	file := filepath.Join(s.config.AttachmentCacheDir, messageID)
+	stat, err := os.Stat(file)
+	if err != nil {
+		return errHTTPNotFound
+	}
+	w.Header().Set("Length", fmt.Sprintf("%d", stat.Size()))
+	f, err := os.Open(file)
+	if err != nil {
+		return err
+	}
+	defer f.Close()
+	_, err = io.Copy(util.NewContentTypeWriter(w), f)
+	return err
+}
+
 func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visitor) error {
 	t, err := s.topicFromPath(r.URL.Path)
 	if err != nil {
 		return err
 	}
-	reader := io.LimitReader(r.Body, int64(s.config.MessageLimit))
-	b, err := io.ReadAll(reader)
+	body, err := util.Peak(r.Body, s.config.MessageLimit)
 	if err != nil {
 		return err
 	}
-	m := newDefaultMessage(t.ID, strings.TrimSpace(string(b)))
+	m := newDefaultMessage(t.ID, "")
+	if !body.LimitReached && utf8.Valid(body.PeakedBytes) {
+		m.Message = strings.TrimSpace(string(body.PeakedBytes))
+	} else if err := s.writeAttachment(v, m, body); err != nil {
+		return err
+	}
 	cache, firebase, email, err := s.parsePublishParams(r, m)
 	if err != nil {
 		return err
@@ -478,6 +522,48 @@ func readParam(r *http.Request, names ...string) string {
 	return ""
 }
 
+func (s *Server) writeAttachment(v *visitor, m *message, body *util.PeakedReadCloser) error {
+	if s.config.AttachmentCacheDir == "" || !util.FileExists(s.config.AttachmentCacheDir) {
+		return errHTTPBadRequestAttachmentsPublishDisallowed
+	}
+	contentType := http.DetectContentType(body.PeakedBytes)
+	exts, err := mime.ExtensionsByType(contentType)
+	if err != nil {
+		return err
+	}
+	ext := ".bin"
+	if len(exts) > 0 {
+		ext = exts[0]
+	}
+	filename := fmt.Sprintf("attachment%s", ext)
+	file := filepath.Join(s.config.AttachmentCacheDir, m.ID)
+	f, err := os.OpenFile(file, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
+	if err != nil {
+		return err
+	}
+	defer f.Close()
+	fileSizeLimiter := util.NewLimiter(s.config.AttachmentSizeLimit)
+	limitWriter := util.NewLimitWriter(f, fileSizeLimiter)
+	if _, err := io.Copy(limitWriter, body); err != nil {
+		os.Remove(file)
+		if err == util.ErrLimitReached {
+			return errHTTPBadRequestMessageTooLarge
+		}
+		return err
+	}
+	if err := f.Close(); err != nil {
+		os.Remove(file)
+		return err
+	}
+	m.Message = fmt.Sprintf("You received a file: %s", filename)
+	m.Attachment = &attachment{
+		Name: filename,
+		Type: contentType,
+		URL:  fmt.Sprintf("%s/file/%s%s", s.config.BaseURL, m.ID, ext),
+	}
+	return nil
+}
+
 func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request, v *visitor) error {
 	encoder := func(msg *message) (string, error) {
 		var buf bytes.Buffer
@@ -691,7 +777,7 @@ func (s *Server) topicsFromIDs(ids ...string) ([]*topic, error) {
 			return nil, errHTTPBadRequestTopicDisallowed
 		}
 		if _, ok := s.topics[id]; !ok {
-			if len(s.topics) >= s.config.GlobalTopicLimit {
+			if len(s.topics) >= s.config.TotalTopicLimit {
 				return nil, errHTTPTooManyRequestsLimitGlobalTopics
 			}
 			s.topics[id] = newTopic(id)
diff --git a/server/server_test.go b/server/server_test.go
index e713e60..f8a1a8a 100644
--- a/server/server_test.go
+++ b/server/server_test.go
@@ -165,17 +165,8 @@ func TestServer_PublishLargeMessage(t *testing.T) {
 	s := newTestServer(t, newTestConfig(t))
 
 	body := strings.Repeat("this is a large message", 5000)
-	truncated := body[0:4096]
 	response := request(t, s, "PUT", "/mytopic", body, nil)
-	msg := toMessage(t, response.Body.String())
-	require.NotEmpty(t, msg.ID)
-	require.Equal(t, truncated, msg.Message)
-	require.Equal(t, 4096, len(msg.Message))
-
-	response = request(t, s, "GET", "/mytopic/json?poll=1", "", nil)
-	messages := toMessages(t, response.Body.String())
-	require.Equal(t, 1, len(messages))
-	require.Equal(t, truncated, messages[0].Message)
+	require.Equal(t, 400, response.Code)
 }
 
 func TestServer_PublishPriority(t *testing.T) {
diff --git a/util/content_type_writer.go b/util/content_type_writer.go
new file mode 100644
index 0000000..fb3c43f
--- /dev/null
+++ b/util/content_type_writer.go
@@ -0,0 +1,41 @@
+package util
+
+import (
+	"net/http"
+	"strings"
+)
+
+// ContentTypeWriter is an implementation of http.ResponseWriter that will detect the content type and set the
+// Content-Type and (optionally) Content-Disposition headers accordingly.
+//
+// It will always set a Content-Type based on http.DetectContentType, but will never send the "text/html"
+// content type.
+type ContentTypeWriter struct {
+	w       http.ResponseWriter
+	sniffed bool
+}
+
+// NewContentTypeWriter creates a new ContentTypeWriter
+func NewContentTypeWriter(w http.ResponseWriter) *ContentTypeWriter {
+	return &ContentTypeWriter{w, false}
+}
+
+func (w *ContentTypeWriter) Write(p []byte) (n int, err error) {
+	if w.sniffed {
+		return w.w.Write(p)
+	}
+	// Detect and set Content-Type header
+	// Fix content types that we don't want to inline-render in the browser. In particular,
+	// we don't want to render HTML in the browser for security reasons.
+	contentType := http.DetectContentType(p)
+	if strings.HasPrefix(contentType, "text/html") {
+		contentType = strings.ReplaceAll(contentType, "text/html", "text/plain")
+	} else if contentType == "application/octet-stream" {
+		contentType = "" // Reset to let downstream http.ResponseWriter take care of it
+	}
+	if contentType != "" {
+		w.w.Header().Set("Content-Type", contentType)
+	}
+	w.sniffed = true
+	return w.w.Write(p)
+}
diff --git a/util/content_type_writer_test.go b/util/content_type_writer_test.go
new file mode 100644
index 0000000..08dd751
--- /dev/null
+++ b/util/content_type_writer_test.go
@@ -0,0 +1,50 @@
+package util
+
+import (
+	"crypto/rand"
+	"github.com/stretchr/testify/require"
+	"net/http/httptest"
+	"testing"
+)
+
+func TestSniffWriter_WriteHTML(t *testing.T) {
+	rr := httptest.NewRecorder()
+	sw := NewContentTypeWriter(rr)
+	sw.Write([]byte("<script>alert('hi')</script>"))
+	require.Equal(t, "text/plain; charset=utf-8", rr.Header().Get("Content-Type"))
+}
+
+func TestSniffWriter_WriteTwoWriteCalls(t *testing.T) {
+	rr := httptest.NewRecorder()
+	sw := NewContentTypeWriter(rr)
+	sw.Write([]byte{0x25, 0x50, 0x44, 0x46, 0x2d, 0x11, 0x22, 0x33})
+	sw.Write([]byte("<script>alert('hi')</script>"))
+	require.Equal(t, "application/pdf", rr.Header().Get("Content-Type"))
+}
+
+func TestSniffWriter_NoSniffWriterWriteHTML(t *testing.T) {
+	// This test just makes sure that without the sniff-w, we would get text/html
+
+	rr := httptest.NewRecorder()
+	rr.Write([]byte("<script>alert('hi')</script>"))
+	require.Equal(t, "text/html; charset=utf-8", rr.Header().Get("Content-Type"))
+}
+
+func TestSniffWriter_WriteHTMLSplitIntoTwoWrites(t *testing.T) {
+	// This test shows how splitting the HTML into two Write() calls will still yield text/plain
+
+	rr := httptest.NewRecorder()
+	sw := NewContentTypeWriter(rr)
+	sw.Write([]byte("<scr"))
+	sw.Write([]byte("ipt>alert('hi')</script>"))
+	require.Equal(t, "text/plain; charset=utf-8", rr.Header().Get("Content-Type"))
+}
+
+func TestSniffWriter_WriteUnknownMimeType(t *testing.T) {
+	rr := httptest.NewRecorder()
+	sw := NewContentTypeWriter(rr)
+	randomBytes := make([]byte, 199)
+	rand.Read(randomBytes)
+	sw.Write(randomBytes)
+	require.Equal(t, "application/octet-stream", rr.Header().Get("Content-Type"))
+}
diff --git a/util/limit.go b/util/limit.go
index e556124..bac3c15 100644
--- a/util/limit.go
+++ b/util/limit.go
@@ -2,6 +2,7 @@ package util
 
 import (
 	"errors"
+	"io"
 	"sync"
 )
 
@@ -58,3 +59,43 @@ func (l *Limiter) Value() int64 {
 	defer l.mu.Unlock()
 	return l.value
 }
+
+// Limit returns the defined limit
+func (l *Limiter) Limit() int64 {
+	return l.limit
+}
+
+// LimitWriter implements an io.Writer that will pass through all Write calls to the underlying
+// writer w until any of the limiter's limit is reached, at which point a Write will return ErrLimitReached.
+// Each limiter's value is increased with every write.
+type LimitWriter struct {
+	w        io.Writer
+	written  int64
+	limiters []*Limiter
+	mu       sync.Mutex
+}
+
+// NewLimitWriter creates a new LimitWriter
+func NewLimitWriter(w io.Writer, limiters ...*Limiter) *LimitWriter {
+	return &LimitWriter{
+		w:        w,
+		limiters: limiters,
+	}
+}
+
+// Write passes through all writes to the underlying writer until any of the given limiter's limit is reached
+func (w *LimitWriter) Write(p []byte) (n int, err error) {
+	w.mu.Lock()
+	defer w.mu.Unlock()
+	for i := 0; i < len(w.limiters); i++ {
+		if err := w.limiters[i].Add(int64(len(p))); err != nil {
+			for j := i - 1; j >= 0; j-- {
+				w.limiters[j].Sub(int64(len(p)))
+			}
+			return 0, ErrLimitReached
+		}
+	}
+	n, err = w.w.Write(p)
+	w.written += int64(n)
+	return
+}
diff --git a/util/limit_test.go b/util/limit_test.go
index f6d56c6..4f07e00 100644
--- a/util/limit_test.go
+++ b/util/limit_test.go
@@ -1,6 +1,7 @@
 package util
 
 import (
+	"bytes"
 	"testing"
 )
 
@@ -17,14 +18,68 @@ func TestLimiter_Add(t *testing.T) {
 	}
 }
 
-func TestLimiter_AddSub(t *testing.T) {
+func TestLimiter_AddSet(t *testing.T) {
 	l := NewLimiter(10)
 	l.Add(5)
 	if l.Value() != 5 {
 		t.Fatalf("expected value to be %d, got %d", 5, l.Value())
 	}
-	l.Sub(2)
-	if l.Value() != 3 {
-		t.Fatalf("expected value to be %d, got %d", 3, l.Value())
+	l.Set(7)
+	if l.Value() != 7 {
+		t.Fatalf("expected value to be %d, got %d", 7, l.Value())
+	}
+}
+
+func TestLimitWriter_WriteNoLimiter(t *testing.T) {
+	var buf bytes.Buffer
+	lw := NewLimitWriter(&buf)
+	if _, err := lw.Write(make([]byte, 10)); err != nil {
+		t.Fatal(err)
+	}
+	if _, err := lw.Write(make([]byte, 1)); err != nil {
+		t.Fatal(err)
+	}
+	if buf.Len() != 11 {
+		t.Fatalf("expected buffer length to be %d, got %d", 11, buf.Len())
+	}
+}
+
+func TestLimitWriter_WriteOneLimiter(t *testing.T) {
+	var buf bytes.Buffer
+	l := NewLimiter(10)
+	lw := NewLimitWriter(&buf, l)
+	if _, err := lw.Write(make([]byte, 10)); err != nil {
+		t.Fatal(err)
+	}
+	if _, err := lw.Write(make([]byte, 1)); err != ErrLimitReached {
+		t.Fatalf("expected ErrLimitReached, got %#v", err)
+	}
+	if buf.Len() != 10 {
+		t.Fatalf("expected buffer length to be %d, got %d", 10, buf.Len())
+	}
+	if l.Value() != 10 {
+		t.Fatalf("expected limiter value to be %d, got %d", 10, l.Value())
+	}
+}
+
+func TestLimitWriter_WriteTwoLimiters(t *testing.T) {
+	var buf bytes.Buffer
+	l1 := NewLimiter(11)
+	l2 := NewLimiter(9)
+	lw := NewLimitWriter(&buf, l1, l2)
+	if _, err := lw.Write(make([]byte, 8)); err != nil {
+		t.Fatal(err)
+	}
+	if _, err := lw.Write(make([]byte, 2)); err != ErrLimitReached {
+		t.Fatalf("expected ErrLimitReached, got %#v", err)
+	}
+	if buf.Len() != 8 {
+		t.Fatalf("expected buffer length to be %d, got %d", 8, buf.Len())
+	}
+	if l1.Value() != 8 {
+		t.Fatalf("expected limiter 1 value to be %d, got %d", 8, l1.Value())
+	}
+	if l2.Value() != 8 {
+		t.Fatalf("expected limiter 2 value to be %d, got %d", 8, l2.Value())
 	}
 }
diff --git a/util/peak.go b/util/peak.go
new file mode 100644
index 0000000..100c269
--- /dev/null
+++ b/util/peak.go
@@ -0,0 +1,61 @@
+package util
+
+import (
+	"bytes"
+	"io"
+	"strings"
+)
+
+// PeakedReadCloser is a ReadCloser that allows peaking into a stream and buffering it in memory.
+// It can be instantiated using the Peak function. After a stream has been peaked, it can still be fully
+// read by reading the PeakedReadCloser. It first drained from the memory buffer, and then from the remaining
+// underlying reader.
+type PeakedReadCloser struct {
+	PeakedBytes  []byte
+	LimitReached bool
+	peaked       io.Reader
+	underlying   io.ReadCloser
+	closed       bool
+}
+
+// Peak reads the underlying ReadCloser into memory up until the limit and returns a PeakedReadCloser
+func Peak(underlying io.ReadCloser, limit int) (*PeakedReadCloser, error) {
+	if underlying == nil {
+		underlying = io.NopCloser(strings.NewReader(""))
+	}
+	peaked := make([]byte, limit)
+	read, err := io.ReadFull(underlying, peaked)
+	if err != nil && err != io.ErrUnexpectedEOF && err != io.EOF {
+		return nil, err
+	}
+	return &PeakedReadCloser{
+		PeakedBytes:  peaked[:read],
+		LimitReached: read == limit,
+		underlying:   underlying,
+		peaked:       bytes.NewReader(peaked[:read]),
+		closed:       false,
+	}, nil
+}
+
+// Read reads from the peaked bytes and then from the underlying stream
+func (r *PeakedReadCloser) Read(p []byte) (n int, err error) {
+	if r.closed {
+		return 0, io.EOF
+	}
+	n, err = r.peaked.Read(p)
+	if err == io.EOF {
+		return r.underlying.Read(p)
+	} else if err != nil {
+		return 0, err
+	}
+	return
+}
+
+// Close closes the underlying stream
+func (r *PeakedReadCloser) Close() error {
+	if r.closed {
+		return io.EOF
+	}
+	r.closed = true
+	return r.underlying.Close()
+}
diff --git a/util/peak_test.go b/util/peak_test.go
new file mode 100644
index 0000000..7699517
--- /dev/null
+++ b/util/peak_test.go
@@ -0,0 +1,55 @@
+package util
+
+import (
+	"github.com/stretchr/testify/require"
+	"io"
+	"strings"
+	"testing"
+)
+
+func TestPeak_LimitReached(t *testing.T) {
+	underlying := io.NopCloser(strings.NewReader("1234567890"))
+	peaked, err := Peak(underlying, 5)
+	if err != nil {
+		t.Fatal(err)
+	}
+	require.Equal(t, []byte("12345"), peaked.PeakedBytes)
+	require.Equal(t, true, peaked.LimitReached)
+
+	all, err := io.ReadAll(peaked)
+	if err != nil {
+		t.Fatal(err)
+	}
+	require.Equal(t, []byte("1234567890"), all)
+	require.Equal(t, []byte("12345"), peaked.PeakedBytes)
+	require.Equal(t, true, peaked.LimitReached)
+}
+
+func TestPeak_LimitNotReached(t *testing.T) {
+	underlying := io.NopCloser(strings.NewReader("1234567890"))
+	peaked, err := Peak(underlying, 15)
+	if err != nil {
+		t.Fatal(err)
+	}
+	all, err := io.ReadAll(peaked)
+	if err != nil {
+		t.Fatal(err)
+	}
+	require.Equal(t, []byte("1234567890"), all)
+	require.Equal(t, []byte("1234567890"), peaked.PeakedBytes)
+	require.Equal(t, false, peaked.LimitReached)
+}
+
+func TestPeak_Nil(t *testing.T) {
+	peaked, err := Peak(nil, 15)
+	if err != nil {
+		t.Fatal(err)
+	}
+	all, err := io.ReadAll(peaked)
+	if err != nil {
+		t.Fatal(err)
+	}
+	require.Equal(t, []byte(""), all)
+	require.Equal(t, []byte(""), peaked.PeakedBytes)
+	require.Equal(t, false, peaked.LimitReached)
+}
-- 
GitLab