0
0
Fork 0
mirror of https://github.com/schollz/croc.git synced 2025-10-11 13:21:00 +02:00

use new version of croc

This commit is contained in:
Zack Scholl 2019-04-29 16:48:17 -06:00
parent dcc7689816
commit 1f49966bb1
7 changed files with 689 additions and 1765 deletions

View file

@ -1,187 +0,0 @@
package croc
import (
"bytes"
"fmt"
"io/ioutil"
"os"
"path"
"path/filepath"
"time"
"github.com/BurntSushi/toml"
homedir "github.com/mitchellh/go-homedir"
"github.com/schollz/croc/src/utils"
)
type Config struct {
// Relay parameters
RelayWebsocketPort string
RelayTCPPorts []string
// Sender parameters
CurveType string
// Options for connecting to server
PublicServerIP string
AddressTCPPorts []string
AddressWebsocketPort string
Timeout time.Duration
LocalOnly bool
NoLocal bool
// Options for file transfering
UseEncryption bool
UseCompression bool
AllowLocalDiscovery bool
NoRecipientPrompt bool
ForceTCP bool
ForceWebsockets bool
Codephrase string
}
func defaultConfig() Config {
c := Config{}
cr := Init(false)
c.RelayWebsocketPort = cr.RelayWebsocketPort
c.RelayTCPPorts = cr.RelayTCPPorts
c.CurveType = cr.CurveType
c.PublicServerIP = cr.Address
c.AddressTCPPorts = cr.AddressTCPPorts
c.AddressWebsocketPort = cr.AddressWebsocketPort
c.Timeout = cr.Timeout
c.LocalOnly = cr.LocalOnly
c.NoLocal = cr.NoLocal
c.UseEncryption = cr.UseEncryption
c.UseCompression = cr.UseCompression
c.AllowLocalDiscovery = cr.AllowLocalDiscovery
c.NoRecipientPrompt = cr.NoRecipientPrompt
c.ForceTCP = false
c.ForceWebsockets = false
c.Codephrase = ""
return c
}
func SaveDefaultConfig() error {
homedir, err := homedir.Dir()
if err != nil {
return err
}
os.MkdirAll(path.Join(homedir, ".config", "croc"), 0755)
c := defaultConfig()
buf := new(bytes.Buffer)
toml.NewEncoder(buf).Encode(c)
confTOML := buf.String()
err = ioutil.WriteFile(path.Join(homedir, ".config", "croc", "config.toml"), []byte(confTOML), 0644)
if err == nil {
fmt.Printf("Default config file written at '%s'\r\n", filepath.Clean(path.Join(homedir, ".config", "croc", "config.toml")))
}
return err
}
// LoadConfig will override parameters
func (cr *Croc) LoadConfig() (err error) {
homedir, err := homedir.Dir()
if err != nil {
return err
}
pathToConfig := path.Join(homedir, ".config", "croc", "config.toml")
if !utils.Exists(pathToConfig) {
// ignore if doesn't exist
return nil
}
var c Config
_, err = toml.DecodeFile(pathToConfig, &c)
if err != nil {
return
}
cDefault := defaultConfig()
// only load if things are different than defaults
// just in case the CLI parameters are used
if c.RelayWebsocketPort != cDefault.RelayWebsocketPort && cr.RelayWebsocketPort == cDefault.RelayWebsocketPort {
cr.RelayWebsocketPort = c.RelayWebsocketPort
fmt.Printf("loaded RelayWebsocketPort from config\n")
}
if !slicesEqual(c.RelayTCPPorts, cDefault.RelayTCPPorts) && slicesEqual(cr.RelayTCPPorts, cDefault.RelayTCPPorts) {
cr.RelayTCPPorts = c.RelayTCPPorts
fmt.Printf("loaded RelayTCPPorts from config\n")
}
if c.CurveType != cDefault.CurveType && cr.CurveType == cDefault.CurveType {
cr.CurveType = c.CurveType
fmt.Printf("loaded CurveType from config\n")
}
if c.PublicServerIP != cDefault.PublicServerIP && cr.Address == cDefault.PublicServerIP {
cr.Address = c.PublicServerIP
fmt.Printf("loaded Address from config\n")
}
if !slicesEqual(c.AddressTCPPorts, cDefault.AddressTCPPorts) {
cr.AddressTCPPorts = c.AddressTCPPorts
fmt.Printf("loaded AddressTCPPorts from config\n")
}
if c.AddressWebsocketPort != cDefault.AddressWebsocketPort && cr.AddressWebsocketPort == cDefault.AddressWebsocketPort {
cr.AddressWebsocketPort = c.AddressWebsocketPort
fmt.Printf("loaded AddressWebsocketPort from config\n")
}
if c.Timeout != cDefault.Timeout && cr.Timeout == cDefault.Timeout {
cr.Timeout = c.Timeout
fmt.Printf("loaded Timeout from config\n")
}
if c.LocalOnly != cDefault.LocalOnly && cr.LocalOnly == cDefault.LocalOnly {
cr.LocalOnly = c.LocalOnly
fmt.Printf("loaded LocalOnly from config\n")
}
if c.NoLocal != cDefault.NoLocal && cr.NoLocal == cDefault.NoLocal {
cr.NoLocal = c.NoLocal
fmt.Printf("loaded NoLocal from config\n")
}
if c.UseEncryption != cDefault.UseEncryption && cr.UseEncryption == cDefault.UseEncryption {
cr.UseEncryption = c.UseEncryption
fmt.Printf("loaded UseEncryption from config\n")
}
if c.UseCompression != cDefault.UseCompression && cr.UseCompression == cDefault.UseCompression {
cr.UseCompression = c.UseCompression
fmt.Printf("loaded UseCompression from config\n")
}
if c.AllowLocalDiscovery != cDefault.AllowLocalDiscovery && cr.AllowLocalDiscovery == cDefault.AllowLocalDiscovery {
cr.AllowLocalDiscovery = c.AllowLocalDiscovery
fmt.Printf("loaded AllowLocalDiscovery from config\n")
}
if c.NoRecipientPrompt != cDefault.NoRecipientPrompt && cr.NoRecipientPrompt == cDefault.NoRecipientPrompt {
cr.NoRecipientPrompt = c.NoRecipientPrompt
fmt.Printf("loaded NoRecipientPrompt from config\n")
}
if c.ForceWebsockets {
cr.ForceSend = 1
}
if c.ForceTCP {
cr.ForceSend = 2
}
if c.Codephrase != cDefault.Codephrase && cr.Codephrase == cDefault.Codephrase {
cr.Codephrase = c.Codephrase
fmt.Printf("loaded Codephrase from config\n")
}
return
}
// slicesEqual checcks if two slices are equal
// from https://stackoverflow.com/a/15312097
func slicesEqual(a, b []string) bool {
// If one is nil, the other must also be nil.
if (a == nil) != (b == nil) {
return false
}
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}

View file

