0
0
Fork 0
mirror of https://github.com/schollz/croc.git synced 2025-10-11 13:21:00 +02:00

Merge pull request #90 from schollz/resume

Resume
This commit is contained in:
Zack 2018-10-09 06:59:06 -07:00 committed by GitHub
commit 5b5c05d694
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 200 additions and 96 deletions

View file

@ -1,6 +1,7 @@
package recipient package recipient
import ( import (
"bufio"
"bytes" "bytes"
"encoding/json" "encoding/json"
"errors" "errors"
@ -50,8 +51,11 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
var transferTime time.Duration var transferTime time.Duration
var hash256 []byte var hash256 []byte
var otherIP string var otherIP string
var progressFile string
var resumeFile bool
var tcpConnections []comm.Comm var tcpConnections []comm.Comm
dataChan := make(chan []byte, 1024*1024) dataChan := make(chan []byte, 1024*1024)
blocks := []string{}
useWebsockets := true useWebsockets := true
switch forceSend { switch forceSend {
@ -129,6 +133,7 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
return err return err
} }
log.Debugf("%x\n", sessionKey) log.Debugf("%x\n", sessionKey)
c.WriteMessage(websocket.BinaryMessage, []byte("ready")) c.WriteMessage(websocket.BinaryMessage, []byte("ready"))
case 3: case 3:
spin.Stop() spin.Stop()
@ -151,9 +156,14 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
log.Debugf("got file stats: %+v", fstats) log.Debugf("got file stats: %+v", fstats)
// prompt user if its okay to receive file // prompt user if its okay to receive file
progressFile = fmt.Sprintf("%s.progress", fstats.SentName)
overwritingOrReceiving := "Receiving" overwritingOrReceiving := "Receiving"
if utils.Exists(fstats.Name) { if utils.Exists(fstats.Name) || utils.Exists(fstats.SentName) {
overwritingOrReceiving = "Overwriting" overwritingOrReceiving = "Overwriting"
if utils.Exists(progressFile) {
overwritingOrReceiving = "Resume receiving"
resumeFile = true
}
} }
fileOrFolder := "file" fileOrFolder := "file"
if fstats.IsDir { if fstats.IsDir {
@ -189,15 +199,50 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
} }
// await file // await file
f, err := os.Create(fstats.SentName) var f *os.File
if err != nil { if utils.Exists(fstats.SentName) && resumeFile {
log.Error(err) if !useWebsockets {
return err f, err = os.OpenFile(fstats.SentName, os.O_WRONLY, 0644)
} else {
f, err = os.OpenFile(fstats.SentName, os.O_APPEND, 0644)
}
if err != nil {
log.Error(err)
return err
}
} else {
f, err = os.Create(fstats.SentName)
if err != nil {
log.Error(err)
return err
}
if !useWebsockets {
if err = f.Truncate(fstats.Size); err != nil {
log.Error(err)
return err
}
}
} }
if err = f.Truncate(fstats.Size); err != nil {
log.Error(err) // append the previous blocks if there was progress previously
return err if resumeFile {
file, _ := os.Open(progressFile)
scanner := bufio.NewScanner(file)
for scanner.Scan() {
blocks = append(blocks, strings.TrimSpace(scanner.Text()))
}
file.Close()
} }
blocksBytes, _ := json.Marshal(blocks)
blockSize := 0
if useWebsockets {
blockSize = models.WEBSOCKET_BUFFER_SIZE / 8
} else {
blockSize = models.TCP_BUFFER_SIZE / 2
}
// start the ui for pgoress
bytesWritten := 0 bytesWritten := 0
fmt.Fprintf(os.Stderr, "\nReceiving (<-%s)...\n", otherIP) fmt.Fprintf(os.Stderr, "\nReceiving (<-%s)...\n", otherIP)
bar := progressbar.NewOptions( bar := progressbar.NewOptions(
@ -206,9 +251,32 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
progressbar.OptionSetBytes(int(fstats.Size)), progressbar.OptionSetBytes(int(fstats.Size)),
progressbar.OptionSetWriter(os.Stderr), progressbar.OptionSetWriter(os.Stderr),
) )
bar.Add((len(blocks) * blockSize))
finished := make(chan bool) finished := make(chan bool)
go func(finished chan bool, dataChan chan []byte) (err error) { go func(finished chan bool, dataChan chan []byte) (err error) {
// remove previous progress
var fProgress *os.File
var progressErr error
if resumeFile {
fProgress, progressErr = os.OpenFile(progressFile, os.O_APPEND, 0644)
bytesWritten = len(blocks) * blockSize
} else {
os.Remove(progressFile)
fProgress, progressErr = os.Create(progressFile)
}
if progressErr != nil {
panic(progressErr)
}
defer fProgress.Close()
blocksWritten := 0.0
blocksToWrite := float64(fstats.Size)
if useWebsockets {
blocksToWrite = blocksToWrite/float64(models.WEBSOCKET_BUFFER_SIZE/8) - float64(len(blocks))
} else {
blocksToWrite = blocksToWrite/float64(models.TCP_BUFFER_SIZE/2) - float64(len(blocks))
}
for { for {
message := <-dataChan message := <-dataChan
// do decryption // do decryption
@ -245,19 +313,25 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
return err return err
} }
n, err = f.WriteAt(decrypted, int64(locationToWrite)) n, err = f.WriteAt(decrypted, int64(locationToWrite))
fProgress.WriteString(fmt.Sprintf("%d\n", locationToWrite))
log.Debugf("writing to location %d (%2.0f/%2.0f)", locationToWrite, blocksWritten, blocksToWrite)
} else { } else {
// write to file // write to file
n, err = f.Write(decrypted) n, err = f.Write(decrypted)
log.Debugf("writing to location %d (%2.0f/%2.0f)", bytesWritten, blocksWritten, blocksToWrite)
fProgress.WriteString(fmt.Sprintf("%d\n", bytesWritten))
} }
if err != nil { if err != nil {
log.Error(err)
return err return err
} }
// update the bytes written // update the bytes written
bytesWritten += n bytesWritten += n
blocksWritten += 1.0
// update the progress bar // update the progress bar
bar.Add(n) bar.Add(n)
if int64(bytesWritten) == fstats.Size { if int64(bytesWritten) == fstats.Size || blocksWritten >= blocksToWrite {
log.Debug("finished") log.Debug("finished")
break break
} }
@ -267,7 +341,8 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
}(finished, dataChan) }(finished, dataChan)
log.Debug("telling sender i'm ready") log.Debug("telling sender i'm ready")
c.WriteMessage(websocket.BinaryMessage, []byte("ready")) c.WriteMessage(websocket.BinaryMessage, append([]byte("ready"), blocksBytes...))
startTime := time.Now() startTime := time.Now()
if useWebsockets { if useWebsockets {
for { for {
@ -388,6 +463,7 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
fstats.Name = "stdout" fstats.Name = "stdout"
} }
fmt.Fprintf(os.Stderr, "\nReceived %s written to %s (%2.1f %s)\n", folderOrFile, fstats.Name, transferRate, transferType) fmt.Fprintf(os.Stderr, "\nReceived %s written to %s (%2.1f %s)\n", folderOrFile, fstats.Name, transferRate, transferType)
os.Remove(progressFile)
} }
return err return err
} else { } else {
@ -397,7 +473,6 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
} }
return errors.New("file corrupted") return errors.New("file corrupted")
} }
default: default:
return fmt.Errorf("unknown step") return fmt.Errorf("unknown step")
} }

