Skip to content
Snippets Groups Projects
Commit 10f7d207 authored by Travis Ralston's avatar Travis Ralston
Browse files

Reduce usage of io.ReadAll

parent 8a5109d4
No related branches found
No related tags found
No related merge requests found
...@@ -2,7 +2,6 @@ package custom ...@@ -2,7 +2,6 @@ package custom
import ( import (
"encoding/json" "encoding/json"
"io"
"net/http" "net/http"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
...@@ -41,15 +40,9 @@ func GetFederationInfo(r *http.Request, rctx rcontext.RequestContext, user _apim ...@@ -41,15 +40,9 @@ func GetFederationInfo(r *http.Request, rctx rcontext.RequestContext, user _apim
return _responses.InternalServerError(err.Error()) return _responses.InternalServerError(err.Error())
} }
c, err := io.ReadAll(versionResponse.Body) decoder := json.NewDecoder(versionResponse.Body)
if err != nil {
rctx.Log.Error(err)
sentry.CaptureException(err)
return _responses.InternalServerError(err.Error())
}
out := make(map[string]interface{}) out := make(map[string]interface{})
err = json.Unmarshal(c, &out) err = decoder.Decode(&out)
if err != nil { if err != nil {
rctx.Log.Error(err) rctx.Log.Error(err)
sentry.CaptureException(err) sentry.CaptureException(err)
......
...@@ -89,29 +89,29 @@ func extractPrefixTo(pathName string, destination string) { ...@@ -89,29 +89,29 @@ func extractPrefixTo(pathName string, destination string) {
continue continue
} }
logrus.Infof("Decoding %s", f)
b, err := base64.StdEncoding.DecodeString(b64) b, err := base64.StdEncoding.DecodeString(b64)
if err != nil { if err != nil {
panic(err) panic(err)
} }
logrus.Infof("Decompressing %s", f)
gr, err := gzip.NewReader(bytes.NewBuffer(b)) gr, err := gzip.NewReader(bytes.NewBuffer(b))
if err != nil { if err != nil {
panic(err) panic(err)
} }
//noinspection GoDeferInLoop,GoUnhandledErrorResult
defer gr.Close() dest := path.Join(destination, filepath.Base(f))
uncompressedBytes, err := io.ReadAll(gr) logrus.Debugf("Writing %s to %s", f, dest)
file, err := os.Create(dest)
if err != nil { if err != nil {
panic(err) panic(err)
} }
dest := path.Join(destination, filepath.Base(f)) _, err = io.Copy(file, gr)
logrus.Infof("Writing %s to %s", f, dest)
err = os.WriteFile(dest, uncompressedBytes, 0644)
if err != nil { if err != nil {
panic(err) panic(err)
} }
_ = gr.Close()
file.Close()
} }
} }
...@@ -9,7 +9,6 @@ import ( ...@@ -9,7 +9,6 @@ import (
"sync" "sync"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/turt2live/matrix-media-repo/util/stream_util"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
...@@ -109,7 +108,8 @@ func reloadConfig() (*MainRepoConfig, map[string]*DomainRepoConfig, error) { ...@@ -109,7 +108,8 @@ func reloadConfig() (*MainRepoConfig, map[string]*DomainRepoConfig, error) {
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
defer stream_util.DumpAndCloseStream(f) //goland:noinspection GoDeferInLoop
defer f.Close()
buffer, err := io.ReadAll(f) buffer, err := io.ReadAll(f)
if err != nil { if err != nil {
......
...@@ -10,7 +10,6 @@ import ( ...@@ -10,7 +10,6 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/common/rcontext"
"github.com/turt2live/matrix-media-repo/util/stream_util"
) )
// Based in part on https://github.com/matrix-org/gomatrix/blob/072b39f7fa6b40257b4eead8c958d71985c28bdd/client.go#L180-L243 // Based in part on https://github.com/matrix-org/gomatrix/blob/072b39f7fa6b40257b4eead8c958d71985c28bdd/client.go#L180-L243
...@@ -48,7 +47,7 @@ func doRequest(ctx rcontext.RequestContext, method string, urlStr string, body i ...@@ -48,7 +47,7 @@ func doRequest(ctx rcontext.RequestContext, method string, urlStr string, body i
if err != nil { if err != nil {
return err return err
} }
defer stream_util.DumpAndCloseStream(res.Body) defer res.Body.Close()
contents, err := io.ReadAll(res.Body) contents, err := io.ReadAll(res.Body)
if err != nil { if err != nil {
......
...@@ -6,7 +6,6 @@ import ( ...@@ -6,7 +6,6 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io"
"net" "net"
"net/http" "net/http"
"os" "os"
...@@ -105,64 +104,62 @@ func GetServerApiUrl(hostname string) (string, string, error) { ...@@ -105,64 +104,62 @@ func GetServerApiUrl(hostname string) (string, string, error) {
r, err := http.Get(fmt.Sprintf("https://%s/.well-known/matrix/server", h)) r, err := http.Get(fmt.Sprintf("https://%s/.well-known/matrix/server", h))
if err == nil && r.StatusCode == http.StatusOK { if err == nil && r.StatusCode == http.StatusOK {
// Try parsing .well-known // Try parsing .well-known
c, err2 := io.ReadAll(r.Body) decoder := json.NewDecoder(r.Body)
if err2 == nil { wk := &wellknownServerResponse{}
wk := &wellknownServerResponse{} err3 := decoder.Decode(&wk)
err3 := json.Unmarshal(c, wk) if err3 == nil && wk.ServerAddr != "" {
if err3 == nil && wk.ServerAddr != "" { wkHost, wkPort, err4 := net.SplitHostPort(wk.ServerAddr)
wkHost, wkPort, err4 := net.SplitHostPort(wk.ServerAddr) wkDefPort := false
wkDefPort := false if err4 != nil && strings.HasSuffix(err4.Error(), "missing port in address") {
if err4 != nil && strings.HasSuffix(err4.Error(), "missing port in address") { wkHost, wkPort, err4 = net.SplitHostPort(wk.ServerAddr + ":8448")
wkHost, wkPort, err4 = net.SplitHostPort(wk.ServerAddr + ":8448") wkDefPort = true
wkDefPort = true }
if err4 == nil {
// Step 3a: if the delegated host is an IP address, use that (regardless of port)
logrus.Debug("Checking if WK host is an IP: " + wkHost)
if is.IP(wkHost) {
url := fmt.Sprintf("https://%s", net.JoinHostPort(wkHost, wkPort))
server := cachedServer{url, wk.ServerAddr}
apiUrlCacheInstance.Set(hostname, server, cache.DefaultExpiration)
logrus.Debug("Server API URL for " + hostname + " is " + url + " (WK; IP address)")
return url, wk.ServerAddr, nil
} }
if err4 == nil {
// Step 3a: if the delegated host is an IP address, use that (regardless of port)
logrus.Debug("Checking if WK host is an IP: " + wkHost)
if is.IP(wkHost) {
url := fmt.Sprintf("https://%s", net.JoinHostPort(wkHost, wkPort))
server := cachedServer{url, wk.ServerAddr}
apiUrlCacheInstance.Set(hostname, server, cache.DefaultExpiration)
logrus.Debug("Server API URL for " + hostname + " is " + url + " (WK; IP address)")
return url, wk.ServerAddr, nil
}
// Step 3b: if the delegated host is not an IP and an explicit port is given, use that // Step 3b: if the delegated host is not an IP and an explicit port is given, use that
logrus.Debug("Checking if WK is using default port? ", wkDefPort) logrus.Debug("Checking if WK is using default port? ", wkDefPort)
if !wkDefPort { if !wkDefPort {
wkHost = net.JoinHostPort(wkHost, wkPort) wkHost = net.JoinHostPort(wkHost, wkPort)
url := fmt.Sprintf("https://%s", wkHost) url := fmt.Sprintf("https://%s", wkHost)
server := cachedServer{url, wkHost} server := cachedServer{url, wkHost}
apiUrlCacheInstance.Set(hostname, server, cache.DefaultExpiration) apiUrlCacheInstance.Set(hostname, server, cache.DefaultExpiration)
logrus.Debug("Server API URL for " + hostname + " is " + url + " (WK; explicit port)") logrus.Debug("Server API URL for " + hostname + " is " + url + " (WK; explicit port)")
return url, wkHost, nil return url, wkHost, nil
} }
// Step 3c: if the delegated host is not an IP and doesn't have a port, start a SRV lookup and use it // Step 3c: if the delegated host is not an IP and doesn't have a port, start a SRV lookup and use it
// Note: we ignore errors here because the hostname will fail elsewhere. // Note: we ignore errors here because the hostname will fail elsewhere.
logrus.Debug("Doing SRV on WK host ", wkHost) logrus.Debug("Doing SRV on WK host ", wkHost)
_, addrs, _ := net.LookupSRV("matrix", "tcp", wkHost) _, addrs, _ := net.LookupSRV("matrix", "tcp", wkHost)
if len(addrs) > 0 { if len(addrs) > 0 {
// Trim off the trailing period if there is one (golang doesn't like this) // Trim off the trailing period if there is one (golang doesn't like this)
realAddr := addrs[0].Target realAddr := addrs[0].Target
if realAddr[len(realAddr)-1:] == "." { if realAddr[len(realAddr)-1:] == "." {
realAddr = realAddr[0 : len(realAddr)-1] realAddr = realAddr[0 : len(realAddr)-1]
}
url := fmt.Sprintf("https://%s", net.JoinHostPort(realAddr, strconv.Itoa(int(addrs[0].Port))))
server := cachedServer{url, wkHost}
apiUrlCacheInstance.Set(hostname, server, cache.DefaultExpiration)
logrus.Debug("Server API URL for " + hostname + " is " + url + " (WK; SRV)")
return url, wkHost, nil
} }
url := fmt.Sprintf("https://%s", net.JoinHostPort(realAddr, strconv.Itoa(int(addrs[0].Port))))
// Step 3d: use the delegated host as-is
logrus.Debug("Using .well-known as-is for ", wkHost)
url := fmt.Sprintf("https://%s", net.JoinHostPort(wkHost, wkPort))
server := cachedServer{url, wkHost} server := cachedServer{url, wkHost}
apiUrlCacheInstance.Set(hostname, server, cache.DefaultExpiration) apiUrlCacheInstance.Set(hostname, server, cache.DefaultExpiration)
logrus.Debug("Server API URL for " + hostname + " is " + url + " (WK; fallback)") logrus.Debug("Server API URL for " + hostname + " is " + url + " (WK; SRV)")
return url, wkHost, nil return url, wkHost, nil
} }
// Step 3d: use the delegated host as-is
logrus.Debug("Using .well-known as-is for ", wkHost)
url := fmt.Sprintf("https://%s", net.JoinHostPort(wkHost, wkPort))
server := cachedServer{url, wkHost}
apiUrlCacheInstance.Set(hostname, server, cache.DefaultExpiration)
logrus.Debug("Server API URL for " + hostname + " is " + url + " (WK; fallback)")
return url, wkHost, nil
} }
} }
} }
......
File moved
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment