1
1
Fork 0
mirror of https://github.com/schollz/croc.git synced 2025-10-11 13:21:00 +02:00

ran 'go fmt *.go' to (hopefully) get rid of commit issues

This commit is contained in:
Brad Lunsford 2017-10-20 15:18:06 -07:00
parent 798a0d2c52
commit 6bdbdce655
3 changed files with 774 additions and 774 deletions

View file

@ -1,455 +1,455 @@
package main package main
import ( import (
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"math" "math"
"net" "net"
"os" "os"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/gosuri/uiprogress" "github.com/gosuri/uiprogress"
"github.com/pkg/errors" "github.com/pkg/errors"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
type Connection struct { type Connection struct {
Server string Server string
File FileMetaData File FileMetaData
NumberOfConnections int NumberOfConnections int
Code string Code string
HashedCode string HashedCode string
IsSender bool IsSender bool
Debug bool Debug bool
DontEncrypt bool DontEncrypt bool
Wait bool Wait bool
bars []*uiprogress.Bar bars []*uiprogress.Bar
rate int rate int
} }
type FileMetaData struct { type FileMetaData struct {
Name string Name string
Size int Size int
Hash string Hash string
} }
func NewConnection(flags *Flags) *Connection { func NewConnection(flags *Flags) *Connection {
c := new(Connection) c := new(Connection)
c.Debug = flags.Debug c.Debug = flags.Debug
c.DontEncrypt = flags.DontEncrypt c.DontEncrypt = flags.DontEncrypt
c.Wait = flags.Wait c.Wait = flags.Wait
c.Server = flags.Server c.Server = flags.Server
c.Code = flags.Code c.Code = flags.Code
c.NumberOfConnections = flags.NumberOfConnections c.NumberOfConnections = flags.NumberOfConnections
c.rate = flags.Rate c.rate = flags.Rate
if len(flags.File) > 0 { if len(flags.File) > 0 {
c.File.Name = flags.File c.File.Name = flags.File
c.IsSender = true c.IsSender = true
} else { } else {
c.IsSender = false c.IsSender = false
} }
log.SetFormatter(&log.TextFormatter{}) log.SetFormatter(&log.TextFormatter{})
if c.Debug { if c.Debug {
log.SetLevel(log.DebugLevel) log.SetLevel(log.DebugLevel)
} else { } else {
log.SetLevel(log.WarnLevel) log.SetLevel(log.WarnLevel)
} }
return c return c
} }
func (c *Connection) Run() error { func (c *Connection) Run() error {
forceSingleThreaded := false forceSingleThreaded := false
if c.IsSender { if c.IsSender {
fsize, err := FileSize(c.File.Name) fsize, err := FileSize(c.File.Name)
if err != nil { if err != nil {
return err return err
} }
if fsize < MAX_NUMBER_THREADS*BUFFERSIZE { if fsize < MAX_NUMBER_THREADS*BUFFERSIZE {
forceSingleThreaded = true forceSingleThreaded = true
log.Debug("forcing single thread") log.Debug("forcing single thread")
} }
} }
log.Debug("checking code validity") log.Debug("checking code validity")
for { for {
// check code // check code
goodCode := true goodCode := true
m := strings.Split(c.Code, "-") m := strings.Split(c.Code, "-")
log.Debug(m) log.Debug(m)
numThreads, errParse := strconv.Atoi(m[0]) numThreads, errParse := strconv.Atoi(m[0])
if len(m) < 2 { if len(m) < 2 {
goodCode = false goodCode = false
log.Debug("code too short") log.Debug("code too short")
} else if numThreads > MAX_NUMBER_THREADS || numThreads < 1 || (forceSingleThreaded && numThreads != 1) { } else if numThreads > MAX_NUMBER_THREADS || numThreads < 1 || (forceSingleThreaded && numThreads != 1) {
c.NumberOfConnections = MAX_NUMBER_THREADS c.NumberOfConnections = MAX_NUMBER_THREADS
goodCode = false goodCode = false
log.Debug("incorrect number of threads") log.Debug("incorrect number of threads")
} else if errParse != nil { } else if errParse != nil {
goodCode = false goodCode = false
log.Debug("problem parsing threads") log.Debug("problem parsing threads")
} }
log.Debug(m) log.Debug(m)
log.Debug(goodCode) log.Debug(goodCode)
if !goodCode { if !goodCode {
if c.IsSender { if c.IsSender {
if forceSingleThreaded { if forceSingleThreaded {
c.NumberOfConnections = 1 c.NumberOfConnections = 1
} }
c.Code = strconv.Itoa(c.NumberOfConnections) + "-" + GetRandomName() c.Code = strconv.Itoa(c.NumberOfConnections) + "-" + GetRandomName()
} else { } else {
if len(c.Code) != 0 { if len(c.Code) != 0 {
fmt.Println("Code must begin with number of threads (e.g. 3-some-code)") fmt.Println("Code must begin with number of threads (e.g. 3-some-code)")
} }
c.Code = getInput("Enter receive code: ") c.Code = getInput("Enter receive code: ")
} }
} else { } else {
break break
} }
} }
// assign number of connections // assign number of connections
c.NumberOfConnections, _ = strconv.Atoi(strings.Split(c.Code, "-")[0]) c.NumberOfConnections, _ = strconv.Atoi(strings.Split(c.Code, "-")[0])
if c.IsSender { if c.IsSender {
if c.DontEncrypt { if c.DontEncrypt {
// don't encrypt // don't encrypt
CopyFile(c.File.Name, c.File.Name+".enc") CopyFile(c.File.Name, c.File.Name+".enc")
} else { } else {
// encrypt // encrypt
log.Debug("encrypting...") log.Debug("encrypting...")
if err := EncryptFile(c.File.Name, c.File.Name+".enc", c.Code); err != nil { if err := EncryptFile(c.File.Name, c.File.Name+".enc", c.Code); err != nil {
return err return err
} }
} }
// get file hash // get file hash
var err error var err error
c.File.Hash, err = HashFile(c.File.Name) c.File.Hash, err = HashFile(c.File.Name)
if err != nil { if err != nil {
return err return err
} }
// get file size // get file size
c.File.Size, err = FileSize(c.File.Name + ".enc") c.File.Size, err = FileSize(c.File.Name + ".enc")
if err != nil { if err != nil {
return err return err
} }
fmt.Printf("Sending %d byte file named '%s'\n", c.File.Size, c.File.Name) fmt.Printf("Sending %d byte file named '%s'\n", c.File.Size, c.File.Name)
fmt.Printf("Code is: %s\n", c.Code) fmt.Printf("Code is: %s\n", c.Code)
} }
return c.runClient() return c.runClient()
} }
// runClient spawns threads for parallel uplink/downlink via TCP // runClient spawns threads for parallel uplink/downlink via TCP
func (c *Connection) runClient() error { func (c *Connection) runClient() error {
logger := log.WithFields(log.Fields{ logger := log.WithFields(log.Fields{
"code": c.Code, "code": c.Code,
"sender?": c.IsSender, "sender?": c.IsSender,
}) })
c.HashedCode = Hash(c.Code) c.HashedCode = Hash(c.Code)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(c.NumberOfConnections) wg.Add(c.NumberOfConnections)
uiprogress.Start() uiprogress.Start()
if !c.Debug { if !c.Debug {
c.bars = make([]*uiprogress.Bar, c.NumberOfConnections) c.bars = make([]*uiprogress.Bar, c.NumberOfConnections)
} }
gotOK := false gotOK := false
gotResponse := false gotResponse := false
notPresent := false notPresent := false
for id := 0; id < c.NumberOfConnections; id++ { for id := 0; id < c.NumberOfConnections; id++ {
go func(id int) { go func(id int) {
defer wg.Done() defer wg.Done()
port := strconv.Itoa(27001 + id) port := strconv.Itoa(27001 + id)
connection, err := net.Dial("tcp", c.Server+":"+port) connection, err := net.Dial("tcp", c.Server+":"+port)
if err != nil { if err != nil {
panic(err) panic(err)
} }
defer connection.Close() defer connection.Close()
message := receiveMessage(connection) message := receiveMessage(connection)
logger.Debugf("relay says: %s", message) logger.Debugf("relay says: %s", message)
if c.IsSender { if c.IsSender {
logger.Debugf("telling relay: %s", "s."+c.Code) logger.Debugf("telling relay: %s", "s."+c.Code)
metaData, err := json.Marshal(c.File) metaData, err := json.Marshal(c.File)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
} }
encryptedMetaData, salt, iv := Encrypt(metaData, c.Code) encryptedMetaData, salt, iv := Encrypt(metaData, c.Code)
sendMessage("s."+c.HashedCode+"."+hex.EncodeToString(encryptedMetaData)+"-"+salt+"-"+iv, connection) sendMessage("s."+c.HashedCode+"."+hex.EncodeToString(encryptedMetaData)+"-"+salt+"-"+iv, connection)
} else { } else {
logger.Debugf("telling relay: %s", "r."+c.Code) logger.Debugf("telling relay: %s", "r."+c.Code)
if c.Wait { if c.Wait {
sendMessage("r."+c.HashedCode+".0.0.0", connection) sendMessage("r."+c.HashedCode+".0.0.0", connection)
} else { } else {
sendMessage("c."+c.HashedCode+".0.0.0", connection) sendMessage("c."+c.HashedCode+".0.0.0", connection)
} }
} }
if c.IsSender { // this is a sender if c.IsSender { // this is a sender
logger.Debug("waiting for ok from relay") logger.Debug("waiting for ok from relay")
message = receiveMessage(connection) message = receiveMessage(connection)
logger.Debug("got ok from relay") logger.Debug("got ok from relay")
if id == 0 { if id == 0 {
fmt.Printf("\nSending (->%s)..\n", message) fmt.Printf("\nSending (->%s)..\n", message)
} }
// wait for pipe to be made // wait for pipe to be made
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
// Write data from file // Write data from file
logger.Debug("send file") logger.Debug("send file")
c.sendFile(id, connection) c.sendFile(id, connection)
} else { // this is a receiver } else { // this is a receiver
logger.Debug("waiting for meta data from sender") logger.Debug("waiting for meta data from sender")
message = receiveMessage(connection) message = receiveMessage(connection)
m := strings.Split(message, "-") m := strings.Split(message, "-")
encryptedData, salt, iv, sendersAddress := m[0], m[1], m[2], m[3] encryptedData, salt, iv, sendersAddress := m[0], m[1], m[2], m[3]
if sendersAddress == "0.0.0.0" { if sendersAddress == "0.0.0.0" {
notPresent = true notPresent = true
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
return return
} }
encryptedBytes, err := hex.DecodeString(encryptedData) encryptedBytes, err := hex.DecodeString(encryptedData)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return return
} }
decryptedBytes, _ := Decrypt(encryptedBytes, c.Code, salt, iv, c.DontEncrypt) decryptedBytes, _ := Decrypt(encryptedBytes, c.Code, salt, iv, c.DontEncrypt)
err = json.Unmarshal(decryptedBytes, &c.File) err = json.Unmarshal(decryptedBytes, &c.File)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return return
} }
log.Debugf("meta data received: %v", c.File) log.Debugf("meta data received: %v", c.File)
// have the main thread ask for the okay // have the main thread ask for the okay
if id == 0 { if id == 0 {
fmt.Printf("Receiving file (%d bytes) into: %s\n", c.File.Size, c.File.Name) fmt.Printf("Receiving file (%d bytes) into: %s\n", c.File.Size, c.File.Name)
var sentFileNames []string var sentFileNames []string
if fileAlreadyExists(sentFileNames, c.File.Name) { if fileAlreadyExists(sentFileNames, c.File.Name) {
fmt.Printf("Will not overwrite file!") fmt.Printf("Will not overwrite file!")
os.Exit(1) os.Exit(1)
} }
getOK := getInput("ok? (y/n): ") getOK := getInput("ok? (y/n): ")
if getOK == "y" { if getOK == "y" {
gotOK = true gotOK = true
sentFileNames = append(sentFileNames, c.File.Name) sentFileNames = append(sentFileNames, c.File.Name)
} }
gotResponse = true gotResponse = true
} }
// wait for the main thread to get the okay // wait for the main thread to get the okay
for limit := 0; limit < 1000; limit++ { for limit := 0; limit < 1000; limit++ {
if gotResponse { if gotResponse {
break break
} }
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
} }
if !gotOK { if !gotOK {
sendMessage("not ok", connection) sendMessage("not ok", connection)
} else { } else {
sendMessage("ok", connection) sendMessage("ok", connection)
logger.Debug("receive file") logger.Debug("receive file")
if id == 0 { if id == 0 {
fmt.Printf("\n\nReceiving (<-%s)..\n", sendersAddress) fmt.Printf("\n\nReceiving (<-%s)..\n", sendersAddress)
} }
c.receiveFile(id, connection) c.receiveFile(id, connection)
} }
} }
}(id) }(id)
} }
wg.Wait() wg.Wait()
if !c.IsSender { if !c.IsSender {
if notPresent { if notPresent {
fmt.Println("Sender/Code not present") fmt.Println("Sender/Code not present")
return nil return nil
} }
if !gotOK { if !gotOK {
return errors.New("Transfer interrupted") return errors.New("Transfer interrupted")
} }
c.catFile(c.File.Name) c.catFile(c.File.Name)
log.Debugf("Code: [%s]", c.Code) log.Debugf("Code: [%s]", c.Code)
if c.DontEncrypt { if c.DontEncrypt {
if err := CopyFile(c.File.Name+".enc", c.File.Name); err != nil { if err := CopyFile(c.File.Name+".enc", c.File.Name); err != nil {
return err return err
} }
} else { } else {
if err := DecryptFile(c.File.Name+".enc", c.File.Name, c.Code); err != nil { if err := DecryptFile(c.File.Name+".enc", c.File.Name, c.Code); err != nil {
return errors.Wrap(err, "Problem decrypting file") return errors.Wrap(err, "Problem decrypting file")
} }
} }
if !c.Debug { if !c.Debug {
os.Remove(c.File.Name + ".enc") os.Remove(c.File.Name + ".enc")
} }
fileHash, err := HashFile(c.File.Name) fileHash, err := HashFile(c.File.Name)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
} }
log.Debugf("\n\n\ndownloaded hash: [%s]", fileHash) log.Debugf("\n\n\ndownloaded hash: [%s]", fileHash)
log.Debugf("\n\n\nrelayed hash: [%s]", c.File.Hash) log.Debugf("\n\n\nrelayed hash: [%s]", c.File.Hash)
if c.File.Hash != fileHash { if c.File.Hash != fileHash {
return fmt.Errorf("\nUh oh! %s is corrupted! Sorry, try again.\n", c.File.Name) return fmt.Errorf("\nUh oh! %s is corrupted! Sorry, try again.\n", c.File.Name)
} else { } else {
fmt.Printf("\nReceived file written to %s", c.File.Name) fmt.Printf("\nReceived file written to %s", c.File.Name)
} }
} else { } else {
fmt.Println("File sent.") fmt.Println("File sent.")
// TODO: Add confirmation // TODO: Add confirmation
} }
return nil return nil
} }
func fileAlreadyExists(s []string, f string) bool { func fileAlreadyExists(s []string, f string) bool {
for _, a := range s { for _, a := range s {
if a == f { if a == f {
return true return true
} }
} }
return false return false
} }
func (c *Connection) catFile(fname string) { func (c *Connection) catFile(fname string) {
// cat the file // cat the file
os.Remove(fname) os.Remove(fname)
finished, err := os.Create(fname + ".enc") finished, err := os.Create(fname + ".enc")
defer finished.Close() defer finished.Close()
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
for id := 0; id < c.NumberOfConnections; id++ { for id := 0; id < c.NumberOfConnections; id++ {
fh, err := os.Open(fname + "." + strconv.Itoa(id)) fh, err := os.Open(fname + "." + strconv.Itoa(id))
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
_, err = io.Copy(finished, fh) _, err = io.Copy(finished, fh)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
fh.Close() fh.Close()
os.Remove(fname + "." + strconv.Itoa(id)) os.Remove(fname + "." + strconv.Itoa(id))
} }
} }
func (c *Connection) receiveFile(id int, connection net.Conn) error { func (c *Connection) receiveFile(id int, connection net.Conn) error {
logger := log.WithFields(log.Fields{ logger := log.WithFields(log.Fields{
"function": "receiveFile #" + strconv.Itoa(id), "function": "receiveFile #" + strconv.Itoa(id),
}) })
logger.Debug("waiting for chunk size from sender") logger.Debug("waiting for chunk size from sender")
fileSizeBuffer := make([]byte, 10) fileSizeBuffer := make([]byte, 10)
connection.Read(fileSizeBuffer) connection.Read(fileSizeBuffer)
fileDataString := strings.Trim(string(fileSizeBuffer), ":") fileDataString := strings.Trim(string(fileSizeBuffer), ":")
fileSizeInt, _ := strconv.Atoi(fileDataString) fileSizeInt, _ := strconv.Atoi(fileDataString)
chunkSize := int64(fileSizeInt) chunkSize := int64(fileSizeInt)
logger.Debugf("chunk size: %d", chunkSize) logger.Debugf("chunk size: %d", chunkSize)
os.Remove(c.File.Name + "." + strconv.Itoa(id)) os.Remove(c.File.Name + "." + strconv.Itoa(id))
newFile, err := os.Create(c.File.Name + "." + strconv.Itoa(id)) newFile, err := os.Create(c.File.Name + "." + strconv.Itoa(id))
if err != nil { if err != nil {
panic(err) panic(err)
} }
defer newFile.Close() defer newFile.Close()
if !c.Debug { if !c.Debug {
c.bars[id] = uiprogress.AddBar(int(chunkSize)/1024 + 1).AppendCompleted().PrependElapsed() c.bars[id] = uiprogress.AddBar(int(chunkSize)/1024 + 1).AppendCompleted().PrependElapsed()
} }
logger.Debug("waiting for file") logger.Debug("waiting for file")
var receivedBytes int64 var receivedBytes int64
receivedFirstBytes := false receivedFirstBytes := false
for { for {
if !c.Debug { if !c.Debug {
c.bars[id].Incr() c.bars[id].Incr()
} }
if (chunkSize - receivedBytes) < BUFFERSIZE { if (chunkSize - receivedBytes) < BUFFERSIZE {
logger.Debug("at the end") logger.Debug("at the end")
io.CopyN(newFile, connection, (chunkSize - receivedBytes)) io.CopyN(newFile, connection, (chunkSize - receivedBytes))
// Empty the remaining bytes that we don't need from the network buffer // Empty the remaining bytes that we don't need from the network buffer
if (receivedBytes+BUFFERSIZE)-chunkSize < BUFFERSIZE { if (receivedBytes+BUFFERSIZE)-chunkSize < BUFFERSIZE {
logger.Debug("empty remaining bytes from network buffer") logger.Debug("empty remaining bytes from network buffer")
connection.Read(make([]byte, (receivedBytes+BUFFERSIZE)-chunkSize)) connection.Read(make([]byte, (receivedBytes+BUFFERSIZE)-chunkSize))
} }
break break
} }
io.CopyN(newFile, connection, BUFFERSIZE) io.CopyN(newFile, connection, BUFFERSIZE)
receivedBytes += BUFFERSIZE receivedBytes += BUFFERSIZE
if !receivedFirstBytes { if !receivedFirstBytes {
receivedFirstBytes = true receivedFirstBytes = true
logger.Debug("Receieved first bytes!") logger.Debug("Receieved first bytes!")
} }
} }
logger.Debug("received file") logger.Debug("received file")
return nil return nil
} }
func (c *Connection) sendFile(id int, connection net.Conn) { func (c *Connection) sendFile(id int, connection net.Conn) {
logger := log.WithFields(log.Fields{ logger := log.WithFields(log.Fields{
"function": "sendFile #" + strconv.Itoa(id), "function": "sendFile #" + strconv.Itoa(id),
}) })
defer connection.Close() defer connection.Close()
var err error var err error
numChunks := math.Ceil(float64(c.File.Size) / float64(BUFFERSIZE)) numChunks := math.Ceil(float64(c.File.Size) / float64(BUFFERSIZE))
chunksPerWorker := int(math.Ceil(numChunks / float64(c.NumberOfConnections))) chunksPerWorker := int(math.Ceil(numChunks / float64(c.NumberOfConnections)))
chunkSize := int64(chunksPerWorker * BUFFERSIZE) chunkSize := int64(chunksPerWorker * BUFFERSIZE)
if id+1 == c.NumberOfConnections { if id+1 == c.NumberOfConnections {
chunkSize = int64(c.File.Size) - int64(c.NumberOfConnections-1)*chunkSize chunkSize = int64(c.File.Size) - int64(c.NumberOfConnections-1)*chunkSize
} }
if id == 0 || id == c.NumberOfConnections-1 { if id == 0 || id == c.NumberOfConnections-1 {
logger.Debugf("numChunks: %v", numChunks) logger.Debugf("numChunks: %v", numChunks)
logger.Debugf("chunksPerWorker: %v", chunksPerWorker) logger.Debugf("chunksPerWorker: %v", chunksPerWorker)
logger.Debugf("bytesPerchunkSizeConnection: %v", chunkSize) logger.Debugf("bytesPerchunkSizeConnection: %v", chunkSize)
} }
logger.Debugf("sending chunk size: %d", chunkSize) logger.Debugf("sending chunk size: %d", chunkSize)
connection.Write([]byte(fillString(strconv.FormatInt(int64(chunkSize), 10), 10))) connection.Write([]byte(fillString(strconv.FormatInt(int64(chunkSize), 10), 10)))
sendBuffer := make([]byte, BUFFERSIZE) sendBuffer := make([]byte, BUFFERSIZE)
// open encrypted file // open encrypted file
file, err := os.OpenFile(c.File.Name+".enc", os.O_RDONLY, 0755) file, err := os.OpenFile(c.File.Name+".enc", os.O_RDONLY, 0755)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return return
} }
defer file.Close() defer file.Close()
chunkI := 0 chunkI := 0
if !c.Debug { if !c.Debug {
c.bars[id] = uiprogress.AddBar(chunksPerWorker).AppendCompleted().PrependElapsed() c.bars[id] = uiprogress.AddBar(chunksPerWorker).AppendCompleted().PrependElapsed()
} }
bufferSizeInKilobytes := BUFFERSIZE / 1024 bufferSizeInKilobytes := BUFFERSIZE / 1024
rate := float64(c.rate) / float64(c.NumberOfConnections*bufferSizeInKilobytes) rate := float64(c.rate) / float64(c.NumberOfConnections*bufferSizeInKilobytes)
throttle := time.NewTicker(time.Second / time.Duration(rate)) throttle := time.NewTicker(time.Second / time.Duration(rate))
defer throttle.Stop() defer throttle.Stop()
for range throttle.C { for range throttle.C {
_, err = file.Read(sendBuffer) _, err = file.Read(sendBuffer)
if err == io.EOF { if err == io.EOF {
//End of file reached, break out of for loop //End of file reached, break out of for loop
logger.Debug("EOF") logger.Debug("EOF")
break break
} }
if (chunkI >= chunksPerWorker*id && chunkI < chunksPerWorker*id+chunksPerWorker) || (id == c.NumberOfConnections-1 && chunkI >= chunksPerWorker*id) { if (chunkI >= chunksPerWorker*id && chunkI < chunksPerWorker*id+chunksPerWorker) || (id == c.NumberOfConnections-1 && chunkI >= chunksPerWorker*id) {
connection.Write(sendBuffer) connection.Write(sendBuffer)
if !c.Debug { if !c.Debug {
c.bars[id].Incr() c.bars[id].Incr()
} }
} }
chunkI++ chunkI++
} }
logger.Debug("file is sent") logger.Debug("file is sent")
return return
} }

