1
1
Fork 0
mirror of https://github.com/schollz/croc.git synced 2025-10-11 13:21:00 +02:00
fixed spelling connectionMap.receiver
Added method to detect if sender and receivers are already connected.
Added client code to correctly action "no" returned by the code being in use.
This commit is contained in:
lummie 2017-10-20 21:51:03 +01:00
parent 17a1f097c3
commit e2faa87b59
2 changed files with 720 additions and 673 deletions

View file

@ -1,425 +1,440 @@
package main package main
import ( import (
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"math" "math"
"net" "net"
"os" "os"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/gosuri/uiprogress" "github.com/gosuri/uiprogress"
"github.com/pkg/errors" "github.com/pkg/errors"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
type Connection struct { type Connection struct {
Server string Server string
File FileMetaData File FileMetaData
NumberOfConnections int NumberOfConnections int
Code string Code string
HashedCode string HashedCode string
IsSender bool IsSender bool
Debug bool Debug bool
DontEncrypt bool DontEncrypt bool
bars []*uiprogress.Bar bars []*uiprogress.Bar
rate int rate int
} }
type FileMetaData struct { type FileMetaData struct {
Name string Name string
Size int Size int
Hash string Hash string
} }
func NewConnection(flags *Flags) *Connection { func NewConnection(flags *Flags) *Connection {
c := new(Connection) c := new(Connection)
c.Debug = flags.Debug c.Debug = flags.Debug
c.DontEncrypt = flags.DontEncrypt c.DontEncrypt = flags.DontEncrypt
c.Server = flags.Server c.Server = flags.Server
c.Code = flags.Code c.Code = flags.Code
c.NumberOfConnections = flags.NumberOfConnections c.NumberOfConnections = flags.NumberOfConnections
c.rate = flags.Rate c.rate = flags.Rate
if len(flags.File) > 0 { if len(flags.File) > 0 {
c.File.Name = flags.File c.File.Name = flags.File
c.IsSender = true c.IsSender = true
} else { } else {
c.IsSender = false c.IsSender = false
} }
log.SetFormatter(&log.TextFormatter{}) log.SetFormatter(&log.TextFormatter{})
if c.Debug { if c.Debug {
log.SetLevel(log.DebugLevel) log.SetLevel(log.DebugLevel)
} else { } else {
log.SetLevel(log.WarnLevel) log.SetLevel(log.WarnLevel)
} }
return c return c
} }
func (c *Connection) Run() error { func (c *Connection) Run() error {
forceSingleThreaded := false forceSingleThreaded := false
if c.IsSender { if c.IsSender {
fsize, err := FileSize(c.File.Name) fsize, err := FileSize(c.File.Name)
if err != nil { if err != nil {
return err return err
} }
if fsize < MAX_NUMBER_THREADS*BUFFERSIZE { if fsize < MAX_NUMBER_THREADS*BUFFERSIZE {
forceSingleThreaded = true forceSingleThreaded = true
log.Debug("forcing single thread") log.Debug("forcing single thread")
} }
} }
log.Debug("checking code validity") log.Debug("checking code validity")
for { for {
// check code // check code
goodCode := true goodCode := true
m := strings.Split(c.Code, "-") m := strings.Split(c.Code, "-")
numThreads, errParse := strconv.Atoi(m[0]) numThreads, errParse := strconv.Atoi(m[0])
if len(m) < 2 { if len(m) < 2 {
goodCode = false goodCode = false
} else if numThreads > MAX_NUMBER_THREADS || numThreads < 1 || (forceSingleThreaded && numThreads != 1) { } else if numThreads > MAX_NUMBER_THREADS || numThreads < 1 || (forceSingleThreaded && numThreads != 1) {
c.NumberOfConnections = MAX_NUMBER_THREADS c.NumberOfConnections = MAX_NUMBER_THREADS
goodCode = false goodCode = false
} else if errParse != nil { } else if errParse != nil {
goodCode = false goodCode = false
} }
log.Debug(m) log.Debug(m)
if !goodCode { if !goodCode {
if c.IsSender { if c.IsSender {
if forceSingleThreaded { if forceSingleThreaded {
c.NumberOfConnections = 1 c.NumberOfConnections = 1
} }
c.Code = strconv.Itoa(c.NumberOfConnections) + "-" + GetRandomName() c.Code = strconv.Itoa(c.NumberOfConnections) + "-" + GetRandomName()
} else { } else {
if len(c.Code) != 0 { if len(c.Code) != 0 {
fmt.Println("Code must begin with number of threads (e.g. 3-some-code)") fmt.Println("Code must begin with number of threads (e.g. 3-some-code)")
} }
c.Code = getInput("Enter receive code: ") c.Code = getInput("Enter receive code: ")
} }
} else { } else {
break break
} }
} }
// assign number of connections // assign number of connections
c.NumberOfConnections, _ = strconv.Atoi(strings.Split(c.Code, "-")[0]) c.NumberOfConnections, _ = strconv.Atoi(strings.Split(c.Code, "-")[0])
if c.IsSender { if c.IsSender {
if c.DontEncrypt { if c.DontEncrypt {
// don't encrypt // don't encrypt
CopyFile(c.File.Name, c.File.Name+".enc") CopyFile(c.File.Name, c.File.Name+".enc")
} else { } else {
// encrypt // encrypt
log.Debug("encrypting...") log.Debug("encrypting...")
if err := EncryptFile(c.File.Name, c.File.Name+".enc", c.Code); err != nil { if err := EncryptFile(c.File.Name, c.File.Name+".enc", c.Code); err != nil {
return err return err
} }
} }
// get file hash // get file hash
var err error var err error
c.File.Hash, err = HashFile(c.File.Name) c.File.Hash, err = HashFile(c.File.Name)
if err != nil { if err != nil {
return err return err
} }
// get file size // get file size
c.File.Size, err = FileSize(c.File.Name + ".enc") c.File.Size, err = FileSize(c.File.Name + ".enc")
if err != nil { if err != nil {
return err return err
} }
} }
return c.runClient() return c.runClient()
} }
// runClient spawns threads for parallel uplink/downlink via TCP // runClient spawns threads for parallel uplink/downlink via TCP
func (c *Connection) runClient() error { func (c *Connection) runClient() error {
logger := log.WithFields(log.Fields{ logger := log.WithFields(log.Fields{
"code": c.Code, "code": c.Code,
"sender?": c.IsSender, "sender?": c.IsSender,
}) })
c.HashedCode = Hash(c.Code) c.HashedCode = Hash(c.Code)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(c.NumberOfConnections) wg.Add(c.NumberOfConnections)
uiprogress.Start() uiprogress.Start()
if !c.Debug { if !c.Debug {
c.bars = make([]*uiprogress.Bar, c.NumberOfConnections) c.bars = make([]*uiprogress.Bar, c.NumberOfConnections)
} }
gotOK := false gotOK := false
gotResponse := false gotResponse := false
for id := 0; id < c.NumberOfConnections; id++ { gotConnectionInUse := false
go func(id int) { for id := 0; id < c.NumberOfConnections; id++ {
defer wg.Done() go func(id int) {
port := strconv.Itoa(27001 + id) defer wg.Done()
connection, err := net.Dial("tcp", c.Server+":"+port) port := strconv.Itoa(27001 + id)
if err != nil { connection, err := net.Dial("tcp", c.Server+":"+port)
panic(err) if err != nil {
} panic(err)
defer connection.Close() }
defer connection.Close()
message := receiveMessage(connection)
logger.Debugf("relay says: %s", message) message := receiveMessage(connection)
if c.IsSender { logger.Debugf("relay says: %s", message)
logger.Debugf("telling relay: %s", "s."+c.Code) if c.IsSender {
metaData, err := json.Marshal(c.File) logger.Debugf("telling relay: %s", "s."+c.Code)
if err != nil { metaData, err := json.Marshal(c.File)
log.Error(err) if err != nil {
} log.Error(err)
encryptedMetaData, salt, iv := Encrypt(metaData, c.Code) }
sendMessage("s."+c.HashedCode+"."+hex.EncodeToString(encryptedMetaData)+"-"+salt+"-"+iv, connection) encryptedMetaData, salt, iv := Encrypt(metaData, c.Code)
} else { sendMessage("s."+c.HashedCode+"."+hex.EncodeToString(encryptedMetaData)+"-"+salt+"-"+iv, connection)
logger.Debugf("telling relay: %s", "r."+c.Code) } else {
sendMessage("r."+c.HashedCode+".0.0.0", connection) logger.Debugf("telling relay: %s", "r."+c.Code)
} sendMessage("r."+c.HashedCode+".0.0.0", connection)
if c.IsSender { // this is a sender }
logger.Debug("waiting for ok from relay") if c.IsSender { // this is a sender
message = receiveMessage(connection) logger.Debug("waiting for ok from relay")
logger.Debug("got ok from relay") message = receiveMessage(connection)
if id == 0 { if message == "no" {
fmt.Printf("\nSending (->%s)..\n", message) fmt.Println("The specifed code is already in use by a sender.")
} gotConnectionInUse = true
// wait for pipe to be made } else {
time.Sleep(100 * time.Millisecond) logger.Debug("got ok from relay")
// Write data from file if id == 0 {
logger.Debug("send file") fmt.Printf("\nSending (->%s)..\n", message)
c.sendFile(id, connection) }
} else { // this is a receiver // wait for pipe to be made
logger.Debug("waiting for meta data from sender") time.Sleep(100 * time.Millisecond)
message = receiveMessage(connection) // Write data from file
m := strings.Split(message, "-") logger.Debug("send file")
encryptedData, salt, iv, sendersAddress := m[0], m[1], m[2], m[3] c.sendFile(id, connection)
encryptedBytes, err := hex.DecodeString(encryptedData) fmt.Println("File sent.")
if err != nil { }
log.Error(err) } else { // this is a receiver
return logger.Debug("waiting for meta data from sender")
} message = receiveMessage(connection)
decryptedBytes, _ := Decrypt(encryptedBytes, c.Code, salt, iv, c.DontEncrypt) if message == "no" {
err = json.Unmarshal(decryptedBytes, &c.File) fmt.Println("The specifed code is already in use by a receiver.")
if err != nil { gotConnectionInUse = true
log.Error(err) } else {
return m := strings.Split(message, "-")
} encryptedData, salt, iv, sendersAddress := m[0], m[1], m[2], m[3]
log.Debugf("meta data received: %v", c.File) encryptedBytes, err := hex.DecodeString(encryptedData)
// have the main thread ask for the okay if err != nil {
if id == 0 { log.Error(err)
fmt.Printf("Receiving file (%d bytes) into: %s\n", c.File.Size, c.File.Name) return
var sentFileNames []string }
decryptedBytes, _ := Decrypt(encryptedBytes, c.Code, salt, iv, c.DontEncrypt)
if fileAlreadyExists(sentFileNames, c.File.Name) { err = json.Unmarshal(decryptedBytes, &c.File)
fmt.Printf("Will not overwrite file!") if err != nil {
os.Exit(1) log.Error(err)
} return
getOK := getInput("ok? (y/n): ") }
if getOK == "y" { log.Debugf("meta data received: %v", c.File)
gotOK = true // have the main thread ask for the okay
sentFileNames = append(sentFileNames, c.File.Name) if id == 0 {
} fmt.Printf("Receiving file (%d bytes) into: %s\n", c.File.Size, c.File.Name)
gotResponse = true var sentFileNames []string
}
// wait for the main thread to get the okay if fileAlreadyExists(sentFileNames, c.File.Name) {
for limit := 0; limit < 1000; limit++ { fmt.Printf("Will not overwrite file!")
if gotResponse { os.Exit(1)
break }
} getOK := getInput("ok? (y/n): ")
time.Sleep(10 * time.Millisecond) if getOK == "y" {
} gotOK = true
if !gotOK { sentFileNames = append(sentFileNames, c.File.Name)
sendMessage("not ok", connection) }
} else { gotResponse = true
sendMessage("ok", connection) }
logger.Debug("receive file") // wait for the main thread to get the okay
fmt.Printf("\n\nReceiving (<-%s)..\n", sendersAddress) for limit := 0; limit < 1000; limit++ {
c.receiveFile(id, connection) if gotResponse {
} break
} }
}(id) time.Sleep(10 * time.Millisecond)
} }
wg.Wait() if !gotOK {
sendMessage("not ok", connection)
if !c.IsSender { } else {
if !gotOK { sendMessage("ok", connection)
return errors.New("Transfer interrupted") logger.Debug("receive file")
} fmt.Printf("\n\nReceiving (<-%s)..\n", sendersAddress)
c.catFile(c.File.Name) c.receiveFile(id, connection)
log.Debugf("Code: [%s]", c.Code) }
if c.DontEncrypt { }
if err := CopyFile(c.File.Name+".enc", c.File.Name); err != nil { }
return err }(id)
} }
} else { wg.Wait()
if err := DecryptFile(c.File.Name+".enc", c.File.Name, c.Code); err != nil {
return errors.Wrap(err, "Problem decrypting file") if gotConnectionInUse {
} return nil // connection was in use, just quit cleanly
} }
if !c.Debug {
os.Remove(c.File.Name + ".enc") if c.IsSender {
} // TODO: Add confirmation
} else { // Is a Receiver
fileHash, err := HashFile(c.File.Name) if !gotOK {
if err != nil { return errors.New("Transfer interrupted")
log.Error(err) }
} c.catFile(c.File.Name)
log.Debugf("\n\n\ndownloaded hash: [%s]", fileHash) log.Debugf("Code: [%s]", c.Code)
log.Debugf("\n\n\nrelayed hash: [%s]", c.File.Hash) if c.DontEncrypt {
if err := CopyFile(c.File.Name+".enc", c.File.Name); err != nil {
if c.File.Hash != fileHash { return err
return fmt.Errorf("\nUh oh! %s is corrupted! Sorry, try again.\n", c.File.Name) }
} else { } else {
fmt.Printf("\nReceived file written to %s", c.File.Name) if err := DecryptFile(c.File.Name+".enc", c.File.Name, c.Code); err != nil {
} return errors.Wrap(err, "Problem decrypting file")
} else { }
fmt.Println("File sent.") }
// TODO: Add confirmation if !c.Debug {
} os.Remove(c.File.Name + ".enc")
return nil }
}
fileHash, err := HashFile(c.File.Name)
func fileAlreadyExists(s []string, f string) bool { if err != nil {
for _, a := range s { log.Error(err)
if a == f { }
return true log.Debugf("\n\n\ndownloaded hash: [%s]", fileHash)
} log.Debugf("\n\n\nrelayed hash: [%s]", c.File.Hash)
}
return false if c.File.Hash != fileHash {
} return fmt.Errorf("\nUh oh! %s is corrupted! Sorry, try again.\n", c.File.Name)
} else {
func (c *Connection) catFile(fname string) { fmt.Printf("\nReceived file written to %s", c.File.Name)
// cat the file }
os.Remove(fname) }
finished, err := os.Create(fname + ".enc") return nil
defer finished.Close() }
if err != nil {
log.Fatal(err) func fileAlreadyExists(s []string, f string) bool {
} for _, a := range s {
for id := 0; id < c.NumberOfConnections; id++ { if a == f {
fh, err := os.Open(fname + "." + strconv.Itoa(id)) return true
if err != nil { }
log.Fatal(err) }
} return false
}
_, err = io.Copy(finished, fh)
if err != nil { func (c *Connection) catFile(fname string) {
log.Fatal(err) // cat the file
} os.Remove(fname)
fh.Close() finished, err := os.Create(fname + ".enc")
os.Remove(fname + "." + strconv.Itoa(id)) defer finished.Close()
} if err != nil {
log.Fatal(err)
} }
for id := 0; id < c.NumberOfConnections; id++ {
func (c *Connection) receiveFile(id int, connection net.Conn) error { fh, err := os.Open(fname + "." + strconv.Itoa(id))
logger := log.WithFields(log.Fields{ if err != nil {
"function": "receiveFile #" + strconv.Itoa(id), log.Fatal(err)
}) }
logger.Debug("waiting for chunk size from sender") _, err = io.Copy(finished, fh)
fileSizeBuffer := make([]byte, 10) if err != nil {
connection.Read(fileSizeBuffer) log.Fatal(err)
fileDataString := strings.Trim(string(fileSizeBuffer), ":") }
fileSizeInt, _ := strconv.Atoi(fileDataString) fh.Close()
chunkSize := int64(fileSizeInt) os.Remove(fname + "." + strconv.Itoa(id))
logger.Debugf("chunk size: %d", chunkSize) }
os.Remove(c.File.Name + "." + strconv.Itoa(id)) }
newFile, err := os.Create(c.File.Name + "." + strconv.Itoa(id))
if err != nil { func (c *Connection) receiveFile(id int, connection net.Conn) error {
panic(err) logger := log.WithFields(log.Fields{
} "function": "receiveFile #" + strconv.Itoa(id),
defer newFile.Close() })
if !c.Debug { logger.Debug("waiting for chunk size from sender")
c.bars[id] = uiprogress.AddBar(int(chunkSize)/1024 + 1).AppendCompleted().PrependElapsed() fileSizeBuffer := make([]byte, 10)
} connection.Read(fileSizeBuffer)
fileDataString := strings.Trim(string(fileSizeBuffer), ":")
logger.Debug("waiting for file") fileSizeInt, _ := strconv.Atoi(fileDataString)
var receivedBytes int64 chunkSize := int64(fileSizeInt)
for { logger.Debugf("chunk size: %d", chunkSize)
if !c.Debug {
c.bars[id].Incr() os.Remove(c.File.Name + "." + strconv.Itoa(id))
} newFile, err := os.Create(c.File.Name + "." + strconv.Itoa(id))
if (chunkSize - receivedBytes) < BUFFERSIZE { if err != nil {
logger.Debug("at the end") panic(err)
io.CopyN(newFile, connection, (chunkSize - receivedBytes)) }
// Empty the remaining bytes that we don't need from the network buffer defer newFile.Close()
if (receivedBytes+BUFFERSIZE)-chunkSize < BUFFERSIZE {
logger.Debug("empty remaining bytes from network buffer") if !c.Debug {
connection.Read(make([]byte, (receivedBytes+BUFFERSIZE)-chunkSize)) c.bars[id] = uiprogress.AddBar(int(chunkSize)/1024 + 1).AppendCompleted().PrependElapsed()
} }
break
} logger.Debug("waiting for file")
io.CopyN(newFile, connection, BUFFERSIZE) var receivedBytes int64
receivedBytes += BUFFERSIZE for {
} if !c.Debug {
logger.Debug("received file") c.bars[id].Incr()
return nil }
} if (chunkSize - receivedBytes) < BUFFERSIZE {
logger.Debug("at the end")
func (c *Connection) sendFile(id int, connection net.Conn) { io.CopyN(newFile, connection, (chunkSize - receivedBytes))
logger := log.WithFields(log.Fields{ // Empty the remaining bytes that we don't need from the network buffer
"function": "sendFile #" + strconv.Itoa(id), if (receivedBytes+BUFFERSIZE)-chunkSize < BUFFERSIZE {
}) logger.Debug("empty remaining bytes from network buffer")
defer connection.Close() connection.Read(make([]byte, (receivedBytes+BUFFERSIZE)-chunkSize))
}
var err error break
}
numChunks := math.Ceil(float64(c.File.Size) / float64(BUFFERSIZE)) io.CopyN(newFile, connection, BUFFERSIZE)
chunksPerWorker := int(math.Ceil(numChunks / float64(c.NumberOfConnections))) receivedBytes += BUFFERSIZE
}
chunkSize := int64(chunksPerWorker * BUFFERSIZE) logger.Debug("received file")
if id+1 == c.NumberOfConnections { return nil
chunkSize = int64(c.File.Size) - int64(c.NumberOfConnections-1)*chunkSize }
}
func (c *Connection) sendFile(id int, connection net.Conn) {
if id == 0 || id == c.NumberOfConnections-1 { logger := log.WithFields(log.Fields{
logger.Debugf("numChunks: %v", numChunks) "function": "sendFile #" + strconv.Itoa(id),
logger.Debugf("chunksPerWorker: %v", chunksPerWorker) })
logger.Debugf("bytesPerchunkSizeConnection: %v", chunkSize) defer connection.Close()
}
var err error
logger.Debugf("sending chunk size: %d", chunkSize)
connection.Write([]byte(fillString(strconv.FormatInt(int64(chunkSize), 10), 10))) numChunks := math.Ceil(float64(c.File.Size) / float64(BUFFERSIZE))
chunksPerWorker := int(math.Ceil(numChunks / float64(c.NumberOfConnections)))
sendBuffer := make([]byte, BUFFERSIZE)
chunkSize := int64(chunksPerWorker * BUFFERSIZE)
// open encrypted file if id+1 == c.NumberOfConnections {
file, err := os.Open(c.File.Name + ".enc") chunkSize = int64(c.File.Size) - int64(c.NumberOfConnections-1)*chunkSize
if err != nil { }
log.Error(err)
return if id == 0 || id == c.NumberOfConnections-1 {
} logger.Debugf("numChunks: %v", numChunks)
defer file.Close() logger.Debugf("chunksPerWorker: %v", chunksPerWorker)
logger.Debugf("bytesPerchunkSizeConnection: %v", chunkSize)
chunkI := 0 }
if !c.Debug {
c.bars[id] = uiprogress.AddBar(chunksPerWorker).AppendCompleted().PrependElapsed() logger.Debugf("sending chunk size: %d", chunkSize)
} connection.Write([]byte(fillString(strconv.FormatInt(int64(chunkSize), 10), 10)))
bufferSizeInKilobytes := BUFFERSIZE / 1024 sendBuffer := make([]byte, BUFFERSIZE)
rate := float64(c.rate) / float64(c.NumberOfConnections*bufferSizeInKilobytes)
throttle := time.NewTicker(time.Second / time.Duration(rate)) // open encrypted file
defer throttle.Stop() file, err := os.Open(c.File.Name + ".enc")
if err != nil {
for range throttle.C { log.Error(err)
_, err = file.Read(sendBuffer) return
if err == io.EOF { }
//End of file reached, break out of for loop defer file.Close()
logger.Debug("EOF")
break chunkI := 0
} if !c.Debug {
if (chunkI >= chunksPerWorker*id && chunkI < chunksPerWorker*id+chunksPerWorker) || (id == c.NumberOfConnections-1 && chunkI >= chunksPerWorker*id) { c.bars[id] = uiprogress.AddBar(chunksPerWorker).AppendCompleted().PrependElapsed()
connection.Write(sendBuffer) }
if !c.Debug {
c.bars[id].Incr() bufferSizeInKilobytes := BUFFERSIZE / 1024
} rate := float64(c.rate) / float64(c.NumberOfConnections*bufferSizeInKilobytes)
} throttle := time.NewTicker(time.Second / time.Duration(rate))
chunkI++ defer throttle.Stop()
}
logger.Debug("file is sent") for range throttle.C {
return _, err = file.Read(sendBuffer)
} if err == io.EOF {
//End of file reached, break out of for loop
logger.Debug("EOF")
break
}
if (chunkI >= chunksPerWorker*id && chunkI < chunksPerWorker*id+chunksPerWorker) || (id == c.NumberOfConnections-1 && chunkI >= chunksPerWorker*id) {
connection.Write(sendBuffer)
if !c.Debug {
c.bars[id].Incr()
}
}
chunkI++
}
logger.Debug("file is sent")
return
}

