Add support for encrypted files

This commit is contained in:
Tulir Asokan
2020-04-29 02:45:54 +03:00
parent fa04323daf
commit a9dff6da73
7 changed files with 82 additions and 47 deletions

View File

@ -17,7 +17,6 @@
package matrix
import (
"bytes"
"context"
"crypto/tls"
"encoding/gob"
@ -38,6 +37,7 @@ import (
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/crypto/attachment"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/format"
"maunium.net/go/mautrix/id"
@ -1061,7 +1061,7 @@ func cp(src, dst string) error {
return out.Close()
}
func (c *Container) DownloadToDisk(uri id.ContentURI, target string) (fullPath string, err error) {
func (c *Container) DownloadToDisk(uri id.ContentURI, file *attachment.EncryptedFile, target string) (fullPath string, err error) {
cachePath := c.GetCachePath(uri)
if target == "" {
fullPath = cachePath
@ -1072,21 +1072,27 @@ func (c *Container) DownloadToDisk(uri id.ContentURI, target string) (fullPath s
}
if _, statErr := os.Stat(cachePath); os.IsNotExist(statErr) {
var file *os.File
file, err = os.OpenFile(cachePath, os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
return
}
defer file.Close()
var body io.ReadCloser
body, err = c.client.Download(uri)
if err != nil {
return
}
defer body.Close()
_, err = io.Copy(file, body)
var data []byte
data, err = ioutil.ReadAll(body)
_ = body.Close()
if err != nil {
return
}
if file != nil {
data, err = file.Decrypt(data)
if err != nil {
return
}
}
err = ioutil.WriteFile(cachePath, data, 0600)
if err != nil {
return
}
@ -1106,7 +1112,7 @@ func (c *Container) DownloadToDisk(uri id.ContentURI, target string) (fullPath s
// Download fetches the given Matrix content (mxc) URL and returns the data, homeserver, file ID and potential errors.
//
// The file will be either read from the media cache (if found) or downloaded from the server.
func (c *Container) Download(uri id.ContentURI) (data []byte, err error) {
func (c *Container) Download(uri id.ContentURI, file *attachment.EncryptedFile) (data []byte, err error) {
cacheFile := c.GetCachePath(uri)
var info os.FileInfo
if info, err = os.Stat(cacheFile); err == nil && !info.IsDir() {
@ -1116,7 +1122,7 @@ func (c *Container) Download(uri id.ContentURI) (data []byte, err error) {
}
}
data, err = c.download(uri, cacheFile)
data, err = c.download(uri, file, cacheFile)
return
}
@ -1124,21 +1130,25 @@ func (c *Container) GetDownloadURL(uri id.ContentURI) string {
return c.client.GetDownloadURL(uri)
}
func (c *Container) download(uri id.ContentURI, cacheFile string) (data []byte, err error) {
func (c *Container) download(uri id.ContentURI, file *attachment.EncryptedFile, cacheFile string) (data []byte, err error) {
var body io.ReadCloser
body, err = c.client.Download(uri)
if err != nil {
return
}
defer body.Close()
var buf bytes.Buffer
_, err = io.Copy(&buf, body)
data, err = ioutil.ReadAll(body)
_ = body.Close()
if err != nil {
return
}
data = buf.Bytes()
if file != nil {
data, err = file.Decrypt(data)
if err != nil {
return
}
}
err = ioutil.WriteFile(cacheFile, data, 0600)
return