From eb3af56ccf96d57b430e7b60ef6d33f97b253320 Mon Sep 17 00:00:00 2001 From: Bradley T Lunsford Date: Fri, 20 Oct 2017 14:51:30 -0700 Subject: [PATCH 1/3] adding wait funtionality: if waiting, will send a different identification to the server --- connect.go | 17 ++++++++++++++++- main.go | 2 ++ relay.go | 4 ++++ 3 files changed, 22 insertions(+), 1 deletion(-) diff --git a/connect.go b/connect.go index 47feb7f8..0d6aa700 100644 --- a/connect.go +++ b/connect.go @@ -27,6 +27,7 @@ type Connection struct { IsSender bool Debug bool DontEncrypt bool + Wait bool bars []*uiprogress.Bar rate int } @@ -41,6 +42,7 @@ func NewConnection(flags *Flags) *Connection { c := new(Connection) c.Debug = flags.Debug c.DontEncrypt = flags.DontEncrypt + c.Wait = flags.Wait c.Server = flags.Server c.Code = flags.Code c.NumberOfConnections = flags.NumberOfConnections @@ -160,6 +162,7 @@ func (c *Connection) runClient() error { } gotOK := false gotResponse := false + notPresent := false for id := 0; id < c.NumberOfConnections; id++ { go func(id int) { defer wg.Done() @@ -182,7 +185,11 @@ func (c *Connection) runClient() error { sendMessage("s."+c.HashedCode+"."+hex.EncodeToString(encryptedMetaData)+"-"+salt+"-"+iv, connection) } else { logger.Debugf("telling relay: %s", "r."+c.Code) - sendMessage("r."+c.HashedCode+".0.0.0", connection) + if c.Wait { + sendMessage("r."+c.HashedCode+".0.0.0", connection) + } else { + sendMessage("c."+c.HashedCode+".0.0.0", connection) + } } if c.IsSender { // this is a sender logger.Debug("waiting for ok from relay") @@ -201,6 +208,10 @@ func (c *Connection) runClient() error { message = receiveMessage(connection) m := strings.Split(message, "-") encryptedData, salt, iv, sendersAddress := m[0], m[1], m[2], m[3] + if sendersAddress == "0.0.0.0" { + notPresent = true + return + } encryptedBytes, err := hex.DecodeString(encryptedData) if err != nil { log.Error(err) @@ -252,6 +263,10 @@ func (c *Connection) runClient() error { wg.Wait() if !c.IsSender { + if notPresent { + fmt.Println("Sender/Code not present") + return nil + } if !gotOK { return errors.New("Transfer interrupted") } diff --git a/main.go b/main.go index fbc719ec..199aa509 100644 --- a/main.go +++ b/main.go @@ -15,6 +15,7 @@ var oneGigabytePerSecond = 1000000 // expressed as kbps type Flags struct { Relay bool Debug bool + Wait bool DontEncrypt bool Server string File string @@ -37,6 +38,7 @@ croc version ` + version + ` flags := new(Flags) flag.BoolVar(&flags.Relay, "relay", false, "run as relay") flag.BoolVar(&flags.Debug, "debug", false, "debug mode") + flag.BoolVar(&flags.Wait, "wait", false, "wait for code to be sent") flag.StringVar(&flags.Server, "server", "cowyo.com", "address of relay server") flag.StringVar(&flags.File, "send", "", "file to send") flag.StringVar(&flags.Code, "code", "", "use your own code phrase") diff --git a/relay.go b/relay.go index ef35c207..0e2d5561 100644 --- a/relay.go +++ b/relay.go @@ -156,6 +156,10 @@ func (r *Relay) clientCommuncation(id int, connection net.Conn) { } } r.connections.RUnlock() + if connectionType == "c" { + sendMessage("0-0-0-0.0.0.0", connection) + return + } time.Sleep(100 * time.Millisecond) } // send meta data From 798a0d2c52d178ba1788e57cf04bb27c205c0107 Mon Sep 17 00:00:00 2001 From: Brad Lunsford Date: Fri, 20 Oct 2017 15:11:29 -0700 Subject: [PATCH 2/3] Adding sleep time of 1 second if not waiting and code not present --- connect.go | 1 + 1 file changed, 1 insertion(+) diff --git a/connect.go b/connect.go index 0d6aa700..e82139ff 100644 --- a/connect.go +++ b/connect.go @@ -210,6 +210,7 @@ func (c *Connection) runClient() error { encryptedData, salt, iv, sendersAddress := m[0], m[1], m[2], m[3] if sendersAddress == "0.0.0.0" { notPresent = true + time.Sleep(1 * time.Second) return } encryptedBytes, err := hex.DecodeString(encryptedData) From 6bdbdce655339e0c217d3fed4a8ca5d48eabfd12 Mon Sep 17 00:00:00 2001 From: Brad Lunsford Date: Fri, 20 Oct 2017 15:18:06 -0700 Subject: [PATCH 3/3] ran 'go fmt *.go' to (hopefully) get rid of commit issues --- connect.go | 910 ++++++++++++++++++++++++++--------------------------- main.go | 134 ++++---- relay.go | 504 ++++++++++++++--------------- 3 files changed, 774 insertions(+), 774 deletions(-) diff --git a/connect.go b/connect.go index e82139ff..80f4e820 100644 --- a/connect.go +++ b/connect.go @@ -1,455 +1,455 @@ -package main - -import ( - "encoding/hex" - "encoding/json" - "fmt" - "io" - "math" - "net" - "os" - "strconv" - "strings" - "sync" - "time" - - "github.com/gosuri/uiprogress" - "github.com/pkg/errors" - log "github.com/sirupsen/logrus" -) - -type Connection struct { - Server string - File FileMetaData - NumberOfConnections int - Code string - HashedCode string - IsSender bool - Debug bool - DontEncrypt bool - Wait bool - bars []*uiprogress.Bar - rate int -} - -type FileMetaData struct { - Name string - Size int - Hash string -} - -func NewConnection(flags *Flags) *Connection { - c := new(Connection) - c.Debug = flags.Debug - c.DontEncrypt = flags.DontEncrypt - c.Wait = flags.Wait - c.Server = flags.Server - c.Code = flags.Code - c.NumberOfConnections = flags.NumberOfConnections - c.rate = flags.Rate - if len(flags.File) > 0 { - c.File.Name = flags.File - c.IsSender = true - } else { - c.IsSender = false - } - - log.SetFormatter(&log.TextFormatter{}) - if c.Debug { - log.SetLevel(log.DebugLevel) - } else { - log.SetLevel(log.WarnLevel) - } - - return c -} - -func (c *Connection) Run() error { - forceSingleThreaded := false - if c.IsSender { - fsize, err := FileSize(c.File.Name) - if err != nil { - return err - } - if fsize < MAX_NUMBER_THREADS*BUFFERSIZE { - forceSingleThreaded = true - log.Debug("forcing single thread") - } - } - log.Debug("checking code validity") - for { - // check code - goodCode := true - m := strings.Split(c.Code, "-") - log.Debug(m) - numThreads, errParse := strconv.Atoi(m[0]) - if len(m) < 2 { - goodCode = false - log.Debug("code too short") - } else if numThreads > MAX_NUMBER_THREADS || numThreads < 1 || (forceSingleThreaded && numThreads != 1) { - c.NumberOfConnections = MAX_NUMBER_THREADS - goodCode = false - log.Debug("incorrect number of threads") - } else if errParse != nil { - goodCode = false - log.Debug("problem parsing threads") - } - log.Debug(m) - log.Debug(goodCode) - if !goodCode { - if c.IsSender { - if forceSingleThreaded { - c.NumberOfConnections = 1 - } - c.Code = strconv.Itoa(c.NumberOfConnections) + "-" + GetRandomName() - } else { - if len(c.Code) != 0 { - fmt.Println("Code must begin with number of threads (e.g. 3-some-code)") - } - c.Code = getInput("Enter receive code: ") - } - } else { - break - } - } - // assign number of connections - c.NumberOfConnections, _ = strconv.Atoi(strings.Split(c.Code, "-")[0]) - - if c.IsSender { - if c.DontEncrypt { - // don't encrypt - CopyFile(c.File.Name, c.File.Name+".enc") - } else { - // encrypt - log.Debug("encrypting...") - if err := EncryptFile(c.File.Name, c.File.Name+".enc", c.Code); err != nil { - return err - } - } - // get file hash - var err error - c.File.Hash, err = HashFile(c.File.Name) - if err != nil { - return err - } - // get file size - c.File.Size, err = FileSize(c.File.Name + ".enc") - if err != nil { - return err - } - fmt.Printf("Sending %d byte file named '%s'\n", c.File.Size, c.File.Name) - fmt.Printf("Code is: %s\n", c.Code) - } - - return c.runClient() -} - -// runClient spawns threads for parallel uplink/downlink via TCP -func (c *Connection) runClient() error { - logger := log.WithFields(log.Fields{ - "code": c.Code, - "sender?": c.IsSender, - }) - - c.HashedCode = Hash(c.Code) - - var wg sync.WaitGroup - wg.Add(c.NumberOfConnections) - - uiprogress.Start() - if !c.Debug { - c.bars = make([]*uiprogress.Bar, c.NumberOfConnections) - } - gotOK := false - gotResponse := false - notPresent := false - for id := 0; id < c.NumberOfConnections; id++ { - go func(id int) { - defer wg.Done() - port := strconv.Itoa(27001 + id) - connection, err := net.Dial("tcp", c.Server+":"+port) - if err != nil { - panic(err) - } - defer connection.Close() - - message := receiveMessage(connection) - logger.Debugf("relay says: %s", message) - if c.IsSender { - logger.Debugf("telling relay: %s", "s."+c.Code) - metaData, err := json.Marshal(c.File) - if err != nil { - log.Error(err) - } - encryptedMetaData, salt, iv := Encrypt(metaData, c.Code) - sendMessage("s."+c.HashedCode+"."+hex.EncodeToString(encryptedMetaData)+"-"+salt+"-"+iv, connection) - } else { - logger.Debugf("telling relay: %s", "r."+c.Code) - if c.Wait { - sendMessage("r."+c.HashedCode+".0.0.0", connection) - } else { - sendMessage("c."+c.HashedCode+".0.0.0", connection) - } - } - if c.IsSender { // this is a sender - logger.Debug("waiting for ok from relay") - message = receiveMessage(connection) - logger.Debug("got ok from relay") - if id == 0 { - fmt.Printf("\nSending (->%s)..\n", message) - } - // wait for pipe to be made - time.Sleep(100 * time.Millisecond) - // Write data from file - logger.Debug("send file") - c.sendFile(id, connection) - } else { // this is a receiver - logger.Debug("waiting for meta data from sender") - message = receiveMessage(connection) - m := strings.Split(message, "-") - encryptedData, salt, iv, sendersAddress := m[0], m[1], m[2], m[3] - if sendersAddress == "0.0.0.0" { - notPresent = true - time.Sleep(1 * time.Second) - return - } - encryptedBytes, err := hex.DecodeString(encryptedData) - if err != nil { - log.Error(err) - return - } - decryptedBytes, _ := Decrypt(encryptedBytes, c.Code, salt, iv, c.DontEncrypt) - err = json.Unmarshal(decryptedBytes, &c.File) - if err != nil { - log.Error(err) - return - } - log.Debugf("meta data received: %v", c.File) - // have the main thread ask for the okay - if id == 0 { - fmt.Printf("Receiving file (%d bytes) into: %s\n", c.File.Size, c.File.Name) - var sentFileNames []string - - if fileAlreadyExists(sentFileNames, c.File.Name) { - fmt.Printf("Will not overwrite file!") - os.Exit(1) - } - getOK := getInput("ok? (y/n): ") - if getOK == "y" { - gotOK = true - sentFileNames = append(sentFileNames, c.File.Name) - } - gotResponse = true - } - // wait for the main thread to get the okay - for limit := 0; limit < 1000; limit++ { - if gotResponse { - break - } - time.Sleep(10 * time.Millisecond) - } - if !gotOK { - sendMessage("not ok", connection) - } else { - sendMessage("ok", connection) - logger.Debug("receive file") - if id == 0 { - fmt.Printf("\n\nReceiving (<-%s)..\n", sendersAddress) - } - c.receiveFile(id, connection) - } - } - }(id) - } - wg.Wait() - - if !c.IsSender { - if notPresent { - fmt.Println("Sender/Code not present") - return nil - } - if !gotOK { - return errors.New("Transfer interrupted") - } - c.catFile(c.File.Name) - log.Debugf("Code: [%s]", c.Code) - if c.DontEncrypt { - if err := CopyFile(c.File.Name+".enc", c.File.Name); err != nil { - return err - } - } else { - if err := DecryptFile(c.File.Name+".enc", c.File.Name, c.Code); err != nil { - return errors.Wrap(err, "Problem decrypting file") - } - } - if !c.Debug { - os.Remove(c.File.Name + ".enc") - } - - fileHash, err := HashFile(c.File.Name) - if err != nil { - log.Error(err) - } - log.Debugf("\n\n\ndownloaded hash: [%s]", fileHash) - log.Debugf("\n\n\nrelayed hash: [%s]", c.File.Hash) - - if c.File.Hash != fileHash { - return fmt.Errorf("\nUh oh! %s is corrupted! Sorry, try again.\n", c.File.Name) - } else { - fmt.Printf("\nReceived file written to %s", c.File.Name) - } - } else { - fmt.Println("File sent.") - // TODO: Add confirmation - } - return nil -} - -func fileAlreadyExists(s []string, f string) bool { - for _, a := range s { - if a == f { - return true - } - } - return false -} - -func (c *Connection) catFile(fname string) { - // cat the file - os.Remove(fname) - finished, err := os.Create(fname + ".enc") - defer finished.Close() - if err != nil { - log.Fatal(err) - } - for id := 0; id < c.NumberOfConnections; id++ { - fh, err := os.Open(fname + "." + strconv.Itoa(id)) - if err != nil { - log.Fatal(err) - } - - _, err = io.Copy(finished, fh) - if err != nil { - log.Fatal(err) - } - fh.Close() - os.Remove(fname + "." + strconv.Itoa(id)) - } - -} - -func (c *Connection) receiveFile(id int, connection net.Conn) error { - logger := log.WithFields(log.Fields{ - "function": "receiveFile #" + strconv.Itoa(id), - }) - - logger.Debug("waiting for chunk size from sender") - fileSizeBuffer := make([]byte, 10) - connection.Read(fileSizeBuffer) - fileDataString := strings.Trim(string(fileSizeBuffer), ":") - fileSizeInt, _ := strconv.Atoi(fileDataString) - chunkSize := int64(fileSizeInt) - logger.Debugf("chunk size: %d", chunkSize) - - os.Remove(c.File.Name + "." + strconv.Itoa(id)) - newFile, err := os.Create(c.File.Name + "." + strconv.Itoa(id)) - if err != nil { - panic(err) - } - defer newFile.Close() - - if !c.Debug { - c.bars[id] = uiprogress.AddBar(int(chunkSize)/1024 + 1).AppendCompleted().PrependElapsed() - } - - logger.Debug("waiting for file") - var receivedBytes int64 - receivedFirstBytes := false - for { - if !c.Debug { - c.bars[id].Incr() - } - if (chunkSize - receivedBytes) < BUFFERSIZE { - logger.Debug("at the end") - io.CopyN(newFile, connection, (chunkSize - receivedBytes)) - // Empty the remaining bytes that we don't need from the network buffer - if (receivedBytes+BUFFERSIZE)-chunkSize < BUFFERSIZE { - logger.Debug("empty remaining bytes from network buffer") - connection.Read(make([]byte, (receivedBytes+BUFFERSIZE)-chunkSize)) - } - break - } - io.CopyN(newFile, connection, BUFFERSIZE) - receivedBytes += BUFFERSIZE - if !receivedFirstBytes { - receivedFirstBytes = true - logger.Debug("Receieved first bytes!") - } - } - logger.Debug("received file") - return nil -} - -func (c *Connection) sendFile(id int, connection net.Conn) { - logger := log.WithFields(log.Fields{ - "function": "sendFile #" + strconv.Itoa(id), - }) - defer connection.Close() - - var err error - - numChunks := math.Ceil(float64(c.File.Size) / float64(BUFFERSIZE)) - chunksPerWorker := int(math.Ceil(numChunks / float64(c.NumberOfConnections))) - - chunkSize := int64(chunksPerWorker * BUFFERSIZE) - if id+1 == c.NumberOfConnections { - chunkSize = int64(c.File.Size) - int64(c.NumberOfConnections-1)*chunkSize - } - - if id == 0 || id == c.NumberOfConnections-1 { - logger.Debugf("numChunks: %v", numChunks) - logger.Debugf("chunksPerWorker: %v", chunksPerWorker) - logger.Debugf("bytesPerchunkSizeConnection: %v", chunkSize) - } - - logger.Debugf("sending chunk size: %d", chunkSize) - connection.Write([]byte(fillString(strconv.FormatInt(int64(chunkSize), 10), 10))) - - sendBuffer := make([]byte, BUFFERSIZE) - - // open encrypted file - file, err := os.OpenFile(c.File.Name+".enc", os.O_RDONLY, 0755) - if err != nil { - log.Error(err) - return - } - defer file.Close() - - chunkI := 0 - if !c.Debug { - c.bars[id] = uiprogress.AddBar(chunksPerWorker).AppendCompleted().PrependElapsed() - } - - bufferSizeInKilobytes := BUFFERSIZE / 1024 - rate := float64(c.rate) / float64(c.NumberOfConnections*bufferSizeInKilobytes) - throttle := time.NewTicker(time.Second / time.Duration(rate)) - defer throttle.Stop() - - for range throttle.C { - _, err = file.Read(sendBuffer) - if err == io.EOF { - //End of file reached, break out of for loop - logger.Debug("EOF") - break - } - if (chunkI >= chunksPerWorker*id && chunkI < chunksPerWorker*id+chunksPerWorker) || (id == c.NumberOfConnections-1 && chunkI >= chunksPerWorker*id) { - connection.Write(sendBuffer) - if !c.Debug { - c.bars[id].Incr() - } - } - chunkI++ - } - logger.Debug("file is sent") - return -} +package main + +import ( + "encoding/hex" + "encoding/json" + "fmt" + "io" + "math" + "net" + "os" + "strconv" + "strings" + "sync" + "time" + + "github.com/gosuri/uiprogress" + "github.com/pkg/errors" + log "github.com/sirupsen/logrus" +) + +type Connection struct { + Server string + File FileMetaData + NumberOfConnections int + Code string + HashedCode string + IsSender bool + Debug bool + DontEncrypt bool + Wait bool + bars []*uiprogress.Bar + rate int +} + +type FileMetaData struct { + Name string + Size int + Hash string +} + +func NewConnection(flags *Flags) *Connection { + c := new(Connection) + c.Debug = flags.Debug + c.DontEncrypt = flags.DontEncrypt + c.Wait = flags.Wait + c.Server = flags.Server + c.Code = flags.Code + c.NumberOfConnections = flags.NumberOfConnections + c.rate = flags.Rate + if len(flags.File) > 0 { + c.File.Name = flags.File + c.IsSender = true + } else { + c.IsSender = false + } + + log.SetFormatter(&log.TextFormatter{}) + if c.Debug { + log.SetLevel(log.DebugLevel) + } else { + log.SetLevel(log.WarnLevel) + } + + return c +} + +func (c *Connection) Run() error { + forceSingleThreaded := false + if c.IsSender { + fsize, err := FileSize(c.File.Name) + if err != nil { + return err + } + if fsize < MAX_NUMBER_THREADS*BUFFERSIZE { + forceSingleThreaded = true + log.Debug("forcing single thread") + } + } + log.Debug("checking code validity") + for { + // check code + goodCode := true + m := strings.Split(c.Code, "-") + log.Debug(m) + numThreads, errParse := strconv.Atoi(m[0]) + if len(m) < 2 { + goodCode = false + log.Debug("code too short") + } else if numThreads > MAX_NUMBER_THREADS || numThreads < 1 || (forceSingleThreaded && numThreads != 1) { + c.NumberOfConnections = MAX_NUMBER_THREADS + goodCode = false + log.Debug("incorrect number of threads") + } else if errParse != nil { + goodCode = false + log.Debug("problem parsing threads") + } + log.Debug(m) + log.Debug(goodCode) + if !goodCode { + if c.IsSender { + if forceSingleThreaded { + c.NumberOfConnections = 1 + } + c.Code = strconv.Itoa(c.NumberOfConnections) + "-" + GetRandomName() + } else { + if len(c.Code) != 0 { + fmt.Println("Code must begin with number of threads (e.g. 3-some-code)") + } + c.Code = getInput("Enter receive code: ") + } + } else { + break + } + } + // assign number of connections + c.NumberOfConnections, _ = strconv.Atoi(strings.Split(c.Code, "-")[0]) + + if c.IsSender { + if c.DontEncrypt { + // don't encrypt + CopyFile(c.File.Name, c.File.Name+".enc") + } else { + // encrypt + log.Debug("encrypting...") + if err := EncryptFile(c.File.Name, c.File.Name+".enc", c.Code); err != nil { + return err + } + } + // get file hash + var err error + c.File.Hash, err = HashFile(c.File.Name) + if err != nil { + return err + } + // get file size + c.File.Size, err = FileSize(c.File.Name + ".enc") + if err != nil { + return err + } + fmt.Printf("Sending %d byte file named '%s'\n", c.File.Size, c.File.Name) + fmt.Printf("Code is: %s\n", c.Code) + } + + return c.runClient() +} + +// runClient spawns threads for parallel uplink/downlink via TCP +func (c *Connection) runClient() error { + logger := log.WithFields(log.Fields{ + "code": c.Code, + "sender?": c.IsSender, + }) + + c.HashedCode = Hash(c.Code) + + var wg sync.WaitGroup + wg.Add(c.NumberOfConnections) + + uiprogress.Start() + if !c.Debug { + c.bars = make([]*uiprogress.Bar, c.NumberOfConnections) + } + gotOK := false + gotResponse := false + notPresent := false + for id := 0; id < c.NumberOfConnections; id++ { + go func(id int) { + defer wg.Done() + port := strconv.Itoa(27001 + id) + connection, err := net.Dial("tcp", c.Server+":"+port) + if err != nil { + panic(err) + } + defer connection.Close() + + message := receiveMessage(connection) + logger.Debugf("relay says: %s", message) + if c.IsSender { + logger.Debugf("telling relay: %s", "s."+c.Code) + metaData, err := json.Marshal(c.File) + if err != nil { + log.Error(err) + } + encryptedMetaData, salt, iv := Encrypt(metaData, c.Code) + sendMessage("s."+c.HashedCode+"."+hex.EncodeToString(encryptedMetaData)+"-"+salt+"-"+iv, connection) + } else { + logger.Debugf("telling relay: %s", "r."+c.Code) + if c.Wait { + sendMessage("r."+c.HashedCode+".0.0.0", connection) + } else { + sendMessage("c."+c.HashedCode+".0.0.0", connection) + } + } + if c.IsSender { // this is a sender + logger.Debug("waiting for ok from relay") + message = receiveMessage(connection) + logger.Debug("got ok from relay") + if id == 0 { + fmt.Printf("\nSending (->%s)..\n", message) + } + // wait for pipe to be made + time.Sleep(100 * time.Millisecond) + // Write data from file + logger.Debug("send file") + c.sendFile(id, connection) + } else { // this is a receiver + logger.Debug("waiting for meta data from sender") + message = receiveMessage(connection) + m := strings.Split(message, "-") + encryptedData, salt, iv, sendersAddress := m[0], m[1], m[2], m[3] + if sendersAddress == "0.0.0.0" { + notPresent = true + time.Sleep(1 * time.Second) + return + } + encryptedBytes, err := hex.DecodeString(encryptedData) + if err != nil { + log.Error(err) + return + } + decryptedBytes, _ := Decrypt(encryptedBytes, c.Code, salt, iv, c.DontEncrypt) + err = json.Unmarshal(decryptedBytes, &c.File) + if err != nil { + log.Error(err) + return + } + log.Debugf("meta data received: %v", c.File) + // have the main thread ask for the okay + if id == 0 { + fmt.Printf("Receiving file (%d bytes) into: %s\n", c.File.Size, c.File.Name) + var sentFileNames []string + + if fileAlreadyExists(sentFileNames, c.File.Name) { + fmt.Printf("Will not overwrite file!") + os.Exit(1) + } + getOK := getInput("ok? (y/n): ") + if getOK == "y" { + gotOK = true + sentFileNames = append(sentFileNames, c.File.Name) + } + gotResponse = true + } + // wait for the main thread to get the okay + for limit := 0; limit < 1000; limit++ { + if gotResponse { + break + } + time.Sleep(10 * time.Millisecond) + } + if !gotOK { + sendMessage("not ok", connection) + } else { + sendMessage("ok", connection) + logger.Debug("receive file") + if id == 0 { + fmt.Printf("\n\nReceiving (<-%s)..\n", sendersAddress) + } + c.receiveFile(id, connection) + } + } + }(id) + } + wg.Wait() + + if !c.IsSender { + if notPresent { + fmt.Println("Sender/Code not present") + return nil + } + if !gotOK { + return errors.New("Transfer interrupted") + } + c.catFile(c.File.Name) + log.Debugf("Code: [%s]", c.Code) + if c.DontEncrypt { + if err := CopyFile(c.File.Name+".enc", c.File.Name); err != nil { + return err + } + } else { + if err := DecryptFile(c.File.Name+".enc", c.File.Name, c.Code); err != nil { + return errors.Wrap(err, "Problem decrypting file") + } + } + if !c.Debug { + os.Remove(c.File.Name + ".enc") + } + + fileHash, err := HashFile(c.File.Name) + if err != nil { + log.Error(err) + } + log.Debugf("\n\n\ndownloaded hash: [%s]", fileHash) + log.Debugf("\n\n\nrelayed hash: [%s]", c.File.Hash) + + if c.File.Hash != fileHash { + return fmt.Errorf("\nUh oh! %s is corrupted! Sorry, try again.\n", c.File.Name) + } else { + fmt.Printf("\nReceived file written to %s", c.File.Name) + } + } else { + fmt.Println("File sent.") + // TODO: Add confirmation + } + return nil +} + +func fileAlreadyExists(s []string, f string) bool { + for _, a := range s { + if a == f { + return true + } + } + return false +} + +func (c *Connection) catFile(fname string) { + // cat the file + os.Remove(fname) + finished, err := os.Create(fname + ".enc") + defer finished.Close() + if err != nil { + log.Fatal(err) + } + for id := 0; id < c.NumberOfConnections; id++ { + fh, err := os.Open(fname + "." + strconv.Itoa(id)) + if err != nil { + log.Fatal(err) + } + + _, err = io.Copy(finished, fh) + if err != nil { + log.Fatal(err) + } + fh.Close() + os.Remove(fname + "." + strconv.Itoa(id)) + } + +} + +func (c *Connection) receiveFile(id int, connection net.Conn) error { + logger := log.WithFields(log.Fields{ + "function": "receiveFile #" + strconv.Itoa(id), + }) + + logger.Debug("waiting for chunk size from sender") + fileSizeBuffer := make([]byte, 10) + connection.Read(fileSizeBuffer) + fileDataString := strings.Trim(string(fileSizeBuffer), ":") + fileSizeInt, _ := strconv.Atoi(fileDataString) + chunkSize := int64(fileSizeInt) + logger.Debugf("chunk size: %d", chunkSize) + + os.Remove(c.File.Name + "." + strconv.Itoa(id)) + newFile, err := os.Create(c.File.Name + "." + strconv.Itoa(id)) + if err != nil { + panic(err) + } + defer newFile.Close() + + if !c.Debug { + c.bars[id] = uiprogress.AddBar(int(chunkSize)/1024 + 1).AppendCompleted().PrependElapsed() + } + + logger.Debug("waiting for file") + var receivedBytes int64 + receivedFirstBytes := false + for { + if !c.Debug { + c.bars[id].Incr() + } + if (chunkSize - receivedBytes) < BUFFERSIZE { + logger.Debug("at the end") + io.CopyN(newFile, connection, (chunkSize - receivedBytes)) + // Empty the remaining bytes that we don't need from the network buffer + if (receivedBytes+BUFFERSIZE)-chunkSize < BUFFERSIZE { + logger.Debug("empty remaining bytes from network buffer") + connection.Read(make([]byte, (receivedBytes+BUFFERSIZE)-chunkSize)) + } + break + } + io.CopyN(newFile, connection, BUFFERSIZE) + receivedBytes += BUFFERSIZE + if !receivedFirstBytes { + receivedFirstBytes = true + logger.Debug("Receieved first bytes!") + } + } + logger.Debug("received file") + return nil +} + +func (c *Connection) sendFile(id int, connection net.Conn) { + logger := log.WithFields(log.Fields{ + "function": "sendFile #" + strconv.Itoa(id), + }) + defer connection.Close() + + var err error + + numChunks := math.Ceil(float64(c.File.Size) / float64(BUFFERSIZE)) + chunksPerWorker := int(math.Ceil(numChunks / float64(c.NumberOfConnections))) + + chunkSize := int64(chunksPerWorker * BUFFERSIZE) + if id+1 == c.NumberOfConnections { + chunkSize = int64(c.File.Size) - int64(c.NumberOfConnections-1)*chunkSize + } + + if id == 0 || id == c.NumberOfConnections-1 { + logger.Debugf("numChunks: %v", numChunks) + logger.Debugf("chunksPerWorker: %v", chunksPerWorker) + logger.Debugf("bytesPerchunkSizeConnection: %v", chunkSize) + } + + logger.Debugf("sending chunk size: %d", chunkSize) + connection.Write([]byte(fillString(strconv.FormatInt(int64(chunkSize), 10), 10))) + + sendBuffer := make([]byte, BUFFERSIZE) + + // open encrypted file + file, err := os.OpenFile(c.File.Name+".enc", os.O_RDONLY, 0755) + if err != nil { + log.Error(err) + return + } + defer file.Close() + + chunkI := 0 + if !c.Debug { + c.bars[id] = uiprogress.AddBar(chunksPerWorker).AppendCompleted().PrependElapsed() + } + + bufferSizeInKilobytes := BUFFERSIZE / 1024 + rate := float64(c.rate) / float64(c.NumberOfConnections*bufferSizeInKilobytes) + throttle := time.NewTicker(time.Second / time.Duration(rate)) + defer throttle.Stop() + + for range throttle.C { + _, err = file.Read(sendBuffer) + if err == io.EOF { + //End of file reached, break out of for loop + logger.Debug("EOF") + break + } + if (chunkI >= chunksPerWorker*id && chunkI < chunksPerWorker*id+chunksPerWorker) || (id == c.NumberOfConnections-1 && chunkI >= chunksPerWorker*id) { + connection.Write(sendBuffer) + if !c.Debug { + c.bars[id].Incr() + } + } + chunkI++ + } + logger.Debug("file is sent") + return +} diff --git a/main.go b/main.go index 199aa509..7945b54a 100644 --- a/main.go +++ b/main.go @@ -1,67 +1,67 @@ -package main - -import ( - "bufio" - "flag" - "fmt" - "os" - "strings" -) - -const BUFFERSIZE = 1024 - -var oneGigabytePerSecond = 1000000 // expressed as kbps - -type Flags struct { - Relay bool - Debug bool - Wait bool - DontEncrypt bool - Server string - File string - Code string - Rate int - NumberOfConnections int -} - -var version string - -func main() { - fmt.Println(` - /\_/\ - ____/ o o \ - /~____ =ø= / - (______)__m_m) - -croc version ` + version + ` -`) - flags := new(Flags) - flag.BoolVar(&flags.Relay, "relay", false, "run as relay") - flag.BoolVar(&flags.Debug, "debug", false, "debug mode") - flag.BoolVar(&flags.Wait, "wait", false, "wait for code to be sent") - flag.StringVar(&flags.Server, "server", "cowyo.com", "address of relay server") - flag.StringVar(&flags.File, "send", "", "file to send") - flag.StringVar(&flags.Code, "code", "", "use your own code phrase") - flag.IntVar(&flags.Rate, "rate", oneGigabytePerSecond, "throttle down to speed in kbps") - flag.BoolVar(&flags.DontEncrypt, "no-encrypt", false, "turn off encryption") - flag.IntVar(&flags.NumberOfConnections, "threads", 4, "number of threads to use") - flag.Parse() - - if flags.Relay { - r := NewRelay(flags) - r.Run() - } else { - c := NewConnection(flags) - err := c.Run() - if err != nil { - fmt.Printf("Error! Please submit the following error to https://github.com/schollz/croc/issues:\n\n'%s'\n\n", err.Error()) - } - } -} - -func getInput(prompt string) string { - reader := bufio.NewReader(os.Stdin) - fmt.Print(prompt) - text, _ := reader.ReadString('\n') - return strings.TrimSpace(text) -} +package main + +import ( + "bufio" + "flag" + "fmt" + "os" + "strings" +) + +const BUFFERSIZE = 1024 + +var oneGigabytePerSecond = 1000000 // expressed as kbps + +type Flags struct { + Relay bool + Debug bool + Wait bool + DontEncrypt bool + Server string + File string + Code string + Rate int + NumberOfConnections int +} + +var version string + +func main() { + fmt.Println(` + /\_/\ + ____/ o o \ + /~____ =ø= / + (______)__m_m) + +croc version ` + version + ` +`) + flags := new(Flags) + flag.BoolVar(&flags.Relay, "relay", false, "run as relay") + flag.BoolVar(&flags.Debug, "debug", false, "debug mode") + flag.BoolVar(&flags.Wait, "wait", false, "wait for code to be sent") + flag.StringVar(&flags.Server, "server", "cowyo.com", "address of relay server") + flag.StringVar(&flags.File, "send", "", "file to send") + flag.StringVar(&flags.Code, "code", "", "use your own code phrase") + flag.IntVar(&flags.Rate, "rate", oneGigabytePerSecond, "throttle down to speed in kbps") + flag.BoolVar(&flags.DontEncrypt, "no-encrypt", false, "turn off encryption") + flag.IntVar(&flags.NumberOfConnections, "threads", 4, "number of threads to use") + flag.Parse() + + if flags.Relay { + r := NewRelay(flags) + r.Run() + } else { + c := NewConnection(flags) + err := c.Run() + if err != nil { + fmt.Printf("Error! Please submit the following error to https://github.com/schollz/croc/issues:\n\n'%s'\n\n", err.Error()) + } + } +} + +func getInput(prompt string) string { + reader := bufio.NewReader(os.Stdin) + fmt.Print(prompt) + text, _ := reader.ReadString('\n') + return strings.TrimSpace(text) +} diff --git a/relay.go b/relay.go index 0e2d5561..829dfee3 100644 --- a/relay.go +++ b/relay.go @@ -1,252 +1,252 @@ -package main - -import ( - "net" - "strconv" - "strings" - "sync" - "time" - - "github.com/pkg/errors" - log "github.com/sirupsen/logrus" -) - -const MAX_NUMBER_THREADS = 8 - -type connectionMap struct { - reciever map[string]net.Conn - sender map[string]net.Conn - metadata map[string]string - sync.RWMutex -} - -type Relay struct { - connections connectionMap - Debug bool - NumberOfConnections int -} - -func NewRelay(flags *Flags) *Relay { - r := new(Relay) - r.Debug = flags.Debug - r.NumberOfConnections = MAX_NUMBER_THREADS - log.SetFormatter(&log.TextFormatter{}) - if r.Debug { - log.SetLevel(log.DebugLevel) - } else { - log.SetLevel(log.WarnLevel) - } - return r -} - -func (r *Relay) Run() { - r.connections = connectionMap{} - r.connections.Lock() - r.connections.reciever = make(map[string]net.Conn) - r.connections.sender = make(map[string]net.Conn) - r.connections.metadata = make(map[string]string) - r.connections.Unlock() - r.runServer() -} - -func (r *Relay) runServer() { - logger := log.WithFields(log.Fields{ - "function": "main", - }) - logger.Debug("Initializing") - var wg sync.WaitGroup - wg.Add(r.NumberOfConnections) - for id := 0; id < r.NumberOfConnections; id++ { - go r.listenerThread(id, &wg) - } - wg.Wait() -} - -func (r *Relay) listenerThread(id int, wg *sync.WaitGroup) { - logger := log.WithFields(log.Fields{ - "function": "listenerThread:" + strconv.Itoa(27000+id), - }) - - defer wg.Done() - err := r.listener(id) - if err != nil { - logger.Error(err) - } -} - -func (r *Relay) listener(id int) (err error) { - port := strconv.Itoa(27001 + id) - logger := log.WithFields(log.Fields{ - "function": "listener" + ":" + port, - }) - server, err := net.Listen("tcp", "0.0.0.0:"+port) - if err != nil { - return errors.Wrap(err, "Error listening on "+":"+port) - } - defer server.Close() - logger.Debug("waiting for connections") - //Spawn a new goroutine whenever a client connects - for { - connection, err := server.Accept() - if err != nil { - return errors.Wrap(err, "problem accepting connection") - } - logger.Debugf("Client %s connected", connection.RemoteAddr().String()) - go r.clientCommuncation(id, connection) - } -} - -func (r *Relay) clientCommuncation(id int, connection net.Conn) { - sendMessage("who?", connection) - - m := strings.Split(receiveMessage(connection), ".") - connectionType, codePhrase, metaData := m[0], m[1], m[2] - key := codePhrase + "-" + strconv.Itoa(id) - logger := log.WithFields(log.Fields{ - "id": id, - "codePhrase": codePhrase, - }) - - if connectionType == "s" { - logger.Debug("got sender") - r.connections.Lock() - r.connections.metadata[key] = metaData - r.connections.sender[key] = connection - r.connections.Unlock() - // wait for receiver - receiversAddress := "" - for { - r.connections.RLock() - if _, ok := r.connections.reciever[key]; ok { - receiversAddress = r.connections.reciever[key].RemoteAddr().String() - logger.Debug("got reciever") - r.connections.RUnlock() - break - } - r.connections.RUnlock() - time.Sleep(100 * time.Millisecond) - } - logger.Debug("telling sender ok") - sendMessage(receiversAddress, connection) - logger.Debug("preparing pipe") - r.connections.Lock() - con1 := r.connections.sender[key] - con2 := r.connections.reciever[key] - r.connections.Unlock() - logger.Debug("piping connections") - Pipe(con1, con2) - logger.Debug("done piping") - r.connections.Lock() - delete(r.connections.sender, key) - delete(r.connections.reciever, key) - delete(r.connections.metadata, key) - r.connections.Unlock() - logger.Debug("deleted sender and receiver") - } else { - // wait for sender's metadata - sendersAddress := "" - for { - r.connections.RLock() - if _, ok := r.connections.metadata[key]; ok { - if _, ok2 := r.connections.sender[key]; ok2 { - sendersAddress = r.connections.sender[key].RemoteAddr().String() - logger.Debug("got sender meta data") - r.connections.RUnlock() - break - } - } - r.connections.RUnlock() - if connectionType == "c" { - sendMessage("0-0-0-0.0.0.0", connection) - return - } - time.Sleep(100 * time.Millisecond) - } - // send meta data - r.connections.RLock() - sendMessage(r.connections.metadata[key]+"-"+sendersAddress, connection) - r.connections.RUnlock() - // check for receiver's consent - consent := receiveMessage(connection) - logger.Debugf("consent: %s", consent) - if consent == "ok" { - logger.Debug("got consent") - r.connections.Lock() - r.connections.reciever[key] = connection - r.connections.Unlock() - } - } - return -} - -func sendMessage(message string, connection net.Conn) { - message = fillString(message, BUFFERSIZE) - connection.Write([]byte(message)) -} - -func receiveMessage(connection net.Conn) string { - messageByte := make([]byte, BUFFERSIZE) - connection.Read(messageByte) - return strings.Replace(string(messageByte), ":", "", -1) -} - -func fillString(retunString string, toLength int) string { - for { - lengthString := len(retunString) - if lengthString < toLength { - retunString = retunString + ":" - continue - } - break - } - return retunString -} - -// chanFromConn creates a channel from a Conn object, and sends everything it -// Read()s from the socket to the channel. -func chanFromConn(conn net.Conn) chan []byte { - c := make(chan []byte) - - go func() { - b := make([]byte, BUFFERSIZE) - - for { - n, err := conn.Read(b) - if n > 0 { - res := make([]byte, n) - // Copy the buffer so it doesn't get changed while read by the recipient. - copy(res, b[:n]) - c <- res - } - if err != nil { - c <- nil - break - } - } - }() - - return c -} - -// Pipe creates a full-duplex pipe between the two sockets and transfers data from one to the other. -func Pipe(conn1 net.Conn, conn2 net.Conn) { - chan1 := chanFromConn(conn1) - chan2 := chanFromConn(conn2) - - for { - select { - case b1 := <-chan1: - if b1 == nil { - return - } else { - conn2.Write(b1) - } - case b2 := <-chan2: - if b2 == nil { - return - } else { - conn1.Write(b2) - } - } - } -} +package main + +import ( + "net" + "strconv" + "strings" + "sync" + "time" + + "github.com/pkg/errors" + log "github.com/sirupsen/logrus" +) + +const MAX_NUMBER_THREADS = 8 + +type connectionMap struct { + reciever map[string]net.Conn + sender map[string]net.Conn + metadata map[string]string + sync.RWMutex +} + +type Relay struct { + connections connectionMap + Debug bool + NumberOfConnections int +} + +func NewRelay(flags *Flags) *Relay { + r := new(Relay) + r.Debug = flags.Debug + r.NumberOfConnections = MAX_NUMBER_THREADS + log.SetFormatter(&log.TextFormatter{}) + if r.Debug { + log.SetLevel(log.DebugLevel) + } else { + log.SetLevel(log.WarnLevel) + } + return r +} + +func (r *Relay) Run() { + r.connections = connectionMap{} + r.connections.Lock() + r.connections.reciever = make(map[string]net.Conn) + r.connections.sender = make(map[string]net.Conn) + r.connections.metadata = make(map[string]string) + r.connections.Unlock() + r.runServer() +} + +func (r *Relay) runServer() { + logger := log.WithFields(log.Fields{ + "function": "main", + }) + logger.Debug("Initializing") + var wg sync.WaitGroup + wg.Add(r.NumberOfConnections) + for id := 0; id < r.NumberOfConnections; id++ { + go r.listenerThread(id, &wg) + } + wg.Wait() +} + +func (r *Relay) listenerThread(id int, wg *sync.WaitGroup) { + logger := log.WithFields(log.Fields{ + "function": "listenerThread:" + strconv.Itoa(27000+id), + }) + + defer wg.Done() + err := r.listener(id) + if err != nil { + logger.Error(err) + } +} + +func (r *Relay) listener(id int) (err error) { + port := strconv.Itoa(27001 + id) + logger := log.WithFields(log.Fields{ + "function": "listener" + ":" + port, + }) + server, err := net.Listen("tcp", "0.0.0.0:"+port) + if err != nil { + return errors.Wrap(err, "Error listening on "+":"+port) + } + defer server.Close() + logger.Debug("waiting for connections") + //Spawn a new goroutine whenever a client connects + for { + connection, err := server.Accept() + if err != nil { + return errors.Wrap(err, "problem accepting connection") + } + logger.Debugf("Client %s connected", connection.RemoteAddr().String()) + go r.clientCommuncation(id, connection) + } +} + +func (r *Relay) clientCommuncation(id int, connection net.Conn) { + sendMessage("who?", connection) + + m := strings.Split(receiveMessage(connection), ".") + connectionType, codePhrase, metaData := m[0], m[1], m[2] + key := codePhrase + "-" + strconv.Itoa(id) + logger := log.WithFields(log.Fields{ + "id": id, + "codePhrase": codePhrase, + }) + + if connectionType == "s" { + logger.Debug("got sender") + r.connections.Lock() + r.connections.metadata[key] = metaData + r.connections.sender[key] = connection + r.connections.Unlock() + // wait for receiver + receiversAddress := "" + for { + r.connections.RLock() + if _, ok := r.connections.reciever[key]; ok { + receiversAddress = r.connections.reciever[key].RemoteAddr().String() + logger.Debug("got reciever") + r.connections.RUnlock() + break + } + r.connections.RUnlock() + time.Sleep(100 * time.Millisecond) + } + logger.Debug("telling sender ok") + sendMessage(receiversAddress, connection) + logger.Debug("preparing pipe") + r.connections.Lock() + con1 := r.connections.sender[key] + con2 := r.connections.reciever[key] + r.connections.Unlock() + logger.Debug("piping connections") + Pipe(con1, con2) + logger.Debug("done piping") + r.connections.Lock() + delete(r.connections.sender, key) + delete(r.connections.reciever, key) + delete(r.connections.metadata, key) + r.connections.Unlock() + logger.Debug("deleted sender and receiver") + } else { + // wait for sender's metadata + sendersAddress := "" + for { + r.connections.RLock() + if _, ok := r.connections.metadata[key]; ok { + if _, ok2 := r.connections.sender[key]; ok2 { + sendersAddress = r.connections.sender[key].RemoteAddr().String() + logger.Debug("got sender meta data") + r.connections.RUnlock() + break + } + } + r.connections.RUnlock() + if connectionType == "c" { + sendMessage("0-0-0-0.0.0.0", connection) + return + } + time.Sleep(100 * time.Millisecond) + } + // send meta data + r.connections.RLock() + sendMessage(r.connections.metadata[key]+"-"+sendersAddress, connection) + r.connections.RUnlock() + // check for receiver's consent + consent := receiveMessage(connection) + logger.Debugf("consent: %s", consent) + if consent == "ok" { + logger.Debug("got consent") + r.connections.Lock() + r.connections.reciever[key] = connection + r.connections.Unlock() + } + } + return +} + +func sendMessage(message string, connection net.Conn) { + message = fillString(message, BUFFERSIZE) + connection.Write([]byte(message)) +} + +func receiveMessage(connection net.Conn) string { + messageByte := make([]byte, BUFFERSIZE) + connection.Read(messageByte) + return strings.Replace(string(messageByte), ":", "", -1) +} + +func fillString(retunString string, toLength int) string { + for { + lengthString := len(retunString) + if lengthString < toLength { + retunString = retunString + ":" + continue + } + break + } + return retunString +} + +// chanFromConn creates a channel from a Conn object, and sends everything it +// Read()s from the socket to the channel. +func chanFromConn(conn net.Conn) chan []byte { + c := make(chan []byte) + + go func() { + b := make([]byte, BUFFERSIZE) + + for { + n, err := conn.Read(b) + if n > 0 { + res := make([]byte, n) + // Copy the buffer so it doesn't get changed while read by the recipient. + copy(res, b[:n]) + c <- res + } + if err != nil { + c <- nil + break + } + } + }() + + return c +} + +// Pipe creates a full-duplex pipe between the two sockets and transfers data from one to the other. +func Pipe(conn1 net.Conn, conn2 net.Conn) { + chan1 := chanFromConn(conn1) + chan2 := chanFromConn(conn2) + + for { + select { + case b1 := <-chan1: + if b1 == nil { + return + } else { + conn2.Write(b1) + } + case b2 := <-chan2: + if b2 == nil { + return + } else { + conn1.Write(b2) + } + } + } +}