0
0
Fork 0
mirror of https://github.com/schollz/croc.git synced 2025-10-11 13:21:00 +02:00

Merge pull request #26 from lummie/master

Issue #25
This commit is contained in:
Zack 2017-10-20 15:38:24 -06:00 committed by GitHub
commit 28ea514725
4 changed files with 724 additions and 687 deletions

3
.gitignore vendored Normal file
View file

@ -0,0 +1,3 @@
# IDEs
.idea/

View file

@ -1,439 +1,440 @@
package main package main
import ( import (
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"math" "math"
"net" "net"
"os" "os"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/gosuri/uiprogress" "github.com/gosuri/uiprogress"
"github.com/pkg/errors" "github.com/pkg/errors"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
type Connection struct { type Connection struct {
Server string Server string
File FileMetaData File FileMetaData
NumberOfConnections int NumberOfConnections int
Code string Code string
HashedCode string HashedCode string
IsSender bool IsSender bool
Debug bool Debug bool
DontEncrypt bool DontEncrypt bool
bars []*uiprogress.Bar bars []*uiprogress.Bar
rate int rate int
} }
type FileMetaData struct { type FileMetaData struct {
Name string Name string
Size int Size int
Hash string Hash string
} }
func NewConnection(flags *Flags) *Connection { func NewConnection(flags *Flags) *Connection {
c := new(Connection) c := new(Connection)
c.Debug = flags.Debug c.Debug = flags.Debug
c.DontEncrypt = flags.DontEncrypt c.DontEncrypt = flags.DontEncrypt
c.Server = flags.Server c.Server = flags.Server
c.Code = flags.Code c.Code = flags.Code
c.NumberOfConnections = flags.NumberOfConnections c.NumberOfConnections = flags.NumberOfConnections
c.rate = flags.Rate c.rate = flags.Rate
if len(flags.File) > 0 { if len(flags.File) > 0 {
c.File.Name = flags.File c.File.Name = flags.File
c.IsSender = true c.IsSender = true
} else { } else {
c.IsSender = false c.IsSender = false
} }
log.SetFormatter(&log.TextFormatter{}) log.SetFormatter(&log.TextFormatter{})
if c.Debug { if c.Debug {
log.SetLevel(log.DebugLevel) log.SetLevel(log.DebugLevel)
} else { } else {
log.SetLevel(log.WarnLevel) log.SetLevel(log.WarnLevel)
} }
return c return c
} }
func (c *Connection) Run() error { func (c *Connection) Run() error {
forceSingleThreaded := false forceSingleThreaded := false
if c.IsSender { if c.IsSender {
fsize, err := FileSize(c.File.Name) fsize, err := FileSize(c.File.Name)
if err != nil { if err != nil {
return err return err
} }
if fsize < MAX_NUMBER_THREADS*BUFFERSIZE { if fsize < MAX_NUMBER_THREADS*BUFFERSIZE {
forceSingleThreaded = true forceSingleThreaded = true
log.Debug("forcing single thread") log.Debug("forcing single thread")
} }
} }
log.Debug("checking code validity") log.Debug("checking code validity")
for { for {
// check code // check code
goodCode := true goodCode := true
m := strings.Split(c.Code, "-") m := strings.Split(c.Code, "-")
log.Debug(m) numThreads, errParse := strconv.Atoi(m[0])
numThreads, errParse := strconv.Atoi(m[0]) if len(m) < 2 {
if len(m) < 2 { goodCode = false
goodCode = false } else if numThreads > MAX_NUMBER_THREADS || numThreads < 1 || (forceSingleThreaded && numThreads != 1) {
log.Debug("code too short") c.NumberOfConnections = MAX_NUMBER_THREADS
} else if numThreads > MAX_NUMBER_THREADS || numThreads < 1 || (forceSingleThreaded && numThreads != 1) { goodCode = false
c.NumberOfConnections = MAX_NUMBER_THREADS } else if errParse != nil {
goodCode = false goodCode = false
log.Debug("incorrect number of threads") }
} else if errParse != nil { log.Debug(m)
goodCode = false if !goodCode {
log.Debug("problem parsing threads") if c.IsSender {
} if forceSingleThreaded {
log.Debug(m) c.NumberOfConnections = 1
log.Debug(goodCode) }
if !goodCode { c.Code = strconv.Itoa(c.NumberOfConnections) + "-" + GetRandomName()
if c.IsSender { } else {
if forceSingleThreaded { if len(c.Code) != 0 {
c.NumberOfConnections = 1 fmt.Println("Code must begin with number of threads (e.g. 3-some-code)")
} }
c.Code = strconv.Itoa(c.NumberOfConnections) + "-" + GetRandomName() c.Code = getInput("Enter receive code: ")
} else { }
if len(c.Code) != 0 { } else {
fmt.Println("Code must begin with number of threads (e.g. 3-some-code)") break
} }
c.Code = getInput("Enter receive code: ") }
} // assign number of connections
} else { c.NumberOfConnections, _ = strconv.Atoi(strings.Split(c.Code, "-")[0])
break
} if c.IsSender {
} if c.DontEncrypt {
// assign number of connections // don't encrypt
c.NumberOfConnections, _ = strconv.Atoi(strings.Split(c.Code, "-")[0]) CopyFile(c.File.Name, c.File.Name+".enc")
} else {
if c.IsSender { // encrypt
if c.DontEncrypt { log.Debug("encrypting...")
// don't encrypt if err := EncryptFile(c.File.Name, c.File.Name+".enc", c.Code); err != nil {
CopyFile(c.File.Name, c.File.Name+".enc") return err
} else { }
// encrypt }
log.Debug("encrypting...") // get file hash
if err := EncryptFile(c.File.Name, c.File.Name+".enc", c.Code); err != nil { var err error
return err c.File.Hash, err = HashFile(c.File.Name)
} if err != nil {
} return err
// get file hash }
var err error // get file size
c.File.Hash, err = HashFile(c.File.Name) c.File.Size, err = FileSize(c.File.Name + ".enc")
if err != nil { if err != nil {
return err return err
} }
// get file size }
c.File.Size, err = FileSize(c.File.Name + ".enc")
if err != nil { return c.runClient()
return err }
}
fmt.Printf("Sending %d byte file named '%s'\n", c.File.Size, c.File.Name) // runClient spawns threads for parallel uplink/downlink via TCP
fmt.Printf("Code is: %s\n", c.Code) func (c *Connection) runClient() error {
} logger := log.WithFields(log.Fields{
"code": c.Code,
return c.runClient() "sender?": c.IsSender,
} })
// runClient spawns threads for parallel uplink/downlink via TCP c.HashedCode = Hash(c.Code)
func (c *Connection) runClient() error {
logger := log.WithFields(log.Fields{ var wg sync.WaitGroup
"code": c.Code, wg.Add(c.NumberOfConnections)
"sender?": c.IsSender,
}) uiprogress.Start()
if !c.Debug {
c.HashedCode = Hash(c.Code) c.bars = make([]*uiprogress.Bar, c.NumberOfConnections)
}
var wg sync.WaitGroup gotOK := false
wg.Add(c.NumberOfConnections) gotResponse := false
gotConnectionInUse := false
uiprogress.Start() for id := 0; id < c.NumberOfConnections; id++ {
if !c.Debug { go func(id int) {
c.bars = make([]*uiprogress.Bar, c.NumberOfConnections) defer wg.Done()
} port := strconv.Itoa(27001 + id)
gotOK := false connection, err := net.Dial("tcp", c.Server+":"+port)
gotResponse := false if err != nil {
for id := 0; id < c.NumberOfConnections; id++ { panic(err)
go func(id int) { }
defer wg.Done() defer connection.Close()
port := strconv.Itoa(27001 + id)
connection, err := net.Dial("tcp", c.Server+":"+port) message := receiveMessage(connection)
if err != nil { logger.Debugf("relay says: %s", message)
panic(err) if c.IsSender {
} logger.Debugf("telling relay: %s", "s."+c.Code)
defer connection.Close() metaData, err := json.Marshal(c.File)
if err != nil {
message := receiveMessage(connection) log.Error(err)
logger.Debugf("relay says: %s", message) }
if c.IsSender { encryptedMetaData, salt, iv := Encrypt(metaData, c.Code)
logger.Debugf("telling relay: %s", "s."+c.Code) sendMessage("s."+c.HashedCode+"."+hex.EncodeToString(encryptedMetaData)+"-"+salt+"-"+iv, connection)
metaData, err := json.Marshal(c.File) } else {
if err != nil { logger.Debugf("telling relay: %s", "r."+c.Code)
log.Error(err) sendMessage("r."+c.HashedCode+".0.0.0", connection)
} }
encryptedMetaData, salt, iv := Encrypt(metaData, c.Code) if c.IsSender { // this is a sender
sendMessage("s."+c.HashedCode+"."+hex.EncodeToString(encryptedMetaData)+"-"+salt+"-"+iv, connection) logger.Debug("waiting for ok from relay")
} else { message = receiveMessage(connection)
logger.Debugf("telling relay: %s", "r."+c.Code) if message == "no" {
sendMessage("r."+c.HashedCode+".0.0.0", connection) fmt.Println("The specifed code is already in use by a sender.")
} gotConnectionInUse = true
if c.IsSender { // this is a sender } else {
logger.Debug("waiting for ok from relay") logger.Debug("got ok from relay")
message = receiveMessage(connection) if id == 0 {
logger.Debug("got ok from relay") fmt.Printf("\nSending (->%s)..\n", message)
if id == 0 { }
fmt.Printf("\nSending (->%s)..\n", message) // wait for pipe to be made
} time.Sleep(100 * time.Millisecond)
// wait for pipe to be made // Write data from file
time.Sleep(100 * time.Millisecond) logger.Debug("send file")
// Write data from file c.sendFile(id, connection)
logger.Debug("send file") fmt.Println("File sent.")
c.sendFile(id, connection) }
} else { // this is a receiver } else { // this is a receiver
logger.Debug("waiting for meta data from sender") logger.Debug("waiting for meta data from sender")
message = receiveMessage(connection) message = receiveMessage(connection)
m := strings.Split(message, "-") if message == "no" {
encryptedData, salt, iv, sendersAddress := m[0], m[1], m[2], m[3] fmt.Println("The specifed code is already in use by a receiver.")
encryptedBytes, err := hex.DecodeString(encryptedData) gotConnectionInUse = true
if err != nil { } else {
log.Error(err) m := strings.Split(message, "-")
return encryptedData, salt, iv, sendersAddress := m[0], m[1], m[2], m[3]
} encryptedBytes, err := hex.DecodeString(encryptedData)
decryptedBytes, _ := Decrypt(encryptedBytes, c.Code, salt, iv, c.DontEncrypt) if err != nil {
err = json.Unmarshal(decryptedBytes, &c.File) log.Error(err)
if err != nil { return
log.Error(err) }
return decryptedBytes, _ := Decrypt(encryptedBytes, c.Code, salt, iv, c.DontEncrypt)
} err = json.Unmarshal(decryptedBytes, &c.File)
log.Debugf("meta data received: %v", c.File) if err != nil {
// have the main thread ask for the okay log.Error(err)
if id == 0 { return
fmt.Printf("Receiving file (%d bytes) into: %s\n", c.File.Size, c.File.Name) }
var sentFileNames []string log.Debugf("meta data received: %v", c.File)
// have the main thread ask for the okay
if fileAlreadyExists(sentFileNames, c.File.Name) { if id == 0 {
fmt.Printf("Will not overwrite file!") fmt.Printf("Receiving file (%d bytes) into: %s\n", c.File.Size, c.File.Name)
os.Exit(1) var sentFileNames []string
}
getOK := getInput("ok? (y/n): ") if fileAlreadyExists(sentFileNames, c.File.Name) {
if getOK == "y" { fmt.Printf("Will not overwrite file!")
gotOK = true os.Exit(1)
sentFileNames = append(sentFileNames, c.File.Name) }
} getOK := getInput("ok? (y/n): ")
gotResponse = true if getOK == "y" {
} gotOK = true
// wait for the main thread to get the okay sentFileNames = append(sentFileNames, c.File.Name)
for limit := 0; limit < 1000; limit++ { }
if gotResponse { gotResponse = true
break }
} // wait for the main thread to get the okay
time.Sleep(10 * time.Millisecond) for limit := 0; limit < 1000; limit++ {
} if gotResponse {
if !gotOK { break
sendMessage("not ok", connection) }
} else { time.Sleep(10 * time.Millisecond)
sendMessage("ok", connection) }
logger.Debug("receive file") if !gotOK {
if id == 0 { sendMessage("not ok", connection)
fmt.Printf("\n\nReceiving (<-%s)..\n", sendersAddress) } else {
} sendMessage("ok", connection)
c.receiveFile(id, connection) logger.Debug("receive file")
} fmt.Printf("\n\nReceiving (<-%s)..\n", sendersAddress)
} c.receiveFile(id, connection)
}(id) }
} }
wg.Wait() }
}(id)
if !c.IsSender { }
if !gotOK { wg.Wait()
return errors.New("Transfer interrupted")
} if gotConnectionInUse {
c.catFile(c.File.Name) return nil // connection was in use, just quit cleanly
log.Debugf("Code: [%s]", c.Code) }
if c.DontEncrypt {
if err := CopyFile(c.File.Name+".enc", c.File.Name); err != nil { if c.IsSender {
return err // TODO: Add confirmation
} } else { // Is a Receiver
} else { if !gotOK {
if err := DecryptFile(c.File.Name+".enc", c.File.Name, c.Code); err != nil { return errors.New("Transfer interrupted")
return errors.Wrap(err, "Problem decrypting file") }
} c.catFile(c.File.Name)
} log.Debugf("Code: [%s]", c.Code)
if !c.Debug { if c.DontEncrypt {
os.Remove(c.File.Name + ".enc") if err := CopyFile(c.File.Name+".enc", c.File.Name); err != nil {
} return err
}
fileHash, err := HashFile(c.File.Name) } else {
if err != nil { if err := DecryptFile(c.File.Name+".enc", c.File.Name, c.Code); err != nil {
log.Error(err) return errors.Wrap(err, "Problem decrypting file")
} }
log.Debugf("\n\n\ndownloaded hash: [%s]", fileHash) }
log.Debugf("\n\n\nrelayed hash: [%s]", c.File.Hash) if !c.Debug {
os.Remove(c.File.Name + ".enc")
if c.File.Hash != fileHash { }
return fmt.Errorf("\nUh oh! %s is corrupted! Sorry, try again.\n", c.File.Name)
} else { fileHash, err := HashFile(c.File.Name)
fmt.Printf("\nReceived file written to %s", c.File.Name) if err != nil {
} log.Error(err)
} else { }
fmt.Println("File sent.") log.Debugf("\n\n\ndownloaded hash: [%s]", fileHash)
// TODO: Add confirmation log.Debugf("\n\n\nrelayed hash: [%s]", c.File.Hash)
}
return nil if c.File.Hash != fileHash {
} return fmt.Errorf("\nUh oh! %s is corrupted! Sorry, try again.\n", c.File.Name)
} else {
func fileAlreadyExists(s []string, f string) bool { fmt.Printf("\nReceived file written to %s", c.File.Name)
for _, a := range s { }
if a == f { }
return true return nil
} }
}
return false func fileAlreadyExists(s []string, f string) bool {
} for _, a := range s {
if a == f {
func (c *Connection) catFile(fname string) { return true
// cat the file }
os.Remove(fname) }
finished, err := os.Create(fname + ".enc") return false
defer finished.Close() }
if err != nil {
log.Fatal(err) func (c *Connection) catFile(fname string) {
} // cat the file
for id := 0; id < c.NumberOfConnections; id++ { os.Remove(fname)
fh, err := os.Open(fname + "." + strconv.Itoa(id)) finished, err := os.Create(fname + ".enc")
if err != nil { defer finished.Close()
log.Fatal(err) if err != nil {
} log.Fatal(err)
}
_, err = io.Copy(finished, fh) for id := 0; id < c.NumberOfConnections; id++ {
if err != nil { fh, err := os.Open(fname + "." + strconv.Itoa(id))
log.Fatal(err) if err != nil {
} log.Fatal(err)
fh.Close() }
os.Remove(fname + "." + strconv.Itoa(id))
} _, err = io.Copy(finished, fh)
if err != nil {
} log.Fatal(err)
}
func (c *Connection) receiveFile(id int, connection net.Conn) error { fh.Close()
logger := log.WithFields(log.Fields{ os.Remove(fname + "." + strconv.Itoa(id))
"function": "receiveFile #" + strconv.Itoa(id), }
})
}
logger.Debug("waiting for chunk size from sender")
fileSizeBuffer := make([]byte, 10) func (c *Connection) receiveFile(id int, connection net.Conn) error {
connection.Read(fileSizeBuffer) logger := log.WithFields(log.Fields{
fileDataString := strings.Trim(string(fileSizeBuffer), ":") "function": "receiveFile #" + strconv.Itoa(id),
fileSizeInt, _ := strconv.Atoi(fileDataString) })
chunkSize := int64(fileSizeInt)
logger.Debugf("chunk size: %d", chunkSize) logger.Debug("waiting for chunk size from sender")
fileSizeBuffer := make([]byte, 10)
os.Remove(c.File.Name + "." + strconv.Itoa(id)) connection.Read(fileSizeBuffer)
newFile, err := os.Create(c.File.Name + "." + strconv.Itoa(id)) fileDataString := strings.Trim(string(fileSizeBuffer), ":")
if err != nil { fileSizeInt, _ := strconv.Atoi(fileDataString)
panic(err) chunkSize := int64(fileSizeInt)
} logger.Debugf("chunk size: %d", chunkSize)
defer newFile.Close()
os.Remove(c.File.Name + "." + strconv.Itoa(id))
if !c.Debug { newFile, err := os.Create(c.File.Name + "." + strconv.Itoa(id))
c.bars[id] = uiprogress.AddBar(int(chunkSize)/1024 + 1).AppendCompleted().PrependElapsed() if err != nil {
} panic(err)
}
logger.Debug("waiting for file") defer newFile.Close()
var receivedBytes int64
receivedFirstBytes := false if !c.Debug {
for { c.bars[id] = uiprogress.AddBar(int(chunkSize)/1024 + 1).AppendCompleted().PrependElapsed()
if !c.Debug { }
c.bars[id].Incr()
} logger.Debug("waiting for file")
if (chunkSize - receivedBytes) < BUFFERSIZE { var receivedBytes int64
logger.Debug("at the end") for {
io.CopyN(newFile, connection, (chunkSize - receivedBytes)) if !c.Debug {
// Empty the remaining bytes that we don't need from the network buffer c.bars[id].Incr()
if (receivedBytes+BUFFERSIZE)-chunkSize < BUFFERSIZE { }
logger.Debug("empty remaining bytes from network buffer") if (chunkSize - receivedBytes) < BUFFERSIZE {
connection.Read(make([]byte, (receivedBytes+BUFFERSIZE)-chunkSize)) logger.Debug("at the end")
} io.CopyN(newFile, connection, (chunkSize - receivedBytes))
break // Empty the remaining bytes that we don't need from the network buffer
} if (receivedBytes+BUFFERSIZE)-chunkSize < BUFFERSIZE {
io.CopyN(newFile, connection, BUFFERSIZE) logger.Debug("empty remaining bytes from network buffer")
receivedBytes += BUFFERSIZE connection.Read(make([]byte, (receivedBytes+BUFFERSIZE)-chunkSize))
if !receivedFirstBytes { }
receivedFirstBytes = true break
logger.Debug("Receieved first bytes!") }
} io.CopyN(newFile, connection, BUFFERSIZE)
} receivedBytes += BUFFERSIZE
logger.Debug("received file") }
return nil logger.Debug("received file")
} return nil
}
func (c *Connection) sendFile(id int, connection net.Conn) {
logger := log.WithFields(log.Fields{ func (c *Connection) sendFile(id int, connection net.Conn) {
"function": "sendFile #" + strconv.Itoa(id), logger := log.WithFields(log.Fields{
}) "function": "sendFile #" + strconv.Itoa(id),
defer connection.Close() })
defer connection.Close()
var err error
var err error
numChunks := math.Ceil(float64(c.File.Size) / float64(BUFFERSIZE))
chunksPerWorker := int(math.Ceil(numChunks / float64(c.NumberOfConnections))) numChunks := math.Ceil(float64(c.File.Size) / float64(BUFFERSIZE))
chunksPerWorker := int(math.Ceil(numChunks / float64(c.NumberOfConnections)))
chunkSize := int64(chunksPerWorker * BUFFERSIZE)
if id+1 == c.NumberOfConnections { chunkSize := int64(chunksPerWorker * BUFFERSIZE)
chunkSize = int64(c.File.Size) - int64(c.NumberOfConnections-1)*chunkSize if id+1 == c.NumberOfConnections {
} chunkSize = int64(c.File.Size) - int64(c.NumberOfConnections-1)*chunkSize
}
if id == 0 || id == c.NumberOfConnections-1 {
logger.Debugf("numChunks: %v", numChunks) if id == 0 || id == c.NumberOfConnections-1 {
logger.Debugf("chunksPerWorker: %v", chunksPerWorker) logger.Debugf("numChunks: %v", numChunks)
logger.Debugf("bytesPerchunkSizeConnection: %v", chunkSize) logger.Debugf("chunksPerWorker: %v", chunksPerWorker)
} logger.Debugf("bytesPerchunkSizeConnection: %v", chunkSize)
}
logger.Debugf("sending chunk size: %d", chunkSize)
connection.Write([]byte(fillString(strconv.FormatInt(int64(chunkSize), 10), 10))) logger.Debugf("sending chunk size: %d", chunkSize)
connection.Write([]byte(fillString(strconv.FormatInt(int64(chunkSize), 10), 10)))
sendBuffer := make([]byte, BUFFERSIZE)
sendBuffer := make([]byte, BUFFERSIZE)
// open encrypted file
file, err := os.OpenFile(c.File.Name+".enc", os.O_RDONLY, 0755) // open encrypted file
if err != nil { file, err := os.Open(c.File.Name + ".enc")
log.Error(err) if err != nil {
return log.Error(err)
} return
defer file.Close() }
defer file.Close()
chunkI := 0
if !c.Debug { chunkI := 0
c.bars[id] = uiprogress.AddBar(chunksPerWorker).AppendCompleted().PrependElapsed() if !c.Debug {
} c.bars[id] = uiprogress.AddBar(chunksPerWorker).AppendCompleted().PrependElapsed()
}
bufferSizeInKilobytes := BUFFERSIZE / 1024
rate := float64(c.rate) / float64(c.NumberOfConnections*bufferSizeInKilobytes) bufferSizeInKilobytes := BUFFERSIZE / 1024
throttle := time.NewTicker(time.Second / time.Duration(rate)) rate := float64(c.rate) / float64(c.NumberOfConnections*bufferSizeInKilobytes)
defer throttle.Stop() throttle := time.NewTicker(time.Second / time.Duration(rate))
defer throttle.Stop()
for range throttle.C {
_, err = file.Read(sendBuffer) for range throttle.C {
if err == io.EOF { _, err = file.Read(sendBuffer)
//End of file reached, break out of for loop if err == io.EOF {
logger.Debug("EOF") //End of file reached, break out of for loop
break logger.Debug("EOF")
} break
if (chunkI >= chunksPerWorker*id && chunkI < chunksPerWorker*id+chunksPerWorker) || (id == c.NumberOfConnections-1 && chunkI >= chunksPerWorker*id) { }
connection.Write(sendBuffer) if (chunkI >= chunksPerWorker*id && chunkI < chunksPerWorker*id+chunksPerWorker) || (id == c.NumberOfConnections-1 && chunkI >= chunksPerWorker*id) {
if !c.Debug { connection.Write(sendBuffer)
c.bars[id].Incr() if !c.Debug {
} c.bars[id].Incr()
} }
chunkI++ }
} chunkI++
logger.Debug("file is sent") }
return logger.Debug("file is sent")
} return
}

View file

@ -45,4 +45,5 @@ func TestEncryptFiles(t *testing.T) {
} }
os.Remove("temp.dec") os.Remove("temp.dec")
os.Remove("temp.enc") os.Remove("temp.enc")
os.Remove("temp")
} }

528
relay.go
View file

@ -1,248 +1,280 @@
package main package main
import ( import (
"net" "net"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
const MAX_NUMBER_THREADS = 8 const MAX_NUMBER_THREADS = 8
type connectionMap struct { type connectionMap struct {
reciever map[string]net.Conn receiver map[string]net.Conn
sender map[string]net.Conn sender map[string]net.Conn
metadata map[string]string metadata map[string]string
sync.RWMutex potentialReceivers map[string]struct{}
} sync.RWMutex
}
type Relay struct {
connections connectionMap func (c *connectionMap) IsSenderConnected(key string) (found bool) {
Debug bool c.RLock()
NumberOfConnections int defer c.RUnlock()
} _, found = c.sender[key]
return
func NewRelay(flags *Flags) *Relay { }
r := new(Relay)
r.Debug = flags.Debug func (c *connectionMap) IsPotentialReceiverConnected(key string) (found bool) {
r.NumberOfConnections = MAX_NUMBER_THREADS c.RLock()
log.SetFormatter(&log.TextFormatter{}) defer c.RUnlock()
if r.Debug { _, found = c.potentialReceivers[key]
log.SetLevel(log.DebugLevel) return
} else { }
log.SetLevel(log.WarnLevel)
} type Relay struct {
return r connections connectionMap
} Debug bool
NumberOfConnections int
func (r *Relay) Run() { }
r.connections = connectionMap{}
r.connections.Lock() func NewRelay(flags *Flags) *Relay {
r.connections.reciever = make(map[string]net.Conn) r := new(Relay)
r.connections.sender = make(map[string]net.Conn) r.Debug = flags.Debug
r.connections.metadata = make(map[string]string) r.NumberOfConnections = MAX_NUMBER_THREADS
r.connections.Unlock() log.SetFormatter(&log.TextFormatter{})
r.runServer() if r.Debug {
} log.SetLevel(log.DebugLevel)
} else {
func (r *Relay) runServer() { log.SetLevel(log.WarnLevel)
logger := log.WithFields(log.Fields{ }
"function": "main", return r
}) }
logger.Debug("Initializing")
var wg sync.WaitGroup func (r *Relay) Run() {
wg.Add(r.NumberOfConnections) r.connections = connectionMap{}
for id := 0; id < r.NumberOfConnections; id++ { r.connections.Lock()
go r.listenerThread(id, &wg) r.connections.receiver = make(map[string]net.Conn)
} r.connections.sender = make(map[string]net.Conn)
wg.Wait() r.connections.metadata = make(map[string]string)
} r.connections.potentialReceivers = make(map[string]struct{})
r.connections.Unlock()
func (r *Relay) listenerThread(id int, wg *sync.WaitGroup) { r.runServer()
logger := log.WithFields(log.Fields{ }
"function": "listenerThread:" + strconv.Itoa(27000+id),
}) func (r *Relay) runServer() {
logger := log.WithFields(log.Fields{
defer wg.Done() "function": "main",
err := r.listener(id) })
if err != nil { logger.Debug("Initializing")
logger.Error(err) var wg sync.WaitGroup
} wg.Add(r.NumberOfConnections)
} for id := 0; id < r.NumberOfConnections; id++ {
go r.listenerThread(id, &wg)
func (r *Relay) listener(id int) (err error) { }
port := strconv.Itoa(27001 + id) wg.Wait()
logger := log.WithFields(log.Fields{ }
"function": "listener" + ":" + port,
}) func (r *Relay) listenerThread(id int, wg *sync.WaitGroup) {
server, err := net.Listen("tcp", "0.0.0.0:"+port) logger := log.WithFields(log.Fields{
if err != nil { "function": "listenerThread:" + strconv.Itoa(27000+id),
return errors.Wrap(err, "Error listening on "+":"+port) })
}
defer server.Close() defer wg.Done()
logger.Debug("waiting for connections") err := r.listener(id)
//Spawn a new goroutine whenever a client connects if err != nil {
for { logger.Error(err)
connection, err := server.Accept() }
if err != nil { }
return errors.Wrap(err, "problem accepting connection")
} func (r *Relay) listener(id int) (err error) {
logger.Debugf("Client %s connected", connection.RemoteAddr().String()) port := strconv.Itoa(27001 + id)
go r.clientCommuncation(id, connection) logger := log.WithFields(log.Fields{
} "function": "listener" + ":" + port,
} })
server, err := net.Listen("tcp", "0.0.0.0:"+port)
func (r *Relay) clientCommuncation(id int, connection net.Conn) { if err != nil {
sendMessage("who?", connection) return errors.Wrap(err, "Error listening on "+":"+port)
}
m := strings.Split(receiveMessage(connection), ".") defer server.Close()
connectionType, codePhrase, metaData := m[0], m[1], m[2] logger.Debug("waiting for connections")
key := codePhrase + "-" + strconv.Itoa(id) //Spawn a new goroutine whenever a client connects
logger := log.WithFields(log.Fields{ for {
"id": id, connection, err := server.Accept()
"codePhrase": codePhrase, if err != nil {
}) return errors.Wrap(err, "problem accepting connection")
}
if connectionType == "s" { logger.Debugf("Client %s connected", connection.RemoteAddr().String())
logger.Debug("got sender") go r.clientCommuncation(id, connection)
r.connections.Lock() }
r.connections.metadata[key] = metaData }
r.connections.sender[key] = connection
r.connections.Unlock() func (r *Relay) clientCommuncation(id int, connection net.Conn) {
// wait for receiver sendMessage("who?", connection)
receiversAddress := ""
for { m := strings.Split(receiveMessage(connection), ".")
r.connections.RLock() connectionType, codePhrase, metaData := m[0], m[1], m[2]
if _, ok := r.connections.reciever[key]; ok { key := codePhrase + "-" + strconv.Itoa(id)
receiversAddress = r.connections.reciever[key].RemoteAddr().String() logger := log.WithFields(log.Fields{
logger.Debug("got reciever") "id": id,
r.connections.RUnlock() "codePhrase": codePhrase,
break })
}
r.connections.RUnlock() if connectionType == "s" { // sender connection
time.Sleep(100 * time.Millisecond) if r.connections.IsSenderConnected(key) {
} sendMessage("no", connection)
logger.Debug("telling sender ok") return
sendMessage(receiversAddress, connection) }
logger.Debug("preparing pipe")
r.connections.Lock() logger.Debug("got sender")
con1 := r.connections.sender[key] r.connections.Lock()
con2 := r.connections.reciever[key] r.connections.metadata[key] = metaData
r.connections.Unlock() r.connections.sender[key] = connection
logger.Debug("piping connections") r.connections.Unlock()
Pipe(con1, con2) // wait for receiver
logger.Debug("done piping") receiversAddress := ""
r.connections.Lock() for {
delete(r.connections.sender, key) r.connections.RLock()
delete(r.connections.reciever, key) if _, ok := r.connections.receiver[key]; ok {
delete(r.connections.metadata, key) receiversAddress = r.connections.receiver[key].RemoteAddr().String()
r.connections.Unlock() logger.Debug("got receiver")
logger.Debug("deleted sender and receiver") r.connections.RUnlock()
} else { break
// wait for sender's metadata }
sendersAddress := "" r.connections.RUnlock()
for { time.Sleep(100 * time.Millisecond)
r.connections.RLock() }
if _, ok := r.connections.metadata[key]; ok { logger.Debug("telling sender ok")
if _, ok2 := r.connections.sender[key]; ok2 { sendMessage(receiversAddress, connection)
sendersAddress = r.connections.sender[key].RemoteAddr().String() logger.Debug("preparing pipe")
logger.Debug("got sender meta data") r.connections.Lock()
r.connections.RUnlock() con1 := r.connections.sender[key]
break con2 := r.connections.receiver[key]
} r.connections.Unlock()
} logger.Debug("piping connections")
r.connections.RUnlock() Pipe(con1, con2)
time.Sleep(100 * time.Millisecond) logger.Debug("done piping")
} r.connections.Lock()
// send meta data delete(r.connections.sender, key)
r.connections.RLock() delete(r.connections.receiver, key)
sendMessage(r.connections.metadata[key]+"-"+sendersAddress, connection) delete(r.connections.metadata, key)
r.connections.RUnlock() delete(r.connections.potentialReceivers, key)
// check for receiver's consent r.connections.Unlock()
consent := receiveMessage(connection) logger.Debug("deleted sender and receiver")
logger.Debugf("consent: %s", consent) } else { //receiver connection "r"
if consent == "ok" { if r.connections.IsPotentialReceiverConnected(key) {
logger.Debug("got consent") sendMessage("no", connection)
r.connections.Lock() return
r.connections.reciever[key] = connection }
r.connections.Unlock()
} // add as a potential receiver
} r.connections.Lock()
return r.connections.potentialReceivers[key] = struct{}{}
} r.connections.Unlock()
func sendMessage(message string, connection net.Conn) { // wait for sender's metadata
message = fillString(message, BUFFERSIZE) sendersAddress := ""
connection.Write([]byte(message)) for {
} r.connections.RLock()
if _, ok := r.connections.metadata[key]; ok {
func receiveMessage(connection net.Conn) string { if _, ok2 := r.connections.sender[key]; ok2 {
messageByte := make([]byte, BUFFERSIZE) sendersAddress = r.connections.sender[key].RemoteAddr().String()
connection.Read(messageByte) logger.Debug("got sender meta data")
return strings.Replace(string(messageByte), ":", "", -1) r.connections.RUnlock()
} break
}
func fillString(retunString string, toLength int) string { }
for { r.connections.RUnlock()
lengthString := len(retunString) time.Sleep(100 * time.Millisecond)
if lengthString < toLength { }
retunString = retunString + ":" // send meta data
continue r.connections.RLock()
} sendMessage(r.connections.metadata[key]+"-"+sendersAddress, connection)
break r.connections.RUnlock()
} // check for receiver's consent
return retunString consent := receiveMessage(connection)
} logger.Debugf("consent: %s", consent)
if consent == "ok" {
// chanFromConn creates a channel from a Conn object, and sends everything it logger.Debug("got consent")
// Read()s from the socket to the channel. r.connections.Lock()
func chanFromConn(conn net.Conn) chan []byte { r.connections.receiver[key] = connection
c := make(chan []byte) r.connections.Unlock()
}
go func() { }
b := make([]byte, BUFFERSIZE) return
}
for {
n, err := conn.Read(b) func sendMessage(message string, connection net.Conn) {
if n > 0 { message = fillString(message, BUFFERSIZE)
res := make([]byte, n) connection.Write([]byte(message))
// Copy the buffer so it doesn't get changed while read by the recipient. }
copy(res, b[:n])
c <- res func receiveMessage(connection net.Conn) string {
} messageByte := make([]byte, BUFFERSIZE)
if err != nil { connection.Read(messageByte)
c <- nil return strings.Replace(string(messageByte), ":", "", -1)
break }
}
} func fillString(retunString string, toLength int) string {
}() for {
lengthString := len(retunString)
return c if lengthString < toLength {
} retunString = retunString + ":"
continue
// Pipe creates a full-duplex pipe between the two sockets and transfers data from one to the other. }
func Pipe(conn1 net.Conn, conn2 net.Conn) { break
chan1 := chanFromConn(conn1) }
chan2 := chanFromConn(conn2) return retunString
}
for {
select { // chanFromConn creates a channel from a Conn object, and sends everything it
case b1 := <-chan1: // Read()s from the socket to the channel.
if b1 == nil { func chanFromConn(conn net.Conn) chan []byte {
return c := make(chan []byte)
} else {
conn2.Write(b1) go func() {
} b := make([]byte, BUFFERSIZE)
case b2 := <-chan2:
if b2 == nil { for {
return n, err := conn.Read(b)
} else { if n > 0 {
conn1.Write(b2) res := make([]byte, n)
} // Copy the buffer so it doesn't get changed while read by the recipient.
} copy(res, b[:n])
} c <- res
} }
if err != nil {
c <- nil
break
}
}
}()
return c
}
// Pipe creates a full-duplex pipe between the two sockets and transfers data from one to the other.
func Pipe(conn1 net.Conn, conn2 net.Conn) {
chan1 := chanFromConn(conn1)
chan2 := chanFromConn(conn2)
for {
select {
case b1 := <-chan1:
if b1 == nil {
return
} else {
conn2.Write(b1)
}
case b2 := <-chan2:
if b2 == nil {
return
} else {
conn1.Write(b2)
}
}
}
}