Skip to content
Snippets Groups Projects
Commit 7b810acf authored by Philipp Heckel's avatar Philipp Heckel
Browse files

SQLite cache

parent 1c7695c1
No related branches found
No related tags found
No related merge requests found
package server package server
import ( import (
"database/sql"
"time"
_ "github.com/mattn/go-sqlite3" // SQLite driver _ "github.com/mattn/go-sqlite3" // SQLite driver
"time"
) )
const ( type cache interface {
createTableQuery = `CREATE TABLE IF NOT EXISTS messages ( AddMessage(m *message) error
id VARCHAR(20) PRIMARY KEY, Messages(topic string, since time.Time) ([]*message, error)
time INT NOT NULL, MessageCount(topic string) (int, error)
topic VARCHAR(64) NOT NULL, Topics() (map[string]*topic, error)
message VARCHAR(1024) NOT NULL Prune(keep time.Duration) error
)`
insertQuery = `INSERT INTO messages (id, time, topic, message) VALUES (?, ?, ?, ?)`
pruneOlderThanQuery = `DELETE FROM messages WHERE time < ?`
)
type cache struct {
db *sql.DB
insert *sql.Stmt
prune *sql.Stmt
}
func newCache(filename string) (*cache, error) {
db, err := sql.Open("sqlite3", filename)
if err != nil {
return nil, err
}
if _, err := db.Exec(createTableQuery); err != nil {
return nil, err
}
insert, err := db.Prepare(insertQuery)
if err != nil {
return nil, err
}
prune, err := db.Prepare(pruneOlderThanQuery)
if err != nil {
return nil, err
}
return &cache{
db: db,
insert: insert,
prune: prune,
}, nil
}
func (c *cache) Load() (map[string]*topic, error) {
}
func (c *cache) Add(m *message) error {
_, err := c.insert.Exec(m.ID, m.Time, m.Topic, m.Message)
return err
}
func (c *cache) Prune(olderThan time.Duration) error {
_, err := c.prune.Exec(time.Now().Add(-1 * olderThan).Unix())
return err
} }
package server
import (
_ "github.com/mattn/go-sqlite3" // SQLite driver
"sync"
"time"
)
type memCache struct {
messages map[string][]*message
mu sync.Mutex
}
var _ cache = (*memCache)(nil)
func newMemCache() *memCache {
return &memCache{
messages: make(map[string][]*message),
}
}
func (s *memCache) AddMessage(m *message) error {
s.mu.Lock()
defer s.mu.Unlock()
if _, ok := s.messages[m.Topic]; !ok {
s.messages[m.Topic] = make([]*message, 0)
}
s.messages[m.Topic] = append(s.messages[m.Topic], m)
return nil
}
func (s *memCache) Messages(topic string, since time.Time) ([]*message, error) {
s.mu.Lock()
defer s.mu.Unlock()
if _, ok := s.messages[topic]; !ok {
return make([]*message, 0), nil
}
messages := make([]*message, 0) // copy!
for _, m := range s.messages[topic] {
msgTime := time.Unix(m.Time, 0)
if msgTime == since || msgTime.After(since) {
messages = append(messages, m)
}
}
return messages, nil
}
func (s *memCache) MessageCount(topic string) (int, error) {
s.mu.Lock()
defer s.mu.Unlock()
if _, ok := s.messages[topic]; !ok {
return 0, nil
}
return len(s.messages[topic]), nil
}
func (s *memCache) Topics() (map[string]*topic, error) {
// Hack since we know when this is called there are no messages!
return make(map[string]*topic), nil
}
func (s *memCache) Prune(keep time.Duration) error {
s.mu.Lock()
defer s.mu.Unlock()
for topic, _ := range s.messages {
s.pruneTopic(topic, keep)
}
return nil
}
func (s *memCache) pruneTopic(topic string, keep time.Duration) {
for i, m := range s.messages[topic] {
msgTime := time.Unix(m.Time, 0)
if time.Since(msgTime) < keep {
s.messages[topic] = s.messages[topic][i:]
return
}
}
s.messages[topic] = make([]*message, 0) // all messages expired
}
package server
import (
"database/sql"
"errors"
_ "github.com/mattn/go-sqlite3" // SQLite driver
"time"
)
const (
createTableQuery = `
BEGIN;
CREATE TABLE IF NOT EXISTS messages (
id VARCHAR(20) PRIMARY KEY,
time INT NOT NULL,
topic VARCHAR(64) NOT NULL,
message VARCHAR(1024) NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
COMMIT;
`
insertMessageQuery = `INSERT INTO messages (id, time, topic, message) VALUES (?, ?, ?, ?)`
pruneMessagesQuery = `DELETE FROM messages WHERE time < ?`
selectMessagesSinceTimeQuery = `
SELECT id, time, message
FROM messages
WHERE topic = ? AND time >= ?
ORDER BY time ASC
`
selectMessageCountQuery = `SELECT count(*) FROM messages WHERE topic = ?`
selectTopicsQuery = `SELECT topic, MAX(time) FROM messages GROUP BY TOPIC`
)
type sqliteCache struct {
db *sql.DB
}
var _ cache = (*sqliteCache)(nil)
func newSqliteCache(filename string) (*sqliteCache, error) {
db, err := sql.Open("sqlite3", filename)
if err != nil {
return nil, err
}
if _, err := db.Exec(createTableQuery); err != nil {
return nil, err
}
return &sqliteCache{
db: db,
}, nil
}
func (c *sqliteCache) AddMessage(m *message) error {
_, err := c.db.Exec(insertMessageQuery, m.ID, m.Time, m.Topic, m.Message)
return err
}
func (c *sqliteCache) Messages(topic string, since time.Time) ([]*message, error) {
rows, err := c.db.Query(selectMessagesSinceTimeQuery, topic, since.Unix())
if err != nil {
return nil, err
}
defer rows.Close()
messages := make([]*message, 0)
for rows.Next() {
var timestamp int64
var id, msg string
if err := rows.Scan(&id, &timestamp, &msg); err != nil {
return nil, err
}
messages = append(messages, &message{
ID: id,
Time: timestamp,
Event: messageEvent,
Topic: topic,
Message: msg,
})
}
if err := rows.Err(); err != nil {
return nil, err
}
return messages, nil
}
func (c *sqliteCache) MessageCount(topic string) (int, error) {
rows, err := c.db.Query(selectMessageCountQuery, topic)
if err != nil {
return 0, err
}
defer rows.Close()
var count int
if !rows.Next() {
return 0, errors.New("no rows found")
}
if err := rows.Scan(&count); err != nil {
return 0, err
} else if err := rows.Err(); err != nil {
return 0, err
}
return count, nil
}
func (s *sqliteCache) Topics() (map[string]*topic, error) {
rows, err := s.db.Query(selectTopicsQuery)
if err != nil {
return nil, err
}
defer rows.Close()
topics := make(map[string]*topic, 0)
for rows.Next() {
var id string
var last int64
if err := rows.Scan(&id, &last); err != nil {
return nil, err
}
topics[id] = newTopic(id, time.Unix(last, 0))
}
if err := rows.Err(); err != nil {
return nil, err
}
return topics, nil
}
func (c *sqliteCache) Prune(keep time.Duration) error {
_, err := c.db.Exec(pruneMessagesQuery, time.Now().Add(-1 * keep).Unix())
return err
}
...@@ -32,7 +32,7 @@ type Server struct { ...@@ -32,7 +32,7 @@ type Server struct {
visitors map[string]*visitor visitors map[string]*visitor
firebase subscriber firebase subscriber
messages int64 messages int64
cache *cache cache cache
mu sync.Mutex mu sync.Mutex
} }
...@@ -78,30 +78,28 @@ func New(conf *config.Config) (*Server, error) { ...@@ -78,30 +78,28 @@ func New(conf *config.Config) (*Server, error) {
return nil, err return nil, err
} }
} }
cache, err := maybeCreateCache(conf) cache, err := createCache(conf)
if err != nil { if err != nil {
return nil, err return nil, err
} }
topics := make(map[string]*topic) topics, err := cache.Topics()
if cache != nil { if err != nil {
if topics, err = cache.Load(); err != nil { return nil, err
return nil, err
}
} }
return &Server{ return &Server{
config: conf, config: conf,
cache: cache, cache: cache,
firebase: firebaseSubscriber, firebase: firebaseSubscriber,
topics: topics, topics: topics,
visitors: make(map[string]*visitor), visitors: make(map[string]*visitor),
}, nil }, nil
} }
func maybeCreateCache(conf *config.Config) (*cache, error) { func createCache(conf *config.Config) (cache, error) {
if conf.CacheFile == "" { if conf.CacheFile != "" {
return nil, nil return newSqliteCache(conf.CacheFile)
} }
return newCache(conf.CacheFile) return newMemCache(), nil
} }
func createFirebaseSubscriber(conf *config.Config) (subscriber, error) { func createFirebaseSubscriber(conf *config.Config) (subscriber, error) {
...@@ -202,8 +200,8 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito ...@@ -202,8 +200,8 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
if err := t.Publish(m); err != nil { if err := t.Publish(m); err != nil {
return err return err
} }
if s.cache != nil { if err := s.cache.AddMessage(m); err != nil {
s.cache.Add(m) return err
} }
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
s.mu.Lock() s.mu.Lock()
...@@ -277,20 +275,18 @@ func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visi ...@@ -277,20 +275,18 @@ func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visi
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
w.Header().Set("Content-Type", contentType) w.Header().Set("Content-Type", contentType)
if poll { if poll {
return sendOldMessages(t, since, sub) return s.sendOldMessages(t, since, sub)
} }
subscriberID := t.Subscribe(sub) subscriberID := t.Subscribe(sub)
defer t.Unsubscribe(subscriberID) defer t.Unsubscribe(subscriberID)
if err := sub(newOpenMessage(t.id)); err != nil { // Send out open message if err := sub(newOpenMessage(t.id)); err != nil { // Send out open message
return err return err
} }
if err := sendOldMessages(t, since, sub); err != nil { if err := s.sendOldMessages(t, since, sub); err != nil {
return err return err
} }
for { for {
select { select {
case <-t.ctx.Done():
return nil
case <-r.Context().Done(): case <-r.Context().Done():
return nil return nil
case <-time.After(s.config.KeepaliveInterval): case <-time.After(s.config.KeepaliveInterval):
...@@ -302,11 +298,15 @@ func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visi ...@@ -302,11 +298,15 @@ func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visi
} }
} }
func sendOldMessages(t *topic, since time.Time, sub subscriber) error { func (s *Server) sendOldMessages(t *topic, since time.Time, sub subscriber) error {
if since.IsZero() { if since.IsZero() {
return nil return nil
} }
for _, m := range t.Messages(since) { messages, err := s.cache.Messages(t.id, since)
if err != nil {
return err
}
for _, m := range messages {
if err := sub(m); err != nil { if err := sub(m); err != nil {
return err return err
} }
...@@ -340,7 +340,7 @@ func (s *Server) topic(id string) (*topic, error) { ...@@ -340,7 +340,7 @@ func (s *Server) topic(id string) (*topic, error) {
if len(s.topics) >= s.config.GlobalTopicLimit { if len(s.topics) >= s.config.GlobalTopicLimit {
return nil, errHTTPTooManyRequests return nil, errHTTPTooManyRequests
} }
s.topics[id] = newTopic(id) s.topics[id] = newTopic(id, time.Now())
if s.firebase != nil { if s.firebase != nil {
s.topics[id].Subscribe(s.firebase) s.topics[id].Subscribe(s.firebase)
} }
...@@ -360,28 +360,28 @@ func (s *Server) updateStatsAndExpire() { ...@@ -360,28 +360,28 @@ func (s *Server) updateStatsAndExpire() {
} }
// Prune cache // Prune cache
if s.cache != nil { if err := s.cache.Prune(s.config.MessageBufferDuration); err != nil {
if err := s.cache.Prune(s.config.MessageBufferDuration); err != nil { log.Printf("error pruning cache: %s", err.Error())
log.Printf("error pruning cache: %s", err.Error())
}
} }
// Prune old messages, remove subscriptions without subscribers // Prune old messages, remove subscriptions without subscribers
var subscribers, messages int
for _, t := range s.topics { for _, t := range s.topics {
t.Prune(s.config.MessageBufferDuration) subs := t.Subscribers()
subs, msgs := t.Stats() msgs, err := s.cache.MessageCount(t.id)
if err != nil {
log.Printf("cannot get stats for topic %s: %s", t.id, err.Error())
continue
}
if msgs == 0 && (subs == 0 || (s.firebase != nil && subs == 1)) { // Firebase is a subscriber! if msgs == 0 && (subs == 0 || (s.firebase != nil && subs == 1)) { // Firebase is a subscriber!
delete(s.topics, t.id) delete(s.topics, t.id)
continue
} }
}
// Print stats
var subscribers, messages int
for _, t := range s.topics {
subs, msgs := t.Stats()
subscribers += subs subscribers += subs
messages += msgs messages += msgs
} }
// Print stats
log.Printf("Stats: %d message(s) published, %d topic(s) active, %d subscriber(s), %d message(s) buffered, %d visitor(s)", log.Printf("Stats: %d message(s) published, %d topic(s) active, %d subscriber(s), %d message(s) buffered, %d visitor(s)",
s.messages, len(s.topics), subscribers, messages, len(s.visitors)) s.messages, len(s.topics), subscribers, messages, len(s.visitors))
} }
......
package server package server
import ( import (
"context"
"log" "log"
"math/rand" "math/rand"
"sync" "sync"
...@@ -12,11 +11,8 @@ import ( ...@@ -12,11 +11,8 @@ import (
// can publish a message // can publish a message
type topic struct { type topic struct {
id string id string
subscribers map[int]subscriber
messages []*message
last time.Time last time.Time
ctx context.Context subscribers map[int]subscriber
cancel context.CancelFunc
mu sync.Mutex mu sync.Mutex
} }
...@@ -24,15 +20,11 @@ type topic struct { ...@@ -24,15 +20,11 @@ type topic struct {
type subscriber func(msg *message) error type subscriber func(msg *message) error
// newTopic creates a new topic // newTopic creates a new topic
func newTopic(id string) *topic { func newTopic(id string, last time.Time) *topic {
ctx, cancel := context.WithCancel(context.Background())
return &topic{ return &topic{
id: id, id: id,
last: last,
subscribers: make(map[int]subscriber), subscribers: make(map[int]subscriber),
messages: make([]*message, 0),
last: time.Now(),
ctx: ctx,
cancel: cancel,
} }
} }
...@@ -55,7 +47,6 @@ func (t *topic) Publish(m *message) error { ...@@ -55,7 +47,6 @@ func (t *topic) Publish(m *message) error {
t.mu.Lock() t.mu.Lock()
defer t.mu.Unlock() defer t.mu.Unlock()
t.last = time.Now() t.last = time.Now()
t.messages = append(t.messages, m)
for _, s := range t.subscribers { for _, s := range t.subscribers {
if err := s(m); err != nil { if err := s(m); err != nil {
log.Printf("error publishing message to subscriber") log.Printf("error publishing message to subscriber")
...@@ -64,38 +55,8 @@ func (t *topic) Publish(m *message) error { ...@@ -64,38 +55,8 @@ func (t *topic) Publish(m *message) error {
return nil return nil
} }
func (t *topic) Messages(since time.Time) []*message { func (t *topic) Subscribers() int {
t.mu.Lock()
defer t.mu.Unlock()
messages := make([]*message, 0) // copy!
for _, m := range t.messages {
msgTime := time.Unix(m.Time, 0)
if msgTime == since || msgTime.After(since) {
messages = append(messages, m)
}
}
return messages
}
func (t *topic) Prune(keep time.Duration) {
t.mu.Lock() t.mu.Lock()
defer t.mu.Unlock() defer t.mu.Unlock()
for i, m := range t.messages { return len(t.subscribers)
msgTime := time.Unix(m.Time, 0)
if time.Since(msgTime) < keep {
t.messages = t.messages[i:]
return
}
}
t.messages = make([]*message, 0)
}
func (t *topic) Stats() (subscribers int, messages int) {
t.mu.Lock()
defer t.mu.Unlock()
return len(t.subscribers), len(t.messages)
}
func (t *topic) Close() {
t.cancel()
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment