From e273b83398943a79dd8ee5d95ef7b0e844d618a0 Mon Sep 17 00:00:00 2001
From: Travis Ralston <travpc@gmail.com>
Date: Sat, 28 Dec 2019 15:42:08 -0700
Subject: [PATCH] Fix config inheritance using maps

Objects get default values, which override the things we don't want...
---
 common/config/access.go | 82 ++++++++++++++++++++++++++++-------------
 common/config/util.go   | 26 +++++++++++++
 2 files changed, 82 insertions(+), 26 deletions(-)
 create mode 100644 common/config/util.go

diff --git a/common/config/access.go b/common/config/access.go
index 5c5f2574..cefa9b3f 100644
--- a/common/config/access.go
+++ b/common/config/access.go
@@ -25,7 +25,6 @@ var singletonLock = &sync.Once{}
 var domains = make(map[string]*DomainRepoConfig)
 
 func reloadConfig() (*MainRepoConfig, map[string]*DomainRepoConfig, error) {
-	c := NewDefaultMainConfig()
 	domainConfs := make(map[string]*DomainRepoConfig)
 
 	// Write a default config if the one given doesn't exist
@@ -33,7 +32,7 @@ func reloadConfig() (*MainRepoConfig, map[string]*DomainRepoConfig, error) {
 	exists := err == nil || !os.IsNotExist(err)
 	if !exists {
 		fmt.Println("Generating new configuration...")
-		configBytes, err := yaml.Marshal(c)
+		configBytes, err := yaml.Marshal(NewDefaultMainConfig())
 		if err != nil {
 			return nil, nil, err
 		}
@@ -78,7 +77,12 @@ func reloadConfig() (*MainRepoConfig, map[string]*DomainRepoConfig, error) {
 		pathsOrdered = append(pathsOrdered, Path)
 	}
 
+	// Note: the rest of this relies on maps before finalizing on objects because when
+	// the yaml is parsed it causes default values for the types to land in the overridden
+	// config. We don't want this, so we use maps which inherently override only what is
+	// present then we convert that overtop of a default object we create.
 	pendingDomainConfigs := make(map[string][][]byte)
+	cMap := make(map[string]interface{})
 
 	for _, p := range pathsOrdered {
 		logrus.Info("Loading config file: ", p)
@@ -121,50 +125,61 @@ func reloadConfig() (*MainRepoConfig, map[string]*DomainRepoConfig, error) {
 		}
 
 		// Not a domain config - parse into regular config
-		err = yaml.Unmarshal(buffer, &c)
+		err = yaml.Unmarshal(buffer, &cMap)
 		if err != nil {
 			return nil, nil, err
 		}
 	}
 
-	newDomainConfig := func() DomainRepoConfig {
-		dc := NewDefaultDomainConfig()
-		dc.DataStores = c.DataStores
-		dc.Archiving = c.Archiving
-		dc.Uploads = c.Uploads
-		dc.Identicons = c.Identicons
-		dc.Quarantine = c.Quarantine
-		dc.TimeoutSeconds = c.TimeoutSeconds
-		dc.Downloads = c.Downloads.DownloadsConfig
-		dc.Thumbnails = c.Thumbnails.ThumbnailsConfig
-		dc.UrlPreviews = c.UrlPreviews.UrlPreviewsConfig
-		return dc
+	c := NewDefaultMainConfig()
+	err = mapToObjYaml(cMap, &c)
+	if err != nil {
+		return nil, nil, err
 	}
 
 	// Start building domain configs
+	dMaps := make(map[string]map[string]interface{})
 	for _, d := range c.Homeservers {
-		dc := newDomainConfig()
-		domainConfs[d.Name] = &dc
-		domainConfs[d.Name].Name = d.Name
-		domainConfs[d.Name].ClientServerApi = d.ClientServerApi
-		domainConfs[d.Name].BackoffAt = d.BackoffAt
-		domainConfs[d.Name].AdminApiKind = d.AdminApiKind
+		dc := DomainConfigFrom(c)
+		dc.Name = d.Name
+		dc.ClientServerApi = d.ClientServerApi
+		dc.BackoffAt = d.BackoffAt
+		dc.AdminApiKind = d.AdminApiKind
+
+		m, err := objToMapYaml(dc)
+		if err != nil {
+			return nil, nil, err
+		}
+		dMaps[d.Name] = m
 	}
 	for hs, bs := range pendingDomainConfigs {
-		if _, ok := domainConfs[hs]; !ok {
-			dc := newDomainConfig()
-			domainConfs[hs] = &dc
-			domainConfs[hs].Name = hs
+		if _, ok := dMaps[hs]; !ok {
+			dc := DomainConfigFrom(c)
+			dc.Name = hs
+
+			m, err := objToMapYaml(dc)
+			if err != nil {
+				return nil, nil, err
+			}
+			dMaps[hs] = m
 		}
 
 		for _, b := range bs {
-			err = yaml.Unmarshal(b, domainConfs[hs])
+			m := dMaps[hs]
+			err = yaml.Unmarshal(b, &m)
 			if err != nil {
 				return nil, nil, err
 			}
 		}
 
+		c := DomainRepoConfig{}
+		err = mapToObjYaml(dMaps[hs], &c)
+		if err != nil {
+			return nil, nil, err
+		}
+
 		// For good measure...
+		domainConfs[hs] = &c
 		domainConfs[hs].Name = hs
 	}
 
@@ -198,6 +213,21 @@ func GetDomain(domain string) *DomainRepoConfig {
 	return domains[domain]
 }
 
+func DomainConfigFrom(c MainRepoConfig) DomainRepoConfig {
+	// HACK: We should be better at this kind of inheritance
+	dc := NewDefaultDomainConfig()
+	dc.DataStores = c.DataStores
+	dc.Archiving = c.Archiving
+	dc.Uploads = c.Uploads
+	dc.Identicons = c.Identicons
+	dc.Quarantine = c.Quarantine
+	dc.TimeoutSeconds = c.TimeoutSeconds
+	dc.Downloads = c.Downloads.DownloadsConfig
+	dc.Thumbnails = c.Thumbnails.ThumbnailsConfig
+	dc.UrlPreviews = c.UrlPreviews.UrlPreviewsConfig
+	return dc
+}
+
 func UniqueDatastores() []DatastoreConfig {
 	confs := make([]DatastoreConfig, 0)
 
diff --git a/common/config/util.go b/common/config/util.go
new file mode 100644
index 00000000..337e122f
--- /dev/null
+++ b/common/config/util.go
@@ -0,0 +1,26 @@
+package config
+
+import (
+	"gopkg.in/yaml.v2"
+)
+
+func mapToObjYaml(input map[string]interface{}, ref interface{}) error {
+	encoded, err := yaml.Marshal(input)
+	if err != nil {
+		return err
+	}
+
+	err = yaml.Unmarshal(encoded, ref)
+	return err
+}
+
+func objToMapYaml(input interface{}) (map[string]interface{}, error) {
+	encoded, err := yaml.Marshal(input)
+	if err != nil {
+		return nil, err
+	}
+
+	m := make(map[string]interface{})
+	err = yaml.Unmarshal(encoded, &m)
+	return m, err
+}
-- 
GitLab