0
0
Fork 0
mirror of https://github.com/schollz/croc.git synced 2025-10-11 21:30:16 +02:00

use pointers

This commit is contained in:
Zack Scholl 2019-04-29 14:06:18 -06:00
parent b9a5f450c5
commit ef25c556a9
3 changed files with 36 additions and 18 deletions

View file

@ -15,8 +15,23 @@ type Comm struct {
connection net.Conn connection net.Conn
} }
func (c *Comm) IsClosed() bool {
one := []byte{}
c.connection.SetReadDeadline(time.Now())
_, err := c.connection.Read(one)
if err != nil {
fmt.Println(err)
c.connection.Close()
c.connection = nil
return true
} else {
c.connection.SetReadDeadline(time.Now().Add(3 * time.Hour))
}
return false
}
// NewConnection gets a new comm to a tcp address // NewConnection gets a new comm to a tcp address
func NewConnection(address string) (c Comm, err error) { func NewConnection(address string) (c *Comm, err error) {
connection, err := net.DialTimeout("tcp", address, 3*time.Second) connection, err := net.DialTimeout("tcp", address, 3*time.Second)
if err != nil { if err != nil {
return return
@ -26,24 +41,26 @@ func NewConnection(address string) (c Comm, err error) {
} }
// New returns a new comm // New returns a new comm
func New(c net.Conn) Comm { func New(c net.Conn) *Comm {
c.SetReadDeadline(time.Now().Add(3 * time.Hour)) c.SetReadDeadline(time.Now().Add(3 * time.Hour))
c.SetDeadline(time.Now().Add(3 * time.Hour)) c.SetDeadline(time.Now().Add(3 * time.Hour))
c.SetWriteDeadline(time.Now().Add(3 * time.Hour)) c.SetWriteDeadline(time.Now().Add(3 * time.Hour))
return Comm{c} comm := new(Comm)
comm.connection = c
return comm
} }
// Connection returns the net.Conn connection // Connection returns the net.Conn connection
func (c Comm) Connection() net.Conn { func (c *Comm) Connection() net.Conn {
return c.connection return c.connection
} }
// Close closes the connection // Close closes the connection
func (c Comm) Close() { func (c *Comm) Close() {
c.connection.Close() c.connection.Close()
} }
func (c Comm) Write(b []byte) (int, error) { func (c *Comm) Write(b []byte) (int, error) {
header := new(bytes.Buffer) header := new(bytes.Buffer)
err := binary.Write(header, binary.LittleEndian, uint32(len(b))) err := binary.Write(header, binary.LittleEndian, uint32(len(b)))
if err != nil { if err != nil {
@ -62,7 +79,7 @@ func (c Comm) Write(b []byte) (int, error) {
return n, err return n, err
} }
func (c Comm) Read() (buf []byte, numBytes int, bs []byte, err error) { func (c *Comm) Read() (buf []byte, numBytes int, bs []byte, err error) {
// read until we get 5 bytes // read until we get 5 bytes
header := make([]byte, 4) header := make([]byte, 4)
n, err := c.connection.Read(header) n, err := c.connection.Read(header)
@ -99,13 +116,13 @@ func (c Comm) Read() (buf []byte, numBytes int, bs []byte, err error) {
} }
// Send a message // Send a message
func (c Comm) Send(message []byte) (err error) { func (c *Comm) Send(message []byte) (err error) {
_, err = c.Write(message) _, err = c.Write(message)
return return
} }
// Receive a message // Receive a message
func (c Comm) Receive() (b []byte, err error) { func (c *Comm) Receive() (b []byte, err error) {
b, _, _, err = c.Read() b, _, _, err = c.Read()
return return
} }

View file

@ -14,8 +14,8 @@ import (
const TCP_BUFFER_SIZE = 1024 * 64 const TCP_BUFFER_SIZE = 1024 * 64
type roomInfo struct { type roomInfo struct {
first comm.Comm first *comm.Comm
second comm.Comm second *comm.Comm
opened time.Time opened time.Time
full bool full bool
} }
@ -77,7 +77,7 @@ func run(port string) (err error) {
} }
} }
func clientCommuncation(port string, c comm.Comm) (err error) { func clientCommuncation(port string, c *comm.Comm) (err error) {
// send ok to tell client they are connected // send ok to tell client they are connected
log.Debug("sending ok") log.Debug("sending ok")
err = c.Send([]byte("ok")) err = c.Send([]byte("ok"))
@ -134,7 +134,7 @@ func clientCommuncation(port string, c comm.Comm) (err error) {
wg.Add(1) wg.Add(1)
// start piping // start piping
go func(com1, com2 comm.Comm, wg *sync.WaitGroup) { go func(com1, com2 *comm.Comm, wg *sync.WaitGroup) {
log.Debug("starting pipes") log.Debug("starting pipes")
pipe(com1.Connection(), com2.Connection()) pipe(com1.Connection(), com2.Connection())
wg.Done() wg.Done()
@ -153,6 +153,7 @@ func clientCommuncation(port string, c comm.Comm) (err error) {
log.Debugf("deleting room: %s", room) log.Debugf("deleting room: %s", room)
rooms.rooms[room].first.Close() rooms.rooms[room].first.Close()
rooms.rooms[room].second.Close() rooms.rooms[room].second.Close()
rooms.rooms[room] = roomInfo{first: nil, second: nil}
delete(rooms.rooms, room) delete(rooms.rooms, room)
rooms.Unlock() rooms.Unlock()
return nil return nil

View file

@ -20,6 +20,7 @@ func TestTCP(t *testing.T) {
_, err = ConnectToTCPServer("localhost:8081", "testRoom") _, err = ConnectToTCPServer("localhost:8081", "testRoom")
assert.NotNil(t, err) assert.NotNil(t, err)
assert.False(t, c1.IsClosed())
// try sending data // try sending data
assert.Nil(t, c1.Send([]byte("hello, c2"))) assert.Nil(t, c1.Send([]byte("hello, c2")))
data, err := c2.Receive() data, err := c2.Receive()
@ -32,14 +33,13 @@ func TestTCP(t *testing.T) {
assert.Equal(t, []byte("hello, c1"), data) assert.Equal(t, []byte("hello, c1"), data)
c1.Close() c1.Close()
assert.True(t, c1.IsClosed())
time.Sleep(200 * time.Millisecond) time.Sleep(200 * time.Millisecond)
err = c2.Send([]byte("test")) assert.True(t, c2.IsClosed())
assert.Nil(t, err)
_, err = c2.Receive()
assert.NotNil(t, err)
} }
func ConnectToTCPServer(address, room string) (c comm.Comm, err error) { func ConnectToTCPServer(address, room string) (c *comm.Comm, err error) {
c, err = comm.NewConnection("localhost:8081") c, err = comm.NewConnection("localhost:8081")
if err != nil { if err != nil {
return return