@ -1,96 +1,706 @@
package croc
import (
"runtime"
"bytes"
"crypto/elliptic"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"math"
"os"
"path"
"path/filepath"
"strings"
"sync"
"time"
"github.com/schollz/croc/src/logger"
"github.com/schollz/croc/src/models"
"github.com/schollz/croc/src/relay"
"github.com/schollz/croc/src/zipper"
"github.com/denisbrodbeck/machineid"
"github.com/go-redis/redis"
"github.com/mattn/go-colorable"
"github.com/pions/webrtc"
"github.com/schollz/croc/v5/src/crypt"
"github.com/schollz/croc/v5/src/utils"
"github.com/schollz/croc/v5/src/webrtc/pkg/session/common"
"github.com/schollz/croc/v5/src/webrtc/pkg/session/receiver"
"github.com/schollz/croc/v5/src/webrtc/pkg/session/sender"
"github.com/schollz/pake"
"github.com/schollz/progressbar/v2"
"github.com/schollz/spinner"
"github.com/sirupsen/logrus"
)
const BufferSize = 4096 * 10
const Channels = 1
var log = logrus.New()
func init() {
runtime.GOMAXPROCS(runtime.NumCPU())
log.SetFormatter(&logrus.TextFormatter{ForceColors: true})
log.SetOutput(colorable.NewColorableStdout())
Debug(false)
}
// Croc options
type Croc struct {
// Version is the version of croc
Version string
// Options for all
Debug bool
// ShowText will display text on the stderr
ShowText bool
// Options for relay
RelayWebsocketPort string
RelayTCPPorts []string
CurveType string
// Options for connecting to server
Address string
AddressTCPPorts []string
AddressWebsocketPort string
Timeout time.Duration
LocalOnly bool
NoLocal bool
// Options for file transfering
UseEncryption bool
UseCompression bool
AllowLocalDiscovery bool
NoRecipientPrompt bool
Stdout bool
ForceSend int // 0: ignore, 1: websockets, 2: TCP
// Parameters for file transfer
Filename string
Codephrase string
// localIP address
localIP string
// is using local relay
isLocal bool
normalFinish bool
// state variables
StateString string
Bar *progressbar.ProgressBar
FileInfo models.FileStats
OtherIP string
// special for window
WindowRecipientPrompt bool
WindowRecipientAccept bool
WindowReceivingString string
}
// Init will initiate with the default parameters
func Init(debug bool) (c *Croc) {
c = new(Croc)
c.UseCompression = true
c.UseEncryption = true
c.AllowLocalDiscovery = true
c.RelayWebsocketPort = "8153"
c.RelayTCPPorts = []string{"8154", "8155", "8156", "8157", "8158", "8159", "8160", "8161"}
c.CurveType = "siec"
c.Address = "croc4.schollz.com"
c.AddressWebsocketPort = "8153"
c.AddressTCPPorts = []string{"8154", "8155", "8156", "8157", "8158", "8159", "8160", "8161"}
c.NoRecipientPrompt = true
debugLevel := "info"
func Debug(debug bool) {
receiver.Debug(debug)
sender.Debug(debug)
if debug {
debugLevel = "debug"
c.Debug = true
log.SetLevel(logrus.DebugLevel)
} else {
log.SetLevel(logrus.WarnLevel)
}
SetDebugLevel(debugLevel)
}
type Client struct {
Options Options
// basic setup
redisdb *redis.Client
log *logrus.Entry
Pake *pake.Pake
// steps involved in forming relationship
Step1ChannelSecured bool
Step2FileInfoTransfered bool
Step3RecipientRequestFile bool
Step4FileTransfer bool
Step5CloseChannels bool // TODO: Step5 should close files and reset things
// send / receive information of all files
FilesToTransfer []FileInfo
FilesToTransferCurrentNum int
// send / receive information of current file
CurrentFile *os.File
CurrentFileChunks []int64
sendSess *sender.Session
recvSess *receiver.Session
// channel data
incomingMessageChannel <-chan *redis.Message
nameOutChannel string
nameInChannel string
// webrtc connections
peerConnection [8]*webrtc.PeerConnection
dataChannel [8]*webrtc.DataChannel
bar *progressbar.ProgressBar
spinner *spinner.Spinner
machineID string
mutex *sync.Mutex
quit chan bool
}
type Message struct {
Type string `json:"t,omitempty"`
Message string `json:"m,omitempty"`
Bytes []byte `json:"b,omitempty"`
Num int `json:"n,omitempty"`
}
type Chunk struct {
Bytes []byte `json:"b,omitempty"`
Location int64 `json:"l,omitempty"`
}
type FileInfo struct {
Name string `json:"n,omitempty"`
FolderRemote string `json:"fr,omitempty"`
FolderSource string `json:"fs,omitempty"`
Hash []byte `json:"h,omitempty"`
Size int64 `json:"s,omitempty"`
ModTime time.Time `json:"m,omitempty"`
IsCompressed bool `json:"c,omitempty"`
IsEncrypted bool `json:"e,omitempty"`
}
type RemoteFileRequest struct {
CurrentFileChunks []int64
FilesToTransferCurrentNum int
}
type SenderInfo struct {
MachineID string
FilesToTransfer []FileInfo
}
func (m Message) String() string {
b, _ := json.Marshal(m)
return string(b)
}
type Options struct {
IsSender bool
SharedSecret string
Debug bool
AddressRelay string
Stdout bool
NoPrompt bool
}
// New establishes a new connection for transfering files between two instances.
func New(ops Options) (c *Client, err error) {
c = new(Client)
// setup basic info
c.Options = ops
Debug(c.Options.Debug)
log.Debugf("options: %+v", c.Options)
// set channels
if c.Options.IsSender {
c.nameOutChannel = c.Options.SharedSecret + "2"
c.nameInChannel = c.Options.SharedSecret + "1"
} else {
c.nameOutChannel = c.Options.SharedSecret + "1"
c.nameInChannel = c.Options.SharedSecret + "2"
}
// initialize redis for communication in establishing channel
c.redisdb = redis.NewClient(&redis.Options{
Addr: c.Options.AddressRelay,
Password: "",
DB: 4,
WriteTimeout: 1 * time.Hour,
ReadTimeout: 1 * time.Hour,
})
_, err = c.redisdb.Ping().Result()
if err != nil {
return
}
func SetDebugLevel(debugLevel string) {
logger.SetLogLevel(debugLevel)
relay.DebugLevel = debugLevel
zipper.DebugLevel = debugLevel
// setup channel for listening
pubsub := c.redisdb.Subscribe(c.nameInChannel)
_, err = pubsub.Receive()
if err != nil {
return
}
c.incomingMessageChannel = pubsub.Channel()
// initialize pake
if c.Options.IsSender {
c.Pake, err = pake.Init([]byte(c.Options.SharedSecret), 1, elliptic.P521(), 1*time.Microsecond)
} else {
c.Pake, err = pake.Init([]byte(c.Options.SharedSecret), 0, elliptic.P521(), 1*time.Microsecond)
}
if err != nil {
return
}
// initialize logger
c.log = log.WithFields(logrus.Fields{
"is": "sender",
})
if !c.Options.IsSender {
c.log = log.WithFields(logrus.Fields{
"is": "recipient",
})
}
c.spinner = spinner.New(spinner.CharSets[9], 100*time.Millisecond)
c.spinner.Writer = os.Stderr
c.spinner.Suffix = " connecting..."
c.mutex = &sync.Mutex{}
return
}
type TransferOptions struct {
PathToFiles []string
KeepPathInRemote bool
}
// Send will send the specified file
func (c *Client) Send(options TransferOptions) (err error) {
return c.transfer(options)
}
// Receive will receive a file
func (c *Client) Receive() (err error) {
return c.transfer(TransferOptions{})
}
func (c *Client) transfer(options TransferOptions) (err error) {
if c.Options.IsSender {
c.FilesToTransfer = make([]FileInfo, len(options.PathToFiles))
totalFilesSize := int64(0)
for i, pathToFile := range options.PathToFiles {
var fstats os.FileInfo
var fullPath string
fullPath, err = filepath.Abs(pathToFile)
if err != nil {
return
}
fullPath = filepath.Clean(fullPath)
var folderName string
folderName, _ = filepath.Split(fullPath)
fstats, err = os.Stat(fullPath)
if err != nil {
return
}
c.FilesToTransfer[i] = FileInfo{
Name: fstats.Name(),
FolderRemote: ".",
FolderSource: folderName,
Size: fstats.Size(),
ModTime: fstats.ModTime(),
}
c.FilesToTransfer[i].Hash, err = utils.HashFile(fullPath)
totalFilesSize += fstats.Size()
if err != nil {
return
}
if options.KeepPathInRemote {
var curFolder string
curFolder, err = os.Getwd()
if err != nil {
return
}
curFolder, err = filepath.Abs(curFolder)
if err != nil {
return
}
if !strings.HasPrefix(folderName, curFolder) {
err = fmt.Errorf("remote directory must be relative to current")
return
}
c.FilesToTransfer[i].FolderRemote = strings.TrimPrefix(folderName, curFolder)
c.FilesToTransfer[i].FolderRemote = filepath.ToSlash(c.FilesToTransfer[i].FolderRemote)
c.FilesToTransfer[i].FolderRemote = strings.TrimPrefix(c.FilesToTransfer[i].FolderRemote, "/")
if c.FilesToTransfer[i].FolderRemote == "" {
c.FilesToTransfer[i].FolderRemote = "."
}
}
log.Debugf("file %d info: %+v", i, c.FilesToTransfer[i])
}
fname := fmt.Sprintf("%d files", len(c.FilesToTransfer))
if len(c.FilesToTransfer) == 1 {
fname = fmt.Sprintf("'%s'", c.FilesToTransfer[0].Name)
}
machID, macIDerr := machineid.ID()
if macIDerr != nil {
log.Error(macIDerr)
return
}
if len(machID) > 6 {
machID = machID[:6]
}
c.machineID = machID
fmt.Fprintf(os.Stderr, "Sending %s (%s) from your machine, '%s'\n", fname, utils.ByteCountDecimal(totalFilesSize), machID)
fmt.Fprintf(os.Stderr, "Code is: %s\nOn the other computer run\n\ncroc %s\n", c.Options.SharedSecret, c.Options.SharedSecret)
c.spinner.Suffix = " waiting for recipient..."
}
c.spinner.Start()
// create channel for quitting
// quit with c.quit <- true
c.quit = make(chan bool)
// if recipient, initialize with sending pake information
c.log.Debug("ready")
if !c.Options.IsSender && !c.Step1ChannelSecured {
err = c.redisdb.Publish(c.nameOutChannel, Message{
Type: "pake",
Bytes: c.Pake.Bytes(),
}.String()).Err()
if err != nil {
return
}
}
// listen for incoming messages and process them
for {
select {
case <-c.quit:
return
case msg := <-c.incomingMessageChannel:
var m Message
err = json.Unmarshal([]byte(msg.Payload), &m)
if err != nil {
return
}
if m.Type == "finished" {
err = c.redisdb.Publish(c.nameOutChannel, Message{
Type: "finished",
}.String()).Err()
return err
}
err = c.processMessage(m)
if err != nil {
return
}
default:
time.Sleep(1 * time.Millisecond)
}
}
return
}
func (c *Client) sendOverRedis() (err error) {
go func() {
c.bar = progressbar.NewOptions(
int(c.FilesToTransfer[c.FilesToTransferCurrentNum].Size),
progressbar.OptionSetRenderBlankState(true),
progressbar.OptionSetBytes(int(c.FilesToTransfer[c.FilesToTransferCurrentNum].Size)),
progressbar.OptionSetWriter(os.Stderr),
progressbar.OptionThrottle(1/60*time.Second),
)
c.CurrentFile, err = os.Open(c.FilesToTransfer[c.FilesToTransferCurrentNum].Name)
if err != nil {
panic(err)
}
location := int64(0)
for {
buf := make([]byte, 4096*128)
n, errRead := c.CurrentFile.Read(buf)
c.bar.Add(n)
chunk := Chunk{
Bytes: buf[:n],
Location: location,
}
chunkB, _ := json.Marshal(chunk)
err = c.redisdb.Publish(c.nameOutChannel, Message{
Type: "chunk",
Bytes: chunkB,
}.String()).Err()
if err != nil {
panic(err)
}
location += int64(n)
if errRead == io.EOF {
break
}
if errRead != nil {
panic(errRead)
}
}
}()
return
}
func (c *Client) processMessage(m Message) (err error) {
switch m.Type {
case "pake":
if c.spinner.Suffix != " performing PAKE..." {
c.spinner.Stop()
c.spinner.Suffix = " performing PAKE..."
c.spinner.Start()
}
notVerified := !c.Pake.IsVerified()
err = c.Pake.Update(m.Bytes)
if err != nil {
return
}
if (notVerified && c.Pake.IsVerified() && !c.Options.IsSender) || !c.Pake.IsVerified() {
err = c.redisdb.Publish(c.nameOutChannel, Message{
Type: "pake",
Bytes: c.Pake.Bytes(),
}.String()).Err()
}
if c.Pake.IsVerified() {
c.log.Debug(c.Pake.SessionKey())
c.Step1ChannelSecured = true
}
case "error":
c.spinner.Stop()
fmt.Print("\r")
err = fmt.Errorf("peer error: %s", m.Message)
return err
case "fileinfo":
var senderInfo SenderInfo
var decryptedBytes []byte
key, _ := c.Pake.SessionKey()
decryptedBytes, err = crypt.DecryptFromBytes(m.Bytes, key)
if err != nil {
log.Error(err)
return
}
err = json.Unmarshal(decryptedBytes, &senderInfo)
if err != nil {
log.Error(err)
return
}
c.FilesToTransfer = senderInfo.FilesToTransfer
fname := fmt.Sprintf("%d files", len(c.FilesToTransfer))
if len(c.FilesToTransfer) == 1 {
fname = fmt.Sprintf("'%s'", c.FilesToTransfer[0].Name)
}
totalSize := int64(0)
for _, fi := range c.FilesToTransfer {
totalSize += fi.Size
}
c.spinner.Stop()
if !c.Options.NoPrompt {
fmt.Fprintf(os.Stderr, "\rAccept %s (%s) from machine '%s'? (y/n) ", fname, utils.ByteCountDecimal(totalSize), senderInfo.MachineID)
if strings.ToLower(strings.TrimSpace(utils.GetInput(""))) != "y" {
err = c.redisdb.Publish(c.nameOutChannel, Message{
Type: "error",
Message: "refusing files",
}.String()).Err()
return fmt.Errorf("refused files")
}
} else {
fmt.Fprintf(os.Stderr, "\rReceiving %s (%s) from machine '%s'\n", fname, utils.ByteCountDecimal(totalSize), senderInfo.MachineID)
}
c.log.Debug(c.FilesToTransfer)
c.Step2FileInfoTransfered = true
case "recipientready":
var remoteFile RemoteFileRequest
var decryptedBytes []byte
key, _ := c.Pake.SessionKey()
decryptedBytes, err = crypt.DecryptFromBytes(m.Bytes, key)
if err != nil {
log.Error(err)
return
}
err = json.Unmarshal(decryptedBytes, &remoteFile)
if err != nil {
return
}
c.FilesToTransferCurrentNum = remoteFile.FilesToTransferCurrentNum
c.CurrentFileChunks = remoteFile.CurrentFileChunks
c.Step3RecipientRequestFile = true
case "datachannel-offer":
err = c.dataChannelReceive()
if err != nil {
return
}
err = c.recvSess.SetSDP(m.Message)
if err != nil {
return
}
var answer string
answer, err = c.recvSess.CreateAnswer()
if err != nil {
return
}
// Output the answer in base64 so we can paste it in browser
err = c.redisdb.Publish(c.nameOutChannel, Message{
Type: "datachannel-answer",
Message: answer,
Num: m.Num,
}.String()).Err()
// start receiving data
pathToFile := path.Join(c.FilesToTransfer[c.FilesToTransferCurrentNum].FolderRemote, c.FilesToTransfer[c.FilesToTransferCurrentNum].Name)
c.spinner.Stop()
key, _ := c.Pake.SessionKey()
c.recvSess.ReceiveData(pathToFile, c.FilesToTransfer[c.FilesToTransferCurrentNum].Size, key)
log.Debug("sending close-sender")
err = c.redisdb.Publish(c.nameOutChannel, Message{
Type: "close-sender",
}.String()).Err()
case "datachannel-answer":
c.log.Debug("got answer:", m.Message)
// Apply the answer as the remote description
err = c.sendSess.SetSDP(m.Message)
pathToFile := path.Join(c.FilesToTransfer[c.FilesToTransferCurrentNum].FolderSource, c.FilesToTransfer[c.FilesToTransferCurrentNum].Name)
c.spinner.Stop()
fmt.Fprintf(os.Stderr, "\r\nTransfering...\n")
key, _ := c.Pake.SessionKey()
c.sendSess.TransferFile(pathToFile, key)
case "close-sender":
log.Debug("close-sender received...")
c.Step4FileTransfer = false
c.Step3RecipientRequestFile = false
c.sendSess.StopSending()
log.Debug("sending close-recipient")
err = c.redisdb.Publish(c.nameOutChannel, Message{
Type: "close-recipient",
Num: m.Num,
}.String()).Err()
case "close-recipient":
c.Step4FileTransfer = false
c.Step3RecipientRequestFile = false
}
if err != nil {
return
}
err = c.updateState()
return
}
func (c *Client) updateState() (err error) {
if c.Options.IsSender && c.Step1ChannelSecured && !c.Step2FileInfoTransfered {
var b []byte
b, err = json.Marshal(SenderInfo{
MachineID: c.machineID,
FilesToTransfer: c.FilesToTransfer,
})
if err != nil {
log.Error(err)
return
}
key, _ := c.Pake.SessionKey()
err = c.redisdb.Publish(c.nameOutChannel, Message{
Type: "fileinfo",
Bytes: crypt.EncryptToBytes(b, key),
}.String()).Err()
if err != nil {
return
}
c.Step2FileInfoTransfered = true
}
if !c.Options.IsSender && c.Step2FileInfoTransfered && !c.Step3RecipientRequestFile {
// find the next file to transfer and send that number
// if the files are the same size, then look for missing chunks
finished := true
for i, fileInfo := range c.FilesToTransfer {
if i < c.FilesToTransferCurrentNum {
continue
}
fileHash, errHash := utils.HashFile(path.Join(fileInfo.FolderRemote, fileInfo.Name))
if errHash != nil || !bytes.Equal(fileHash, fileInfo.Hash) {
if !bytes.Equal(fileHash, fileInfo.Hash) {
log.Debugf("hashes are not equal %x != %x", fileHash, fileInfo.Hash)
}
finished = false
c.FilesToTransferCurrentNum = i
break
}
// TODO: print out something about this file already existing
}
if finished {
// TODO: do the last finishing stuff
log.Debug("finished")
err = c.redisdb.Publish(c.nameOutChannel, Message{
Type: "finished",
}.String()).Err()
if err != nil {
panic(err)
}
}
// start initiating the process to receive a new file
log.Debugf("working on file %d", c.FilesToTransferCurrentNum)
// recipient requests the file and chunks (if empty, then should receive all chunks)
bRequest, _ := json.Marshal(RemoteFileRequest{
CurrentFileChunks: c.CurrentFileChunks,
FilesToTransferCurrentNum: c.FilesToTransferCurrentNum,
})
key, _ := c.Pake.SessionKey()
err = c.redisdb.Publish(c.nameOutChannel, Message{
Type: "recipientready",
Bytes: crypt.EncryptToBytes(bRequest, key),
}.String()).Err()
if err != nil {
return
}
c.Step3RecipientRequestFile = true
err = c.dataChannelReceive()
}
if c.Options.IsSender && c.Step3RecipientRequestFile && !c.Step4FileTransfer {
c.log.Debug("start sending data!")
err = c.dataChannelSend()
c.Step4FileTransfer = true
}
return
}
func (c *Client) dataChannelReceive() (err error) {
c.recvSess = receiver.NewWith(receiver.Config{})
err = c.recvSess.CreateConnection()
if err != nil {
return
}
c.recvSess.CreateDataHandler()
return
}
func (c *Client) dataChannelSend() (err error) {
c.sendSess = sender.NewWith(sender.Config{
Configuration: common.Configuration{
OnCompletion: func() {
},
},
})
if err := c.sendSess.CreateConnection(); err != nil {
log.Error(err)
return err
}
if err := c.sendSess.CreateDataChannel(); err != nil {
log.Error(err)
return err
}
offer, err := c.sendSess.CreateOffer()
if err != nil {
log.Error(err)
return err
}
// sending offer
err = c.redisdb.Publish(c.nameOutChannel, Message{
Type: "datachannel-offer",
Message: offer,
}.String()).Err()
if err != nil {
return
}
return
}
// MissingChunks returns the positions of missing chunks.
// If file doesn't exist, it returns an empty chunk list (all chunks).
// If the file size is not the same as requested, it returns an empty chunk list (all chunks).
func MissingChunks(fname string, fsize int64, chunkSize int) (chunks []int64) {
fstat, err := os.Stat(fname)
if fstat.Size() != fsize {
return
}
f, err := os.Open(fname)
if err != nil {
return
}
defer f.Close()
buffer := make([]byte, chunkSize)
emptyBuffer := make([]byte, chunkSize)
chunkNum := 0
chunks = make([]int64, int64(math.Ceil(float64(fsize)/float64(chunkSize))))
var currentLocation int64
for {
bytesread, err := f.Read(buffer)
if err != nil {
break
}
if bytes.Equal(buffer[:bytesread], emptyBuffer[:bytesread]) {
chunks[chunkNum] = currentLocation
}
currentLocation += int64(bytesread)
}
if chunkNum == 0 {
chunks = []int64{}
} else {
chunks = chunks[:chunkNum]
}
return
}
// Encode encodes the input in base64
// It can optionally zip the input before encoding
func Encode(obj interface{}) string {
b, err := json.Marshal(obj)
if err != nil {
panic(err)
}
return base64.StdEncoding.EncodeToString(b)
}
// Decode decodes the input from base64
// It can optionally unzip the input after decoding
func Decode(in string, obj interface{}) (err error) {
b, err := base64.StdEncoding.DecodeString(in)
if err != nil {
return
}
err = json.Unmarshal(b, obj)
return
}

View file

@ -1,81 +0,0 @@
package croc
import (
"crypto/rand"
"fmt"
"io/ioutil"
"os"
"sync"
"testing"
"time"
"github.com/schollz/croc/src/utils"
"github.com/stretchr/testify/assert"
)
func sendAndReceive(t *testing.T, forceSend int, local bool) {
room := utils.GetRandomName()
var startTime time.Time
var durationPerMegabyte float64
megabytes := 1
if local {
megabytes = 100
}
fname := generateRandomFile(megabytes)
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
c := Init(true)
c.NoLocal = !local
// c.AddressTCPPorts = []string{"8154", "8155"}
c.ForceSend = forceSend
c.UseCompression = true
c.UseEncryption = true
assert.Nil(t, c.Send(fname, room))
}()
go func() {
defer wg.Done()
time.Sleep(5 * time.Second)
os.MkdirAll("test", 0755)
os.Chdir("test")
c := Init(true)
c.NoLocal = !local
// c.AddressTCPPorts = []string{"8154", "8155"}
c.ForceSend = forceSend
startTime = time.Now()
assert.Nil(t, c.Receive(room))
durationPerMegabyte = float64(megabytes) / time.Since(startTime).Seconds()
assert.True(t, utils.Exists(fname))
}()
wg.Wait()
os.Chdir("..")
os.RemoveAll("test")
os.Remove(fname)
fmt.Printf("\n-----\n%2.1f MB/s\n----\n", durationPerMegabyte)
}
func TestSendReceivePubWebsockets(t *testing.T) {
sendAndReceive(t, 1, false)
}
func TestSendReceivePubTCP(t *testing.T) {
sendAndReceive(t, 2, false)
}
func TestSendReceiveLocalWebsockets(t *testing.T) {
sendAndReceive(t, 1, true)
}
// func TestSendReceiveLocalTCP(t *testing.T) {
// sendAndReceive(t, 2, true)
// }
func generateRandomFile(megabytes int) (fname string) {
// generate a random file
bigBuff := make([]byte, 1024*1024*megabytes)
rand.Read(bigBuff)
fname = fmt.Sprintf("%dmb.file", megabytes)
ioutil.WriteFile(fname, bigBuff, 0666)
return
}

View file

@ -1,7 +0,0 @@
package croc
type WebSocketMessage struct {
messageType int
message []byte
err error
}

View file

@ -1,624 +0,0 @@
package croc
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io/ioutil"
"os"
"strconv"
"strings"
"sync"
"time"
log "github.com/cihub/seelog"
humanize "github.com/dustin/go-humanize"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
"github.com/schollz/croc/src/comm"
"github.com/schollz/croc/src/compress"
"github.com/schollz/croc/src/crypt"
"github.com/schollz/croc/src/logger"
"github.com/schollz/croc/src/models"
"github.com/schollz/croc/src/utils"
"github.com/schollz/croc/src/zipper"
"github.com/schollz/pake"
"github.com/schollz/progressbar/v2"
"github.com/schollz/spinner"
)
var DebugLevel string
// Receive is the async operation to receive a file
func (cr *Croc) startRecipient(forceSend int, serverAddress string, tcpPorts []string, isLocal bool, done chan error, c *websocket.Conn, codephrase string, noPrompt bool, useStdout bool) {
logger.SetLogLevel(DebugLevel)
err := cr.receive(forceSend, serverAddress, tcpPorts, isLocal, c, codephrase, noPrompt, useStdout)
if err != nil && strings.HasPrefix(err.Error(), "websocket: close 100") {
err = nil
}
done <- err
}
func (cr *Croc) receive(forceSend int, serverAddress string, tcpPorts []string, isLocal bool, c *websocket.Conn, codephrase string, noPrompt bool, useStdout bool) (err error) {
var sessionKey []byte
var transferTime time.Duration
var hash256 []byte
var progressFile string
var resumeFile bool
var tcpConnections []comm.Comm
var Q *pake.Pake
dataChan := make(chan []byte, 1024*1024)
isConnectedIfUsingTCP := make(chan bool)
blocks := []string{}
useWebsockets := true
switch forceSend {
case 0:
if !isLocal {
useWebsockets = false
}
case 1:
useWebsockets = true
case 2:
useWebsockets = false
}
// start a spinner
spin := spinner.New(spinner.CharSets[9], 100*time.Millisecond)
spin.Writer = os.Stderr
spin.Suffix = " connecting..."
cr.StateString = "Connecting as recipient..."
spin.Start()
defer spin.Stop()
// both parties should have a weak key
pw := []byte(codephrase)
// start the reader
websocketMessages := make(chan WebSocketMessage, 1024)
go func() {
defer func() {
if r := recover(); r != nil {
log.Debugf("recovered from %s", r)
}
}()
for {
messageType, message, err := c.ReadMessage()
websocketMessages <- WebSocketMessage{messageType, message, err}
}
}()
step := 0
for {
var websocketMessageMain WebSocketMessage
// websocketMessageMain = <-websocketMessages
timeWaitingForMessage := time.Now()
for {
done := false
select {
case websocketMessageMain = <-websocketMessages:
done = true
default:
time.Sleep(10 * time.Millisecond)
}
if done {
break
}
if time.Since(timeWaitingForMessage).Seconds() > 3 && step == 0 {
return fmt.Errorf("You are trying to receive a file with no sender.")
}
}
messageType := websocketMessageMain.messageType
message := websocketMessageMain.message
err := websocketMessageMain.err
if err != nil {
return err
}
if messageType == websocket.PongMessage || messageType == websocket.PingMessage {
continue
}
if messageType == websocket.TextMessage && bytes.Equal(message, []byte("interrupt")) {
return errors.New("\rinterrupted by other party")
}
log.Debugf("got %d: %s", messageType, message)
switch step {
case 0:
spin.Stop()
spin.Suffix = " performing PAKE..."
cr.StateString = "Performing PAKE..."
spin.Start()
// sender has initiated, sends their initial data
var initialData models.Initial
err = json.Unmarshal(message, &initialData)
if err != nil {
err = errors.Wrap(err, "incompatible versions of croc")
return err
}
cr.OtherIP = initialData.IPAddress
log.Debugf("sender IP: %s", cr.OtherIP)
// check whether the version strings are compatible
versionStringsOther := strings.Split(initialData.VersionString, ".")
versionStringsSelf := strings.Split(cr.Version, ".")
if len(versionStringsOther) == 3 && len(versionStringsSelf) == 3 {
if versionStringsSelf[0] != versionStringsOther[0] || versionStringsSelf[1] != versionStringsOther[1] {
return fmt.Errorf("version sender %s is not compatible with recipient %s", cr.Version, initialData.VersionString)
}
}
// initialize the PAKE with the curve sent from the sender
Q, err = pake.InitCurve(pw, 1, initialData.CurveType, 1*time.Millisecond)
if err != nil {
err = errors.Wrap(err, "incompatible curve type")
return err
}
// recipient begins by sending back initial data to sender
ip := ""
if isLocal {
ip = utils.LocalIP()
} else {
ip, _ = utils.PublicIP()
}
initialData.VersionString = cr.Version
initialData.IPAddress = ip
bInitialData, _ := json.Marshal(initialData)
c.WriteMessage(websocket.BinaryMessage, bInitialData)
case 1:
// Q receives u
log.Debugf("[%d] Q computes k, sends H(k), v back to P", step)
if err := Q.Update(message); err != nil {
return fmt.Errorf("Recipient is using wrong code phrase.")
}
// Q has the session key now, but we will still check if its valid
sessionKey, err = Q.SessionKey()
if err != nil {
return fmt.Errorf("Recipient is using wrong code phrase.")
}
log.Debugf("%x\n", sessionKey)
// initialize TCP connections if using (possible, but unlikely, race condition)
go func() {
log.Debug("initializing TCP connections")
if !useWebsockets {
log.Debugf("connecting to server")
tcpConnections = make([]comm.Comm, len(tcpPorts))
var wg sync.WaitGroup
wg.Add(len(tcpPorts))
for i, tcpPort := range tcpPorts {
go func(i int, tcpPort string) {
defer wg.Done()
log.Debugf("connecting to %d", i)
var message string
tcpConnections[i], message, err = connectToTCPServer(utils.SHA256(fmt.Sprintf("%d%x", i, sessionKey)), serverAddress+":"+tcpPort)
if err != nil {
log.Error(err)
}
if message != "recipient" {
log.Errorf("got wrong message: %s", message)
}
}(i, tcpPort)
}
wg.Wait()
log.Debugf("fully connected")
}
isConnectedIfUsingTCP <- true
}()
c.WriteMessage(websocket.BinaryMessage, Q.Bytes())
case 2:
log.Debugf("[%d] Q recieves H(k) from P", step)
// check if everything is still kosher with our computed session key
if err := Q.Update(message); err != nil {
log.Debug(err)
return fmt.Errorf("Recipient is using wrong code phrase.")
}
c.WriteMessage(websocket.BinaryMessage, []byte("ready"))
case 3:
spin.Stop()
cr.StateString = "Recieving file info..."
// unmarshal the file info
log.Debugf("[%d] recieve file info", step)
// do decryption on the file stats
enc, err := crypt.FromBytes(message)
if err != nil {
return err
}
decryptedFileData, err := enc.Decrypt(sessionKey)
if err != nil {
return err
}
err = json.Unmarshal(decryptedFileData, &cr.FileInfo)
if err != nil {
return err
}
log.Debugf("got file stats: %+v", cr.FileInfo)
// determine if the file is resuming or not
progressFile = fmt.Sprintf("%s.progress", cr.FileInfo.SentName)
overwritingOrReceiving := "Receiving"
if utils.Exists(cr.FileInfo.Name) || utils.Exists(cr.FileInfo.SentName) {
overwritingOrReceiving = "Overwriting"
if utils.Exists(progressFile) {
overwritingOrReceiving = "Resume receiving"
resumeFile = true
}
}
// send blocks
if resumeFile {
fileWithBlocks, _ := os.Open(progressFile)
scanner := bufio.NewScanner(fileWithBlocks)
for scanner.Scan() {
blocks = append(blocks, strings.TrimSpace(scanner.Text()))
}
fileWithBlocks.Close()
}
blocksBytes, _ := json.Marshal(blocks)
// encrypt the block data and send
encblockBytes := crypt.Encrypt(blocksBytes, sessionKey)
// wait for TCP connections if using them
_ = <-isConnectedIfUsingTCP
c.WriteMessage(websocket.BinaryMessage, encblockBytes.Bytes())
// prompt user about the file
fileOrFolder := "file"
if cr.FileInfo.IsDir {
fileOrFolder = "folder"
}
cr.WindowReceivingString = fmt.Sprintf("%s %s (%s) into: %s",
overwritingOrReceiving,
fileOrFolder,
humanize.Bytes(uint64(cr.FileInfo.Size)),
cr.FileInfo.Name,
)
fmt.Fprintf(os.Stderr, "\r%s\n",
cr.WindowReceivingString,
)
if !noPrompt {
if "y" != utils.GetInput("ok? (y/N): ") {
fmt.Fprintf(os.Stderr, "Cancelling request")
c.WriteMessage(websocket.BinaryMessage, []byte("no"))
return nil
}
}
if cr.WindowRecipientPrompt {
// wait until it switches to false
// the window should then set WindowRecipientAccept
for {
if !cr.WindowRecipientPrompt {
if cr.WindowRecipientAccept {
break
} else {
fmt.Fprintf(os.Stderr, "Cancelling request")
c.WriteMessage(websocket.BinaryMessage, []byte("no"))
return nil
}
}
time.Sleep(10 * time.Millisecond)
}
}
// await file
// erase file if overwriting
if overwritingOrReceiving == "Overwriting" {
os.Remove(cr.FileInfo.SentName)
}
var f *os.File
if utils.Exists(cr.FileInfo.SentName) && resumeFile {
if !useWebsockets {
f, err = os.OpenFile(cr.FileInfo.SentName, os.O_WRONLY, 0644)
} else {
f, err = os.OpenFile(cr.FileInfo.SentName, os.O_APPEND|os.O_WRONLY, 0644)
}
if err != nil {
log.Error(err)
return err
}
} else {
f, err = os.Create(cr.FileInfo.SentName)
if err != nil {
log.Error(err)
return err
}
if !useWebsockets {
if err = f.Truncate(cr.FileInfo.Size); err != nil {
log.Error(err)
return err
}
}
}
blockSize := 0
if useWebsockets {
blockSize = models.WEBSOCKET_BUFFER_SIZE / 8
} else {
blockSize = models.TCP_BUFFER_SIZE / 2
}
// start the ui for pgoress
cr.StateString = "Recieving file..."
bytesWritten := 0
fmt.Fprintf(os.Stderr, "\nReceiving (<-%s)...\n", cr.OtherIP)
cr.Bar = progressbar.NewOptions(
int(cr.FileInfo.Size),
progressbar.OptionSetRenderBlankState(true),
progressbar.OptionSetBytes(int(cr.FileInfo.Size)),
progressbar.OptionSetWriter(os.Stderr),
progressbar.OptionThrottle(1/60*time.Second),
)
cr.Bar.Add((len(blocks) * blockSize))
finished := make(chan bool)
go func(finished chan bool, dataChan chan []byte) (err error) {
// remove previous progress
var fProgress *os.File
var progressErr error
if resumeFile {
fProgress, progressErr = os.OpenFile(progressFile, os.O_APPEND|os.O_WRONLY, 0644)
bytesWritten = len(blocks) * blockSize
} else {
os.Remove(progressFile)
fProgress, progressErr = os.Create(progressFile)
}
if progressErr != nil {
panic(progressErr)
}
defer fProgress.Close()
blocksWritten := 0.0
blocksToWrite := float64(cr.FileInfo.Size)
if useWebsockets {
blocksToWrite = blocksToWrite/float64(models.WEBSOCKET_BUFFER_SIZE/8) - float64(len(blocks))
} else {
blocksToWrite = blocksToWrite/float64(models.TCP_BUFFER_SIZE/2) - float64(len(blocks))
}
for {
message := <-dataChan
// do decryption
var enc crypt.Encryption
err = json.Unmarshal(message, &enc)
if err != nil {
// log.Errorf("%s: [%s] [%+v] (%d/%d) %+v", err.Error(), message, message, len(message), numBytes, bs)
log.Error(err)
return err
}
decrypted, err := enc.Decrypt(sessionKey, !cr.FileInfo.IsEncrypted)
if err != nil {
log.Error(err)
return err
}
// get location if TCP
var locationToWrite int
if !useWebsockets {
pieces := bytes.SplitN(decrypted, []byte("-"), 2)
decrypted = pieces[1]
locationToWrite, _ = strconv.Atoi(string(pieces[0]))
}
// do decompression
if cr.FileInfo.IsCompressed && !cr.FileInfo.IsDir {
decrypted = compress.Decompress(decrypted)
}
var n int
if !useWebsockets {
if err != nil {
log.Error(err)
return err
}
n, err = f.WriteAt(decrypted, int64(locationToWrite))
fProgress.WriteString(fmt.Sprintf("%d\n", locationToWrite))
log.Debugf("wrote %d bytes to location %d (%2.0f/%2.0f)", n, locationToWrite, blocksWritten, blocksToWrite)
} else {
// write to file
n, err = f.Write(decrypted)
log.Debugf("wrote %d bytes to location %d (%2.0f/%2.0f)", n, bytesWritten, blocksWritten, blocksToWrite)
fProgress.WriteString(fmt.Sprintf("%d\n", bytesWritten))
}
if err != nil {
log.Error(err)
return err
}
// update the bytes written
bytesWritten += n
blocksWritten += 1.0
// update the progress bar
cr.Bar.Add(n)
if int64(bytesWritten) == cr.FileInfo.Size || blocksWritten >= blocksToWrite {
log.Debug("finished", int64(bytesWritten), cr.FileInfo.Size, blocksWritten, blocksToWrite)
break
}
}
finished <- true
return
}(finished, dataChan)
log.Debug("telling sender i'm ready")
c.WriteMessage(websocket.BinaryMessage, []byte("ready"))
startTime := time.Now()
if useWebsockets {
for {
// read from websockets
websocketMessageData := <-websocketMessages
if bytes.HasPrefix(websocketMessageData.message, []byte("error")) {
return fmt.Errorf("%s", websocketMessageData.message)
}
if websocketMessageData.messageType != websocket.BinaryMessage {
continue
}
if err != nil {
log.Error(err)
return err
}
if bytes.Equal(websocketMessageData.message, []byte("magic")) {
log.Debug("got magic")
break
}
dataChan <- websocketMessageData.message
}
} else {
log.Debugf("starting listening with tcp with %d connections", len(tcpConnections))
// check to see if any messages are sent
stopMessageSignal := make(chan bool, 1)
errorsDuringTransfer := make(chan error, 24)
go func() {
for {
select {
case sig := <-stopMessageSignal:
errorsDuringTransfer <- nil
log.Debugf("got message signal: %+v", sig)
return
case wsMessage := <-websocketMessages:
log.Debugf("got message: %s", wsMessage.message)
if bytes.HasPrefix(wsMessage.message, []byte("error")) {
log.Debug("stopping transfer")
for i := 0; i < len(tcpConnections)+1; i++ {
errorsDuringTransfer <- fmt.Errorf("%s", wsMessage.message)
}
return
}
default:
continue
}
}
}()
// using TCP
go func() {
var wg sync.WaitGroup
wg.Add(len(tcpConnections))
for i := range tcpConnections {
defer func(i int) {
log.Debugf("closing connection %d", i)
tcpConnections[i].Close()
}(i)
go func(wg *sync.WaitGroup, j int) {
defer wg.Done()
for {
select {
case _ = <-errorsDuringTransfer:
log.Debugf("%d got stop", i)
return
default:
}
log.Debugf("waiting to read on %d", j)
// read from TCP connection
message, _, _, err := tcpConnections[j].Read()
// log.Debugf("message: %s", message)
if err != nil {
panic(err)
}
if bytes.Equal(message, []byte("magic")) {
log.Debugf("%d got magic, leaving", j)
return
}
dataChan <- message
}
}(&wg, i)
}
log.Debug("waiting for tcp goroutines")
wg.Wait()
errorsDuringTransfer <- nil
}()
// block until this is done
log.Debug("waiting for error")
errorDuringTransfer := <-errorsDuringTransfer
log.Debug("sending stop message signal")
stopMessageSignal <- true
if errorDuringTransfer != nil {
log.Debugf("got error during transfer: %s", errorDuringTransfer.Error())
return errorDuringTransfer
}
}
_ = <-finished
log.Debug("telling sender i'm done")
c.WriteMessage(websocket.BinaryMessage, []byte("done"))
// we are finished
transferTime = time.Since(startTime)
// close file
err = f.Close()
if err != nil {
return err
}
// finish bar
cr.Bar.Finish()
// check hash
hash256, err = utils.HashFile(cr.FileInfo.SentName)
if err != nil {
log.Error(err)
return err
}
// tell the sender the hash so they can quit
c.WriteMessage(websocket.BinaryMessage, append([]byte("hash:"), hash256...))
case 4:
// receive the hash from the sender so we can check it and quit
log.Debugf("got hash: %x", message)
if bytes.Equal(hash256, message) {
// open directory
if cr.FileInfo.IsDir {
err = zipper.UnzipFile(cr.FileInfo.SentName, ".")
if DebugLevel != "debug" {
os.Remove(cr.FileInfo.SentName)
}
} else {
err = nil
}
if err == nil {
if useStdout && !cr.FileInfo.IsDir {
var bFile []byte
bFile, err = ioutil.ReadFile(cr.FileInfo.SentName)
if err != nil {
return err
}
os.Stdout.Write(bFile)
os.Remove(cr.FileInfo.SentName)
}
transferRate := float64(cr.FileInfo.Size) / 1000000.0 / transferTime.Seconds()
transferType := "MB/s"
if transferRate < 1 {
transferRate = float64(cr.FileInfo.Size) / 1000.0 / transferTime.Seconds()
transferType = "kB/s"
}
folderOrFile := "file"
if cr.FileInfo.IsDir {
folderOrFile = "folder"
}
if useStdout {
cr.FileInfo.Name = "stdout"
}
fmt.Fprintf(os.Stderr, "\nReceived %s written to %s (%2.1f %s)", folderOrFile, cr.FileInfo.Name, transferRate, transferType)
os.Remove(progressFile)
cr.StateString = fmt.Sprintf("Received %s written to %s (%2.1f %s)", folderOrFile, cr.FileInfo.Name, transferRate, transferType)
}
return err
} else {
if DebugLevel != "debug" {
log.Debug("removing corrupted file")
os.Remove(cr.FileInfo.SentName)
}
return errors.New("file corrupted")
}
default:
return fmt.Errorf("unknown step")
}
step++
}
}

