-
Travis Ralston authoredTravis Ralston authored
webserver.go 2.13 KiB
package api
import (
"context"
"encoding/json"
"net"
"net/http"
"strconv"
"sync"
"time"
"github.com/didip/tollbooth"
"github.com/getsentry/sentry-go"
sentryhttp "github.com/getsentry/sentry-go/http"
"github.com/sirupsen/logrus"
"github.com/turt2live/matrix-media-repo/api/_responses"
"github.com/turt2live/matrix-media-repo/common/config"
)
var srv *http.Server
var waitGroup = &sync.WaitGroup{}
var reload = false
func Init() *sync.WaitGroup {
address := net.JoinHostPort(config.Get().General.BindAddress, strconv.Itoa(config.Get().General.Port))
//defer func() {
// if err := recover(); err != nil {
// logrus.Fatal(err)
// }
//}()
handler := buildRoutes()
if config.Get().RateLimit.Enabled {
logrus.Debug("Enabling rate limit")
limiter := tollbooth.NewLimiter(0, nil)
limiter.SetIPLookups([]string{"X-Forwarded-For", "X-Real-IP", "RemoteAddr"})
limiter.SetTokenBucketExpirationTTL(time.Hour)
limiter.SetBurst(config.Get().RateLimit.BurstCount)
limiter.SetMax(config.Get().RateLimit.RequestsPerSecond)
b, _ := json.Marshal(_responses.RateLimitReached())
limiter.SetMessage(string(b))
limiter.SetMessageContentType("application/json")
handler = tollbooth.LimitHandler(limiter, handler)
}
// Note: we bind Sentry here to ensure we capture *everything*
sentryHandler := sentryhttp.New(sentryhttp.Options{})
srv = &http.Server{Addr: address, Handler: sentryHandler.Handle(handler)}
reload = false
go func() {
logrus.WithField("address", address).Info("Started up. Listening at http://" + address)
if err := srv.ListenAndServe(); err != http.ErrServerClosed {
sentry.CaptureException(err)
logrus.Fatal(err)
}
// Only notify the main thread that we're done if we're actually done
srv = nil
if !reload {
waitGroup.Done()
}
}()
return waitGroup
}
func Reload() {
reload = true
// Stop the server first
Stop()
// Reload the web server, ignoring the wait group (because we don't care to wait here)
Init()
}
func Stop() {
if srv != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := srv.Shutdown(ctx); err != nil {
panic(err)
}
}
}