From 7b68bcaea2db9a3b1020e1c43d226d790db4e59a Mon Sep 17 00:00:00 2001 From: Zack Scholl Date: Sun, 17 Nov 2019 15:17:06 -0800 Subject: [PATCH] make encryption functional --- src/croc/croc.go | 16 +++++------ src/croc/croc_test.go | 3 ++- src/crypt/crypt.go | 55 +++++++++++--------------------------- src/crypt/crypt_test.go | 59 ++++++++++++++++------------------------- src/message/message.go | 18 ++++++++----- 5 files changed, 58 insertions(+), 93 deletions(-) diff --git a/src/croc/croc.go b/src/croc/croc.go index 63358d5e..52970609 100644 --- a/src/croc/croc.go +++ b/src/croc/croc.go @@ -63,7 +63,7 @@ type Options struct { type Client struct { Options Options Pake *pake.Pake - Key crypt.Encryption + Key []byte ExternalIP, ExternalIPConnected string // steps involved in forming relationship @@ -147,12 +147,6 @@ func New(ops Options) (c *Client, err error) { c.conn = make([]*comm.Comm, 16) - // use default key (no encryption, until PAKE succeeds) - c.Key, err = crypt.New(nil, nil) - if err != nil { - return - } - // initialize pake if c.Options.IsSender { c.Pake, err = pake.Init([]byte(c.Options.SharedSecret), 1, elliptic.P521(), 1*time.Microsecond) @@ -648,10 +642,11 @@ func (c *Client) processMessageSalt(m message.Message) (done bool, err error) { if err != nil { return true, err } - c.Key, err = crypt.New(key, m.Bytes) + c.Key, _, err = crypt.New(key, m.Bytes) if err != nil { return true, err } + log.Debugf("key = %+x", c.Key) if c.ExternalIPConnected == "" { // it can be preset by the local relay c.ExternalIPConnected = m.Message @@ -1040,7 +1035,7 @@ func (c *Client) receiveData(i int) { break } - data, err = c.Key.Decrypt(data) + data, err = crypt.Decrypt(data, c.Key) if err != nil { panic(err) } @@ -1126,10 +1121,11 @@ func (c *Client) sendData(i int) { posByte := make([]byte, 8) binary.LittleEndian.PutUint64(posByte, pos) - dataToSend, err := c.Key.Encrypt( + dataToSend, err := crypt.Encrypt( compress.Compress( append(posByte, data[:n]...), ), + c.Key, ) if err != nil { panic(err) diff --git a/src/croc/croc_test.go b/src/croc/croc_test.go index c731914a..45aa5895 100644 --- a/src/croc/croc_test.go +++ b/src/croc/croc_test.go @@ -20,7 +20,7 @@ func TestCroc(t *testing.T) { go tcp.Run("debug", "8083") go tcp.Run("debug", "8084") go tcp.Run("debug", "8085") - time.Sleep(1 * time.Second) + time.Sleep(3 * time.Second) log.Debug("setting up sender") sender, err := New(Options{ @@ -36,6 +36,7 @@ func TestCroc(t *testing.T) { if err != nil { panic(err) } + time.Sleep(3 * time.Second) log.Debug("setting up receiver") receiver, err := New(Options{ diff --git a/src/crypt/crypt.go b/src/crypt/crypt.go index 65ba77c4..b864d467 100644 --- a/src/crypt/crypt.go +++ b/src/crypt/crypt.go @@ -5,56 +5,37 @@ import ( "crypto/cipher" "crypto/rand" "crypto/sha256" + "fmt" "golang.org/x/crypto/pbkdf2" ) -// Encryption is the basic type for storing -// the key, passphrase and salt -type Encryption struct { - key []byte - passphrase []byte - salt []byte -} - -// New generates a new Encryption, using the supplied passphrase and -// an optional supplied salt. -// Passing nil passphrase will not use decryption. -func New(passphrase []byte, salt []byte) (e Encryption, err error) { - if passphrase == nil { - e = Encryption{nil, nil, nil} +// New generates a new key based on a passphrase and salt +func New(passphrase []byte, usersalt []byte) (key []byte, salt []byte, err error) { + if len(passphrase) < 1 { + err = fmt.Errorf("need more than that for passphrase") return } - e.passphrase = passphrase - if salt == nil { - e.salt = make([]byte, 8) + if usersalt == nil { + salt = make([]byte, 8) // http://www.ietf.org/rfc/rfc2898.txt // Salt. - rand.Read(e.salt) + rand.Read(salt) } else { - e.salt = salt + salt = usersalt } - e.key = pbkdf2.Key([]byte(passphrase), e.salt, 100, 32, sha256.New) + key = pbkdf2.Key([]byte(passphrase), salt, 100, 32, sha256.New) return } -// Salt returns the salt bytes -func (e Encryption) Salt() []byte { - return e.salt -} - -// Encrypt will generate an Encryption, prefixed with the IV -func (e Encryption) Encrypt(plaintext []byte) (encrypted []byte, err error) { - if e.passphrase == nil { - encrypted = plaintext - return - } +// Encrypt will encrypt using the pre-generated key +func Encrypt(plaintext []byte, key []byte) (encrypted []byte, err error) { // generate a random iv each time // http://nvlpubs.nist.gov/nistpubs/Legacy/SP/nistspecialpublication800-38d.pdf // Section 8.2 ivBytes := make([]byte, 12) rand.Read(ivBytes) - b, err := aes.NewCipher(e.key) + b, err := aes.NewCipher(key) if err != nil { return } @@ -67,13 +48,9 @@ func (e Encryption) Encrypt(plaintext []byte) (encrypted []byte, err error) { return } -// Decrypt an Encryption -func (e Encryption) Decrypt(encrypted []byte) (plaintext []byte, err error) { - if e.passphrase == nil { - plaintext = encrypted - return - } - b, err := aes.NewCipher(e.key) +// Decrypt using the pre-generated key +func Decrypt(encrypted []byte, key []byte) (plaintext []byte, err error) { + b, err := aes.NewCipher(key) if err != nil { return } diff --git a/src/crypt/crypt_test.go b/src/crypt/crypt_test.go index 14ac6397..d03c4bd8 100644 --- a/src/crypt/crypt_test.go +++ b/src/crypt/crypt_test.go @@ -6,55 +6,42 @@ import ( "github.com/stretchr/testify/assert" ) -func BenchmarkEncryptionNew(b *testing.B) { +func BenchmarkEncrypt(b *testing.B) { + bob, _, _ := New([]byte("password"), nil) for i := 0; i < b.N; i++ { - bob, _ := New([]byte("password"), nil) - bob.Encrypt([]byte("hello, world")) + Encrypt([]byte("hello, world"), bob) } } -func BenchmarkEncryption(b *testing.B) { - bob, _ := New([]byte("password"), nil) +func BenchmarkDecrypt(b *testing.B) { + key, _, _ := New([]byte("password"), nil) + msg := []byte("hello, world") + enc, _ := Encrypt(msg, key) + b.ResetTimer() for i := 0; i < b.N; i++ { - bob.Encrypt([]byte("hello, world")) + Decrypt(enc, key) } } func TestEncryption(t *testing.T) { - bob, err := New([]byte("password"), nil) + key, salt, err := New([]byte("password"), nil) assert.Nil(t, err) - jane, err := New([]byte("password"), bob.Salt()) + msg := []byte("hello, world") + enc, err := Encrypt(msg, key) assert.Nil(t, err) - enc, err := bob.Encrypt([]byte("hello, world")) + dec, err := Decrypt(enc, key) assert.Nil(t, err) - dec, err := jane.Decrypt(enc) - assert.Nil(t, err) - assert.Equal(t, dec, []byte("hello, world")) + assert.Equal(t, msg, dec) - jane2, err := New([]byte("password"), nil) + // check reusing the salt + key2, _, err := New([]byte("password"), salt) + dec, err = Decrypt(enc, key2) assert.Nil(t, err) - dec, err = jane2.Decrypt(enc) + assert.Equal(t, msg, dec) + + // check reusing the salt + key2, _, err = New([]byte("wrong password"), salt) + dec, err = Decrypt(enc, key2) assert.NotNil(t, err) - assert.NotEqual(t, dec, []byte("hello, world")) - - jane3, err := New([]byte("passwordwrong"), bob.Salt()) - assert.Nil(t, err) - dec, err = jane3.Decrypt(enc) - assert.NotNil(t, err) - assert.NotEqual(t, dec, []byte("hello, world")) - -} - -func TestNoEncryption(t *testing.T) { - bob, err := New(nil, nil) - assert.Nil(t, err) - jane, err := New(nil, nil) - assert.Nil(t, err) - enc, err := bob.Encrypt([]byte("hello, world")) - assert.Nil(t, err) - dec, err := jane.Decrypt(enc) - assert.Nil(t, err) - assert.Equal(t, dec, []byte("hello, world")) - assert.Equal(t, enc, []byte("hello, world")) - + assert.NotEqual(t, msg, dec) } diff --git a/src/message/message.go b/src/message/message.go index 2f0abb6c..d46f9835 100644 --- a/src/message/message.go +++ b/src/message/message.go @@ -23,7 +23,7 @@ func (m Message) String() string { } // Send will send out -func Send(c *comm.Comm, key crypt.Encryption, m Message) (err error) { +func Send(c *comm.Comm, key []byte, m Message) (err error) { mSend, err := Encode(key, m) if err != nil { return @@ -34,21 +34,25 @@ func Send(c *comm.Comm, key crypt.Encryption, m Message) (err error) { } // Encode will convert to bytes -func Encode(key crypt.Encryption, m Message) (b []byte, err error) { +func Encode(key []byte, m Message) (b []byte, err error) { b, err = json.Marshal(m) if err != nil { return } b = compress.Compress(b) - b, err = key.Encrypt(b) + if key != nil { + b, err = crypt.Encrypt(b, key) + } return } // Decode will convert from bytes -func Decode(key crypt.Encryption, b []byte) (m Message, err error) { - b, err = key.Decrypt(b) - if err != nil { - return +func Decode(key []byte, b []byte) (m Message, err error) { + if key != nil { + b, err = crypt.Decrypt(b, key) + if err != nil { + return + } } b = compress.Decompress(b) err = json.Unmarshal(b, &m)