View file

@ -1,570 +0,0 @@
package croc
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
log "github.com/cihub/seelog"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
"github.com/schollz/croc/src/comm"
"github.com/schollz/croc/src/compress"
"github.com/schollz/croc/src/crypt"
"github.com/schollz/croc/src/logger"
"github.com/schollz/croc/src/models"
"github.com/schollz/croc/src/utils"
"github.com/schollz/croc/src/zipper"
"github.com/schollz/pake"
progressbar "github.com/schollz/progressbar/v2"
"github.com/schollz/spinner"
)
// Send is the async call to send data
func (cr *Croc) startSender(forceSend int, serverAddress string, tcpPorts []string, isLocal bool, done chan error, c *websocket.Conn, fname string, codephrase string, useCompression bool, useEncryption bool) {
logger.SetLogLevel(DebugLevel)
log.Debugf("sending %s", fname)
err := cr.send(forceSend, serverAddress, tcpPorts, isLocal, c, fname, codephrase, useCompression, useEncryption)
if err != nil && strings.HasPrefix(err.Error(), "websocket: close 100") {
err = nil
}
done <- err
}
func (cr *Croc) send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool, c *websocket.Conn, fname string, codephrase string, useCompression bool, useEncryption bool) (err error) {
var f *os.File
defer f.Close() // ignore the error if it wasn't opened :(
var fileHash []byte
var startTransfer time.Time
var tcpConnections []comm.Comm
blocksToSkip := make(map[int64]struct{})
isConnectedIfUsingTCP := make(chan bool)
type DataChan struct {
b []byte
currentPostition int64
bytesRead int
err error
}
dataChan := make(chan DataChan, 1024*1024)
defer close(dataChan)
useWebsockets := true
switch forceSend {
case 0:
if !isLocal {
useWebsockets = false
}
case 1:
useWebsockets = true
case 2:
useWebsockets = false
}
fileReady := make(chan error)
// normalize the file name
fname, err = filepath.Abs(fname)
if err != nil {
return err
}
_, filename := filepath.Split(fname)
// get ready to generate session key
var sessionKey []byte
// start a spinner
spin := spinner.New(spinner.CharSets[9], 100*time.Millisecond)
spin.Writer = os.Stderr
defer spin.Stop()
// both parties should have a weak key
pw := []byte(codephrase)
// initialize sender P ("0" indicates sender)
P, err := pake.InitCurve(pw, 0, cr.CurveType, 1*time.Millisecond)
if err != nil {
return
}
// start the reader
websocketMessages := make(chan WebSocketMessage, 1024)
go func() {
defer func() {
if r := recover(); r != nil {
log.Debugf("recovered from %s", r)
}
}()
for {
messageType, message, err := c.ReadMessage()
websocketMessages <- WebSocketMessage{messageType, message, err}
}
}()
step := 0
for {
websocketMessage := <-websocketMessages
messageType := websocketMessage.messageType
message := websocketMessage.message
errRead := websocketMessage.err
if errRead != nil {
return errRead
}
if messageType == websocket.PongMessage || messageType == websocket.PingMessage {
continue
}
if messageType == websocket.TextMessage && bytes.HasPrefix(message, []byte("interrupt")) {
return errors.New("\rinterrupted by other party")
}
if messageType == websocket.TextMessage && bytes.HasPrefix(message, []byte("err")) {
return errors.New("\r" + string(message))
}
log.Debugf("got %d: %s", messageType, message)
switch step {
case 0:
// sender initiates communication
ip := ""
if isLocal {
ip = utils.LocalIP()
} else {
ip, _ = utils.PublicIP()
}
initialData := models.Initial{
CurveType: cr.CurveType,
IPAddress: ip,
VersionString: cr.Version, // version should match
}
bInitialData, _ := json.Marshal(initialData)
// send the initial data
c.WriteMessage(websocket.BinaryMessage, bInitialData)
case 1:
// first receive the initial data from the recipient
var initialData models.Initial
err = json.Unmarshal(message, &initialData)
if err != nil {
err = errors.Wrap(err, "incompatible versions of croc")
return
}
cr.OtherIP = initialData.IPAddress
log.Debugf("recipient IP: %s", cr.OtherIP)
go func() {
// recipient might want file! start gathering information about file
fstat, err := os.Stat(fname)
if err != nil {
fileReady <- err
return
}
cr.FileInfo = models.FileStats{
Name: filename,
Size: fstat.Size(),
ModTime: fstat.ModTime(),
IsDir: fstat.IsDir(),
SentName: fstat.Name(),
IsCompressed: useCompression,
IsEncrypted: useEncryption,
}
if cr.FileInfo.IsDir {
// zip the directory
cr.FileInfo.SentName, err = zipper.ZipFile(fname, true)
if err != nil {
log.Error(err)
fileReady <- err
return
}
fname = cr.FileInfo.SentName
fstat, err := os.Stat(fname)
if err != nil {
fileReady <- err
return
}
// get new size
cr.FileInfo.Size = fstat.Size()
}
// open the file
f, err = os.Open(fname)
if err != nil {
fileReady <- err
return
}
fileReady <- nil
}()
// send pake data
log.Debugf("[%d] first, P sends u to Q", step)
c.WriteMessage(websocket.BinaryMessage, P.Bytes())
// start PAKE spinnner
spin.Suffix = " performing PAKE..."
cr.StateString = "Performing PAKE..."
spin.Start()
case 2:
// P recieves H(k),v from Q
log.Debugf("[%d] P computes k, H(k), sends H(k) to Q", step)
err := P.Update(message)
c.WriteMessage(websocket.BinaryMessage, P.Bytes())
if err != nil {
return fmt.Errorf("Recipient is using wrong code phrase.")
}
sessionKey, _ = P.SessionKey()
// check(err)
log.Debugf("%x\n", sessionKey)
// wait for readiness
spin.Stop()
spin.Suffix = " waiting for recipient ok..."
cr.StateString = "Waiting for recipient ok...."
spin.Start()
case 3:
log.Debugf("[%d] recipient declares readiness for file info", step)
if !bytes.HasPrefix(message, []byte("ready")) {
return errors.New("Recipient refused file")
}
err = <-fileReady // block until file is ready
if err != nil {
return err
}
fstatsBytes, err := json.Marshal(cr.FileInfo)
if err != nil {
return err
}
// encrypt the file meta data
enc := crypt.Encrypt(fstatsBytes, sessionKey)
// send the file meta data
c.WriteMessage(websocket.BinaryMessage, enc.Bytes())
case 4:
log.Debugf("[%d] recipient gives blocks", step)
// recipient sends blocks, and sender does not send anything back
// determine if any blocks were sent to skip
enc, err := crypt.FromBytes(message)
if err != nil {
log.Error(err)
return err
}
decrypted, err := enc.Decrypt(sessionKey)
if err != nil {
err = errors.Wrap(err, "could not decrypt blocks with session key")
log.Error(err)
return err
}
var blocks []string
errBlocks := json.Unmarshal(decrypted, &blocks)
if errBlocks == nil {
for _, block := range blocks {
blockInt64, errBlock := strconv.Atoi(block)
if errBlock == nil {
blocksToSkip[int64(blockInt64)] = struct{}{}
}
}
}
log.Debugf("found blocks: %+v", blocksToSkip)
// connect to TCP in background
tcpConnections = make([]comm.Comm, len(tcpPorts))
go func() {
if !useWebsockets {
log.Debugf("connecting to server")
var wg sync.WaitGroup
wg.Add(len(tcpPorts))
for i, tcpPort := range tcpPorts {
go func(i int, tcpPort string) {
defer wg.Done()
log.Debugf("connecting to %s on connection %d", tcpPort, i)
var message string
tcpConnections[i], message, err = connectToTCPServer(utils.SHA256(fmt.Sprintf("%d%x", i, sessionKey)), serverAddress+":"+tcpPort)
if err != nil {
log.Error(err)
}
if message != "sender" {
log.Errorf("got wrong message: %s", message)
}
}(i, tcpPort)
}
wg.Wait()
}
isConnectedIfUsingTCP <- true
}()
// start loading the file into memory
// start streaming encryption/compression
if cr.FileInfo.IsDir {
// remove file if zipped
defer os.Remove(cr.FileInfo.SentName)
}
go func(dataChan chan DataChan) {
var buffer []byte
if useWebsockets {
buffer = make([]byte, models.WEBSOCKET_BUFFER_SIZE/8)
} else {
buffer = make([]byte, models.TCP_BUFFER_SIZE/2)
}
currentPostition := int64(0)
for {
bytesread, err := f.Read(buffer)
if bytesread > 0 {
if _, ok := blocksToSkip[currentPostition]; ok {
log.Debugf("skipping the sending of block %d", currentPostition)
currentPostition += int64(bytesread)
continue
}
// do compression
var compressedBytes []byte
if useCompression && !cr.FileInfo.IsDir {
compressedBytes = compress.Compress(buffer[:bytesread])
} else {
compressedBytes = buffer[:bytesread]
}
// if using TCP, prepend the location to write the data to in the resulting file
if !useWebsockets {
compressedBytes = append([]byte(fmt.Sprintf("%d-", currentPostition)), compressedBytes...)
}
// do encryption
enc := crypt.Encrypt(compressedBytes, sessionKey, !useEncryption)
encBytes, err := json.Marshal(enc)
if err != nil {
dataChan <- DataChan{
b: nil,
bytesRead: 0,
err: err,
}
return
}
dataChan <- DataChan{
b: encBytes,
bytesRead: bytesread,
err: nil,
}
currentPostition += int64(bytesread)
}
if err != nil {
if err != io.EOF {
log.Error(err)
}
break
}
}
// finish
log.Debug("sending magic")
dataChan <- DataChan{
b: []byte("magic"),
bytesRead: 0,
err: nil,
}
if !useWebsockets {
log.Debug("sending extra magic to %d others", len(tcpPorts)-1)
for i := 0; i < len(tcpPorts)-1; i++ {
log.Debug("sending magic")
dataChan <- DataChan{
b: []byte("magic"),
bytesRead: 0,
err: nil,
}
}
}
}(dataChan)
case 5:
spin.Stop()
log.Debugf("[%d] recipient declares readiness for file data", step)
if !bytes.HasPrefix(message, []byte("ready")) {
return errors.New("Recipient refused file")
}
cr.StateString = "Transfer in progress..."
fmt.Fprintf(os.Stderr, "\rSending (->%s)...\n", cr.OtherIP)
// send file, compure hash simultaneously
startTransfer = time.Now()
blockSize := 0
if useWebsockets {
blockSize = models.WEBSOCKET_BUFFER_SIZE / 8
} else {
blockSize = models.TCP_BUFFER_SIZE / 2
}
cr.Bar = progressbar.NewOptions(
int(cr.FileInfo.Size),
progressbar.OptionSetRenderBlankState(true),
progressbar.OptionSetBytes(int(cr.FileInfo.Size)),
progressbar.OptionSetWriter(os.Stderr),
progressbar.OptionThrottle(1/60*time.Second),
)
cr.Bar.Add(blockSize * len(blocksToSkip))
if useWebsockets {
for {
data := <-dataChan
if data.err != nil {
return data.err
}
cr.Bar.Add(data.bytesRead)
// write data to websockets
err = c.WriteMessage(websocket.BinaryMessage, data.b)
if err != nil {
err = errors.Wrap(err, "problem writing message")
return err
}
if bytes.Equal(data.b, []byte("magic")) {
break
}
}
} else {
_ = <-isConnectedIfUsingTCP
log.Debug("connected and ready to send on tcp")
// check to see if any messages are sent
stopMessageSignal := make(chan bool, 1)
errorsDuringTransfer := make(chan error, 24)
go func() {
for {
select {
case sig := <-stopMessageSignal:
errorsDuringTransfer <- nil
log.Debugf("got message signal: %+v", sig)
return
case wsMessage := <-websocketMessages:
log.Debugf("got message: %s", wsMessage.message)
if bytes.HasPrefix(wsMessage.message, []byte("error")) {
log.Debug("stopping transfer")
for i := 0; i < len(tcpConnections)+1; i++ {
errorsDuringTransfer <- fmt.Errorf("%s", wsMessage.message)
}
return
}
default:
continue
}
}
}()
var wg sync.WaitGroup
wg.Add(len(tcpConnections))
for i := range tcpConnections {
defer func(i int) {
log.Debugf("closing connection %d", i)
tcpConnections[i].Close()
}(i)
go func(i int, wg *sync.WaitGroup, dataChan <-chan DataChan) {
defer wg.Done()
for data := range dataChan {
select {
case _ = <-errorsDuringTransfer:
log.Debugf("%d got stop", i)
return
default:
}
if data.err != nil {
log.Error(data.err)
return
}
cr.Bar.Add(data.bytesRead)
// write data to tcp connection
_, errTcp := tcpConnections[i].Write(data.b)
if errTcp != nil {
errTcp = errors.Wrap(errTcp, "problem writing message")
log.Debug(errTcp)
errorsDuringTransfer <- errTcp
return
}
if bytes.Equal(data.b, []byte("magic")) {
log.Debugf("%d got magic", i)
return
}
}
}(i, &wg, dataChan)
}
// block until this is done
log.Debug("waiting for tcp goroutines")
wg.Wait()
log.Debug("sending stop message signal")
stopMessageSignal <- true
log.Debug("waiting for error")
errorDuringTransfer := <-errorsDuringTransfer
if errorDuringTransfer != nil {
log.Debugf("got error during transfer: %s", errorDuringTransfer.Error())
return errorDuringTransfer
}
}
cr.Bar.Finish()
log.Debug("send hash to finish file")
fileHash, err = utils.HashFile(fname)
if err != nil {
return err
}
case 6:
// recevied something, maybe the file hash
transferTime := time.Since(startTransfer)
if !bytes.HasPrefix(message, []byte("hash:")) {
log.Debugf("%s", message)
continue
}
c.WriteMessage(websocket.BinaryMessage, fileHash)
message = bytes.TrimPrefix(message, []byte("hash:"))
log.Debugf("[%d] determing whether it went ok", step)
if bytes.Equal(message, fileHash) {
log.Debug("file transfered successfully")
transferRate := float64(cr.FileInfo.Size) / 1000000.0 / transferTime.Seconds()
transferType := "MB/s"
if transferRate < 1 {
transferRate = float64(cr.FileInfo.Size) / 1000.0 / transferTime.Seconds()
transferType = "kB/s"
}
fmt.Fprintf(os.Stderr, "\nTransfer complete (%2.1f %s)", transferRate, transferType)
cr.StateString = fmt.Sprintf("Transfer complete (%2.1f %s)", transferRate, transferType)
return nil
} else {
fmt.Fprintf(os.Stderr, "\nTransfer corrupted")
return errors.New("file not transfered succesfully")
}
default:
return fmt.Errorf("unknown step")
}
step++
}
}
func connectToTCPServer(room string, address string) (com comm.Comm, message string, err error) {
connection, err := net.DialTimeout("tcp", address, 3*time.Hour)
if err != nil {
return
}
connection.SetReadDeadline(time.Now().Add(3 * time.Hour))
connection.SetDeadline(time.Now().Add(3 * time.Hour))
connection.SetWriteDeadline(time.Now().Add(3 * time.Hour))
com = comm.New(connection)
ok, err := com.Receive()
if err != nil {
return
}
log.Debugf("server says: %s", ok)
err = com.Send(room)
if err != nil {
return
}
message, err = com.Receive()
log.Debugf("server says: %s", message)
return
}

