1
1
Fork 0
mirror of https://github.com/schollz/croc.git synced 2025-10-11 13:21:00 +02:00

fix: make sure that only pake messages are unencrypted

This commit is contained in:
Zack Scholl 2021-04-16 17:15:51 -07:00
parent babfd5f35f
commit c02b4f1256

View file

@ -678,7 +678,7 @@ func (c *Client) Receive() (err error) {
err = c.transfer(TransferOptions{}) err = c.transfer(TransferOptions{})
if err == nil { if err == nil {
if c.numberOfTransferedFiles == 0 { if c.numberOfTransferedFiles == 0 {
fmt.Fprintf(os.Stderr,"\rNo files need transfering.") fmt.Fprintf(os.Stderr, "\rNo files need transfering.")
} }
} }
return return
@ -931,6 +931,15 @@ func (c *Client) processMessage(payload []byte) (done bool, err error) {
return return
} }
// only "pake" messages should be unencrypted
// if a non-"pake" message is received unencrypted something
// is weird
if m.Type != "pake" && c.Key == nil {
err = fmt.Errorf("unencrypted communication rejected")
done = true
return
}
switch m.Type { switch m.Type {
case "finished": case "finished":
err = message.Send(c.conn[0], c.Key, message.Message{ err = message.Send(c.conn[0], c.Key, message.Message{
@ -1209,7 +1218,7 @@ func (c *Client) updateIfRecipientHasFileInfo() (err error) {
err = c.createEmptyFileAndFinish(fileInfo, i) err = c.createEmptyFileAndFinish(fileInfo, i)
if err != nil { if err != nil {
return return
} else{ } else {
c.numberOfTransferedFiles++ c.numberOfTransferedFiles++
} }
continue continue
@ -1217,10 +1226,10 @@ func (c *Client) updateIfRecipientHasFileInfo() (err error) {
log.Debugf("%s %+x %+x %+v", fileInfo.Name, fileHash, fileInfo.Hash, errHash) log.Debugf("%s %+x %+x %+v", fileInfo.Name, fileHash, fileInfo.Hash, errHash)
if !bytes.Equal(fileHash, fileInfo.Hash) { if !bytes.Equal(fileHash, fileInfo.Hash) {
log.Debugf("hashes are not equal %x != %x", fileHash, fileInfo.Hash) log.Debugf("hashes are not equal %x != %x", fileHash, fileInfo.Hash)
if errHash== nil && !c.Options.Overwrite { if errHash == nil && !c.Options.Overwrite {
ans := utils.GetInput(fmt.Sprintf("\rOverwrite '%s'? (y/n) ",path.Join(fileInfo.FolderRemote, fileInfo.Name))) ans := utils.GetInput(fmt.Sprintf("\rOverwrite '%s'? (y/n) ", path.Join(fileInfo.FolderRemote, fileInfo.Name)))
if strings.TrimSpace(strings.ToLower(ans)) != "y" { if strings.TrimSpace(strings.ToLower(ans)) != "y" {
fmt.Fprintf(os.Stderr,"skipping '%s'",path.Join(fileInfo.FolderRemote, fileInfo.Name)) fmt.Fprintf(os.Stderr, "skipping '%s'", path.Join(fileInfo.FolderRemote, fileInfo.Name))
continue continue
} }
} }