134
main.go
View file

@ -1,67 +1,67 @@
package main package main
import ( import (
"bufio" "bufio"
"flag" "flag"
"fmt" "fmt"
"os" "os"
"strings" "strings"
) )
const BUFFERSIZE = 1024 const BUFFERSIZE = 1024
var oneGigabytePerSecond = 1000000 // expressed as kbps var oneGigabytePerSecond = 1000000 // expressed as kbps
type Flags struct { type Flags struct {
Relay bool Relay bool
Debug bool Debug bool
Wait bool Wait bool
DontEncrypt bool DontEncrypt bool
Server string Server string
File string File string
Code string Code string
Rate int Rate int
NumberOfConnections int NumberOfConnections int
} }
var version string var version string
func main() { func main() {
fmt.Println(` fmt.Println(`
/\_/\ /\_/\
____/ o o \ ____/ o o \
/~____ =ø= / /~____ =ø= /
(______)__m_m) (______)__m_m)
croc version ` + version + ` croc version ` + version + `
`) `)
flags := new(Flags) flags := new(Flags)
flag.BoolVar(&flags.Relay, "relay", false, "run as relay") flag.BoolVar(&flags.Relay, "relay", false, "run as relay")
flag.BoolVar(&flags.Debug, "debug", false, "debug mode") flag.BoolVar(&flags.Debug, "debug", false, "debug mode")
flag.BoolVar(&flags.Wait, "wait", false, "wait for code to be sent") flag.BoolVar(&flags.Wait, "wait", false, "wait for code to be sent")
flag.StringVar(&flags.Server, "server", "cowyo.com", "address of relay server") flag.StringVar(&flags.Server, "server", "cowyo.com", "address of relay server")
flag.StringVar(&flags.File, "send", "", "file to send") flag.StringVar(&flags.File, "send", "", "file to send")
flag.StringVar(&flags.Code, "code", "", "use your own code phrase") flag.StringVar(&flags.Code, "code", "", "use your own code phrase")
flag.IntVar(&flags.Rate, "rate", oneGigabytePerSecond, "throttle down to speed in kbps") flag.IntVar(&flags.Rate, "rate", oneGigabytePerSecond, "throttle down to speed in kbps")
flag.BoolVar(&flags.DontEncrypt, "no-encrypt", false, "turn off encryption") flag.BoolVar(&flags.DontEncrypt, "no-encrypt", false, "turn off encryption")
flag.IntVar(&flags.NumberOfConnections, "threads", 4, "number of threads to use") flag.IntVar(&flags.NumberOfConnections, "threads", 4, "number of threads to use")
flag.Parse() flag.Parse()
if flags.Relay { if flags.Relay {
r := NewRelay(flags) r := NewRelay(flags)
r.Run() r.Run()
} else { } else {
c := NewConnection(flags) c := NewConnection(flags)
err := c.Run() err := c.Run()
if err != nil { if err != nil {
fmt.Printf("Error! Please submit the following error to https://github.com/schollz/croc/issues:\n\n'%s'\n\n", err.Error()) fmt.Printf("Error! Please submit the following error to https://github.com/schollz/croc/issues:\n\n'%s'\n\n", err.Error())
} }
} }
} }
func getInput(prompt string) string { func getInput(prompt string) string {
reader := bufio.NewReader(os.Stdin) reader := bufio.NewReader(os.Stdin)
fmt.Print(prompt) fmt.Print(prompt)
text, _ := reader.ReadString('\n') text, _ := reader.ReadString('\n')
return strings.TrimSpace(text) return strings.TrimSpace(text)
} }

