0
0
Fork 0
mirror of https://github.com/schollz/croc.git synced 2025-10-11 13:21:00 +02:00

use netconn

This commit is contained in:
Zack Scholl 2018-09-26 10:43:38 -07:00
parent 1261a16b55
commit 420d2be271
2 changed files with 27 additions and 19 deletions

View file

@ -12,22 +12,25 @@ import (
// Comm is some basic TCP communication // Comm is some basic TCP communication
type Comm struct { type Comm struct {
connection *net.TCPConn connection net.Conn
writer *bufio.Writer writer *bufio.Writer
reader *bufio.Reader
} }
// New returns a new comm // New returns a new comm
func New(n *net.TCPConn) *Comm { func New(n net.Conn) *Comm {
c := new(Comm) c := new(Comm)
c.connection = n c.connection = n
c.connection.SetReadDeadline(time.Now().Add(3 * time.Hour)) c.connection.SetReadDeadline(time.Now().Add(3 * time.Hour))
c.connection.SetDeadline(time.Now().Add(3 * time.Hour)) c.connection.SetDeadline(time.Now().Add(3 * time.Hour))
c.connection.SetWriteDeadline(time.Now().Add(3 * time.Hour)) c.connection.SetWriteDeadline(time.Now().Add(3 * time.Hour))
c.writer = bufio.NewWriter(n)
c.reader = bufio.NewReader(n)
return c return c
} }
// Connection returns the net.TCPConn connection // Connection returns the net.TCPConn connection
func (c *Comm) Connection() *net.TCPConn { func (c *Comm) Connection() net.Conn {
return c.connection return c.connection
} }
@ -37,14 +40,14 @@ func (c *Comm) Close() {
} }
func (c *Comm) Write(b []byte) (int, error) { func (c *Comm) Write(b []byte) (int, error) {
c.connection.Write([]byte(fmt.Sprintf("%0.6d", len(b)))) c.writer.Write([]byte(fmt.Sprintf("%0.6d", len(b))))
n, err := c.connection.Write(b) n, err := c.writer.Write(b)
if n != len(b) { if n != len(b) {
err = fmt.Errorf("wanted to write %d but wrote %d", n, len(b)) err = fmt.Errorf("wanted to write %d but wrote %d", n, len(b))
} }
// if err == nil { if err == nil {
// c.connection.Flush() c.writer.Flush()
// } }
// log.Printf("wanted to write %d but wrote %d", n, len(b)) // log.Printf("wanted to write %d but wrote %d", n, len(b))
return n, err return n, err
} }
@ -56,7 +59,7 @@ func (c *Comm) Write(b []byte) (int, error) {
func (c *Comm) Read() (buf []byte, numBytes int, bs []byte, err error) { func (c *Comm) Read() (buf []byte, numBytes int, bs []byte, err error) {
// read until we get 6 bytes // read until we get 6 bytes
tmp := make([]byte, 6) tmp := make([]byte, 6)
n, err := c.connection.Read(tmp) n, err := c.reader.Read(tmp)
if err != nil { if err != nil {
return return
} }
@ -72,7 +75,7 @@ func (c *Comm) Read() (buf []byte, numBytes int, bs []byte, err error) {
if len(bs) == 6 { if len(bs) == 6 {
break break
} }
n, err := c.connection.Read(tmp) n, err := c.reader.Read(tmp)
if err != nil { if err != nil {
return nil, 0, nil, err return nil, 0, nil, err
} }

View file

@ -41,23 +41,28 @@ func Run(debugLevel, port string) {
func run(port string) (err error) { func run(port string) (err error) {
log.Debugf("starting TCP server on " + port) log.Debugf("starting TCP server on " + port)
rAddr, err := net.ResolveTCPAddr("tcp", "0.0.0.0:"+port) // rAddr, err := net.ResolveTCPAddr("tcp", "0.0.0.0:"+port)
if err != nil { // if err != nil {
panic(err) // panic(err)
} // }
server, err := net.ListenTCP("tcp", rAddr) // server, err := net.ListenTCP("tcp", rAddr)
// if err != nil {
// return errors.Wrap(err, "Error listening on :"+port)
// }
server, err := net.Listen("tcp", ":"+port)
if err != nil { if err != nil {
return errors.Wrap(err, "Error listening on :"+port) return errors.Wrap(err, "Error listening on :"+port)
} }
defer server.Close() defer server.Close()
// spawn a new goroutine whenever a client connects // spawn a new goroutine whenever a client connects
for { for {
connection, err := server.AcceptTCP() // connection, err := server.AcceptTCP()
connection, err := server.Accept()
if err != nil { if err != nil {
return errors.Wrap(err, "problem accepting connection") return errors.Wrap(err, "problem accepting connection")
} }
log.Debugf("client %s connected", connection.RemoteAddr().String()) log.Debugf("client %s connected", connection.RemoteAddr().String())
go func(port string, connection *net.TCPConn) { go func(port string, connection net.Conn) {
errCommunication := clientCommuncation(port, comm.New(connection)) errCommunication := clientCommuncation(port, comm.New(connection))
if errCommunication != nil { if errCommunication != nil {
log.Warnf("relay-%s: %s", connection.RemoteAddr().String(), errCommunication.Error()) log.Warnf("relay-%s: %s", connection.RemoteAddr().String(), errCommunication.Error())
@ -126,7 +131,7 @@ func clientCommuncation(port string, c *comm.Comm) (err error) {
// chanFromConn creates a channel from a Conn object, and sends everything it // chanFromConn creates a channel from a Conn object, and sends everything it
// Read()s from the socket to the channel. // Read()s from the socket to the channel.
func chanFromConn(conn *net.TCPConn) chan []byte { func chanFromConn(conn net.Conn) chan []byte {
c := make(chan []byte) c := make(chan []byte)
// reader := bufio.NewReader(conn) // reader := bufio.NewReader(conn)
@ -153,7 +158,7 @@ func chanFromConn(conn *net.TCPConn) chan []byte {
// pipe creates a full-duplex pipe between the two sockets and // pipe creates a full-duplex pipe between the two sockets and
// transfers data from one to the other. // transfers data from one to the other.
func pipe(conn1 *net.TCPConn, conn2 *net.TCPConn) { func pipe(conn1 net.Conn, conn2 net.Conn) {
chan1 := chanFromConn(conn1) chan1 := chanFromConn(conn1)
// chan2 := chanFromConn(conn2) // chan2 := chanFromConn(conn2)
// writer1 := bufio.NewWriter(conn1) // writer1 := bufio.NewWriter(conn1)