package rproxy import ( "context" "errors" "fmt" "strings" "sync" "helmet/app/logger" ) type Proxy struct { Forwarders []*Forwarder `json:"forwarders" yaml:"forwarders"` ctx context.Context `json:"-" yaml:"-"` cancel context.CancelFunc `json:"-" yaml:"-"` wg sync.WaitGroup `json:"-" yaml:"-"` log *logger.Logger `json:"-" yaml:"-"` mtx sync.Mutex `json:"-" yaml:"-"` } func NewProxy() *Proxy { ctx, cancel := context.WithCancel(context.Background()) return &Proxy{ Forwarders: make([]*Forwarder, 0), ctx: ctx, cancel: cancel, log: logger.NewLogger("proxy"), } } func (prox *Proxy) ListForwarders(ctx context.Context) []*Forwarder { prox.mtx.Lock() defer prox.mtx.Unlock() forwards := make([]*Forwarder, 0) for _, forw := range prox.Forwarders { dests := make([]string, 0) for _, dest := range forw.Dests { dests = append(dests, dest.Address) } newForw, _ := NewForwarder(ctx, forw.Type, forw.Lport, forw.Dport, dests...) forwards = append(forwards, newForw) } return forwards } func (prox *Proxy) CreateOrUpdateForwarder(ctx context.Context, proto string, lport, dport uint32, addrs ...string) error { var err error if lport == 0 { return errors.New("Zero forwarder lport") } if dport == 0 { return errors.New("Zero forwarder dport") } if proto == "" { return errors.New("Empty forwarder type") } proto = strings.ToLower(proto) if proto != TCP && proto != UDP { return fmt.Errorf("Unknown forwarder protocol %s", proto) } prox.mtx.Lock() defer prox.mtx.Unlock() var forw *Forwarder for _, iforw := range prox.Forwarders { if iforw.Lport == lport && iforw.Type == proto { forw = iforw break } } switch { case forw == nil: prox.log.Debugf("Create rorwarder %s:%d", proto, lport) forw, err = NewForwarder(ctx, proto, lport, dport, addrs...) if err != nil { return err } prox.Forwarders = append(prox.Forwarders, forw) prox.wg.Add(1) go forw.Listen(&prox.wg) default: prox.log.Debugf("Update forwarder %s:%d", proto, lport) forw.Dport = dport dests := make([]*Destination, 0) for _, addr := range addrs { dest := NewDestination(addr) forw.Dests = append(dests, dest) } forw.Dests = dests } return err } func (prox *Proxy) DeleteForwarder(ctx context.Context, proto string, lport uint32) error { var err error var forw *Forwarder prox.mtx.Lock() defer prox.mtx.Unlock() for _, iforw := range prox.Forwarders { if iforw.Lport == lport && iforw.Type == proto { forw = iforw break } } if forw == nil { prox.log.Debugf("Forwarder for %s:%d not found", proto, lport) return err } prox.log.Debugf("Stop forwarder for %s:%d", proto, lport) err = forw.Stop() if err != nil { return err } forwarders := make([]*Forwarder, 0) for _, forw := range prox.Forwarders { if forw.Lport == lport && forw.Type == proto { continue } forwarders = append(forwarders, forw) } prox.Forwarders = forwarders return err } func (prox *Proxy) Start() error { var err error prox.mtx.Lock() defer prox.mtx.Unlock() for _, forw := range prox.Forwarders { prox.wg.Add(1) go forw.Listen(&prox.wg) } prox.wg.Wait() return err } func (prox *Proxy) Stop() error { var err error prox.mtx.Lock() defer prox.mtx.Unlock() for _, forw := range prox.Forwarders { forw.Stop() } return err }