0
0
Fork 0
mirror of https://github.com/schollz/croc.git synced 2025-10-11 13:21:00 +02:00

connect to TCP as early as possible

This commit is contained in:
Zack Scholl 2018-10-13 06:09:55 -07:00
parent 70fb9b77ca
commit cc6edd24d1
4 changed files with 128 additions and 93 deletions

View file

@ -7,6 +7,8 @@ import (
"strconv"
"strings"
"time"
"github.com/pkg/errors"
)
// Comm is some basic TCP communication
@ -39,7 +41,11 @@ func (c Comm) Write(b []byte) (int, error) {
copy(tmpCopy[5:], b)
n, err := c.connection.Write(tmpCopy)
if n != len(tmpCopy) {
err = fmt.Errorf("wanted to write %d but wrote %d", len(b), n)
if err != nil {
err = errors.Wrap(err, fmt.Sprintf("wanted to write %d but wrote %d", len(b), n))
} else {
err = fmt.Errorf("wanted to write %d but wrote %d", len(b), n)
}
}
// log.Printf("wanted to write %d but wrote %d", n, len(b))
return n, err

View file

@ -55,6 +55,7 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
var resumeFile bool
var tcpConnections []comm.Comm
dataChan := make(chan []byte, 1024*1024)
isConnectedIfUsingTCP := make(chan bool)
blocks := []string{}
useWebsockets := true
@ -121,19 +122,38 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
if err := Q.Update(message); err != nil {
return err
}
c.WriteMessage(websocket.BinaryMessage, Q.Bytes())
case 2:
log.Debugf("[%d] Q recieves H(k) from P", step)
if err := Q.Update(message); err != nil {
return err
}
// Q has the session key now, but we will still check if its valid
sessionKey, err = Q.SessionKey()
if err != nil {
return err
}
log.Debugf("%x\n", sessionKey)
// initialize TCP connections if using (possible, but unlikely, race condition)
go func() {
if !useWebsockets {
log.Debugf("connecting to server")
tcpConnections = make([]comm.Comm, len(tcpPorts))
for i, tcpPort := range tcpPorts {
log.Debugf("connecting to %d", i)
tcpConnections[i], err = connectToTCPServer(utils.SHA256(fmt.Sprintf("%d%x", i, sessionKey)), serverAddress+":"+tcpPort)
if err != nil {
log.Error(err)
}
}
log.Debugf("fully connected")
}
isConnectedIfUsingTCP <- true
}()
c.WriteMessage(websocket.BinaryMessage, Q.Bytes())
case 2:
log.Debugf("[%d] Q recieves H(k) from P", step)
// check if everything is still kosher with our computed session key
if err := Q.Update(message); err != nil {
return err
}
c.WriteMessage(websocket.BinaryMessage, []byte("ready"))
case 3:
spin.Stop()
@ -155,7 +175,7 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
}
log.Debugf("got file stats: %+v", fstats)
// prompt user if its okay to receive file
// determine if the file is resuming or not
progressFile = fmt.Sprintf("%s.progress", fstats.SentName)
overwritingOrReceiving := "Receiving"
if utils.Exists(fstats.Name) || utils.Exists(fstats.SentName) {
@ -165,6 +185,22 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
resumeFile = true
}
}
// send blocks
if resumeFile {
fileWithBlocks, _ := os.Open(progressFile)
scanner := bufio.NewScanner(fileWithBlocks)
for scanner.Scan() {
blocks = append(blocks, strings.TrimSpace(scanner.Text()))
}
fileWithBlocks.Close()
}
blocksBytes, _ := json.Marshal(blocks)
// encrypt the block data and send
encblockBytes := crypt.Encrypt(blocksBytes, sessionKey)
c.WriteMessage(websocket.BinaryMessage, encblockBytes.Bytes())
// prompt user about the file
fileOrFolder := "file"
if fstats.IsDir {
fileOrFolder = "folder"
@ -183,21 +219,6 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
}
}
// connect to TCP to receive file
if !useWebsockets {
log.Debugf("connecting to server")
tcpConnections = make([]comm.Comm, len(tcpPorts))
for i, tcpPort := range tcpPorts {
tcpConnections[i], err = connectToTCPServer(utils.SHA256(fmt.Sprintf("%d%x", i, sessionKey)), serverAddress+":"+tcpPort)
if err != nil {
log.Error(err)
return err
}
defer tcpConnections[i].Close()
}
log.Debugf("fully connected")
}
// await file
// erase file if overwriting
if overwritingOrReceiving == "Overwriting" {
@ -228,17 +249,6 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
}
}
// append the previous blocks if there was progress previously
if resumeFile {
file, _ := os.Open(progressFile)
scanner := bufio.NewScanner(file)
for scanner.Scan() {
blocks = append(blocks, strings.TrimSpace(scanner.Text()))
}
file.Close()
}
blocksBytes, _ := json.Marshal(blocks)
blockSize := 0
if useWebsockets {
blockSize = models.WEBSOCKET_BUFFER_SIZE / 8
@ -364,45 +374,44 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
log.Debug("got magic")
break
}
select {
case dataChan <- message:
default:
log.Debug("blocked")
// no message sent
// block
dataChan <- message
}
dataChan <- message
// select {
// case dataChan <- message:
// default:
// log.Debug("blocked")
// // no message sent
// // block
// dataChan <- message
// }
}
} else {
_ = <-isConnectedIfUsingTCP
log.Debugf("starting listening with tcp with %d connections", len(tcpConnections))
// using TCP
var wg sync.WaitGroup
wg.Add(len(tcpConnections))
for i := range tcpConnections {
go func(wg *sync.WaitGroup, tcpConnection comm.Comm) {
defer func(i int) {
log.Debugf("closing connection %d", i)
tcpConnections[i].Close()
}(i)
go func(wg *sync.WaitGroup, j int) {
defer wg.Done()
for {
log.Debugf("waiting to read on %d", j)
// read from TCP connection
message, _, _, err := tcpConnection.Read()
message, _, _, err := tcpConnections[j].Read()
// log.Debugf("message: %s", message)
if err != nil {
log.Error(err)
return
panic(err)
}
if bytes.Equal(message, []byte("magic")) {
log.Debug("got magic")
log.Debugf("%d got magic, leaving", j)
return
}
select {
case dataChan <- message:
default:
log.Debug("blocked")
// no message sent
// block
dataChan <- message
}
dataChan <- message
}
}(&wg, tcpConnections[i])
}(&wg, i)
}
wg.Wait()
}

