working commit

This commit is contained in:
2026-03-25 19:00:25 +02:00
parent 5e7b1f312d
commit 7142e32a39
7 changed files with 264 additions and 196 deletions

View File

@@ -1,9 +1,8 @@
package operator
import (
"helmet/app/logger"
"helmet/app/config"
"helmet/app/logger"
"helmet/app/rproxy"
)

View File

@@ -33,6 +33,9 @@ func (oper *Operator) CreateForwarder(ctx context.Context, params *mlbctl.Create
var err error
res := &mlbctl.CreateForwarderResult{}
err = oper.proxy.AddForwarder(ctx, params.Lport, params.Dport, params.Destinations...)
if err != err {
return res, err
}
return res, err
}
@@ -40,5 +43,8 @@ func (oper *Operator) DeleteForwarder(ctx context.Context, params *mlbctl.Delete
var err error
res := &mlbctl.DeleteForwarderResult{}
err = oper.proxy.DeleteForwarder(ctx, params.Lport)
if err != err {
return res, err
}
return res, err
}

124
app/rproxy/forwarder.go Normal file
View File

@@ -0,0 +1,124 @@
package rproxy
import (
"context"
"io"
"math/rand"
"net"
"strconv"
"sync"
"helmet/app/logger"
)
type Forwarder struct {
listen net.Listener `json:"-" yaml:"-"`
ctx context.Context `json:"-" yaml:"-"`
cancel context.CancelFunc `json:"-" yaml:"-"`
Lport uint32 `json:"lport" yaml:"lport"`
Dport uint32 `json:"dport" yaml:"dport"`
Dests []*Destination `json:"dests" yaml:"dests"`
log *logger.Logger
}
func NewForwarder(ctx context.Context, lport, dport uint32, addrs ...string) (*Forwarder, error) {
ctx, cancel := context.WithCancel(context.Background())
forw := &Forwarder{
Dests: make([]*Destination, 0),
Lport: lport,
Dport: dport,
ctx: ctx,
cancel: cancel,
}
id := strconv.FormatUint(uint64(lport), 10)
forw.log = logger.NewLogger("forwarder:" + id)
for _, addr := range addrs {
dest := NewDestination(addr)
forw.Dests = append(forw.Dests, dest)
}
portinfo := ":" + strconv.FormatUint(uint64(forw.Lport), 10)
laddr, err := net.ResolveTCPAddr("tcp", portinfo)
if err != nil {
return forw, err
}
listen, err := net.ListenTCP("tcp", laddr)
if err != nil {
return forw, err
}
forw.listen = listen
return forw, err
}
func (forw *Forwarder) Listen(wg *sync.WaitGroup) {
forw.log.Debugf("Start listening on %d", forw.Lport)
defer wg.Done()
for {
conn, err := forw.listen.Accept()
if err != nil {
forw.log.Errorf("Listen err: %v", err)
return
}
go forw.handle(forw.ctx, conn)
}
}
func (forw *Forwarder) Stop() error {
return forw.listen.Close()
}
type Streamer struct {
source string
dest string
ctx context.Context
cancel context.CancelFunc
}
func NewStreamer(ctx context.Context) *Streamer {
ctx, cancel := context.WithCancel(ctx)
return &Streamer{
ctx: ctx,
cancel: cancel,
}
}
func (forw *Forwarder) handle(ctx context.Context, inconn net.Conn) {
forw.log.Debugf("%s: Handle on %d started", inconn.RemoteAddr(), forw.Lport)
defer inconn.Close()
if len(forw.Dests) == 0 {
return
}
addrnum := rand.Uint32() % uint32(len(forw.Dests))
ipaddr := forw.Dests[addrnum].Address
dstaddr := ipaddr + ":" + strconv.FormatUint(uint64(forw.Dport), 10)
outconn, err := net.Dial("tcp", dstaddr)
if err != nil {
return
}
defer outconn.Close()
var wg sync.WaitGroup
wg.Add(1)
go forw.stream(&wg, inconn, outconn)
wg.Add(1)
go forw.stream(&wg, outconn, inconn)
wg.Wait()
forw.log.Debugf("Handler on %d stopped", forw.Lport)
}
func (forw *Forwarder) stream(wg *sync.WaitGroup, inconn io.Reader, outconn io.Writer) {
defer wg.Done()
_, err := copy(forw.ctx, outconn, inconn)
if err != nil {
forw.log.Errorf("Copy err: %v", err)
}
}
type Destination struct {
Address string `json:"address" yaml:"address"`
}
func NewDestination(addr string) *Destination {
return &Destination{
Address: addr,
}
}

