diff --git a/src/croc/config.go b/src/croc/config.go deleted file mode 100644 index 8a37560b..00000000 --- a/src/croc/config.go +++ /dev/null @@ -1,187 +0,0 @@ -package croc - -import ( - "bytes" - "fmt" - "io/ioutil" - "os" - "path" - "path/filepath" - "time" - - "github.com/BurntSushi/toml" - homedir "github.com/mitchellh/go-homedir" - "github.com/schollz/croc/src/utils" -) - -type Config struct { - // Relay parameters - RelayWebsocketPort string - RelayTCPPorts []string - - // Sender parameters - CurveType string - - // Options for connecting to server - PublicServerIP string - AddressTCPPorts []string - AddressWebsocketPort string - Timeout time.Duration - LocalOnly bool - NoLocal bool - - // Options for file transfering - UseEncryption bool - UseCompression bool - AllowLocalDiscovery bool - NoRecipientPrompt bool - ForceTCP bool - ForceWebsockets bool - Codephrase string -} - -func defaultConfig() Config { - c := Config{} - cr := Init(false) - c.RelayWebsocketPort = cr.RelayWebsocketPort - c.RelayTCPPorts = cr.RelayTCPPorts - c.CurveType = cr.CurveType - c.PublicServerIP = cr.Address - c.AddressTCPPorts = cr.AddressTCPPorts - c.AddressWebsocketPort = cr.AddressWebsocketPort - c.Timeout = cr.Timeout - c.LocalOnly = cr.LocalOnly - c.NoLocal = cr.NoLocal - c.UseEncryption = cr.UseEncryption - c.UseCompression = cr.UseCompression - c.AllowLocalDiscovery = cr.AllowLocalDiscovery - c.NoRecipientPrompt = cr.NoRecipientPrompt - c.ForceTCP = false - c.ForceWebsockets = false - c.Codephrase = "" - return c -} - -func SaveDefaultConfig() error { - homedir, err := homedir.Dir() - if err != nil { - return err - } - os.MkdirAll(path.Join(homedir, ".config", "croc"), 0755) - c := defaultConfig() - buf := new(bytes.Buffer) - toml.NewEncoder(buf).Encode(c) - confTOML := buf.String() - err = ioutil.WriteFile(path.Join(homedir, ".config", "croc", "config.toml"), []byte(confTOML), 0644) - if err == nil { - fmt.Printf("Default config file written at '%s'\r\n", filepath.Clean(path.Join(homedir, ".config", "croc", "config.toml"))) - } - return err -} - -// LoadConfig will override parameters -func (cr *Croc) LoadConfig() (err error) { - homedir, err := homedir.Dir() - if err != nil { - return err - } - pathToConfig := path.Join(homedir, ".config", "croc", "config.toml") - if !utils.Exists(pathToConfig) { - // ignore if doesn't exist - return nil - } - - var c Config - _, err = toml.DecodeFile(pathToConfig, &c) - if err != nil { - return - } - - cDefault := defaultConfig() - // only load if things are different than defaults - // just in case the CLI parameters are used - if c.RelayWebsocketPort != cDefault.RelayWebsocketPort && cr.RelayWebsocketPort == cDefault.RelayWebsocketPort { - cr.RelayWebsocketPort = c.RelayWebsocketPort - fmt.Printf("loaded RelayWebsocketPort from config\n") - } - if !slicesEqual(c.RelayTCPPorts, cDefault.RelayTCPPorts) && slicesEqual(cr.RelayTCPPorts, cDefault.RelayTCPPorts) { - cr.RelayTCPPorts = c.RelayTCPPorts - fmt.Printf("loaded RelayTCPPorts from config\n") - } - if c.CurveType != cDefault.CurveType && cr.CurveType == cDefault.CurveType { - cr.CurveType = c.CurveType - fmt.Printf("loaded CurveType from config\n") - } - if c.PublicServerIP != cDefault.PublicServerIP && cr.Address == cDefault.PublicServerIP { - cr.Address = c.PublicServerIP - fmt.Printf("loaded Address from config\n") - } - if !slicesEqual(c.AddressTCPPorts, cDefault.AddressTCPPorts) { - cr.AddressTCPPorts = c.AddressTCPPorts - fmt.Printf("loaded AddressTCPPorts from config\n") - } - if c.AddressWebsocketPort != cDefault.AddressWebsocketPort && cr.AddressWebsocketPort == cDefault.AddressWebsocketPort { - cr.AddressWebsocketPort = c.AddressWebsocketPort - fmt.Printf("loaded AddressWebsocketPort from config\n") - } - if c.Timeout != cDefault.Timeout && cr.Timeout == cDefault.Timeout { - cr.Timeout = c.Timeout - fmt.Printf("loaded Timeout from config\n") - } - if c.LocalOnly != cDefault.LocalOnly && cr.LocalOnly == cDefault.LocalOnly { - cr.LocalOnly = c.LocalOnly - fmt.Printf("loaded LocalOnly from config\n") - } - if c.NoLocal != cDefault.NoLocal && cr.NoLocal == cDefault.NoLocal { - cr.NoLocal = c.NoLocal - fmt.Printf("loaded NoLocal from config\n") - } - if c.UseEncryption != cDefault.UseEncryption && cr.UseEncryption == cDefault.UseEncryption { - cr.UseEncryption = c.UseEncryption - fmt.Printf("loaded UseEncryption from config\n") - } - if c.UseCompression != cDefault.UseCompression && cr.UseCompression == cDefault.UseCompression { - cr.UseCompression = c.UseCompression - fmt.Printf("loaded UseCompression from config\n") - } - if c.AllowLocalDiscovery != cDefault.AllowLocalDiscovery && cr.AllowLocalDiscovery == cDefault.AllowLocalDiscovery { - cr.AllowLocalDiscovery = c.AllowLocalDiscovery - fmt.Printf("loaded AllowLocalDiscovery from config\n") - } - if c.NoRecipientPrompt != cDefault.NoRecipientPrompt && cr.NoRecipientPrompt == cDefault.NoRecipientPrompt { - cr.NoRecipientPrompt = c.NoRecipientPrompt - fmt.Printf("loaded NoRecipientPrompt from config\n") - } - if c.ForceWebsockets { - cr.ForceSend = 1 - } - if c.ForceTCP { - cr.ForceSend = 2 - } - if c.Codephrase != cDefault.Codephrase && cr.Codephrase == cDefault.Codephrase { - cr.Codephrase = c.Codephrase - fmt.Printf("loaded Codephrase from config\n") - } - return -} - -// slicesEqual checcks if two slices are equal -// from https://stackoverflow.com/a/15312097 -func slicesEqual(a, b []string) bool { - // If one is nil, the other must also be nil. - if (a == nil) != (b == nil) { - return false - } - - if len(a) != len(b) { - return false - } - - for i := range a { - if a[i] != b[i] { - return false - } - } - - return true -} diff --git a/src/croc/croc.go b/src/croc/croc.go index fd11e207..490c4241 100644 --- a/src/croc/croc.go +++ b/src/croc/croc.go @@ -1,96 +1,706 @@ package croc import ( - "runtime" + "bytes" + "crypto/elliptic" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "math" + "os" + "path" + "path/filepath" + "strings" + "sync" "time" - "github.com/schollz/croc/src/logger" - "github.com/schollz/croc/src/models" - "github.com/schollz/croc/src/relay" - "github.com/schollz/croc/src/zipper" + "github.com/denisbrodbeck/machineid" + "github.com/go-redis/redis" + "github.com/mattn/go-colorable" + "github.com/pions/webrtc" + "github.com/schollz/croc/v5/src/crypt" + "github.com/schollz/croc/v5/src/utils" + "github.com/schollz/croc/v5/src/webrtc/pkg/session/common" + "github.com/schollz/croc/v5/src/webrtc/pkg/session/receiver" + "github.com/schollz/croc/v5/src/webrtc/pkg/session/sender" + "github.com/schollz/pake" "github.com/schollz/progressbar/v2" + "github.com/schollz/spinner" + "github.com/sirupsen/logrus" ) +const BufferSize = 4096 * 10 +const Channels = 1 + +var log = logrus.New() + func init() { - runtime.GOMAXPROCS(runtime.NumCPU()) + log.SetFormatter(&logrus.TextFormatter{ForceColors: true}) + log.SetOutput(colorable.NewColorableStdout()) + Debug(false) } -// Croc options -type Croc struct { - // Version is the version of croc - Version string - // Options for all - Debug bool - // ShowText will display text on the stderr - ShowText bool - - // Options for relay - RelayWebsocketPort string - RelayTCPPorts []string - CurveType string - - // Options for connecting to server - Address string - AddressTCPPorts []string - AddressWebsocketPort string - Timeout time.Duration - LocalOnly bool - NoLocal bool - - // Options for file transfering - UseEncryption bool - UseCompression bool - AllowLocalDiscovery bool - NoRecipientPrompt bool - Stdout bool - ForceSend int // 0: ignore, 1: websockets, 2: TCP - - // Parameters for file transfer - Filename string - Codephrase string - - // localIP address - localIP string - // is using local relay - isLocal bool - normalFinish bool - - // state variables - StateString string - Bar *progressbar.ProgressBar - FileInfo models.FileStats - OtherIP string - - // special for window - WindowRecipientPrompt bool - WindowRecipientAccept bool - WindowReceivingString string -} - -// Init will initiate with the default parameters -func Init(debug bool) (c *Croc) { - c = new(Croc) - c.UseCompression = true - c.UseEncryption = true - c.AllowLocalDiscovery = true - c.RelayWebsocketPort = "8153" - c.RelayTCPPorts = []string{"8154", "8155", "8156", "8157", "8158", "8159", "8160", "8161"} - c.CurveType = "siec" - c.Address = "croc4.schollz.com" - c.AddressWebsocketPort = "8153" - c.AddressTCPPorts = []string{"8154", "8155", "8156", "8157", "8158", "8159", "8160", "8161"} - c.NoRecipientPrompt = true - debugLevel := "info" +func Debug(debug bool) { + receiver.Debug(debug) + sender.Debug(debug) if debug { - debugLevel = "debug" - c.Debug = true + log.SetLevel(logrus.DebugLevel) + } else { + log.SetLevel(logrus.WarnLevel) } - SetDebugLevel(debugLevel) +} + +type Client struct { + Options Options + // basic setup + redisdb *redis.Client + log *logrus.Entry + Pake *pake.Pake + + // steps involved in forming relationship + Step1ChannelSecured bool + Step2FileInfoTransfered bool + Step3RecipientRequestFile bool + Step4FileTransfer bool + Step5CloseChannels bool // TODO: Step5 should close files and reset things + + // send / receive information of all files + FilesToTransfer []FileInfo + FilesToTransferCurrentNum int + + // send / receive information of current file + CurrentFile *os.File + CurrentFileChunks []int64 + + sendSess *sender.Session + recvSess *receiver.Session + + // channel data + incomingMessageChannel <-chan *redis.Message + nameOutChannel string + nameInChannel string + + // webrtc connections + peerConnection [8]*webrtc.PeerConnection + dataChannel [8]*webrtc.DataChannel + + bar *progressbar.ProgressBar + spinner *spinner.Spinner + machineID string + + mutex *sync.Mutex + quit chan bool +} + +type Message struct { + Type string `json:"t,omitempty"` + Message string `json:"m,omitempty"` + Bytes []byte `json:"b,omitempty"` + Num int `json:"n,omitempty"` +} + +type Chunk struct { + Bytes []byte `json:"b,omitempty"` + Location int64 `json:"l,omitempty"` +} + +type FileInfo struct { + Name string `json:"n,omitempty"` + FolderRemote string `json:"fr,omitempty"` + FolderSource string `json:"fs,omitempty"` + Hash []byte `json:"h,omitempty"` + Size int64 `json:"s,omitempty"` + ModTime time.Time `json:"m,omitempty"` + IsCompressed bool `json:"c,omitempty"` + IsEncrypted bool `json:"e,omitempty"` +} + +type RemoteFileRequest struct { + CurrentFileChunks []int64 + FilesToTransferCurrentNum int +} + +type SenderInfo struct { + MachineID string + FilesToTransfer []FileInfo +} + +func (m Message) String() string { + b, _ := json.Marshal(m) + return string(b) +} + +type Options struct { + IsSender bool + SharedSecret string + Debug bool + AddressRelay string + Stdout bool + NoPrompt bool +} + +// New establishes a new connection for transfering files between two instances. +func New(ops Options) (c *Client, err error) { + c = new(Client) + + // setup basic info + c.Options = ops + Debug(c.Options.Debug) + log.Debugf("options: %+v", c.Options) + + // set channels + if c.Options.IsSender { + c.nameOutChannel = c.Options.SharedSecret + "2" + c.nameInChannel = c.Options.SharedSecret + "1" + } else { + c.nameOutChannel = c.Options.SharedSecret + "1" + c.nameInChannel = c.Options.SharedSecret + "2" + } + + // initialize redis for communication in establishing channel + c.redisdb = redis.NewClient(&redis.Options{ + Addr: c.Options.AddressRelay, + Password: "", + DB: 4, + WriteTimeout: 1 * time.Hour, + ReadTimeout: 1 * time.Hour, + }) + _, err = c.redisdb.Ping().Result() + if err != nil { + return + } + + // setup channel for listening + pubsub := c.redisdb.Subscribe(c.nameInChannel) + _, err = pubsub.Receive() + if err != nil { + return + } + c.incomingMessageChannel = pubsub.Channel() + + // initialize pake + if c.Options.IsSender { + c.Pake, err = pake.Init([]byte(c.Options.SharedSecret), 1, elliptic.P521(), 1*time.Microsecond) + } else { + c.Pake, err = pake.Init([]byte(c.Options.SharedSecret), 0, elliptic.P521(), 1*time.Microsecond) + } + if err != nil { + return + } + + // initialize logger + c.log = log.WithFields(logrus.Fields{ + "is": "sender", + }) + if !c.Options.IsSender { + c.log = log.WithFields(logrus.Fields{ + "is": "recipient", + }) + } + + c.spinner = spinner.New(spinner.CharSets[9], 100*time.Millisecond) + c.spinner.Writer = os.Stderr + c.spinner.Suffix = " connecting..." + + c.mutex = &sync.Mutex{} return } -func SetDebugLevel(debugLevel string) { - logger.SetLogLevel(debugLevel) - relay.DebugLevel = debugLevel - zipper.DebugLevel = debugLevel +type TransferOptions struct { + PathToFiles []string + KeepPathInRemote bool +} + +// Send will send the specified file +func (c *Client) Send(options TransferOptions) (err error) { + return c.transfer(options) +} + +// Receive will receive a file +func (c *Client) Receive() (err error) { + return c.transfer(TransferOptions{}) +} + +func (c *Client) transfer(options TransferOptions) (err error) { + if c.Options.IsSender { + c.FilesToTransfer = make([]FileInfo, len(options.PathToFiles)) + totalFilesSize := int64(0) + for i, pathToFile := range options.PathToFiles { + var fstats os.FileInfo + var fullPath string + fullPath, err = filepath.Abs(pathToFile) + if err != nil { + return + } + fullPath = filepath.Clean(fullPath) + var folderName string + folderName, _ = filepath.Split(fullPath) + + fstats, err = os.Stat(fullPath) + if err != nil { + return + } + c.FilesToTransfer[i] = FileInfo{ + Name: fstats.Name(), + FolderRemote: ".", + FolderSource: folderName, + Size: fstats.Size(), + ModTime: fstats.ModTime(), + } + c.FilesToTransfer[i].Hash, err = utils.HashFile(fullPath) + totalFilesSize += fstats.Size() + if err != nil { + return + } + if options.KeepPathInRemote { + var curFolder string + curFolder, err = os.Getwd() + if err != nil { + return + } + curFolder, err = filepath.Abs(curFolder) + if err != nil { + return + } + if !strings.HasPrefix(folderName, curFolder) { + err = fmt.Errorf("remote directory must be relative to current") + return + } + c.FilesToTransfer[i].FolderRemote = strings.TrimPrefix(folderName, curFolder) + c.FilesToTransfer[i].FolderRemote = filepath.ToSlash(c.FilesToTransfer[i].FolderRemote) + c.FilesToTransfer[i].FolderRemote = strings.TrimPrefix(c.FilesToTransfer[i].FolderRemote, "/") + if c.FilesToTransfer[i].FolderRemote == "" { + c.FilesToTransfer[i].FolderRemote = "." + } + } + log.Debugf("file %d info: %+v", i, c.FilesToTransfer[i]) + } + fname := fmt.Sprintf("%d files", len(c.FilesToTransfer)) + if len(c.FilesToTransfer) == 1 { + fname = fmt.Sprintf("'%s'", c.FilesToTransfer[0].Name) + } + machID, macIDerr := machineid.ID() + if macIDerr != nil { + log.Error(macIDerr) + return + } + if len(machID) > 6 { + machID = machID[:6] + } + c.machineID = machID + fmt.Fprintf(os.Stderr, "Sending %s (%s) from your machine, '%s'\n", fname, utils.ByteCountDecimal(totalFilesSize), machID) + fmt.Fprintf(os.Stderr, "Code is: %s\nOn the other computer run\n\ncroc %s\n", c.Options.SharedSecret, c.Options.SharedSecret) + c.spinner.Suffix = " waiting for recipient..." + } + c.spinner.Start() + // create channel for quitting + // quit with c.quit <- true + c.quit = make(chan bool) + + // if recipient, initialize with sending pake information + c.log.Debug("ready") + if !c.Options.IsSender && !c.Step1ChannelSecured { + err = c.redisdb.Publish(c.nameOutChannel, Message{ + Type: "pake", + Bytes: c.Pake.Bytes(), + }.String()).Err() + if err != nil { + return + } + } + + // listen for incoming messages and process them + for { + select { + case <-c.quit: + return + case msg := <-c.incomingMessageChannel: + var m Message + err = json.Unmarshal([]byte(msg.Payload), &m) + if err != nil { + return + } + if m.Type == "finished" { + err = c.redisdb.Publish(c.nameOutChannel, Message{ + Type: "finished", + }.String()).Err() + return err + } + err = c.processMessage(m) + if err != nil { + return + } + default: + time.Sleep(1 * time.Millisecond) + } + } + return +} + +func (c *Client) sendOverRedis() (err error) { + go func() { + c.bar = progressbar.NewOptions( + int(c.FilesToTransfer[c.FilesToTransferCurrentNum].Size), + progressbar.OptionSetRenderBlankState(true), + progressbar.OptionSetBytes(int(c.FilesToTransfer[c.FilesToTransferCurrentNum].Size)), + progressbar.OptionSetWriter(os.Stderr), + progressbar.OptionThrottle(1/60*time.Second), + ) + c.CurrentFile, err = os.Open(c.FilesToTransfer[c.FilesToTransferCurrentNum].Name) + if err != nil { + panic(err) + } + location := int64(0) + for { + buf := make([]byte, 4096*128) + n, errRead := c.CurrentFile.Read(buf) + c.bar.Add(n) + chunk := Chunk{ + Bytes: buf[:n], + Location: location, + } + chunkB, _ := json.Marshal(chunk) + err = c.redisdb.Publish(c.nameOutChannel, Message{ + Type: "chunk", + Bytes: chunkB, + }.String()).Err() + if err != nil { + panic(err) + } + location += int64(n) + if errRead == io.EOF { + break + } + if errRead != nil { + panic(errRead) + } + } + }() + return +} + +func (c *Client) processMessage(m Message) (err error) { + switch m.Type { + case "pake": + if c.spinner.Suffix != " performing PAKE..." { + c.spinner.Stop() + c.spinner.Suffix = " performing PAKE..." + c.spinner.Start() + } + notVerified := !c.Pake.IsVerified() + err = c.Pake.Update(m.Bytes) + if err != nil { + return + } + if (notVerified && c.Pake.IsVerified() && !c.Options.IsSender) || !c.Pake.IsVerified() { + err = c.redisdb.Publish(c.nameOutChannel, Message{ + Type: "pake", + Bytes: c.Pake.Bytes(), + }.String()).Err() + } + if c.Pake.IsVerified() { + c.log.Debug(c.Pake.SessionKey()) + c.Step1ChannelSecured = true + } + case "error": + c.spinner.Stop() + fmt.Print("\r") + err = fmt.Errorf("peer error: %s", m.Message) + return err + case "fileinfo": + var senderInfo SenderInfo + var decryptedBytes []byte + key, _ := c.Pake.SessionKey() + decryptedBytes, err = crypt.DecryptFromBytes(m.Bytes, key) + if err != nil { + log.Error(err) + return + } + err = json.Unmarshal(decryptedBytes, &senderInfo) + if err != nil { + log.Error(err) + return + } + c.FilesToTransfer = senderInfo.FilesToTransfer + fname := fmt.Sprintf("%d files", len(c.FilesToTransfer)) + if len(c.FilesToTransfer) == 1 { + fname = fmt.Sprintf("'%s'", c.FilesToTransfer[0].Name) + } + totalSize := int64(0) + for _, fi := range c.FilesToTransfer { + totalSize += fi.Size + } + c.spinner.Stop() + if !c.Options.NoPrompt { + fmt.Fprintf(os.Stderr, "\rAccept %s (%s) from machine '%s'? (y/n) ", fname, utils.ByteCountDecimal(totalSize), senderInfo.MachineID) + if strings.ToLower(strings.TrimSpace(utils.GetInput(""))) != "y" { + err = c.redisdb.Publish(c.nameOutChannel, Message{ + Type: "error", + Message: "refusing files", + }.String()).Err() + return fmt.Errorf("refused files") + } + } else { + fmt.Fprintf(os.Stderr, "\rReceiving %s (%s) from machine '%s'\n", fname, utils.ByteCountDecimal(totalSize), senderInfo.MachineID) + } + c.log.Debug(c.FilesToTransfer) + c.Step2FileInfoTransfered = true + case "recipientready": + var remoteFile RemoteFileRequest + var decryptedBytes []byte + key, _ := c.Pake.SessionKey() + decryptedBytes, err = crypt.DecryptFromBytes(m.Bytes, key) + if err != nil { + log.Error(err) + return + } + err = json.Unmarshal(decryptedBytes, &remoteFile) + if err != nil { + return + } + c.FilesToTransferCurrentNum = remoteFile.FilesToTransferCurrentNum + c.CurrentFileChunks = remoteFile.CurrentFileChunks + c.Step3RecipientRequestFile = true + case "datachannel-offer": + err = c.dataChannelReceive() + if err != nil { + return + } + err = c.recvSess.SetSDP(m.Message) + if err != nil { + return + } + var answer string + answer, err = c.recvSess.CreateAnswer() + if err != nil { + return + } + // Output the answer in base64 so we can paste it in browser + err = c.redisdb.Publish(c.nameOutChannel, Message{ + Type: "datachannel-answer", + Message: answer, + Num: m.Num, + }.String()).Err() + // start receiving data + pathToFile := path.Join(c.FilesToTransfer[c.FilesToTransferCurrentNum].FolderRemote, c.FilesToTransfer[c.FilesToTransferCurrentNum].Name) + c.spinner.Stop() + key, _ := c.Pake.SessionKey() + c.recvSess.ReceiveData(pathToFile, c.FilesToTransfer[c.FilesToTransferCurrentNum].Size, key) + log.Debug("sending close-sender") + err = c.redisdb.Publish(c.nameOutChannel, Message{ + Type: "close-sender", + }.String()).Err() + case "datachannel-answer": + c.log.Debug("got answer:", m.Message) + // Apply the answer as the remote description + err = c.sendSess.SetSDP(m.Message) + pathToFile := path.Join(c.FilesToTransfer[c.FilesToTransferCurrentNum].FolderSource, c.FilesToTransfer[c.FilesToTransferCurrentNum].Name) + c.spinner.Stop() + fmt.Fprintf(os.Stderr, "\r\nTransfering...\n") + key, _ := c.Pake.SessionKey() + c.sendSess.TransferFile(pathToFile, key) + case "close-sender": + log.Debug("close-sender received...") + c.Step4FileTransfer = false + c.Step3RecipientRequestFile = false + c.sendSess.StopSending() + log.Debug("sending close-recipient") + err = c.redisdb.Publish(c.nameOutChannel, Message{ + Type: "close-recipient", + Num: m.Num, + }.String()).Err() + case "close-recipient": + c.Step4FileTransfer = false + c.Step3RecipientRequestFile = false + } + if err != nil { + return + } + err = c.updateState() + + return +} + +func (c *Client) updateState() (err error) { + if c.Options.IsSender && c.Step1ChannelSecured && !c.Step2FileInfoTransfered { + var b []byte + b, err = json.Marshal(SenderInfo{ + MachineID: c.machineID, + FilesToTransfer: c.FilesToTransfer, + }) + if err != nil { + log.Error(err) + return + } + key, _ := c.Pake.SessionKey() + err = c.redisdb.Publish(c.nameOutChannel, Message{ + Type: "fileinfo", + Bytes: crypt.EncryptToBytes(b, key), + }.String()).Err() + if err != nil { + return + } + c.Step2FileInfoTransfered = true + } + if !c.Options.IsSender && c.Step2FileInfoTransfered && !c.Step3RecipientRequestFile { + // find the next file to transfer and send that number + // if the files are the same size, then look for missing chunks + finished := true + for i, fileInfo := range c.FilesToTransfer { + if i < c.FilesToTransferCurrentNum { + continue + } + fileHash, errHash := utils.HashFile(path.Join(fileInfo.FolderRemote, fileInfo.Name)) + if errHash != nil || !bytes.Equal(fileHash, fileInfo.Hash) { + if !bytes.Equal(fileHash, fileInfo.Hash) { + log.Debugf("hashes are not equal %x != %x", fileHash, fileInfo.Hash) + } + finished = false + c.FilesToTransferCurrentNum = i + break + } + // TODO: print out something about this file already existing + } + if finished { + // TODO: do the last finishing stuff + log.Debug("finished") + err = c.redisdb.Publish(c.nameOutChannel, Message{ + Type: "finished", + }.String()).Err() + if err != nil { + panic(err) + } + } + + // start initiating the process to receive a new file + log.Debugf("working on file %d", c.FilesToTransferCurrentNum) + + // recipient requests the file and chunks (if empty, then should receive all chunks) + bRequest, _ := json.Marshal(RemoteFileRequest{ + CurrentFileChunks: c.CurrentFileChunks, + FilesToTransferCurrentNum: c.FilesToTransferCurrentNum, + }) + key, _ := c.Pake.SessionKey() + err = c.redisdb.Publish(c.nameOutChannel, Message{ + Type: "recipientready", + Bytes: crypt.EncryptToBytes(bRequest, key), + }.String()).Err() + if err != nil { + return + } + c.Step3RecipientRequestFile = true + err = c.dataChannelReceive() + } + if c.Options.IsSender && c.Step3RecipientRequestFile && !c.Step4FileTransfer { + c.log.Debug("start sending data!") + err = c.dataChannelSend() + c.Step4FileTransfer = true + } + return +} + +func (c *Client) dataChannelReceive() (err error) { + c.recvSess = receiver.NewWith(receiver.Config{}) + err = c.recvSess.CreateConnection() + if err != nil { + return + } + c.recvSess.CreateDataHandler() + return +} + +func (c *Client) dataChannelSend() (err error) { + c.sendSess = sender.NewWith(sender.Config{ + Configuration: common.Configuration{ + OnCompletion: func() { + }, + }, + }) + + if err := c.sendSess.CreateConnection(); err != nil { + log.Error(err) + return err + } + if err := c.sendSess.CreateDataChannel(); err != nil { + log.Error(err) + return err + } + offer, err := c.sendSess.CreateOffer() + if err != nil { + log.Error(err) + return err + } + + // sending offer + err = c.redisdb.Publish(c.nameOutChannel, Message{ + Type: "datachannel-offer", + Message: offer, + }.String()).Err() + if err != nil { + return + } + + return +} + +// MissingChunks returns the positions of missing chunks. +// If file doesn't exist, it returns an empty chunk list (all chunks). +// If the file size is not the same as requested, it returns an empty chunk list (all chunks). +func MissingChunks(fname string, fsize int64, chunkSize int) (chunks []int64) { + fstat, err := os.Stat(fname) + if fstat.Size() != fsize { + return + } + + f, err := os.Open(fname) + if err != nil { + return + } + defer f.Close() + + buffer := make([]byte, chunkSize) + emptyBuffer := make([]byte, chunkSize) + chunkNum := 0 + chunks = make([]int64, int64(math.Ceil(float64(fsize)/float64(chunkSize)))) + var currentLocation int64 + for { + bytesread, err := f.Read(buffer) + if err != nil { + break + } + if bytes.Equal(buffer[:bytesread], emptyBuffer[:bytesread]) { + chunks[chunkNum] = currentLocation + } + currentLocation += int64(bytesread) + } + if chunkNum == 0 { + chunks = []int64{} + } else { + chunks = chunks[:chunkNum] + } + return +} + +// Encode encodes the input in base64 +// It can optionally zip the input before encoding +func Encode(obj interface{}) string { + b, err := json.Marshal(obj) + if err != nil { + panic(err) + } + + return base64.StdEncoding.EncodeToString(b) +} + +// Decode decodes the input from base64 +// It can optionally unzip the input after decoding +func Decode(in string, obj interface{}) (err error) { + b, err := base64.StdEncoding.DecodeString(in) + if err != nil { + return + } + + err = json.Unmarshal(b, obj) + return } diff --git a/src/croc/croc_test.go b/src/croc/croc_test.go deleted file mode 100644 index 65b07c60..00000000 --- a/src/croc/croc_test.go +++ /dev/null @@ -1,81 +0,0 @@ -package croc - -import ( - "crypto/rand" - "fmt" - "io/ioutil" - "os" - "sync" - "testing" - "time" - - "github.com/schollz/croc/src/utils" - "github.com/stretchr/testify/assert" -) - -func sendAndReceive(t *testing.T, forceSend int, local bool) { - room := utils.GetRandomName() - var startTime time.Time - var durationPerMegabyte float64 - megabytes := 1 - if local { - megabytes = 100 - } - fname := generateRandomFile(megabytes) - var wg sync.WaitGroup - wg.Add(2) - go func() { - defer wg.Done() - c := Init(true) - c.NoLocal = !local - // c.AddressTCPPorts = []string{"8154", "8155"} - c.ForceSend = forceSend - c.UseCompression = true - c.UseEncryption = true - assert.Nil(t, c.Send(fname, room)) - }() - go func() { - defer wg.Done() - time.Sleep(5 * time.Second) - os.MkdirAll("test", 0755) - os.Chdir("test") - c := Init(true) - c.NoLocal = !local - // c.AddressTCPPorts = []string{"8154", "8155"} - c.ForceSend = forceSend - startTime = time.Now() - assert.Nil(t, c.Receive(room)) - durationPerMegabyte = float64(megabytes) / time.Since(startTime).Seconds() - assert.True(t, utils.Exists(fname)) - }() - wg.Wait() - os.Chdir("..") - os.RemoveAll("test") - os.Remove(fname) - fmt.Printf("\n-----\n%2.1f MB/s\n----\n", durationPerMegabyte) -} - -func TestSendReceivePubWebsockets(t *testing.T) { - sendAndReceive(t, 1, false) -} - -func TestSendReceivePubTCP(t *testing.T) { - sendAndReceive(t, 2, false) -} - -func TestSendReceiveLocalWebsockets(t *testing.T) { - sendAndReceive(t, 1, true) -} - -// func TestSendReceiveLocalTCP(t *testing.T) { -// sendAndReceive(t, 2, true) -// } - -func generateRandomFile(megabytes int) (fname string) { - // generate a random file - bigBuff := make([]byte, 1024*1024*megabytes) - rand.Read(bigBuff) - fname = fmt.Sprintf("%dmb.file", megabytes) - ioutil.WriteFile(fname, bigBuff, 0666) - return -} diff --git a/src/croc/models.go b/src/croc/models.go deleted file mode 100644 index de2650a3..00000000 --- a/src/croc/models.go +++ /dev/null @@ -1,7 +0,0 @@ -package croc - -type WebSocketMessage struct { - messageType int - message []byte - err error -} diff --git a/src/croc/recipient.go b/src/croc/recipient.go deleted file mode 100644 index 48e529fd..00000000 --- a/src/croc/recipient.go +++ /dev/null @@ -1,624 +0,0 @@ -package croc - -import ( - "bufio" - "bytes" - "encoding/json" - "fmt" - "io/ioutil" - "os" - "strconv" - "strings" - "sync" - "time" - - log "github.com/cihub/seelog" - humanize "github.com/dustin/go-humanize" - "github.com/gorilla/websocket" - "github.com/pkg/errors" - "github.com/schollz/croc/src/comm" - "github.com/schollz/croc/src/compress" - "github.com/schollz/croc/src/crypt" - "github.com/schollz/croc/src/logger" - "github.com/schollz/croc/src/models" - "github.com/schollz/croc/src/utils" - "github.com/schollz/croc/src/zipper" - "github.com/schollz/pake" - "github.com/schollz/progressbar/v2" - "github.com/schollz/spinner" -) - -var DebugLevel string - -// Receive is the async operation to receive a file -func (cr *Croc) startRecipient(forceSend int, serverAddress string, tcpPorts []string, isLocal bool, done chan error, c *websocket.Conn, codephrase string, noPrompt bool, useStdout bool) { - logger.SetLogLevel(DebugLevel) - err := cr.receive(forceSend, serverAddress, tcpPorts, isLocal, c, codephrase, noPrompt, useStdout) - if err != nil && strings.HasPrefix(err.Error(), "websocket: close 100") { - err = nil - } - done <- err -} - -func (cr *Croc) receive(forceSend int, serverAddress string, tcpPorts []string, isLocal bool, c *websocket.Conn, codephrase string, noPrompt bool, useStdout bool) (err error) { - var sessionKey []byte - var transferTime time.Duration - var hash256 []byte - var progressFile string - var resumeFile bool - var tcpConnections []comm.Comm - var Q *pake.Pake - - dataChan := make(chan []byte, 1024*1024) - isConnectedIfUsingTCP := make(chan bool) - blocks := []string{} - - useWebsockets := true - switch forceSend { - case 0: - if !isLocal { - useWebsockets = false - } - case 1: - useWebsockets = true - case 2: - useWebsockets = false - } - - // start a spinner - spin := spinner.New(spinner.CharSets[9], 100*time.Millisecond) - spin.Writer = os.Stderr - spin.Suffix = " connecting..." - cr.StateString = "Connecting as recipient..." - spin.Start() - defer spin.Stop() - - // both parties should have a weak key - pw := []byte(codephrase) - - // start the reader - websocketMessages := make(chan WebSocketMessage, 1024) - go func() { - defer func() { - if r := recover(); r != nil { - log.Debugf("recovered from %s", r) - } - }() - for { - messageType, message, err := c.ReadMessage() - websocketMessages <- WebSocketMessage{messageType, message, err} - } - }() - - step := 0 - for { - var websocketMessageMain WebSocketMessage - // websocketMessageMain = <-websocketMessages - timeWaitingForMessage := time.Now() - for { - done := false - select { - case websocketMessageMain = <-websocketMessages: - done = true - default: - time.Sleep(10 * time.Millisecond) - } - if done { - break - } - if time.Since(timeWaitingForMessage).Seconds() > 3 && step == 0 { - return fmt.Errorf("You are trying to receive a file with no sender.") - } - } - - messageType := websocketMessageMain.messageType - message := websocketMessageMain.message - err := websocketMessageMain.err - if err != nil { - return err - } - if messageType == websocket.PongMessage || messageType == websocket.PingMessage { - continue - } - if messageType == websocket.TextMessage && bytes.Equal(message, []byte("interrupt")) { - return errors.New("\rinterrupted by other party") - } - - log.Debugf("got %d: %s", messageType, message) - switch step { - case 0: - spin.Stop() - spin.Suffix = " performing PAKE..." - cr.StateString = "Performing PAKE..." - spin.Start() - // sender has initiated, sends their initial data - var initialData models.Initial - err = json.Unmarshal(message, &initialData) - if err != nil { - err = errors.Wrap(err, "incompatible versions of croc") - return err - } - cr.OtherIP = initialData.IPAddress - log.Debugf("sender IP: %s", cr.OtherIP) - - // check whether the version strings are compatible - versionStringsOther := strings.Split(initialData.VersionString, ".") - versionStringsSelf := strings.Split(cr.Version, ".") - if len(versionStringsOther) == 3 && len(versionStringsSelf) == 3 { - if versionStringsSelf[0] != versionStringsOther[0] || versionStringsSelf[1] != versionStringsOther[1] { - return fmt.Errorf("version sender %s is not compatible with recipient %s", cr.Version, initialData.VersionString) - } - } - - // initialize the PAKE with the curve sent from the sender - Q, err = pake.InitCurve(pw, 1, initialData.CurveType, 1*time.Millisecond) - if err != nil { - err = errors.Wrap(err, "incompatible curve type") - return err - } - - // recipient begins by sending back initial data to sender - ip := "" - if isLocal { - ip = utils.LocalIP() - } else { - ip, _ = utils.PublicIP() - } - initialData.VersionString = cr.Version - initialData.IPAddress = ip - bInitialData, _ := json.Marshal(initialData) - c.WriteMessage(websocket.BinaryMessage, bInitialData) - case 1: - // Q receives u - log.Debugf("[%d] Q computes k, sends H(k), v back to P", step) - if err := Q.Update(message); err != nil { - return fmt.Errorf("Recipient is using wrong code phrase.") - } - - // Q has the session key now, but we will still check if its valid - sessionKey, err = Q.SessionKey() - if err != nil { - return fmt.Errorf("Recipient is using wrong code phrase.") - } - log.Debugf("%x\n", sessionKey) - - // initialize TCP connections if using (possible, but unlikely, race condition) - go func() { - log.Debug("initializing TCP connections") - if !useWebsockets { - log.Debugf("connecting to server") - tcpConnections = make([]comm.Comm, len(tcpPorts)) - var wg sync.WaitGroup - wg.Add(len(tcpPorts)) - for i, tcpPort := range tcpPorts { - go func(i int, tcpPort string) { - defer wg.Done() - log.Debugf("connecting to %d", i) - var message string - tcpConnections[i], message, err = connectToTCPServer(utils.SHA256(fmt.Sprintf("%d%x", i, sessionKey)), serverAddress+":"+tcpPort) - if err != nil { - log.Error(err) - } - if message != "recipient" { - log.Errorf("got wrong message: %s", message) - } - }(i, tcpPort) - } - wg.Wait() - log.Debugf("fully connected") - } - isConnectedIfUsingTCP <- true - }() - - c.WriteMessage(websocket.BinaryMessage, Q.Bytes()) - case 2: - log.Debugf("[%d] Q recieves H(k) from P", step) - // check if everything is still kosher with our computed session key - if err := Q.Update(message); err != nil { - log.Debug(err) - return fmt.Errorf("Recipient is using wrong code phrase.") - } - c.WriteMessage(websocket.BinaryMessage, []byte("ready")) - case 3: - spin.Stop() - cr.StateString = "Recieving file info..." - - // unmarshal the file info - log.Debugf("[%d] recieve file info", step) - // do decryption on the file stats - enc, err := crypt.FromBytes(message) - if err != nil { - return err - } - decryptedFileData, err := enc.Decrypt(sessionKey) - if err != nil { - return err - } - err = json.Unmarshal(decryptedFileData, &cr.FileInfo) - if err != nil { - return err - } - log.Debugf("got file stats: %+v", cr.FileInfo) - - // determine if the file is resuming or not - progressFile = fmt.Sprintf("%s.progress", cr.FileInfo.SentName) - overwritingOrReceiving := "Receiving" - if utils.Exists(cr.FileInfo.Name) || utils.Exists(cr.FileInfo.SentName) { - overwritingOrReceiving = "Overwriting" - if utils.Exists(progressFile) { - overwritingOrReceiving = "Resume receiving" - resumeFile = true - } - } - - // send blocks - if resumeFile { - fileWithBlocks, _ := os.Open(progressFile) - scanner := bufio.NewScanner(fileWithBlocks) - for scanner.Scan() { - blocks = append(blocks, strings.TrimSpace(scanner.Text())) - } - fileWithBlocks.Close() - } - blocksBytes, _ := json.Marshal(blocks) - // encrypt the block data and send - encblockBytes := crypt.Encrypt(blocksBytes, sessionKey) - - // wait for TCP connections if using them - _ = <-isConnectedIfUsingTCP - c.WriteMessage(websocket.BinaryMessage, encblockBytes.Bytes()) - - // prompt user about the file - fileOrFolder := "file" - if cr.FileInfo.IsDir { - fileOrFolder = "folder" - } - cr.WindowReceivingString = fmt.Sprintf("%s %s (%s) into: %s", - overwritingOrReceiving, - fileOrFolder, - humanize.Bytes(uint64(cr.FileInfo.Size)), - cr.FileInfo.Name, - ) - fmt.Fprintf(os.Stderr, "\r%s\n", - cr.WindowReceivingString, - ) - if !noPrompt { - if "y" != utils.GetInput("ok? (y/N): ") { - fmt.Fprintf(os.Stderr, "Cancelling request") - c.WriteMessage(websocket.BinaryMessage, []byte("no")) - return nil - } - } - if cr.WindowRecipientPrompt { - // wait until it switches to false - // the window should then set WindowRecipientAccept - for { - if !cr.WindowRecipientPrompt { - if cr.WindowRecipientAccept { - break - } else { - fmt.Fprintf(os.Stderr, "Cancelling request") - c.WriteMessage(websocket.BinaryMessage, []byte("no")) - return nil - } - } - time.Sleep(10 * time.Millisecond) - } - } - - // await file - // erase file if overwriting - if overwritingOrReceiving == "Overwriting" { - os.Remove(cr.FileInfo.SentName) - } - var f *os.File - if utils.Exists(cr.FileInfo.SentName) && resumeFile { - if !useWebsockets { - f, err = os.OpenFile(cr.FileInfo.SentName, os.O_WRONLY, 0644) - } else { - f, err = os.OpenFile(cr.FileInfo.SentName, os.O_APPEND|os.O_WRONLY, 0644) - } - if err != nil { - log.Error(err) - return err - } - } else { - f, err = os.Create(cr.FileInfo.SentName) - if err != nil { - log.Error(err) - return err - } - if !useWebsockets { - if err = f.Truncate(cr.FileInfo.Size); err != nil { - log.Error(err) - return err - } - } - } - - blockSize := 0 - if useWebsockets { - blockSize = models.WEBSOCKET_BUFFER_SIZE / 8 - } else { - blockSize = models.TCP_BUFFER_SIZE / 2 - } - // start the ui for pgoress - cr.StateString = "Recieving file..." - bytesWritten := 0 - fmt.Fprintf(os.Stderr, "\nReceiving (<-%s)...\n", cr.OtherIP) - cr.Bar = progressbar.NewOptions( - int(cr.FileInfo.Size), - progressbar.OptionSetRenderBlankState(true), - progressbar.OptionSetBytes(int(cr.FileInfo.Size)), - progressbar.OptionSetWriter(os.Stderr), - progressbar.OptionThrottle(1/60*time.Second), - ) - cr.Bar.Add((len(blocks) * blockSize)) - finished := make(chan bool) - - go func(finished chan bool, dataChan chan []byte) (err error) { - // remove previous progress - var fProgress *os.File - var progressErr error - if resumeFile { - fProgress, progressErr = os.OpenFile(progressFile, os.O_APPEND|os.O_WRONLY, 0644) - bytesWritten = len(blocks) * blockSize - } else { - os.Remove(progressFile) - fProgress, progressErr = os.Create(progressFile) - } - if progressErr != nil { - panic(progressErr) - } - defer fProgress.Close() - - blocksWritten := 0.0 - blocksToWrite := float64(cr.FileInfo.Size) - if useWebsockets { - blocksToWrite = blocksToWrite/float64(models.WEBSOCKET_BUFFER_SIZE/8) - float64(len(blocks)) - } else { - blocksToWrite = blocksToWrite/float64(models.TCP_BUFFER_SIZE/2) - float64(len(blocks)) - } - for { - message := <-dataChan - // do decryption - var enc crypt.Encryption - err = json.Unmarshal(message, &enc) - if err != nil { - // log.Errorf("%s: [%s] [%+v] (%d/%d) %+v", err.Error(), message, message, len(message), numBytes, bs) - log.Error(err) - return err - } - decrypted, err := enc.Decrypt(sessionKey, !cr.FileInfo.IsEncrypted) - if err != nil { - log.Error(err) - return err - } - - // get location if TCP - var locationToWrite int - if !useWebsockets { - pieces := bytes.SplitN(decrypted, []byte("-"), 2) - decrypted = pieces[1] - locationToWrite, _ = strconv.Atoi(string(pieces[0])) - } - - // do decompression - if cr.FileInfo.IsCompressed && !cr.FileInfo.IsDir { - decrypted = compress.Decompress(decrypted) - } - - var n int - if !useWebsockets { - if err != nil { - log.Error(err) - return err - } - n, err = f.WriteAt(decrypted, int64(locationToWrite)) - fProgress.WriteString(fmt.Sprintf("%d\n", locationToWrite)) - log.Debugf("wrote %d bytes to location %d (%2.0f/%2.0f)", n, locationToWrite, blocksWritten, blocksToWrite) - } else { - // write to file - n, err = f.Write(decrypted) - log.Debugf("wrote %d bytes to location %d (%2.0f/%2.0f)", n, bytesWritten, blocksWritten, blocksToWrite) - fProgress.WriteString(fmt.Sprintf("%d\n", bytesWritten)) - } - if err != nil { - log.Error(err) - return err - } - - // update the bytes written - bytesWritten += n - blocksWritten += 1.0 - // update the progress bar - cr.Bar.Add(n) - if int64(bytesWritten) == cr.FileInfo.Size || blocksWritten >= blocksToWrite { - log.Debug("finished", int64(bytesWritten), cr.FileInfo.Size, blocksWritten, blocksToWrite) - break - } - } - finished <- true - return - }(finished, dataChan) - - log.Debug("telling sender i'm ready") - c.WriteMessage(websocket.BinaryMessage, []byte("ready")) - - startTime := time.Now() - if useWebsockets { - for { - // read from websockets - websocketMessageData := <-websocketMessages - if bytes.HasPrefix(websocketMessageData.message, []byte("error")) { - return fmt.Errorf("%s", websocketMessageData.message) - } - if websocketMessageData.messageType != websocket.BinaryMessage { - continue - } - if err != nil { - log.Error(err) - return err - } - if bytes.Equal(websocketMessageData.message, []byte("magic")) { - log.Debug("got magic") - break - } - dataChan <- websocketMessageData.message - } - } else { - log.Debugf("starting listening with tcp with %d connections", len(tcpConnections)) - - // check to see if any messages are sent - stopMessageSignal := make(chan bool, 1) - errorsDuringTransfer := make(chan error, 24) - go func() { - for { - select { - case sig := <-stopMessageSignal: - errorsDuringTransfer <- nil - log.Debugf("got message signal: %+v", sig) - return - case wsMessage := <-websocketMessages: - log.Debugf("got message: %s", wsMessage.message) - if bytes.HasPrefix(wsMessage.message, []byte("error")) { - log.Debug("stopping transfer") - for i := 0; i < len(tcpConnections)+1; i++ { - errorsDuringTransfer <- fmt.Errorf("%s", wsMessage.message) - } - return - } - default: - continue - } - } - }() - - // using TCP - go func() { - var wg sync.WaitGroup - wg.Add(len(tcpConnections)) - for i := range tcpConnections { - defer func(i int) { - log.Debugf("closing connection %d", i) - tcpConnections[i].Close() - }(i) - go func(wg *sync.WaitGroup, j int) { - defer wg.Done() - for { - select { - case _ = <-errorsDuringTransfer: - log.Debugf("%d got stop", i) - return - default: - } - - log.Debugf("waiting to read on %d", j) - // read from TCP connection - message, _, _, err := tcpConnections[j].Read() - // log.Debugf("message: %s", message) - if err != nil { - panic(err) - } - if bytes.Equal(message, []byte("magic")) { - log.Debugf("%d got magic, leaving", j) - return - } - dataChan <- message - } - }(&wg, i) - } - log.Debug("waiting for tcp goroutines") - wg.Wait() - errorsDuringTransfer <- nil - }() - - // block until this is done - - log.Debug("waiting for error") - errorDuringTransfer := <-errorsDuringTransfer - log.Debug("sending stop message signal") - stopMessageSignal <- true - if errorDuringTransfer != nil { - log.Debugf("got error during transfer: %s", errorDuringTransfer.Error()) - return errorDuringTransfer - } - } - - _ = <-finished - log.Debug("telling sender i'm done") - c.WriteMessage(websocket.BinaryMessage, []byte("done")) - // we are finished - transferTime = time.Since(startTime) - - // close file - err = f.Close() - if err != nil { - return err - } - - // finish bar - cr.Bar.Finish() - - // check hash - hash256, err = utils.HashFile(cr.FileInfo.SentName) - if err != nil { - log.Error(err) - return err - } - // tell the sender the hash so they can quit - c.WriteMessage(websocket.BinaryMessage, append([]byte("hash:"), hash256...)) - case 4: - // receive the hash from the sender so we can check it and quit - log.Debugf("got hash: %x", message) - if bytes.Equal(hash256, message) { - // open directory - if cr.FileInfo.IsDir { - err = zipper.UnzipFile(cr.FileInfo.SentName, ".") - if DebugLevel != "debug" { - os.Remove(cr.FileInfo.SentName) - } - } else { - err = nil - } - if err == nil { - if useStdout && !cr.FileInfo.IsDir { - var bFile []byte - bFile, err = ioutil.ReadFile(cr.FileInfo.SentName) - if err != nil { - return err - } - os.Stdout.Write(bFile) - os.Remove(cr.FileInfo.SentName) - } - transferRate := float64(cr.FileInfo.Size) / 1000000.0 / transferTime.Seconds() - transferType := "MB/s" - if transferRate < 1 { - transferRate = float64(cr.FileInfo.Size) / 1000.0 / transferTime.Seconds() - transferType = "kB/s" - } - folderOrFile := "file" - if cr.FileInfo.IsDir { - folderOrFile = "folder" - } - if useStdout { - cr.FileInfo.Name = "stdout" - } - fmt.Fprintf(os.Stderr, "\nReceived %s written to %s (%2.1f %s)", folderOrFile, cr.FileInfo.Name, transferRate, transferType) - os.Remove(progressFile) - cr.StateString = fmt.Sprintf("Received %s written to %s (%2.1f %s)", folderOrFile, cr.FileInfo.Name, transferRate, transferType) - } - return err - } else { - if DebugLevel != "debug" { - log.Debug("removing corrupted file") - os.Remove(cr.FileInfo.SentName) - } - return errors.New("file corrupted") - } - default: - return fmt.Errorf("unknown step") - } - step++ - } -} diff --git a/src/croc/sender.go b/src/croc/sender.go deleted file mode 100644 index 0615fecf..00000000 --- a/src/croc/sender.go +++ /dev/null @@ -1,570 +0,0 @@ -package croc - -import ( - "bytes" - "encoding/json" - "fmt" - "io" - "net" - "os" - "path/filepath" - "strconv" - "strings" - "sync" - "time" - - log "github.com/cihub/seelog" - "github.com/gorilla/websocket" - "github.com/pkg/errors" - "github.com/schollz/croc/src/comm" - "github.com/schollz/croc/src/compress" - "github.com/schollz/croc/src/crypt" - "github.com/schollz/croc/src/logger" - "github.com/schollz/croc/src/models" - "github.com/schollz/croc/src/utils" - "github.com/schollz/croc/src/zipper" - "github.com/schollz/pake" - progressbar "github.com/schollz/progressbar/v2" - "github.com/schollz/spinner" -) - -// Send is the async call to send data -func (cr *Croc) startSender(forceSend int, serverAddress string, tcpPorts []string, isLocal bool, done chan error, c *websocket.Conn, fname string, codephrase string, useCompression bool, useEncryption bool) { - logger.SetLogLevel(DebugLevel) - log.Debugf("sending %s", fname) - err := cr.send(forceSend, serverAddress, tcpPorts, isLocal, c, fname, codephrase, useCompression, useEncryption) - if err != nil && strings.HasPrefix(err.Error(), "websocket: close 100") { - err = nil - } - done <- err -} - -func (cr *Croc) send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool, c *websocket.Conn, fname string, codephrase string, useCompression bool, useEncryption bool) (err error) { - var f *os.File - defer f.Close() // ignore the error if it wasn't opened :( - var fileHash []byte - var startTransfer time.Time - var tcpConnections []comm.Comm - blocksToSkip := make(map[int64]struct{}) - isConnectedIfUsingTCP := make(chan bool) - - type DataChan struct { - b []byte - currentPostition int64 - bytesRead int - err error - } - dataChan := make(chan DataChan, 1024*1024) - defer close(dataChan) - - useWebsockets := true - switch forceSend { - case 0: - if !isLocal { - useWebsockets = false - } - case 1: - useWebsockets = true - case 2: - useWebsockets = false - } - - fileReady := make(chan error) - - // normalize the file name - fname, err = filepath.Abs(fname) - if err != nil { - return err - } - _, filename := filepath.Split(fname) - - // get ready to generate session key - var sessionKey []byte - - // start a spinner - spin := spinner.New(spinner.CharSets[9], 100*time.Millisecond) - spin.Writer = os.Stderr - defer spin.Stop() - - // both parties should have a weak key - pw := []byte(codephrase) - // initialize sender P ("0" indicates sender) - P, err := pake.InitCurve(pw, 0, cr.CurveType, 1*time.Millisecond) - if err != nil { - return - } - - // start the reader - websocketMessages := make(chan WebSocketMessage, 1024) - go func() { - defer func() { - if r := recover(); r != nil { - log.Debugf("recovered from %s", r) - } - }() - for { - messageType, message, err := c.ReadMessage() - websocketMessages <- WebSocketMessage{messageType, message, err} - } - }() - - step := 0 - for { - websocketMessage := <-websocketMessages - messageType := websocketMessage.messageType - message := websocketMessage.message - errRead := websocketMessage.err - if errRead != nil { - return errRead - } - if messageType == websocket.PongMessage || messageType == websocket.PingMessage { - continue - } - if messageType == websocket.TextMessage && bytes.HasPrefix(message, []byte("interrupt")) { - return errors.New("\rinterrupted by other party") - } - if messageType == websocket.TextMessage && bytes.HasPrefix(message, []byte("err")) { - return errors.New("\r" + string(message)) - } - log.Debugf("got %d: %s", messageType, message) - switch step { - case 0: - // sender initiates communication - ip := "" - if isLocal { - ip = utils.LocalIP() - } else { - ip, _ = utils.PublicIP() - } - - initialData := models.Initial{ - CurveType: cr.CurveType, - IPAddress: ip, - VersionString: cr.Version, // version should match - } - bInitialData, _ := json.Marshal(initialData) - // send the initial data - c.WriteMessage(websocket.BinaryMessage, bInitialData) - case 1: - // first receive the initial data from the recipient - var initialData models.Initial - err = json.Unmarshal(message, &initialData) - if err != nil { - err = errors.Wrap(err, "incompatible versions of croc") - return - } - cr.OtherIP = initialData.IPAddress - log.Debugf("recipient IP: %s", cr.OtherIP) - - go func() { - // recipient might want file! start gathering information about file - fstat, err := os.Stat(fname) - if err != nil { - fileReady <- err - return - } - cr.FileInfo = models.FileStats{ - Name: filename, - Size: fstat.Size(), - ModTime: fstat.ModTime(), - IsDir: fstat.IsDir(), - SentName: fstat.Name(), - IsCompressed: useCompression, - IsEncrypted: useEncryption, - } - if cr.FileInfo.IsDir { - // zip the directory - cr.FileInfo.SentName, err = zipper.ZipFile(fname, true) - if err != nil { - log.Error(err) - fileReady <- err - return - } - fname = cr.FileInfo.SentName - - fstat, err := os.Stat(fname) - if err != nil { - fileReady <- err - return - } - // get new size - cr.FileInfo.Size = fstat.Size() - } - - // open the file - f, err = os.Open(fname) - if err != nil { - fileReady <- err - return - } - fileReady <- nil - - }() - - // send pake data - log.Debugf("[%d] first, P sends u to Q", step) - c.WriteMessage(websocket.BinaryMessage, P.Bytes()) - // start PAKE spinnner - spin.Suffix = " performing PAKE..." - cr.StateString = "Performing PAKE..." - spin.Start() - case 2: - // P recieves H(k),v from Q - log.Debugf("[%d] P computes k, H(k), sends H(k) to Q", step) - err := P.Update(message) - c.WriteMessage(websocket.BinaryMessage, P.Bytes()) - if err != nil { - return fmt.Errorf("Recipient is using wrong code phrase.") - } - - sessionKey, _ = P.SessionKey() - // check(err) - log.Debugf("%x\n", sessionKey) - - // wait for readiness - spin.Stop() - spin.Suffix = " waiting for recipient ok..." - cr.StateString = "Waiting for recipient ok...." - spin.Start() - case 3: - log.Debugf("[%d] recipient declares readiness for file info", step) - if !bytes.HasPrefix(message, []byte("ready")) { - return errors.New("Recipient refused file") - } - - err = <-fileReady // block until file is ready - if err != nil { - return err - } - fstatsBytes, err := json.Marshal(cr.FileInfo) - if err != nil { - return err - } - - // encrypt the file meta data - enc := crypt.Encrypt(fstatsBytes, sessionKey) - // send the file meta data - c.WriteMessage(websocket.BinaryMessage, enc.Bytes()) - case 4: - log.Debugf("[%d] recipient gives blocks", step) - // recipient sends blocks, and sender does not send anything back - // determine if any blocks were sent to skip - enc, err := crypt.FromBytes(message) - if err != nil { - log.Error(err) - return err - } - decrypted, err := enc.Decrypt(sessionKey) - if err != nil { - err = errors.Wrap(err, "could not decrypt blocks with session key") - log.Error(err) - return err - } - - var blocks []string - errBlocks := json.Unmarshal(decrypted, &blocks) - if errBlocks == nil { - for _, block := range blocks { - blockInt64, errBlock := strconv.Atoi(block) - if errBlock == nil { - blocksToSkip[int64(blockInt64)] = struct{}{} - } - } - } - log.Debugf("found blocks: %+v", blocksToSkip) - - // connect to TCP in background - tcpConnections = make([]comm.Comm, len(tcpPorts)) - go func() { - if !useWebsockets { - log.Debugf("connecting to server") - var wg sync.WaitGroup - wg.Add(len(tcpPorts)) - for i, tcpPort := range tcpPorts { - go func(i int, tcpPort string) { - defer wg.Done() - log.Debugf("connecting to %s on connection %d", tcpPort, i) - var message string - tcpConnections[i], message, err = connectToTCPServer(utils.SHA256(fmt.Sprintf("%d%x", i, sessionKey)), serverAddress+":"+tcpPort) - if err != nil { - log.Error(err) - } - if message != "sender" { - log.Errorf("got wrong message: %s", message) - } - }(i, tcpPort) - } - wg.Wait() - } - isConnectedIfUsingTCP <- true - }() - - // start loading the file into memory - // start streaming encryption/compression - if cr.FileInfo.IsDir { - // remove file if zipped - defer os.Remove(cr.FileInfo.SentName) - } - go func(dataChan chan DataChan) { - var buffer []byte - if useWebsockets { - buffer = make([]byte, models.WEBSOCKET_BUFFER_SIZE/8) - } else { - buffer = make([]byte, models.TCP_BUFFER_SIZE/2) - } - - currentPostition := int64(0) - for { - bytesread, err := f.Read(buffer) - if bytesread > 0 { - if _, ok := blocksToSkip[currentPostition]; ok { - log.Debugf("skipping the sending of block %d", currentPostition) - currentPostition += int64(bytesread) - continue - } - - // do compression - var compressedBytes []byte - if useCompression && !cr.FileInfo.IsDir { - compressedBytes = compress.Compress(buffer[:bytesread]) - } else { - compressedBytes = buffer[:bytesread] - } - - // if using TCP, prepend the location to write the data to in the resulting file - if !useWebsockets { - compressedBytes = append([]byte(fmt.Sprintf("%d-", currentPostition)), compressedBytes...) - } - - // do encryption - enc := crypt.Encrypt(compressedBytes, sessionKey, !useEncryption) - encBytes, err := json.Marshal(enc) - if err != nil { - dataChan <- DataChan{ - b: nil, - bytesRead: 0, - err: err, - } - return - } - - dataChan <- DataChan{ - b: encBytes, - bytesRead: bytesread, - err: nil, - } - currentPostition += int64(bytesread) - } - if err != nil { - if err != io.EOF { - log.Error(err) - } - break - } - } - // finish - log.Debug("sending magic") - dataChan <- DataChan{ - b: []byte("magic"), - bytesRead: 0, - err: nil, - } - if !useWebsockets { - log.Debug("sending extra magic to %d others", len(tcpPorts)-1) - for i := 0; i < len(tcpPorts)-1; i++ { - log.Debug("sending magic") - dataChan <- DataChan{ - b: []byte("magic"), - bytesRead: 0, - err: nil, - } - } - } - }(dataChan) - - case 5: - spin.Stop() - - log.Debugf("[%d] recipient declares readiness for file data", step) - if !bytes.HasPrefix(message, []byte("ready")) { - return errors.New("Recipient refused file") - } - cr.StateString = "Transfer in progress..." - fmt.Fprintf(os.Stderr, "\rSending (->%s)...\n", cr.OtherIP) - // send file, compure hash simultaneously - startTransfer = time.Now() - - blockSize := 0 - if useWebsockets { - blockSize = models.WEBSOCKET_BUFFER_SIZE / 8 - } else { - blockSize = models.TCP_BUFFER_SIZE / 2 - } - cr.Bar = progressbar.NewOptions( - int(cr.FileInfo.Size), - progressbar.OptionSetRenderBlankState(true), - progressbar.OptionSetBytes(int(cr.FileInfo.Size)), - progressbar.OptionSetWriter(os.Stderr), - progressbar.OptionThrottle(1/60*time.Second), - ) - cr.Bar.Add(blockSize * len(blocksToSkip)) - - if useWebsockets { - for { - data := <-dataChan - if data.err != nil { - return data.err - } - cr.Bar.Add(data.bytesRead) - - // write data to websockets - err = c.WriteMessage(websocket.BinaryMessage, data.b) - if err != nil { - err = errors.Wrap(err, "problem writing message") - return err - } - if bytes.Equal(data.b, []byte("magic")) { - break - } - } - } else { - _ = <-isConnectedIfUsingTCP - log.Debug("connected and ready to send on tcp") - - // check to see if any messages are sent - stopMessageSignal := make(chan bool, 1) - errorsDuringTransfer := make(chan error, 24) - go func() { - for { - select { - case sig := <-stopMessageSignal: - errorsDuringTransfer <- nil - log.Debugf("got message signal: %+v", sig) - return - case wsMessage := <-websocketMessages: - log.Debugf("got message: %s", wsMessage.message) - if bytes.HasPrefix(wsMessage.message, []byte("error")) { - log.Debug("stopping transfer") - for i := 0; i < len(tcpConnections)+1; i++ { - errorsDuringTransfer <- fmt.Errorf("%s", wsMessage.message) - } - return - } - default: - continue - } - } - }() - - var wg sync.WaitGroup - wg.Add(len(tcpConnections)) - for i := range tcpConnections { - defer func(i int) { - log.Debugf("closing connection %d", i) - tcpConnections[i].Close() - }(i) - go func(i int, wg *sync.WaitGroup, dataChan <-chan DataChan) { - defer wg.Done() - for data := range dataChan { - select { - case _ = <-errorsDuringTransfer: - log.Debugf("%d got stop", i) - return - default: - } - if data.err != nil { - log.Error(data.err) - return - } - cr.Bar.Add(data.bytesRead) - // write data to tcp connection - _, errTcp := tcpConnections[i].Write(data.b) - if errTcp != nil { - errTcp = errors.Wrap(errTcp, "problem writing message") - log.Debug(errTcp) - errorsDuringTransfer <- errTcp - return - } - if bytes.Equal(data.b, []byte("magic")) { - log.Debugf("%d got magic", i) - return - } - } - }(i, &wg, dataChan) - } - - // block until this is done - log.Debug("waiting for tcp goroutines") - wg.Wait() - log.Debug("sending stop message signal") - stopMessageSignal <- true - log.Debug("waiting for error") - errorDuringTransfer := <-errorsDuringTransfer - if errorDuringTransfer != nil { - log.Debugf("got error during transfer: %s", errorDuringTransfer.Error()) - return errorDuringTransfer - } - } - - cr.Bar.Finish() - log.Debug("send hash to finish file") - fileHash, err = utils.HashFile(fname) - if err != nil { - return err - } - case 6: - // recevied something, maybe the file hash - transferTime := time.Since(startTransfer) - if !bytes.HasPrefix(message, []byte("hash:")) { - log.Debugf("%s", message) - continue - } - c.WriteMessage(websocket.BinaryMessage, fileHash) - message = bytes.TrimPrefix(message, []byte("hash:")) - log.Debugf("[%d] determing whether it went ok", step) - if bytes.Equal(message, fileHash) { - log.Debug("file transfered successfully") - transferRate := float64(cr.FileInfo.Size) / 1000000.0 / transferTime.Seconds() - transferType := "MB/s" - if transferRate < 1 { - transferRate = float64(cr.FileInfo.Size) / 1000.0 / transferTime.Seconds() - transferType = "kB/s" - } - fmt.Fprintf(os.Stderr, "\nTransfer complete (%2.1f %s)", transferRate, transferType) - cr.StateString = fmt.Sprintf("Transfer complete (%2.1f %s)", transferRate, transferType) - return nil - } else { - fmt.Fprintf(os.Stderr, "\nTransfer corrupted") - return errors.New("file not transfered succesfully") - } - default: - return fmt.Errorf("unknown step") - } - step++ - } -} - -func connectToTCPServer(room string, address string) (com comm.Comm, message string, err error) { - connection, err := net.DialTimeout("tcp", address, 3*time.Hour) - if err != nil { - return - } - connection.SetReadDeadline(time.Now().Add(3 * time.Hour)) - connection.SetDeadline(time.Now().Add(3 * time.Hour)) - connection.SetWriteDeadline(time.Now().Add(3 * time.Hour)) - - com = comm.New(connection) - ok, err := com.Receive() - if err != nil { - return - } - log.Debugf("server says: %s", ok) - - err = com.Send(room) - if err != nil { - return - } - message, err = com.Receive() - log.Debugf("server says: %s", message) - return -} diff --git a/src/croc/sending.go b/src/croc/sending.go deleted file mode 100644 index 250de266..00000000 --- a/src/croc/sending.go +++ /dev/null @@ -1,217 +0,0 @@ -package croc - -import ( - "errors" - "fmt" - "net/http" - "os" - "os/signal" - "strings" - "time" - - log "github.com/cihub/seelog" - "github.com/gorilla/websocket" - "github.com/schollz/croc/src/relay" - "github.com/schollz/croc/src/utils" - "github.com/schollz/peerdiscovery" -) - -// Send the file -func (c *Croc) Send(fname, codephrase string) (err error) { - defer log.Flush() - log.Debugf("sending %s", fname) - errChan := make(chan error) - - // normally attempt two connections - waitingFor := 2 - - // use public relay - if !c.LocalOnly { - go func() { - // atttempt to connect to public relay - errChan <- c.sendReceive(c.Address, c.AddressWebsocketPort, c.AddressTCPPorts, fname, codephrase, true, false) - }() - } else { - waitingFor = 1 - } - - // use local relay - if !c.NoLocal { - defer func() { - log.Debug("sending relay stop signal") - relay.Stop() - }() - go func() { - // start own relay and connect to it - go relay.Run(c.RelayWebsocketPort, c.RelayTCPPorts) - time.Sleep(250 * time.Millisecond) // race condition here, but this should work most of the time :( - - // broadcast for peer discovery - go func() { - log.Debug("starting local discovery...") - discovered, err := peerdiscovery.Discover(peerdiscovery.Settings{ - Limit: 1, - TimeLimit: 600 * time.Second, - Delay: 50 * time.Millisecond, - Payload: []byte(c.RelayWebsocketPort + "- " + strings.Join(c.RelayTCPPorts, ",")), - MulticastAddress: fmt.Sprintf("239.255.255.%d", 230+int64(time.Now().Minute()/5)), - }) - log.Debug(discovered, err) - }() - - // connect to own relay - errChan <- c.sendReceive("localhost", c.RelayWebsocketPort, c.RelayTCPPorts, fname, codephrase, true, true) - }() - } else { - waitingFor = 1 - } - - err = <-errChan - if err == nil || waitingFor == 1 { - log.Debug("returning") - return - } - log.Debug(err) - return <-errChan -} - -// Receive the file -func (c *Croc) Receive(codephrase string) (err error) { - defer log.Flush() - log.Debug("receiving") - - // use local relay first - if !c.NoLocal { - log.Debug("trying to discover") - // try to discovery codephrase and server through peer network - discovered, errDiscover := peerdiscovery.Discover(peerdiscovery.Settings{ - Limit: 1, - TimeLimit: 300 * time.Millisecond, - Delay: 50 * time.Millisecond, - Payload: []byte("checking"), - AllowSelf: true, - DisableBroadcast: true, - MulticastAddress: fmt.Sprintf("239.255.255.%d", 230+int64(time.Now().Minute()/5)), - }) - log.Debug("finished") - log.Debug(discovered) - if errDiscover != nil { - log.Debug(errDiscover) - } - if len(discovered) > 0 { - if discovered[0].Address == utils.LocalIP() { - discovered[0].Address = "localhost" - } - log.Debugf("discovered %s:%s", discovered[0].Address, discovered[0].Payload) - // see if we can actually connect to it - timeout := time.Duration(200 * time.Millisecond) - client := http.Client{ - Timeout: timeout, - } - ports := strings.Split(string(discovered[0].Payload), "-") - if len(ports) != 2 { - return errors.New("bad payload") - } - resp, err := client.Get(fmt.Sprintf("http://%s:%s/", discovered[0].Address, ports[0])) - if err == nil { - if resp.StatusCode == http.StatusOK { - // we connected, so use this - return c.sendReceive(discovered[0].Address, strings.TrimSpace(ports[0]), strings.Split(strings.TrimSpace(ports[1]), ","), "", codephrase, false, true) - } - } else { - log.Debugf("could not connect: %s", err.Error()) - } - } else { - log.Debug("discovered no peers") - } - } - - // use public relay - if !c.LocalOnly { - log.Debug("using public relay") - return c.sendReceive(c.Address, c.AddressWebsocketPort, c.AddressTCPPorts, "", codephrase, false, false) - } - - return errors.New("must use local or public relay") -} - -func (c *Croc) sendReceive(address, websocketPort string, tcpPorts []string, fname string, codephrase string, isSender bool, isLocal bool) (err error) { - defer log.Flush() - if len(codephrase) < 4 { - return fmt.Errorf("codephrase is too short") - } - - // allow interrupts from Ctl+C - interrupt := make(chan os.Signal, 1) - signal.Notify(interrupt, os.Interrupt) - - done := make(chan error) - // connect to server - websocketAddress := "" - if len(websocketPort) > 0 { - websocketAddress = fmt.Sprintf("ws://%s:%s/ws?room=%s", address, websocketPort, codephrase[:3]) - } else { - websocketAddress = fmt.Sprintf("ws://%s/ws?room=%s", address, codephrase[:3]) - } - log.Debugf("connecting to %s", websocketAddress) - sock, _, err := websocket.DefaultDialer.Dial(websocketAddress, nil) - if err != nil { - log.Error(err) - return - } - defer sock.Close() - - // tell the websockets we are connected - err = sock.WriteMessage(websocket.BinaryMessage, []byte("connected")) - if err != nil { - log.Error(err) - return err - } - - if isSender { - go c.startSender(c.ForceSend, address, tcpPorts, isLocal, done, sock, fname, codephrase, c.UseCompression, c.UseEncryption) - } else { - go c.startRecipient(c.ForceSend, address, tcpPorts, isLocal, done, sock, codephrase, c.NoRecipientPrompt, c.Stdout) - } - - for { - select { - case doneError := <-done: - log.Debug("received done signal") - if doneError != nil { - c.StateString = doneError.Error() - sock.WriteMessage(websocket.TextMessage, []byte("error: "+doneError.Error())) - time.Sleep(50 * time.Millisecond) - } - return doneError - case <-interrupt: - if !c.Debug { - SetDebugLevel("critical") - } - log.Debug("interrupt") - err = sock.WriteMessage(websocket.TextMessage, []byte("error: interrupted by other party")) - if err != nil { - return err - } - time.Sleep(50 * time.Millisecond) - - // Cleanly close the connection by sending a close message and then - // waiting (with timeout) for the server to close the connection. - err := sock.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) - if err != nil { - log.Debug("write close:", err) - return nil - } - select { - case <-done: - case <-time.After(100 * time.Millisecond): - } - return nil - } - } -} - -// Relay will start a relay on the specified port -func (c *Croc) Relay() (err error) { - return relay.Run(c.RelayWebsocketPort, c.RelayTCPPorts) -}