mirror of
https://github.com/schollz/croc.git
synced 2025-10-11 13:21:00 +02:00
connect to TCP as early as possible
This commit is contained in:
parent
70fb9b77ca
commit
cc6edd24d1
4 changed files with 128 additions and 93 deletions
|
@ -7,6 +7,8 @@ import (
|
|||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Comm is some basic TCP communication
|
||||
|
@ -39,7 +41,11 @@ func (c Comm) Write(b []byte) (int, error) {
|
|||
copy(tmpCopy[5:], b)
|
||||
n, err := c.connection.Write(tmpCopy)
|
||||
if n != len(tmpCopy) {
|
||||
err = fmt.Errorf("wanted to write %d but wrote %d", len(b), n)
|
||||
if err != nil {
|
||||
err = errors.Wrap(err, fmt.Sprintf("wanted to write %d but wrote %d", len(b), n))
|
||||
} else {
|
||||
err = fmt.Errorf("wanted to write %d but wrote %d", len(b), n)
|
||||
}
|
||||
}
|
||||
// log.Printf("wanted to write %d but wrote %d", n, len(b))
|
||||
return n, err
|
||||
|
|
|
@ -55,6 +55,7 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
|
|||
var resumeFile bool
|
||||
var tcpConnections []comm.Comm
|
||||
dataChan := make(chan []byte, 1024*1024)
|
||||
isConnectedIfUsingTCP := make(chan bool)
|
||||
blocks := []string{}
|
||||
|
||||
useWebsockets := true
|
||||
|
@ -121,19 +122,38 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
|
|||
if err := Q.Update(message); err != nil {
|
||||
return err
|
||||
}
|
||||
c.WriteMessage(websocket.BinaryMessage, Q.Bytes())
|
||||
case 2:
|
||||
log.Debugf("[%d] Q recieves H(k) from P", step)
|
||||
if err := Q.Update(message); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Q has the session key now, but we will still check if its valid
|
||||
sessionKey, err = Q.SessionKey()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Debugf("%x\n", sessionKey)
|
||||
|
||||
// initialize TCP connections if using (possible, but unlikely, race condition)
|
||||
go func() {
|
||||
if !useWebsockets {
|
||||
log.Debugf("connecting to server")
|
||||
tcpConnections = make([]comm.Comm, len(tcpPorts))
|
||||
for i, tcpPort := range tcpPorts {
|
||||
log.Debugf("connecting to %d", i)
|
||||
tcpConnections[i], err = connectToTCPServer(utils.SHA256(fmt.Sprintf("%d%x", i, sessionKey)), serverAddress+":"+tcpPort)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
}
|
||||
log.Debugf("fully connected")
|
||||
}
|
||||
isConnectedIfUsingTCP <- true
|
||||
}()
|
||||
|
||||
c.WriteMessage(websocket.BinaryMessage, Q.Bytes())
|
||||
case 2:
|
||||
log.Debugf("[%d] Q recieves H(k) from P", step)
|
||||
// check if everything is still kosher with our computed session key
|
||||
if err := Q.Update(message); err != nil {
|
||||
return err
|
||||
}
|
||||
c.WriteMessage(websocket.BinaryMessage, []byte("ready"))
|
||||
case 3:
|
||||
spin.Stop()
|
||||
|
@ -155,7 +175,7 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
|
|||
}
|
||||
log.Debugf("got file stats: %+v", fstats)
|
||||
|
||||
// prompt user if its okay to receive file
|
||||
// determine if the file is resuming or not
|
||||
progressFile = fmt.Sprintf("%s.progress", fstats.SentName)
|
||||
overwritingOrReceiving := "Receiving"
|
||||
if utils.Exists(fstats.Name) || utils.Exists(fstats.SentName) {
|
||||
|
@ -165,6 +185,22 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
|
|||
resumeFile = true
|
||||
}
|
||||
}
|
||||
|
||||
// send blocks
|
||||
if resumeFile {
|
||||
fileWithBlocks, _ := os.Open(progressFile)
|
||||
scanner := bufio.NewScanner(fileWithBlocks)
|
||||
for scanner.Scan() {
|
||||
blocks = append(blocks, strings.TrimSpace(scanner.Text()))
|
||||
}
|
||||
fileWithBlocks.Close()
|
||||
}
|
||||
blocksBytes, _ := json.Marshal(blocks)
|
||||
// encrypt the block data and send
|
||||
encblockBytes := crypt.Encrypt(blocksBytes, sessionKey)
|
||||
c.WriteMessage(websocket.BinaryMessage, encblockBytes.Bytes())
|
||||
|
||||
// prompt user about the file
|
||||
fileOrFolder := "file"
|
||||
if fstats.IsDir {
|
||||
fileOrFolder = "folder"
|
||||
|
@ -183,21 +219,6 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
|
|||
}
|
||||
}
|
||||
|
||||
// connect to TCP to receive file
|
||||
if !useWebsockets {
|
||||
log.Debugf("connecting to server")
|
||||
tcpConnections = make([]comm.Comm, len(tcpPorts))
|
||||
for i, tcpPort := range tcpPorts {
|
||||
tcpConnections[i], err = connectToTCPServer(utils.SHA256(fmt.Sprintf("%d%x", i, sessionKey)), serverAddress+":"+tcpPort)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
defer tcpConnections[i].Close()
|
||||
}
|
||||
log.Debugf("fully connected")
|
||||
}
|
||||
|
||||
// await file
|
||||
// erase file if overwriting
|
||||
if overwritingOrReceiving == "Overwriting" {
|
||||
|
@ -228,17 +249,6 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
|
|||
}
|
||||
}
|
||||
|
||||
// append the previous blocks if there was progress previously
|
||||
if resumeFile {
|
||||
file, _ := os.Open(progressFile)
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
blocks = append(blocks, strings.TrimSpace(scanner.Text()))
|
||||
}
|
||||
file.Close()
|
||||
}
|
||||
blocksBytes, _ := json.Marshal(blocks)
|
||||
|
||||
blockSize := 0
|
||||
if useWebsockets {
|
||||
blockSize = models.WEBSOCKET_BUFFER_SIZE / 8
|
||||
|
@ -364,45 +374,44 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
|
|||
log.Debug("got magic")
|
||||
break
|
||||
}
|
||||
select {
|
||||
case dataChan <- message:
|
||||
default:
|
||||
log.Debug("blocked")
|
||||
// no message sent
|
||||
// block
|
||||
dataChan <- message
|
||||
}
|
||||
dataChan <- message
|
||||
// select {
|
||||
// case dataChan <- message:
|
||||
// default:
|
||||
// log.Debug("blocked")
|
||||
// // no message sent
|
||||
// // block
|
||||
// dataChan <- message
|
||||
// }
|
||||
}
|
||||
} else {
|
||||
_ = <-isConnectedIfUsingTCP
|
||||
log.Debugf("starting listening with tcp with %d connections", len(tcpConnections))
|
||||
// using TCP
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(tcpConnections))
|
||||
for i := range tcpConnections {
|
||||
go func(wg *sync.WaitGroup, tcpConnection comm.Comm) {
|
||||
defer func(i int) {
|
||||
log.Debugf("closing connection %d", i)
|
||||
tcpConnections[i].Close()
|
||||
}(i)
|
||||
go func(wg *sync.WaitGroup, j int) {
|
||||
defer wg.Done()
|
||||
for {
|
||||
log.Debugf("waiting to read on %d", j)
|
||||
// read from TCP connection
|
||||
message, _, _, err := tcpConnection.Read()
|
||||
message, _, _, err := tcpConnections[j].Read()
|
||||
// log.Debugf("message: %s", message)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return
|
||||
panic(err)
|
||||
}
|
||||
if bytes.Equal(message, []byte("magic")) {
|
||||
log.Debug("got magic")
|
||||
log.Debugf("%d got magic, leaving", j)
|
||||
return
|
||||
}
|
||||
select {
|
||||
case dataChan <- message:
|
||||
default:
|
||||
log.Debug("blocked")
|
||||
// no message sent
|
||||
// block
|
||||
dataChan <- message
|
||||
}
|
||||
dataChan <- message
|
||||
}
|
||||
}(&wg, tcpConnections[i])
|
||||
}(&wg, i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
|
|
@ -54,6 +54,7 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
|
|||
var startTransfer time.Time
|
||||
var tcpConnections []comm.Comm
|
||||
blocksToSkip := make(map[int64]struct{})
|
||||
isConnectedIfUsingTCP := make(chan bool)
|
||||
|
||||
type DataChan struct {
|
||||
b []byte
|
||||
|
@ -203,6 +204,22 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
|
|||
return errors.New("recipient refused file")
|
||||
}
|
||||
|
||||
// connect to TCP in background
|
||||
tcpConnections = make([]comm.Comm, len(tcpPorts))
|
||||
go func() {
|
||||
if !useWebsockets {
|
||||
log.Debugf("connecting to server")
|
||||
for i, tcpPort := range tcpPorts {
|
||||
log.Debugf("connecting to %s on connection %d", tcpPort, i)
|
||||
tcpConnections[i], err = connectToTCPServer(utils.SHA256(fmt.Sprintf("%d%x", i, sessionKey)), serverAddress+":"+tcpPort)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
isConnectedIfUsingTCP <- true
|
||||
}()
|
||||
|
||||
err = <-fileReady // block until file is ready
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -217,16 +234,23 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
|
|||
// send the file meta data
|
||||
c.WriteMessage(websocket.BinaryMessage, enc.Bytes())
|
||||
case 4:
|
||||
spin.Stop()
|
||||
|
||||
log.Debugf("[%d] recipient declares readiness for file data", step)
|
||||
if !bytes.HasPrefix(message, []byte("ready")) {
|
||||
return errors.New("recipient refused file")
|
||||
log.Debugf("[%d] recipient declares gives blocks", step)
|
||||
// recipient sends blocks, and sender does not send anything back
|
||||
// determine if any blocks were sent to skip
|
||||
enc, err := crypt.FromBytes(message)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
decrypted, err := enc.Decrypt(sessionKey)
|
||||
if err != nil {
|
||||
err = errors.Wrap(err, "could not decrypt blocks with session key")
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
|
||||
// determine if any blocks were sent to skip
|
||||
var blocks []string
|
||||
errBlocks := json.Unmarshal(message[5:], &blocks)
|
||||
errBlocks := json.Unmarshal(decrypted, &blocks)
|
||||
if errBlocks == nil {
|
||||
for _, block := range blocks {
|
||||
blockInt64, errBlock := strconv.Atoi(block)
|
||||
|
@ -237,6 +261,7 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
|
|||
}
|
||||
log.Debugf("found blocks: %+v", blocksToSkip)
|
||||
|
||||
// start loading the file into memory
|
||||
// start streaming encryption/compression
|
||||
if fstats.IsDir {
|
||||
// remove file if zipped
|
||||
|
@ -285,21 +310,10 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
|
|||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case dataChan <- DataChan{
|
||||
dataChan <- DataChan{
|
||||
b: encBytes,
|
||||
bytesRead: bytesread,
|
||||
err: nil,
|
||||
}:
|
||||
default:
|
||||
log.Debug("blocked")
|
||||
// no message sent
|
||||
// block
|
||||
dataChan <- DataChan{
|
||||
b: encBytes,
|
||||
bytesRead: bytesread,
|
||||
err: nil,
|
||||
}
|
||||
}
|
||||
currentPostition += int64(bytesread)
|
||||
}
|
||||
|
@ -330,19 +344,12 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
|
|||
}
|
||||
}(dataChan)
|
||||
|
||||
// connect to TCP to receive file
|
||||
if !useWebsockets {
|
||||
log.Debugf("connecting to server")
|
||||
tcpConnections = make([]comm.Comm, len(tcpPorts))
|
||||
for i, tcpPort := range tcpPorts {
|
||||
log.Debug(tcpPort)
|
||||
tcpConnections[i], err = connectToTCPServer(utils.SHA256(fmt.Sprintf("%d%x", i, sessionKey)), serverAddress+":"+tcpPort)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return err
|
||||
}
|
||||
defer tcpConnections[i].Close()
|
||||
}
|
||||
case 5:
|
||||
spin.Stop()
|
||||
|
||||
log.Debugf("[%d] recipient declares readiness for file data", step)
|
||||
if !bytes.HasPrefix(message, []byte("ready")) {
|
||||
return errors.New("recipient refused file")
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\rSending (->%s)...\n", otherIP)
|
||||
|
@ -381,10 +388,16 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
|
|||
}
|
||||
}
|
||||
} else {
|
||||
_ = <-isConnectedIfUsingTCP
|
||||
log.Debug("connected and ready to send on tcp")
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(tcpConnections))
|
||||
for i := range tcpConnections {
|
||||
go func(i int, wg *sync.WaitGroup, dataChan <-chan DataChan, tcpConnection comm.Comm) {
|
||||
defer func(i int) {
|
||||
log.Debugf("closing connection %d", i)
|
||||
tcpConnections[i].Close()
|
||||
}(i)
|
||||
go func(i int, wg *sync.WaitGroup, dataChan <-chan DataChan) {
|
||||
defer wg.Done()
|
||||
for data := range dataChan {
|
||||
if data.err != nil {
|
||||
|
@ -393,7 +406,7 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
|
|||
}
|
||||
bar.Add(data.bytesRead)
|
||||
// write data to tcp connection
|
||||
_, err = tcpConnection.Write(data.b)
|
||||
_, err = tcpConnections[i].Write(data.b)
|
||||
if err != nil {
|
||||
err = errors.Wrap(err, "problem writing message")
|
||||
log.Error(err)
|
||||
|
@ -404,7 +417,7 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
|
|||
return
|
||||
}
|
||||
}
|
||||
}(i, &wg, dataChan, tcpConnections[i])
|
||||
}(i, &wg, dataChan)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
@ -415,7 +428,8 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case 5:
|
||||
case 6:
|
||||
// recevied something, maybe the file hash
|
||||
transferTime := time.Since(startTransfer)
|
||||
if !bytes.HasPrefix(message, []byte("hash:")) {
|
||||
log.Debugf("%s", message)
|
||||
|
@ -446,7 +460,7 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
|
|||
}
|
||||
|
||||
func connectToTCPServer(room string, address string) (com comm.Comm, err error) {
|
||||
connection, err := net.Dial("tcp", address)
|
||||
connection, err := net.DialTimeout("tcp", address, 3*time.Hour)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
|
@ -64,12 +64,14 @@ func run(port string) (err error) {
|
|||
|
||||
func clientCommuncation(port string, c comm.Comm) (err error) {
|
||||
// send ok to tell client they are connected
|
||||
log.Debug("sending ok")
|
||||
err = c.Send("ok")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// wait for client to tell me which room they want
|
||||
log.Debug("waiting for answer")
|
||||
room, err := c.Receive()
|
||||
if err != nil {
|
||||
return
|
||||
|
@ -86,10 +88,13 @@ func clientCommuncation(port string, c comm.Comm) (err error) {
|
|||
// tell the client that they got the room
|
||||
err = c.Send("recipient")
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
log.Debug("recipient connected")
|
||||
return nil
|
||||
}
|
||||
log.Debug("sender connected")
|
||||
receiver := rooms.rooms[room].receiver
|
||||
rooms.Unlock()
|
||||
|
||||
|
@ -137,6 +142,7 @@ func chanFromConn(conn net.Conn) chan []byte {
|
|||
c <- res
|
||||
}
|
||||
if err != nil {
|
||||
log.Debug(err)
|
||||
c <- nil
|
||||
break
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue