From e222dbd1e35d1a93a97f4a984c326bcd2426332c Mon Sep 17 00:00:00 2001 From: Zack Scholl Date: Mon, 8 Apr 2019 07:14:54 -0700 Subject: [PATCH] fork Antonito/gfile --- cmd/bench/cmd.go | 38 +++++ cmd/install.go | 23 +++ cmd/receive/cmd.go | 45 ++++++ cmd/send/cmd.go | 48 ++++++ internal/buffer/buffer.go | 47 ++++++ internal/session/getters.go | 8 + internal/session/session.go | 133 +++++++++++++++ internal/session/session_test.go | 24 +++ pkg/session/bench/benchmark.go | 61 +++++++ pkg/session/bench/benchmark_test.go | 86 ++++++++++ pkg/session/bench/id.go | 24 +++ pkg/session/bench/id_test.go | 28 ++++ pkg/session/bench/init.go | 62 +++++++ pkg/session/bench/session.go | 59 +++++++ pkg/session/bench/session_test.go | 28 ++++ pkg/session/bench/state.go | 12 ++ pkg/session/bench/state_download.go | 59 +++++++ pkg/session/bench/state_upload.go | 75 +++++++++ pkg/session/bench/timeout_test.go | 38 +++++ pkg/session/common/config.go | 14 ++ pkg/session/receiver/init.go | 78 +++++++++ pkg/session/receiver/receiver.go | 47 ++++++ pkg/session/receiver/receiver_test.go | 19 +++ pkg/session/receiver/state.go | 25 +++ pkg/session/sender/getters.go | 8 + pkg/session/sender/init.go | 68 ++++++++ pkg/session/sender/io.go | 78 +++++++++ pkg/session/sender/sender.go | 75 +++++++++ pkg/session/sender/sender_test.go | 19 +++ pkg/session/sender/state.go | 58 +++++++ pkg/session/session.go | 7 + pkg/session/session_test.go | 99 ++++++++++++ pkg/stats/bytes.go | 17 ++ pkg/stats/bytes_test.go | 40 +++++ pkg/stats/ctrl.go | 55 +++++++ pkg/stats/ctrl_test.go | 55 +++++++ pkg/stats/data.go | 26 +++ pkg/stats/data_test.go | 79 +++++++++ pkg/stats/stats.go | 31 ++++ pkg/utils/utils.go | 102 ++++++++++++ pkg/utils/utils_test.go | 223 ++++++++++++++++++++++++++ 41 files changed, 2121 insertions(+) create mode 100644 cmd/bench/cmd.go create mode 100644 cmd/install.go create mode 100644 cmd/receive/cmd.go create mode 100644 cmd/send/cmd.go create mode 100644 internal/buffer/buffer.go create mode 100644 internal/session/getters.go create mode 100644 internal/session/session.go create mode 100644 internal/session/session_test.go create mode 100644 pkg/session/bench/benchmark.go create mode 100644 pkg/session/bench/benchmark_test.go create mode 100644 pkg/session/bench/id.go create mode 100644 pkg/session/bench/id_test.go create mode 100644 pkg/session/bench/init.go create mode 100644 pkg/session/bench/session.go create mode 100644 pkg/session/bench/session_test.go create mode 100644 pkg/session/bench/state.go create mode 100644 pkg/session/bench/state_download.go create mode 100644 pkg/session/bench/state_upload.go create mode 100644 pkg/session/bench/timeout_test.go create mode 100644 pkg/session/common/config.go create mode 100644 pkg/session/receiver/init.go create mode 100644 pkg/session/receiver/receiver.go create mode 100644 pkg/session/receiver/receiver_test.go create mode 100644 pkg/session/receiver/state.go create mode 100644 pkg/session/sender/getters.go create mode 100644 pkg/session/sender/init.go create mode 100644 pkg/session/sender/io.go create mode 100644 pkg/session/sender/sender.go create mode 100644 pkg/session/sender/sender_test.go create mode 100644 pkg/session/sender/state.go create mode 100644 pkg/session/session.go create mode 100644 pkg/session/session_test.go create mode 100644 pkg/stats/bytes.go create mode 100644 pkg/stats/bytes_test.go create mode 100644 pkg/stats/ctrl.go create mode 100644 pkg/stats/ctrl_test.go create mode 100644 pkg/stats/data.go create mode 100644 pkg/stats/data_test.go create mode 100644 pkg/stats/stats.go create mode 100644 pkg/utils/utils.go create mode 100644 pkg/utils/utils_test.go diff --git a/cmd/bench/cmd.go b/cmd/bench/cmd.go new file mode 100644 index 00000000..16e793a9 --- /dev/null +++ b/cmd/bench/cmd.go @@ -0,0 +1,38 @@ +package bench + +import ( + "github.com/antonito/gfile/pkg/session/bench" + "github.com/antonito/gfile/pkg/session/common" + log "github.com/sirupsen/logrus" + "gopkg.in/urfave/cli.v1" +) + +func handler(c *cli.Context) error { + isMaster := c.Bool("master") + + sess := bench.NewWith(bench.Config{ + Master: isMaster, + Configuration: common.Configuration{ + OnCompletion: func() { + }, + }, + }) + return sess.Start() +} + +// New creates the command +func New() cli.Command { + log.Traceln("Installing 'bench' command") + return cli.Command{ + Name: "bench", + Aliases: []string{"b"}, + Usage: "Benchmark the connexion", + Action: handler, + Flags: []cli.Flag{ + cli.BoolFlag{ + Name: "master, m", + Usage: "Is creating the SDP offer?", + }, + }, + } +} diff --git a/cmd/install.go b/cmd/install.go new file mode 100644 index 00000000..bf85742c --- /dev/null +++ b/cmd/install.go @@ -0,0 +1,23 @@ +package cmd + +import ( + "sort" + + "github.com/antonito/gfile/cmd/bench" + "github.com/antonito/gfile/cmd/receive" + "github.com/antonito/gfile/cmd/send" + log "github.com/sirupsen/logrus" + "gopkg.in/urfave/cli.v1" +) + +// Install all the commands +func Install(app *cli.App) { + app.Commands = []cli.Command{ + send.New(), + receive.New(), + bench.New(), + } + log.Trace("Installed commands") + + sort.Sort(cli.CommandsByName(app.Commands)) +} diff --git a/cmd/receive/cmd.go b/cmd/receive/cmd.go new file mode 100644 index 00000000..7093460a --- /dev/null +++ b/cmd/receive/cmd.go @@ -0,0 +1,45 @@ +package receive + +import ( + "fmt" + "os" + + log "github.com/sirupsen/logrus" + + "github.com/antonito/gfile/pkg/session/receiver" + "gopkg.in/urfave/cli.v1" +) + +func handler(c *cli.Context) error { + output := c.String("output") + if output == "" { + return fmt.Errorf("output parameter missing") + } + f, err := os.OpenFile(output, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) + if err != nil { + return err + } + defer f.Close() + + sess := receiver.NewWith(receiver.Config{ + Stream: f, + }) + return sess.Start() +} + +// New creates the command +func New() cli.Command { + log.Traceln("Installing 'receive' command") + return cli.Command{ + Name: "receive", + Aliases: []string{"r"}, + Usage: "Receive a file", + Action: handler, + Flags: []cli.Flag{ + cli.StringFlag{ + Name: "output, o", + Usage: "Output", + }, + }, + } +} diff --git a/cmd/send/cmd.go b/cmd/send/cmd.go new file mode 100644 index 00000000..ceb76bb0 --- /dev/null +++ b/cmd/send/cmd.go @@ -0,0 +1,48 @@ +package send + +import ( + "fmt" + "os" + + "github.com/antonito/gfile/pkg/session/common" + "github.com/antonito/gfile/pkg/session/sender" + log "github.com/sirupsen/logrus" + "gopkg.in/urfave/cli.v1" +) + +func handler(c *cli.Context) error { + fileToSend := c.String("file") + if fileToSend == "" { + return fmt.Errorf("file parameter missing") + } + f, err := os.Open(fileToSend) + if err != nil { + return err + } + defer f.Close() + sess := sender.NewWith(sender.Config{ + Stream: f, + Configuration: common.Configuration{ + OnCompletion: func() { + }, + }, + }) + return sess.Start() +} + +// New creates the command +func New() cli.Command { + log.Traceln("Installing 'send' command") + return cli.Command{ + Name: "send", + Aliases: []string{"s"}, + Usage: "Sends a file", + Action: handler, + Flags: []cli.Flag{ + cli.StringFlag{ + Name: "file, f", + Usage: "Send content of file `FILE`", + }, + }, + } +} diff --git a/internal/buffer/buffer.go b/internal/buffer/buffer.go new file mode 100644 index 00000000..108a4997 --- /dev/null +++ b/internal/buffer/buffer.go @@ -0,0 +1,47 @@ +package buffer + +import ( + "bytes" + "sync" +) + +// Buffer is a threadsafe buffer +type Buffer struct { + b bytes.Buffer + m sync.Mutex +} + +// Read in a thread-safe way +func (b *Buffer) Read(p []byte) (n int, err error) { + b.m.Lock() + defer b.m.Unlock() + return b.b.Read(p) +} + +// ReadString in a thread-safe way +func (b *Buffer) ReadString(delim byte) (line string, err error) { + b.m.Lock() + defer b.m.Unlock() + return b.b.ReadString(delim) +} + +// Write in a thread-safe way +func (b *Buffer) Write(p []byte) (n int, err error) { + b.m.Lock() + defer b.m.Unlock() + return b.b.Write(p) +} + +// WriteString in a thread-safe way +func (b *Buffer) WriteString(s string) (n int, err error) { + b.m.Lock() + defer b.m.Unlock() + return b.b.WriteString(s) +} + +// String in a thread-safe way +func (b *Buffer) String() string { + b.m.Lock() + defer b.m.Unlock() + return b.b.String() +} diff --git a/internal/session/getters.go b/internal/session/getters.go new file mode 100644 index 00000000..435e5371 --- /dev/null +++ b/internal/session/getters.go @@ -0,0 +1,8 @@ +package session + +import "io" + +// SDPProvider returns the SDP input +func (s *Session) SDPProvider() io.Reader { + return s.sdpInput +} diff --git a/internal/session/session.go b/internal/session/session.go new file mode 100644 index 00000000..d7627397 --- /dev/null +++ b/internal/session/session.go @@ -0,0 +1,133 @@ +package session + +import ( + "fmt" + "io" + "os" + + "github.com/antonito/gfile/pkg/stats" + "github.com/antonito/gfile/pkg/utils" + "github.com/pion/webrtc/v2" +) + +// CompletionHandler to be called when transfer is done +type CompletionHandler func() + +// Session contains common elements to perform send/receive +type Session struct { + Done chan struct{} + NetworkStats *stats.Stats + sdpInput io.Reader + sdpOutput io.Writer + peerConnection *webrtc.PeerConnection + onCompletion CompletionHandler +} + +// New creates a new Session +func New(sdpInput io.Reader, sdpOutput io.Writer) Session { + if sdpInput == nil { + sdpInput = os.Stdin + } + if sdpOutput == nil { + sdpOutput = os.Stdout + } + return Session{ + sdpInput: sdpInput, + sdpOutput: sdpOutput, + Done: make(chan struct{}), + NetworkStats: stats.New(), + } +} + +// CreateConnection prepares a WebRTC connection +func (s *Session) CreateConnection(onConnectionStateChange func(connectionState webrtc.ICEConnectionState)) error { + config := webrtc.Configuration{ + ICEServers: []webrtc.ICEServer{ + { + URLs: []string{"stun:stun.l.google.com:19302"}, + }, + }, + } + + // Create a new RTCPeerConnection + peerConnection, err := webrtc.NewPeerConnection(config) + if err != nil { + return err + } + s.peerConnection = peerConnection + peerConnection.OnICEConnectionStateChange(onConnectionStateChange) + + return nil +} + +// ReadSDP from the SDP input stream +func (s *Session) ReadSDP() error { + var sdp webrtc.SessionDescription + + fmt.Println("Please, paste the remote SDP:") + for { + encoded, err := utils.MustReadStream(s.sdpInput) + if err == nil { + if err := utils.Decode(encoded, &sdp); err == nil { + break + } + } + fmt.Println("Invalid SDP, try again...") + } + return s.peerConnection.SetRemoteDescription(sdp) +} + +// CreateDataChannel that will be used to send data +func (s *Session) CreateDataChannel(c *webrtc.DataChannelInit) (*webrtc.DataChannel, error) { + return s.peerConnection.CreateDataChannel("data", c) +} + +// OnDataChannel sets an OnDataChannel handler +func (s *Session) OnDataChannel(handler func(d *webrtc.DataChannel)) { + s.peerConnection.OnDataChannel(handler) +} + +// CreateAnswer set the local description and print the answer SDP +func (s *Session) CreateAnswer() error { + // Create an answer + answer, err := s.peerConnection.CreateAnswer(nil) + if err != nil { + return err + } + return s.createSessionDescription(answer) +} + +// CreateOffer set the local description and print the offer SDP +func (s *Session) CreateOffer() error { + // Create an offer + answer, err := s.peerConnection.CreateOffer(nil) + if err != nil { + return err + } + return s.createSessionDescription(answer) +} + +// createSessionDescription set the local description and print the SDP +func (s *Session) createSessionDescription(desc webrtc.SessionDescription) error { + // Sets the LocalDescription, and starts our UDP listeners + if err := s.peerConnection.SetLocalDescription(desc); err != nil { + return err + } + desc.SDP = utils.StripSDP(desc.SDP) + + // Output the SDP in base64 so we can paste it in browser + resp, err := utils.Encode(desc) + if err != nil { + return err + } + fmt.Println("Send this SDP:") + fmt.Fprintf(s.sdpOutput, "%s\n", resp) + return nil +} + +// OnCompletion is called when session ends +func (s *Session) OnCompletion() { + if s.onCompletion != nil { + s.onCompletion() + } +} diff --git a/internal/session/session_test.go b/internal/session/session_test.go new file mode 100644 index 00000000..9770aca3 --- /dev/null +++ b/internal/session/session_test.go @@ -0,0 +1,24 @@ +package session + +import ( + "bufio" + "bytes" + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_New(t *testing.T) { + assert := assert.New(t) + input := bufio.NewReader(&bytes.Buffer{}) + output := bufio.NewWriter(&bytes.Buffer{}) + + sess := New(nil, nil) + assert.Equal(os.Stdin, sess.sdpInput) + assert.Equal(os.Stdout, sess.sdpOutput) + + sess = New(input, output) + assert.Equal(input, sess.sdpInput) + assert.Equal(output, sess.sdpOutput) +} diff --git a/pkg/session/bench/benchmark.go b/pkg/session/bench/benchmark.go new file mode 100644 index 00000000..5c4854a2 --- /dev/null +++ b/pkg/session/bench/benchmark.go @@ -0,0 +1,61 @@ +package bench + +import ( + "sync" + "time" + + internalSess "github.com/antonito/gfile/internal/session" + "github.com/antonito/gfile/pkg/session/common" + "github.com/antonito/gfile/pkg/stats" +) + +const ( + bufferThresholdDefault = 64 * 1024 // 64kB + testDurationDefault = 20 * time.Second + testDurationErrorDefault = (testDurationDefault * 10) / 7 +) + +// Session is a benchmark session +type Session struct { + sess internalSess.Session + master bool + wg sync.WaitGroup + + // Settings + bufferThreshold uint64 + testDuration time.Duration + testDurationError time.Duration + + startPhase2 chan struct{} + uploadNetworkStats *stats.Stats + downloadDone chan bool + downloadNetworkStats *stats.Stats +} + +// New creates a new sender session +func new(s internalSess.Session, isMaster bool) *Session { + return &Session{ + sess: s, + master: isMaster, + + bufferThreshold: bufferThresholdDefault, + testDuration: testDurationDefault, + testDurationError: testDurationErrorDefault, + + startPhase2: make(chan struct{}), + downloadDone: make(chan bool), + uploadNetworkStats: stats.New(), + downloadNetworkStats: stats.New(), + } +} + +// Config contains custom configuration for a session +type Config struct { + common.Configuration + Master bool // Will create the SDP offer ? +} + +// NewWith createa a new benchmark Session with custom configuration +func NewWith(c Config) *Session { + return new(internalSess.New(c.SDPProvider, c.SDPOutput), c.Master) +} diff --git a/pkg/session/bench/benchmark_test.go b/pkg/session/bench/benchmark_test.go new file mode 100644 index 00000000..5a738828 --- /dev/null +++ b/pkg/session/bench/benchmark_test.go @@ -0,0 +1,86 @@ +package bench + +import ( + "testing" + "time" + + "github.com/antonito/gfile/internal/buffer" + "github.com/antonito/gfile/pkg/session/common" + "github.com/antonito/gfile/pkg/utils" + "github.com/stretchr/testify/assert" +) + +func Test_New(t *testing.T) { + assert := assert.New(t) + + sess := NewWith(Config{ + Master: false, + }) + + assert.NotNil(sess) + assert.Equal(false, sess.master) +} + +func Test_Bench(t *testing.T) { + assert := assert.New(t) + + sessionSDPProvider := &buffer.Buffer{} + sessionSDPOutput := &buffer.Buffer{} + sessionMasterSDPProvider := &buffer.Buffer{} + sessionMasterSDPOutput := &buffer.Buffer{} + + testDuration := 2 * time.Second + + sess := NewWith(Config{ + Configuration: common.Configuration{ + SDPProvider: sessionSDPProvider, + SDPOutput: sessionSDPOutput, + }, + Master: false, + }) + assert.NotNil(sess) + sess.testDuration = testDuration + sess.testDurationError = (testDuration * 10) / 8 + + sessMaster := NewWith(Config{ + Configuration: common.Configuration{ + SDPProvider: sessionMasterSDPProvider, + SDPOutput: sessionMasterSDPOutput, + }, + Master: true, + }) + assert.NotNil(sessMaster) + sessMaster.testDuration = testDuration + sessMaster.testDurationError = (testDuration * 10) / 8 + + masterDone := make(chan struct{}) + go func() { + defer close(masterDone) + err := sessMaster.Start() + assert.Nil(err) + }() + + sdp, err := utils.MustReadStream(sessionMasterSDPOutput) + assert.Nil(err) + sdp += "\n" + n, err := sessionSDPProvider.WriteString(sdp) + assert.Nil(err) + assert.Equal(len(sdp), n) + + slaveDone := make(chan struct{}) + go func() { + defer close(slaveDone) + err := sess.Start() + assert.Nil(err) + }() + + // Get SDP from slave and send it to the master + sdp, err = utils.MustReadStream(sessionSDPOutput) + assert.Nil(err) + n, err = sessionMasterSDPProvider.WriteString(sdp) + assert.Nil(err) + assert.Equal(len(sdp), n) + + <-masterDone + <-slaveDone +} diff --git a/pkg/session/bench/id.go b/pkg/session/bench/id.go new file mode 100644 index 00000000..9b6994bf --- /dev/null +++ b/pkg/session/bench/id.go @@ -0,0 +1,24 @@ +package bench + +const ( + // Used as upload channel for master (and download channel for non-master) + // 43981 -> 0xABCD + dataChannel1ID = uint16(43981) + // Used as download channel for master (and upload channel for non-master) + // 61185 -> 0xef01 + dataChannel2ID = uint16(61185) +) + +func (s *Session) uploadChannelID() uint16 { + if s.master { + return dataChannel1ID + } + return dataChannel2ID +} + +func (s *Session) downloadChannelID() uint16 { + if s.master { + return dataChannel2ID + } + return dataChannel1ID +} diff --git a/pkg/session/bench/id_test.go b/pkg/session/bench/id_test.go new file mode 100644 index 00000000..b93b00a2 --- /dev/null +++ b/pkg/session/bench/id_test.go @@ -0,0 +1,28 @@ +package bench + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_IDs(t *testing.T) { + assert := assert.New(t) + + sess := NewWith(Config{ + Master: false, + }) + assert.NotNil(sess) + assert.Equal(false, sess.master) + + sessMaster := NewWith(Config{ + Master: true, + }) + assert.NotNil(sessMaster) + assert.Equal(true, sessMaster.master) + + assert.Equal(sessMaster.downloadChannelID(), sess.uploadChannelID()) + assert.Equal(sessMaster.uploadChannelID(), sess.downloadChannelID()) + assert.NotEqual(sessMaster.downloadChannelID(), sess.downloadChannelID()) + +} diff --git a/pkg/session/bench/init.go b/pkg/session/bench/init.go new file mode 100644 index 00000000..0fd7c3fe --- /dev/null +++ b/pkg/session/bench/init.go @@ -0,0 +1,62 @@ +package bench + +import ( + "fmt" + + "github.com/pion/webrtc/v2" + log "github.com/sirupsen/logrus" +) + +// Start initializes the connection and the benchmark +func (s *Session) Start() error { + if err := s.sess.CreateConnection(s.onConnectionStateChange()); err != nil { + log.Errorln(err) + return err + } + + s.sess.OnDataChannel(s.onNewDataChannel()) + if err := s.createUploadDataChannel(); err != nil { + log.Errorln(err) + return err + } + + s.wg.Add(2) // Download + Upload + if s.master { + if err := s.createMasterSession(); err != nil { + return err + } + } else { + if err := s.createSlaveSession(); err != nil { + return err + } + } + // Wait for benchmarks to be done + s.wg.Wait() + + fmt.Printf("Upload: %s\n", s.uploadNetworkStats.String()) + fmt.Printf("Download: %s\n", s.downloadNetworkStats.String()) + s.sess.OnCompletion() + return nil +} + +func (s *Session) initDataChannel(channelID *uint16) (*webrtc.DataChannel, error) { + ordered := true + maxPacketLifeTime := uint16(10000) + return s.sess.CreateDataChannel(&webrtc.DataChannelInit{ + Ordered: &ordered, + MaxPacketLifeTime: &maxPacketLifeTime, + ID: channelID, + }) +} + +func (s *Session) createUploadDataChannel() error { + channelID := s.uploadChannelID() + dataChannel, err := s.initDataChannel(&channelID) + if err != nil { + return err + } + + dataChannel.OnOpen(s.onOpenUploadHandler(dataChannel)) + + return nil +} diff --git a/pkg/session/bench/session.go b/pkg/session/bench/session.go new file mode 100644 index 00000000..3df7c0ae --- /dev/null +++ b/pkg/session/bench/session.go @@ -0,0 +1,59 @@ +package bench + +import ( + "github.com/pion/webrtc/v2" + log "github.com/sirupsen/logrus" +) + +// Useful for unit tests +func (s *Session) onNewDataChannelHelper(name string, channelID uint16, d *webrtc.DataChannel) { + log.Tracef("New DataChannel %s (id: %x)\n", name, channelID) + + switch channelID { + case s.downloadChannelID(): + log.Traceln("Created Download data channel") + d.OnClose(s.onCloseHandlerDownload()) + go s.onOpenHandlerDownload(d)() + + case s.uploadChannelID(): + log.Traceln("Created Upload data channel") + + default: + log.Warningln("Created unknown data channel") + } +} + +func (s *Session) onNewDataChannel() func(d *webrtc.DataChannel) { + return func(d *webrtc.DataChannel) { + if d == nil || d.ID() == nil { + return + } + s.onNewDataChannelHelper(d.Label(), *d.ID(), d) + } +} + +func (s *Session) createMasterSession() error { + if err := s.sess.CreateOffer(); err != nil { + log.Errorln(err) + return err + } + + if err := s.sess.ReadSDP(); err != nil { + log.Errorln(err) + return err + } + return nil +} + +func (s *Session) createSlaveSession() error { + if err := s.sess.ReadSDP(); err != nil { + log.Errorln(err) + return err + } + + if err := s.sess.CreateAnswer(); err != nil { + log.Errorln(err) + return err + } + return nil +} diff --git a/pkg/session/bench/session_test.go b/pkg/session/bench/session_test.go new file mode 100644 index 00000000..328b79b8 --- /dev/null +++ b/pkg/session/bench/session_test.go @@ -0,0 +1,28 @@ +package bench + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func Test_OnNewDataChannel(t *testing.T) { + assert := assert.New(t) + testDuration := 2 * time.Second + + sess := NewWith(Config{ + Master: false, + }) + assert.NotNil(sess) + sess.testDuration = testDuration + sess.testDurationError = (testDuration * 10) / 8 + + sess.onNewDataChannel()(nil) + + testID := sess.uploadChannelID() + sess.onNewDataChannelHelper("", testID, nil) + + testID = sess.uploadChannelID() | sess.downloadChannelID() + sess.onNewDataChannelHelper("", testID, nil) +} diff --git a/pkg/session/bench/state.go b/pkg/session/bench/state.go new file mode 100644 index 00000000..eecdff13 --- /dev/null +++ b/pkg/session/bench/state.go @@ -0,0 +1,12 @@ +package bench + +import ( + "github.com/pion/webrtc/v2" + log "github.com/sirupsen/logrus" +) + +func (s *Session) onConnectionStateChange() func(connectionState webrtc.ICEConnectionState) { + return func(connectionState webrtc.ICEConnectionState) { + log.Infof("ICE Connection State has changed: %s\n", connectionState.String()) + } +} diff --git a/pkg/session/bench/state_download.go b/pkg/session/bench/state_download.go new file mode 100644 index 00000000..4ddd5491 --- /dev/null +++ b/pkg/session/bench/state_download.go @@ -0,0 +1,59 @@ +package bench + +import ( + "fmt" + "time" + + "github.com/pion/webrtc/v2" + log "github.com/sirupsen/logrus" +) + +func (s *Session) onOpenHandlerDownload(dc *webrtc.DataChannel) func() { + // If master, wait for the upload to complete + // If not master, close the channel so the upload can start + return func() { + if s.master { + <-s.startPhase2 + } + + log.Debugf("Starting to download data...") + defer log.Debugf("Stopped downloading data...") + + s.downloadNetworkStats.Start() + + // Useful for unit tests + if dc != nil { + dc.OnMessage(func(msg webrtc.DataChannelMessage) { + fmt.Printf("Downloading at %.2f MB/s\r", s.downloadNetworkStats.Bandwidth()) + s.downloadNetworkStats.AddBytes(uint64(len(msg.Data))) + }) + } else { + log.Warningln("No DataChannel provided") + } + + timeoutErr := time.After(s.testDurationError) + fmt.Printf("Downloading random datas ... (%d s)\n", int(s.testDuration.Seconds())) + + select { + case <-s.downloadDone: + case <-timeoutErr: + log.Error("Time'd out") + } + + log.Traceln("Done downloading") + + if !s.master { + close(s.startPhase2) + } + + fmt.Printf("\n") + s.downloadNetworkStats.Stop() + s.wg.Done() + } +} + +func (s *Session) onCloseHandlerDownload() func() { + return func() { + close(s.downloadDone) + } +} diff --git a/pkg/session/bench/state_upload.go b/pkg/session/bench/state_upload.go new file mode 100644 index 00000000..58960a0b --- /dev/null +++ b/pkg/session/bench/state_upload.go @@ -0,0 +1,75 @@ +package bench + +import ( + "crypto/rand" + "fmt" + "time" + + "github.com/pion/webrtc/v2" + log "github.com/sirupsen/logrus" +) + +func (s *Session) onOpenUploadHandler(dc *webrtc.DataChannel) func() { + return func() { + if !s.master { + <-s.startPhase2 + } + + log.Debugln("Starting to upload data...") + defer log.Debugln("Stopped uploading data...") + + lenToken := uint64(4096) + token := make([]byte, lenToken) + if _, err := rand.Read(token); err != nil { + log.Fatalln("Err: ", err) + } + + s.uploadNetworkStats.Start() + + // Useful for unit tests + if dc != nil { + dc.SetBufferedAmountLowThreshold(s.bufferThreshold) + dc.OnBufferedAmountLow(func() { + if err := dc.Send(token); err == nil { + fmt.Printf("Uploading at %.2f MB/s\r", s.uploadNetworkStats.Bandwidth()) + s.uploadNetworkStats.AddBytes(lenToken) + } + }) + } else { + log.Warningln("No DataChannel provided") + } + + fmt.Printf("Uploading random datas ... (%d s)\n", int(s.testDuration.Seconds())) + timeout := time.After(s.testDuration) + timeoutErr := time.After(s.testDurationError) + + if dc != nil { + // Ignore potential error + _ = dc.Send(token) + } + SENDING_LOOP: + for { + select { + case <-timeoutErr: + log.Error("Time'd out") + break SENDING_LOOP + + case <-timeout: + log.Traceln("Done uploading") + break SENDING_LOOP + } + } + fmt.Printf("\n") + s.uploadNetworkStats.Stop() + + if dc != nil { + dc.Close() + } + + if s.master { + close(s.startPhase2) + } + + s.wg.Done() + } +} diff --git a/pkg/session/bench/timeout_test.go b/pkg/session/bench/timeout_test.go new file mode 100644 index 00000000..d102bc85 --- /dev/null +++ b/pkg/session/bench/timeout_test.go @@ -0,0 +1,38 @@ +package bench + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func Test_TimeoutDownload(t *testing.T) { + assert := assert.New(t) + + sess := NewWith(Config{ + Master: false, + }) + + assert.NotNil(sess) + assert.Equal(false, sess.master) + sess.testDurationError = 2 * time.Millisecond + + sess.wg.Add(1) + sess.onOpenHandlerDownload(nil)() +} + +func Test_TimeoutUpload(t *testing.T) { + assert := assert.New(t) + + sess := NewWith(Config{ + Master: true, + }) + + assert.NotNil(sess) + assert.Equal(true, sess.master) + sess.testDurationError = 2 * time.Millisecond + + sess.wg.Add(1) + sess.onOpenUploadHandler(nil)() +} diff --git a/pkg/session/common/config.go b/pkg/session/common/config.go new file mode 100644 index 00000000..d2a29935 --- /dev/null +++ b/pkg/session/common/config.go @@ -0,0 +1,14 @@ +package common + +import ( + "io" + + "github.com/antonito/gfile/internal/session" +) + +// Configuration common to both Sender and Receiver session +type Configuration struct { + SDPProvider io.Reader // The SDP reader + SDPOutput io.Writer // The SDP writer + OnCompletion session.CompletionHandler // Handler to call on session completion +} diff --git a/pkg/session/receiver/init.go b/pkg/session/receiver/init.go new file mode 100644 index 00000000..12beb324 --- /dev/null +++ b/pkg/session/receiver/init.go @@ -0,0 +1,78 @@ +package receiver + +import ( + "fmt" + + "github.com/pion/webrtc/v2" + log "github.com/sirupsen/logrus" +) + +// Initialize creates the connection, the datachannel and creates the offer +func (s *Session) Initialize() error { + if s.initialized { + return nil + } + if err := s.sess.CreateConnection(s.onConnectionStateChange()); err != nil { + log.Errorln(err) + return err + } + s.createDataHandler() + if err := s.sess.ReadSDP(); err != nil { + log.Errorln(err) + return err + } + if err := s.sess.CreateAnswer(); err != nil { + log.Errorln(err) + return err + } + + s.initialized = true + return nil +} + +// Start initializes the connection and the file transfer +func (s *Session) Start() error { + if err := s.Initialize(); err != nil { + return err + } + + // Handle data + s.receiveData() + s.sess.OnCompletion() + return nil +} + +func (s *Session) createDataHandler() { + s.sess.OnDataChannel(func(d *webrtc.DataChannel) { + log.Debugf("New DataChannel %s %d\n", d.Label(), d.ID()) + s.sess.NetworkStats.Start() + d.OnMessage(s.onMessage()) + d.OnClose(s.onClose()) + }) +} + +func (s *Session) receiveData() { + log.Infoln("Starting to receive data...") + defer log.Infoln("Stopped receiving data...") + + // Consume the message channel, until done + // Does not stop on error + for { + select { + case <-s.sess.Done: + s.sess.NetworkStats.Stop() + fmt.Printf("\nNetwork: %s\n", s.sess.NetworkStats.String()) + return + case msg := <-s.msgChannel: + n, err := s.stream.Write(msg.Data) + + if err != nil { + log.Errorln(err) + } else { + currentSpeed := s.sess.NetworkStats.Bandwidth() + fmt.Printf("Transferring at %.2f MB/s\r", currentSpeed) + s.sess.NetworkStats.AddBytes(uint64(n)) + } + } + } +} diff --git a/pkg/session/receiver/receiver.go b/pkg/session/receiver/receiver.go new file mode 100644 index 00000000..045b2730 --- /dev/null +++ b/pkg/session/receiver/receiver.go @@ -0,0 +1,47 @@ +package receiver + +import ( + "io" + + internalSess "github.com/antonito/gfile/internal/session" + "github.com/antonito/gfile/pkg/session/common" + "github.com/pion/webrtc/v2" +) + +// Session is a receiver session +type Session struct { + sess internalSess.Session + stream io.Writer + msgChannel chan webrtc.DataChannelMessage + initialized bool +} + +func new(s internalSess.Session, f io.Writer) *Session { + return &Session{ + sess: s, + stream: f, + msgChannel: make(chan webrtc.DataChannelMessage, 4096*2), + initialized: false, + } +} + +// New creates a new receiver session +func New(f io.Writer) *Session { + return new(internalSess.New(nil, nil), f) +} + +// Config contains custom configuration for a session +type Config struct { + common.Configuration + Stream io.Writer // The Stream to write to +} + +// NewWith createa a new receiver Session with custom configuration +func NewWith(c Config) *Session { + return new(internalSess.New(c.SDPProvider, c.SDPOutput), c.Stream) +} + +// SetStream changes the stream, useful for WASM integration +func (s *Session) SetStream(stream io.Writer) { + s.stream = stream +} diff --git a/pkg/session/receiver/receiver_test.go b/pkg/session/receiver/receiver_test.go new file mode 100644 index 00000000..0224805c --- /dev/null +++ b/pkg/session/receiver/receiver_test.go @@ -0,0 +1,19 @@ +package receiver + +import ( + "bufio" + "bytes" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_New(t *testing.T) { + assert := assert.New(t) + output := bufio.NewWriter(&bytes.Buffer{}) + + sess := New(output) + + assert.NotNil(sess) + assert.Equal(output, sess.stream) +} diff --git a/pkg/session/receiver/state.go b/pkg/session/receiver/state.go new file mode 100644 index 00000000..99257f63 --- /dev/null +++ b/pkg/session/receiver/state.go @@ -0,0 +1,25 @@ +package receiver + +import ( + "github.com/pion/webrtc/v2" + log "github.com/sirupsen/logrus" +) + +func (s *Session) onConnectionStateChange() func(connectionState webrtc.ICEConnectionState) { + return func(connectionState webrtc.ICEConnectionState) { + log.Infof("ICE Connection State has changed: %s\n", connectionState.String()) + } +} + +func (s *Session) onMessage() func(msg webrtc.DataChannelMessage) { + return func(msg webrtc.DataChannelMessage) { + // Store each message in the message channel + s.msgChannel <- msg + } +} + +func (s *Session) onClose() func() { + return func() { + close(s.sess.Done) + } +} diff --git a/pkg/session/sender/getters.go b/pkg/session/sender/getters.go new file mode 100644 index 00000000..c9c6dad8 --- /dev/null +++ b/pkg/session/sender/getters.go @@ -0,0 +1,8 @@ +package sender + +import "io" + +// SDPProvider returns the underlying SDPProvider +func (s *Session) SDPProvider() io.Reader { + return s.sess.SDPProvider() +} diff --git a/pkg/session/sender/init.go b/pkg/session/sender/init.go new file mode 100644 index 00000000..eeebc0d1 --- /dev/null +++ b/pkg/session/sender/init.go @@ -0,0 +1,68 @@ +package sender + +import ( + "github.com/pion/webrtc/v2" + log "github.com/sirupsen/logrus" +) + +const ( + bufferThreshold = 512 * 1024 // 512kB +) + +// Initialize creates the connection, the datachannel and creates the offer +func (s *Session) Initialize() error { + if s.initialized { + return nil + } + + if err := s.sess.CreateConnection(s.onConnectionStateChange()); err != nil { + log.Errorln(err) + return err + } + if err := s.createDataChannel(); err != nil { + log.Errorln(err) + return err + } + if err := s.sess.CreateOffer(); err != nil { + log.Errorln(err) + return err + } + + s.initialized = true + return nil +} + +// Start the connection and the file transfer +func (s *Session) Start() error { + if err := s.Initialize(); err != nil { + return err + } + go s.readFile() + if err := s.sess.ReadSDP(); err != nil { + log.Errorln(err) + return err + } + <-s.sess.Done + s.sess.OnCompletion() + return nil +} + +func (s *Session) createDataChannel() error { + ordered := true + maxPacketLifeTime := uint16(10000) + dataChannel, err := s.sess.CreateDataChannel(&webrtc.DataChannelInit{ + Ordered: &ordered, + MaxPacketLifeTime: &maxPacketLifeTime, + }) + if err != nil { + return err + } + + s.dataChannel = dataChannel + s.dataChannel.OnBufferedAmountLow(s.onBufferedAmountLow()) + s.dataChannel.SetBufferedAmountLowThreshold(bufferThreshold) + s.dataChannel.OnOpen(s.onOpenHandler()) + s.dataChannel.OnClose(s.onCloseHandler()) + + return nil +} diff --git a/pkg/session/sender/io.go b/pkg/session/sender/io.go new file mode 100644 index 00000000..b56524b7 --- /dev/null +++ b/pkg/session/sender/io.go @@ -0,0 +1,78 @@ +package sender + +import ( + "fmt" + "io" + + log "github.com/sirupsen/logrus" +) + +func (s *Session) readFile() { + log.Infof("Starting to read data...") + s.readingStats.Start() + defer func() { + s.readingStats.Pause() + log.Infof("Stopped reading data...") + close(s.output) + }() + + for { + // Read file + s.dataBuff = s.dataBuff[:cap(s.dataBuff)] + n, err := s.stream.Read(s.dataBuff) + if err != nil { + if err == io.EOF { + s.readingStats.Stop() + log.Debugf("Got EOF after %v bytes!\n", s.readingStats.Bytes()) + return + } + log.Errorf("Read Error: %v\n", err) + return + } + s.dataBuff = s.dataBuff[:n] + s.readingStats.AddBytes(uint64(n)) + + s.output <- outputMsg{ + n: n, + // Make a copy of the buffer + buff: append([]byte(nil), s.dataBuff...), + } + } +} + +func (s *Session) onBufferedAmountLow() func() { + return func() { + data := <-s.output + if data.n != 0 { + s.msgToBeSent = append(s.msgToBeSent, data) + } else if len(s.msgToBeSent) == 0 && s.dataChannel.BufferedAmount() == 0 { + s.sess.NetworkStats.Stop() + s.close(false) + return + } + + currentSpeed := s.sess.NetworkStats.Bandwidth() + fmt.Printf("Transferring at %.2f MB/s\r", currentSpeed) + + for len(s.msgToBeSent) != 0 { + cur := s.msgToBeSent[0] + + if err := s.dataChannel.Send(cur.buff); err != nil { + log.Errorf("Error, cannot send to client: %v\n", err) + return + } + s.sess.NetworkStats.AddBytes(uint64(cur.n)) + s.msgToBeSent = s.msgToBeSent[1:] + } + } +} + +func (s *Session) writeToNetwork() { + // Set callback, as transfer may be paused + s.dataChannel.OnBufferedAmountLow(s.onBufferedAmountLow()) + + <-s.stopSending + s.dataChannel.OnBufferedAmountLow(nil) + s.sess.NetworkStats.Pause() + log.Infof("Pausing network I/O... (remaining at least %v packets)\n", len(s.output)) +} diff --git a/pkg/session/sender/sender.go b/pkg/session/sender/sender.go new file mode 100644 index 00000000..6dec8826 --- /dev/null +++ b/pkg/session/sender/sender.go @@ -0,0 +1,75 @@ +package sender + +import ( + "io" + "sync" + + internalSess "github.com/antonito/gfile/internal/session" + "github.com/antonito/gfile/pkg/session/common" + "github.com/antonito/gfile/pkg/stats" + "github.com/pion/webrtc/v2" +) + +const ( + // Must be <= 16384 + senderBuffSize = 16384 +) + +type outputMsg struct { + n int + buff []byte +} + +// Session is a sender session +type Session struct { + sess internalSess.Session + stream io.Reader + initialized bool + + dataChannel *webrtc.DataChannel + dataBuff []byte + msgToBeSent []outputMsg + stopSending chan struct{} + output chan outputMsg + + doneCheckLock sync.Mutex + doneCheck bool + + // Stats/infos + readingStats *stats.Stats +} + +// New creates a new sender session +func new(s internalSess.Session, f io.Reader) *Session { + return &Session{ + sess: s, + stream: f, + initialized: false, + dataBuff: make([]byte, senderBuffSize), + stopSending: make(chan struct{}, 1), + output: make(chan outputMsg, senderBuffSize*10), + doneCheck: false, + readingStats: stats.New(), + } +} + +// New creates a new receiver session +func New(f io.Reader) *Session { + return new(internalSess.New(nil, nil), f) +} + +// Config contains custom configuration for a session +type Config struct { + common.Configuration + Stream io.Reader // The Stream to read from +} + +// NewWith createa a new sender Session with custom configuration +func NewWith(c Config) *Session { + return new(internalSess.New(c.SDPProvider, c.SDPOutput), c.Stream) +} + +// SetStream changes the stream, useful for WASM integration +func (s *Session) SetStream(stream io.Reader) { + s.stream = stream +} diff --git a/pkg/session/sender/sender_test.go b/pkg/session/sender/sender_test.go new file mode 100644 index 00000000..029a405c --- /dev/null +++ b/pkg/session/sender/sender_test.go @@ -0,0 +1,19 @@ +package sender + +import ( + "bufio" + "bytes" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_New(t *testing.T) { + assert := assert.New(t) + input := bufio.NewReader(&bytes.Buffer{}) + + sess := New(input) + + assert.NotNil(sess) + assert.Equal(input, sess.stream) +} diff --git a/pkg/session/sender/state.go b/pkg/session/sender/state.go new file mode 100644 index 00000000..47075424 --- /dev/null +++ b/pkg/session/sender/state.go @@ -0,0 +1,58 @@ +package sender + +import ( + "fmt" + + "github.com/pion/webrtc/v2" + log "github.com/sirupsen/logrus" +) + +func (s *Session) onConnectionStateChange() func(connectionState webrtc.ICEConnectionState) { + return func(connectionState webrtc.ICEConnectionState) { + log.Infof("ICE Connection State has changed: %s\n", connectionState.String()) + if connectionState == webrtc.ICEConnectionStateDisconnected { + s.stopSending <- struct{}{} + } + } +} + +func (s *Session) onOpenHandler() func() { + return func() { + s.sess.NetworkStats.Start() + + log.Infof("Starting to send data...") + defer log.Infof("Stopped sending data...") + + s.writeToNetwork() + } +} + +func (s *Session) onCloseHandler() func() { + return func() { + s.close(true) + } +} + +func (s *Session) close(calledFromCloseHandler bool) { + if !calledFromCloseHandler { + s.dataChannel.Close() + } + + // Sometime, onCloseHandler is not invoked, so it's a work-around + s.doneCheckLock.Lock() + if s.doneCheck { + s.doneCheckLock.Unlock() + return + } + s.doneCheck = true + s.doneCheckLock.Unlock() + s.dumpStats() + close(s.sess.Done) +} + +func (s *Session) dumpStats() { + fmt.Printf(` +Disk : %s +Network: %s +`, s.readingStats.String(), s.sess.NetworkStats.String()) +} diff --git a/pkg/session/session.go b/pkg/session/session.go new file mode 100644 index 00000000..eab1644d --- /dev/null +++ b/pkg/session/session.go @@ -0,0 +1,7 @@ +package session + +// Session defines a common interface for sender and receiver sessions +type Session interface { + // Start a connection and starts the file transfer + Start() error +} diff --git a/pkg/session/session_test.go b/pkg/session/session_test.go new file mode 100644 index 00000000..34a24986 --- /dev/null +++ b/pkg/session/session_test.go @@ -0,0 +1,99 @@ +package session + +import ( + "bytes" + "fmt" + "testing" + + "github.com/antonito/gfile/internal/buffer" + "github.com/antonito/gfile/pkg/session/common" + "github.com/antonito/gfile/pkg/session/receiver" + "github.com/antonito/gfile/pkg/session/sender" + "github.com/antonito/gfile/pkg/utils" + "github.com/stretchr/testify/assert" +) + +// Tests + +func Test_CreateReceiverSession(t *testing.T) { + assert := assert.New(t) + stream := &bytes.Buffer{} + + sess := receiver.NewWith(receiver.Config{ + Stream: stream, + }) + assert.NotNil(sess) +} + +func Test_TransferSmallMessage(t *testing.T) { + assert := assert.New(t) + + // Create client receiver + clientStream := &buffer.Buffer{} + clientSDPProvider := &buffer.Buffer{} + clientSDPOutput := &buffer.Buffer{} + clientConfig := receiver.Config{ + Stream: clientStream, + Configuration: common.Configuration{ + SDPProvider: clientSDPProvider, + SDPOutput: clientSDPOutput, + }, + } + clientSession := receiver.NewWith(clientConfig) + assert.NotNil(clientSession) + + // Create sender + senderStream := &buffer.Buffer{} + senderSDPProvider := &buffer.Buffer{} + senderSDPOutput := &buffer.Buffer{} + n, err := senderStream.WriteString("Hello World!\n") + assert.Nil(err) + assert.Equal(13, n) // Len "Hello World\n" + senderConfig := sender.Config{ + Stream: senderStream, + Configuration: common.Configuration{ + SDPProvider: senderSDPProvider, + SDPOutput: senderSDPOutput, + }, + } + senderSession := sender.NewWith(senderConfig) + assert.NotNil(senderSession) + + senderDone := make(chan struct{}) + go func() { + defer close(senderDone) + err := senderSession.Start() + assert.Nil(err) + }() + + // Get SDP from sender and send it to the client + sdp, err := utils.MustReadStream(senderSDPOutput) + assert.Nil(err) + fmt.Printf("READ SDP -> %s\n", sdp) + sdp += "\n" + n, err = clientSDPProvider.WriteString(sdp) + assert.Nil(err) + assert.Equal(len(sdp), n) + + clientDone := make(chan struct{}) + go func() { + defer close(clientDone) + err := clientSession.Start() + assert.Nil(err) + }() + + // Get SDP from client and send it to the sender + sdp, err = utils.MustReadStream(clientSDPOutput) + assert.Nil(err) + n, err = senderSDPProvider.WriteString(sdp) + assert.Nil(err) + assert.Equal(len(sdp), n) + + fmt.Println("Waiting for everyone to be done...") + <-senderDone + <-clientDone + + msg, err := clientStream.ReadString('\n') + assert.Nil(err) + assert.Equal("Hello World!\n", msg) +} diff --git a/pkg/stats/bytes.go b/pkg/stats/bytes.go new file mode 100644 index 00000000..b3dd73ef --- /dev/null +++ b/pkg/stats/bytes.go @@ -0,0 +1,17 @@ +package stats + +// Bytes returns the stored number of bytes +func (s *Stats) Bytes() uint64 { + s.lock.RLock() + defer s.lock.RUnlock() + + return s.nbBytes +} + +// AddBytes increase the nbBytes counter +func (s *Stats) AddBytes(c uint64) { + s.lock.Lock() + defer s.lock.Unlock() + + s.nbBytes += c +} diff --git a/pkg/stats/bytes_test.go b/pkg/stats/bytes_test.go new file mode 100644 index 00000000..085238f1 --- /dev/null +++ b/pkg/stats/bytes_test.go @@ -0,0 +1,40 @@ +package stats + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_Bytes(t *testing.T) { + assert := assert.New(t) + + tests := []struct { + before uint64 + add uint64 + after uint64 + }{ + { + before: 0, + add: 0, + after: 0, + }, + { + before: 0, + add: 1, + after: 1, + }, + { + before: 1, + add: 10, + after: 11, + }, + } + + s := New() + for _, cur := range tests { + assert.Equal(cur.before, s.Bytes()) + s.AddBytes(cur.add) + assert.Equal(cur.after, s.Bytes()) + } +} diff --git a/pkg/stats/ctrl.go b/pkg/stats/ctrl.go new file mode 100644 index 00000000..f33bf5fb --- /dev/null +++ b/pkg/stats/ctrl.go @@ -0,0 +1,55 @@ +package stats + +import "time" + +// Start stores the "start" timestamp +func (s *Stats) Start() { + s.lock.Lock() + defer s.lock.Unlock() + + if s.timeStart.IsZero() { + s.timeStart = time.Now() + } else if !s.timePause.IsZero() { + s.timePaused += time.Since(s.timePause) + // Reset + s.timePause = time.Time{} + } +} + +// Pause stores an interruption timestamp +func (s *Stats) Pause() { + s.lock.RLock() + + if s.timeStart.IsZero() || !s.timeStop.IsZero() { + // Can't stop if not started, or if stopped + s.lock.RUnlock() + return + } + s.lock.RUnlock() + + s.lock.Lock() + defer s.lock.Unlock() + + if s.timePause.IsZero() { + s.timePause = time.Now() + } +} + +// Stop stores the "stop" timestamp +func (s *Stats) Stop() { + s.lock.RLock() + + if s.timeStart.IsZero() { + // Can't stop if not started + s.lock.RUnlock() + return + } + s.lock.RUnlock() + + s.lock.Lock() + defer s.lock.Unlock() + + if s.timeStop.IsZero() { + s.timeStop = time.Now() + } +} diff --git a/pkg/stats/ctrl_test.go b/pkg/stats/ctrl_test.go new file mode 100644 index 00000000..0fbb6c41 --- /dev/null +++ b/pkg/stats/ctrl_test.go @@ -0,0 +1,55 @@ +package stats + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func Test_ControlFlow(t *testing.T) { + assert := assert.New(t) + s := New() + + // Everything should be 0 at the beginning + assert.Equal(true, s.timeStart.IsZero()) + assert.Equal(true, s.timeStop.IsZero()) + assert.Equal(true, s.timePause.IsZero()) + + // Should not do anything + s.Stop() + assert.Equal(true, s.timeStop.IsZero()) + + // Should not do anything + s.Pause() + assert.Equal(true, s.timePause.IsZero()) + + // Should start + s.Start() + originalStart := s.timeStart + assert.Equal(false, s.timeStart.IsZero()) + + // Should pause + s.Pause() + assert.Equal(false, s.timePause.IsZero()) + originalPause := s.timePause + // Should not modify + s.Pause() + assert.Equal(originalPause, s.timePause) + + // Should release + assert.Equal(int64(0), s.timePaused.Nanoseconds()) + s.Start() + assert.NotEqual(0, s.timePaused.Nanoseconds()) + originalPausedDuration := s.timePaused + assert.Equal(true, s.timePause.IsZero()) + assert.Equal(originalStart, s.timeStart) + + s.Pause() + time.Sleep(10 * time.Nanosecond) + s.Start() + assert.Equal(true, s.timePaused > originalPausedDuration) + + s.Stop() + assert.Equal(false, s.timeStop.IsZero()) +} diff --git a/pkg/stats/data.go b/pkg/stats/data.go new file mode 100644 index 00000000..91f834ad --- /dev/null +++ b/pkg/stats/data.go @@ -0,0 +1,26 @@ +package stats + +import "time" + +// Duration returns the 'stop - start' duration, if stopped +// Returns 0 if not started +// Returns time.Since(s.timeStart) if not stopped +func (s *Stats) Duration() time.Duration { + s.lock.RLock() + defer s.lock.RUnlock() + + if s.timeStart.IsZero() { + return 0 + } else if s.timeStop.IsZero() { + return time.Since(s.timeStart) - s.timePaused + } + return s.timeStop.Sub(s.timeStart) - s.timePaused +} + +// Bandwidth returns the IO speed in MB/s +func (s *Stats) Bandwidth() float64 { + s.lock.RLock() + defer s.lock.RUnlock() + + return (float64(s.nbBytes) / 1024 / 1024) / s.Duration().Seconds() +} diff --git a/pkg/stats/data_test.go b/pkg/stats/data_test.go new file mode 100644 index 00000000..551dec9a --- /dev/null +++ b/pkg/stats/data_test.go @@ -0,0 +1,79 @@ +package stats + +import ( + "math" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func Test_Bandwidth(t *testing.T) { + assert := assert.New(t) + s := New() + + now := time.Now() + tests := []struct { + startTime time.Time + stopTime time.Time + pauseDuration time.Duration + bytesCount uint64 + expectedBandwidth float64 + }{ + { + startTime: time.Time{}, + stopTime: time.Time{}, + pauseDuration: 0, + bytesCount: 0, + expectedBandwidth: math.NaN(), + }, + { + startTime: now, + stopTime: time.Time{}, + pauseDuration: 0, + bytesCount: 0, + expectedBandwidth: 0, + }, + { + startTime: now, + stopTime: now.Add(time.Duration(1 * 1000000000)), + pauseDuration: 0, + bytesCount: 1024 * 1024, + expectedBandwidth: 1, + }, + { + startTime: now, + stopTime: now.Add(time.Duration(2 * 1000000000)), + pauseDuration: time.Duration(1 * 1000000000), + bytesCount: 1024 * 1024, + expectedBandwidth: 1, + }, + } + + for _, cur := range tests { + s.timeStart = cur.startTime + s.timeStop = cur.stopTime + s.timePaused = cur.pauseDuration + s.nbBytes = cur.bytesCount + + if math.IsNaN(cur.expectedBandwidth) { + assert.Equal(true, math.IsNaN(s.Bandwidth())) + } else { + assert.Equal(cur.expectedBandwidth, s.Bandwidth()) + } + } +} + +func Test_Duration(t *testing.T) { + assert := assert.New(t) + s := New() + + // Should be 0 + assert.Equal(time.Duration(0), s.Duration()) + + // Should return time.Since() + s.Start() + durationTmp := s.Duration() + time.Sleep(10 * time.Nanosecond) + assert.Equal(true, s.Duration() > durationTmp) +} diff --git a/pkg/stats/stats.go b/pkg/stats/stats.go new file mode 100644 index 00000000..e6fba76c --- /dev/null +++ b/pkg/stats/stats.go @@ -0,0 +1,31 @@ +package stats + +import ( + "fmt" + "sync" + "time" +) + +// Stats provide a way to track statistics infos +type Stats struct { + lock *sync.RWMutex + nbBytes uint64 + timeStart time.Time + timeStop time.Time + + timePause time.Time + timePaused time.Duration +} + +// New creates a new Stats +func New() *Stats { + return &Stats{ + lock: &sync.RWMutex{}, + } +} + +func (s *Stats) String() string { + s.lock.RLock() + defer s.lock.RUnlock() + return fmt.Sprintf("%v bytes | %-v | %0.4f MB/s", s.Bytes(), s.Duration(), s.Bandwidth()) +} diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go new file mode 100644 index 00000000..ee5037d1 --- /dev/null +++ b/pkg/utils/utils.go @@ -0,0 +1,102 @@ +package utils + +import ( + "bufio" + "bytes" + "compress/gzip" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "strings" +) + +// MustReadStream blocks until input is received from the stream +func MustReadStream(stream io.Reader) (string, error) { + r := bufio.NewReader(stream) + + var in string + for { + var err error + in, err = r.ReadString('\n') + if err != io.EOF { + if err != nil { + return "", err + } + } + in = strings.TrimSpace(in) + if len(in) > 0 { + break + } + } + + fmt.Println("") + return in, nil +} + +// StripSDP remove useless elements from an SDP +func StripSDP(originalSDP string) string { + finalSDP := strings.Replace(originalSDP, "a=group:BUNDLE audio video data", "a=group:BUNDLE data", -1) + tmp := strings.Split(finalSDP, "m=audio") + beginningSdp := tmp[0] + + var endSdp string + if len(tmp) > 1 { + tmp = strings.Split(tmp[1], "a=end-of-candidates") + endSdp = strings.Join(tmp[2:], "a=end-of-candidates") + } else { + endSdp = strings.Join(tmp[1:], "a=end-of-candidates") + } + + finalSDP = beginningSdp + endSdp + finalSDP = strings.Replace(finalSDP, "\r\n\r\n", "\r\n", -1) + finalSDP = strings.Replace(finalSDP, "\n\n", "\n", -1) + return finalSDP +} + +// Encode encodes the input in base64 +// It can optionally zip the input before encoding +func Encode(obj interface{}) (string, error) { + b, err := json.Marshal(obj) + if err != nil { + return "", err + } + var gzbuff bytes.Buffer + gz, err := gzip.NewWriterLevel(&gzbuff, gzip.BestCompression) + if err != nil { + return "", err + } + if _, err := gz.Write(b); err != nil { + return "", err + } + if err := gz.Flush(); err != nil { + return "", err + } + if err := gz.Close(); err != nil { + return "", err + } + + return base64.StdEncoding.EncodeToString(gzbuff.Bytes()), nil +} + +// Decode decodes the input from base64 +// It can optionally unzip the input after decoding +func Decode(in string, obj interface{}) error { + b, err := base64.StdEncoding.DecodeString(in) + if err != nil { + return err + } + + gz, err := gzip.NewReader(bytes.NewReader(b)) + if err != nil { + return err + } + defer gz.Close() + s, err := ioutil.ReadAll(gz) + if err != nil { + return err + } + + return json.Unmarshal(s, obj) +} diff --git a/pkg/utils/utils_test.go b/pkg/utils/utils_test.go new file mode 100644 index 00000000..9c8da0e5 --- /dev/null +++ b/pkg/utils/utils_test.go @@ -0,0 +1,223 @@ +package utils + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_ReadStream(t *testing.T) { + assert := assert.New(t) + stream := &bytes.Buffer{} + + _, err := stream.WriteString("Hello\n") + assert.Nil(err) + + str, err := MustReadStream(stream) + assert.Equal("Hello", str) + assert.Nil(err) +} + +func Test_StripSDP(t *testing.T) { + assert := assert.New(t) + + tests := []struct { + sdp string + expected string + }{ + { + sdp: "", + expected: "", + }, + { + sdp: `v=0 +o=- 297292268 1552262038 IN IP4 0.0.0.0 +s=- +t=0 0 +a=fingerprint:sha-256 70:E0:B2:DA:F8:04:D6:0C:32:03:DF:CD:A8:70:EC:45:10:FF:66:6F:3D:72:B1:BA:4C:AF:FB:5E:BE:F9:CF:6A +a=group:BUNDLE audio video data +m=audio 9 UDP/TLS/RTP/SAVPF 111 9 +c=IN IP4 0.0.0.0 +a=setup:actpass +a=mid:audio +a=ice-ufrag:SNxNaqIiaNoDiCNM +a=ice-pwd:dSZlwOEOKEmBfNiXCtpmPTOVJlwUCaFX +a=rtcp-mux +a=rtcp-rsize +a=rtpmap:111 opus/48000/2 +a=fmtp:111 minptime=10;useinbandfec=1 +a=rtpmap:9 G722/8000 +a=recvonly +a=candidate:foundation 1 udp 3776 192.168.100.207 61879 typ host generation 0 +a=candidate:foundation 2 udp 3776 192.168.100.207 61879 typ host generation 0 +a=end-of-candidates +a=setup:actpass +m=video 9 UDP/TLS/RTP/SAVPF 96 100 98 +c=IN IP4 0.0.0.0 +a=setup:actpass +a=mid:video +a=ice-ufrag:SNxNaqIiaNoDiCNM +a=ice-pwd:dSZlwOEOKEmBfNiXCtpmPTOVJlwUCaFX +a=rtcp-mux +a=rtcp-rsize +a=rtpmap:96 VP8/90000 +a=rtpmap:100 H264/90000 +a=fmtp:100 level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42001f +a=rtpmap:98 VP9/90000 +a=recvonly +a=candidate:foundation 1 udp 3776 192.168.100.207 61879 typ host generation 0 +a=candidate:foundation 2 udp 3776 192.168.100.207 61879 typ host generation 0 +a=end-of-candidates +a=setup:actpass +m=application 9 DTLS/SCTP 5000 +c=IN IP4 0.0.0.0 +a=setup:actpass +a=mid:data +a=sendrecv +a=sctpmap:5000 webrtc-datachannel 1024 +a=ice-ufrag:SNxNaqIiaNoDiCNM +a=ice-pwd:dSZlwOEOKEmBfNiXCtpmPTOVJlwUCaFX +a=candidate:foundation 1 udp 3776 192.168.100.207 61879 typ host generation 0 +a=candidate:foundation 2 udp 3776 192.168.100.207 61879 typ host generation 0 +a=end-of-candidates +a=setup:actpass +`, + expected: `v=0 +o=- 297292268 1552262038 IN IP4 0.0.0.0 +s=- +t=0 0 +a=fingerprint:sha-256 70:E0:B2:DA:F8:04:D6:0C:32:03:DF:CD:A8:70:EC:45:10:FF:66:6F:3D:72:B1:BA:4C:AF:FB:5E:BE:F9:CF:6A +a=group:BUNDLE data +a=setup:actpass +m=application 9 DTLS/SCTP 5000 +c=IN IP4 0.0.0.0 +a=setup:actpass +a=mid:data +a=sendrecv +a=sctpmap:5000 webrtc-datachannel 1024 +a=ice-ufrag:SNxNaqIiaNoDiCNM +a=ice-pwd:dSZlwOEOKEmBfNiXCtpmPTOVJlwUCaFX +a=candidate:foundation 1 udp 3776 192.168.100.207 61879 typ host generation 0 +a=candidate:foundation 2 udp 3776 192.168.100.207 61879 typ host generation 0 +a=end-of-candidates +a=setup:actpass +`, + }, + } + + for _, cur := range tests { + assert.Equal(cur.expected, StripSDP(cur.sdp)) + } +} + +func Test_Encode(t *testing.T) { + assert := assert.New(t) + + tests := []struct { + input interface{} + shouldErr bool + expected string + }{ + // Invalid object + { + input: make(chan int), + shouldErr: true, + }, + // Empty input + { + input: nil, + shouldErr: false, + expected: "H4sIAAAAAAAC/8orzckBAAAA//8BAAD//0/8yyUEAAAA", + }, + // Not JSON + { + input: "ThisTestIsNotInB64", + shouldErr: false, + expected: "H4sIAAAAAAAC/1IKycgsDkktLvEs9ssv8cxzMjNRAgAAAP//AQAA//8+sWiWFAAAAA==", + }, + // JSON + { + input: struct { + Name string `json:"name"` + }{ + Name: "TestJson", + }, + shouldErr: false, + expected: "H4sIAAAAAAAC/6pWykvMTVWyUgpJLS7xKs7PU6oFAAAA//8BAAD//3cqgZQTAAAA", + }, + } + + for _, cur := range tests { + res, err := Encode(cur.input) + + if cur.shouldErr { + assert.NotNil(err) + } else { + assert.Nil(err) + assert.Equal(cur.expected, res) + } + } +} + +func Test_Decode(t *testing.T) { + assert := assert.New(t) + + tests := []struct { + input string + shouldErr bool + }{ + // Empty string + { + input: "", + shouldErr: true, + }, + // Not base64 + { + input: "ThisTestIsNotInB64", + shouldErr: true, + }, + // Not base64 JSON + { + input: "aGVsbG8gd29ybGQ=", + shouldErr: true, + }, + // Base64 JSON + { + input: "H4sIAAAAAAAC/+xVTY/bNhD9KwOdK5ukqK8JdFh77XabNPXGXqcJcmFEystGogiKsuMU/e+FLG/qtEWBBRo0h4UgCTPz+OaRegP9FvijVQEGwnQH5YLvgk7aAIN9Qd65d6YtQojzmCRpmmZA45ixhFNO4eYl3Kw4kMnpGqBdEQ4vXxA4xaKotNkpZ502Hrt7EbI4gYjiLMKE4oIhYZgskWXIyfCcLZFdY5ZgxnEeYTRDHmGeYsyQ57icIWdI5ji/wnmClCBLMUowjjBLcXGFEUHOx8Y71/YWZ3cvr18sQPRSt7DXUrUghRcDpCnGbA5316vp5sV6+mqzmq6vtqslUEohH0Bl8fdNiqJTvrcoSq/3asw0WuKJbgx1qcK+cmKH5avl6zflzeGu8z9vd4df/qzbg8TXerH7cfnrQj9/u/a36+2b55ubulosDx9u5eb2p7cj2vnShk3/8SJynf6kHmLbCIuD5Nb23ZRnhJApOx9/48dSo431ulEFJc/6TmnzXhhZqbKgX7Dk8H3K2HSgOCs1l9sshZFaCq+wansjhdetAQq9tBClaQI0ZxOaZBMyoVkOcRQnCfijhfu287BTRrlxCfkXOvbf0p3VUUKApxM63CyfRAxinhJ6outcVX8EJ6R0f2npbOv8Gfk4+V+jnzIybKvwc9tutPFo63+ycZ7AoCPPHmvlE+X/ZuU8ge0qm+bkswsfPE4I/MASflkaHU4I1Gqv6lB0x6ZR3h1DUdftQcmCPrOi/KC8/nQ6zLBppRqSrq10rcJxmZYFZ4TQ6kshGWxX+WW3p3H45sdBWFvrckTmcD1MxHq+WUF8/oiPmYOHf8VQN9Kpcn+OytEgAycc1Hvny3DAlvfCGFUDJYx/jfF5Mtw3Z7jg9z8AAAD//wEAAP//RjpVQj8JAAA=", + shouldErr: false, + }, + } + + var obj interface{} + for _, cur := range tests { + err := Decode(cur.input, &obj) + + if cur.shouldErr { + assert.NotNil(err) + } else { + assert.Nil(err) + } + } +} + +func Test_EncodeDecode(t *testing.T) { + assert := assert.New(t) + + input := struct { + Name string `json:"name"` + }{ + Name: "TestJson", + } + + encoded, err := Encode(input) + assert.Nil(err) + assert.Equal("H4sIAAAAAAAC/6pWykvMTVWyUgpJLS7xKs7PU6oFAAAA//8BAAD//3cqgZQTAAAA", encoded) + + var obj struct { + Name string `json:"name"` + } + err = Decode(encoded, &obj) + assert.Nil(err) + assert.Equal(input, obj) +}