From 6f527f9af77c12daa47d4650c326a4bf2ee4dccb Mon Sep 17 00:00:00 2001 From: Hong Xiaojian Date: Mon, 6 Dec 2021 19:30:39 +0800 Subject: [PATCH] feat(bridge): add WebSocket bridge (#263) * feat(bridge): add WebSocket bridge * feat(bridge): send DataFrame back to WebSocket connection * feat(bridge): broadcast data to multi connections in WebSocket bridge * refactor: remove session in context * set websocket conn payload type * feat(websocket): broadcast offline message * Revert "feat(websocket): broadcast offline message" This reverts commit da76906913ce7ace1bcd297f35d734204148824a. Co-authored-by: venjiang --- .gitignore | 3 + core/bridge.go | 18 ++++ core/client.go | 1 + core/connector.go | 52 ++++++--- core/context.go | 61 +++++++---- core/router.go | 8 ++ core/server.go | 30 ++++-- example/Taskfile.yml | 67 ++++++++++++ example/websocket-bridge/client-1/main.go | 59 +++++++++++ example/websocket-bridge/client-2/main.go | 36 +++++++ example/websocket-bridge/zipper/config.yaml | 8 ++ example/websocket-bridge/zipper/main.go | 32 ++++++ go.mod | 1 + pkg/bridge/bridges.go | 35 ++++++ pkg/bridge/websocket.go | 111 ++++++++++++++++++++ pkg/config/zipper_workflow.go | 10 ++ zipper.go | 18 ++++ 17 files changed, 507 insertions(+), 43 deletions(-) create mode 100644 core/bridge.go create mode 100644 example/websocket-bridge/client-1/main.go create mode 100644 example/websocket-bridge/client-2/main.go create mode 100644 example/websocket-bridge/zipper/config.yaml create mode 100644 example/websocket-bridge/zipper/main.go create mode 100644 pkg/bridge/bridges.go create mode 100644 pkg/bridge/websocket.go diff --git a/.gitignore b/.gitignore index a9acb7f41..a1e1b85b2 100644 --- a/.gitignore +++ b/.gitignore @@ -31,6 +31,9 @@ example/same-stream-fn/stream-fn/fn1 example/same-stream-fn/stream-fn/fn2 example/same-stream-fn/zipper/zipper example/multi-zipper/bin +example/websocket-bridge/client-1/client +example/websocket-bridge/client-2/client +example/websocket-bridge/zipper/zipper # cli cli/example/source/source cli/example/stream-fn-db/stream-fn-db diff --git a/core/bridge.go b/core/bridge.go new file mode 100644 index 000000000..b0cc6012c --- /dev/null +++ b/core/bridge.go @@ -0,0 +1,18 @@ +package core + +import "github.com/yomorun/yomo/core/frame" + +// Bridge is an interface of bridge which connects the clients of different transport protocols (f.e. WebSocket) with zipper. +type Bridge interface { + // Name returns the name of bridge. + Name() string + + // Addr returns the address of bridge. + Addr() string + + // ListenAndServe starts a server with a given handler. + ListenAndServe(handler func(ctx *Context)) error + + // Send the data to clients. + Send(f frame.Frame) error +} diff --git a/core/client.go b/core/client.go index 903fa3277..94ddee041 100644 --- a/core/client.go +++ b/core/client.go @@ -50,6 +50,7 @@ func NewClient(appName string, connType ClientType, opts ...ClientOption) *Clien return c } +// Init the options. func (c *Client) Init(opts ...ClientOption) error { for _, o := range opts { o(&c.opts) diff --git a/core/connector.go b/core/connector.go index 8eee1267e..73338e30c 100644 --- a/core/connector.go +++ b/core/connector.go @@ -2,18 +2,13 @@ package core import ( "fmt" + "io" "sync" - "github.com/lucas-clemente/quic-go" "github.com/yomorun/yomo/core/frame" "github.com/yomorun/yomo/pkg/logger" ) -type connStream struct { - id string // connection id (remote addr) - stream *quic.Stream // quic stream -} - type app struct { id string // app id name string // app name @@ -33,20 +28,33 @@ func (a *app) Name() string { var _ Connector = &connector{} +// Connector is a interface to manage the connections and applications. type Connector interface { - Add(connID string, stream *quic.Stream) + // Add a connection. + Add(connID string, stream io.ReadWriteCloser) + // Remove a connection. Remove(connID string) - Get(connID string) *quic.Stream + // Get a connection by connection id. + Get(connID string) io.ReadWriteCloser + // ConnID gets the connection id by appID and mae. ConnID(appID string, name string) (string, bool) + // Write a DataFrame from a connection to another one. Write(f *frame.DataFrame, fromID string, toID string) error - GetSnapshot() map[string]*quic.Stream + // GetSnapshot gets the snapshot of all connections. + GetSnapshot() map[string]io.ReadWriteCloser + // App gets the app by connID. App(connID string) (*app, bool) + // AppID gets the ID of app by connID. AppID(connID string) (string, bool) + // AppName gets the name of app by connID. AppName(connID string) (string, bool) + // LinkApp links the app and connection. LinkApp(connID string, appID string, name string) + // UnlinkApp removes the app by connID. UnlinkApp(connID string, appID string, name string) + // Clean the connector. Clean() } @@ -62,11 +70,13 @@ func newConnector() Connector { } } -func (c *connector) Add(connID string, stream *quic.Stream) { +// Add a connection. +func (c *connector) Add(connID string, stream io.ReadWriteCloser) { logger.Debugf("%sconnector add: connID=%s", ServerLogPrefix, connID) c.conns.Store(connID, stream) } +// Remove a connection. func (c *connector) Remove(connID string) { logger.Debugf("%sconnector remove: connID=%s", ServerLogPrefix, connID) c.conns.Delete(connID) @@ -74,14 +84,16 @@ func (c *connector) Remove(connID string) { c.apps.Delete(connID) } -func (c *connector) Get(connID string) *quic.Stream { +// Get a connection by connection id. +func (c *connector) Get(connID string) io.ReadWriteCloser { logger.Debugf("%sconnector get connection: connID=%s", ServerLogPrefix, connID) if stream, ok := c.conns.Load(connID); ok { - return stream.(*quic.Stream) + return stream.(io.ReadWriteCloser) } return nil } +// App gets the app by connID. func (c *connector) App(connID string) (*app, bool) { if result, found := c.apps.Load(connID); found { app, ok := result.(*app) @@ -96,6 +108,7 @@ func (c *connector) App(connID string) (*app, bool) { return nil, false } +// AppID gets the ID of app by connID. func (c *connector) AppID(connID string) (string, bool) { if app, ok := c.App(connID); ok { return app.id, true @@ -103,6 +116,7 @@ func (c *connector) AppID(connID string) (string, bool) { return "", false } +// AppName gets the name of app by connID. func (c *connector) AppName(connID string) (string, bool) { if app, ok := c.App(connID); ok { return app.name, true @@ -110,6 +124,7 @@ func (c *connector) AppName(connID string) (string, bool) { return "", false } +// ConnID gets the connection id by appID and mae. func (c *connector) ConnID(appID string, name string) (string, bool) { var connID string var ok bool @@ -131,30 +146,34 @@ func (c *connector) ConnID(appID string, name string) (string, bool) { return connID, true } +// Write a DataFrame from a connection to another one. func (c *connector) Write(f *frame.DataFrame, fromID string, toID string) error { targetStream := c.Get(toID) if targetStream == nil { logger.Warnf("%swill write to: [%s] -> [%s], target stream is nil", ServerLogPrefix, fromID, toID) return fmt.Errorf("target[%s] stream is nil", toID) } - _, err := (*targetStream).Write(f.Encode()) + _, err := targetStream.Write(f.Encode()) return err } -func (c *connector) GetSnapshot() map[string]*quic.Stream { - result := make(map[string]*quic.Stream) +// GetSnapshot gets the snapshot of all connections. +func (c *connector) GetSnapshot() map[string]io.ReadWriteCloser { + result := make(map[string]io.ReadWriteCloser) c.conns.Range(func(key interface{}, val interface{}) bool { - result[key.(string)] = val.(*quic.Stream) + result[key.(string)] = val.(io.ReadWriteCloser) return true }) return result } +// LinkApp links the app and connection. func (c *connector) LinkApp(connID string, appID string, name string) { logger.Debugf("%sconnector link application: connID[%s] --> app[%s::%s]", ServerLogPrefix, connID, appID, name) c.apps.Store(connID, newApp(appID, name)) } +// UnlinkApp removes the app by connID. func (c *connector) UnlinkApp(connID string, appID string, name string) { logger.Debugf("%sconnector unlink application: connID[%s] x-> app[%s::%s]", ServerLogPrefix, connID, appID, name) c.apps.Delete(connID) @@ -178,6 +197,7 @@ func (c *connector) UnlinkApp(connID string, appID string, name string) { // return conns // } +// Clean the connector. func (c *connector) Clean() { c.conns = sync.Map{} c.apps = sync.Map{} diff --git a/core/context.go b/core/context.go index c77cd9d18..e88d1e23e 100644 --- a/core/context.go +++ b/core/context.go @@ -1,6 +1,7 @@ package core import ( + "io" "sync" "time" @@ -9,47 +10,62 @@ import ( "github.com/yomorun/yomo/pkg/logger" ) +// Context for YoMo Server. type Context struct { - Session quic.Session - Stream quic.Stream - Frame frame.Frame - Keys map[string]interface{} + // ConnID is the connection ID of client. + ConnID string + // Stream is the long-lived connection between client and server. + Stream io.ReadWriteCloser + // Frame receives from client. + Frame frame.Frame + // Keys store the key/value pairs in context. + Keys map[string]interface{} + + // SendDataBack is the callback function when the zipper needs to send the data back to the client's connection. + // For example, the data needs to be sent back to the connections from WebSocket Bridge. + SendDataBack func(f frame.Frame) error + + // OnClose is the callback function when the conn (or stream) is closed. + OnClose func(code uint64, msg string) mu sync.RWMutex } -func newContext(session quic.Session, stream quic.Stream) *Context { +func newContext(connID string, stream quic.Stream) *Context { return &Context{ - Session: session, - Stream: stream, + ConnID: connID, + Stream: stream, // keys: make(map[string]interface{}), } } +// WithFrame sets a frame to context. func (c *Context) WithFrame(f frame.Frame) *Context { c.Frame = f return c } +// Clean the context. func (c *Context) Clean() { - logger.Debugf("%sconn[%s] context clean", ServerLogPrefix, c.ConnID()) - c.Session = nil + logger.Debugf("%sconn[%s] context clean", ServerLogPrefix, c.ConnID) c.Stream = nil c.Frame = nil c.Keys = nil } +// CloseWithError closes the stream and cleans the context. func (c *Context) CloseWithError(code uint64, msg string) { - logger.Debugf("%sconn[%s] context close, errCode=%d, msg=%s", ServerLogPrefix, c.ConnID(), code, msg) + logger.Debugf("%sconn[%s] context close, errCode=%d, msg=%s", ServerLogPrefix, c.ConnID, code, msg) if c.Stream != nil { c.Stream.Close() } - if c.Session != nil { - c.Session.CloseWithError(quic.ApplicationErrorCode(code), msg) + if c.OnClose != nil { + c.OnClose(code, msg) } c.Clean() } +// Set a key/value pair to context. func (c *Context) Set(key string, value interface{}) { c.mu.Lock() if c.Keys == nil { @@ -60,6 +76,7 @@ func (c *Context) Set(key string, value interface{}) { c.mu.Unlock() } +// Get the value by a specified key. func (c *Context) Get(key string) (value interface{}, exists bool) { c.mu.RLock() value, exists = c.Keys[key] @@ -67,6 +84,7 @@ func (c *Context) Get(key string) (value interface{}, exists bool) { return } +// GetString gets a string value by a specified key. func (c *Context) GetString(key string) (s string) { if val, ok := c.Get(key); ok && val != nil { s, _ = val.(string) @@ -74,6 +92,7 @@ func (c *Context) GetString(key string) (s string) { return } +// GetBool gets a bool value by a specified key. func (c *Context) GetBool(key string) (b bool) { if val, ok := c.Get(key); ok && val != nil { b, _ = val.(bool) @@ -81,6 +100,7 @@ func (c *Context) GetBool(key string) (b bool) { return } +// GetInt gets an int value by a specified key. func (c *Context) GetInt(key string) (i int) { if val, ok := c.Get(key); ok && val != nil { i, _ = val.(int) @@ -88,6 +108,7 @@ func (c *Context) GetInt(key string) (i int) { return } +// GetInt64 gets an int64 value by a specified key. func (c *Context) GetInt64(key string) (i64 int64) { if val, ok := c.Get(key); ok && val != nil { i64, _ = val.(int64) @@ -95,6 +116,7 @@ func (c *Context) GetInt64(key string) (i64 int64) { return } +// GetUint gets an uint value by a specified key. func (c *Context) GetUint(key string) (ui uint) { if val, ok := c.Get(key); ok && val != nil { ui, _ = val.(uint) @@ -102,6 +124,7 @@ func (c *Context) GetUint(key string) (ui uint) { return } +// GetUint64 gets an uint64 value by a specified key. func (c *Context) GetUint64(key string) (ui64 uint64) { if val, ok := c.Get(key); ok && val != nil { ui64, _ = val.(uint64) @@ -109,6 +132,7 @@ func (c *Context) GetUint64(key string) (ui64 uint64) { return } +// GetFloat64 gets a float64 value by a specified key. func (c *Context) GetFloat64(key string) (f64 float64) { if val, ok := c.Get(key); ok && val != nil { f64, _ = val.(float64) @@ -116,6 +140,7 @@ func (c *Context) GetFloat64(key string) (f64 float64) { return } +// GetTime gets a time.Time value by a specified key. func (c *Context) GetTime(key string) (t time.Time) { if val, ok := c.Get(key); ok && val != nil { t, _ = val.(time.Time) @@ -123,6 +148,7 @@ func (c *Context) GetTime(key string) (t time.Time) { return } +// GetDuration gets a time.Duration value by a specified key. func (c *Context) GetDuration(key string) (d time.Duration) { if val, ok := c.Get(key); ok && val != nil { d, _ = val.(time.Duration) @@ -130,6 +156,7 @@ func (c *Context) GetDuration(key string) (d time.Duration) { return } +// GetStringSlice gets a []string value by a specified key. func (c *Context) GetStringSlice(key string) (ss []string) { if val, ok := c.Get(key); ok && val != nil { ss, _ = val.([]string) @@ -137,6 +164,7 @@ func (c *Context) GetStringSlice(key string) (ss []string) { return } +// GetStringMap gets a map[string]interface{} value by a specified key. func (c *Context) GetStringMap(key string) (sm map[string]interface{}) { if val, ok := c.Get(key); ok && val != nil { sm, _ = val.(map[string]interface{}) @@ -144,6 +172,7 @@ func (c *Context) GetStringMap(key string) (sm map[string]interface{}) { return } +// GetStringMapString gets a map[string]string value by a specified key. func (c *Context) GetStringMapString(key string) (sms map[string]string) { if val, ok := c.Get(key); ok && val != nil { sms, _ = val.(map[string]string) @@ -151,16 +180,10 @@ func (c *Context) GetStringMapString(key string) (sms map[string]string) { return } +// GetStringMapStringSlice gets a map[string][]string value by a specified key. func (c *Context) GetStringMapStringSlice(key string) (smss map[string][]string) { if val, ok := c.Get(key); ok && val != nil { smss, _ = val.(map[string][]string) } return } - -func (c *Context) ConnID() string { - if c.Session != nil { - return c.Session.RemoteAddr().String() - } - return "" -} diff --git a/core/router.go b/core/router.go index 5eed39203..23b1a5be2 100644 --- a/core/router.go +++ b/core/router.go @@ -1,11 +1,19 @@ package core +// Router is the interface to manage the routes for applications. type Router interface { + // Route gets the route by appID. Route(appID string) Route + // Clean the routes. Clean() } + +// Route is the interface for route. type Route interface { + // Add a route. Add(index int, name string) + // Next gets the next route. Next(current string) (string, bool) + // Exists indicates whether the route exists or not. Exists(name string) bool } diff --git a/core/server.go b/core/server.go index ea1dd4356..7c815f9a2 100644 --- a/core/server.go +++ b/core/server.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "io" "net" "reflect" "sync" @@ -55,6 +56,7 @@ func NewServer(name string, opts ...ServerOption) *Server { return s } +// Init the options. func (s *Server) Init(opts ...ServerOption) error { for _, o := range opts { o(&s.opts) @@ -81,6 +83,7 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) error { return s.Serve(ctx, conn) } +// Serve the server with a net.PacketConn. func (s *Server) Serve(ctx context.Context, conn net.PacketConn) error { listener := newListener() // listen the address @@ -101,7 +104,6 @@ func (s *Server) Serve(ctx context.Context, conn net.PacketConn) error { session, err := listener.Accept(sctx) if err != nil { logger.Errorf("%screate session error: %v", ServerLogPrefix, err) - sctx.Done() return err } @@ -131,8 +133,13 @@ func (s *Server) Serve(ctx context.Context, conn net.PacketConn) error { logger.Infof("%s❤️4/ [stream:%d] created, connID=%s", ServerLogPrefix, stream.StreamID(), connID) // process frames on stream - c := newContext(sess, stream) + c := newContext(connID, stream) defer c.Clean() + c.OnClose = func(code uint64, msg string) { + if sess != nil { + sess.CloseWithError(quic.ApplicationErrorCode(code), msg) + } + } s.handleSession(c) logger.Infof("%s❤️5/ [stream:%d] handleSession DONE", ServerLogPrefix, stream.StreamID()) } @@ -257,7 +264,7 @@ func (s *Server) handleHandshakeFrame(c *Context) error { if err := s.validateRouter(); err != nil { return err } - connID := c.ConnID() + connID := c.ConnID route := s.router.Route(appID) if reflect.ValueOf(route).IsNil() { err := errors.New("handleHandshakeFrame route is nil") @@ -272,7 +279,7 @@ func (s *Server) handleHandshakeFrame(c *Context) error { stream := c.Stream switch clientType { case ClientTypeSource: - s.connector.Add(connID, &stream) + s.connector.Add(connID, stream) s.connector.LinkApp(connID, appID, name) case ClientTypeStreamFunction: // when sfn connect, it will provide its name to the server. server will check if this client @@ -287,11 +294,11 @@ func (s *Server) handleHandshakeFrame(c *Context) error { return err } - s.connector.Add(connID, &stream) + s.connector.Add(connID, stream) // link connection to stream function s.connector.LinkApp(connID, appID, name) case ClientTypeUpstreamZipper: - s.connector.Add(connID, &stream) + s.connector.Add(connID, stream) s.connector.LinkApp(connID, appID, name) default: // unknown client type @@ -314,7 +321,7 @@ func (s *Server) handleDataFrame(c *Context) error { // counter +1 atomic.AddInt64(&s.counterOfDataFrame, 1) // currentIssuer := f.GetIssuer() - fromID := c.ConnID() + fromID := c.ConnID from, ok := s.connector.AppName(fromID) if !ok { logger.Warnf("%shandleDataFrame have connection[%s], but not have function", ServerLogPrefix, fromID) @@ -356,7 +363,7 @@ func (s *Server) handleDataFrame(c *Context) error { // StatsFunctions returns the sfn stats of server. // func (s *Server) StatsFunctions() map[string][]*quic.Stream { -func (s *Server) StatsFunctions() map[string]*quic.Stream { +func (s *Server) StatsFunctions() map[string]io.ReadWriteCloser { return s.connector.GetSnapshot() } @@ -416,6 +423,13 @@ func (s *Server) dispatchToDownstreams(df *frame.DataFrame) { } } +// AddBridge add a bridge to this server. +func (s *Server) AddBridge(bridge Bridge) { + // serve bridge + go bridge.ListenAndServe(s.handleSession) + logger.Debugf("%sadd a bridge, name=[%s], addr=[%s]", ServerLogPrefix, bridge.Name(), bridge.Addr()) +} + // GetConnID get quic session connection id func GetConnID(sess quic.Session) string { return sess.RemoteAddr().String() diff --git a/example/Taskfile.yml b/example/Taskfile.yml index c6ebb6c9d..82f0ba725 100644 --- a/example/Taskfile.yml +++ b/example/Taskfile.yml @@ -23,6 +23,9 @@ tasks: - rm -rf same-stream-fn/stream-fn/fn1 - rm -rf same-stream-fn/stream-fn/fn2 - rm -rf same-stream-fn/zipper/zipper + - rm -rf websocket-bridge/client-1/client + - rm -rf websocket-bridge/client-2/client + - rm -rf websocket-bridge/zipper/zipper - echo 'example clean.' # basic example basic: @@ -332,3 +335,67 @@ tasks: YOMO_LOG_LEVEL: debug cmds: - ./source + + websocket-bridge: + desc: run websocket bridge example + deps: [websocket-bridge-zipper, websocket-bridge-client-1, websocket-bridge-client-2] + cmds: + - echo 'websocket bridge example' + + websocket-bridge-zipper: + desc: run websocket bridge zipper + deps: [websocket-bridge-zipper-build] + dir: "websocket-bridge/zipper" + cmds: + # - "yomo serve -v -c workflow.yaml" + - "./zipper{{exeExt}}" + env: + YOMO_ENABLE_DEBUG: true + YOMO_LOG_LEVEL: debug + + websocket-bridge-zipper-build: + desc: build websocket bridge zipper + dir: "websocket-bridge/zipper" + cmds: + - echo "websocket bridge zipper building..." + - "go build -o zipper{{exeExt}} main.go" + - echo "websocket bridge zipper built." + silent: false + + websocket-bridge-client-1: + desc: run websocket bridge client-1 example + deps: [websocket-bridge-client-1-build] + dir: "websocket-bridge/client-1" + cmds: + - "./client{{exeExt}}" + env: + YOMO_ENABLE_DEBUG: true + YOMO_LOG_LEVEL: debug + + websocket-bridge-client-1-build: + desc: build websocket bridge client-1 example + dir: "websocket-bridge/client-1" + cmds: + - echo "websocket bridge client-1 building..." + - "go build -o client{{exeExt}} main.go" + - echo "websocket bridge client-1 built." + silent: false + + websocket-bridge-client-2: + desc: run websocket bridge client-2 example + deps: [websocket-bridge-client-2-build] + dir: "websocket-bridge/client-2" + cmds: + - "./client{{exeExt}}" + env: + YOMO_ENABLE_DEBUG: true + YOMO_LOG_LEVEL: debug + + websocket-bridge-client-2-build: + desc: build websocket bridge client-2 example + dir: "websocket-bridge/client-2" + cmds: + - echo "websocket bridge client-2 building..." + - "go build -o client{{exeExt}} main.go" + - echo "websocket bridge client-2 built." + silent: false diff --git a/example/websocket-bridge/client-1/main.go b/example/websocket-bridge/client-1/main.go new file mode 100644 index 000000000..608931c9a --- /dev/null +++ b/example/websocket-bridge/client-1/main.go @@ -0,0 +1,59 @@ +package main + +import ( + "fmt" + "log" + "time" + + "github.com/yomorun/yomo/core" + "github.com/yomorun/yomo/core/auth" + "github.com/yomorun/yomo/core/frame" + "golang.org/x/net/websocket" +) + +func main() { + origin := "http://localhost/" + url := "ws://localhost:7000/" + ws, err := websocket.Dial(url, "", origin) + if err != nil { + log.Println(err) + // wait 2s for zipper start-up. + time.Sleep(2 * time.Second) + // reconnect + ws, _ = websocket.Dial(url, "", origin) + } + + // handshake + credential := auth.NewCredendialNone() + handshakeFrame := frame.NewHandshakeFrame("ws-bridge-client", byte(core.ClientTypeSource), credential.AppID(), byte(credential.Type()), credential.Payload()) + if _, err := ws.Write(handshakeFrame.Encode()); err != nil { + log.Fatal(err) + } + + count := 1 + for { + // send data. + msg := fmt.Sprintf("websocket-bridge #%d from [client 1]", count) + dataFrame := frame.NewDataFrame() + dataFrame.SetCarriage(0x33, []byte(msg)) + if _, err := ws.Write(dataFrame.Encode()); err != nil { + log.Fatal(err) + } + log.Printf("Sent: %s.\n", msg) + count++ + + time.Sleep(1 * time.Second) + + // receive echo data. + var buf = make([]byte, 512) + var n int + if n, err = ws.Read(buf); err == nil { + dataFrame, err := frame.DecodeToDataFrame(buf[:n]) + if err != nil { + log.Fatalf("Decode data to DataFrame failed, frame=%# x", buf[:n]) + } else { + log.Printf("Received: %s.\n", dataFrame.GetCarriage()) + } + } + } +} diff --git a/example/websocket-bridge/client-2/main.go b/example/websocket-bridge/client-2/main.go new file mode 100644 index 000000000..eca4f7785 --- /dev/null +++ b/example/websocket-bridge/client-2/main.go @@ -0,0 +1,36 @@ +package main + +import ( + "log" + "time" + + "github.com/yomorun/yomo/core/frame" + "golang.org/x/net/websocket" +) + +func main() { + origin := "http://localhost/" + url := "ws://localhost:7000/" + ws, err := websocket.Dial(url, "", origin) + if err != nil { + log.Println(err) + // wait 2s for zipper start-up. + time.Sleep(2 * time.Second) + // reconnect + ws, _ = websocket.Dial(url, "", origin) + } + + for { + // receive data from client-1. + var buf = make([]byte, 512) + var n int + if n, err = ws.Read(buf); err == nil { + dataFrame, err := frame.DecodeToDataFrame(buf[:n]) + if err != nil { + log.Fatalf("Decode data to DataFrame failed, frame=%# x", buf[:n]) + } else { + log.Printf("Received: %s.\n", dataFrame.GetCarriage()) + } + } + } +} diff --git a/example/websocket-bridge/zipper/config.yaml b/example/websocket-bridge/zipper/config.yaml new file mode 100644 index 000000000..8e46dbaa2 --- /dev/null +++ b/example/websocket-bridge/zipper/config.yaml @@ -0,0 +1,8 @@ +name: example-websocket-bridge +host: 0.0.0.0 +port: 9000 +functions: + - name: rand +bridges: + - name: websocket + port: 7000 \ No newline at end of file diff --git a/example/websocket-bridge/zipper/main.go b/example/websocket-bridge/zipper/main.go new file mode 100644 index 000000000..c0c2893bc --- /dev/null +++ b/example/websocket-bridge/zipper/main.go @@ -0,0 +1,32 @@ +package main + +import ( + "os" + + "github.com/yomorun/yomo" + "github.com/yomorun/yomo/pkg/logger" +) + +func main() { + zipper := yomo.NewZipperWithOptions( + "zipper-with-websocket-bridge", + yomo.WithZipperAddr("localhost:9000"), + ) + defer zipper.Close() + + err := zipper.ConfigWorkflow("config.yaml") + if err != nil { + panic(err) + } + + // start zipper service + go func(zipper yomo.Zipper) { + err := zipper.ListenAndServe() + if err != nil { + panic(err) + } + }(zipper) + + logger.Printf("Server has started!, pid: %d", os.Getpid()) + select {} +} diff --git a/go.mod b/go.mod index 4b147e9f2..75d2a79dd 100644 --- a/go.mod +++ b/go.mod @@ -9,5 +9,6 @@ require ( github.com/stretchr/testify v1.7.0 github.com/yomorun/y3 v1.0.4 go.uber.org/zap v1.19.0 + golang.org/x/net v0.0.0-20210428140749-89ef3d95e781 gopkg.in/yaml.v2 v2.4.0 ) diff --git a/pkg/bridge/bridges.go b/pkg/bridge/bridges.go new file mode 100644 index 000000000..a02a5ca9e --- /dev/null +++ b/pkg/bridge/bridges.go @@ -0,0 +1,35 @@ +package bridge + +import ( + "fmt" + + "github.com/yomorun/yomo/core" + "github.com/yomorun/yomo/pkg/config" + "github.com/yomorun/yomo/pkg/logger" +) + +const ( + nameOfWebSocket = "websocket" +) + +// Init the bridges from conf. +func Init(conf *config.WorkflowConfig) []core.Bridge { + bridges := make([]core.Bridge, 0) + if conf.Bridges == nil { + return bridges + } + + for _, cb := range conf.Bridges { + // all bridges will be running in the same host of zipper. + addr := fmt.Sprintf("%s:%d", conf.Host, cb.Port) + + switch cb.Name { + case nameOfWebSocket: + bridges = append(bridges, NewWebSocketBridge(addr)) + default: + logger.Errorf("InitBridges: the name of bridge %s is not implemented", cb.Name) + } + } + + return bridges +} diff --git a/pkg/bridge/websocket.go b/pkg/bridge/websocket.go new file mode 100644 index 000000000..cff012998 --- /dev/null +++ b/pkg/bridge/websocket.go @@ -0,0 +1,111 @@ +package bridge + +import ( + "net/http" + "net/url" + "sync" + + "github.com/yomorun/yomo/core" + "github.com/yomorun/yomo/core/frame" + "github.com/yomorun/yomo/pkg/logger" + "golang.org/x/net/websocket" +) + +const defaultRoomID = "default" + +// WebSocketBridge implements the Bridge interface for WebSocket. +type WebSocketBridge struct { + addr string + server *websocket.Server + + // Registered the connections in each room. + // Key: room id (string) + // Value: conns in room (sync.Map) + rooms sync.Map +} + +// NewWebSocketBridge initializes an instance for WebSocketBridge. +func NewWebSocketBridge(addr string) *WebSocketBridge { + return &WebSocketBridge{ + addr: addr, + server: &websocket.Server{ + Config: websocket.Config{ + Origin: &url.URL{ + Host: addr, + }, + }, + Handshake: func(c *websocket.Config, r *http.Request) error { + // TODO: check Origin header for auth. + return nil + }, + }, + rooms: sync.Map{}, + } +} + +// Name returns the name of WebSocket bridge. +func (ws *WebSocketBridge) Name() string { + return nameOfWebSocket +} + +// Addr returns the address of bridge. +func (ws *WebSocketBridge) Addr() string { + return ws.addr +} + +// ListenAndServe starts a WebSocket server with a given handler. +func (ws *WebSocketBridge) ListenAndServe(handler func(ctx *core.Context)) error { + // wrap the WebSocket handler. + ws.server.Handler = func(c *websocket.Conn) { + // set payload type + c.PayloadType = websocket.BinaryFrame + // TODO: support multi rooms. + roomID := defaultRoomID + conns := ws.getConnsByRoomID(roomID) + + // register new connections. + conns.Store(c, true) + ws.rooms.Store(roomID, conns) + + // trigger the YoMo Server's Handler in bridge. + handler(&core.Context{ + ConnID: c.Request().RemoteAddr, + Stream: c, + SendDataBack: ws.Send, + OnClose: func(code uint64, msg string) { + // remove this connection in room. + conns := ws.getConnsByRoomID(roomID) + conns.Delete(c) + }, + }) + } + + // serve + return http.ListenAndServe(ws.addr, ws.server) +} + +// Send the data to WebSocket clients. +func (ws *WebSocketBridge) Send(f frame.Frame) error { + // TODO: get RoomID from MetaFrame. + roomID := defaultRoomID + conns := ws.getConnsByRoomID(roomID) + conns.Range(func(key, value interface{}) bool { + if c, ok := key.(*websocket.Conn); ok { + _, err := c.Write(f.Encode()) + if err != nil { + logger.Errorf("[WebSocketBridge] send data to conn failed, roomID=%s", roomID) + } + } + return true + }) + return nil +} + +func (ws *WebSocketBridge) getConnsByRoomID(roomID string) sync.Map { + v, ok := ws.rooms.Load(roomID) + if !ok || v == nil { + v = sync.Map{} + } + conns, _ := v.(sync.Map) + return conns +} diff --git a/pkg/config/zipper_workflow.go b/pkg/config/zipper_workflow.go index 79fbd904f..7f0133308 100644 --- a/pkg/config/zipper_workflow.go +++ b/pkg/config/zipper_workflow.go @@ -18,6 +18,14 @@ type Workflow struct { Functions []App `yaml:"functions"` } +// Bridge represents a YoMo Bridge. +type Bridge struct { + // Name represents the name of the bridge. + Name string `yaml:"name"` + // Port represents the listening port of the bridge. + Port int `yaml:"port"` +} + // WorkflowConfig represents a YoMo Workflow config. type WorkflowConfig struct { // Name represents the name of the zipper. @@ -28,6 +36,8 @@ type WorkflowConfig struct { Port int `yaml:"port"` // Workflow represents the sfn workflow. Workflow `yaml:",inline"` + // Bridges represents a YoMo Bridges. + Bridges []Bridge `yaml:"bridges"` } // LoadWorkflowConfig the WorkflowConfig by path. diff --git a/zipper.go b/zipper.go index 2d2478ab3..f1f798409 100644 --- a/zipper.go +++ b/zipper.go @@ -8,6 +8,8 @@ import ( "net/http" "github.com/yomorun/yomo/core" + "github.com/yomorun/yomo/core/frame" + "github.com/yomorun/yomo/pkg/bridge" "github.com/yomorun/yomo/pkg/config" "github.com/yomorun/yomo/pkg/logger" ) @@ -130,6 +132,22 @@ func (z *zipper) ConfigWorkflow(conf string) error { } func (z *zipper) configWorkflow(config *config.WorkflowConfig) error { + // bridges + bridges := bridge.Init(config) + for _, bridge := range bridges { + z.server.AddBridge(bridge) + } + + // send DataFrame back to the connections from bridges. + z.server.SetBeforeHandlers(func(c *core.Context) error { + if c.SendDataBack != nil && c.Stream != nil && + c.Frame != nil && c.Frame.Type() == frame.TagOfDataFrame { + return c.SendDataBack(c.Frame) + } + return nil + }) + + // router return z.server.ConfigRouter(newRouter(config)) }