package rproxy import ( "context" "errors" "io" "net" "strconv" "sync" "helmet/app/logger" ) type Streamer struct { ctx context.Context cancel context.CancelFunc log *logger.Logger } func NewStreamer(ctx context.Context) *Streamer { ctx, cancel := context.WithCancel(ctx) log := logger.NewLogger("streamer") return &Streamer{ ctx: ctx, cancel: cancel, log: log, } } func (str *Streamer) Stream(inconn net.Conn, dipaddr string, dport uint32) error { var err error dstaddr := dipaddr + ":" + strconv.FormatUint(uint64(dport), 10) outconn, err := net.Dial("tcp", dstaddr) if err != nil { return err } defer outconn.Close() var wg sync.WaitGroup wg.Add(1) go str.stream(&wg, inconn, outconn) wg.Add(1) go str.stream(&wg, outconn, inconn) wg.Wait() return err } func (str *Streamer) stream(wg *sync.WaitGroup, inconn io.Reader, outconn io.Writer) { defer wg.Done() _, err := iocopy(str.ctx, outconn, inconn) if err != nil { str.log.Errorf("Copy err: %v", err) } } func iocopy(ctx context.Context, writer io.Writer, reader io.Reader) (int64, error) { var err error var size int64 var halt bool buffer := make([]byte, 1024*4) for { select { case <-ctx.Done(): err = errors.New("Break copy by context") break default: } rsize, err := reader.Read(buffer) if err == io.EOF { err = nil halt = true } if err != nil { return size, err } wsize, err := writer.Write(buffer[0:rsize]) size += int64(wsize) if err != nil { return size, err } if halt { break } } return size, err }