0
0
Fork 0
mirror of https://github.com/schollz/croc.git synced 2025-10-11 21:30:16 +02:00

make encryption functional

This commit is contained in:
Zack Scholl 2019-11-17 15:17:06 -08:00
parent 382ef3157a
commit 7b68bcaea2
5 changed files with 58 additions and 93 deletions

View file

@ -63,7 +63,7 @@ type Options struct {
type Client struct { type Client struct {
Options Options Options Options
Pake *pake.Pake Pake *pake.Pake
Key crypt.Encryption Key []byte
ExternalIP, ExternalIPConnected string ExternalIP, ExternalIPConnected string
// steps involved in forming relationship // steps involved in forming relationship
@ -147,12 +147,6 @@ func New(ops Options) (c *Client, err error) {
c.conn = make([]*comm.Comm, 16) c.conn = make([]*comm.Comm, 16)
// use default key (no encryption, until PAKE succeeds)
c.Key, err = crypt.New(nil, nil)
if err != nil {
return
}
// initialize pake // initialize pake
if c.Options.IsSender { if c.Options.IsSender {
c.Pake, err = pake.Init([]byte(c.Options.SharedSecret), 1, elliptic.P521(), 1*time.Microsecond) c.Pake, err = pake.Init([]byte(c.Options.SharedSecret), 1, elliptic.P521(), 1*time.Microsecond)
@ -648,10 +642,11 @@ func (c *Client) processMessageSalt(m message.Message) (done bool, err error) {
if err != nil { if err != nil {
return true, err return true, err
} }
c.Key, err = crypt.New(key, m.Bytes) c.Key, _, err = crypt.New(key, m.Bytes)
if err != nil { if err != nil {
return true, err return true, err
} }
log.Debugf("key = %+x", c.Key)
if c.ExternalIPConnected == "" { if c.ExternalIPConnected == "" {
// it can be preset by the local relay // it can be preset by the local relay
c.ExternalIPConnected = m.Message c.ExternalIPConnected = m.Message
@ -1040,7 +1035,7 @@ func (c *Client) receiveData(i int) {
break break
} }
data, err = c.Key.Decrypt(data) data, err = crypt.Decrypt(data, c.Key)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -1126,10 +1121,11 @@ func (c *Client) sendData(i int) {
posByte := make([]byte, 8) posByte := make([]byte, 8)
binary.LittleEndian.PutUint64(posByte, pos) binary.LittleEndian.PutUint64(posByte, pos)
dataToSend, err := c.Key.Encrypt( dataToSend, err := crypt.Encrypt(
compress.Compress( compress.Compress(
append(posByte, data[:n]...), append(posByte, data[:n]...),
), ),
c.Key,
) )
if err != nil { if err != nil {
panic(err) panic(err)

View file

@ -20,7 +20,7 @@ func TestCroc(t *testing.T) {
go tcp.Run("debug", "8083") go tcp.Run("debug", "8083")
go tcp.Run("debug", "8084") go tcp.Run("debug", "8084")
go tcp.Run("debug", "8085") go tcp.Run("debug", "8085")
time.Sleep(1 * time.Second) time.Sleep(3 * time.Second)
log.Debug("setting up sender") log.Debug("setting up sender")
sender, err := New(Options{ sender, err := New(Options{
@ -36,6 +36,7 @@ func TestCroc(t *testing.T) {
if err != nil { if err != nil {
panic(err) panic(err)
} }
time.Sleep(3 * time.Second)
log.Debug("setting up receiver") log.Debug("setting up receiver")
receiver, err := New(Options{ receiver, err := New(Options{

View file

@ -5,56 +5,37 @@ import (
"crypto/cipher" "crypto/cipher"
"crypto/rand" "crypto/rand"
"crypto/sha256" "crypto/sha256"
"fmt"
"golang.org/x/crypto/pbkdf2" "golang.org/x/crypto/pbkdf2"
) )
// Encryption is the basic type for storing // New generates a new key based on a passphrase and salt
// the key, passphrase and salt func New(passphrase []byte, usersalt []byte) (key []byte, salt []byte, err error) {
type Encryption struct { if len(passphrase) < 1 {
key []byte err = fmt.Errorf("need more than that for passphrase")
passphrase []byte
salt []byte
}
// New generates a new Encryption, using the supplied passphrase and
// an optional supplied salt.
// Passing nil passphrase will not use decryption.
func New(passphrase []byte, salt []byte) (e Encryption, err error) {
if passphrase == nil {
e = Encryption{nil, nil, nil}
return return
} }
e.passphrase = passphrase if usersalt == nil {
if salt == nil { salt = make([]byte, 8)
e.salt = make([]byte, 8)
// http://www.ietf.org/rfc/rfc2898.txt // http://www.ietf.org/rfc/rfc2898.txt
// Salt. // Salt.
rand.Read(e.salt) rand.Read(salt)
} else { } else {
e.salt = salt salt = usersalt
} }
e.key = pbkdf2.Key([]byte(passphrase), e.salt, 100, 32, sha256.New) key = pbkdf2.Key([]byte(passphrase), salt, 100, 32, sha256.New)
return return
} }
// Salt returns the salt bytes // Encrypt will encrypt using the pre-generated key
func (e Encryption) Salt() []byte { func Encrypt(plaintext []byte, key []byte) (encrypted []byte, err error) {
return e.salt
}
// Encrypt will generate an Encryption, prefixed with the IV
func (e Encryption) Encrypt(plaintext []byte) (encrypted []byte, err error) {
if e.passphrase == nil {
encrypted = plaintext
return
}
// generate a random iv each time // generate a random iv each time
// http://nvlpubs.nist.gov/nistpubs/Legacy/SP/nistspecialpublication800-38d.pdf // http://nvlpubs.nist.gov/nistpubs/Legacy/SP/nistspecialpublication800-38d.pdf
// Section 8.2 // Section 8.2
ivBytes := make([]byte, 12) ivBytes := make([]byte, 12)
rand.Read(ivBytes) rand.Read(ivBytes)
b, err := aes.NewCipher(e.key) b, err := aes.NewCipher(key)
if err != nil { if err != nil {
return return
} }
@ -67,13 +48,9 @@ func (e Encryption) Encrypt(plaintext []byte) (encrypted []byte, err error) {
return return
} }
// Decrypt an Encryption // Decrypt using the pre-generated key
func (e Encryption) Decrypt(encrypted []byte) (plaintext []byte, err error) { func Decrypt(encrypted []byte, key []byte) (plaintext []byte, err error) {
if e.passphrase == nil { b, err := aes.NewCipher(key)
plaintext = encrypted
return
}
b, err := aes.NewCipher(e.key)
if err != nil { if err != nil {
return return
} }

View file

@ -6,55 +6,42 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func BenchmarkEncryptionNew(b *testing.B) { func BenchmarkEncrypt(b *testing.B) {
bob, _, _ := New([]byte("password"), nil)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
bob, _ := New([]byte("password"), nil) Encrypt([]byte("hello, world"), bob)
bob.Encrypt([]byte("hello, world"))
} }
} }
func BenchmarkEncryption(b *testing.B) { func BenchmarkDecrypt(b *testing.B) {
bob, _ := New([]byte("password"), nil) key, _, _ := New([]byte("password"), nil)
msg := []byte("hello, world")
enc, _ := Encrypt(msg, key)
b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
bob.Encrypt([]byte("hello, world")) Decrypt(enc, key)
} }
} }
func TestEncryption(t *testing.T) { func TestEncryption(t *testing.T) {
bob, err := New([]byte("password"), nil) key, salt, err := New([]byte("password"), nil)
assert.Nil(t, err) assert.Nil(t, err)
jane, err := New([]byte("password"), bob.Salt()) msg := []byte("hello, world")
enc, err := Encrypt(msg, key)
assert.Nil(t, err) assert.Nil(t, err)
enc, err := bob.Encrypt([]byte("hello, world")) dec, err := Decrypt(enc, key)
assert.Nil(t, err) assert.Nil(t, err)
dec, err := jane.Decrypt(enc) assert.Equal(t, msg, dec)
assert.Nil(t, err)
assert.Equal(t, dec, []byte("hello, world"))
jane2, err := New([]byte("password"), nil) // check reusing the salt
key2, _, err := New([]byte("password"), salt)
dec, err = Decrypt(enc, key2)
assert.Nil(t, err) assert.Nil(t, err)
dec, err = jane2.Decrypt(enc) assert.Equal(t, msg, dec)
// check reusing the salt
key2, _, err = New([]byte("wrong password"), salt)
dec, err = Decrypt(enc, key2)
assert.NotNil(t, err) assert.NotNil(t, err)
assert.NotEqual(t, dec, []byte("hello, world")) assert.NotEqual(t, msg, dec)
jane3, err := New([]byte("passwordwrong"), bob.Salt())
assert.Nil(t, err)
dec, err = jane3.Decrypt(enc)
assert.NotNil(t, err)
assert.NotEqual(t, dec, []byte("hello, world"))
}
func TestNoEncryption(t *testing.T) {
bob, err := New(nil, nil)
assert.Nil(t, err)
jane, err := New(nil, nil)
assert.Nil(t, err)
enc, err := bob.Encrypt([]byte("hello, world"))
assert.Nil(t, err)
dec, err := jane.Decrypt(enc)
assert.Nil(t, err)
assert.Equal(t, dec, []byte("hello, world"))
assert.Equal(t, enc, []byte("hello, world"))
} }

View file

@ -23,7 +23,7 @@ func (m Message) String() string {
} }
// Send will send out // Send will send out
func Send(c *comm.Comm, key crypt.Encryption, m Message) (err error) { func Send(c *comm.Comm, key []byte, m Message) (err error) {
mSend, err := Encode(key, m) mSend, err := Encode(key, m)
if err != nil { if err != nil {
return return
@ -34,21 +34,25 @@ func Send(c *comm.Comm, key crypt.Encryption, m Message) (err error) {
} }
// Encode will convert to bytes // Encode will convert to bytes
func Encode(key crypt.Encryption, m Message) (b []byte, err error) { func Encode(key []byte, m Message) (b []byte, err error) {
b, err = json.Marshal(m) b, err = json.Marshal(m)
if err != nil { if err != nil {
return return
} }
b = compress.Compress(b) b = compress.Compress(b)
b, err = key.Encrypt(b) if key != nil {
b, err = crypt.Encrypt(b, key)
}
return return
} }
// Decode will convert from bytes // Decode will convert from bytes
func Decode(key crypt.Encryption, b []byte) (m Message, err error) { func Decode(key []byte, b []byte) (m Message, err error) {
b, err = key.Decrypt(b) if key != nil {
if err != nil { b, err = crypt.Decrypt(b, key)
return if err != nil {
return
}
} }
b = compress.Decompress(b) b = compress.Decompress(b)
err = json.Unmarshal(b, &m) err = json.Unmarshal(b, &m)