528
relay.go
View file

@ -1,248 +1,280 @@
package main package main
import ( import (
"net" "net"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
const MAX_NUMBER_THREADS = 8 const MAX_NUMBER_THREADS = 8
type connectionMap struct { type connectionMap struct {
reciever map[string]net.Conn receiver map[string]net.Conn
sender map[string]net.Conn sender map[string]net.Conn
metadata map[string]string metadata map[string]string
sync.RWMutex potentialReceivers map[string]struct{}
} sync.RWMutex
}
type Relay struct {
connections connectionMap func (c *connectionMap) IsSenderConnected(key string) (found bool) {
Debug bool c.RLock()
NumberOfConnections int defer c.RUnlock()
} _, found = c.sender[key]
return
func NewRelay(flags *Flags) *Relay { }
r := new(Relay)
r.Debug = flags.Debug func (c *connectionMap) IsPotentialReceiverConnected(key string) (found bool) {
r.NumberOfConnections = MAX_NUMBER_THREADS c.RLock()
log.SetFormatter(&log.TextFormatter{}) defer c.RUnlock()
if r.Debug { _, found = c.potentialReceivers[key]
log.SetLevel(log.DebugLevel) return
} else { }
log.SetLevel(log.WarnLevel)
} type Relay struct {
return r connections connectionMap
} Debug bool
NumberOfConnections int
func (r *Relay) Run() { }
r.connections = connectionMap{}
r.connections.Lock() func NewRelay(flags *Flags) *Relay {
r.connections.reciever = make(map[string]net.Conn) r := new(Relay)
r.connections.sender = make(map[string]net.Conn) r.Debug = flags.Debug
r.connections.metadata = make(map[string]string) r.NumberOfConnections = MAX_NUMBER_THREADS
r.connections.Unlock() log.SetFormatter(&log.TextFormatter{})
r.runServer() if r.Debug {
} log.SetLevel(log.DebugLevel)
} else {
func (r *Relay) runServer() { log.SetLevel(log.WarnLevel)
logger := log.WithFields(log.Fields{ }
"function": "main", return r
}) }
logger.Debug("Initializing")
var wg sync.WaitGroup func (r *Relay) Run() {
wg.Add(r.NumberOfConnections) r.connections = connectionMap{}
for id := 0; id < r.NumberOfConnections; id++ { r.connections.Lock()
go r.listenerThread(id, &wg) r.connections.receiver = make(map[string]net.Conn)
} r.connections.sender = make(map[string]net.Conn)
wg.Wait() r.connections.metadata = make(map[string]string)
} r.connections.potentialReceivers = make(map[string]struct{})
r.connections.Unlock()
func (r *Relay) listenerThread(id int, wg *sync.WaitGroup) { r.runServer()
logger := log.WithFields(log.Fields{ }
"function": "listenerThread:" + strconv.Itoa(27000+id),
}) func (r *Relay) runServer() {
logger := log.WithFields(log.Fields{
defer wg.Done() "function": "main",
err := r.listener(id) })
if err != nil { logger.Debug("Initializing")
logger.Error(err) var wg sync.WaitGroup
} wg.Add(r.NumberOfConnections)
} for id := 0; id < r.NumberOfConnections; id++ {
go r.listenerThread(id, &wg)
func (r *Relay) listener(id int) (err error) { }
port := strconv.Itoa(27001 + id) wg.Wait()
logger := log.WithFields(log.Fields{ }
"function": "listener" + ":" + port,
}) func (r *Relay) listenerThread(id int, wg *sync.WaitGroup) {
server, err := net.Listen("tcp", "0.0.0.0:"+port) logger := log.WithFields(log.Fields{
if err != nil { "function": "listenerThread:" + strconv.Itoa(27000+id),
return errors.Wrap(err, "Error listening on "+":"+port) })
}
defer server.Close() defer wg.Done()
logger.Debug("waiting for connections") err := r.listener(id)
//Spawn a new goroutine whenever a client connects if err != nil {
for { logger.Error(err)
connection, err := server.Accept() }
if err != nil { }
return errors.Wrap(err, "problem accepting connection")
} func (r *Relay) listener(id int) (err error) {
logger.Debugf("Client %s connected", connection.RemoteAddr().String()) port := strconv.Itoa(27001 + id)
go r.clientCommuncation(id, connection) logger := log.WithFields(log.Fields{
} "function": "listener" + ":" + port,
} })
server, err := net.Listen("tcp", "0.0.0.0:"+port)
func (r *Relay) clientCommuncation(id int, connection net.Conn) { if err != nil {
sendMessage("who?", connection) return errors.Wrap(err, "Error listening on "+":"+port)
}
m := strings.Split(receiveMessage(connection), ".") defer server.Close()
connectionType, codePhrase, metaData := m[0], m[1], m[2] logger.Debug("waiting for connections")
key := codePhrase + "-" + strconv.Itoa(id) //Spawn a new goroutine whenever a client connects
logger := log.WithFields(log.Fields{ for {
"id": id, connection, err := server.Accept()
"codePhrase": codePhrase, if err != nil {
}) return errors.Wrap(err, "problem accepting connection")
}
if connectionType == "s" { logger.Debugf("Client %s connected", connection.RemoteAddr().String())
logger.Debug("got sender") go r.clientCommuncation(id, connection)
r.connections.Lock() }
r.connections.metadata[key] = metaData }
r.connections.sender[key] = connection
r.connections.Unlock() func (r *Relay) clientCommuncation(id int, connection net.Conn) {
// wait for receiver sendMessage("who?", connection)
receiversAddress := ""
for { m := strings.Split(receiveMessage(connection), ".")
r.connections.RLock() connectionType, codePhrase, metaData := m[0], m[1], m[2]
if _, ok := r.connections.reciever[key]; ok { key := codePhrase + "-" + strconv.Itoa(id)
receiversAddress = r.connections.reciever[key].RemoteAddr().String() logger := log.WithFields(log.Fields{
logger.Debug("got reciever") "id": id,
r.connections.RUnlock() "codePhrase": codePhrase,
break })
}
r.connections.RUnlock() if connectionType == "s" { // sender connection
time.Sleep(100 * time.Millisecond) if r.connections.IsSenderConnected(key) {
} sendMessage("no", connection)
logger.Debug("telling sender ok") return
sendMessage(receiversAddress, connection) }
logger.Debug("preparing pipe")
r.connections.Lock() logger.Debug("got sender")
con1 := r.connections.sender[key] r.connections.Lock()
con2 := r.connections.reciever[key] r.connections.metadata[key] = metaData
r.connections.Unlock() r.connections.sender[key] = connection
logger.Debug("piping connections") r.connections.Unlock()
Pipe(con1, con2) // wait for receiver
logger.Debug("done piping") receiversAddress := ""
r.connections.Lock() for {
delete(r.connections.sender, key) r.connections.RLock()
delete(r.connections.reciever, key) if _, ok := r.connections.receiver[key]; ok {
delete(r.connections.metadata, key) receiversAddress = r.connections.receiver[key].RemoteAddr().String()
r.connections.Unlock() logger.Debug("got receiver")
logger.Debug("deleted sender and receiver") r.connections.RUnlock()
} else { break
// wait for sender's metadata }
sendersAddress := "" r.connections.RUnlock()
for { time.Sleep(100 * time.Millisecond)
r.connections.RLock() }
if _, ok := r.connections.metadata[key]; ok { logger.Debug("telling sender ok")
if _, ok2 := r.connections.sender[key]; ok2 { sendMessage(receiversAddress, connection)
sendersAddress = r.connections.sender[key].RemoteAddr().String() logger.Debug("preparing pipe")
logger.Debug("got sender meta data") r.connections.Lock()
r.connections.RUnlock() con1 := r.connections.sender[key]
break con2 := r.connections.receiver[key]
} r.connections.Unlock()
} logger.Debug("piping connections")
r.connections.RUnlock() Pipe(con1, con2)
time.Sleep(100 * time.Millisecond) logger.Debug("done piping")
} r.connections.Lock()
// send meta data delete(r.connections.sender, key)
r.connections.RLock() delete(r.connections.receiver, key)
sendMessage(r.connections.metadata[key]+"-"+sendersAddress, connection) delete(r.connections.metadata, key)
r.connections.RUnlock() delete(r.connections.potentialReceivers, key)
// check for receiver's consent r.connections.Unlock()
consent := receiveMessage(connection) logger.Debug("deleted sender and receiver")
logger.Debugf("consent: %s", consent) } else { //receiver connection "r"
if consent == "ok" { if r.connections.IsPotentialReceiverConnected(key) {
logger.Debug("got consent") sendMessage("no", connection)
r.connections.Lock() return
r.connections.reciever[key] = connection }
r.connections.Unlock()
} // add as a potential receiver
} r.connections.Lock()
return r.connections.potentialReceivers[key] = struct{}{}
} r.connections.Unlock()
func sendMessage(message string, connection net.Conn) { // wait for sender's metadata
message = fillString(message, BUFFERSIZE) sendersAddress := ""
connection.Write([]byte(message)) for {
} r.connections.RLock()
if _, ok := r.connections.metadata[key]; ok {
func receiveMessage(connection net.Conn) string { if _, ok2 := r.connections.sender[key]; ok2 {
messageByte := make([]byte, BUFFERSIZE) sendersAddress = r.connections.sender[key].RemoteAddr().String()
connection.Read(messageByte) logger.Debug("got sender meta data")
return strings.Replace(string(messageByte), ":", "", -1) r.connections.RUnlock()
} break
}
func fillString(retunString string, toLength int) string { }
for { r.connections.RUnlock()
lengthString := len(retunString) time.Sleep(100 * time.Millisecond)
if lengthString < toLength { }
retunString = retunString + ":" // send meta data
continue r.connections.RLock()
} sendMessage(r.connections.metadata[key]+"-"+sendersAddress, connection)
break r.connections.RUnlock()
} // check for receiver's consent
return retunString consent := receiveMessage(connection)
} logger.Debugf("consent: %s", consent)
if consent == "ok" {
// chanFromConn creates a channel from a Conn object, and sends everything it logger.Debug("got consent")
// Read()s from the socket to the channel. r.connections.Lock()
func chanFromConn(conn net.Conn) chan []byte { r.connections.receiver[key] = connection
c := make(chan []byte) r.connections.Unlock()
}
go func() { }
b := make([]byte, BUFFERSIZE) return
}
for {
n, err := conn.Read(b) func sendMessage(message string, connection net.Conn) {
if n > 0 { message = fillString(message, BUFFERSIZE)
res := make([]byte, n) connection.Write([]byte(message))
// Copy the buffer so it doesn't get changed while read by the recipient. }
copy(res, b[:n])
c <- res func receiveMessage(connection net.Conn) string {
} messageByte := make([]byte, BUFFERSIZE)
if err != nil { connection.Read(messageByte)
c <- nil return strings.Replace(string(messageByte), ":", "", -1)
break }
}
} func fillString(retunString string, toLength int) string {
}() for {
lengthString := len(retunString)
return c if lengthString < toLength {
} retunString = retunString + ":"
continue
// Pipe creates a full-duplex pipe between the two sockets and transfers data from one to the other. }
func Pipe(conn1 net.Conn, conn2 net.Conn) { break
chan1 := chanFromConn(conn1) }
chan2 := chanFromConn(conn2) return retunString
}
for {
select { // chanFromConn creates a channel from a Conn object, and sends everything it
case b1 := <-chan1: // Read()s from the socket to the channel.
if b1 == nil { func chanFromConn(conn net.Conn) chan []byte {
return c := make(chan []byte)
} else {
conn2.Write(b1) go func() {
} b := make([]byte, BUFFERSIZE)
case b2 := <-chan2:
if b2 == nil { for {
return n, err := conn.Read(b)
} else { if n > 0 {
conn1.Write(b2) res := make([]byte, n)
} // Copy the buffer so it doesn't get changed while read by the recipient.
} copy(res, b[:n])
} c <- res
} }
if err != nil {
c <- nil
break
}
}
}()
return c
}
// Pipe creates a full-duplex pipe between the two sockets and transfers data from one to the other.
func Pipe(conn1 net.Conn, conn2 net.Conn) {
chan1 := chanFromConn(conn1)
chan2 := chanFromConn(conn2)
for {
select {
case b1 := <-chan1:
if b1 == nil {
return
} else {
conn2.Write(b1)
}
case b2 := <-chan2:
if b2 == nil {
return
} else {
conn1.Write(b2)
}
}
}
}