504
relay.go
View file

@ -1,252 +1,252 @@
package main package main
import ( import (
"net" "net"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
const MAX_NUMBER_THREADS = 8 const MAX_NUMBER_THREADS = 8
type connectionMap struct { type connectionMap struct {
reciever map[string]net.Conn reciever map[string]net.Conn
sender map[string]net.Conn sender map[string]net.Conn
metadata map[string]string metadata map[string]string
sync.RWMutex sync.RWMutex
} }
type Relay struct { type Relay struct {
connections connectionMap connections connectionMap
Debug bool Debug bool
NumberOfConnections int NumberOfConnections int
} }
func NewRelay(flags *Flags) *Relay { func NewRelay(flags *Flags) *Relay {
r := new(Relay) r := new(Relay)
r.Debug = flags.Debug r.Debug = flags.Debug
r.NumberOfConnections = MAX_NUMBER_THREADS r.NumberOfConnections = MAX_NUMBER_THREADS
log.SetFormatter(&log.TextFormatter{}) log.SetFormatter(&log.TextFormatter{})
if r.Debug { if r.Debug {
log.SetLevel(log.DebugLevel) log.SetLevel(log.DebugLevel)
} else { } else {
log.SetLevel(log.WarnLevel) log.SetLevel(log.WarnLevel)
} }
return r return r
} }
func (r *Relay) Run() { func (r *Relay) Run() {
r.connections = connectionMap{} r.connections = connectionMap{}
r.connections.Lock() r.connections.Lock()
r.connections.reciever = make(map[string]net.Conn) r.connections.reciever = make(map[string]net.Conn)
r.connections.sender = make(map[string]net.Conn) r.connections.sender = make(map[string]net.Conn)
r.connections.metadata = make(map[string]string) r.connections.metadata = make(map[string]string)
r.connections.Unlock() r.connections.Unlock()
r.runServer() r.runServer()
} }
func (r *Relay) runServer() { func (r *Relay) runServer() {
logger := log.WithFields(log.Fields{ logger := log.WithFields(log.Fields{
"function": "main", "function": "main",
}) })
logger.Debug("Initializing") logger.Debug("Initializing")
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(r.NumberOfConnections) wg.Add(r.NumberOfConnections)
for id := 0; id < r.NumberOfConnections; id++ { for id := 0; id < r.NumberOfConnections; id++ {
go r.listenerThread(id, &wg) go r.listenerThread(id, &wg)
} }
wg.Wait() wg.Wait()
} }
func (r *Relay) listenerThread(id int, wg *sync.WaitGroup) { func (r *Relay) listenerThread(id int, wg *sync.WaitGroup) {
logger := log.WithFields(log.Fields{ logger := log.WithFields(log.Fields{
"function": "listenerThread:" + strconv.Itoa(27000+id), "function": "listenerThread:" + strconv.Itoa(27000+id),
}) })
defer wg.Done() defer wg.Done()
err := r.listener(id) err := r.listener(id)
if err != nil { if err != nil {
logger.Error(err) logger.Error(err)
} }
} }
func (r *Relay) listener(id int) (err error) { func (r *Relay) listener(id int) (err error) {
port := strconv.Itoa(27001 + id) port := strconv.Itoa(27001 + id)
logger := log.WithFields(log.Fields{ logger := log.WithFields(log.Fields{
"function": "listener" + ":" + port, "function": "listener" + ":" + port,
}) })
server, err := net.Listen("tcp", "0.0.0.0:"+port) server, err := net.Listen("tcp", "0.0.0.0:"+port)
if err != nil { if err != nil {
return errors.Wrap(err, "Error listening on "+":"+port) return errors.Wrap(err, "Error listening on "+":"+port)
} }
defer server.Close() defer server.Close()
logger.Debug("waiting for connections") logger.Debug("waiting for connections")
//Spawn a new goroutine whenever a client connects //Spawn a new goroutine whenever a client connects
for { for {
connection, err := server.Accept() connection, err := server.Accept()
if err != nil { if err != nil {
return errors.Wrap(err, "problem accepting connection") return errors.Wrap(err, "problem accepting connection")
} }
logger.Debugf("Client %s connected", connection.RemoteAddr().String()) logger.Debugf("Client %s connected", connection.RemoteAddr().String())
go r.clientCommuncation(id, connection) go r.clientCommuncation(id, connection)
} }
} }
func (r *Relay) clientCommuncation(id int, connection net.Conn) { func (r *Relay) clientCommuncation(id int, connection net.Conn) {
sendMessage("who?", connection) sendMessage("who?", connection)
m := strings.Split(receiveMessage(connection), ".") m := strings.Split(receiveMessage(connection), ".")
connectionType, codePhrase, metaData := m[0], m[1], m[2] connectionType, codePhrase, metaData := m[0], m[1], m[2]
key := codePhrase + "-" + strconv.Itoa(id) key := codePhrase + "-" + strconv.Itoa(id)
logger := log.WithFields(log.Fields{ logger := log.WithFields(log.Fields{
"id": id, "id": id,
"codePhrase": codePhrase, "codePhrase": codePhrase,
}) })
if connectionType == "s" { if connectionType == "s" {
logger.Debug("got sender") logger.Debug("got sender")
r.connections.Lock() r.connections.Lock()
r.connections.metadata[key] = metaData r.connections.metadata[key] = metaData
r.connections.sender[key] = connection r.connections.sender[key] = connection
r.connections.Unlock() r.connections.Unlock()
// wait for receiver // wait for receiver
receiversAddress := "" receiversAddress := ""
for { for {
r.connections.RLock() r.connections.RLock()
if _, ok := r.connections.reciever[key]; ok { if _, ok := r.connections.reciever[key]; ok {
receiversAddress = r.connections.reciever[key].RemoteAddr().String() receiversAddress = r.connections.reciever[key].RemoteAddr().String()
logger.Debug("got reciever") logger.Debug("got reciever")
r.connections.RUnlock() r.connections.RUnlock()
break break
} }
r.connections.RUnlock() r.connections.RUnlock()
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
} }
logger.Debug("telling sender ok") logger.Debug("telling sender ok")
sendMessage(receiversAddress, connection) sendMessage(receiversAddress, connection)
logger.Debug("preparing pipe") logger.Debug("preparing pipe")
r.connections.Lock() r.connections.Lock()
con1 := r.connections.sender[key] con1 := r.connections.sender[key]
con2 := r.connections.reciever[key] con2 := r.connections.reciever[key]
r.connections.Unlock() r.connections.Unlock()
logger.Debug("piping connections") logger.Debug("piping connections")
Pipe(con1, con2) Pipe(con1, con2)
logger.Debug("done piping") logger.Debug("done piping")
r.connections.Lock() r.connections.Lock()
delete(r.connections.sender, key) delete(r.connections.sender, key)
delete(r.connections.reciever, key) delete(r.connections.reciever, key)
delete(r.connections.metadata, key) delete(r.connections.metadata, key)
r.connections.Unlock() r.connections.Unlock()
logger.Debug("deleted sender and receiver") logger.Debug("deleted sender and receiver")
} else { } else {
// wait for sender's metadata // wait for sender's metadata
sendersAddress := "" sendersAddress := ""
for { for {
r.connections.RLock() r.connections.RLock()
if _, ok := r.connections.metadata[key]; ok { if _, ok := r.connections.metadata[key]; ok {
if _, ok2 := r.connections.sender[key]; ok2 { if _, ok2 := r.connections.sender[key]; ok2 {
sendersAddress = r.connections.sender[key].RemoteAddr().String() sendersAddress = r.connections.sender[key].RemoteAddr().String()
logger.Debug("got sender meta data") logger.Debug("got sender meta data")
r.connections.RUnlock() r.connections.RUnlock()
break break
} }
} }
r.connections.RUnlock() r.connections.RUnlock()
if connectionType == "c" { if connectionType == "c" {
sendMessage("0-0-0-0.0.0.0", connection) sendMessage("0-0-0-0.0.0.0", connection)
return return
} }
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
} }
// send meta data // send meta data
r.connections.RLock() r.connections.RLock()
sendMessage(r.connections.metadata[key]+"-"+sendersAddress, connection) sendMessage(r.connections.metadata[key]+"-"+sendersAddress, connection)
r.connections.RUnlock() r.connections.RUnlock()
// check for receiver's consent // check for receiver's consent
consent := receiveMessage(connection) consent := receiveMessage(connection)
logger.Debugf("consent: %s", consent) logger.Debugf("consent: %s", consent)
if consent == "ok" { if consent == "ok" {
logger.Debug("got consent") logger.Debug("got consent")
r.connections.Lock() r.connections.Lock()
r.connections.reciever[key] = connection r.connections.reciever[key] = connection
r.connections.Unlock() r.connections.Unlock()
} }
} }
return return
} }
func sendMessage(message string, connection net.Conn) { func sendMessage(message string, connection net.Conn) {
message = fillString(message, BUFFERSIZE) message = fillString(message, BUFFERSIZE)
connection.Write([]byte(message)) connection.Write([]byte(message))
} }
func receiveMessage(connection net.Conn) string { func receiveMessage(connection net.Conn) string {
messageByte := make([]byte, BUFFERSIZE) messageByte := make([]byte, BUFFERSIZE)
connection.Read(messageByte) connection.Read(messageByte)
return strings.Replace(string(messageByte), ":", "", -1) return strings.Replace(string(messageByte), ":", "", -1)
} }
func fillString(retunString string, toLength int) string { func fillString(retunString string, toLength int) string {
for { for {
lengthString := len(retunString) lengthString := len(retunString)
if lengthString < toLength { if lengthString < toLength {
retunString = retunString + ":" retunString = retunString + ":"
continue continue
} }
break break
} }
return retunString return retunString
} }
// chanFromConn creates a channel from a Conn object, and sends everything it // chanFromConn creates a channel from a Conn object, and sends everything it
// Read()s from the socket to the channel. // Read()s from the socket to the channel.
func chanFromConn(conn net.Conn) chan []byte { func chanFromConn(conn net.Conn) chan []byte {
c := make(chan []byte) c := make(chan []byte)
go func() { go func() {
b := make([]byte, BUFFERSIZE) b := make([]byte, BUFFERSIZE)
for { for {
n, err := conn.Read(b) n, err := conn.Read(b)
if n > 0 { if n > 0 {
res := make([]byte, n) res := make([]byte, n)
// Copy the buffer so it doesn't get changed while read by the recipient. // Copy the buffer so it doesn't get changed while read by the recipient.
copy(res, b[:n]) copy(res, b[:n])
c <- res c <- res
} }
if err != nil { if err != nil {
c <- nil c <- nil
break break
} }
} }
}() }()
return c return c
} }
// Pipe creates a full-duplex pipe between the two sockets and transfers data from one to the other. // Pipe creates a full-duplex pipe between the two sockets and transfers data from one to the other.
func Pipe(conn1 net.Conn, conn2 net.Conn) { func Pipe(conn1 net.Conn, conn2 net.Conn) {
chan1 := chanFromConn(conn1) chan1 := chanFromConn(conn1)
chan2 := chanFromConn(conn2) chan2 := chanFromConn(conn2)
for { for {
select { select {
case b1 := <-chan1: case b1 := <-chan1:
if b1 == nil { if b1 == nil {
return return
} else { } else {
conn2.Write(b1) conn2.Write(b1)
} }
case b2 := <-chan2: case b2 := <-chan2:
if b2 == nil { if b2 == nil {
return return
} else { } else {
conn1.Write(b2) conn1.Write(b2)
} }
} }
} }
} }