View File

@@ -1,192 +0,0 @@
package rproxy
import (
"context"
"errors"
"io"
"log"
"math/rand"
"net"
"strconv"
"sync"
)
type Proxy struct {
Forwarders []*Forwarder `json:"forwarders" yaml:"forwarders"`
ctx context.Context `json:"-" yaml:"-"`
cancel context.CancelFunc `json:"-" yaml:"-"`
wg sync.WaitGroup `json:"-" yaml:"-"`
}
func NewProxy() *Proxy {
ctx, cancel := context.WithCancel(context.Background())
return &Proxy{
Forwarders: make([]*Forwarder, 0),
ctx: ctx,
cancel: cancel,
}
}
func (bal *Proxy) AddForwarder(ctx context.Context, lport, dport uint32, addrs ...string) error {
var err error
forw, err := NewForwarder(ctx, lport, dport, addrs...)
if err != nil {
return err
}
bal.Forwarders = append(bal.Forwarders, forw)
bal.wg.Add(1)
go forw.Listen(&bal.wg)
return err
}
func (bal *Proxy) DeleteForwarder(ctx context.Context, lport uint32) error {
var err error
forwarders := make([]*Forwarder, 0)
for _, forw := range bal.Forwarders {
if forw.Lport == lport {
forw.Stop()
continue
}
forwarders = append(forwarders, forw)
}
bal.Forwarders = forwarders
return err
}
func (bal *Proxy) Start() error {
var err error
for _, forw := range bal.Forwarders {
bal.wg.Add(1)
go forw.Listen(&bal.wg)
}
bal.wg.Wait()
return err
}
func (bal *Proxy) Stop() error {
var err error
for _, forw := range bal.Forwarders {
forw.Stop()
}
return err
}
type Forwarder struct {
listen net.Listener `json:"-" yaml:"-"`
ctx context.Context `json:"-" yaml:"-"`
cancel context.CancelFunc `json:"-" yaml:"-"`
Lport uint32 `json:"lport" yaml:"lport"`
Dport uint32 `json:"dport" yaml:"dport"`
Dests []*Destination `json:"dests" yaml:"dests"`
}
func NewForwarder(ctx context.Context, lport, dport uint32, addrs ...string) (*Forwarder, error) {
ctx, cancel := context.WithCancel(ctx)
forw := &Forwarder{
Dests: make([]*Destination, 0),
Lport: lport,
Dport: dport,
ctx: ctx,
cancel: cancel,
}
for _, addr := range addrs {
dest := NewDestination(addr)
forw.Dests = append(forw.Dests, dest)
}
port := ":" + strconv.FormatUint(uint64(forw.Lport), 10)
listen, err := net.Listen("tcp", port)
if err != nil {
return forw, err
}
forw.listen = listen
return forw, err
}
func (forw *Forwarder) Listen(wg *sync.WaitGroup) {
log.Printf("Start listening on %d\n", forw.Lport)
defer wg.Done()
for {
conn, err := forw.listen.Accept()
if err != nil {
log.Printf("Listen err: %v\n", err)
return
}
go forw.handle(forw.ctx, conn)
}
}
func (forw *Forwarder) Stop() error {
return forw.listen.Close()
}
func (forw *Forwarder) handle(ctx context.Context, inconn net.Conn) {
log.Printf("Handler on %d started\n", forw.Lport)
defer inconn.Close()
if len(forw.Dests) == 0 {
return
}
addrnum := rand.Uint32() % uint32(len(forw.Dests))
ipaddr := forw.Dests[addrnum].Address
dstaddr := ipaddr + ":" + strconv.FormatUint(uint64(forw.Dport), 10)
outconn, err := net.Dial("tcp", dstaddr)
if err != nil {
return
}
var wg sync.WaitGroup
wg.Add(1)
go forw.stream(&wg, inconn, outconn)
wg.Add(1)
go forw.stream(&wg, outconn, inconn)
wg.Wait()
log.Printf("Handler on %d stopped\n", forw.Lport)
}
func (forw *Forwarder) stream(wg *sync.WaitGroup, inconn io.Reader, outconn io.Writer) {
defer wg.Done()
_, err := copy(forw.ctx, outconn, inconn)
if err != nil {
log.Printf("Copy err: %v\n", err)
}
}
type Destination struct {
Address string `json:"address" yaml:"address"`
}
func NewDestination(addr string) *Destination {
return &Destination{
Address: addr,
}
}
func copy(ctx context.Context, writer io.Writer, reader io.Reader) (int64, error) {
var err error
var size int64
var halt bool
buffer := make([]byte, 1024*4)
for {
select {
case <-ctx.Done():
err = errors.New("Break copy by context")
break
default:
}
rsize, err := reader.Read(buffer)
if err == io.EOF {
err = nil
halt = true
}
if err != nil {
return size, err
}
wsize, err := writer.Write(buffer[0:rsize])
size += int64(wsize)
if err != nil {
return size, err
}
if halt {
break
}
}
return size, err
}

