0
0
Fork 0
mirror of https://github.com/schollz/croc.git synced 2025-10-11 13:21:00 +02:00

connect to TCP as early as possible

This commit is contained in:
Zack Scholl 2018-10-13 06:09:55 -07:00
parent 70fb9b77ca
commit cc6edd24d1
4 changed files with 128 additions and 93 deletions

View file

@ -7,6 +7,8 @@ import (
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/pkg/errors"
) )
// Comm is some basic TCP communication // Comm is some basic TCP communication
@ -39,7 +41,11 @@ func (c Comm) Write(b []byte) (int, error) {
copy(tmpCopy[5:], b) copy(tmpCopy[5:], b)
n, err := c.connection.Write(tmpCopy) n, err := c.connection.Write(tmpCopy)
if n != len(tmpCopy) { if n != len(tmpCopy) {
err = fmt.Errorf("wanted to write %d but wrote %d", len(b), n) if err != nil {
err = errors.Wrap(err, fmt.Sprintf("wanted to write %d but wrote %d", len(b), n))
} else {
err = fmt.Errorf("wanted to write %d but wrote %d", len(b), n)
}
} }
// log.Printf("wanted to write %d but wrote %d", n, len(b)) // log.Printf("wanted to write %d but wrote %d", n, len(b))
return n, err return n, err

View file

@ -55,6 +55,7 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
var resumeFile bool var resumeFile bool
var tcpConnections []comm.Comm var tcpConnections []comm.Comm
dataChan := make(chan []byte, 1024*1024) dataChan := make(chan []byte, 1024*1024)
isConnectedIfUsingTCP := make(chan bool)
blocks := []string{} blocks := []string{}
useWebsockets := true useWebsockets := true
@ -121,19 +122,38 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
if err := Q.Update(message); err != nil { if err := Q.Update(message); err != nil {
return err return err
} }
c.WriteMessage(websocket.BinaryMessage, Q.Bytes())
case 2:
log.Debugf("[%d] Q recieves H(k) from P", step)
if err := Q.Update(message); err != nil {
return err
}
// Q has the session key now, but we will still check if its valid
sessionKey, err = Q.SessionKey() sessionKey, err = Q.SessionKey()
if err != nil { if err != nil {
return err return err
} }
log.Debugf("%x\n", sessionKey) log.Debugf("%x\n", sessionKey)
// initialize TCP connections if using (possible, but unlikely, race condition)
go func() {
if !useWebsockets {
log.Debugf("connecting to server")
tcpConnections = make([]comm.Comm, len(tcpPorts))
for i, tcpPort := range tcpPorts {
log.Debugf("connecting to %d", i)
tcpConnections[i], err = connectToTCPServer(utils.SHA256(fmt.Sprintf("%d%x", i, sessionKey)), serverAddress+":"+tcpPort)
if err != nil {
log.Error(err)
}
}
log.Debugf("fully connected")
}
isConnectedIfUsingTCP <- true
}()
c.WriteMessage(websocket.BinaryMessage, Q.Bytes())
case 2:
log.Debugf("[%d] Q recieves H(k) from P", step)
// check if everything is still kosher with our computed session key
if err := Q.Update(message); err != nil {
return err
}
c.WriteMessage(websocket.BinaryMessage, []byte("ready")) c.WriteMessage(websocket.BinaryMessage, []byte("ready"))
case 3: case 3:
spin.Stop() spin.Stop()
@ -155,7 +175,7 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
} }
log.Debugf("got file stats: %+v", fstats) log.Debugf("got file stats: %+v", fstats)
// prompt user if its okay to receive file // determine if the file is resuming or not
progressFile = fmt.Sprintf("%s.progress", fstats.SentName) progressFile = fmt.Sprintf("%s.progress", fstats.SentName)
overwritingOrReceiving := "Receiving" overwritingOrReceiving := "Receiving"
if utils.Exists(fstats.Name) || utils.Exists(fstats.SentName) { if utils.Exists(fstats.Name) || utils.Exists(fstats.SentName) {
@ -165,6 +185,22 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
resumeFile = true resumeFile = true
} }
} }
// send blocks
if resumeFile {
fileWithBlocks, _ := os.Open(progressFile)
scanner := bufio.NewScanner(fileWithBlocks)
for scanner.Scan() {
blocks = append(blocks, strings.TrimSpace(scanner.Text()))
}
fileWithBlocks.Close()
}
blocksBytes, _ := json.Marshal(blocks)
// encrypt the block data and send
encblockBytes := crypt.Encrypt(blocksBytes, sessionKey)
c.WriteMessage(websocket.BinaryMessage, encblockBytes.Bytes())
// prompt user about the file
fileOrFolder := "file" fileOrFolder := "file"
if fstats.IsDir { if fstats.IsDir {
fileOrFolder = "folder" fileOrFolder = "folder"
@ -183,21 +219,6 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
} }
} }
// connect to TCP to receive file
if !useWebsockets {
log.Debugf("connecting to server")
tcpConnections = make([]comm.Comm, len(tcpPorts))
for i, tcpPort := range tcpPorts {
tcpConnections[i], err = connectToTCPServer(utils.SHA256(fmt.Sprintf("%d%x", i, sessionKey)), serverAddress+":"+tcpPort)
if err != nil {
log.Error(err)
return err
}
defer tcpConnections[i].Close()
}
log.Debugf("fully connected")
}
// await file // await file
// erase file if overwriting // erase file if overwriting
if overwritingOrReceiving == "Overwriting" { if overwritingOrReceiving == "Overwriting" {
@ -228,17 +249,6 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
} }
} }
// append the previous blocks if there was progress previously
if resumeFile {
file, _ := os.Open(progressFile)
scanner := bufio.NewScanner(file)
for scanner.Scan() {
blocks = append(blocks, strings.TrimSpace(scanner.Text()))
}
file.Close()
}
blocksBytes, _ := json.Marshal(blocks)
blockSize := 0 blockSize := 0
if useWebsockets { if useWebsockets {
blockSize = models.WEBSOCKET_BUFFER_SIZE / 8 blockSize = models.WEBSOCKET_BUFFER_SIZE / 8
@ -364,45 +374,44 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
log.Debug("got magic") log.Debug("got magic")
break break
} }
select { dataChan <- message
case dataChan <- message: // select {
default: // case dataChan <- message:
log.Debug("blocked") // default:
// no message sent // log.Debug("blocked")
// block // // no message sent
dataChan <- message // // block
} // dataChan <- message
// }
} }
} else { } else {
_ = <-isConnectedIfUsingTCP
log.Debugf("starting listening with tcp with %d connections", len(tcpConnections)) log.Debugf("starting listening with tcp with %d connections", len(tcpConnections))
// using TCP // using TCP
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(len(tcpConnections)) wg.Add(len(tcpConnections))
for i := range tcpConnections { for i := range tcpConnections {
go func(wg *sync.WaitGroup, tcpConnection comm.Comm) { defer func(i int) {
log.Debugf("closing connection %d", i)
tcpConnections[i].Close()
}(i)
go func(wg *sync.WaitGroup, j int) {
defer wg.Done() defer wg.Done()
for { for {
log.Debugf("waiting to read on %d", j)
// read from TCP connection // read from TCP connection
message, _, _, err := tcpConnection.Read() message, _, _, err := tcpConnections[j].Read()
// log.Debugf("message: %s", message) // log.Debugf("message: %s", message)
if err != nil { if err != nil {
log.Error(err) panic(err)
return
} }
if bytes.Equal(message, []byte("magic")) { if bytes.Equal(message, []byte("magic")) {
log.Debug("got magic") log.Debugf("%d got magic, leaving", j)
return return
} }
select { dataChan <- message
case dataChan <- message:
default:
log.Debug("blocked")
// no message sent
// block
dataChan <- message
}
} }
}(&wg, tcpConnections[i]) }(&wg, i)
} }
wg.Wait() wg.Wait()
} }

