diff --git a/cmd/media_repo/inits.go b/cmd/media_repo/inits.go index 814c032089b0c600be90d867654c67382d533505..2bc77dfa7a2a9a762962ab71fb02a3bd730d1434 100644 --- a/cmd/media_repo/inits.go +++ b/cmd/media_repo/inits.go @@ -91,7 +91,7 @@ func loadDatastores() { mediaStore := storage.GetDatabase().GetMediaStore(rcontext.Initial()) logrus.Info("Initializing datastores...") - for _, ds := range config.Get().DataStores { + for _, ds := range config.UniqueDatastores() { if !ds.Enabled { continue } @@ -114,7 +114,7 @@ func loadDatastores() { logrus.Info(fmt.Sprintf("\t%s (%s): %s", ds.Type, ds.DatastoreId, ds.Uri)) if ds.Type == "s3" { - conf, err := datastore.GetDatastoreConfig(ds, rcontext.Initial()) + conf, err := datastore.GetDatastoreConfig(ds) if err != nil { continue } diff --git a/cmd/media_repo/main.go b/cmd/media_repo/main.go index 4ce51521d836198443743258c9472b974de57675..28d2d366d84a0ba4e53d5d438e33efbb8c3a92d1 100644 --- a/cmd/media_repo/main.go +++ b/cmd/media_repo/main.go @@ -35,6 +35,7 @@ func main() { logrus.Info("Starting up...") + config.PrintDomainInfo() loadDatabase() loadDatastores() diff --git a/common/config/access.go b/common/config/access.go index 4afa9fa9a5e40605d0d1291107b6d02e461588d9..bac89e69617bbceb91c50b50d2981c0b668bd525 100644 --- a/common/config/access.go +++ b/common/config/access.go @@ -22,10 +22,11 @@ var Path = "media-repo.yaml" var instance *MainRepoConfig var singletonLock = &sync.Once{} -var domains = make(map[string]DomainRepoConfig) +var domains = make(map[string]*DomainRepoConfig) -func reloadConfig() (*MainRepoConfig, error) { +func reloadConfig() (*MainRepoConfig, map[string]*DomainRepoConfig, error) { c := NewDefaultMainConfig() + domainConfs := make(map[string]*DomainRepoConfig) // Write a default config if the one given doesn't exist info, err := os.Stat(Path) @@ -34,29 +35,29 @@ func reloadConfig() (*MainRepoConfig, error) { fmt.Println("Generating new configuration...") configBytes, err := yaml.Marshal(c) if err != nil { - return nil, err + return nil, nil, err } newFile, err := os.Create(Path) if err != nil { - return nil, err + return nil, nil, err } _, err = newFile.Write(configBytes) if err != nil { - return nil, err + return nil, nil, err } err = newFile.Close() if err != nil { - return nil, err + return nil, nil, err } } // Get new info about the possible directory after creating info, err = os.Stat(Path) if err != nil { - return nil, err + return nil, nil, err } pathsOrdered := make([]string, 0) @@ -65,7 +66,7 @@ func reloadConfig() (*MainRepoConfig, error) { files, err := ioutil.ReadDir(Path) if err != nil { - return nil, err + return nil, nil, err } for _, f := range files { @@ -77,35 +78,146 @@ func reloadConfig() (*MainRepoConfig, error) { pathsOrdered = append(pathsOrdered, Path) } + pendingDomainConfigs := make(map[string][][]byte) + for _, p := range pathsOrdered { logrus.Info("Loading config file: ", p) + + s, err := os.Stat(p) + if err != nil { + return nil, nil, err + } + if s.IsDir() { + continue // skip directories + } + f, err := os.Open(p) if err != nil { - return nil, err + return nil, nil, err } //noinspection GoDeferInLoop defer f.Close() buffer, err := ioutil.ReadAll(f) + if err != nil { + return nil, nil, err + } + + testMap := make(map[string]interface{}) + err = yaml.Unmarshal(buffer, &testMap) + if err != nil { + return nil, nil, err + } + + if hsRaw, ok := testMap["homeserver"]; ok { + if hs, ok := hsRaw.(string); ok { + if _, ok = pendingDomainConfigs[hs]; !ok { + pendingDomainConfigs[hs] = make([][]byte, 0) + } + pendingDomainConfigs[hs] = append(pendingDomainConfigs[hs], buffer) + continue // skip parsing - we'll do this in a moment + } + } + + // Not a domain config - parse into regular config err = yaml.Unmarshal(buffer, &c) if err != nil { - return nil, err + return nil, nil, err + } + } + + // Start building domain configs + for _, d := range c.Homeservers { + dc := NewDefaultDomainConfig() + domainConfs[d.Name] = &dc + domainConfs[d.Name].Name = d.Name + domainConfs[d.Name].ClientServerApi = d.ClientServerApi + domainConfs[d.Name].BackoffAt = d.BackoffAt + domainConfs[d.Name].AdminApiKind = d.AdminApiKind + } + for hs, bs := range pendingDomainConfigs { + if _, ok := domainConfs[hs]; !ok { + dc := NewDefaultDomainConfig() + domainConfs[hs] = &dc + domainConfs[hs].Name = hs + } + + for _, b := range bs { + err = yaml.Unmarshal(b, domainConfs[hs]) + if err != nil { + return nil, nil, err + } } + + // For good measure... + domainConfs[hs].Name = hs } - return &c, nil + return &c, domainConfs, nil } func Get() *MainRepoConfig { if instance == nil { singletonLock.Do(func() { - c, err := reloadConfig() + c, d, err := reloadConfig() if err != nil { logrus.Fatal(err) } instance = c + domains = d }) } return instance } + +func AllDomains() []*DomainRepoConfig { + vals := make([]*DomainRepoConfig, 0) + for _, v := range domains { + vals = append(vals, v) + } + return vals +} + +func GetDomain(domain string) *DomainRepoConfig { + Get() // Ensure we generate a main config + return domains[domain] +} + +func UniqueDatastores() []DatastoreConfig { + confs := make([]DatastoreConfig, 0) + + for _, dsc := range Get().DataStores { + confs = append(confs, dsc) + } + + for _, d := range AllDomains() { + for _, dsc := range d.DataStores { + found := false + for _, edsc := range confs { + if edsc.Type == dsc.Type { + if dsc.Type == "file" && edsc.Options["path"] == dsc.Options["path"] { + found = true + break + } else if dsc.Type == "s3" && edsc.Options["endpoint"] == dsc.Options["endpoint"] && edsc.Options["bucketName"] == dsc.Options["bucketName"] { + found = true + break + } + } + } + if found { + continue + } + confs = append(confs, dsc) + } + } + + return confs +} + +func PrintDomainInfo() { + logrus.Info("Domains loaded:") + for _, d := range domains { + logrus.Info(fmt.Sprintf("\t%s (%s)", d.Name, d.ClientServerApi)) + } +} diff --git a/common/config/conf_domain.go b/common/config/conf_domain.go index 45af7772a04004d2582df3423df19222022e68fc..2f2bcfcbb0eee6bffc2e4575cc7217113abb7ad6 100644 --- a/common/config/conf_domain.go +++ b/common/config/conf_domain.go @@ -2,6 +2,7 @@ package config type DomainRepoConfig struct { MinimumRepoConfig `yaml:",inline"` + HomeserverConfig `yaml:",inline"` Downloads DownloadsConfig `yaml:"downloads"` Thumbnails ThumbnailsConfig `yaml:"thumbnails"` UrlPreviews UrlPreviewsConfig `yaml:"urlPreviews"` @@ -10,6 +11,12 @@ type DomainRepoConfig struct { func NewDefaultDomainConfig() DomainRepoConfig { return DomainRepoConfig{ MinimumRepoConfig: NewDefaultMinimumRepoConfig(), + HomeserverConfig: HomeserverConfig{ + Name: "UNDEFINED", + ClientServerApi: "https://UNDEFINED", + BackoffAt: 10, + AdminApiKind: "matrix", + }, Downloads: DownloadsConfig{ MaxSizeBytes: 104857600, // 100mb FailureCacheMinutes: 15, diff --git a/common/config/watch.go b/common/config/watch.go index 08ec19099fc16de7f7b5e030fd2526264c388111..8b6dc80c358632f02d304bb6cb9138cf703ee053 100644 --- a/common/config/watch.go +++ b/common/config/watch.go @@ -44,7 +44,7 @@ func Watch() *fsnotify.Watcher { func onFileChanged() { logrus.Info("Config file change detected - reloading") configNow := Get() - configNew, err := reloadConfig() + configNew, domainsNew, err := reloadConfig() if err != nil { logrus.Error("Error reloading configuration - ignoring") logrus.Error(err) @@ -53,6 +53,8 @@ func onFileChanged() { logrus.Info("Applying reloaded config live") instance = configNew + domains = domainsNew + PrintDomainInfo() bindAddressChange := configNew.General.BindAddress != configNow.General.BindAddress bindPortChange := configNew.General.Port != configNow.General.Port diff --git a/matrix/matrix.go b/matrix/matrix.go index 7dc76fb84b4dbc103d9be91163c20c8edf75fba0..70417b6cec4a6734940c511b580fc0b4d0125e45 100644 --- a/matrix/matrix.go +++ b/matrix/matrix.go @@ -6,13 +6,12 @@ import ( "github.com/rubyist/circuitbreaker" "github.com/turt2live/matrix-media-repo/common" "github.com/turt2live/matrix-media-repo/common/config" - "github.com/turt2live/matrix-media-repo/util" ) var breakers = &sync.Map{} -func getBreakerAndConfig(serverName string) (*config.HomeserverConfig, *circuit.Breaker) { - hs := util.GetHomeserverConfig(serverName) +func getBreakerAndConfig(serverName string) (*config.DomainRepoConfig, *circuit.Breaker) { + hs := config.GetDomain(serverName) var cb *circuit.Breaker cbRaw, hasCb := breakers.Load(hs.Name) diff --git a/storage/datastore/datastore.go b/storage/datastore/datastore.go index 1b8016b3870046ee733c404bbc16ef215415a518..81438cf509c62d9ae1b85eb8400e3bb442901ff1 100644 --- a/storage/datastore/datastore.go +++ b/storage/datastore/datastore.go @@ -39,7 +39,7 @@ func LocateDatastore(ctx rcontext.RequestContext, datastoreId string) (*Datastor return nil, err } - conf, err := GetDatastoreConfig(ds, ctx) + conf, err := GetDatastoreConfig(ds) if err != nil { return nil, err } @@ -55,8 +55,8 @@ func DownloadStream(ctx rcontext.RequestContext, datastoreId string, location st return ref.DownloadFile(location) } -func GetDatastoreConfig(ds *types.Datastore, ctx rcontext.RequestContext) (config.DatastoreConfig, error) { - for _, dsConf := range ctx.Config.DataStores { +func GetDatastoreConfig(ds *types.Datastore) (config.DatastoreConfig, error) { + for _, dsConf := range config.UniqueDatastores() { if dsConf.Type == ds.Type && GetUriForDatastore(dsConf) == ds.Uri { return dsConf, nil } diff --git a/util/config.go b/util/config.go index 04980a6af268ebfbd4a1dfe073d235a4d00fdfd8..18bd40e1feedeaa8a1b67b56a20916a3c95a6ea5 100644 --- a/util/config.go +++ b/util/config.go @@ -5,22 +5,10 @@ import ( ) func IsServerOurs(server string) bool { - hs := GetHomeserverConfig(server) + hs := config.GetDomain(server) return hs != nil } -// TODO: Replace with per-domain config -func GetHomeserverConfig(server string) *config.HomeserverConfig { - for i := 0; i < len(config.Get().Homeservers); i++ { - hs := config.Get().Homeservers[i] - if hs.Name == server { - return &hs - } - } - - return nil -} - func IsGlobalAdmin(userId string) bool { for _, admin := range config.Get().Admins { if admin == userId {