package rproxy import ( "context" "io" "math/rand" "net" "strconv" "sync" "helmet/app/logger" ) const ( TCP = "tcp" UDP = "udp" ) type Forwarder struct { Type string `json:"type" yaml:"type"` listen net.Listener `json:"-" yaml:"-"` ctx context.Context `json:"-" yaml:"-"` cancel context.CancelFunc `json:"-" yaml:"-"` Lport uint32 `json:"lport" yaml:"lport"` Dport uint32 `json:"dport" yaml:"dport"` Dests []*Destination `json:"dests" yaml:"dests"` log *logger.Logger } func NewForwarder(ctx context.Context, typ string, lport, dport uint32, addrs ...string) (*Forwarder, error) { ctx, cancel := context.WithCancel(context.Background()) forw := &Forwarder{ Dests: make([]*Destination, 0), Lport: lport, Dport: dport, ctx: ctx, cancel: cancel, Type: typ, } id := strconv.FormatUint(uint64(lport), 10) forw.log = logger.NewLogger("forwarder:" + id) for _, addr := range addrs { dest := NewDestination(addr) forw.Dests = append(forw.Dests, dest) } portinfo := ":" + strconv.FormatUint(uint64(forw.Lport), 10) laddr, err := net.ResolveTCPAddr("tcp", portinfo) if err != nil { return forw, err } listen, err := net.ListenTCP("tcp", laddr) if err != nil { return forw, err } forw.listen = listen return forw, err } func (forw *Forwarder) ListenTCP(wg *sync.WaitGroup) { forw.log.Debugf("Start listening on %d", forw.Lport) defer wg.Done() for { conn, err := forw.listen.Accept() if err != nil { forw.log.Errorf("Listen err: %v", err) return } go forw.handle(forw.ctx, conn) } } func (forw *Forwarder) Stop() error { return forw.listen.Close() } type Streamer struct { source string dest string ctx context.Context cancel context.CancelFunc } func NewStreamer(ctx context.Context) *Streamer { ctx, cancel := context.WithCancel(ctx) return &Streamer{ ctx: ctx, cancel: cancel, } } func (forw *Forwarder) handle(ctx context.Context, inconn net.Conn) { forw.log.Debugf("%s: Handle on %d started", inconn.RemoteAddr(), forw.Lport) defer inconn.Close() if len(forw.Dests) == 0 { return } addrnum := rand.Uint32() % uint32(len(forw.Dests)) ipaddr := forw.Dests[addrnum].Address dstaddr := ipaddr + ":" + strconv.FormatUint(uint64(forw.Dport), 10) outconn, err := net.Dial("tcp", dstaddr) if err != nil { return } defer outconn.Close() var wg sync.WaitGroup wg.Add(1) go forw.stream(&wg, inconn, outconn) wg.Add(1) go forw.stream(&wg, outconn, inconn) wg.Wait() forw.log.Debugf("Handler on %d stopped", forw.Lport) } func (forw *Forwarder) stream(wg *sync.WaitGroup, inconn io.Reader, outconn io.Writer) { defer wg.Done() _, err := copy(forw.ctx, outconn, inconn) if err != nil { forw.log.Errorf("Copy err: %v", err) } } type Destination struct { Address string `json:"address" yaml:"address"` } func NewDestination(addr string) *Destination { return &Destination{ Address: addr, } }