125
app/rproxy/rproxy.go Normal file
View File

@@ -0,0 +1,125 @@
package rproxy
import (
"context"
"errors"
"io"
"sync"
"helmet/app/logger"
)
type Proxy struct {
Forwarders []*Forwarder `json:"forwarders" yaml:"forwarders"`
ctx context.Context `json:"-" yaml:"-"`
cancel context.CancelFunc `json:"-" yaml:"-"`
wg sync.WaitGroup `json:"-" yaml:"-"`
log *logger.Logger `json:"-" yaml:"-"`
}
func NewProxy() *Proxy {
ctx, cancel := context.WithCancel(context.Background())
return &Proxy{
Forwarders: make([]*Forwarder, 0),
ctx: ctx,
cancel: cancel,
log: logger.NewLogger("proxy"),
}
}
func (bal *Proxy) AddForwarder(ctx context.Context, lport, dport uint32, addrs ...string) error {
var err error
if lport == 0 {
return errors.New("Zero lport")
}
if dport == 0 {
return errors.New("Zero dport")
}
forw, err := NewForwarder(ctx, lport, dport, addrs...)
if err != nil {
return err
}
bal.Forwarders = append(bal.Forwarders, forw)
bal.wg.Add(1)
go forw.Listen(&bal.wg)
return err
}
func (bal *Proxy) DeleteForwarder(ctx context.Context, lport uint32) error {
var err error
var forw *Forwarder
for _, iforw := range bal.Forwarders {
if iforw.Lport == lport {
forw = iforw
break
}
}
if forw == nil {
bal.log.Debugf("Forwarder for lport %d not found", lport)
return err
}
bal.log.Debugf("Stop forwarder for lport %d", lport)
err = forw.Stop()
if err != nil {
return err
}
forwarders := make([]*Forwarder, 0)
for _, forw := range bal.Forwarders {
if forw.Lport == lport {
continue
}
forwarders = append(forwarders, forw)
}
bal.Forwarders = forwarders
return err
}
func (bal *Proxy) Start() error {
var err error
for _, forw := range bal.Forwarders {
bal.wg.Add(1)
go forw.Listen(&bal.wg)
}
bal.wg.Wait()
return err
}
func (bal *Proxy) Stop() error {
var err error
for _, forw := range bal.Forwarders {
forw.Stop()
}
return err
}
func copy(ctx context.Context, writer io.Writer, reader io.Reader) (int64, error) {
var err error
var size int64
var halt bool
buffer := make([]byte, 1024*4)
for {
select {
case <-ctx.Done():
err = errors.New("Break copy by context")
break
default:
}
rsize, err := reader.Read(buffer)
if err == io.EOF {
err = nil
halt = true
}
if err != nil {
return size, err
}
wsize, err := writer.Write(buffer[0:rsize])
size += int64(wsize)
if err != nil {
return size, err
}
if halt {
break
}
}
return size, err
}

View File

@@ -109,7 +109,13 @@ func (svc *Service) logInterceptor(ctx context.Context, req any, info *grpc.Unar
if err == nil {
svc.log.Debugf("Request: %s", string(reqData))
}
return handler(ctx, req)
res, err := handler(ctx, req)
resData, err := json.Marshal(res)
if err == nil {
svc.log.Debugf("Response: %s", string(resData))
}
return res, err
}
func (svc *Service) Stop() {

View File

@@ -47,7 +47,7 @@ func NewTool() *Tool {
Use: "delete",
Args: cobra.ExactArgs(1),
Short: "Delete forwarder",
Run: tool.CreateForwarder,
Run: tool.DeleteForwarder,
}
deleteForwarderCmd.Flags().Uint32VarP(&tool.deleteForwarderParams.Lport, "lport", "L", 0, "Listening port")
deleteForwarderCmd.MarkFlagRequired("lport")