From c02b4f12560ae13f19e3c290e7df8ded686f2518 Mon Sep 17 00:00:00 2001 From: Zack Scholl Date: Fri, 16 Apr 2021 17:15:51 -0700 Subject: [PATCH] fix: make sure that only pake messages are unencrypted --- src/croc/croc.go | 31 ++++++++++++++++++++----------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/src/croc/croc.go b/src/croc/croc.go index bc4f3dd0..5ac5012c 100644 --- a/src/croc/croc.go +++ b/src/croc/croc.go @@ -104,11 +104,11 @@ type Client struct { longestFilename int firstSend bool - mutex *sync.Mutex - fread *os.File - numfinished int - quit chan bool - finishedNum int + mutex *sync.Mutex + fread *os.File + numfinished int + quit chan bool + finishedNum int numberOfTransferedFiles int } @@ -678,7 +678,7 @@ func (c *Client) Receive() (err error) { err = c.transfer(TransferOptions{}) if err == nil { if c.numberOfTransferedFiles == 0 { - fmt.Fprintf(os.Stderr,"\rNo files need transfering.") + fmt.Fprintf(os.Stderr, "\rNo files need transfering.") } } return @@ -931,6 +931,15 @@ func (c *Client) processMessage(payload []byte) (done bool, err error) { return } + // only "pake" messages should be unencrypted + // if a non-"pake" message is received unencrypted something + // is weird + if m.Type != "pake" && c.Key == nil { + err = fmt.Errorf("unencrypted communication rejected") + done = true + return + } + switch m.Type { case "finished": err = message.Send(c.conn[0], c.Key, message.Message{ @@ -1209,7 +1218,7 @@ func (c *Client) updateIfRecipientHasFileInfo() (err error) { err = c.createEmptyFileAndFinish(fileInfo, i) if err != nil { return - } else{ + } else { c.numberOfTransferedFiles++ } continue @@ -1217,12 +1226,12 @@ func (c *Client) updateIfRecipientHasFileInfo() (err error) { log.Debugf("%s %+x %+x %+v", fileInfo.Name, fileHash, fileInfo.Hash, errHash) if !bytes.Equal(fileHash, fileInfo.Hash) { log.Debugf("hashes are not equal %x != %x", fileHash, fileInfo.Hash) - if errHash== nil && !c.Options.Overwrite { - ans := utils.GetInput(fmt.Sprintf("\rOverwrite '%s'? (y/n) ",path.Join(fileInfo.FolderRemote, fileInfo.Name))) + if errHash == nil && !c.Options.Overwrite { + ans := utils.GetInput(fmt.Sprintf("\rOverwrite '%s'? (y/n) ", path.Join(fileInfo.FolderRemote, fileInfo.Name))) if strings.TrimSpace(strings.ToLower(ans)) != "y" { - fmt.Fprintf(os.Stderr,"skipping '%s'",path.Join(fileInfo.FolderRemote, fileInfo.Name)) + fmt.Fprintf(os.Stderr, "skipping '%s'", path.Join(fileInfo.FolderRemote, fileInfo.Name)) continue - } + } } } else { log.Debugf("hashes are equal %x == %x", fileHash, fileInfo.Hash)