package rproxy import ( "context" "errors" "io" "sync" "helmet/app/logger" ) type Proxy struct { Forwarders []*Forwarder `json:"forwarders" yaml:"forwarders"` ctx context.Context `json:"-" yaml:"-"` cancel context.CancelFunc `json:"-" yaml:"-"` wg sync.WaitGroup `json:"-" yaml:"-"` log *logger.Logger `json:"-" yaml:"-"` } func NewProxy() *Proxy { ctx, cancel := context.WithCancel(context.Background()) return &Proxy{ Forwarders: make([]*Forwarder, 0), ctx: ctx, cancel: cancel, log: logger.NewLogger("proxy"), } } func (bal *Proxy) AddForwarder(ctx context.Context, lport, dport uint32, addrs ...string) error { var err error if lport == 0 { return errors.New("Zero lport") } if dport == 0 { return errors.New("Zero dport") } forw, err := NewForwarder(ctx, lport, dport, addrs...) if err != nil { return err } bal.Forwarders = append(bal.Forwarders, forw) bal.wg.Add(1) go forw.Listen(&bal.wg) return err } func (bal *Proxy) DeleteForwarder(ctx context.Context, lport uint32) error { var err error var forw *Forwarder for _, iforw := range bal.Forwarders { if iforw.Lport == lport { forw = iforw break } } if forw == nil { bal.log.Debugf("Forwarder for lport %d not found", lport) return err } bal.log.Debugf("Stop forwarder for lport %d", lport) err = forw.Stop() if err != nil { return err } forwarders := make([]*Forwarder, 0) for _, forw := range bal.Forwarders { if forw.Lport == lport { continue } forwarders = append(forwarders, forw) } bal.Forwarders = forwarders return err } func (bal *Proxy) Start() error { var err error for _, forw := range bal.Forwarders { bal.wg.Add(1) go forw.Listen(&bal.wg) } bal.wg.Wait() return err } func (bal *Proxy) Stop() error { var err error for _, forw := range bal.Forwarders { forw.Stop() } return err } func copy(ctx context.Context, writer io.Writer, reader io.Reader) (int64, error) { var err error var size int64 var halt bool buffer := make([]byte, 1024*4) for { select { case <-ctx.Done(): err = errors.New("Break copy by context") break default: } rsize, err := reader.Read(buffer) if err == io.EOF { err = nil halt = true } if err != nil { return size, err } wsize, err := writer.Write(buffer[0:rsize]) size += int64(wsize) if err != nil { return size, err } if halt { break } } return size, err }