diff --git a/src/github.com/turt2live/matrix-media-repo/controllers/upload_controller/upload_controller.go b/src/github.com/turt2live/matrix-media-repo/controllers/upload_controller/upload_controller.go index 93c4e38f0d4b84e303bc2b770c01e78947aabf69..f2ef97cd8acd80ac7a0249ed7d1d1975a8be2587 100644 --- a/src/github.com/turt2live/matrix-media-repo/controllers/upload_controller/upload_controller.go +++ b/src/github.com/turt2live/matrix-media-repo/controllers/upload_controller/upload_controller.go @@ -9,6 +9,7 @@ import ( "github.com/pkg/errors" "github.com/ryanuber/go-glob" "github.com/sirupsen/logrus" + "github.com/turt2live/matrix-media-repo/common" "github.com/turt2live/matrix-media-repo/common/config" "github.com/turt2live/matrix-media-repo/storage" "github.com/turt2live/matrix-media-repo/storage/datastore" @@ -132,22 +133,29 @@ func StoreDirect(contents io.ReadCloser, contentType string, filename string, us return nil, err } info, err := ds.UploadFile(contents, ctx, log) + if err != nil { + return nil, err + } + + stream, err := ds.DownloadFile(info.Location) + if err != nil { + return nil, err + } - // TODO: Enable mime checking for datastores - //fileMime, err := util.GetMimeType(fileLocation) - //if err != nil { - // log.Error("Error while checking content type of file: ", err.Error()) - // os.Remove(fileLocation) // delete temp file - // return nil, err - //} - // - //allowed := IsAllowed(fileMime, contentType, userId, log) - //if !allowed { - // log.Warn("Content type " + fileMime + " (reported as " + contentType + ") is not allowed to be uploaded") - // - // os.Remove(fileLocation) // delete temp file - // return nil, common.ErrMediaNotAllowed - //} + fileMime, err := util.GetMimeType(stream) + if err != nil { + log.Error("Error while checking content type of file: ", err.Error()) + ds.DeleteObject(info.Location) // delete temp object + return nil, err + } + + allowed := IsAllowed(fileMime, contentType, userId, log) + if !allowed { + log.Warn("Content type " + fileMime + " (reported as " + contentType + ") is not allowed to be uploaded") + + ds.DeleteObject(info.Location) // delete temp object + return nil, common.ErrMediaNotAllowed + } db := storage.GetDatabase().GetMediaStore(ctx, log) records, err := db.GetByHash(info.Sha256Hash) diff --git a/src/github.com/turt2live/matrix-media-repo/util/mime.go b/src/github.com/turt2live/matrix-media-repo/util/mime.go index ea0d49ec9fac65f066416a2f00864947ff136c00..e0f7bb6b24e3f72c962895cd94ed879270b6bc87 100644 --- a/src/github.com/turt2live/matrix-media-repo/util/mime.go +++ b/src/github.com/turt2live/matrix-media-repo/util/mime.go @@ -3,22 +3,17 @@ package util import ( "io" "net/http" - "os" "strings" "github.com/h2non/filetype" ) -func GetMimeType(filePath string) (string, error) { - f, err := os.Open(filePath) - if err != nil { - return "", err - } - defer f.Close() +func GetMimeType(stream io.ReadCloser) (string, error) { + defer stream.Close() // We only need the first 512 bytes at most to determine the file type buf := make([]byte, 512) - _, err = f.Read(buf) + _, err := stream.Read(buf) if err != nil && err != io.EOF { return "", err }