From 81bc06eabb295417186b33770ae6c1f90bf76909 Mon Sep 17 00:00:00 2001 From: Zack Scholl Date: Sun, 21 Oct 2018 08:21:58 -0700 Subject: [PATCH] add configuration file --- go.mod | 3 +- src/cli/cli.go | 27 +++++++-- src/croc/config.go | 134 ++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 154 insertions(+), 10 deletions(-) diff --git a/go.mod b/go.mod index 177934f9..3f74465a 100644 --- a/go.mod +++ b/go.mod @@ -1,7 +1,7 @@ module github.com/schollz/croc require ( - github.com/BurntSushi/toml v0.3.1 // indirect + github.com/BurntSushi/toml v0.3.1 github.com/cihub/seelog v0.0.0-20170130134532-f561c5e57575 github.com/dustin/go-humanize v1.0.0 github.com/fatih/color v1.7.0 // indirect @@ -9,6 +9,7 @@ require ( github.com/gorilla/websocket v1.4.0 github.com/mattn/go-colorable v0.0.9 // indirect github.com/mattn/go-isatty v0.0.4 // indirect + github.com/mitchellh/go-homedir v1.0.0 github.com/pkg/errors v0.8.0 github.com/schollz/mnemonicode v1.0.1 github.com/schollz/pake v1.1.0 diff --git a/src/cli/cli.go b/src/cli/cli.go index e76c2b88..9afe16d4 100644 --- a/src/cli/cli.go +++ b/src/cli/cli.go @@ -60,6 +60,16 @@ func Run() { return relay(c) }, }, + { + Name: "config", + Usage: "generates a config file", + Description: "the croc config can be used to set static parameters", + Flags: []cli.Flag{}, + HelpName: "croc config", + Action: func(c *cli.Context) error { + return saveDefaultConfig(c) + }, + }, } app.Flags = []cli.Flag{ cli.StringFlag{Name: "addr", Value: "croc4.schollz.com", Usage: "address of the public relay"}, @@ -115,6 +125,10 @@ func Run() { } } +func saveDefaultConfig(c *cli.Context) error { + return croc.SaveDefaultConfig() +} + func send(c *cli.Context) error { stat, _ := os.Stdin.Stat() var fname string @@ -147,11 +161,12 @@ func send(c *cli.Context) error { cr.UseCompression = !c.Bool("no-compress") cr.UseEncryption = !c.Bool("no-encrypt") if c.String("code") != "" { - codePhrase = c.String("code") + cr.Codephrase = c.String("code") } - if len(codePhrase) == 0 { + cr.LoadConfig() + if len(cr.Codephrase) == 0 { // generate code phrase - codePhrase = utils.GetRandomName() + cr.Codephrase = utils.GetRandomName() } // print the text @@ -176,10 +191,10 @@ func send(c *cli.Context) error { humanize.Bytes(uint64(fsize)), fileOrFolder, filename, - codePhrase, - codePhrase, + cr.Codephrase, + cr.Codephrase, ) - return cr.Send(fname, codePhrase) + return cr.Send(fname, cr.Codephrase) } func receive(c *cli.Context) error { diff --git a/src/croc/config.go b/src/croc/config.go index 0e433427..b1750599 100644 --- a/src/croc/config.go +++ b/src/croc/config.go @@ -2,9 +2,16 @@ package croc import ( "bytes" + "fmt" + "io/ioutil" + "os" + "path" + "path/filepath" "time" "github.com/BurntSushi/toml" + homedir "github.com/mitchellh/go-homedir" + "github.com/schollz/croc/src/utils" ) type Config struct { @@ -33,8 +40,7 @@ type Config struct { Codephrase string } -// DefaultConfig returns the default config -func DefaultConfig() string { +func defaultConfig() Config { c := Config{} cr := Init(false) c.RelayWebsocketPort = cr.RelayWebsocketPort @@ -53,7 +59,129 @@ func DefaultConfig() string { c.ForceTCP = false c.ForceWebsockets = false c.Codephrase = "" + return c +} + +func SaveDefaultConfig() error { + homedir, err := homedir.Dir() + if err != nil { + return err + } + os.MkdirAll(path.Join(homedir, ".config", "croc"), 0644) + c := defaultConfig() buf := new(bytes.Buffer) toml.NewEncoder(buf).Encode(c) - return buf.String() + confTOML := buf.String() + err = ioutil.WriteFile(path.Join(homedir, ".config", "croc", "config.toml"), []byte(confTOML), 0644) + if err == nil { + fmt.Printf("Default config file written at '%s'", filepath.Clean(path.Join(homedir, ".config", "croc", "config.toml"))) + } + return err +} + +// LoadConfig will override parameters +func (cr *Croc) LoadConfig() (err error) { + homedir, err := homedir.Dir() + if err != nil { + return err + } + pathToConfig := path.Join(homedir, ".config", "croc", "config.toml") + if !utils.Exists(pathToConfig) { + // ignore if doesn't exist + return nil + } + + var c Config + _, err = toml.DecodeFile(pathToConfig, &c) + if err != nil { + return + } + + cDefault := defaultConfig() + // only load if things are different than defaults + // just in case the CLI parameters are used + if c.RelayWebsocketPort != cDefault.RelayWebsocketPort && cr.RelayWebsocketPort == cDefault.RelayWebsocketPort { + cr.RelayWebsocketPort = c.RelayWebsocketPort + fmt.Printf("loaded RelayWebsocketPort from config\n") + } + if !slicesEqual(c.RelayTCPPorts, cDefault.RelayTCPPorts) && slicesEqual(cr.RelayTCPPorts, cDefault.RelayTCPPorts) { + cr.RelayTCPPorts = c.RelayTCPPorts + fmt.Printf("loaded RelayTCPPorts from config\n") + } + if c.CurveType != cDefault.CurveType && cr.CurveType == cDefault.CurveType { + cr.CurveType = c.CurveType + fmt.Printf("loaded CurveType from config\n") + } + if c.PublicServerIP != cDefault.PublicServerIP && cr.Address == cDefault.PublicServerIP { + cr.Address = c.PublicServerIP + fmt.Printf("loaded Address from config\n") + } + if !slicesEqual(c.AddressTCPPorts, cDefault.AddressTCPPorts) { + cr.AddressTCPPorts = c.AddressTCPPorts + fmt.Printf("loaded AddressTCPPorts from config\n") + } + if c.AddressWebsocketPort != cDefault.AddressWebsocketPort && cr.AddressWebsocketPort == cDefault.AddressWebsocketPort { + cr.AddressWebsocketPort = c.AddressWebsocketPort + fmt.Printf("loaded AddressWebsocketPort from config\n") + } + if c.Timeout != cDefault.Timeout && cr.Timeout == cDefault.Timeout { + cr.Timeout = c.Timeout + fmt.Printf("loaded Timeout from config\n") + } + if c.LocalOnly != cDefault.LocalOnly && cr.LocalOnly == cDefault.LocalOnly { + cr.LocalOnly = c.LocalOnly + fmt.Printf("loaded LocalOnly from config\n") + } + if c.NoLocal != cDefault.NoLocal && cr.NoLocal == cDefault.NoLocal { + cr.NoLocal = c.NoLocal + fmt.Printf("loaded NoLocal from config\n") + } + if c.UseEncryption != cDefault.UseEncryption && cr.UseEncryption == cDefault.UseEncryption { + cr.UseEncryption = c.UseEncryption + fmt.Printf("loaded UseEncryption from config\n") + } + if c.UseCompression != cDefault.UseCompression && cr.UseCompression == cDefault.UseCompression { + cr.UseCompression = c.UseCompression + fmt.Printf("loaded UseCompression from config\n") + } + if c.AllowLocalDiscovery != cDefault.AllowLocalDiscovery && cr.AllowLocalDiscovery == cDefault.AllowLocalDiscovery { + cr.AllowLocalDiscovery = c.AllowLocalDiscovery + fmt.Printf("loaded AllowLocalDiscovery from config\n") + } + if c.NoRecipientPrompt != cDefault.NoRecipientPrompt && cr.NoRecipientPrompt == cDefault.NoRecipientPrompt { + cr.NoRecipientPrompt = c.NoRecipientPrompt + fmt.Printf("loaded NoRecipientPrompt from config\n") + } + if c.ForceWebsockets { + cr.ForceSend = 1 + } + if c.ForceTCP { + cr.ForceSend = 2 + } + if c.Codephrase != cDefault.Codephrase && cr.Codephrase == cDefault.Codephrase { + cr.Codephrase = c.Codephrase + fmt.Printf("loaded Codephrase from config\n") + } + return +} + +// slicesEqual checcks if two slices are equal +// from https://stackoverflow.com/a/15312097 +func slicesEqual(a, b []string) bool { + // If one is nil, the other must also be nil. + if (a == nil) != (b == nil) { + return false + } + + if len(a) != len(b) { + return false + } + + for i := range a { + if a[i] != b[i] { + return false + } + } + + return true }