Skip to content
Snippets Groups Projects
Commit c80e4e1a authored by Philipp Heckel's avatar Philipp Heckel
Browse files

Make Firebase logic testable, test it

parent f9284a09
No related branches found
No related tags found
No related merge requests found
......@@ -32,22 +32,22 @@ import (
// Server is the main server, providing the UI and API for ntfy
type Server struct {
config *Config
httpServer *http.Server
httpsServer *http.Server
unixListener net.Listener
smtpServer *smtp.Server
smtpBackend *smtpBackend
topics map[string]*topic
visitors map[string]*visitor
firebase subscriber
mailer mailer
messages int64
auth auth.Auther
messageCache *messageCache
fileCache *fileCache
closeChan chan bool
mu sync.Mutex
config *Config
httpServer *http.Server
httpsServer *http.Server
unixListener net.Listener
smtpServer *smtp.Server
smtpBackend *smtpBackend
topics map[string]*topic
visitors map[string]*visitor
firebaseClient *firebaseClient
mailer mailer
messages int64
auth auth.Auther
messageCache *messageCache
fileCache *fileCache
closeChan chan bool
mu sync.Mutex
}
// handleFunc extends the normal http.HandlerFunc to be able to easily return errors
......@@ -134,23 +134,23 @@ func New(conf *Config) (*Server, error) {
return nil, err
}
}
var firebaseSubscriber subscriber
var firebaseClient *firebaseClient
if conf.FirebaseKeyFile != "" {
var err error
firebaseSubscriber, err = createFirebaseSubscriber(conf.FirebaseKeyFile, auther)
sender, err := newFirebaseSender(conf.FirebaseKeyFile)
if err != nil {
return nil, err
}
firebaseClient = newFirebaseClient(sender, auther)
}
return &Server{
config: conf,
messageCache: messageCache,
fileCache: fileCache,
firebase: firebaseSubscriber,
mailer: mailer,
topics: topics,
auth: auther,
visitors: make(map[string]*visitor),
config: conf,
messageCache: messageCache,
fileCache: fileCache,
firebaseClient: firebaseClient,
mailer: mailer,
topics: topics,
auth: auther,
visitors: make(map[string]*visitor),
}, nil
}
......@@ -437,7 +437,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
return err
}
}
if s.firebase != nil && firebase && !delayed {
if s.firebaseClient != nil && firebase && !delayed {
go s.sendToFirebase(v, m)
}
if s.mailer != nil && email != "" && !delayed {
......@@ -463,7 +463,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
}
func (s *Server) sendToFirebase(v *visitor, m *message) {
if err := s.firebase(v, m); err != nil {
if err := s.firebaseClient.Send(v, m); err != nil {
log.Printf("[%s] FB - Unable to publish to Firebase: %v", v.ip, err.Error())
}
}
......@@ -1096,20 +1096,16 @@ func (s *Server) runDelayedSender() {
}
func (s *Server) runFirebaseKeepaliver() {
if s.firebase == nil {
if s.firebaseClient == nil {
return
}
v := newVisitor(s.config, s.messageCache, "0.0.0.0")
v := newVisitor(s.config, s.messageCache, "0.0.0.0") // Background process, not a real visitor
for {
select {
case <-time.After(s.config.FirebaseKeepaliveInterval):
if err := s.firebase(v, newKeepaliveMessage(firebaseControlTopic)); err != nil {
log.Printf("error sending Firebase keepalive message to %s: %s", firebaseControlTopic, err.Error())
}
s.sendToFirebase(v, newKeepaliveMessage(firebaseControlTopic))
case <-time.After(s.config.FirebasePollInterval):
if err := s.firebase(v, newKeepaliveMessage(firebasePollTopic)); err != nil {
log.Printf("error sending Firebase keepalive message to %s: %s", firebasePollTopic, err.Error())
}
s.sendToFirebase(v, newKeepaliveMessage(firebasePollTopic))
case <-s.closeChan:
return
}
......@@ -1142,7 +1138,7 @@ func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
}
}()
}
if s.firebase != nil { // Firebase subscribers may not show up in topics map
if s.firebaseClient != nil { // Firebase subscribers may not show up in topics map
go s.sendToFirebase(v, m)
}
if s.config.UpstreamBaseURL != "" {
......
......@@ -3,6 +3,7 @@ package server
import (
"context"
"encoding/json"
"errors"
"fmt"
"log"
"strings"
......@@ -18,33 +19,75 @@ const (
fcmApnsBodyMessageLimit = 100
)
func createFirebaseSubscriber(credentialsFile string, auther auth.Auther) (subscriber, error) {
var (
errFirebaseQuotaExceeded = errors.New("Firebase quota exceeded")
)
// firebaseClient is a generic client that formats and sends messages to Firebase.
// The actual Firebase implementation is implemented in firebaseSenderImpl, to make it testable.
type firebaseClient struct {
sender firebaseSender
auther auth.Auther
}
func newFirebaseClient(sender firebaseSender, auther auth.Auther) *firebaseClient {
return &firebaseClient{
sender: sender,
auther: auther,
}
}
func (c *firebaseClient) Send(v *visitor, m *message) error {
if err := v.FirebaseAllowed(); err != nil {
return errFirebaseQuotaExceeded
}
fbm, err := toFirebaseMessage(m, c.auther)
if err != nil {
return err
}
err = c.sender.Send(fbm)
if err == errFirebaseQuotaExceeded {
log.Printf("[%s] FB quota exceeded for topic %s, temporarily denying FB access to visitor", v.ip, m.Topic)
v.FirebaseTemporarilyDeny()
}
return err
}
// firebaseSender is an interface that represents a client that can send to Firebase Cloud Messaging.
// In tests, this can be implemented with a mock.
type firebaseSender interface {
// Send sends a message to Firebase, or returns an error. It returns errFirebaseQuotaExceeded
// if a rate limit has reached.
Send(m *messaging.Message) error
}
// firebaseSenderImpl is a firebaseSender that actually talks to Firebase
type firebaseSenderImpl struct {
client *messaging.Client
}
func newFirebaseSender(credentialsFile string) (*firebaseSenderImpl, error) {
fb, err := firebase.NewApp(context.Background(), nil, option.WithCredentialsFile(credentialsFile))
if err != nil {
return nil, err
}
msg, err := fb.Messaging(context.Background())
client, err := fb.Messaging(context.Background())
if err != nil {
return nil, err
}
return func(v *visitor, m *message) error {
if err := v.FirebaseAllowed(); err != nil {
return errHTTPTooManyRequestsFirebaseQuotaReached
}
fbm, err := toFirebaseMessage(m, auther)
if err != nil {
return err
}
_, err = msg.Send(context.Background(), fbm)
if err != nil && messaging.IsQuotaExceeded(err) {
log.Printf("[%s] FB quota exceeded when trying to publish to topic %s, temporarily denying FB access", v.ip, m.Topic)
v.FirebaseTemporarilyDeny()
return errHTTPTooManyRequestsFirebaseQuotaReached
}
return err
return &firebaseSenderImpl{
client: client,
}, nil
}
func (c *firebaseSenderImpl) Send(m *messaging.Message) error {
_, err := c.client.Send(context.Background(), m)
if err != nil && messaging.IsQuotaExceeded(err) {
return errFirebaseQuotaExceeded
}
return err
}
// toFirebaseMessage converts a message to a Firebase message.
//
// Normal messages ("message"):
......
......@@ -26,6 +26,25 @@ func (t testAuther) Authorize(_ *auth.User, _ string, _ auth.Permission) error {
return errors.New("unauthorized")
}
type testFirebaseSender struct {
allowed int
messages []*messaging.Message
}
func newTestFirebaseSender(allowed int) *testFirebaseSender {
return &testFirebaseSender{
allowed: allowed,
messages: make([]*messaging.Message, 0),
}
}
func (s *testFirebaseSender) Send(m *messaging.Message) error {
if len(s.messages)+1 > s.allowed {
return errFirebaseQuotaExceeded
}
s.messages = append(s.messages, m)
return nil
}
func TestToFirebaseMessage_Keepalive(t *testing.T) {
m := newKeepaliveMessage("mytopic")
fbm, err := toFirebaseMessage(m, nil)
......@@ -285,3 +304,22 @@ func TestMaybeTruncateFCMMessage_NotTooLong(t *testing.T) {
require.Equal(t, len(serializedOrigFCMMessage), len(serializedNotTruncatedFCMMessage))
require.Equal(t, "", notTruncatedFCMMessage.Data["truncated"])
}
func TestToFirebaseSender_Abuse(t *testing.T) {
sender := &testFirebaseSender{allowed: 2}
client := newFirebaseClient(sender, &testAuther{})
visitor := newVisitor(newTestConfig(t), newMemTestCache(t), "1.2.3.4")
require.Nil(t, client.Send(visitor, &message{Topic: "mytopic"}))
require.Equal(t, 1, len(sender.messages))
require.Nil(t, client.Send(visitor, &message{Topic: "mytopic"}))
require.Equal(t, 2, len(sender.messages))
require.Equal(t, errFirebaseQuotaExceeded, client.Send(visitor, &message{Topic: "mytopic"}))
require.Equal(t, 2, len(sender.messages))
sender.messages = make([]*messaging.Message, 0) // Reset to test that time limit is working
require.Equal(t, errFirebaseQuotaExceeded, client.Send(visitor, &message{Topic: "mytopic"}))
require.Equal(t, 0, len(sender.messages))
}
......@@ -9,7 +9,6 @@ import (
"math/rand"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"sync"
......@@ -55,6 +54,21 @@ func TestServer_PublishAndPoll(t *testing.T) {
require.Equal(t, "my second message", lines[1]) // \n -> " "
}
func TestServer_PublishWithFirebase(t *testing.T) {
sender := newTestFirebaseSender(10)
s := newTestServer(t, newTestConfig(t))
s.firebaseClient = newFirebaseClient(sender, &testAuther{Allow: true})
response := request(t, s, "PUT", "/mytopic", "my first message", nil)
msg1 := toMessage(t, response.Body.String())
require.NotEmpty(t, msg1.ID)
require.Equal(t, "my first message", msg1.Message)
require.Equal(t, 1, len(sender.messages))
require.Equal(t, "my first message", sender.messages[0].Data["message"])
require.Equal(t, "my first message", sender.messages[0].APNS.Payload.Aps.Alert.Body)
require.Equal(t, "my first message", sender.messages[0].APNS.Payload.CustomData["message"])
}
func TestServer_SubscribeOpenAndKeepalive(t *testing.T) {
c := newTestConfig(t)
c.KeepaliveInterval = time.Second
......@@ -461,27 +475,6 @@ func TestServer_PublishMessageInHeaderWithNewlines(t *testing.T) {
require.Equal(t, "Line 1\nLine 2", msg.Message) // \\n -> \n !
}
func TestServer_PublishFirebase(t *testing.T) {
// This is unfortunately not much of a test, since it merely fires the messages towards Firebase,
// but cannot re-read them. There is no way from Go to read the messages back, or even get an error back.
// I tried everything. I already had written the test, and it increases the code coverage, so I'll leave it ... :shrug: ...
c := newTestConfig(t)
c.FirebaseKeyFile = firebaseServiceAccountFile(t) // May skip the test!
s := newTestServer(t, c)
// Normal message
response := request(t, s, "PUT", "/mytopic", "This is a message for firebase", nil)
msg := toMessage(t, response.Body.String())
require.NotEmpty(t, msg.ID)
// Keepalive message
v := newVisitor(s.config, s.messageCache, "1.2.3.4")
require.Nil(t, s.firebase(v, newKeepaliveMessage(firebaseControlTopic)))
time.Sleep(500 * time.Millisecond) // Time for sends
}
func TestServer_PublishInvalidTopic(t *testing.T) {
s := newTestServer(t, newTestConfig(t))
s.mailer = &testMailer{}
......@@ -1341,18 +1334,6 @@ func toHTTPError(t *testing.T, s string) *errHTTP {
return &e
}
func firebaseServiceAccountFile(t *testing.T) string {
if os.Getenv("NTFY_TEST_FIREBASE_SERVICE_ACCOUNT_FILE") != "" {
return os.Getenv("NTFY_TEST_FIREBASE_SERVICE_ACCOUNT_FILE")
} else if os.Getenv("NTFY_TEST_FIREBASE_SERVICE_ACCOUNT") != "" {
filename := filepath.Join(t.TempDir(), "firebase.json")
require.NotNil(t, os.WriteFile(filename, []byte(os.Getenv("NTFY_TEST_FIREBASE_SERVICE_ACCOUNT")), 0o600))
return filename
}
t.SkipNow()
return ""
}
func basicAuth(s string) string {
return fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte(s)))
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment