diff --git a/src/comm/comm.go b/src/comm/comm.go index 44c0c6e9..47267b20 100644 --- a/src/comm/comm.go +++ b/src/comm/comm.go @@ -7,6 +7,8 @@ import ( "strconv" "strings" "time" + + "github.com/pkg/errors" ) // Comm is some basic TCP communication @@ -39,7 +41,11 @@ func (c Comm) Write(b []byte) (int, error) { copy(tmpCopy[5:], b) n, err := c.connection.Write(tmpCopy) if n != len(tmpCopy) { - err = fmt.Errorf("wanted to write %d but wrote %d", len(b), n) + if err != nil { + err = errors.Wrap(err, fmt.Sprintf("wanted to write %d but wrote %d", len(b), n)) + } else { + err = fmt.Errorf("wanted to write %d but wrote %d", len(b), n) + } } // log.Printf("wanted to write %d but wrote %d", n, len(b)) return n, err diff --git a/src/recipient/recipient.go b/src/recipient/recipient.go index e0e83ff2..ed0a6427 100644 --- a/src/recipient/recipient.go +++ b/src/recipient/recipient.go @@ -55,6 +55,7 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo var resumeFile bool var tcpConnections []comm.Comm dataChan := make(chan []byte, 1024*1024) + isConnectedIfUsingTCP := make(chan bool) blocks := []string{} useWebsockets := true @@ -121,19 +122,38 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo if err := Q.Update(message); err != nil { return err } - c.WriteMessage(websocket.BinaryMessage, Q.Bytes()) - case 2: - log.Debugf("[%d] Q recieves H(k) from P", step) - if err := Q.Update(message); err != nil { - return err - } + // Q has the session key now, but we will still check if its valid sessionKey, err = Q.SessionKey() if err != nil { return err } log.Debugf("%x\n", sessionKey) + // initialize TCP connections if using (possible, but unlikely, race condition) + go func() { + if !useWebsockets { + log.Debugf("connecting to server") + tcpConnections = make([]comm.Comm, len(tcpPorts)) + for i, tcpPort := range tcpPorts { + log.Debugf("connecting to %d", i) + tcpConnections[i], err = connectToTCPServer(utils.SHA256(fmt.Sprintf("%d%x", i, sessionKey)), serverAddress+":"+tcpPort) + if err != nil { + log.Error(err) + } + } + log.Debugf("fully connected") + } + isConnectedIfUsingTCP <- true + }() + + c.WriteMessage(websocket.BinaryMessage, Q.Bytes()) + case 2: + log.Debugf("[%d] Q recieves H(k) from P", step) + // check if everything is still kosher with our computed session key + if err := Q.Update(message); err != nil { + return err + } c.WriteMessage(websocket.BinaryMessage, []byte("ready")) case 3: spin.Stop() @@ -155,7 +175,7 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo } log.Debugf("got file stats: %+v", fstats) - // prompt user if its okay to receive file + // determine if the file is resuming or not progressFile = fmt.Sprintf("%s.progress", fstats.SentName) overwritingOrReceiving := "Receiving" if utils.Exists(fstats.Name) || utils.Exists(fstats.SentName) { @@ -165,6 +185,22 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo resumeFile = true } } + + // send blocks + if resumeFile { + fileWithBlocks, _ := os.Open(progressFile) + scanner := bufio.NewScanner(fileWithBlocks) + for scanner.Scan() { + blocks = append(blocks, strings.TrimSpace(scanner.Text())) + } + fileWithBlocks.Close() + } + blocksBytes, _ := json.Marshal(blocks) + // encrypt the block data and send + encblockBytes := crypt.Encrypt(blocksBytes, sessionKey) + c.WriteMessage(websocket.BinaryMessage, encblockBytes.Bytes()) + + // prompt user about the file fileOrFolder := "file" if fstats.IsDir { fileOrFolder = "folder" @@ -183,21 +219,6 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo } } - // connect to TCP to receive file - if !useWebsockets { - log.Debugf("connecting to server") - tcpConnections = make([]comm.Comm, len(tcpPorts)) - for i, tcpPort := range tcpPorts { - tcpConnections[i], err = connectToTCPServer(utils.SHA256(fmt.Sprintf("%d%x", i, sessionKey)), serverAddress+":"+tcpPort) - if err != nil { - log.Error(err) - return err - } - defer tcpConnections[i].Close() - } - log.Debugf("fully connected") - } - // await file // erase file if overwriting if overwritingOrReceiving == "Overwriting" { @@ -228,17 +249,6 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo } } - // append the previous blocks if there was progress previously - if resumeFile { - file, _ := os.Open(progressFile) - scanner := bufio.NewScanner(file) - for scanner.Scan() { - blocks = append(blocks, strings.TrimSpace(scanner.Text())) - } - file.Close() - } - blocksBytes, _ := json.Marshal(blocks) - blockSize := 0 if useWebsockets { blockSize = models.WEBSOCKET_BUFFER_SIZE / 8 @@ -364,45 +374,44 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo log.Debug("got magic") break } - select { - case dataChan <- message: - default: - log.Debug("blocked") - // no message sent - // block - dataChan <- message - } + dataChan <- message + // select { + // case dataChan <- message: + // default: + // log.Debug("blocked") + // // no message sent + // // block + // dataChan <- message + // } } } else { + _ = <-isConnectedIfUsingTCP log.Debugf("starting listening with tcp with %d connections", len(tcpConnections)) // using TCP var wg sync.WaitGroup wg.Add(len(tcpConnections)) for i := range tcpConnections { - go func(wg *sync.WaitGroup, tcpConnection comm.Comm) { + defer func(i int) { + log.Debugf("closing connection %d", i) + tcpConnections[i].Close() + }(i) + go func(wg *sync.WaitGroup, j int) { defer wg.Done() for { + log.Debugf("waiting to read on %d", j) // read from TCP connection - message, _, _, err := tcpConnection.Read() + message, _, _, err := tcpConnections[j].Read() // log.Debugf("message: %s", message) if err != nil { - log.Error(err) - return + panic(err) } if bytes.Equal(message, []byte("magic")) { - log.Debug("got magic") + log.Debugf("%d got magic, leaving", j) return } - select { - case dataChan <- message: - default: - log.Debug("blocked") - // no message sent - // block - dataChan <- message - } + dataChan <- message } - }(&wg, tcpConnections[i]) + }(&wg, i) } wg.Wait() } diff --git a/src/sender/sender.go b/src/sender/sender.go index 08f09b5d..7e10af8a 100644 --- a/src/sender/sender.go +++ b/src/sender/sender.go @@ -54,6 +54,7 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool, var startTransfer time.Time var tcpConnections []comm.Comm blocksToSkip := make(map[int64]struct{}) + isConnectedIfUsingTCP := make(chan bool) type DataChan struct { b []byte @@ -203,6 +204,22 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool, return errors.New("recipient refused file") } + // connect to TCP in background + tcpConnections = make([]comm.Comm, len(tcpPorts)) + go func() { + if !useWebsockets { + log.Debugf("connecting to server") + for i, tcpPort := range tcpPorts { + log.Debugf("connecting to %s on connection %d", tcpPort, i) + tcpConnections[i], err = connectToTCPServer(utils.SHA256(fmt.Sprintf("%d%x", i, sessionKey)), serverAddress+":"+tcpPort) + if err != nil { + log.Error(err) + } + } + } + isConnectedIfUsingTCP <- true + }() + err = <-fileReady // block until file is ready if err != nil { return err @@ -217,16 +234,23 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool, // send the file meta data c.WriteMessage(websocket.BinaryMessage, enc.Bytes()) case 4: - spin.Stop() - - log.Debugf("[%d] recipient declares readiness for file data", step) - if !bytes.HasPrefix(message, []byte("ready")) { - return errors.New("recipient refused file") + log.Debugf("[%d] recipient declares gives blocks", step) + // recipient sends blocks, and sender does not send anything back + // determine if any blocks were sent to skip + enc, err := crypt.FromBytes(message) + if err != nil { + log.Error(err) + return err + } + decrypted, err := enc.Decrypt(sessionKey) + if err != nil { + err = errors.Wrap(err, "could not decrypt blocks with session key") + log.Error(err) + return err } - // determine if any blocks were sent to skip var blocks []string - errBlocks := json.Unmarshal(message[5:], &blocks) + errBlocks := json.Unmarshal(decrypted, &blocks) if errBlocks == nil { for _, block := range blocks { blockInt64, errBlock := strconv.Atoi(block) @@ -237,6 +261,7 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool, } log.Debugf("found blocks: %+v", blocksToSkip) + // start loading the file into memory // start streaming encryption/compression if fstats.IsDir { // remove file if zipped @@ -285,21 +310,10 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool, return } - select { - case dataChan <- DataChan{ + dataChan <- DataChan{ b: encBytes, bytesRead: bytesread, err: nil, - }: - default: - log.Debug("blocked") - // no message sent - // block - dataChan <- DataChan{ - b: encBytes, - bytesRead: bytesread, - err: nil, - } } currentPostition += int64(bytesread) } @@ -330,19 +344,12 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool, } }(dataChan) - // connect to TCP to receive file - if !useWebsockets { - log.Debugf("connecting to server") - tcpConnections = make([]comm.Comm, len(tcpPorts)) - for i, tcpPort := range tcpPorts { - log.Debug(tcpPort) - tcpConnections[i], err = connectToTCPServer(utils.SHA256(fmt.Sprintf("%d%x", i, sessionKey)), serverAddress+":"+tcpPort) - if err != nil { - log.Error(err) - return err - } - defer tcpConnections[i].Close() - } + case 5: + spin.Stop() + + log.Debugf("[%d] recipient declares readiness for file data", step) + if !bytes.HasPrefix(message, []byte("ready")) { + return errors.New("recipient refused file") } fmt.Fprintf(os.Stderr, "\rSending (->%s)...\n", otherIP) @@ -381,10 +388,16 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool, } } } else { + _ = <-isConnectedIfUsingTCP + log.Debug("connected and ready to send on tcp") var wg sync.WaitGroup wg.Add(len(tcpConnections)) for i := range tcpConnections { - go func(i int, wg *sync.WaitGroup, dataChan <-chan DataChan, tcpConnection comm.Comm) { + defer func(i int) { + log.Debugf("closing connection %d", i) + tcpConnections[i].Close() + }(i) + go func(i int, wg *sync.WaitGroup, dataChan <-chan DataChan) { defer wg.Done() for data := range dataChan { if data.err != nil { @@ -393,7 +406,7 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool, } bar.Add(data.bytesRead) // write data to tcp connection - _, err = tcpConnection.Write(data.b) + _, err = tcpConnections[i].Write(data.b) if err != nil { err = errors.Wrap(err, "problem writing message") log.Error(err) @@ -404,7 +417,7 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool, return } } - }(i, &wg, dataChan, tcpConnections[i]) + }(i, &wg, dataChan) } wg.Wait() } @@ -415,7 +428,8 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool, if err != nil { return err } - case 5: + case 6: + // recevied something, maybe the file hash transferTime := time.Since(startTransfer) if !bytes.HasPrefix(message, []byte("hash:")) { log.Debugf("%s", message) @@ -446,7 +460,7 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool, } func connectToTCPServer(room string, address string) (com comm.Comm, err error) { - connection, err := net.Dial("tcp", address) + connection, err := net.DialTimeout("tcp", address, 3*time.Hour) if err != nil { return } diff --git a/src/tcp/tcp.go b/src/tcp/tcp.go index 4177f77e..d5999e30 100644 --- a/src/tcp/tcp.go +++ b/src/tcp/tcp.go @@ -64,12 +64,14 @@ func run(port string) (err error) { func clientCommuncation(port string, c comm.Comm) (err error) { // send ok to tell client they are connected + log.Debug("sending ok") err = c.Send("ok") if err != nil { return } // wait for client to tell me which room they want + log.Debug("waiting for answer") room, err := c.Receive() if err != nil { return @@ -86,10 +88,13 @@ func clientCommuncation(port string, c comm.Comm) (err error) { // tell the client that they got the room err = c.Send("recipient") if err != nil { + log.Error(err) return } + log.Debug("recipient connected") return nil } + log.Debug("sender connected") receiver := rooms.rooms[room].receiver rooms.Unlock() @@ -137,6 +142,7 @@ func chanFromConn(conn net.Conn) chan []byte { c <- res } if err != nil { + log.Debug(err) c <- nil break }