diff --git a/src/comm/comm.go b/src/comm/comm.go index 10571f18..5a72547d 100644 --- a/src/comm/comm.go +++ b/src/comm/comm.go @@ -12,22 +12,25 @@ import ( // Comm is some basic TCP communication type Comm struct { - connection *net.TCPConn + connection net.Conn writer *bufio.Writer + reader *bufio.Reader } // New returns a new comm -func New(n *net.TCPConn) *Comm { +func New(n net.Conn) *Comm { c := new(Comm) c.connection = n c.connection.SetReadDeadline(time.Now().Add(3 * time.Hour)) c.connection.SetDeadline(time.Now().Add(3 * time.Hour)) c.connection.SetWriteDeadline(time.Now().Add(3 * time.Hour)) + c.writer = bufio.NewWriter(n) + c.reader = bufio.NewReader(n) return c } // Connection returns the net.TCPConn connection -func (c *Comm) Connection() *net.TCPConn { +func (c *Comm) Connection() net.Conn { return c.connection } @@ -37,14 +40,14 @@ func (c *Comm) Close() { } func (c *Comm) Write(b []byte) (int, error) { - c.connection.Write([]byte(fmt.Sprintf("%0.6d", len(b)))) - n, err := c.connection.Write(b) + c.writer.Write([]byte(fmt.Sprintf("%0.6d", len(b)))) + n, err := c.writer.Write(b) if n != len(b) { err = fmt.Errorf("wanted to write %d but wrote %d", n, len(b)) } - // if err == nil { - // c.connection.Flush() - // } + if err == nil { + c.writer.Flush() + } // log.Printf("wanted to write %d but wrote %d", n, len(b)) return n, err } @@ -56,7 +59,7 @@ func (c *Comm) Write(b []byte) (int, error) { func (c *Comm) Read() (buf []byte, numBytes int, bs []byte, err error) { // read until we get 6 bytes tmp := make([]byte, 6) - n, err := c.connection.Read(tmp) + n, err := c.reader.Read(tmp) if err != nil { return } @@ -72,7 +75,7 @@ func (c *Comm) Read() (buf []byte, numBytes int, bs []byte, err error) { if len(bs) == 6 { break } - n, err := c.connection.Read(tmp) + n, err := c.reader.Read(tmp) if err != nil { return nil, 0, nil, err } diff --git a/src/tcp/tcp.go b/src/tcp/tcp.go index 98ebf6db..d22cef68 100644 --- a/src/tcp/tcp.go +++ b/src/tcp/tcp.go @@ -41,23 +41,28 @@ func Run(debugLevel, port string) { func run(port string) (err error) { log.Debugf("starting TCP server on " + port) - rAddr, err := net.ResolveTCPAddr("tcp", "0.0.0.0:"+port) - if err != nil { - panic(err) - } - server, err := net.ListenTCP("tcp", rAddr) + // rAddr, err := net.ResolveTCPAddr("tcp", "0.0.0.0:"+port) + // if err != nil { + // panic(err) + // } + // server, err := net.ListenTCP("tcp", rAddr) + // if err != nil { + // return errors.Wrap(err, "Error listening on :"+port) + // } + server, err := net.Listen("tcp", ":"+port) if err != nil { return errors.Wrap(err, "Error listening on :"+port) } defer server.Close() // spawn a new goroutine whenever a client connects for { - connection, err := server.AcceptTCP() + // connection, err := server.AcceptTCP() + connection, err := server.Accept() if err != nil { return errors.Wrap(err, "problem accepting connection") } log.Debugf("client %s connected", connection.RemoteAddr().String()) - go func(port string, connection *net.TCPConn) { + go func(port string, connection net.Conn) { errCommunication := clientCommuncation(port, comm.New(connection)) if errCommunication != nil { log.Warnf("relay-%s: %s", connection.RemoteAddr().String(), errCommunication.Error()) @@ -126,7 +131,7 @@ func clientCommuncation(port string, c *comm.Comm) (err error) { // chanFromConn creates a channel from a Conn object, and sends everything it // Read()s from the socket to the channel. -func chanFromConn(conn *net.TCPConn) chan []byte { +func chanFromConn(conn net.Conn) chan []byte { c := make(chan []byte) // reader := bufio.NewReader(conn) @@ -153,7 +158,7 @@ func chanFromConn(conn *net.TCPConn) chan []byte { // pipe creates a full-duplex pipe between the two sockets and // transfers data from one to the other. -func pipe(conn1 *net.TCPConn, conn2 *net.TCPConn) { +func pipe(conn1 net.Conn, conn2 net.Conn) { chan1 := chanFromConn(conn1) // chan2 := chanFromConn(conn2) // writer1 := bufio.NewWriter(conn1)