Skip to content
Snippets Groups Projects
Commit eabb6c0f authored by Travis Ralston's avatar Travis Ralston
Browse files

Re-add mime upload check

parent 1ca66cc5
No related branches found
No related tags found
No related merge requests found
......@@ -9,6 +9,7 @@ import (
"github.com/pkg/errors"
"github.com/ryanuber/go-glob"
"github.com/sirupsen/logrus"
"github.com/turt2live/matrix-media-repo/common"
"github.com/turt2live/matrix-media-repo/common/config"
"github.com/turt2live/matrix-media-repo/storage"
"github.com/turt2live/matrix-media-repo/storage/datastore"
......@@ -132,22 +133,29 @@ func StoreDirect(contents io.ReadCloser, contentType string, filename string, us
return nil, err
}
info, err := ds.UploadFile(contents, ctx, log)
if err != nil {
return nil, err
}
stream, err := ds.DownloadFile(info.Location)
if err != nil {
return nil, err
}
// TODO: Enable mime checking for datastores
//fileMime, err := util.GetMimeType(fileLocation)
//if err != nil {
// log.Error("Error while checking content type of file: ", err.Error())
// os.Remove(fileLocation) // delete temp file
// return nil, err
//}
//
//allowed := IsAllowed(fileMime, contentType, userId, log)
//if !allowed {
// log.Warn("Content type " + fileMime + " (reported as " + contentType + ") is not allowed to be uploaded")
//
// os.Remove(fileLocation) // delete temp file
// return nil, common.ErrMediaNotAllowed
//}
fileMime, err := util.GetMimeType(stream)
if err != nil {
log.Error("Error while checking content type of file: ", err.Error())
ds.DeleteObject(info.Location) // delete temp object
return nil, err
}
allowed := IsAllowed(fileMime, contentType, userId, log)
if !allowed {
log.Warn("Content type " + fileMime + " (reported as " + contentType + ") is not allowed to be uploaded")
ds.DeleteObject(info.Location) // delete temp object
return nil, common.ErrMediaNotAllowed
}
db := storage.GetDatabase().GetMediaStore(ctx, log)
records, err := db.GetByHash(info.Sha256Hash)
......
......@@ -3,22 +3,17 @@ package util
import (
"io"
"net/http"
"os"
"strings"
"github.com/h2non/filetype"
)
func GetMimeType(filePath string) (string, error) {
f, err := os.Open(filePath)
if err != nil {
return "", err
}
defer f.Close()
func GetMimeType(stream io.ReadCloser) (string, error) {
defer stream.Close()
// We only need the first 512 bytes at most to determine the file type
buf := make([]byte, 512)
_, err = f.Read(buf)
_, err := stream.Read(buf)
if err != nil && err != io.EOF {
return "", err
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment