diff --git a/main.go b/main.go index bddde771..cdfe99bf 100644 --- a/main.go +++ b/main.go @@ -12,10 +12,12 @@ func main() { flag.Parse() c := croc.Init() - if *role == 0 { + if *role == -1 { err = c.Relay() - } else if *role == 1 { + } else if *role == 0 { err = c.Send("foo") + } else { + err = c.Receive() } if err != nil { panic(err) diff --git a/src/api.go b/src/api.go index c1fb7312..35116773 100644 --- a/src/api.go +++ b/src/api.go @@ -2,20 +2,6 @@ package croc import "time" -type Croc struct { - TcpPorts []string - ServerPort string - Timeout time.Duration - UseEncryption bool - UseCompression bool - CurveType string - AllowLocalDiscovery bool - - // private variables - // rs relay state is only for the relay - rs relayState -} - // Init will initialize the croc relay func Init() (c *Croc) { c = new(Croc) @@ -26,15 +12,14 @@ func Init() (c *Croc) { c.UseCompression = true c.AllowLocalDiscovery = true c.CurveType = "p521" + c.rs.Lock() + c.rs.channel = make(map[string]*channelData) + c.rs.Unlock() return } // Relay initiates a relay func (c *Croc) Relay() error { - c.rs.Lock() - c.rs.channel = make(map[string]*channelData) - c.rs.Unlock() - // start relay go c.startRelay(c.TcpPorts) @@ -44,12 +29,12 @@ func (c *Croc) Relay() error { // Send will take an existing file or folder and send it through the croc relay func (c *Croc) Send(fname string) (err error) { - err = c.send(fname) + err = c.client(0) return } // Receive will receive something through the croc relay func (c *Croc) Receive() (err error) { - + err = c.client(1) return } diff --git a/src/client.go b/src/client.go new file mode 100644 index 00000000..3ceada16 --- /dev/null +++ b/src/client.go @@ -0,0 +1,137 @@ +package croc + +import ( + "errors" + "net/url" + "os" + "os/signal" + "time" + + log "github.com/cihub/seelog" + "github.com/gorilla/websocket" +) + +func (c *Croc) client(role int) (err error) { + defer log.Flush() + + // initialize the channel data for this client + c.cs.Lock() + c.cs.channel = newChannelData("") + c.cs.Unlock() + + interrupt := make(chan os.Signal, 1) + signal.Notify(interrupt, os.Interrupt) + + // connect to the websocket + // TODO: + // use predefined host and HTTPS, if exists + u := url.URL{Scheme: "ws", Host: "localhost:8003", Path: "/"} + log.Debugf("connecting to %s", u.String()) + ws, _, err := websocket.DefaultDialer.Dial(u.String(), nil) + if err != nil { + log.Error("dial:", err) + return + } + defer ws.Close() + + // read in the messages and process them + done := make(chan struct{}) + go func() { + defer close(done) + for { + var cd channelData + err := ws.ReadJSON(&cd) + if err != nil { + log.Debugf("sender read error:", err) + return + } + log.Debugf("recv: %+v", cd) + err = c.processState(cd) + if err != nil { + log.Warn(err) + return + } + } + }() + + // initialize by joining as corresponding role + // TODO: + // allowing suggesting a channel + err = ws.WriteJSON(payload{ + Open: true, + Role: role, + }) + if err != nil { + log.Errorf("problem opening: %s", err.Error()) + return + } + + for { + select { + case <-done: + return + case <-interrupt: + // send Close signal to relay on interrupt + log.Debugf("interrupt") + c.cs.Lock() + channel := c.cs.channel.Channel + uuid := c.cs.channel.UUID + c.cs.Unlock() + // Cleanly close the connection by sending a close message and then + // waiting (with timeout) for the server to close the connection. + log.Debug("sending close signal") + errWrite := ws.WriteJSON(payload{ + Channel: channel, + UUID: uuid, + Close: true, + }) + if errWrite != nil { + log.Debugf("write close:", err) + return + } + select { + case <-done: + case <-time.After(time.Second): + } + return + } + } + return +} + +func (c *Croc) processState(cd channelData) (err error) { + c.cs.Lock() + defer c.cs.Unlock() + + // first check if there is relay reported error + if cd.Error != "" { + err = errors.New(cd.Error) + return + } + // TODO: + // check if the state is not aligned (i.e. have h(k) but no hh(k)) + // throw error if not aligned so it can exit + + // first update the channel data + // initialize if has UUID + if cd.UUID != "" { + c.cs.channel.UUID = cd.UUID + c.cs.channel.Ports = cd.Ports + c.cs.channel.Channel = cd.Channel + c.cs.channel.Role = cd.Role + c.cs.channel.Ports = cd.Ports + log.Debugf("initialized client state") + } + // copy over the rest of the state + if cd.TransferReady { + c.cs.channel.TransferReady = true + } + for key := range cd.State { + c.cs.channel.State[key] = cd.State[key] + } + + // TODO: + // process the client state + log.Debugf("processing client state: %+v", c.cs.channel) + return +} diff --git a/src/models.go b/src/models.go index 9dfb5f9c..691d543a 100644 --- a/src/models.go +++ b/src/models.go @@ -19,11 +19,33 @@ var ( availableStates = []string{"curve", "h_k", "hh_k", "x", "y"} ) +type Croc struct { + TcpPorts []string + ServerPort string + Timeout time.Duration + UseEncryption bool + UseCompression bool + CurveType string + AllowLocalDiscovery bool + + // private variables + // rs relay state is only for the relay + rs relayState + + // cs keeps the client state + cs clientState +} + type relayState struct { channel map[string]*channelData sync.RWMutex } +type clientState struct { + channel *channelData + sync.RWMutex +} + type channelData struct { // Public // Channel is the name of the channel @@ -36,8 +58,14 @@ type channelData struct { // Ports returns which TCP ports to connect to Ports []string `json:"ports"` + // Error is sent if there is an error + Error string `json:"error"` + + // Sent on initialization // UUID is sent out only to one person at a time UUID string `json:"uuid"` + // Role is the role the person will play + Role int `json:"role"` // Private // isopen determine whether or not the channel has been opened diff --git a/src/sender.go b/src/sender.go deleted file mode 100644 index 0a54be6b..00000000 --- a/src/sender.go +++ /dev/null @@ -1,73 +0,0 @@ -package croc - -import ( - "net/url" - "os" - "os/signal" - "time" - - log "github.com/cihub/seelog" - "github.com/gorilla/websocket" -) - -func (c *Croc) send(fname string) (err error) { - interrupt := make(chan os.Signal, 1) - signal.Notify(interrupt, os.Interrupt) - - u := url.URL{Scheme: "ws", Host: "localhost:8003", Path: "/"} - log.Debugf("connecting to %s", u.String()) - - ws, _, err := websocket.DefaultDialer.Dial(u.String(), nil) - if err != nil { - log.Error("dial:", err) - return - } - defer ws.Close() - - done := make(chan struct{}) - - go func() { - defer close(done) - for { - var cd channelData - err := ws.ReadJSON(&cd) - if err != nil { - log.Debugf("sender read error:", err) - return - } - log.Debugf("recv: %+v", cd) - } - }() - - // initialize - err = ws.WriteJSON(payload{ - Open: true, - }) - if err != nil { - log.Errorf("problem opening: %s", err.Error()) - return - } - - for { - select { - case <-done: - return - case <-interrupt: - log.Debugf("interrupt") - - // Cleanly close the connection by sending a close message and then - // waiting (with timeout) for the server to close the connection. - errWrite := ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) - if errWrite != nil { - log.Debugf("write close:", err) - return - } - select { - case <-done: - case <-time.After(time.Second): - } - return - } - } - return -} diff --git a/src/server.go b/src/server.go index 0c58e5ca..56853b3e 100644 --- a/src/server.go +++ b/src/server.go @@ -18,21 +18,41 @@ func (c *Croc) startServer(tcpPorts []string, port string) (err error) { var upgrader = websocket.Upgrader{} // use default options http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { ws, err := upgrader.Upgrade(w, r, nil) + log.Debugf("connecting remote addr: %s", ws.RemoteAddr().String()) if err != nil { log.Error("upgrade:", err) return } defer ws.Close() + var channel string for { + log.Debug("waiting for next message") var p payload err := ws.ReadJSON(&p) if err != nil { - log.Debugf("read:", err) + if _, ok := err.(*websocket.CloseError); ok { + // on forced close, delete the channel + log.Debug("closed channel") + c.rs.Lock() + if _, ok := c.rs.channel[channel]; ok { + delete(c.rs.channel, channel) + } + c.rs.Unlock() + } else { + log.Debugf("read:", err) + } + break } - err = c.processPayload(ws, p) + channel, err = c.processPayload(ws, p) if err != nil { - log.Warn("problem processing payload %+v: %s", err.Error()) + // if error, send the error back and then delete the channel + log.Warn("problem processing payload %+v: %s", p, err.Error()) + ws.WriteJSON(channelData{Error: err.Error()}) + c.rs.Lock() + delete(c.rs.channel, p.Channel) + c.rs.Unlock() + return } } }) @@ -110,6 +130,7 @@ func (c *Croc) joinChannel(ws *websocket.Conn, p payload) (channel string, err e err = ws.WriteJSON(channelData{ Channel: p.Channel, UUID: c.rs.channel[p.Channel].uuids[p.Role], + Role: p.Role, }) if err != nil { return @@ -144,19 +165,44 @@ func (c *Croc) joinChannel(ws *websocket.Conn, p payload) (channel string, err e return } -func (c *Croc) processPayload(ws *websocket.Conn, p payload) (err error) { +func (c *Croc) processPayload(ws *websocket.Conn, p payload) (channel string, err error) { + channel = p.Channel + + // if the request is to close, delete the channel if p.Close { + log.Debugf("closing channel %s", p.Channel) c.rs.Lock() delete(c.rs.channel, p.Channel) c.rs.Unlock() return } - channel := p.Channel + // if request is to Open, try to open if p.Open { channel, err = c.joinChannel(ws, p) - } else if p.Update { + if err != nil { + return + } + } + + // check if open, otherwise return error + c.rs.Lock() + if _, ok := c.rs.channel[channel]; ok { + if !c.rs.channel[channel].isopen { + err = errors.Errorf("channel %s is not open, need to open first", channel) + c.rs.Unlock() + return + } + } + c.rs.Unlock() + + // if the request is to Update, then update the state + if p.Update { // update + err = c.updateChannel(p) + if err != nil { + return + } } // TODO: @@ -176,7 +222,7 @@ func (c *Croc) processPayload(ws *websocket.Conn, p payload) (err error) { } } } - c.rs.Lock() + c.rs.Unlock() return }