View file

@ -8,6 +8,7 @@ import (
"net" "net"
"os" "os"
"path/filepath" "path/filepath"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -52,6 +53,7 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
var otherIP string var otherIP string
var startTransfer time.Time var startTransfer time.Time
var tcpConnections []comm.Comm var tcpConnections []comm.Comm
blocksToSkip := make(map[int64]struct{})
type DataChan struct { type DataChan struct {
b []byte b []byte
@ -169,87 +171,6 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
} }
fileReady <- nil fileReady <- nil
// start streaming encryption/compression
go func(dataChan chan DataChan) {
var buffer []byte
if useWebsockets {
buffer = make([]byte, models.WEBSOCKET_BUFFER_SIZE/8)
} else {
buffer = make([]byte, models.TCP_BUFFER_SIZE/2)
}
currentPostition := int64(0)
for {
bytesread, err := f.Read(buffer)
if bytesread > 0 {
// do compression
var compressedBytes []byte
if useCompression && !fstats.IsDir {
compressedBytes = compress.Compress(buffer[:bytesread])
} else {
compressedBytes = buffer[:bytesread]
}
// if using TCP, prepend the location to write the data to in the resulting file
if !useWebsockets {
compressedBytes = append([]byte(fmt.Sprintf("%d-", currentPostition)), compressedBytes...)
}
// do encryption
enc := crypt.Encrypt(compressedBytes, sessionKey, !useEncryption)
encBytes, err := json.Marshal(enc)
if err != nil {
dataChan <- DataChan{
b: nil,
bytesRead: 0,
err: err,
}
return
}
select {
case dataChan <- DataChan{
b: encBytes,
bytesRead: bytesread,
err: nil,
}:
default:
log.Debug("blocked")
// no message sent
// block
dataChan <- DataChan{
b: encBytes,
bytesRead: bytesread,
err: nil,
}
}
currentPostition += int64(bytesread)
}
if err != nil {
if err != io.EOF {
log.Error(err)
}
break
}
}
// finish
log.Debug("sending magic")
dataChan <- DataChan{
b: []byte("magic"),
bytesRead: 0,
err: nil,
}
if !useWebsockets {
log.Debug("sending extra magic to %d others", len(tcpPorts)-1)
for i := 0; i < len(tcpPorts)-1; i++ {
log.Debug("sending magic")
dataChan <- DataChan{
b: []byte("magic"),
bytesRead: 0,
err: nil,
}
}
}
}(dataChan)
}() }()
// send pake data // send pake data
@ -275,9 +196,10 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
spin.Start() spin.Start()
case 3: case 3:
log.Debugf("[%d] recipient declares readiness for file info", step) log.Debugf("[%d] recipient declares readiness for file info", step)
if !bytes.Equal(message, []byte("ready")) { if !bytes.HasPrefix(message, []byte("ready")) {
return errors.New("recipient refused file") return errors.New("recipient refused file")
} }
err = <-fileReady // block until file is ready err = <-fileReady // block until file is ready
if err != nil { if err != nil {
return err return err
@ -295,10 +217,111 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
spin.Stop() spin.Stop()
log.Debugf("[%d] recipient declares readiness for file data", step) log.Debugf("[%d] recipient declares readiness for file data", step)
if !bytes.Equal(message, []byte("ready")) { if !bytes.HasPrefix(message, []byte("ready")) {
return errors.New("recipient refused file") return errors.New("recipient refused file")
} }
// determine if any blocks were sent to skip
var blocks []string
errBlocks := json.Unmarshal(message[5:], &blocks)
if errBlocks == nil {
for _, block := range blocks {
blockInt64, errBlock := strconv.Atoi(block)
if errBlock == nil {
blocksToSkip[int64(blockInt64)] = struct{}{}
}
}
}
log.Debugf("found blocks: %+v", blocksToSkip)
// start streaming encryption/compression
go func(dataChan chan DataChan) {
var buffer []byte
if useWebsockets {
buffer = make([]byte, models.WEBSOCKET_BUFFER_SIZE/8)
} else {
buffer = make([]byte, models.TCP_BUFFER_SIZE/2)
}
currentPostition := int64(0)
for {
bytesread, err := f.Read(buffer)
if bytesread > 0 {
if _, ok := blocksToSkip[currentPostition]; ok {
log.Debugf("skipping the sending of block %d", currentPostition)
currentPostition += int64(bytesread)
continue
}
// do compression
var compressedBytes []byte
if useCompression && !fstats.IsDir {
compressedBytes = compress.Compress(buffer[:bytesread])
} else {
compressedBytes = buffer[:bytesread]
}
// if using TCP, prepend the location to write the data to in the resulting file
if !useWebsockets {
compressedBytes = append([]byte(fmt.Sprintf("%d-", currentPostition)), compressedBytes...)
}
// do encryption
enc := crypt.Encrypt(compressedBytes, sessionKey, !useEncryption)
encBytes, err := json.Marshal(enc)
if err != nil {
dataChan <- DataChan{
b: nil,
bytesRead: 0,
err: err,
}
return
}
select {
case dataChan <- DataChan{
b: encBytes,
bytesRead: bytesread,
err: nil,
}:
default:
log.Debug("blocked")
// no message sent
// block
dataChan <- DataChan{
b: encBytes,
bytesRead: bytesread,
err: nil,
}
}
currentPostition += int64(bytesread)
}
if err != nil {
if err != io.EOF {
log.Error(err)
}
break
}
}
// finish
log.Debug("sending magic")
dataChan <- DataChan{
b: []byte("magic"),
bytesRead: 0,
err: nil,
}
if !useWebsockets {
log.Debug("sending extra magic to %d others", len(tcpPorts)-1)
for i := 0; i < len(tcpPorts)-1; i++ {
log.Debug("sending magic")
dataChan <- DataChan{
b: []byte("magic"),
bytesRead: 0,
err: nil,
}
}
}
}(dataChan)
// connect to TCP to receive file // connect to TCP to receive file
if !useWebsockets { if !useWebsockets {
log.Debugf("connecting to server") log.Debugf("connecting to server")
@ -318,12 +341,19 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
// send file, compure hash simultaneously // send file, compure hash simultaneously
startTransfer = time.Now() startTransfer = time.Now()
blockSize := 0
if useWebsockets {
blockSize = models.WEBSOCKET_BUFFER_SIZE / 8
} else {
blockSize = models.TCP_BUFFER_SIZE / 2
}
bar := progressbar.NewOptions( bar := progressbar.NewOptions(
int(fstats.Size), int(fstats.Size),
progressbar.OptionSetRenderBlankState(true), progressbar.OptionSetRenderBlankState(true),
progressbar.OptionSetBytes(int(fstats.Size)), progressbar.OptionSetBytes(int(fstats.Size)),
progressbar.OptionSetWriter(os.Stderr), progressbar.OptionSetWriter(os.Stderr),
) )
bar.Add(blockSize * len(blocksToSkip))
if useWebsockets { if useWebsockets {
for { for {

View file

@ -4,7 +4,6 @@ import (
"archive/zip" "archive/zip"
"compress/flate" "compress/flate"
"io" "io"
"io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
@ -97,7 +96,7 @@ func ZipFile(fname string, compress bool) (writtenFilename string, err error) {
return return
} }
log.Debugf("current directory: %s", curdir) log.Debugf("current directory: %s", curdir)
newfile, err := ioutil.TempFile(curdir, filename+".") newfile, err := os.Create(fname + ".croc.zip")
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return return