View file

@ -54,6 +54,7 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
var startTransfer time.Time var startTransfer time.Time
var tcpConnections []comm.Comm var tcpConnections []comm.Comm
blocksToSkip := make(map[int64]struct{}) blocksToSkip := make(map[int64]struct{})
isConnectedIfUsingTCP := make(chan bool)
type DataChan struct { type DataChan struct {
b []byte b []byte
@ -203,6 +204,22 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
return errors.New("recipient refused file") return errors.New("recipient refused file")
} }
// connect to TCP in background
tcpConnections = make([]comm.Comm, len(tcpPorts))
go func() {
if !useWebsockets {
log.Debugf("connecting to server")
for i, tcpPort := range tcpPorts {
log.Debugf("connecting to %s on connection %d", tcpPort, i)
tcpConnections[i], err = connectToTCPServer(utils.SHA256(fmt.Sprintf("%d%x", i, sessionKey)), serverAddress+":"+tcpPort)
if err != nil {
log.Error(err)
}
}
}
isConnectedIfUsingTCP <- true
}()
err = <-fileReady // block until file is ready err = <-fileReady // block until file is ready
if err != nil { if err != nil {
return err return err
@ -217,16 +234,23 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
// send the file meta data // send the file meta data
c.WriteMessage(websocket.BinaryMessage, enc.Bytes()) c.WriteMessage(websocket.BinaryMessage, enc.Bytes())
case 4: case 4:
spin.Stop() log.Debugf("[%d] recipient declares gives blocks", step)
// recipient sends blocks, and sender does not send anything back
log.Debugf("[%d] recipient declares readiness for file data", step) // determine if any blocks were sent to skip
if !bytes.HasPrefix(message, []byte("ready")) { enc, err := crypt.FromBytes(message)
return errors.New("recipient refused file") if err != nil {
log.Error(err)
return err
}
decrypted, err := enc.Decrypt(sessionKey)
if err != nil {
err = errors.Wrap(err, "could not decrypt blocks with session key")
log.Error(err)
return err
} }
// determine if any blocks were sent to skip
var blocks []string var blocks []string
errBlocks := json.Unmarshal(message[5:], &blocks) errBlocks := json.Unmarshal(decrypted, &blocks)
if errBlocks == nil { if errBlocks == nil {
for _, block := range blocks { for _, block := range blocks {
blockInt64, errBlock := strconv.Atoi(block) blockInt64, errBlock := strconv.Atoi(block)
@ -237,6 +261,7 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
} }
log.Debugf("found blocks: %+v", blocksToSkip) log.Debugf("found blocks: %+v", blocksToSkip)
// start loading the file into memory
// start streaming encryption/compression // start streaming encryption/compression
if fstats.IsDir { if fstats.IsDir {
// remove file if zipped // remove file if zipped
@ -285,21 +310,10 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
return return
} }
select { dataChan <- DataChan{
case dataChan <- DataChan{
b: encBytes, b: encBytes,
bytesRead: bytesread, bytesRead: bytesread,
err: nil, err: nil,
}:
default:
log.Debug("blocked")
// no message sent
// block
dataChan <- DataChan{
b: encBytes,
bytesRead: bytesread,
err: nil,
}
} }
currentPostition += int64(bytesread) currentPostition += int64(bytesread)
} }
@ -330,19 +344,12 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
} }
}(dataChan) }(dataChan)
// connect to TCP to receive file case 5:
if !useWebsockets { spin.Stop()
log.Debugf("connecting to server")
tcpConnections = make([]comm.Comm, len(tcpPorts)) log.Debugf("[%d] recipient declares readiness for file data", step)
for i, tcpPort := range tcpPorts { if !bytes.HasPrefix(message, []byte("ready")) {
log.Debug(tcpPort) return errors.New("recipient refused file")
tcpConnections[i], err = connectToTCPServer(utils.SHA256(fmt.Sprintf("%d%x", i, sessionKey)), serverAddress+":"+tcpPort)
if err != nil {
log.Error(err)
return err
}
defer tcpConnections[i].Close()
}
} }
fmt.Fprintf(os.Stderr, "\rSending (->%s)...\n", otherIP) fmt.Fprintf(os.Stderr, "\rSending (->%s)...\n", otherIP)
@ -381,10 +388,16 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
} }
} }
} else { } else {
_ = <-isConnectedIfUsingTCP
log.Debug("connected and ready to send on tcp")
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(len(tcpConnections)) wg.Add(len(tcpConnections))
for i := range tcpConnections { for i := range tcpConnections {
go func(i int, wg *sync.WaitGroup, dataChan <-chan DataChan, tcpConnection comm.Comm) { defer func(i int) {
log.Debugf("closing connection %d", i)
tcpConnections[i].Close()
}(i)
go func(i int, wg *sync.WaitGroup, dataChan <-chan DataChan) {
defer wg.Done() defer wg.Done()
for data := range dataChan { for data := range dataChan {
if data.err != nil { if data.err != nil {
@ -393,7 +406,7 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
} }
bar.Add(data.bytesRead) bar.Add(data.bytesRead)
// write data to tcp connection // write data to tcp connection
_, err = tcpConnection.Write(data.b) _, err = tcpConnections[i].Write(data.b)
if err != nil { if err != nil {
err = errors.Wrap(err, "problem writing message") err = errors.Wrap(err, "problem writing message")
log.Error(err) log.Error(err)
@ -404,7 +417,7 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
return return
} }
} }
}(i, &wg, dataChan, tcpConnections[i]) }(i, &wg, dataChan)
} }
wg.Wait() wg.Wait()
} }
@ -415,7 +428,8 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
if err != nil { if err != nil {
return err return err
} }
case 5: case 6:
// recevied something, maybe the file hash
transferTime := time.Since(startTransfer) transferTime := time.Since(startTransfer)
if !bytes.HasPrefix(message, []byte("hash:")) { if !bytes.HasPrefix(message, []byte("hash:")) {
log.Debugf("%s", message) log.Debugf("%s", message)
@ -446,7 +460,7 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
} }
func connectToTCPServer(room string, address string) (com comm.Comm, err error) { func connectToTCPServer(room string, address string) (com comm.Comm, err error) {
connection, err := net.Dial("tcp", address) connection, err := net.DialTimeout("tcp", address, 3*time.Hour)
if err != nil { if err != nil {
return return
} }

View file

@ -64,12 +64,14 @@ func run(port string) (err error) {
func clientCommuncation(port string, c comm.Comm) (err error) { func clientCommuncation(port string, c comm.Comm) (err error) {
// send ok to tell client they are connected // send ok to tell client they are connected
log.Debug("sending ok")
err = c.Send("ok") err = c.Send("ok")
if err != nil { if err != nil {
return return
} }
// wait for client to tell me which room they want // wait for client to tell me which room they want
log.Debug("waiting for answer")
room, err := c.Receive() room, err := c.Receive()
if err != nil { if err != nil {
return return
@ -86,10 +88,13 @@ func clientCommuncation(port string, c comm.Comm) (err error) {
// tell the client that they got the room // tell the client that they got the room
err = c.Send("recipient") err = c.Send("recipient")
if err != nil { if err != nil {
log.Error(err)
return return
} }
log.Debug("recipient connected")
return nil return nil
} }
log.Debug("sender connected")
receiver := rooms.rooms[room].receiver receiver := rooms.rooms[room].receiver
rooms.Unlock() rooms.Unlock()
@ -137,6 +142,7 @@ func chanFromConn(conn net.Conn) chan []byte {
c <- res c <- res
} }
if err != nil { if err != nil {
log.Debug(err)
c <- nil c <- nil
break break
} }