diff --git a/crypto.go b/crypto.go index 091a4bed..77bf3cb2 100644 --- a/crypto.go +++ b/crypto.go @@ -4,15 +4,19 @@ import ( "crypto/aes" "crypto/cipher" "crypto/rand" + "crypto/sha1" "crypto/sha256" "encoding/binary" "encoding/hex" "fmt" mathrand "math/rand" + "os" "strings" "time" + "github.com/mars9/crypt" "github.com/schollz/mnemonicode" + log "github.com/sirupsen/logrus" "golang.org/x/crypto/pbkdf2" ) @@ -76,3 +80,46 @@ func HashBytes(data []byte) string { sum := sha256.Sum256(data) return fmt.Sprintf("%x", sum) } + +func EncryptFile(inputFilename string, outputFilename string, password string) error { + return cryptFile(inputFilename, outputFilename, password, true) +} + +func DecryptFile(inputFilename string, outputFilename string, password string) error { + return cryptFile(inputFilename, outputFilename, password, false) +} + +func cryptFile(inputFilename string, outputFilename string, password string, encrypt bool) error { + in, err := os.Open(inputFilename) + if err != nil { + return err + } + defer in.Close() + out, err := os.Create(outputFilename) + if err != nil { + return err + } + defer func() { + if err := out.Sync(); err != nil { + log.Error(err) + } + if err := out.Close(); err != nil { + log.Error(err) + } + }() + c := &crypt.Crypter{ + HashFunc: sha1.New, + HashSize: sha1.Size, + Key: crypt.NewPbkdf2Key([]byte(password), 32), + } + if encrypt { + if err := c.Encrypt(out, in); err != nil { + return err + } + } else { + if err := c.Decrypt(out, in); err != nil { + return err + } + } + return nil +} diff --git a/crypto_test.go b/crypto_test.go index 43465c1b..c475d5ba 100644 --- a/crypto_test.go +++ b/crypto_test.go @@ -1,6 +1,8 @@ package main import ( + "io/ioutil" + "os" "testing" ) @@ -19,3 +21,29 @@ func TestEncrypt(t *testing.T) { t.Error("should not work!") } } + +func TestEncryptFiles(t *testing.T) { + key := GetRandomName() + if err := ioutil.WriteFile("temp", []byte("hello, world!"), 0644); err != nil { + t.Error(err) + } + if err := EncryptFile("temp", "temp.enc", key); err != nil { + t.Error(err) + } + if err := DecryptFile("temp.enc", "temp.dec", key); err != nil { + t.Error(err) + } + data, err := ioutil.ReadFile("temp.dec") + if string(data) != "hello, world!" { + t.Errorf("Got something weird: " + string(data)) + } + if err != nil { + t.Error(err) + } + if err := DecryptFile("temp.enc", "temp.dec", key+"wrong password"); err == nil { + t.Error("should throw error!") + } + os.Remove("temp.dec") + os.Remove("temp.enc") + +}