From 8b2c1d03906e00c3adecb0ae1b9819103d6cf9ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=9E=D0=BB=D0=B5=D0=B3=20=D0=91=D0=BE=D1=80=D0=BE=D0=B4?= =?UTF-8?q?=D0=B8=D0=BD?= Date: Thu, 9 Apr 2026 12:50:26 +0200 Subject: [PATCH] added unlinked controller; change create forwarded to createOrUpdate; added global proxy mutex --- app/control/kubeset.go | 25 +++++ app/control/svccont.go | 226 ++++++++++++++++++++++++++++++++++++++++ app/logger/logger.go | 2 +- app/operator/proxy.go | 11 +- app/rproxy/forwarder.go | 224 ++++++++++++++++++++++++--------------- app/rproxy/rproxy.go | 138 +++++++++++++----------- app/rproxy/streamer.go | 85 +++++++++++++++ app/service/service.go | 5 +- 8 files changed, 560 insertions(+), 156 deletions(-) create mode 100644 app/control/kubeset.go create mode 100644 app/control/svccont.go create mode 100644 app/rproxy/streamer.go diff --git a/app/control/kubeset.go b/app/control/kubeset.go new file mode 100644 index 0000000..8bc0cab --- /dev/null +++ b/app/control/kubeset.go @@ -0,0 +1,25 @@ +package control + +import ( + kubeclient "k8s.io/client-go/kubernetes" + kubeclicmd "k8s.io/client-go/tools/clientcmd" +) + +func makeClientset(kubeconf []byte) (kubeclient.Interface, error) { + var res kubeclient.Interface + var err error + clientConfig, err := kubeclicmd.NewClientConfigFromBytes(kubeconf) + if err != nil { + return res, err + } + restConfig, err := clientConfig.ClientConfig() + if err != nil { + return res, err + } + kubeClient, err := kubeclient.NewForConfig(restConfig) + if err != nil { + return res, err + } + res = kubeClient + return res, err +} diff --git a/app/control/svccont.go b/app/control/svccont.go new file mode 100644 index 0000000..1476bfa --- /dev/null +++ b/app/control/svccont.go @@ -0,0 +1,226 @@ +package control + +import ( + "context" + "encoding/json" + "time" + + kubecore "k8s.io/api/core/v1" + k8inform "k8s.io/client-go/informers" + k8client "k8s.io/client-go/kubernetes" + k8cache "k8s.io/client-go/tools/cache" + + kubemeta "k8s.io/apimachinery/pkg/apis/meta/v1" + kubetypes "k8s.io/apimachinery/pkg/types" + kubepatch "k8s.io/apimachinery/pkg/util/strategicpatch" + + "helmet/app/logger" +) + +type Controller struct { + lbaddr string + clientset k8client.Interface + informer k8cache.SharedIndexInformer + ctx context.Context + cancel context.CancelFunc + log *logger.Logger +} + +func NewController(clientset k8client.Interface, lbaddr string) *Controller { + ctx, cancel := context.WithCancel(context.Background()) + cont := &Controller{ + clientset: clientset, + ctx: ctx, + cancel: cancel, + lbaddr: lbaddr, + } + cont.log = logger.NewLogger("controller") + return cont +} + +func (cont *Controller) Run() { + cont.log.Debugf("Start controller") + factory := k8inform.NewSharedInformerFactory(cont.clientset, time.Minute*10) + defer factory.Shutdown() + + serviceInformer := factory.Core().V1().Services().Informer() + handler := k8cache.ResourceEventHandlerFuncs{ + AddFunc: cont.addService, + UpdateFunc: cont.updateService, + DeleteFunc: cont.deleteService, + } + serviceInformer.AddEventHandler(handler) + + ctx, cancel := context.WithCancel(cont.ctx) + defer cancel() + factory.Start(ctx.Done()) + synced := factory.WaitForCacheSync(ctx.Done()) + for _, sync := range synced { + if !sync { + return + } + } + cont.informer = serviceInformer + <-ctx.Done() + cont.log.Debugf("Stop controller") +} + +func (cont *Controller) Stop() { + cont.cancel() +} + +// https://pkg.go.dev/k8s.io/api/core/v1#Service + +func (cont *Controller) addService(obj any) { + service := obj.(*kubecore.Service) + service = service.DeepCopy() + cont.log.Debugf("Service %s/%s created", service.Namespace, service.Name) + if service.Spec.Type == kubecore.ServiceTypeLoadBalancer { + _ = makeForwarding(service) + err := cont.patchService(service) + if err != nil { + cont.log.Debugf("Error patch service %s/%s: %v", service.Namespace, service.Name, err) + } + } +} + +func (cont *Controller) updateService(oldObj, newObj any) { + newService := newObj.(*kubecore.Service) + newService = newService.DeepCopy() + oldService := oldObj.(*kubecore.Service) + oldService = oldService.DeepCopy() + cont.log.Debugf("Service %s/%s updated", newService.Namespace, newService.Name) + + if newService.Spec.Type == kubecore.ServiceTypeLoadBalancer { + if cont.serviceDifferent(oldService, newService) { + _ = makeForwarding(newService) + cont.log.Debugf("Updated service %s/%s have new adresses", newService.Namespace, newService.Name) + err := cont.patchService(newService) + if err != nil { + cont.log.Debugf("Error patch service %s/%s: %v", newService.Namespace, newService.Name, err) + } + } + } +} + +func (cont *Controller) deleteService(obj any) { + service := obj.(*kubecore.Service) + service = service.DeepCopy() + cont.log.Debugf("Service %s/%s deleted", service.Namespace, service.Name) +} + +type Forwarding struct { + Proto string + Port uint32 + Destinations []string +} + +func makeForwarding(service *kubecore.Service) []*Forwarding { + forwardings := make([]*Forwarding, 0) + for _, port := range service.Spec.Ports { + forwarding := &Forwarding{ + Proto: string(port.Protocol), + Port: uint32(port.Port), + Destinations: make([]string, 0), + } + for _, ipaddr := range service.Spec.ClusterIPs { + forwarding.Destinations = append(forwarding.Destinations, ipaddr) + } + } + return forwardings +} + +func (cont *Controller) patchService(svc *kubecore.Service) error { + var err error + oldData, err := json.Marshal(svc) + if err != nil { + return err + } + ingressMode := kubecore.LoadBalancerIPModeProxy + ingresses := make([]kubecore.LoadBalancerIngress, 0) + ingresses = append(ingresses, kubecore.LoadBalancerIngress{ + IP: cont.lbaddr, + IPMode: &ingressMode, + }) + status := kubecore.LoadBalancerStatus{ + Ingress: ingresses, + } + svc.Status.LoadBalancer = status + + newData, err := json.Marshal(svc) + if err != nil { + return err + } + + patchBytes, err := kubepatch.CreateTwoWayMergePatch(oldData, newData, kubecore.Service{}) + if err != nil { + return err + } + cont.log.Debugf("patch: %s\n", string(patchBytes)) + + ctx, _ := context.WithTimeout(cont.ctx, 5*time.Second) + patchServiceOpts := kubemeta.PatchOptions{} + strategy := kubetypes.StrategicMergePatchType + _, err = cont.clientset.CoreV1().Services(svc.Namespace). + Patch(ctx, svc.Name, strategy, patchBytes, patchServiceOpts, "status") + if err != nil { + return err + } + return err +} + +func (cont *Controller) serviceDifferent(oldService, newService *kubecore.Service) bool { + type Address struct { + Port int32 + Proto kubecore.Protocol + Address string + } + oldAddreses := make([]Address, 0) + for _, port := range oldService.Spec.Ports { + for _, ipaddr := range oldService.Spec.ClusterIPs { + address := Address{ + Port: port.Port, + Proto: port.Protocol, + Address: ipaddr, + } + oldAddreses = append(oldAddreses, address) + } + } + newAddreses := make([]Address, 0) + for _, port := range newService.Spec.Ports { + for _, ipaddr := range newService.Spec.ClusterIPs { + address := Address{ + Port: port.Port, + Proto: port.Protocol, + Address: ipaddr, + } + newAddreses = append(newAddreses, address) + } + } + if len(newAddreses) != len(oldAddreses) { + return true + } + for _, newAddress := range newAddreses { + newFound := false + for _, oldAddress := range oldAddreses { + if oldAddress == newAddress { + newFound = true + } + } + if !newFound { + return true + } + } + for _, oldAddress := range oldAddreses { + oldFound := false + for _, newAddress := range newAddreses { + if newAddress == oldAddress { + oldFound = true + } + } + if !oldFound { + return true + } + } + return false +} diff --git a/app/logger/logger.go b/app/logger/logger.go index 24d34de..69bea30 100644 --- a/app/logger/logger.go +++ b/app/logger/logger.go @@ -37,7 +37,7 @@ func NewLogger(subj string) *Logger { } } -func xxxNewLogger() *Logger { +func NewUnamedLogger() *Logger { return &Logger{ writer: output, mtx: &mtx, diff --git a/app/operator/proxy.go b/app/operator/proxy.go index 0d2fdf4..569d8e4 100644 --- a/app/operator/proxy.go +++ b/app/operator/proxy.go @@ -12,7 +12,8 @@ func (oper *Operator) ListForwarders(ctx context.Context, params *mlbctl.ListFor res := &mlbctl.ListForwardersResult{ Forwarders: make([]*mlbctl.Forwarder, 0), } - for _, forw := range oper.proxy.Forwarders { + forws := oper.proxy.ListForwarders(ctx) + for _, forw := range forws { oForw := &mlbctl.Forwarder{ Type: forw.Type, Lport: forw.Lport, @@ -33,8 +34,8 @@ func (oper *Operator) ListForwarders(ctx context.Context, params *mlbctl.ListFor func (oper *Operator) CreateForwarder(ctx context.Context, params *mlbctl.CreateForwarderParams) (*mlbctl.CreateForwarderResult, error) { var err error res := &mlbctl.CreateForwarderResult{} - err = oper.proxy.AddForwarder(ctx, params.Type, params.Lport, params.Dport, params.Destinations...) - if err != err { + err = oper.proxy.CreateOrUpdateForwarder(ctx, params.Type, params.Lport, params.Dport, params.Destinations...) + if err != nil { return res, err } return res, err @@ -43,8 +44,8 @@ func (oper *Operator) CreateForwarder(ctx context.Context, params *mlbctl.Create func (oper *Operator) DeleteForwarder(ctx context.Context, params *mlbctl.DeleteForwarderParams) (*mlbctl.DeleteForwarderResult, error) { var err error res := &mlbctl.DeleteForwarderResult{} - err = oper.proxy.DeleteForwarder(ctx, params.Lport) - if err != err { + err = oper.proxy.DeleteForwarder(ctx, params.Type, params.Lport) + if err != nil { return res, err } return res, err diff --git a/app/rproxy/forwarder.go b/app/rproxy/forwarder.go index 4bf8d64..f118854 100644 --- a/app/rproxy/forwarder.go +++ b/app/rproxy/forwarder.go @@ -2,11 +2,12 @@ package rproxy import ( "context" - "io" + "fmt" "math/rand" "net" "strconv" "sync" + "time" "helmet/app/logger" ) @@ -14,20 +15,27 @@ import ( const ( TCP = "tcp" UDP = "udp" + + ForwStarted = "started" + ForwStopped = "stopped" ) type Forwarder struct { - Type string `json:"type" yaml:"type"` - listen net.Listener `json:"-" yaml:"-"` - ctx context.Context `json:"-" yaml:"-"` - cancel context.CancelFunc `json:"-" yaml:"-"` - Lport uint32 `json:"lport" yaml:"lport"` - Dport uint32 `json:"dport" yaml:"dport"` - Dests []*Destination `json:"dests" yaml:"dests"` - log *logger.Logger + State string `json:"state" yaml:"state"` + Type string `json:"type" yaml:"type"` + Lport uint32 `json:"lport" yaml:"lport"` + Dport uint32 `json:"dport" yaml:"dport"` + Dests []*Destination `json:"dests" yaml:"dests"` + + listenTCP *net.TCPListener `json:"-" yaml:"-"` + listenUDP *net.UDPConn `json:"-" yaml:"-"` + ctx context.Context `json:"-" yaml:"-"` + cancel context.CancelFunc `json:"-" yaml:"-"` + log *logger.Logger } -func NewForwarder(ctx context.Context, typ string, lport, dport uint32, addrs ...string) (*Forwarder, error) { +func NewForwarder(ctx context.Context, proto string, lport, dport uint32, addrs ...string) (*Forwarder, error) { + var err error ctx, cancel := context.WithCancel(context.Background()) forw := &Forwarder{ Dests: make([]*Destination, 0), @@ -35,33 +43,67 @@ func NewForwarder(ctx context.Context, typ string, lport, dport uint32, addrs .. Dport: dport, ctx: ctx, cancel: cancel, - Type: typ, + Type: proto, } - id := strconv.FormatUint(uint64(lport), 10) + id := forw.Type + strconv.FormatUint(uint64(lport), 10) forw.log = logger.NewLogger("forwarder:" + id) for _, addr := range addrs { dest := NewDestination(addr) forw.Dests = append(forw.Dests, dest) } - - portinfo := ":" + strconv.FormatUint(uint64(forw.Lport), 10) - laddr, err := net.ResolveTCPAddr("tcp", portinfo) - if err != nil { - return forw, err - } - listen, err := net.ListenTCP("tcp", laddr) - if err != nil { + switch proto { + case TCP: + portinfo := ":" + strconv.FormatUint(uint64(forw.Lport), 10) + laddr, err := net.ResolveTCPAddr("tcp", portinfo) + if err != nil { + return forw, err + } + listen, err := net.ListenTCP("tcp", laddr) + if err != nil { + return forw, err + } + forw.listenTCP = listen + case UDP: + portinfo := ":" + strconv.FormatUint(uint64(forw.Lport), 10) + laddr, err := net.ResolveUDPAddr("udp", portinfo) + if err != nil { + return forw, err + } + listen, err := net.ListenUDP("udp", laddr) + if err != nil { + return forw, err + } + forw.listenUDP = listen + default: + err = fmt.Errorf("Unknown net type: %s", proto) return forw, err } - forw.listen = listen return forw, err } +func (forw *Forwarder) Listen(wg *sync.WaitGroup) { + switch forw.Type { + case TCP: + if forw.listenTCP != nil { + forw.ListenTCP(wg) + } + case UDP: + if forw.listenUDP != nil { + forw.ListenUDP(wg) + } + } +} + func (forw *Forwarder) ListenTCP(wg *sync.WaitGroup) { - forw.log.Debugf("Start listening on %d", forw.Lport) + forw.log.Debugf("Start listening on %s:%d", forw.Type, forw.Lport) + forw.State = ForwStarted defer wg.Done() + stater := func() { + forw.State = ForwStopped + } + defer stater() for { - conn, err := forw.listen.Accept() + conn, err := forw.listenTCP.Accept() if err != nil { forw.log.Errorf("Listen err: %v", err) return @@ -70,88 +112,102 @@ func (forw *Forwarder) ListenTCP(wg *sync.WaitGroup) { } } -func (forw *Forwarder) Stop() error { - return forw.listen.Close() +func (forw *Forwarder) ListenUDP(wg *sync.WaitGroup) { + forw.log.Debugf("Start listening on %s:%d", forw.Type, forw.Lport) + forw.State = ForwStarted + defer wg.Done() + stater := func() { + forw.State = ForwStopped + } + defer stater() + for { + buffer := make([]byte, 2048) + size, srcAddr, err := forw.listenUDP.ReadFromUDP(buffer) + if err != nil { + forw.log.Errorf("Error reading: %v", err) + continue + } + go forw.handleUDP(forw.listenUDP, srcAddr, buffer[:size]) + } } -func (forw *Forwarder) handleTCP(ctx context.Context, inconn net.Conn) { - forw.log.Debugf("%s: Handle on %d started", inconn.RemoteAddr(), forw.Lport) - defer inconn.Close() +func (forw *Forwarder) handleUDP(listConn *net.UDPConn, srcAddr *net.UDPAddr, data []byte) { + forw.log.Debugf("Handle on %d started", forw.Lport) if len(forw.Dests) == 0 { return } + // Select dest address addrnum := rand.Uint32() % uint32(len(forw.Dests)) ipaddr := forw.Dests[addrnum].Address + destInfo := ipaddr + ":" + strconv.FormatUint(uint64(forw.Dport), 10) - /* - dstaddr := ipaddr + ":" + strconv.FormatUint(uint64(forw.Dport), 10) - outconn, err := net.Dial("tcp", dstaddr) - if err != nil { - return - } - defer outconn.Close() - var wg sync.WaitGroup - wg.Add(1) - go forw.stream(&wg, inconn, outconn) - wg.Add(1) - go forw.stream(&wg, outconn, inconn) - wg.Wait() - */ - str := NewStreamer(forw.ctx) - err := str.Stream(inconn, ipaddr, forw.Dport) + destAddr, err := net.ResolveUDPAddr("udp", destInfo) if err != nil { - forw.log.Errorf("Handler on %d error: %v", forw.Lport, err) + forw.log.Debugf("Error resolving server address: %v", err) + return } - forw.log.Debugf("Handler on %d stopped", forw.Lport) -} - -func (forw *Forwarder) stream(wg *sync.WaitGroup, inconn io.Reader, outconn io.Writer) { - defer wg.Done() - _, err := iocopy(forw.ctx, outconn, inconn) + // Write to destination + destConn, err := net.DialUDP("udp", nil, destAddr) if err != nil { - forw.log.Errorf("Copy err: %v", err) + forw.log.Debugf("Error dialing: %v", err) + return } -} - -type Streamer struct { - ctx context.Context - cancel context.CancelFunc - log *logger.Logger -} - -func NewStreamer(ctx context.Context) *Streamer { - ctx, cancel := context.WithCancel(ctx) - log := logger.NewLogger("streamer") - return &Streamer{ - ctx: ctx, - cancel: cancel, - log: log, + defer destConn.Close() + _, err = destConn.Write(data) + if err != nil { + forw.log.Debugf("Error sending message: %v", err) + return + } + const deadlinePeriod = 5 * time.Second + destConn.SetReadDeadline(time.Now().Add(deadlinePeriod)) + // Read from destination and resend to initiator + const readCount = 1 + for i := 0; i < readCount; i++ { + buffer := make([]byte, 1024*2) + size, _, err := destConn.ReadFromUDP(buffer) + if err != nil { + forw.log.Debugf("Error reading response: %v", err) + return + } + _, err = listConn.WriteToUDP(buffer[:size], srcAddr) + if err != nil { + forw.log.Errorf("Error writing to back: %v", err) + return + } } } -func (str *Streamer) Stream(inconn net.Conn, dipaddr string, dport uint32) error { +func (forw *Forwarder) Stop() error { var err error - dstaddr := dipaddr + ":" + strconv.FormatUint(uint64(dport), 10) - outconn, err := net.Dial("tcp", dstaddr) - if err != nil { - return err - } - defer outconn.Close() - var wg sync.WaitGroup - wg.Add(1) - go str.stream(&wg, inconn, outconn) - wg.Add(1) - go str.stream(&wg, outconn, inconn) - wg.Wait() + switch forw.Type { + case TCP: + if forw.listenTCP != nil { + return forw.listenTCP.Close() + } + case UDP: + if forw.listenUDP != nil { + return forw.listenUDP.Close() + } + } return err } -func (str *Streamer) stream(wg *sync.WaitGroup, inconn io.Reader, outconn io.Writer) { - defer wg.Done() - _, err := iocopy(str.ctx, outconn, inconn) +func (forw *Forwarder) handleTCP(ctx context.Context, inconn net.Conn) { + forw.log.Debugf("%s: Handle on %d started", inconn.RemoteAddr(), forw.Lport) + defer inconn.Close() + if len(forw.Dests) == 0 { + return + } + // Select dest address + addrnum := rand.Uint32() % uint32(len(forw.Dests)) + ipaddr := forw.Dests[addrnum].Address + + str := NewStreamer(forw.ctx) + err := str.Stream(inconn, ipaddr, forw.Dport) if err != nil { - str.log.Errorf("Copy err: %v", err) + forw.log.Errorf("Handler on %d error: %v", forw.Lport, err) } + forw.log.Debugf("Handler on %d stopped", forw.Lport) } type Destination struct { diff --git a/app/rproxy/rproxy.go b/app/rproxy/rproxy.go index 5572f6e..e6fd6cd 100644 --- a/app/rproxy/rproxy.go +++ b/app/rproxy/rproxy.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "io" "strings" "sync" @@ -13,10 +12,11 @@ import ( type Proxy struct { Forwarders []*Forwarder `json:"forwarders" yaml:"forwarders"` - ctx context.Context `jsreturn errors.New("Zero dport")on:"-" yaml:"-"` + ctx context.Context `json:"-" yaml:"-"` cancel context.CancelFunc `json:"-" yaml:"-"` wg sync.WaitGroup `json:"-" yaml:"-"` log *logger.Logger `json:"-" yaml:"-"` + mtx sync.Mutex `json:"-" yaml:"-"` } func NewProxy() *Proxy { @@ -29,106 +29,118 @@ func NewProxy() *Proxy { } } -func (bal *Proxy) AddForwarder(ctx context.Context, typ string, lport, dport uint32, addrs ...string) error { +func (prox *Proxy) ListForwarders(ctx context.Context) []*Forwarder { + prox.mtx.Lock() + defer prox.mtx.Unlock() + forwards := make([]*Forwarder, 0) + for _, forw := range prox.Forwarders { + dests := make([]string, 0) + for _, dest := range forw.Dests { + dests = append(dests, dest.Address) + } + newForw, _ := NewForwarder(ctx, forw.Type, forw.Lport, forw.Dport, dests...) + forwards = append(forwards, newForw) + } + return forwards +} + +func (prox *Proxy) CreateOrUpdateForwarder(ctx context.Context, proto string, lport, dport uint32, addrs ...string) error { var err error if lport == 0 { - return errors.New("Zero lport") + return errors.New("Zero forwarder lport") } if dport == 0 { - return errors.New("Zero dport") + return errors.New("Zero forwarder dport") } - if typ == "" { - return errors.New("Empty type") + if proto == "" { + return errors.New("Empty forwarder type") } - typ = strings.ToLower(typ) - if typ != TCP { - return fmt.Errorf("Unknown type %s", typ) + proto = strings.ToLower(proto) + if proto != TCP && proto != UDP { + return fmt.Errorf("Unknown forwarder protocol %s", proto) } - forw, err := NewForwarder(ctx, typ, lport, dport, addrs...) - if err != nil { - return err + prox.mtx.Lock() + defer prox.mtx.Unlock() + + var forw *Forwarder + for _, iforw := range prox.Forwarders { + if iforw.Lport == lport && iforw.Type == proto { + forw = iforw + break + } + } + switch { + case forw == nil: + prox.log.Debugf("Create rorwarder %s:%d", proto, lport) + forw, err = NewForwarder(ctx, proto, lport, dport, addrs...) + if err != nil { + return err + } + prox.Forwarders = append(prox.Forwarders, forw) + prox.wg.Add(1) + go forw.Listen(&prox.wg) + default: + prox.log.Debugf("Update forwarder %s:%d", proto, lport) + forw.Dport = dport + dests := make([]*Destination, 0) + for _, addr := range addrs { + dest := NewDestination(addr) + forw.Dests = append(dests, dest) + } + forw.Dests = dests } - bal.Forwarders = append(bal.Forwarders, forw) - bal.wg.Add(1) - go forw.ListenTCP(&bal.wg) return err } -func (bal *Proxy) DeleteForwarder(ctx context.Context, lport uint32) error { +func (prox *Proxy) DeleteForwarder(ctx context.Context, proto string, lport uint32) error { var err error var forw *Forwarder - for _, iforw := range bal.Forwarders { - if iforw.Lport == lport { + prox.mtx.Lock() + defer prox.mtx.Unlock() + for _, iforw := range prox.Forwarders { + if iforw.Lport == lport && iforw.Type == proto { forw = iforw break } } if forw == nil { - bal.log.Debugf("Forwarder for lport %d not found", lport) + prox.log.Debugf("Forwarder for %s:%d not found", proto, lport) return err } - bal.log.Debugf("Stop forwarder for lport %d", lport) + prox.log.Debugf("Stop forwarder for %s:%d", proto, lport) err = forw.Stop() if err != nil { return err } forwarders := make([]*Forwarder, 0) - for _, forw := range bal.Forwarders { - if forw.Lport == lport { + for _, forw := range prox.Forwarders { + if forw.Lport == lport && forw.Type == proto { continue } forwarders = append(forwarders, forw) } - bal.Forwarders = forwarders + prox.Forwarders = forwarders return err } -func (bal *Proxy) Start() error { +func (prox *Proxy) Start() error { var err error - for _, forw := range bal.Forwarders { - bal.wg.Add(1) - go forw.ListenTCP(&bal.wg) + prox.mtx.Lock() + defer prox.mtx.Unlock() + for _, forw := range prox.Forwarders { + prox.wg.Add(1) + go forw.Listen(&prox.wg) } - bal.wg.Wait() + prox.wg.Wait() return err } -func (bal *Proxy) Stop() error { +func (prox *Proxy) Stop() error { var err error - for _, forw := range bal.Forwarders { + prox.mtx.Lock() + defer prox.mtx.Unlock() + for _, forw := range prox.Forwarders { forw.Stop() } return err } - -func iocopy(ctx context.Context, writer io.Writer, reader io.Reader) (int64, error) { - var err error - var size int64 - var halt bool - buffer := make([]byte, 1024*4) - for { - select { - case <-ctx.Done(): - err = errors.New("Break copy by context") - break - default: - } - rsize, err := reader.Read(buffer) - if err == io.EOF { - err = nil - halt = true - } - if err != nil { - return size, err - } - wsize, err := writer.Write(buffer[0:rsize]) - size += int64(wsize) - if err != nil { - return size, err - } - if halt { - break - } - } - return size, err -} diff --git a/app/rproxy/streamer.go b/app/rproxy/streamer.go new file mode 100644 index 0000000..0cb7b7f --- /dev/null +++ b/app/rproxy/streamer.go @@ -0,0 +1,85 @@ +package rproxy + +import ( + "context" + "errors" + "io" + "net" + "strconv" + "sync" + + "helmet/app/logger" +) + +type Streamer struct { + ctx context.Context + cancel context.CancelFunc + log *logger.Logger +} + +func NewStreamer(ctx context.Context) *Streamer { + ctx, cancel := context.WithCancel(ctx) + log := logger.NewLogger("streamer") + return &Streamer{ + ctx: ctx, + cancel: cancel, + log: log, + } +} + +func (str *Streamer) Stream(inconn net.Conn, dipaddr string, dport uint32) error { + var err error + dstaddr := dipaddr + ":" + strconv.FormatUint(uint64(dport), 10) + outconn, err := net.Dial("tcp", dstaddr) + if err != nil { + return err + } + defer outconn.Close() + var wg sync.WaitGroup + wg.Add(1) + go str.stream(&wg, inconn, outconn) + wg.Add(1) + go str.stream(&wg, outconn, inconn) + wg.Wait() + return err +} + +func (str *Streamer) stream(wg *sync.WaitGroup, inconn io.Reader, outconn io.Writer) { + defer wg.Done() + _, err := iocopy(str.ctx, outconn, inconn) + if err != nil { + str.log.Errorf("Copy err: %v", err) + } +} + +func iocopy(ctx context.Context, writer io.Writer, reader io.Reader) (int64, error) { + var err error + var size int64 + var halt bool + buffer := make([]byte, 1024*4) + for { + select { + case <-ctx.Done(): + err = errors.New("Break copy by context") + break + default: + } + rsize, err := reader.Read(buffer) + if err == io.EOF { + err = nil + halt = true + } + if err != nil { + return size, err + } + wsize, err := writer.Write(buffer[0:rsize]) + size += int64(wsize) + if err != nil { + return size, err + } + if halt { + break + } + } + return size, err +} diff --git a/app/service/service.go b/app/service/service.go index 3df2a06..8463d56 100644 --- a/app/service/service.go +++ b/app/service/service.go @@ -109,13 +109,12 @@ func (svc *Service) logInterceptor(ctx context.Context, req any, info *grpc.Unar if err == nil { svc.log.Debugf("Request: %s", string(reqData)) } - res, err := handler(ctx, req) + res, handErr := handler(ctx, req) resData, err := json.Marshal(res) if err == nil { svc.log.Debugf("Response: %s", string(resData)) } - - return res, err + return res, handErr } func (svc *Service) Stop() {