From c80e4e1aa9318b85a8abb67fe623451418079fe6 Mon Sep 17 00:00:00 2001
From: Philipp Heckel <pheckel@datto.com>
Date: Tue, 31 May 2022 23:16:44 -0400
Subject: [PATCH] Make Firebase logic testable, test it

---
 server/server.go               | 72 +++++++++++++++----------------
 server/server_firebase.go      | 77 ++++++++++++++++++++++++++--------
 server/server_firebase_test.go | 38 +++++++++++++++++
 server/server_test.go          | 49 +++++++---------------
 4 files changed, 147 insertions(+), 89 deletions(-)

diff --git a/server/server.go b/server/server.go
index 7384ab4..253a422 100644
--- a/server/server.go
+++ b/server/server.go
@@ -32,22 +32,22 @@ import (
 
 // Server is the main server, providing the UI and API for ntfy
 type Server struct {
-	config       *Config
-	httpServer   *http.Server
-	httpsServer  *http.Server
-	unixListener net.Listener
-	smtpServer   *smtp.Server
-	smtpBackend  *smtpBackend
-	topics       map[string]*topic
-	visitors     map[string]*visitor
-	firebase     subscriber
-	mailer       mailer
-	messages     int64
-	auth         auth.Auther
-	messageCache *messageCache
-	fileCache    *fileCache
-	closeChan    chan bool
-	mu           sync.Mutex
+	config         *Config
+	httpServer     *http.Server
+	httpsServer    *http.Server
+	unixListener   net.Listener
+	smtpServer     *smtp.Server
+	smtpBackend    *smtpBackend
+	topics         map[string]*topic
+	visitors       map[string]*visitor
+	firebaseClient *firebaseClient
+	mailer         mailer
+	messages       int64
+	auth           auth.Auther
+	messageCache   *messageCache
+	fileCache      *fileCache
+	closeChan      chan bool
+	mu             sync.Mutex
 }
 
 // handleFunc extends the normal http.HandlerFunc to be able to easily return errors
@@ -134,23 +134,23 @@ func New(conf *Config) (*Server, error) {
 			return nil, err
 		}
 	}
-	var firebaseSubscriber subscriber
+	var firebaseClient *firebaseClient
 	if conf.FirebaseKeyFile != "" {
-		var err error
-		firebaseSubscriber, err = createFirebaseSubscriber(conf.FirebaseKeyFile, auther)
+		sender, err := newFirebaseSender(conf.FirebaseKeyFile)
 		if err != nil {
 			return nil, err
 		}
+		firebaseClient = newFirebaseClient(sender, auther)
 	}
 	return &Server{
-		config:       conf,
-		messageCache: messageCache,
-		fileCache:    fileCache,
-		firebase:     firebaseSubscriber,
-		mailer:       mailer,
-		topics:       topics,
-		auth:         auther,
-		visitors:     make(map[string]*visitor),
+		config:         conf,
+		messageCache:   messageCache,
+		fileCache:      fileCache,
+		firebaseClient: firebaseClient,
+		mailer:         mailer,
+		topics:         topics,
+		auth:           auther,
+		visitors:       make(map[string]*visitor),
 	}, nil
 }
 
@@ -437,7 +437,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
 			return err
 		}
 	}
