diff --git a/app/operator/oper.go b/app/operator/oper.go index 1dc6efc..8cf184f 100644 --- a/app/operator/oper.go +++ b/app/operator/oper.go @@ -1,9 +1,8 @@ package operator import ( - "helmet/app/logger" - "helmet/app/config" + "helmet/app/logger" "helmet/app/rproxy" ) diff --git a/app/operator/proxy.go b/app/operator/proxy.go index e850b19..1de889e 100644 --- a/app/operator/proxy.go +++ b/app/operator/proxy.go @@ -33,6 +33,9 @@ func (oper *Operator) CreateForwarder(ctx context.Context, params *mlbctl.Create var err error res := &mlbctl.CreateForwarderResult{} err = oper.proxy.AddForwarder(ctx, params.Lport, params.Dport, params.Destinations...) + if err != err { + return res, err + } return res, err } @@ -40,5 +43,8 @@ func (oper *Operator) DeleteForwarder(ctx context.Context, params *mlbctl.Delete var err error res := &mlbctl.DeleteForwarderResult{} err = oper.proxy.DeleteForwarder(ctx, params.Lport) + if err != err { + return res, err + } return res, err } diff --git a/app/rproxy/forwarder.go b/app/rproxy/forwarder.go new file mode 100644 index 0000000..6c7efd7 --- /dev/null +++ b/app/rproxy/forwarder.go @@ -0,0 +1,124 @@ +package rproxy + +import ( + "context" + "io" + "math/rand" + "net" + "strconv" + "sync" + + "helmet/app/logger" +) + +type Forwarder struct { + listen net.Listener `json:"-" yaml:"-"` + ctx context.Context `json:"-" yaml:"-"` + cancel context.CancelFunc `json:"-" yaml:"-"` + Lport uint32 `json:"lport" yaml:"lport"` + Dport uint32 `json:"dport" yaml:"dport"` + Dests []*Destination `json:"dests" yaml:"dests"` + log *logger.Logger +} + +func NewForwarder(ctx context.Context, lport, dport uint32, addrs ...string) (*Forwarder, error) { + ctx, cancel := context.WithCancel(context.Background()) + forw := &Forwarder{ + Dests: make([]*Destination, 0), + Lport: lport, + Dport: dport, + ctx: ctx, + cancel: cancel, + } + id := strconv.FormatUint(uint64(lport), 10) + forw.log = logger.NewLogger("forwarder:" + id) + for _, addr := range addrs { + dest := NewDestination(addr) + forw.Dests = append(forw.Dests, dest) + } + + portinfo := ":" + strconv.FormatUint(uint64(forw.Lport), 10) + laddr, err := net.ResolveTCPAddr("tcp", portinfo) + if err != nil { + return forw, err + } + listen, err := net.ListenTCP("tcp", laddr) + if err != nil { + return forw, err + } + forw.listen = listen + return forw, err +} + +func (forw *Forwarder) Listen(wg *sync.WaitGroup) { + forw.log.Debugf("Start listening on %d", forw.Lport) + defer wg.Done() + for { + conn, err := forw.listen.Accept() + if err != nil { + forw.log.Errorf("Listen err: %v", err) + return + } + go forw.handle(forw.ctx, conn) + } +} + +func (forw *Forwarder) Stop() error { + return forw.listen.Close() +} + +type Streamer struct { + source string + dest string + ctx context.Context + cancel context.CancelFunc +} + +func NewStreamer(ctx context.Context) *Streamer { + ctx, cancel := context.WithCancel(ctx) + return &Streamer{ + ctx: ctx, + cancel: cancel, + } +} + +func (forw *Forwarder) handle(ctx context.Context, inconn net.Conn) { + forw.log.Debugf("%s: Handle on %d started", inconn.RemoteAddr(), forw.Lport) + defer inconn.Close() + if len(forw.Dests) == 0 { + return + } + addrnum := rand.Uint32() % uint32(len(forw.Dests)) + ipaddr := forw.Dests[addrnum].Address + dstaddr := ipaddr + ":" + strconv.FormatUint(uint64(forw.Dport), 10) + outconn, err := net.Dial("tcp", dstaddr) + if err != nil { + return + } + defer outconn.Close() + var wg sync.WaitGroup + wg.Add(1) + go forw.stream(&wg, inconn, outconn) + wg.Add(1) + go forw.stream(&wg, outconn, inconn) + wg.Wait() + forw.log.Debugf("Handler on %d stopped", forw.Lport) +} + +func (forw *Forwarder) stream(wg *sync.WaitGroup, inconn io.Reader, outconn io.Writer) { + defer wg.Done() + _, err := copy(forw.ctx, outconn, inconn) + if err != nil { + forw.log.Errorf("Copy err: %v", err) + } +} + +type Destination struct { + Address string `json:"address" yaml:"address"` +} + +func NewDestination(addr string) *Destination { + return &Destination{ + Address: addr, + } +} diff --git a/app/rproxy/proxy.go b/app/rproxy/proxy.go deleted file mode 100644 index 3d0982d..0000000 --- a/app/rproxy/proxy.go +++ /dev/null @@ -1,192 +0,0 @@ -package rproxy - -import ( - "context" - "errors" - "io" - "log" - "math/rand" - "net" - "strconv" - "sync" -) - -type Proxy struct { - Forwarders []*Forwarder `json:"forwarders" yaml:"forwarders"` - ctx context.Context `json:"-" yaml:"-"` - cancel context.CancelFunc `json:"-" yaml:"-"` - wg sync.WaitGroup `json:"-" yaml:"-"` -} - -func NewProxy() *Proxy { - ctx, cancel := context.WithCancel(context.Background()) - return &Proxy{ - Forwarders: make([]*Forwarder, 0), - ctx: ctx, - cancel: cancel, - } -} - -func (bal *Proxy) AddForwarder(ctx context.Context, lport, dport uint32, addrs ...string) error { - var err error - forw, err := NewForwarder(ctx, lport, dport, addrs...) - if err != nil { - return err - } - bal.Forwarders = append(bal.Forwarders, forw) - bal.wg.Add(1) - go forw.Listen(&bal.wg) - return err -} - -func (bal *Proxy) DeleteForwarder(ctx context.Context, lport uint32) error { - var err error - forwarders := make([]*Forwarder, 0) - for _, forw := range bal.Forwarders { - if forw.Lport == lport { - forw.Stop() - continue - } - forwarders = append(forwarders, forw) - } - bal.Forwarders = forwarders - return err -} - -func (bal *Proxy) Start() error { - var err error - for _, forw := range bal.Forwarders { - bal.wg.Add(1) - go forw.Listen(&bal.wg) - } - bal.wg.Wait() - return err -} - -func (bal *Proxy) Stop() error { - var err error - for _, forw := range bal.Forwarders { - forw.Stop() - } - return err -} - -type Forwarder struct { - listen net.Listener `json:"-" yaml:"-"` - ctx context.Context `json:"-" yaml:"-"` - cancel context.CancelFunc `json:"-" yaml:"-"` - Lport uint32 `json:"lport" yaml:"lport"` - Dport uint32 `json:"dport" yaml:"dport"` - Dests []*Destination `json:"dests" yaml:"dests"` -} - -func NewForwarder(ctx context.Context, lport, dport uint32, addrs ...string) (*Forwarder, error) { - ctx, cancel := context.WithCancel(ctx) - forw := &Forwarder{ - Dests: make([]*Destination, 0), - Lport: lport, - Dport: dport, - ctx: ctx, - cancel: cancel, - } - for _, addr := range addrs { - dest := NewDestination(addr) - forw.Dests = append(forw.Dests, dest) - } - port := ":" + strconv.FormatUint(uint64(forw.Lport), 10) - listen, err := net.Listen("tcp", port) - if err != nil { - return forw, err - } - forw.listen = listen - return forw, err -} - -func (forw *Forwarder) Listen(wg *sync.WaitGroup) { - log.Printf("Start listening on %d\n", forw.Lport) - defer wg.Done() - for { - conn, err := forw.listen.Accept() - if err != nil { - log.Printf("Listen err: %v\n", err) - return - } - go forw.handle(forw.ctx, conn) - } -} - -func (forw *Forwarder) Stop() error { - return forw.listen.Close() -} - -func (forw *Forwarder) handle(ctx context.Context, inconn net.Conn) { - log.Printf("Handler on %d started\n", forw.Lport) - defer inconn.Close() - if len(forw.Dests) == 0 { - return - } - addrnum := rand.Uint32() % uint32(len(forw.Dests)) - ipaddr := forw.Dests[addrnum].Address - dstaddr := ipaddr + ":" + strconv.FormatUint(uint64(forw.Dport), 10) - outconn, err := net.Dial("tcp", dstaddr) - if err != nil { - return - } - var wg sync.WaitGroup - wg.Add(1) - go forw.stream(&wg, inconn, outconn) - wg.Add(1) - go forw.stream(&wg, outconn, inconn) - wg.Wait() - log.Printf("Handler on %d stopped\n", forw.Lport) -} - -func (forw *Forwarder) stream(wg *sync.WaitGroup, inconn io.Reader, outconn io.Writer) { - defer wg.Done() - _, err := copy(forw.ctx, outconn, inconn) - if err != nil { - log.Printf("Copy err: %v\n", err) - } -} - -type Destination struct { - Address string `json:"address" yaml:"address"` -} - -func NewDestination(addr string) *Destination { - return &Destination{ - Address: addr, - } -} - -func copy(ctx context.Context, writer io.Writer, reader io.Reader) (int64, error) { - var err error - var size int64 - var halt bool - buffer := make([]byte, 1024*4) - for { - select { - case <-ctx.Done(): - err = errors.New("Break copy by context") - break - default: - } - rsize, err := reader.Read(buffer) - if err == io.EOF { - err = nil - halt = true - } - if err != nil { - return size, err - } - wsize, err := writer.Write(buffer[0:rsize]) - size += int64(wsize) - if err != nil { - return size, err - } - if halt { - break - } - } - return size, err -} diff --git a/app/rproxy/rproxy.go b/app/rproxy/rproxy.go new file mode 100644 index 0000000..861b1e2 --- /dev/null +++ b/app/rproxy/rproxy.go @@ -0,0 +1,125 @@ +package rproxy + +import ( + "context" + "errors" + "io" + "sync" + + "helmet/app/logger" +) + +type Proxy struct { + Forwarders []*Forwarder `json:"forwarders" yaml:"forwarders"` + ctx context.Context `json:"-" yaml:"-"` + cancel context.CancelFunc `json:"-" yaml:"-"` + wg sync.WaitGroup `json:"-" yaml:"-"` + log *logger.Logger `json:"-" yaml:"-"` +} + +func NewProxy() *Proxy { + ctx, cancel := context.WithCancel(context.Background()) + return &Proxy{ + Forwarders: make([]*Forwarder, 0), + ctx: ctx, + cancel: cancel, + log: logger.NewLogger("proxy"), + } +} + +func (bal *Proxy) AddForwarder(ctx context.Context, lport, dport uint32, addrs ...string) error { + var err error + if lport == 0 { + return errors.New("Zero lport") + } + if dport == 0 { + return errors.New("Zero dport") + } + forw, err := NewForwarder(ctx, lport, dport, addrs...) + if err != nil { + return err + } + bal.Forwarders = append(bal.Forwarders, forw) + bal.wg.Add(1) + go forw.Listen(&bal.wg) + return err +} + +func (bal *Proxy) DeleteForwarder(ctx context.Context, lport uint32) error { + var err error + var forw *Forwarder + for _, iforw := range bal.Forwarders { + if iforw.Lport == lport { + forw = iforw + break + } + } + if forw == nil { + bal.log.Debugf("Forwarder for lport %d not found", lport) + return err + } + bal.log.Debugf("Stop forwarder for lport %d", lport) + err = forw.Stop() + if err != nil { + return err + } + forwarders := make([]*Forwarder, 0) + for _, forw := range bal.Forwarders { + if forw.Lport == lport { + continue + } + forwarders = append(forwarders, forw) + } + bal.Forwarders = forwarders + return err +} + +func (bal *Proxy) Start() error { + var err error + for _, forw := range bal.Forwarders { + bal.wg.Add(1) + go forw.Listen(&bal.wg) + } + bal.wg.Wait() + return err +} + +func (bal *Proxy) Stop() error { + var err error + for _, forw := range bal.Forwarders { + forw.Stop() + } + return err +} + +func copy(ctx context.Context, writer io.Writer, reader io.Reader) (int64, error) { + var err error + var size int64 + var halt bool + buffer := make([]byte, 1024*4) + for { + select { + case <-ctx.Done(): + err = errors.New("Break copy by context") + break + default: + } + rsize, err := reader.Read(buffer) + if err == io.EOF { + err = nil + halt = true + } + if err != nil { + return size, err + } + wsize, err := writer.Write(buffer[0:rsize]) + size += int64(wsize) + if err != nil { + return size, err + } + if halt { + break + } + } + return size, err +} diff --git a/app/service/service.go b/app/service/service.go index 24aed18..3df2a06 100644 --- a/app/service/service.go +++ b/app/service/service.go @@ -109,7 +109,13 @@ func (svc *Service) logInterceptor(ctx context.Context, req any, info *grpc.Unar if err == nil { svc.log.Debugf("Request: %s", string(reqData)) } - return handler(ctx, req) + res, err := handler(ctx, req) + resData, err := json.Marshal(res) + if err == nil { + svc.log.Debugf("Response: %s", string(resData)) + } + + return res, err } func (svc *Service) Stop() { diff --git a/cmd/minilbctl/forwarder/forwcmd.go b/cmd/minilbctl/forwarder/forwcmd.go index 2ba3b4c..47c8c59 100644 --- a/cmd/minilbctl/forwarder/forwcmd.go +++ b/cmd/minilbctl/forwarder/forwcmd.go @@ -47,7 +47,7 @@ func NewTool() *Tool { Use: "delete", Args: cobra.ExactArgs(1), Short: "Delete forwarder", - Run: tool.CreateForwarder, + Run: tool.DeleteForwarder, } deleteForwarderCmd.Flags().Uint32VarP(&tool.deleteForwarderParams.Lport, "lport", "L", 0, "Listening port") deleteForwarderCmd.MarkFlagRequired("lport")