1
1
Fork 0
mirror of https://github.com/schollz/croc.git synced 2025-10-11 05:11:06 +02:00

Remove Temporary Files if the program terminates abnormal.

Fixes #799
This commit is contained in:
Zack 2024-09-03 09:28:42 -07:00
parent 149d7364fb
commit bb74eafd36
4 changed files with 61 additions and 26 deletions

11
main.go
View file

@ -5,12 +5,13 @@ package main
//go:generate git tag -af v$VERSION -m "v$VERSION" //go:generate git tag -af v$VERSION -m "v$VERSION"
import ( import (
"log"
"os" "os"
"os/signal" "os/signal"
"syscall" "syscall"
"github.com/schollz/croc/v10/src/cli" "github.com/schollz/croc/v10/src/cli"
"github.com/schollz/croc/v10/src/utils"
log "github.com/schollz/logger"
) )
func main() { func main() {
@ -37,13 +38,17 @@ func main() {
go func() { go func() {
if err := cli.Run(); err != nil { if err := cli.Run(); err != nil {
log.Fatalln(err) log.Error(err)
} }
// Exit the program gracefully
utils.RemoveMarkedFiles()
os.Exit(0)
}() }()
// Wait for a termination signal // Wait for a termination signal
sig := <-sigs sig := <-sigs
log.Println("Received signal:", sig) log.Debugf("Received signal:", sig)
utils.RemoveMarkedFiles()
// Exit the program gracefully // Exit the program gracefully
os.Exit(0) os.Exit(0)

View file

@ -358,6 +358,7 @@ func send(c *cli.Context) (err error) {
if err != nil { if err != nil {
return return
} }
utils.MarkFileForRemoval(fnames[0])
defer func() { defer func() {
e := os.Remove(fnames[0]) e := os.Remove(fnames[0])
if e != nil { if e != nil {
@ -369,6 +370,7 @@ func send(c *cli.Context) (err error) {
if err != nil { if err != nil {
return return
} }
utils.MarkFileForRemoval(fnames[0])
defer func() { defer func() {
e := os.Remove(fnames[0]) e := os.Remove(fnames[0])
if e != nil { if e != nil {
@ -446,15 +448,9 @@ func getStdin() (fnames []string, err error) {
fnames = []string{f.Name()} fnames = []string{f.Name()}
return return
} }
func makeTempFolder() {
path := "temp"
if _, err := os.Stat(path); os.IsNotExist(err) {
os.Mkdir(path, os.ModePerm)
}
}
func makeTempFileWithString(s string) (fnames []string, err error) { func makeTempFileWithString(s string) (fnames []string, err error) {
makeTempFolder() f, err := os.CreateTemp(".", "croc-stdin-")
f, err := os.CreateTemp("temp", "croc-stdin-")
if err != nil { if err != nil {
return return
} }

View file

@ -388,6 +388,7 @@ func GetFilesInfo(fnames []string, zipfolder bool, ignoreGit bool) (filesInfo []
fpath = filepath.Dir(fpath) fpath = filepath.Dir(fpath)
dest := filepath.Base(fpath) + ".zip" dest := filepath.Base(fpath) + ".zip"
utils.ZipDirectory(dest, fpath) utils.ZipDirectory(dest, fpath)
utils.MarkFileForRemoval(dest)
stat, errStat = os.Lstat(dest) stat, errStat = os.Lstat(dest)
if errStat != nil { if errStat != nil {
err = errStat err = errStat

View file

@ -11,7 +11,6 @@ import (
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"io" "io"
"log"
"math" "math"
"math/big" "math/big"
"net" "net"
@ -26,6 +25,7 @@ import (
"github.com/kalafut/imohash" "github.com/kalafut/imohash"
"github.com/minio/highwayhash" "github.com/minio/highwayhash"
"github.com/pion/stun" "github.com/pion/stun"
log "github.com/schollz/logger"
"github.com/schollz/mnemonicode" "github.com/schollz/mnemonicode"
"github.com/schollz/progressbar/v3" "github.com/schollz/progressbar/v3"
) )
@ -276,7 +276,8 @@ func PublicIP() (ip string, err error) {
func LocalIP() string { func LocalIP() string {
conn, err := net.Dial("udp", "8.8.8.8:80") conn, err := net.Dial("udp", "8.8.8.8:80")
if err != nil { if err != nil {
log.Fatal(err) log.Error(err)
return ""
} }
defer conn.Close() defer conn.Close()
@ -477,12 +478,12 @@ func IsLocalIP(ipaddress string) bool {
func ZipDirectory(destination string, source string) (err error) { func ZipDirectory(destination string, source string) (err error) {
if _, err = os.Stat(destination); err == nil { if _, err = os.Stat(destination); err == nil {
log.Fatalf("%s file already exists!\n", destination) log.Errorf("%s file already exists!\n", destination)
} }
fmt.Fprintf(os.Stderr, "Zipping %s to %s\n", source, destination) fmt.Fprintf(os.Stderr, "Zipping %s to %s\n", source, destination)
file, err := os.Create(destination) file, err := os.Create(destination)
if err != nil { if err != nil {
log.Fatalln(err) log.Error(err)
} }
defer file.Close() defer file.Close()
writer := zip.NewWriter(file) writer := zip.NewWriter(file)
@ -493,22 +494,22 @@ func ZipDirectory(destination string, source string) (err error) {
defer writer.Close() defer writer.Close()
err = filepath.Walk(source, func(path string, info os.FileInfo, err error) error { err = filepath.Walk(source, func(path string, info os.FileInfo, err error) error {
if err != nil { if err != nil {
log.Fatalln(err) log.Error(err)
} }
if info.Mode().IsRegular() { if info.Mode().IsRegular() {
f1, err := os.Open(path) f1, err := os.Open(path)
if err != nil { if err != nil {
log.Fatalln(err) log.Error(err)
} }
defer f1.Close() defer f1.Close()
zipPath := strings.ReplaceAll(path, source, strings.TrimSuffix(destination, ".zip")) zipPath := strings.ReplaceAll(path, source, strings.TrimSuffix(destination, ".zip"))
zipPath = filepath.ToSlash(zipPath) zipPath = filepath.ToSlash(zipPath)
w1, err := writer.Create(zipPath) w1, err := writer.Create(zipPath)
if err != nil { if err != nil {
log.Fatalln(err) log.Error(err)
} }
if _, err := io.Copy(w1, f1); err != nil { if _, err := io.Copy(w1, f1); err != nil {
log.Fatalln(err) log.Error(err)
} }
fmt.Fprintf(os.Stderr, "\r\033[2K") fmt.Fprintf(os.Stderr, "\r\033[2K")
fmt.Fprintf(os.Stderr, "\rAdding %s", zipPath) fmt.Fprintf(os.Stderr, "\rAdding %s", zipPath)
@ -516,7 +517,7 @@ func ZipDirectory(destination string, source string) (err error) {
return nil return nil
}) })
if err != nil { if err != nil {
log.Fatalln(err) log.Error(err)
} }
fmt.Fprintf(os.Stderr, "\n") fmt.Fprintf(os.Stderr, "\n")
return nil return nil
@ -525,7 +526,7 @@ func ZipDirectory(destination string, source string) (err error) {
func UnzipDirectory(destination string, source string) error { func UnzipDirectory(destination string, source string) error {
archive, err := zip.OpenReader(source) archive, err := zip.OpenReader(source)
if err != nil { if err != nil {
log.Fatalln(err) log.Error(err)
} }
defer archive.Close() defer archive.Close()
@ -537,7 +538,7 @@ func UnzipDirectory(destination string, source string) error {
// make sure the filepath does not have ".." // make sure the filepath does not have ".."
filePath = filepath.Clean(filePath) filePath = filepath.Clean(filePath)
if strings.Contains(filePath, "..") { if strings.Contains(filePath, "..") {
log.Fatalf("Invalid file path %s\n", filePath) log.Errorf("Invalid file path %s\n", filePath)
} }
if f.FileInfo().IsDir() { if f.FileInfo().IsDir() {
os.MkdirAll(filePath, os.ModePerm) os.MkdirAll(filePath, os.ModePerm)
@ -545,7 +546,7 @@ func UnzipDirectory(destination string, source string) error {
} }
if err := os.MkdirAll(filepath.Dir(filePath), os.ModePerm); err != nil { if err := os.MkdirAll(filepath.Dir(filePath), os.ModePerm); err != nil {
log.Fatalln(err) log.Error(err)
} }
// check if file exists // check if file exists
@ -560,16 +561,16 @@ func UnzipDirectory(destination string, source string) error {
dstFile, err := os.OpenFile(filePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode()) dstFile, err := os.OpenFile(filePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode())
if err != nil { if err != nil {
log.Fatalln(err) log.Error(err)
} }
fileInArchive, err := f.Open() fileInArchive, err := f.Open()
if err != nil { if err != nil {
log.Fatalln(err) log.Error(err)
} }
if _, err := io.Copy(dstFile, fileInArchive); err != nil { if _, err := io.Copy(dstFile, fileInArchive); err != nil {
log.Fatalln(err) log.Error(err)
} }
dstFile.Close() dstFile.Close()
@ -610,3 +611,35 @@ func ValidFileName(fname string) (err error) {
} }
return return
} }
const crocRemovalFile = "croc-marked-files.txt"
func MarkFileForRemoval(fname string) {
// append the fname to the list of files to remove
f, err := os.OpenFile(crocRemovalFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o600)
if err != nil {
log.Debug(err)
return
}
defer f.Close()
_, err = f.WriteString(fname + "\n")
}
func RemoveMarkedFiles() (err error) {
// read the file and remove all the files
f, err := os.Open(crocRemovalFile)
if err != nil {
return
}
defer f.Close()
scanner := bufio.NewScanner(f)
for scanner.Scan() {
fname := scanner.Text()
err = os.Remove(fname)
if err == nil {
log.Tracef("Removed %s", fname)
}
}
os.Remove(crocRemovalFile)
return
}