1
1
Fork 0
mirror of https://github.com/schollz/croc.git synced 2025-10-11 05:11:06 +02:00

Merge pull request #708 from schollz/fix10

fix message passing for initial secure layer
This commit is contained in:
Zack 2024-05-23 09:46:40 -07:00 committed by GitHub
commit b0920bbe70
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -702,26 +702,33 @@ func (c *Client) Send(filesInfo []FileInfo, emptyFoldersToTransfer []FileInfo, t
log.Debugf("banner: %s", banner) log.Debugf("banner: %s", banner)
log.Debugf("connection established: %+v", conn) log.Debugf("connection established: %+v", conn)
var kB []byte var kB []byte
var dataMessage SimpleMessage
B, _ := pake.InitCurve([]byte(c.Options.SharedSecret[5:]), 1, c.Options.Curve) B, _ := pake.InitCurve([]byte(c.Options.SharedSecret[5:]), 1, c.Options.Curve)
for { for {
log.Debug("waiting for bytes") var dataMessage SimpleMessage
log.Trace("waiting for bytes")
data, errConn := conn.Receive() data, errConn := conn.Receive()
if errConn != nil { if errConn != nil {
log.Debugf("[%+v] had error: %s", conn, errConn.Error()) log.Tracef("[%+v] had error: %s", conn, errConn.Error())
}
err = json.Unmarshal(data, &dataMessage)
if err == nil {
log.Debugf("dataMessage: %s", dataMessage)
} }
json.Unmarshal(data, &dataMessage)
log.Tracef("data: %+v '%s'", data, data)
log.Tracef("dataMessage: %s", dataMessage)
log.Tracef("kB: %x", kB)
// if kB not null, then use it to decrypt // if kB not null, then use it to decrypt
if kB != nil { if kB != nil {
data, err = crypt.Decrypt(data, kB) var decryptErr error
if err != nil { var dataDecrypt []byte
log.Debugf("error decrypting: %v", err) dataDecrypt, decryptErr = crypt.Decrypt(data, kB)
if decryptErr != nil {
log.Tracef("error decrypting: %v: '%s'", decryptErr, data)
} else {
// copy dataDecrypt to data
data = dataDecrypt
log.Tracef("decrypted: %s", data)
} }
} }
if bytes.Equal(data, ipRequest) { if bytes.Equal(data, ipRequest) {
log.Tracef("got ipRequest")
// recipient wants to try to connect to local ips // recipient wants to try to connect to local ips
var ips []string var ips []string
// only get local ips if the local is enabled // only get local ips if the local is enabled
@ -729,38 +736,48 @@ func (c *Client) Send(filesInfo []FileInfo, emptyFoldersToTransfer []FileInfo, t
// get list of local ips // get list of local ips
ips, err = utils.GetLocalIPs() ips, err = utils.GetLocalIPs()
if err != nil { if err != nil {
log.Debugf("error getting local ips: %v", err) log.Tracef("error getting local ips: %v", err)
} }
// prepend the port that is being listened to // prepend the port that is being listened to
ips = append([]string{c.Options.RelayPorts[0]}, ips...) ips = append([]string{c.Options.RelayPorts[0]}, ips...)
} }
bips, _ := json.Marshal(ips) log.Tracef("sending ips: %+v", ips)
bips, _ = crypt.Encrypt(bips, kB) bips, errIps := json.Marshal(ips)
if errIps != nil {
log.Tracef("error marshalling ips: %v", errIps)
}
bips, errIps = crypt.Encrypt(bips, kB)
if errIps != nil {
log.Tracef("error encrypting ips: %v", errIps)
}
if err = conn.Send(bips); err != nil { if err = conn.Send(bips); err != nil {
log.Errorf("error sending: %v", err) log.Errorf("error sending: %v", err)
} }
} else if dataMessage.Kind == "pake1" { } else if dataMessage.Kind == "pake1" {
err = B.Update(dataMessage.Bytes) log.Trace("got pake1")
if err == nil { var pakeError error
kB, err = B.SessionKey() pakeError = B.Update(dataMessage.Bytes)
if err == nil { if pakeError == nil {
log.Debugf("dataMessage kB: %x", kB) kB, pakeError = B.SessionKey()
if pakeError == nil {
log.Tracef("dataMessage kB: %x", kB)
dataMessage.Bytes = B.Bytes() dataMessage.Bytes = B.Bytes()
dataMessage.Kind = "pake2" dataMessage.Kind = "pake2"
data, _ = json.Marshal(dataMessage) data, _ = json.Marshal(dataMessage)
if err = conn.Send(data); err != nil { if pakeError = conn.Send(data); err != nil {
log.Errorf("dataMessage error sending: %v", err) log.Errorf("dataMessage error sending: %v", err)
} }
} }
} }
} else if bytes.Equal(data, handshakeRequest) { } else if bytes.Equal(data, handshakeRequest) {
log.Trace("got handshake")
break break
} else if bytes.Equal(data, []byte{1}) { } else if bytes.Equal(data, []byte{1}) {
log.Debug("got ping") log.Trace("got ping")
continue continue
} else { } else {
log.Debugf("[%+v] got weird bytes: %+v", conn, data) log.Tracef("[%+v] got weird bytes: %+v", conn, data)
// throttle the reading // throttle the reading
errchan <- fmt.Errorf("gracefully refusing using the public relay") errchan <- fmt.Errorf("gracefully refusing using the public relay")
return return
@ -1905,14 +1922,14 @@ func (c *Client) setBar() {
} }
func (c *Client) receiveData(i int) { func (c *Client) receiveData(i int) {
log.Debugf("%d receiving data", i) log.Tracef("%d receiving data", i)
for { for {
data, err := c.conn[i+1].Receive() data, err := c.conn[i+1].Receive()
if err != nil { if err != nil {
break break
} }
if bytes.Equal(data, []byte{1}) { if bytes.Equal(data, []byte{1}) {
log.Debug("got ping") log.Trace("got ping")
continue continue
} }