0
0
Fork 0
mirror of https://github.com/schollz/croc.git synced 2025-10-11 21:30:16 +02:00

Encryption works, cleanup is good

This commit is contained in:
Zack Scholl 2017-10-17 21:15:48 -06:00
parent e59df2e617
commit 0cf680fd66
4 changed files with 73 additions and 27 deletions

View file

@ -3,6 +3,7 @@ package main
import ( import (
"fmt" "fmt"
"io" "io"
"io/ioutil"
"math" "math"
"net" "net"
"os" "os"
@ -29,7 +30,7 @@ func runClient(connectionType string, codePhrase string) {
uiprogress.Start() uiprogress.Start()
bars = make([]*uiprogress.Bar, numberConnections) bars = make([]*uiprogress.Bar, numberConnections)
fileNameToReceive := "" var iv, salt, fileNameToReceive string
for id := 0; id < numberConnections; id++ { for id := 0; id < numberConnections; id++ {
go func(id int) { go func(id int) {
defer wg.Done() defer wg.Done()
@ -60,7 +61,7 @@ func runClient(connectionType string, codePhrase string) {
} else { // this is a receiver } else { // this is a receiver
// receive file // receive file
logger.Debug("receive file") logger.Debug("receive file")
fileNameToReceive = receiveFile(id, connection, codePhrase) fileNameToReceive, iv, salt = receiveFile(id, connection, codePhrase)
} }
}(id) }(id)
@ -69,13 +70,32 @@ func runClient(connectionType string, codePhrase string) {
if connectionType == "r" { if connectionType == "r" {
catFile(fileNameToReceive) catFile(fileNameToReceive)
encrypted, err := ioutil.ReadFile(fileNameToReceive + ".encrypted")
if err != nil {
log.Error(err)
return
}
fmt.Println("\n\ndecrypting...")
decrypted, err := Decrypt(encrypted, codePhrase, salt, iv)
if err != nil {
log.Error(err)
return
}
ioutil.WriteFile(fileNameToReceive, decrypted, 0644)
os.Remove(fileNameToReceive + ".encrypted")
fmt.Println("\nDownloaded " + fileNameToReceive + "!")
} else {
log.Info("cleaning up")
os.Remove(fileName + ".encrypted")
os.Remove(fileName + ".iv")
os.Remove(fileName + ".salt")
} }
} }
func catFile(fileNameToReceive string) { func catFile(fileNameToReceive string) {
// cat the file // cat the file
os.Remove(fileNameToReceive) os.Remove(fileNameToReceive)
finished, err := os.Create(fileNameToReceive) finished, err := os.Create(fileNameToReceive + ".encrypted")
defer finished.Close() defer finished.Close()
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
@ -94,24 +114,35 @@ func catFile(fileNameToReceive string) {
os.Remove(fileNameToReceive + "." + strconv.Itoa(id)) os.Remove(fileNameToReceive + "." + strconv.Itoa(id))
} }
fmt.Println("\n\n\nDownloaded " + fileNameToReceive + "!")
} }
func receiveFile(id int, connection net.Conn, codePhrase string) string { func receiveFile(id int, connection net.Conn, codePhrase string) (fileNameToReceive string, iv string, salt string) {
logger := log.WithFields(log.Fields{ logger := log.WithFields(log.Fields{
"function": "receiveFile #" + strconv.Itoa(id), "function": "receiveFile #" + strconv.Itoa(id),
}) })
bufferFileName := make([]byte, 64)
bufferFileSize := make([]byte, 10)
logger.Debug("waiting for file size") logger.Debug("waiting for file size")
bufferFileSize := make([]byte, 10)
connection.Read(bufferFileSize) connection.Read(bufferFileSize)
fileSize, _ := strconv.ParseInt(strings.Trim(string(bufferFileSize), ":"), 10, 64) fileSize, _ := strconv.ParseInt(strings.Trim(string(bufferFileSize), ":"), 10, 64)
logger.Debugf("filesize: %d", fileSize) logger.Debugf("filesize: %d", fileSize)
bufferFileName := make([]byte, 64)
connection.Read(bufferFileName) connection.Read(bufferFileName)
fileNameToReceive := strings.Trim(string(bufferFileName), ":") fileNameToReceive = strings.Trim(string(bufferFileName), ":")
logger.Debugf("fileName: %v", fileNameToReceive) logger.Debugf("fileName: %v", fileNameToReceive)
ivHex := make([]byte, BUFFERSIZE)
connection.Read(ivHex)
iv = strings.Trim(string(ivHex), ":")
logger.Debugf("iv: %v", iv)
saltHex := make([]byte, BUFFERSIZE)
connection.Read(saltHex)
salt = strings.Trim(string(saltHex), ":")
logger.Debugf("salt: %v", salt)
os.Remove(fileNameToReceive + "." + strconv.Itoa(id)) os.Remove(fileNameToReceive + "." + strconv.Itoa(id))
newFile, err := os.Create(fileNameToReceive + "." + strconv.Itoa(id)) newFile, err := os.Create(fileNameToReceive + "." + strconv.Itoa(id))
if err != nil { if err != nil {
@ -140,7 +171,7 @@ func receiveFile(id int, connection net.Conn, codePhrase string) string {
receivedBytes += BUFFERSIZE receivedBytes += BUFFERSIZE
} }
logger.Debug("received file") logger.Debug("received file")
return fileNameToReceive return
} }
func sendFile(id int, connection net.Conn, codePhrase string) { func sendFile(id int, connection net.Conn, codePhrase string) {
@ -185,6 +216,25 @@ func sendFile(id int, connection net.Conn, codePhrase string) {
connection.Write([]byte(fileSize)) connection.Write([]byte(fileSize))
logger.Debugf("sending fileNameToSend: %s", fileNameToSend) logger.Debugf("sending fileNameToSend: %s", fileNameToSend)
connection.Write([]byte(fileNameToSend)) connection.Write([]byte(fileNameToSend))
// send iv
iv, err := ioutil.ReadFile(fileName + ".iv")
if err != nil {
log.Error(err)
return
}
logger.Debugf("sending iv: %s", iv)
connection.Write([]byte(fillString(string(iv), BUFFERSIZE)))
// send salt
salt, err := ioutil.ReadFile(fileName + ".salt")
if err != nil {
log.Error(err)
return
}
logger.Debugf("sending salt: %s", salt)
connection.Write([]byte(fillString(string(salt), BUFFERSIZE)))
sendBuffer := make([]byte, BUFFERSIZE) sendBuffer := make([]byte, BUFFERSIZE)
chunkI := 0 chunkI := 0

View file

@ -1,7 +1,6 @@
package main package main
import ( import (
"bytes"
"crypto/aes" "crypto/aes"
"crypto/cipher" "crypto/cipher"
"crypto/rand" "crypto/rand"
@ -29,7 +28,7 @@ func GetRandomName() string {
return strings.Join(result, "-") return strings.Join(result, "-")
} }
func Encrypt(plaintext []byte, passphrase string) (ciphertext []byte, err error) { func Encrypt(plaintext []byte, passphrase string) ([]byte, string, string) {
key, salt := deriveKey(passphrase, nil) key, salt := deriveKey(passphrase, nil)
iv := make([]byte, 12) iv := make([]byte, 12)
// http://nvlpubs.nist.gov/nistpubs/Legacy/SP/nistspecialpublication800-38d.pdf // http://nvlpubs.nist.gov/nistpubs/Legacy/SP/nistspecialpublication800-38d.pdf
@ -38,19 +37,16 @@ func Encrypt(plaintext []byte, passphrase string) (ciphertext []byte, err error)
b, _ := aes.NewCipher(key) b, _ := aes.NewCipher(key)
aesgcm, _ := cipher.NewGCM(b) aesgcm, _ := cipher.NewGCM(b)
data := aesgcm.Seal(nil, iv, plaintext, nil) data := aesgcm.Seal(nil, iv, plaintext, nil)
ciphertext = []byte(hex.EncodeToString(salt) + "-" + hex.EncodeToString(iv) + "-" + hex.EncodeToString(data)) return data, hex.EncodeToString(salt), hex.EncodeToString(iv)
return
} }
func Decrypt(ciphertext []byte, passphrase string) (plaintext []byte, err error) { func Decrypt(data []byte, passphrase string, salt string, iv string) (plaintext []byte, err error) {
arr := bytes.Split(ciphertext, []byte("-")) saltBytes, _ := hex.DecodeString(salt)
salt, _ := hex.DecodeString(string(arr[0])) ivBytes, _ := hex.DecodeString(iv)
iv, _ := hex.DecodeString(string(arr[1])) key, _ := deriveKey(passphrase, saltBytes)
data, _ := hex.DecodeString(string(arr[2]))
key, _ := deriveKey(passphrase, salt)
b, _ := aes.NewCipher(key) b, _ := aes.NewCipher(key)
aesgcm, _ := cipher.NewGCM(b) aesgcm, _ := cipher.NewGCM(b)
plaintext, err = aesgcm.Open(nil, iv, data, nil) plaintext, err = aesgcm.Open(nil, ivBytes, data, nil)
return return
} }

View file

@ -8,19 +8,16 @@ import (
func TestEncrypt(t *testing.T) { func TestEncrypt(t *testing.T) {
key := GetRandomName() key := GetRandomName()
fmt.Println(key) fmt.Println(key)
encrypted, err := Encrypt([]byte("hello, world"), key) salt, iv, encrypted := Encrypt([]byte("hello, world"), key)
if err != nil {
t.Error(err)
}
fmt.Println(len(encrypted)) fmt.Println(len(encrypted))
decrypted, err := Decrypt(encrypted, key) decrypted, err := Decrypt(salt, iv, encrypted, key)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
if string(decrypted) != "hello, world" { if string(decrypted) != "hello, world" {
t.Error("problem decrypting") t.Error("problem decrypting")
} }
_, err = Decrypt(encrypted, "wrong passphrase") _, err = Decrypt(salt, iv, encrypted, "wrong passphrase")
if err == nil { if err == nil {
t.Error("should not work!") t.Error("should not work!")
} }

View file

@ -56,17 +56,20 @@ func main() {
if connectionTypeFlag == "s" { if connectionTypeFlag == "s" {
// encrypt the file // encrypt the file
fmt.Println("encrypting...")
fdata, err := ioutil.ReadFile(fileName) fdata, err := ioutil.ReadFile(fileName)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
return return
} }
encrypted, err := Encrypt(fdata, codePhraseFlag) encrypted, salt, iv := Encrypt(fdata, codePhraseFlag)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
return return
} }
ioutil.WriteFile(fileName+".encrypted", encrypted, 0644) ioutil.WriteFile(fileName+".encrypted", encrypted, 0644)
ioutil.WriteFile(fileName+".salt", []byte(salt), 0644)
ioutil.WriteFile(fileName+".iv", []byte(iv), 0644)
} }
log.SetFormatter(&log.TextFormatter{}) log.SetFormatter(&log.TextFormatter{})