diff --git a/src/croc/sender.go b/src/croc/sender.go index 1ae0bee8..50fb71b3 100644 --- a/src/croc/sender.go +++ b/src/croc/sender.go @@ -429,6 +429,32 @@ func (cr *Croc) send(forceSend int, serverAddress string, tcpPorts []string, isL } else { _ = <-isConnectedIfUsingTCP log.Debug("connected and ready to send on tcp") + + // check to see if any messages are sent + stopMessageSignal := make(chan bool, 1) + errorsDuringTransfer := make(chan error, 24) + go func() { + for { + select { + case sig := <-stopMessageSignal: + errorsDuringTransfer <- nil + log.Debugf("got message signal: %+v", sig) + return + case wsMessage := <-websocketMessages: + log.Debugf("got message: %s", wsMessage.message) + if bytes.HasPrefix(wsMessage.message, []byte("error")) { + log.Debug("stopping transfer") + for i := 0; i < len(tcpConnections)+1; i++ { + errorsDuringTransfer <- fmt.Errorf("%s", wsMessage.message) + } + return + } + default: + continue + } + } + }() + var wg sync.WaitGroup wg.Add(len(tcpConnections)) for i := range tcpConnections { @@ -439,16 +465,23 @@ func (cr *Croc) send(forceSend int, serverAddress string, tcpPorts []string, isL go func(i int, wg *sync.WaitGroup, dataChan <-chan DataChan) { defer wg.Done() for data := range dataChan { + select { + case _ = <-errorsDuringTransfer: + log.Debugf("%d got stop", i) + return + default: + } if data.err != nil { log.Error(data.err) return } cr.Bar.Add(data.bytesRead) // write data to tcp connection - _, err = tcpConnections[i].Write(data.b) - if err != nil { - err = errors.Wrap(err, "problem writing message") - log.Error(err) + _, errTcp := tcpConnections[i].Write(data.b) + if errTcp != nil { + errTcp = errors.Wrap(errTcp, "problem writing message") + log.Debug(errTcp) + errorsDuringTransfer <- errTcp return } if bytes.Equal(data.b, []byte("magic")) { @@ -458,7 +491,17 @@ func (cr *Croc) send(forceSend int, serverAddress string, tcpPorts []string, isL } }(i, &wg, dataChan) } + // block until this is done + log.Debug("waiting for tcp goroutines") wg.Wait() + log.Debug("sending stop message signal") + stopMessageSignal <- true + log.Debug("waiting for error") + errorDuringTransfer := <-errorsDuringTransfer + if errorDuringTransfer != nil { + log.Debugf("got error during transfer: %s", errorDuringTransfer.Error()) + return errorDuringTransfer + } } cr.Bar.Finish()