-	if s.firebase != nil && firebase && !delayed {
+	if s.firebaseClient != nil && firebase && !delayed {
 		go s.sendToFirebase(v, m)
 	}
 	if s.mailer != nil && email != "" && !delayed {
@@ -463,7 +463,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
 }
 
 func (s *Server) sendToFirebase(v *visitor, m *message) {
-	if err := s.firebase(v, m); err != nil {
+	if err := s.firebaseClient.Send(v, m); err != nil {
 		log.Printf("[%s] FB - Unable to publish to Firebase: %v", v.ip, err.Error())
 	}
 }
@@ -1096,20 +1096,16 @@ func (s *Server) runDelayedSender() {
 }
 
 func (s *Server) runFirebaseKeepaliver() {
-	if s.firebase == nil {
+	if s.firebaseClient == nil {
 		return
 	}
-	v := newVisitor(s.config, s.messageCache, "0.0.0.0")
+	v := newVisitor(s.config, s.messageCache, "0.0.0.0") // Background process, not a real visitor
 	for {
 		select {
 		case <-time.After(s.config.FirebaseKeepaliveInterval):
-			if err := s.firebase(v, newKeepaliveMessage(firebaseControlTopic)); err != nil {
-				log.Printf("error sending Firebase keepalive message to %s: %s", firebaseControlTopic, err.Error())
-			}
+			s.sendToFirebase(v, newKeepaliveMessage(firebaseControlTopic))
 		case <-time.After(s.config.FirebasePollInterval):
-			if err := s.firebase(v, newKeepaliveMessage(firebasePollTopic)); err != nil {
-				log.Printf("error sending Firebase keepalive message to %s: %s", firebasePollTopic, err.Error())
-			}
+			s.sendToFirebase(v, newKeepaliveMessage(firebasePollTopic))
 		case <-s.closeChan:
 			return
 		}
@@ -1142,7 +1138,7 @@ func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
 			}
 		}()
 	}
-	if s.firebase != nil { // Firebase subscribers may not show up in topics map
+	if s.firebaseClient != nil { // Firebase subscribers may not show up in topics map
 		go s.sendToFirebase(v, m)
 	}
 	if s.config.UpstreamBaseURL != "" {
diff --git a/server/server_firebase.go b/server/server_firebase.go
index 8368337..47d2755 100644
--- a/server/server_firebase.go
+++ b/server/server_firebase.go
@@ -3,6 +3,7 @@ package server
 import (
 	"context"
 	"encoding/json"
+	"errors"
 	"fmt"
 	"log"
 	"strings"
@@ -18,33 +19,75 @@ const (
 	fcmApnsBodyMessageLimit = 100
 )
 
-func createFirebaseSubscriber(credentialsFile string, auther auth.Auther) (subscriber, error) {
+var (
+	errFirebaseQuotaExceeded = errors.New("Firebase quota exceeded")
+)
+
+// firebaseClient is a generic client that formats and sends messages to Firebase.
+// The actual Firebase implementation is implemented in firebaseSenderImpl, to make it testable.
+type firebaseClient struct {
+	sender firebaseSender
+	auther auth.Auther
+}
+
+func newFirebaseClient(sender firebaseSender, auther auth.Auther) *firebaseClient {
+	return &firebaseClient{
+		sender: sender,
+		auther: auther,
+	}
+}
+
+func (c *firebaseClient) Send(v *visitor, m *message) error {
+	if err := v.FirebaseAllowed(); err != nil {
+		return errFirebaseQuotaExceeded
+	}
+	fbm, err := toFirebaseMessage(m, c.auther)
+	if err != nil {
+		return err
+	}
+	err = c.sender.Send(fbm)
+	if err == errFirebaseQuotaExceeded {
+		log.Printf("[%s] FB quota exceeded for topic %s, temporarily denying FB access to visitor", v.ip, m.Topic)
+		v.FirebaseTemporarilyDeny()
+	}
+	return err
+}
+
+// firebaseSender is an interface that represents a client that can send to Firebase Cloud Messaging.
+// In tests, this can be implemented with a mock.
+type firebaseSender interface {
+	// Send sends a message to Firebase, or returns an error. It returns errFirebaseQuotaExceeded
+	// if a rate limit has reached.
+	Send(m *messaging.Message) error
+}
+
+// firebaseSenderImpl is a firebaseSender that actually talks to Firebase
+type firebaseSenderImpl struct {
+	client *messaging.Client
+}
+
+func newFirebaseSender(credentialsFile string) (*firebaseSenderImpl, error) {
 	fb, err := firebase.NewApp(context.Background(), nil, option.WithCredentialsFile(credentialsFile))
 	if err != nil {
 		return nil, err
 	}
-	msg, err := fb.Messaging(context.Background())
+	client, err := fb.Messaging(context.Background())
 	if err != nil {
 		return nil, err
 	}
-	return func(v *visitor, m *message) error {
-		if err := v.FirebaseAllowed(); err != nil {
-			return errHTTPTooManyRequestsFirebaseQuotaReached
-		}
-		fbm, err := toFirebaseMessage(m, auther)
-		if err != nil {
-			return err
-		}
-		_, err = msg.Send(context.Background(), fbm)
-		if err != nil && messaging.IsQuotaExceeded(err) {
-			log.Printf("[%s] FB quota exceeded when trying to publish to topic %s, temporarily denying FB access", v.ip, m.Topic)
-			v.FirebaseTemporarilyDeny()
-			return errHTTPTooManyRequestsFirebaseQuotaReached
-		}
-		return err
+	return &firebaseSenderImpl{
+		client: client,
 	}, nil
 }
 
+func (c *firebaseSenderImpl) Send(m *messaging.Message) error {
+	_, err := c.client.Send(context.Background(), m)
+	if err != nil && messaging.IsQuotaExceeded(err) {
+		return errFirebaseQuotaExceeded
+	}
+	return err
+}
+
 // toFirebaseMessage converts a message to a Firebase message.
 //
 // Normal messages ("message"):
diff --git a/server/server_firebase_test.go b/server/server_firebase_test.go
index 6ad6fde..8e08b0d 100644
--- a/server/server_firebase_test.go
+++ b/server/server_firebase_test.go
@@ -26,6 +26,25 @@ func (t testAuther) Authorize(_ *auth.User, _ string, _ auth.Permission) error {
 	return errors.New("unauthorized")
 }
 
+type testFirebaseSender struct {
+	allowed  int
+	messages []*messaging.Message
+}
+
+func newTestFirebaseSender(allowed int) *testFirebaseSender {
+	return &testFirebaseSender{
+		allowed:  allowed,
+		messages: make([]*messaging.Message, 0),
+	}
+}
+func (s *testFirebaseSender) Send(m *messaging.Message) error {
+	if len(s.messages)+1 > s.allowed {
+		return errFirebaseQuotaExceeded
+	}
+	s.messages = append(s.messages, m)
+	return nil
+}
+
 func TestToFirebaseMessage_Keepalive(t *testing.T) {
 	m := newKeepaliveMessage("mytopic")
 	fbm, err := toFirebaseMessage(m, nil)
@@ -285,3 +304,22 @@ func TestMaybeTruncateFCMMessage_NotTooLong(t *testing.T) {
 	require.Equal(t, len(serializedOrigFCMMessage), len(serializedNotTruncatedFCMMessage))
 	require.Equal(t, "", notTruncatedFCMMessage.Data["truncated"])
 }
+
+func TestToFirebaseSender_Abuse(t *testing.T) {
+	sender := &testFirebaseSender{allowed: 2}
+	client := newFirebaseClient(sender, &testAuther{})
+	visitor := newVisitor(newTestConfig(t), newMemTestCache(t), "1.2.3.4")
+
+	require.Nil(t, client.Send(visitor, &message{Topic: "mytopic"}))
+	require.Equal(t, 1, len(sender.messages))
+
+	require.Nil(t, client.Send(visitor, &message{Topic: "mytopic"}))
+	require.Equal(t, 2, len(sender.messages))
+
+	require.Equal(t, errFirebaseQuotaExceeded, client.Send(visitor, &message{Topic: "mytopic"}))
+	require.Equal(t, 2, len(sender.messages))
+
+	sender.messages = make([]*messaging.Message, 0) // Reset to test that time limit is working
+	require.Equal(t, errFirebaseQuotaExceeded, client.Send(visitor, &message{Topic: "mytopic"}))
+	require.Equal(t, 0, len(sender.messages))
+}
diff --git a/server/server_test.go b/server/server_test.go
index 1fec1f5..d05075f 100644
--- a/server/server_test.go
+++ b/server/server_test.go
@@ -9,7 +9,6 @@ import (
 	"math/rand"
 	"net/http"
 	"net/http/httptest"
-	"os"
 	"path/filepath"
 	"strings"
 	"sync"
@@ -55,6 +54,21 @@ func TestServer_PublishAndPoll(t *testing.T) {
 	require.Equal(t, "my second  message", lines[1]) // \n -> " "
 }
 
+func TestServer_PublishWithFirebase(t *testing.T) {
+	sender := newTestFirebaseSender(10)
+	s := newTestServer(t, newTestConfig(t))
+	s.firebaseClient = newFirebaseClient(sender, &testAuther{Allow: true})
+
+	response := request(t, s, "PUT", "/mytopic", "my first message", nil)
+	msg1 := toMessage(t, response.Body.String())
+	require.NotEmpty(t, msg1.ID)
+	require.Equal(t, "my first message", msg1.Message)
+	require.Equal(t, 1, len(sender.messages))
+	require.Equal(t, "my first message", sender.messages[0].Data["message"])
+	require.Equal(t, "my first message", sender.messages[0].APNS.Payload.Aps.Alert.Body)
+	require.Equal(t, "my first message", sender.messages[0].APNS.Payload.CustomData["message"])
+}
+
 func TestServer_SubscribeOpenAndKeepalive(t *testing.T) {
 	c := newTestConfig(t)
 	c.KeepaliveInterval = time.Second
@@ -461,27 +475,6 @@ func TestServer_PublishMessageInHeaderWithNewlines(t *testing.T) {
 	require.Equal(t, "Line 1\nLine 2", msg.Message) // \\n -> \n !
 }
 
-func TestServer_PublishFirebase(t *testing.T) {
-	// This is unfortunately not much of a test, since it merely fires the messages towards Firebase,
-	// but cannot re-read them. There is no way from Go to read the messages back, or even get an error back.
-	// I tried everything. I already had written the test, and it increases the code coverage, so I'll leave it ... :shrug: ...
-
-	c := newTestConfig(t)
-	c.FirebaseKeyFile = firebaseServiceAccountFile(t) // May skip the test!
-	s := newTestServer(t, c)
-
-	// Normal message
-	response := request(t, s, "PUT", "/mytopic", "This is a message for firebase", nil)
-	msg := toMessage(t, response.Body.String())
-	require.NotEmpty(t, msg.ID)
-
-	// Keepalive message
-	v := newVisitor(s.config, s.messageCache, "1.2.3.4")
-	require.Nil(t, s.firebase(v, newKeepaliveMessage(firebaseControlTopic)))
-
-	time.Sleep(500 * time.Millisecond) // Time for sends
-}
-
 func TestServer_PublishInvalidTopic(t *testing.T) {
 	s := newTestServer(t, newTestConfig(t))
 	s.mailer = &testMailer{}
@@ -1341,18 +1334,6 @@ func toHTTPError(t *testing.T, s string) *errHTTP {
 	return &e
 }
 
-func firebaseServiceAccountFile(t *testing.T) string {
-	if os.Getenv("NTFY_TEST_FIREBASE_SERVICE_ACCOUNT_FILE") != "" {
-		return os.Getenv("NTFY_TEST_FIREBASE_SERVICE_ACCOUNT_FILE")
-	} else if os.Getenv("NTFY_TEST_FIREBASE_SERVICE_ACCOUNT") != "" {
-		filename := filepath.Join(t.TempDir(), "firebase.json")
-		require.NotNil(t, os.WriteFile(filename, []byte(os.Getenv("NTFY_TEST_FIREBASE_SERVICE_ACCOUNT")), 0o600))
-		return filename
-	}
-	t.SkipNow()
-	return ""
-}
-
 func basicAuth(s string) string {
 	return fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte(s)))
 }
-- 
GitLab