View file

@ -54,6 +54,7 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
var startTransfer time.Time
var tcpConnections []comm.Comm
blocksToSkip := make(map[int64]struct{})
isConnectedIfUsingTCP := make(chan bool)
type DataChan struct {
b []byte
@ -203,6 +204,22 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
return errors.New("recipient refused file")
}
// connect to TCP in background
tcpConnections = make([]comm.Comm, len(tcpPorts))
go func() {
if !useWebsockets {
log.Debugf("connecting to server")
for i, tcpPort := range tcpPorts {
log.Debugf("connecting to %s on connection %d", tcpPort, i)
tcpConnections[i], err = connectToTCPServer(utils.SHA256(fmt.Sprintf("%d%x", i, sessionKey)), serverAddress+":"+tcpPort)
if err != nil {
log.Error(err)
}
}
}
isConnectedIfUsingTCP <- true
}()
err = <-fileReady // block until file is ready
if err != nil {
return err
@ -217,16 +234,23 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
// send the file meta data
c.WriteMessage(websocket.BinaryMessage, enc.Bytes())
case 4:
spin.Stop()
log.Debugf("[%d] recipient declares readiness for file data", step)
if !bytes.HasPrefix(message, []byte("ready")) {
return errors.New("recipient refused file")
log.Debugf("[%d] recipient declares gives blocks", step)
// recipient sends blocks, and sender does not send anything back
// determine if any blocks were sent to skip
enc, err := crypt.FromBytes(message)
if err != nil {
log.Error(err)
return err
}
decrypted, err := enc.Decrypt(sessionKey)
if err != nil {
err = errors.Wrap(err, "could not decrypt blocks with session key")
log.Error(err)
return err
}
// determine if any blocks were sent to skip
var blocks []string
errBlocks := json.Unmarshal(message[5:], &blocks)
errBlocks := json.Unmarshal(decrypted, &blocks)
if errBlocks == nil {
for _, block := range blocks {
blockInt64, errBlock := strconv.Atoi(block)
@ -237,6 +261,7 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
}
log.Debugf("found blocks: %+v", blocksToSkip)
// start loading the file into memory
// start streaming encryption/compression
if fstats.IsDir {
// remove file if zipped
@ -285,21 +310,10 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
return
}
select {
case dataChan <- DataChan{
dataChan <- DataChan{
b: encBytes,
bytesRead: bytesread,
err: nil,
}:
default:
log.Debug("blocked")
// no message sent
// block
dataChan <- DataChan{
b: encBytes,
bytesRead: bytesread,
err: nil,
}
}
currentPostition += int64(bytesread)
}
@ -330,19 +344,12 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
}
}(dataChan)
// connect to TCP to receive file
if !useWebsockets {
log.Debugf("connecting to server")
tcpConnections = make([]comm.Comm, len(tcpPorts))
for i, tcpPort := range tcpPorts {
log.Debug(tcpPort)
tcpConnections[i], err = connectToTCPServer(utils.SHA256(fmt.Sprintf("%d%x", i, sessionKey)), serverAddress+":"+tcpPort)
if err != nil {
log.Error(err)
return err
}
defer tcpConnections[i].Close()
}
case 5:
spin.Stop()
log.Debugf("[%d] recipient declares readiness for file data", step)
if !bytes.HasPrefix(message, []byte("ready")) {
return errors.New("recipient refused file")
}
fmt.Fprintf(os.Stderr, "\rSending (->%s)...\n", otherIP)
@ -381,10 +388,16 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
}
}
} else {
_ = <-isConnectedIfUsingTCP
log.Debug("connected and ready to send on tcp")
var wg sync.WaitGroup
wg.Add(len(tcpConnections))
for i := range tcpConnections {
go func(i int, wg *sync.WaitGroup, dataChan <-chan DataChan, tcpConnection comm.Comm) {
defer func(i int) {
log.Debugf("closing connection %d", i)
tcpConnections[i].Close()
}(i)
go func(i int, wg *sync.WaitGroup, dataChan <-chan DataChan) {
defer wg.Done()
for data := range dataChan {
if data.err != nil {
@ -393,7 +406,7 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
}
bar.Add(data.bytesRead)
// write data to tcp connection
_, err = tcpConnection.Write(data.b)
_, err = tcpConnections[i].Write(data.b)
if err != nil {
err = errors.Wrap(err, "problem writing message")
log.Error(err)
@ -404,7 +417,7 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
return
}
}
}(i, &wg, dataChan, tcpConnections[i])
}(i, &wg, dataChan)
}
wg.Wait()
}
@ -415,7 +428,8 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
if err != nil {
return err
}
case 5:
case 6:
// recevied something, maybe the file hash
transferTime := time.Since(startTransfer)
if !bytes.HasPrefix(message, []byte("hash:")) {
log.Debugf("%s", message)
@ -446,7 +460,7 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
}
func connectToTCPServer(room string, address string) (com comm.Comm, err error) {
connection, err := net.Dial("tcp", address)
connection, err := net.DialTimeout("tcp", address, 3*time.Hour)
if err != nil {
return
}

View file

@ -64,12 +64,14 @@ func run(port string) (err error) {
func clientCommuncation(port string, c comm.Comm) (err error) {
// send ok to tell client they are connected
log.Debug("sending ok")
err = c.Send("ok")
if err != nil {
return
}
// wait for client to tell me which room they want
log.Debug("waiting for answer")
room, err := c.Receive()
if err != nil {
return
@ -86,10 +88,13 @@ func clientCommuncation(port string, c comm.Comm) (err error) {
// tell the client that they got the room
err = c.Send("recipient")
if err != nil {
log.Error(err)
return
}
log.Debug("recipient connected")
return nil
}
log.Debug("sender connected")
receiver := rooms.rooms[room].receiver
rooms.Unlock()
@ -137,6 +142,7 @@ func chanFromConn(conn net.Conn) chan []byte {
c <- res
}
if err != nil {
log.Debug(err)
c <- nil
break
}