package rproxy import ( "context" "errors" "io" "log" "math/rand" "net" "strconv" "sync" ) type Proxy struct { Forwarders []*Forwarder `json:"forwarders" yaml:"forwarders"` ctx context.Context `json:"-" yaml:"-"` cancel context.CancelFunc `json:"-" yaml:"-"` wg sync.WaitGroup `json:"-" yaml:"-"` } func NewProxy() *Proxy { ctx, cancel := context.WithCancel(context.Background()) return &Proxy{ Forwarders: make([]*Forwarder, 0), ctx: ctx, cancel: cancel, } } func (bal *Proxy) AddForwarder(ctx context.Context, lport, dport uint32, addrs ...string) error { var err error forw, err := NewForwarder(ctx, lport, dport, addrs...) if err != nil { return err } bal.Forwarders = append(bal.Forwarders, forw) bal.wg.Add(1) go forw.Listen(&bal.wg) return err } func (bal *Proxy) DeleteForwarder(ctx context.Context, lport uint32) error { var err error forwarders := make([]*Forwarder, 0) for _, forw := range bal.Forwarders { if forw.Lport == lport { forw.Stop() continue } forwarders = append(forwarders, forw) } bal.Forwarders = forwarders return err } func (bal *Proxy) Start() error { var err error for _, forw := range bal.Forwarders { bal.wg.Add(1) go forw.Listen(&bal.wg) } bal.wg.Wait() return err } func (bal *Proxy) Stop() error { var err error for _, forw := range bal.Forwarders { forw.Stop() } return err } type Forwarder struct { listen net.Listener `json:"-" yaml:"-"` ctx context.Context `json:"-" yaml:"-"` cancel context.CancelFunc `json:"-" yaml:"-"` Lport uint32 `json:"lport" yaml:"lport"` Dport uint32 `json:"dport" yaml:"dport"` Dests []*Destination `json:"dests" yaml:"dests"` } func NewForwarder(ctx context.Context, lport, dport uint32, addrs ...string) (*Forwarder, error) { ctx, cancel := context.WithCancel(ctx) forw := &Forwarder{ Dests: make([]*Destination, 0), Lport: lport, Dport: dport, ctx: ctx, cancel: cancel, } for _, addr := range addrs { dest := NewDestination(addr) forw.Dests = append(forw.Dests, dest) } port := ":" + strconv.FormatUint(uint64(forw.Lport), 10) listen, err := net.Listen("tcp", port) if err != nil { return forw, err } forw.listen = listen return forw, err } func (forw *Forwarder) Listen(wg *sync.WaitGroup) { log.Printf("Start listening on %d\n", forw.Lport) defer wg.Done() for { conn, err := forw.listen.Accept() if err != nil { log.Printf("Listen err: %v\n", err) return } go forw.handle(forw.ctx, conn) } } func (forw *Forwarder) Stop() error { return forw.listen.Close() } func (forw *Forwarder) handle(ctx context.Context, inconn net.Conn) { log.Printf("Handler on %d started\n", forw.Lport) defer inconn.Close() if len(forw.Dests) == 0 { return } addrnum := rand.Uint32() % uint32(len(forw.Dests)) ipaddr := forw.Dests[addrnum].Address dstaddr := ipaddr + ":" + strconv.FormatUint(uint64(forw.Dport), 10) outconn, err := net.Dial("tcp", dstaddr) if err != nil { return } var wg sync.WaitGroup wg.Add(1) go forw.stream(&wg, inconn, outconn) wg.Add(1) go forw.stream(&wg, outconn, inconn) wg.Wait() log.Printf("Handler on %d stopped\n", forw.Lport) } func (forw *Forwarder) stream(wg *sync.WaitGroup, inconn io.Reader, outconn io.Writer) { defer wg.Done() _, err := copy(forw.ctx, outconn, inconn) if err != nil { log.Printf("Copy err: %v\n", err) } } type Destination struct { Address string `json:"address" yaml:"address"` } func NewDestination(addr string) *Destination { return &Destination{ Address: addr, } } func copy(ctx context.Context, writer io.Writer, reader io.Reader) (int64, error) { var err error var size int64 var halt bool buffer := make([]byte, 1024*4) for { select { case <-ctx.Done(): err = errors.New("Break copy by context") break default: } rsize, err := reader.Read(buffer) if err == io.EOF { err = nil halt = true } if err != nil { return size, err } wsize, err := writer.Write(buffer[0:rsize]) size += int64(wsize) if err != nil { return size, err } if halt { break } } return size, err }