package rproxy import ( "context" "fmt" "math/rand" "net" "strconv" "sync" "time" "helmet/app/logger" ) const ( TCP = "tcp" UDP = "udp" ForwStarted = "started" ForwStopped = "stopped" ) type Forwarder struct { State string `json:"state" yaml:"state"` Type string `json:"type" yaml:"type"` Lport uint32 `json:"lport" yaml:"lport"` Dport uint32 `json:"dport" yaml:"dport"` Dests []*Destination `json:"dests" yaml:"dests"` listenTCP *net.TCPListener `json:"-" yaml:"-"` listenUDP *net.UDPConn `json:"-" yaml:"-"` ctx context.Context `json:"-" yaml:"-"` cancel context.CancelFunc `json:"-" yaml:"-"` log *logger.Logger } func NewForwarder(ctx context.Context, proto string, lport, dport uint32, addrs ...string) (*Forwarder, error) { var err error ctx, cancel := context.WithCancel(context.Background()) forw := &Forwarder{ Dests: make([]*Destination, 0), Lport: lport, Dport: dport, ctx: ctx, cancel: cancel, Type: proto, } id := forw.Type + strconv.FormatUint(uint64(lport), 10) forw.log = logger.NewLogger("forwarder:" + id) for _, addr := range addrs { dest := NewDestination(addr) forw.Dests = append(forw.Dests, dest) } switch proto { case TCP: portinfo := ":" + strconv.FormatUint(uint64(forw.Lport), 10) laddr, err := net.ResolveTCPAddr("tcp", portinfo) if err != nil { return forw, err } listen, err := net.ListenTCP("tcp", laddr) if err != nil { return forw, err } forw.listenTCP = listen case UDP: portinfo := ":" + strconv.FormatUint(uint64(forw.Lport), 10) laddr, err := net.ResolveUDPAddr("udp", portinfo) if err != nil { return forw, err } listen, err := net.ListenUDP("udp", laddr) if err != nil { return forw, err } forw.listenUDP = listen default: err = fmt.Errorf("Unknown net type: %s", proto) return forw, err } return forw, err } func (forw *Forwarder) Listen(wg *sync.WaitGroup) { switch forw.Type { case TCP: if forw.listenTCP != nil { forw.ListenTCP(wg) } case UDP: if forw.listenUDP != nil { forw.ListenUDP(wg) } } } func (forw *Forwarder) ListenTCP(wg *sync.WaitGroup) { forw.log.Debugf("Start listening on %s:%d", forw.Type, forw.Lport) forw.State = ForwStarted defer wg.Done() stater := func() { forw.State = ForwStopped } defer stater() for { conn, err := forw.listenTCP.Accept() if err != nil { forw.log.Errorf("Listen err: %v", err) return } go forw.handleTCP(forw.ctx, conn) } } func (forw *Forwarder) ListenUDP(wg *sync.WaitGroup) { forw.log.Debugf("Start listening on %s:%d", forw.Type, forw.Lport) forw.State = ForwStarted defer wg.Done() stater := func() { forw.State = ForwStopped } defer stater() for { buffer := make([]byte, 2048) size, srcAddr, err := forw.listenUDP.ReadFromUDP(buffer) if err != nil { forw.log.Errorf("Error reading: %v", err) continue } go forw.handleUDP(forw.listenUDP, srcAddr, buffer[:size]) } } func (forw *Forwarder) handleUDP(listConn *net.UDPConn, srcAddr *net.UDPAddr, data []byte) { forw.log.Debugf("Handle on %d started", forw.Lport) if len(forw.Dests) == 0 { return } // Select dest address addrnum := rand.Uint32() % uint32(len(forw.Dests)) ipaddr := forw.Dests[addrnum].Address destInfo := ipaddr + ":" + strconv.FormatUint(uint64(forw.Dport), 10) destAddr, err := net.ResolveUDPAddr("udp", destInfo) if err != nil { forw.log.Debugf("Error resolving server address: %v", err) return } // Write to destination destConn, err := net.DialUDP("udp", nil, destAddr) if err != nil { forw.log.Debugf("Error dialing: %v", err) return } defer destConn.Close() _, err = destConn.Write(data) if err != nil { forw.log.Debugf("Error sending message: %v", err) return } const deadlinePeriod = 5 * time.Second destConn.SetReadDeadline(time.Now().Add(deadlinePeriod)) // Read from destination and resend to initiator const readCount = 1 for i := 0; i < readCount; i++ { buffer := make([]byte, 1024*2) size, _, err := destConn.ReadFromUDP(buffer) if err != nil { forw.log.Debugf("Error reading response: %v", err) return } _, err = listConn.WriteToUDP(buffer[:size], srcAddr) if err != nil { forw.log.Errorf("Error writing to back: %v", err) return } } } func (forw *Forwarder) Stop() error { var err error switch forw.Type { case TCP: if forw.listenTCP != nil { return forw.listenTCP.Close() } case UDP: if forw.listenUDP != nil { return forw.listenUDP.Close() } } return err } func (forw *Forwarder) handleTCP(ctx context.Context, inconn net.Conn) { forw.log.Debugf("%s: Handle on %d started", inconn.RemoteAddr(), forw.Lport) defer inconn.Close() if len(forw.Dests) == 0 { return } // Select dest address addrnum := rand.Uint32() % uint32(len(forw.Dests)) ipaddr := forw.Dests[addrnum].Address str := NewStreamer(forw.ctx) err := str.Stream(inconn, ipaddr, forw.Dport) if err != nil { forw.log.Errorf("Handler on %d error: %v", forw.Lport, err) } forw.log.Debugf("Handler on %d stopped", forw.Lport) } type Destination struct { Address string `json:"address" yaml:"address"` } func NewDestination(addr string) *Destination { return &Destination{ Address: addr, } }