| 
 | 1 | +package rpcstream  | 
 | 2 | + | 
 | 3 | +import (  | 
 | 4 | +	"context"  | 
 | 5 | +	"errors"  | 
 | 6 | + | 
 | 7 | +	"github.com/aperturerobotics/starpc/srpc"  | 
 | 8 | +)  | 
 | 9 | + | 
 | 10 | +// RpcStream implements a RPC call stream over a RPC call. Used to implement  | 
 | 11 | +// sub-components which have a different set of services & calls available.  | 
 | 12 | +type RpcStream interface {  | 
 | 13 | +	srpc.Stream  | 
 | 14 | +	Send(*Packet) error  | 
 | 15 | +	Recv() (*Packet, error)  | 
 | 16 | +}  | 
 | 17 | + | 
 | 18 | +// RpcStreamGetter returns the Mux for the component ID from the remote.  | 
 | 19 | +type RpcStreamGetter func(ctx context.Context, componentID string) (srpc.Mux, error)  | 
 | 20 | + | 
 | 21 | +// RpcStreamCaller is a function which starts the RpcStream call.  | 
 | 22 | +type RpcStreamCaller func(ctx context.Context) (RpcStream, error)  | 
 | 23 | + | 
 | 24 | +// NewRpcStreamOpenStream constructs an OpenStream function with a RpcStream.  | 
 | 25 | +func NewRpcStreamOpenStream(componentID string, rpcCaller RpcStreamCaller) srpc.OpenStreamFunc {  | 
 | 26 | +	return func(ctx context.Context, msgHandler srpc.PacketHandler) (srpc.Writer, error) {  | 
 | 27 | +		// open the rpc stream  | 
 | 28 | +		rpcStream, err := rpcCaller(ctx)  | 
 | 29 | +		if err != nil {  | 
 | 30 | +			return nil, err  | 
 | 31 | +		}  | 
 | 32 | + | 
 | 33 | +		// write the component id  | 
 | 34 | +		err = rpcStream.Send(&Packet{  | 
 | 35 | +			Body: &Packet_Init{  | 
 | 36 | +				Init: &RpcStreamInit{  | 
 | 37 | +					ComponentId: componentID,  | 
 | 38 | +				},  | 
 | 39 | +			},  | 
 | 40 | +		})  | 
 | 41 | +		if err != nil {  | 
 | 42 | +			_ = rpcStream.Close()  | 
 | 43 | +			return nil, err  | 
 | 44 | +		}  | 
 | 45 | + | 
 | 46 | +		// initialize the rpc  | 
 | 47 | +		rw := NewRpcStreamReadWriter(rpcStream, msgHandler)  | 
 | 48 | + | 
 | 49 | +		// start the read pump  | 
 | 50 | +		go func() {  | 
 | 51 | +			err := rw.ReadPump()  | 
 | 52 | +			if err != nil {  | 
 | 53 | +				_ = rw.Close()  | 
 | 54 | +			}  | 
 | 55 | +		}()  | 
 | 56 | + | 
 | 57 | +		// return the writer  | 
 | 58 | +		return rw, nil  | 
 | 59 | +	}  | 
 | 60 | +}  | 
 | 61 | + | 
 | 62 | +// HandleRpcStream handles an incoming RPC stream (remote is the initiator).  | 
 | 63 | +func HandleRpcStream(stream RpcStream, getter RpcStreamGetter) error {  | 
 | 64 | +	// Read the "init" packet.  | 
 | 65 | +	initPkt, err := stream.Recv()  | 
 | 66 | +	if err != nil {  | 
 | 67 | +		return err  | 
 | 68 | +	}  | 
 | 69 | +	initInner, ok := initPkt.GetBody().(*Packet_Init)  | 
 | 70 | +	if !ok || initInner.Init == nil {  | 
 | 71 | +		return errors.New("expected init packet")  | 
 | 72 | +	}  | 
 | 73 | +	componentID := initInner.Init.GetComponentId()  | 
 | 74 | +	if componentID == "" {  | 
 | 75 | +		return errors.New("invalid init packet: empty component id")  | 
 | 76 | +	}  | 
 | 77 | + | 
 | 78 | +	// lookup the server for this component id  | 
 | 79 | +	ctx := stream.Context()  | 
 | 80 | +	mux, err := getter(ctx, componentID)  | 
 | 81 | +	if err != nil {  | 
 | 82 | +		return err  | 
 | 83 | +	}  | 
 | 84 | +	if mux == nil {  | 
 | 85 | +		return errors.New("no server for that component")  | 
 | 86 | +	}  | 
 | 87 | + | 
 | 88 | +	// handle the rpc  | 
 | 89 | +	serverRPC := srpc.NewServerRPC(ctx, mux)  | 
 | 90 | +	prw := NewRpcStreamReadWriter(stream, serverRPC.HandlePacket)  | 
 | 91 | +	serverRPC.SetWriter(prw)  | 
 | 92 | +	err = prw.ReadPump()  | 
 | 93 | +	_ = prw.Close()  | 
 | 94 | +	return err  | 
 | 95 | +}  | 
 | 96 | + | 
 | 97 | +// RpcStreamReadWriter reads and writes packets from a RpcStream.  | 
 | 98 | +type RpcStreamReadWriter struct {  | 
 | 99 | +	// stream is the RpcStream  | 
 | 100 | +	stream RpcStream  | 
 | 101 | +	// cb is the callback  | 
 | 102 | +	cb srpc.PacketHandler  | 
 | 103 | +}  | 
 | 104 | + | 
 | 105 | +// NewRpcStreamReadWriter constructs a new read/writer.  | 
 | 106 | +func NewRpcStreamReadWriter(stream RpcStream, cb srpc.PacketHandler) *RpcStreamReadWriter {  | 
 | 107 | +	return &RpcStreamReadWriter{stream: stream, cb: cb}  | 
 | 108 | +}  | 
 | 109 | + | 
 | 110 | +// WritePacket writes a packet to the writer.  | 
 | 111 | +func (r *RpcStreamReadWriter) WritePacket(p *srpc.Packet) error {  | 
 | 112 | +	data, err := p.MarshalVT()  | 
 | 113 | +	if err != nil {  | 
 | 114 | +		return err  | 
 | 115 | +	}  | 
 | 116 | +	return r.stream.Send(&Packet{  | 
 | 117 | +		Body: &Packet_Data{  | 
 | 118 | +			Data: data,  | 
 | 119 | +		},  | 
 | 120 | +	})  | 
 | 121 | +}  | 
 | 122 | + | 
 | 123 | +// ReadPump executes the read pump in a goroutine.  | 
 | 124 | +func (r *RpcStreamReadWriter) ReadPump() error {  | 
 | 125 | +	for {  | 
 | 126 | +		rpcStreamPkt, err := r.stream.Recv()  | 
 | 127 | +		if err != nil {  | 
 | 128 | +			return err  | 
 | 129 | +		}  | 
 | 130 | +		dataPkt, ok := rpcStreamPkt.GetBody().(*Packet_Data)  | 
 | 131 | +		if !ok {  | 
 | 132 | +			return errors.New("expected data packet")  | 
 | 133 | +		}  | 
 | 134 | +		pkt := &srpc.Packet{}  | 
 | 135 | +		if err := pkt.UnmarshalVT(dataPkt.Data); err != nil {  | 
 | 136 | +			return err  | 
 | 137 | +		}  | 
 | 138 | +		if err := r.cb(pkt); err != nil {  | 
 | 139 | +			return err  | 
 | 140 | +		}  | 
 | 141 | +	}  | 
 | 142 | +}  | 
 | 143 | + | 
 | 144 | +// Close closes the packet rw.  | 
 | 145 | +func (r *RpcStreamReadWriter) Close() error {  | 
 | 146 | +	return r.stream.Close()  | 
 | 147 | +}  | 
 | 148 | + | 
 | 149 | +// _ is a type assertion  | 
 | 150 | +var _ srpc.Writer = (*RpcStreamReadWriter)(nil)  | 
0 commit comments