View file

@ -1,217 +0,0 @@
package croc
import (
"errors"
"fmt"
"net/http"
"os"
"os/signal"
"strings"
"time"
log "github.com/cihub/seelog"
"github.com/gorilla/websocket"
"github.com/schollz/croc/src/relay"
"github.com/schollz/croc/src/utils"
"github.com/schollz/peerdiscovery"
)
// Send the file
func (c *Croc) Send(fname, codephrase string) (err error) {
defer log.Flush()
log.Debugf("sending %s", fname)
errChan := make(chan error)
// normally attempt two connections
waitingFor := 2
// use public relay
if !c.LocalOnly {
go func() {
// atttempt to connect to public relay
errChan <- c.sendReceive(c.Address, c.AddressWebsocketPort, c.AddressTCPPorts, fname, codephrase, true, false)
}()
} else {
waitingFor = 1
}
// use local relay
if !c.NoLocal {
defer func() {
log.Debug("sending relay stop signal")
relay.Stop()
}()
go func() {
// start own relay and connect to it
go relay.Run(c.RelayWebsocketPort, c.RelayTCPPorts)
time.Sleep(250 * time.Millisecond) // race condition here, but this should work most of the time :(
// broadcast for peer discovery
go func() {
log.Debug("starting local discovery...")
discovered, err := peerdiscovery.Discover(peerdiscovery.Settings{
Limit: 1,
TimeLimit: 600 * time.Second,
Delay: 50 * time.Millisecond,
Payload: []byte(c.RelayWebsocketPort + "- " + strings.Join(c.RelayTCPPorts, ",")),
MulticastAddress: fmt.Sprintf("239.255.255.%d", 230+int64(time.Now().Minute()/5)),
})
log.Debug(discovered, err)
}()
// connect to own relay
errChan <- c.sendReceive("localhost", c.RelayWebsocketPort, c.RelayTCPPorts, fname, codephrase, true, true)
}()
} else {
waitingFor = 1
}
err = <-errChan
if err == nil || waitingFor == 1 {
log.Debug("returning")
return
}
log.Debug(err)
return <-errChan
}
// Receive the file
func (c *Croc) Receive(codephrase string) (err error) {
defer log.Flush()
log.Debug("receiving")
// use local relay first
if !c.NoLocal {
log.Debug("trying to discover")
// try to discovery codephrase and server through peer network
discovered, errDiscover := peerdiscovery.Discover(peerdiscovery.Settings{
Limit: 1,
TimeLimit: 300 * time.Millisecond,
Delay: 50 * time.Millisecond,
Payload: []byte("checking"),
AllowSelf: true,
DisableBroadcast: true,
MulticastAddress: fmt.Sprintf("239.255.255.%d", 230+int64(time.Now().Minute()/5)),
})
log.Debug("finished")
log.Debug(discovered)
if errDiscover != nil {
log.Debug(errDiscover)
}
if len(discovered) > 0 {
if discovered[0].Address == utils.LocalIP() {
discovered[0].Address = "localhost"
}
log.Debugf("discovered %s:%s", discovered[0].Address, discovered[0].Payload)
// see if we can actually connect to it
timeout := time.Duration(200 * time.Millisecond)
client := http.Client{
Timeout: timeout,
}
ports := strings.Split(string(discovered[0].Payload), "-")
if len(ports) != 2 {
return errors.New("bad payload")
}
resp, err := client.Get(fmt.Sprintf("http://%s:%s/", discovered[0].Address, ports[0]))
if err == nil {
if resp.StatusCode == http.StatusOK {
// we connected, so use this
return c.sendReceive(discovered[0].Address, strings.TrimSpace(ports[0]), strings.Split(strings.TrimSpace(ports[1]), ","), "", codephrase, false, true)
}
} else {
log.Debugf("could not connect: %s", err.Error())
}
} else {
log.Debug("discovered no peers")
}
}
// use public relay
if !c.LocalOnly {
log.Debug("using public relay")
return c.sendReceive(c.Address, c.AddressWebsocketPort, c.AddressTCPPorts, "", codephrase, false, false)
}
return errors.New("must use local or public relay")
}
func (c *Croc) sendReceive(address, websocketPort string, tcpPorts []string, fname string, codephrase string, isSender bool, isLocal bool) (err error) {
defer log.Flush()
if len(codephrase) < 4 {
return fmt.Errorf("codephrase is too short")
}
// allow interrupts from Ctl+C
interrupt := make(chan os.Signal, 1)
signal.Notify(interrupt, os.Interrupt)
done := make(chan error)
// connect to server
websocketAddress := ""
if len(websocketPort) > 0 {
websocketAddress = fmt.Sprintf("ws://%s:%s/ws?room=%s", address, websocketPort, codephrase[:3])
} else {
websocketAddress = fmt.Sprintf("ws://%s/ws?room=%s", address, codephrase[:3])
}
log.Debugf("connecting to %s", websocketAddress)
sock, _, err := websocket.DefaultDialer.Dial(websocketAddress, nil)
if err != nil {
log.Error(err)
return
}
defer sock.Close()
// tell the websockets we are connected
err = sock.WriteMessage(websocket.BinaryMessage, []byte("connected"))
if err != nil {
log.Error(err)
return err
}
if isSender {
go c.startSender(c.ForceSend, address, tcpPorts, isLocal, done, sock, fname, codephrase, c.UseCompression, c.UseEncryption)
} else {
go c.startRecipient(c.ForceSend, address, tcpPorts, isLocal, done, sock, codephrase, c.NoRecipientPrompt, c.Stdout)
}
for {
select {
case doneError := <-done:
log.Debug("received done signal")
if doneError != nil {
c.StateString = doneError.Error()
sock.WriteMessage(websocket.TextMessage, []byte("error: "+doneError.Error()))
time.Sleep(50 * time.Millisecond)
}
return doneError
case <-interrupt:
if !c.Debug {
SetDebugLevel("critical")
}
log.Debug("interrupt")
err = sock.WriteMessage(websocket.TextMessage, []byte("error: interrupted by other party"))
if err != nil {
return err
}
time.Sleep(50 * time.Millisecond)
// Cleanly close the connection by sending a close message and then
// waiting (with timeout) for the server to close the connection.
err := sock.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
if err != nil {
log.Debug("write close:", err)
return nil
}
select {
case <-done:
case <-time.After(100 * time.Millisecond):
}
return nil
}
}
}
// Relay will start a relay on the specified port
func (c *Croc) Relay() (err error) {
return relay.Run(c.RelayWebsocketPort, c.RelayTCPPorts)
}