0
0
Fork 0
mirror of https://github.com/schollz/croc.git synced 2025-10-11 21:30:16 +02:00
This commit is contained in:
Zack Scholl 2018-09-26 14:31:45 -07:00
parent bc4bd2b0a8
commit fdfa7f209d
6 changed files with 48 additions and 97 deletions

View file

@ -1,7 +1,6 @@
package comm package comm
import ( import (
"bufio"
"bytes" "bytes"
"fmt" "fmt"
"net" "net"
@ -13,52 +12,40 @@ import (
// Comm is some basic TCP communication // Comm is some basic TCP communication
type Comm struct { type Comm struct {
connection net.Conn connection net.Conn
writer *bufio.Writer
reader *bufio.Reader
} }
// New returns a new comm // New returns a new comm
func New(n net.Conn) *Comm { func New(c net.Conn) Comm {
c := new(Comm) c.SetReadDeadline(time.Now().Add(3 * time.Hour))
c.connection = n c.SetDeadline(time.Now().Add(3 * time.Hour))
c.connection.SetReadDeadline(time.Now().Add(3 * time.Hour)) c.SetWriteDeadline(time.Now().Add(3 * time.Hour))
c.connection.SetDeadline(time.Now().Add(3 * time.Hour)) return Comm{c}
c.connection.SetWriteDeadline(time.Now().Add(3 * time.Hour))
c.writer = bufio.NewWriter(n)
// c.connection = bufio.NewReader(n)
return c
} }
// Connection returns the net.TCPConn connection // Connection returns the net.Conn connection
func (c *Comm) Connection() net.Conn { func (c Comm) Connection() net.Conn {
return c.connection return c.connection
} }
// Close closes the connection // Close closes the connection
func (c *Comm) Close() { func (c Comm) Close() {
c.connection.Close() c.connection.Close()
} }
func (c *Comm) Write(b []byte) (int, error) {
c.connection.Write([]byte(fmt.Sprintf("%0.6d", len(b)))) func (c Comm) Write(b []byte) (int, error) {
c.connection.Write([]byte(fmt.Sprintf("%0.5d", len(b))))
n, err := c.connection.Write(b) n, err := c.connection.Write(b)
if n != len(b) { if n != len(b) {
err = fmt.Errorf("wanted to write %d but wrote %d", n, len(b)) err = fmt.Errorf("wanted to write %d but wrote %d", n, len(b))
} }
// if err == nil {
// c.writer.Flush()
// }
// log.Printf("wanted to write %d but wrote %d", n, len(b)) // log.Printf("wanted to write %d but wrote %d", n, len(b))
return n, err return n, err
} }
// func (c *Comm) Flush() { func (c Comm) Read() (buf []byte, numBytes int, bs []byte, err error) {
// c.connection.Flush() // read until we get 5 bytes
// } tmp := make([]byte, 5)
func (c *Comm) Read() (buf []byte, numBytes int, bs []byte, err error) {
// read until we get 6 bytes
tmp := make([]byte, 6)
n, err := c.connection.Read(tmp) n, err := c.connection.Read(tmp)
if err != nil { if err != nil {
return return
@ -72,7 +59,7 @@ func (c *Comm) Read() (buf []byte, numBytes int, bs []byte, err error) {
for { for {
// see if we have enough bytes // see if we have enough bytes
bs = bytes.Trim(bs, "\x00") bs = bytes.Trim(bs, "\x00")
if len(bs) == 6 { if len(bs) == 5 {
break break
} }
n, err := c.connection.Read(tmp) n, err := c.connection.Read(tmp)
@ -112,13 +99,13 @@ func (c *Comm) Read() (buf []byte, numBytes int, bs []byte, err error) {
} }
// Send a message // Send a message
func (c *Comm) Send(message string) (err error) { func (c Comm) Send(message string) (err error) {
_, err = c.Write([]byte(message)) _, err = c.Write([]byte(message))
return return
} }
// Receive a message // Receive a message
func (c *Comm) Receive() (s string, err error) { func (c Comm) Receive() (s string, err error) {
b, _, _, err := c.Read() b, _, _, err := c.Read()
s = string(b) s = string(b)
return return

View file

@ -9,7 +9,7 @@ import (
// Compress returns a compressed byte slice. // Compress returns a compressed byte slice.
func Compress(src []byte) []byte { func Compress(src []byte) []byte {
compressedData := new(bytes.Buffer) compressedData := new(bytes.Buffer)
compress(src, compressedData, 1) compress(src, compressedData, 9)
return compressedData.Bytes() return compressedData.Bytes()
} }

View file

@ -1,4 +1,4 @@
package models package models
const WEBSOCKET_BUFFER_SIZE = 1024 * 1024 * 32 const WEBSOCKET_BUFFER_SIZE = 1024 * 1024 * 32
const TCP_BUFFER_SIZE = 993280 const TCP_BUFFER_SIZE = 1024 * 64

View file

@ -3,6 +3,7 @@ package recipient
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net" "net"
@ -13,7 +14,6 @@ import (
"time" "time"
humanize "github.com/dustin/go-humanize" humanize "github.com/dustin/go-humanize"
"github.com/pkg/errors"
log "github.com/cihub/seelog" log "github.com/cihub/seelog"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
@ -50,7 +50,7 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
var transferTime time.Duration var transferTime time.Duration
var hash256 []byte var hash256 []byte
var otherIP string var otherIP string
var tcpConnections []*comm.Comm var tcpConnections []comm.Comm
dataChan := make(chan []byte, 1024*1024) dataChan := make(chan []byte, 1024*1024)
useWebsockets := true useWebsockets := true
@ -176,7 +176,7 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
// connect to TCP to receive file // connect to TCP to receive file
if !useWebsockets { if !useWebsockets {
log.Debugf("connecting to server") log.Debugf("connecting to server")
tcpConnections = make([]*comm.Comm, len(tcpPorts)) tcpConnections = make([]comm.Comm, len(tcpPorts))
for i, tcpPort := range tcpPorts { for i, tcpPort := range tcpPorts {
tcpConnections[i], err = connectToTCPServer(utils.SHA256(fmt.Sprintf("%d%x", i, sessionKey)), serverAddress+":"+tcpPort) tcpConnections[i], err = connectToTCPServer(utils.SHA256(fmt.Sprintf("%d%x", i, sessionKey)), serverAddress+":"+tcpPort)
if err != nil { if err != nil {
@ -300,7 +300,7 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(len(tcpConnections)) wg.Add(len(tcpConnections))
for i := range tcpConnections { for i := range tcpConnections {
go func(wg *sync.WaitGroup, tcpConnection *comm.Comm) { go func(wg *sync.WaitGroup, tcpConnection comm.Comm) {
defer wg.Done() defer wg.Done()
for { for {
// read from TCP connection // read from TCP connection
@ -405,26 +405,16 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
} }
} }
func connectToTCPServer(room string, address string) (com *comm.Comm, err error) { func connectToTCPServer(room string, address string) (com comm.Comm, err error) {
log.Debugf("recipient connecting to %s", address) log.Debugf("recipient connecting to %s", address)
// rAddr, err := net.ResolveTCPAddr("tcp", address)
// if err != nil {
// return
// }
// connection, err := net.DialTCP("tcp", nil, rAddr)
// if err != nil {
// err = errors.Wrap(err, "bad connection to tcp")
// return
// }
// connection.SetNoDelay(true)
connection, err := net.Dial("tcp", address) connection, err := net.Dial("tcp", address)
if err != nil { if err != nil {
err = errors.Wrap(err, "bad connection to tcp")
return return
} }
connection.SetReadDeadline(time.Now().Add(3 * time.Hour)) connection.SetReadDeadline(time.Now().Add(3 * time.Hour))
connection.SetDeadline(time.Now().Add(3 * time.Hour)) connection.SetDeadline(time.Now().Add(3 * time.Hour))
connection.SetWriteDeadline(time.Now().Add(3 * time.Hour)) connection.SetWriteDeadline(time.Now().Add(3 * time.Hour))
com = comm.New(connection) com = comm.New(connection)
log.Debug("waiting for server contact") log.Debug("waiting for server contact")
ok, err := com.Receive() ok, err := com.Receive()

View file

@ -51,7 +51,7 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
var fileHash []byte var fileHash []byte
var otherIP string var otherIP string
var startTransfer time.Time var startTransfer time.Time
var tcpConnections []*comm.Comm var tcpConnections []comm.Comm
type DataChan struct { type DataChan struct {
b []byte b []byte
@ -302,7 +302,7 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
// connect to TCP to receive file // connect to TCP to receive file
if !useWebsockets { if !useWebsockets {
log.Debugf("connecting to server") log.Debugf("connecting to server")
tcpConnections = make([]*comm.Comm, len(tcpPorts)) tcpConnections = make([]comm.Comm, len(tcpPorts))
for i, tcpPort := range tcpPorts { for i, tcpPort := range tcpPorts {
log.Debug(tcpPort) log.Debug(tcpPort)
tcpConnections[i], err = connectToTCPServer(utils.SHA256(fmt.Sprintf("%d%x", i, sessionKey)), serverAddress+":"+tcpPort) tcpConnections[i], err = connectToTCPServer(utils.SHA256(fmt.Sprintf("%d%x", i, sessionKey)), serverAddress+":"+tcpPort)
@ -346,7 +346,7 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(len(tcpConnections)) wg.Add(len(tcpConnections))
for i := range tcpConnections { for i := range tcpConnections {
go func(i int, wg *sync.WaitGroup, dataChan <-chan DataChan, tcpConnection *comm.Comm) { go func(i int, wg *sync.WaitGroup, dataChan <-chan DataChan, tcpConnection comm.Comm) {
defer wg.Done() defer wg.Done()
for data := range dataChan { for data := range dataChan {
if data.err != nil { if data.err != nil {
@ -407,20 +407,9 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
} }
} }
func connectToTCPServer(room string, address string) (com *comm.Comm, err error) { func connectToTCPServer(room string, address string) (com comm.Comm, err error) {
// rAddr, err := net.ResolveTCPAddr("tcp", address)
// if err != nil {
// return
// }
// connection, err := net.DialTCP("tcp", nil, rAddr)
// if err != nil {
// err = errors.Wrap(err, "bad connection to tcp")
// return
// }
// connection.SetNoDelay(true)
connection, err := net.Dial("tcp", address) connection, err := net.Dial("tcp", address)
if err != nil { if err != nil {
err = errors.Wrap(err, "bad connection to tcp")
return return
} }
connection.SetReadDeadline(time.Now().Add(3 * time.Hour)) connection.SetReadDeadline(time.Now().Add(3 * time.Hour))

View file

@ -13,7 +13,7 @@ import (
) )
type roomInfo struct { type roomInfo struct {
receiver *comm.Comm receiver comm.Comm
opened time.Time opened time.Time
} }
@ -41,22 +41,13 @@ func Run(debugLevel, port string) {
func run(port string) (err error) { func run(port string) (err error) {
log.Debugf("starting TCP server on " + port) log.Debugf("starting TCP server on " + port)
// rAddr, err := net.ResolveTCPAddr("tcp", "0.0.0.0:"+port) server, err := net.Listen("tcp", "0.0.0.0:"+port)
// if err != nil {
// panic(err)
// }
// server, err := net.ListenTCP("tcp", rAddr)
// if err != nil {
// return errors.Wrap(err, "Error listening on :"+port)
// }
server, err := net.Listen("tcp", ":"+port)
if err != nil { if err != nil {
return errors.Wrap(err, "Error listening on :"+port) return errors.Wrap(err, "Error listening on :"+port)
} }
defer server.Close() defer server.Close()
// spawn a new goroutine whenever a client connects // spawn a new goroutine whenever a client connects
for { for {
// connection, err := server.AcceptTCP()
connection, err := server.Accept() connection, err := server.Accept()
if err != nil { if err != nil {
return errors.Wrap(err, "problem accepting connection") return errors.Wrap(err, "problem accepting connection")
@ -71,7 +62,7 @@ func run(port string) (err error) {
} }
} }
func clientCommuncation(port string, c *comm.Comm) (err error) { func clientCommuncation(port string, c comm.Comm) (err error) {
// send ok to tell client they are connected // send ok to tell client they are connected
err = c.Send("ok") err = c.Send("ok")
if err != nil { if err != nil {
@ -107,7 +98,7 @@ func clientCommuncation(port string, c *comm.Comm) (err error) {
wg.Add(1) wg.Add(1)
// start piping // start piping
go func(com1, com2 *comm.Comm, wg *sync.WaitGroup) { go func(com1, com2 comm.Comm, wg *sync.WaitGroup) {
log.Debug("starting pipes") log.Debug("starting pipes")
pipe(com1.Connection(), com2.Connection()) pipe(com1.Connection(), com2.Connection())
wg.Done() wg.Done()
@ -133,14 +124,13 @@ func clientCommuncation(port string, c *comm.Comm) (err error) {
// Read()s from the socket to the channel. // Read()s from the socket to the channel.
func chanFromConn(conn net.Conn) chan []byte { func chanFromConn(conn net.Conn) chan []byte {
c := make(chan []byte) c := make(chan []byte)
// reader := bufio.NewReader(conn)
go func() { go func() {
b := make([]byte, models.TCP_BUFFER_SIZE)
for { for {
b := make([]byte, models.TCP_BUFFER_SIZE)
n, err := conn.Read(b) n, err := conn.Read(b)
if n > 0 { if n > 0 {
// c <- b[:n]
res := make([]byte, n) res := make([]byte, n)
// Copy the buffer so it doesn't get changed while read by the recipient. // Copy the buffer so it doesn't get changed while read by the recipient.
copy(res, b[:n]) copy(res, b[:n])
@ -160,26 +150,21 @@ func chanFromConn(conn net.Conn) chan []byte {
// transfers data from one to the other. // transfers data from one to the other.
func pipe(conn1 net.Conn, conn2 net.Conn) { func pipe(conn1 net.Conn, conn2 net.Conn) {
chan1 := chanFromConn(conn1) chan1 := chanFromConn(conn1)
// chan2 := chanFromConn(conn2) chan2 := chanFromConn(conn2)
// writer1 := bufio.NewWriter(conn1)
// writer2 := bufio.NewWriter(conn2)
for { for {
b1 := <-chan1 select {
if b1 == nil { case b1 := <-chan1:
return if b1 == nil {
return
}
conn2.Write(b1)
case b2 := <-chan2:
if b2 == nil {
return
}
conn1.Write(b2)
} }
conn2.Write(b1)
// writer2.Write(b1)
// writer2.Flush()
// case b2 := <-chan2:
// if b2 == nil {
// return
// }
// writer1.Write(b2)
// writer1.Flush()
// }
} }
} }