From f76cbc32aa93569255d8a9258d9afaafd7ccedbe Mon Sep 17 00:00:00 2001 From: Oleg Borodin Date: Tue, 26 Jul 2022 14:26:40 +0200 Subject: [PATCH] changed: server stopping, dial/listen to tcp dial/listen, json rpc encoding to msgpack, info getters for logging ; added keepalive setter to server, etc --- client.go | 65 +++++++++++++++++++++++++---- compat.go | 4 ++ context.go | 70 ++++++++++++++++++++++++++++++- error.go | 4 ++ exec_test.go | 6 +-- faddr.go | 2 - go.mod | 10 +++-- go.sum | 15 +++++-- header.go | 14 +++---- logger.go | 8 ++-- midware.go | 6 +-- packet.go | 6 +-- request.go | 15 +++---- response.go | 15 ++++--- server.go | 113 +++++++++++++++++++++++++++++++++++---------------- tools.go | 2 +- xauth.go | 6 +-- 17 files changed, 266 insertions(+), 95 deletions(-) diff --git a/client.go b/client.go index 6b72051..1a0277c 100644 --- a/client.go +++ b/client.go @@ -1,29 +1,46 @@ /* - * * Copyright 2022 Oleg Borodin - * */ package dsrpc import ( - "encoding/json" "errors" + "fmt" "io" "net" "sync" + + encoder "github.com/vmihailenco/msgpack/v5" ) func Put(address string, method string, reader io.Reader, size int64, param, result any, auth *Auth) error { var err error - conn, err := net.Dial("tcp", address) + addr, err := net.ResolveTCPAddr("tcp", address) + if err != nil { + err = fmt.Errorf("unable to resolve adddress: %s", err) + return Err(err) + } + conn, err := net.DialTCP("tcp", nil, addr) if err != nil { return Err(err) } defer conn.Close() + //err = conn.SetKeepAlive(true) + //if err != nil { + // err = fmt.Errorf("unable to set keepalive: %s", err) + // return Err(err) + //} + + //err = conn.SetKeepAlivePeriod(10 * time.Second) + //if err != nil { + // err = fmt.Errorf("unable to set keepalive period: %s", err) + // return Err(err) + //} + return ConnPut(conn, method, reader, size, param, result, auth) } @@ -78,12 +95,29 @@ func ConnPut(conn net.Conn, method string, reader io.Reader, size int64, param, func Get(address string, method string, writer io.Writer, param, result any, auth *Auth) error { var err error - conn, err := net.Dial("tcp", address) + addr, err := net.ResolveTCPAddr("tcp", address) + if err != nil { + err = fmt.Errorf("unable to resolve adddress: %s", err) + return Err(err) + } + conn, err := net.DialTCP("tcp", nil, addr) if err != nil { return Err(err) } defer conn.Close() + //err = conn.SetKeepAlive(true) + //if err != nil { + // err = fmt.Errorf("unable to set keepalive: %s", err) + // return Err(err) + //} + + //err = conn.SetKeepAlivePeriod(10 * time.Second) + //if err != nil { + // err = fmt.Errorf("unable to set keepalive period: %s", err) + // return Err(err) + //} + return ConnGet(conn, method, writer, param, result, auth) } @@ -127,12 +161,29 @@ func ConnGet(conn net.Conn, method string, writer io.Writer, param, result any, func Exec(address, method string, param any, result any, auth *Auth) error { var err error - conn, err := net.Dial("tcp", address) + + addr, err := net.ResolveTCPAddr("tcp", address) + if err != nil { + err = fmt.Errorf("unable to resolve adddress: %s", err) + return Err(err) + } + conn, err := net.DialTCP("tcp", nil, addr) if err != nil { return Err(err) } defer conn.Close() + //err = conn.SetKeepAlive(true) + //if err != nil { + // err = fmt.Errorf("unable to set keepalive: %s", err) + // return Err(err) + //} + //err = conn.SetKeepAlivePeriod(10 * time.Second) + //if err != nil { + // err = fmt.Errorf("unable to set keepalive period: %s", err) + // return Err(err) + //} + err = ConnExec(conn, method, param, result, auth) if err != nil { return Err(err) @@ -273,7 +324,7 @@ func (context *Context) DownloadBin() error { func (context *Context) BindResponse() error { var err error - err = json.Unmarshal(context.resPacket.rcpPayload, context.resRPC) + err = encoder.Unmarshal(context.resPacket.rcpPayload, context.resRPC) if err != nil { return Err(err) } diff --git a/compat.go b/compat.go index 506e4b6..4f993be 100644 --- a/compat.go +++ b/compat.go @@ -1,3 +1,7 @@ +/* + * Copyright 2022 Oleg Borodin + */ + package dsrpc type any = interface{} diff --git a/context.go b/context.go index 7aeebeb..c38b4d6 100644 --- a/context.go +++ b/context.go @@ -1,7 +1,5 @@ /* - * * Copyright 2022 Oleg Borodin - * */ package dsrpc @@ -60,6 +58,74 @@ func (context *Context) Request() *Request { return context.reqRPC } +func (context *Context) RemoteHost() string { + return context.remoteHost +} + +func (context *Context) Start() time.Time { + return context.start +} + +func (context *Context) Method() string { + var method string + if context.reqRPC != nil { + method = context.reqRPC.Method + } + return method +} + +func (context *Context) ReqRpcSize() int64 { + var size int64 + if context.reqHeader != nil { + size = context.reqHeader.rpcSize + } + return size +} + + +func (context *Context) ReqBinSize() int64 { + var size int64 + if context.reqHeader != nil { + size = context.reqHeader.binSize + } + return size +} + +func (context *Context) ResBinSize() int64 { + var size int64 + if context.resHeader != nil { + size = context.resHeader.binSize + } + return size +} + +func (context *Context) ResRpcSize() int64 { + var size int64 + if context.resHeader != nil { + size = context.resHeader.rpcSize + } + return size +} + +func (context *Context) ReqSize() int64 { + var size int64 + if context.reqHeader != nil { + size += context.reqHeader.binSize + size += context.reqHeader.rpcSize + } + return size +} + +func (context *Context) ResSize() int64 { + var size int64 + if context.resHeader != nil { + size += context.resHeader.binSize + size += context.resHeader.rpcSize + } + return size +} + + func (context *Context) SetAuthIdent(ident []byte) { context.reqRPC.Auth.Ident = ident diff --git a/error.go b/error.go index fe473e8..100ebb1 100644 --- a/error.go +++ b/error.go @@ -1,3 +1,7 @@ +/* + * Copyright 2022 Oleg Borodin + */ + package dsrpc import ( diff --git a/exec_test.go b/exec_test.go index 3736f47..7649166 100644 --- a/exec_test.go +++ b/exec_test.go @@ -1,7 +1,5 @@ /* - * * Copyright 2022 Oleg Borodin - * */ package dsrpc @@ -335,7 +333,7 @@ func loadHandler(context *Context) error { const HelloMethod string = "hello" type HelloParams struct { - Message string `json:"message" json:"message"` + Message string `json:"message" msgpack:"message"` } func NewHelloParams() *HelloParams { @@ -343,7 +341,7 @@ func NewHelloParams() *HelloParams { } type HelloResult struct { - Message string `json:"message" json:"message"` + Message string `json:"message" msgpack:"message"` } func NewHelloResult() *HelloResult { diff --git a/faddr.go b/faddr.go index e67c248..075df38 100644 --- a/faddr.go +++ b/faddr.go @@ -1,7 +1,5 @@ /* - * * Copyright 2022 Oleg Borodin - * */ package dsrpc diff --git a/go.mod b/go.mod index 8ab80d4..6288e12 100644 --- a/go.mod +++ b/go.mod @@ -2,10 +2,14 @@ module github.com/kindsoldier/dsrpc go 1.17 -require github.com/stretchr/testify v1.7.1 +require ( + github.com/stretchr/testify v1.8.0 + github.com/vmihailenco/msgpack/v5 v5.3.5 +) require ( - github.com/davecgh/go-spew v1.1.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect + github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 2dca7c9..11b3d5a 100644 --- a/go.sum +++ b/go.sum @@ -1,11 +1,20 @@ -github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU= +github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc= +github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= +github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/header.go b/header.go index 33f22c4..c4104cd 100644 --- a/header.go +++ b/header.go @@ -33,27 +33,27 @@ func NewHeader() *Header { } } -func (this *Header) JSON() []byte { - jBytes, _ := json.Marshal(this) +func (hdr *Header) JSON() []byte { + jBytes, _ := json.Marshal(hdr) return jBytes } -func (this *Header) Pack() ([]byte, error) { +func (hdr *Header) Pack() ([]byte, error) { var err error headerBytes := make([]byte, 0, headerSize) headerBuffer := bytes.NewBuffer(headerBytes) - magicCodeABytes := encoderI64(this.magicCodeA) + magicCodeABytes := encoderI64(hdr.magicCodeA) headerBuffer.Write(magicCodeABytes) - rpcSizeBytes := encoderI64(this.rpcSize) + rpcSizeBytes := encoderI64(hdr.rpcSize) headerBuffer.Write(rpcSizeBytes) - binSizeBytes := encoderI64(this.binSize) + binSizeBytes := encoderI64(hdr.binSize) headerBuffer.Write(binSizeBytes) - magicCodeBBytes := encoderI64(this.magicCodeB) + magicCodeBBytes := encoderI64(hdr.magicCodeB) headerBuffer.Write(magicCodeBBytes) return headerBuffer.Bytes(), Err(err) diff --git a/logger.go b/logger.go index d8aab96..d0632e6 100644 --- a/logger.go +++ b/logger.go @@ -18,22 +18,22 @@ var accessWriter io.Writer = os.Stdout func logDebug(messages ...any) { stamp := time.Now().Format(time.RFC3339) - fmt.Fprintln(messageWriter, stamp, "dsrpc debug", messages) + fmt.Fprintln(messageWriter, stamp, "debug", messages) } func logInfo(messages ...any) { stamp := time.Now().Format(time.RFC3339) - fmt.Fprintln(messageWriter, stamp, "dsrpc info", messages) + fmt.Fprintln(messageWriter, stamp, "info", messages) } func logError(messages ...any) { stamp := time.Now().Format(time.RFC3339) - fmt.Fprintln(messageWriter, stamp, "dsrpc error", messages) + fmt.Fprintln(messageWriter, stamp, "error", messages) } func logAccess(messages ...any) { stamp := time.Now().Format(time.RFC3339) - fmt.Fprintln(accessWriter, stamp, "dsrpc access", messages) + fmt.Fprintln(accessWriter, stamp, "access", messages) } func SetAccessWriter(writer io.Writer) { diff --git a/midware.go b/midware.go index 75f871b..838381e 100644 --- a/midware.go +++ b/midware.go @@ -1,10 +1,7 @@ /* - * * Copyright 2022 Oleg Borodin - * */ - package dsrpc import ( @@ -26,6 +23,7 @@ func LogResponse(context *Context) error { func LogAccess(context *Context) error { var err error execTime := time.Now().Sub(context.start) - logAccess(context.remoteHost, context.reqRPC.Method, execTime) + login := string(context.AuthIdent()) + logAccess(context.remoteHost, login, context.reqRPC.Method, execTime) return Err(err) } diff --git a/packet.go b/packet.go index d347037..9d88b55 100644 --- a/packet.go +++ b/packet.go @@ -1,7 +1,5 @@ /* - * * Copyright 2022 Oleg Borodin - * */ package dsrpc @@ -19,7 +17,7 @@ func NewPacket() *Packet { return &Packet{} } -func (this *Packet) JSON() []byte { - jBytes, _ := json.Marshal(this) +func (pkt *Packet) JSON() []byte { + jBytes, _ := json.Marshal(pkt) return jBytes } diff --git a/request.go b/request.go index 18cf183..2143813 100644 --- a/request.go +++ b/request.go @@ -8,12 +8,13 @@ package dsrpc import ( "encoding/json" + encoder "github.com/vmihailenco/msgpack/v5" ) type Request struct { - Method string `json:"method" msgpack:"method"` - Params any `json:"params,omitempty" msgpack:"params,omitempty"` - Auth *Auth `json:"auth,omitempty" msgpack:"auth,omitempty"` + Method string `json:"method" msgpack:"method"` + Params any `json:"params,omitempty" msgpack:"params"` + Auth *Auth `json:"auth,omitempty" msgpack:"auth"` } func NewRequest() *Request { @@ -22,12 +23,12 @@ func NewRequest() *Request { return req } -func (this *Request) Pack() ([]byte, error) { - rBytes, err := json.Marshal(this) +func (req *Request) Pack() ([]byte, error) { + rBytes, err := encoder.Marshal(req) return rBytes, Err(err) } -func (this *Request) JSON() []byte { - jBytes, _ := json.Marshal(this) +func (req *Request) JSON() []byte { + jBytes, _ := json.Marshal(req) return jBytes } diff --git a/response.go b/response.go index a03bb03..a010500 100644 --- a/response.go +++ b/response.go @@ -1,31 +1,30 @@ /* - * * Copyright 2022 Oleg Borodin - * */ package dsrpc import ( "encoding/json" + encoder "github.com/vmihailenco/msgpack/v5" ) type Response struct { - Error string `json:"error,omitempty" msgpack:"error,omitempty"` - Result any `json:"result,omitemty" msgpack:"result,omitemty"` + Error string `json:"error" msgpack:"error"` + Result any `json:"result" msgpack:"result"` } func NewResponse() *Response { return &Response{} } -func (this *Response) JSON() []byte { - jBytes, _ := json.Marshal(this) +func (resp *Response) JSON() []byte { + jBytes, _ := json.Marshal(resp) return jBytes } -func (this *Response) Pack() ([]byte, error) { - rBytes, err := json.Marshal(this) +func (resp *Response) Pack() ([]byte, error) { + rBytes, err := encoder.Marshal(resp) return rBytes, Err(err) } diff --git a/server.go b/server.go index 791a19b..dfe7f79 100644 --- a/server.go +++ b/server.go @@ -1,29 +1,33 @@ /* - * * Copyright 2022 Oleg Borodin - * */ package dsrpc import ( "context" - "encoding/json" "errors" + "fmt" "io" "net" "sync" + "time" + + encoder "github.com/vmihailenco/msgpack/v5" ) type HandlerFunc = func(*Context) error type Service struct { - handlers map[string]HandlerFunc - ctx context.Context - cancel context.CancelFunc - wg *sync.WaitGroup - preMw []HandlerFunc - postMw []HandlerFunc + handlers map[string]HandlerFunc + ctx context.Context + cancel context.CancelFunc + wg *sync.WaitGroup + preMw []HandlerFunc + postMw []HandlerFunc + keepalive bool + kaTime time.Duration + kaMtx sync.Mutex } func NewService() *Service { @@ -40,41 +44,59 @@ func NewService() *Service { return rdrpc } -func (this *Service) PreMiddleware(mw HandlerFunc) { - this.preMw = append(this.preMw, mw) +func (svc *Service) PreMiddleware(mw HandlerFunc) { + svc.preMw = append(svc.preMw, mw) +} + +func (svc *Service) PostMiddleware(mw HandlerFunc) { + svc.postMw = append(svc.postMw, mw) } -func (this *Service) PostMiddleware(mw HandlerFunc) { - this.postMw = append(this.postMw, mw) +func (svc *Service) Handler(method string, handler HandlerFunc) { + svc.handlers[method] = handler } +func (svc *Service) SetKeepAlive(flag bool) { + svc.kaMtx.Lock() + defer svc.kaMtx.Unlock() + svc.keepalive = true +} -func (this *Service) Handler(method string, handler HandlerFunc) { - this.handlers[method] = handler +func (svc *Service) SetKeepAlivePeriod(interval time.Duration) { + svc.kaMtx.Lock() + defer svc.kaMtx.Unlock() + svc.kaTime = interval } -func (this *Service) Listen(address string) error { +func (svc *Service) Listen(address string) error { var err error logInfo("server listen:", address) - listener, err := net.Listen("tcp", address) + addr, err := net.ResolveTCPAddr("tcp", address) if err != nil { + err = fmt.Errorf("unable to resolve adddress: %s", err) return err } - this.wg.Add(1) + listener, err := net.ListenTCP("tcp", addr) + if err != nil { + err = fmt.Errorf("unable to start listener: %s", err) + return err + } + for { + conn, err := listener.AcceptTCP() + if err != nil { + logError("conn accept err:", err) + } select { - case <- this.ctx.Done(): - this.wg.Done() + case <-svc.ctx.Done(): return err default: } - conn, err := listener.Accept() - if err != nil { - logError("conn accept err:", err) - } - go this.handleConn(conn) + svc.wg.Add(1) + go svc.handleConn(conn, svc.wg) } + return err } func notFound(context *Context) error { @@ -83,16 +105,34 @@ func notFound(context *Context) error { return err } -func (this *Service) Stop() error { +func (svc *Service) Stop() error { var err error - this.cancel() - this.wg.Wait() + // Disable new connection + logInfo("cancel rpc accept loop") + svc.cancel() + // Wait handlers + logInfo("wait rpc handlers") + svc.wg.Wait() return err } -func (this *Service) handleConn(conn net.Conn) { +func (svc *Service) handleConn(conn *net.TCPConn, wg *sync.WaitGroup) { var err error + if svc.keepalive { + err = conn.SetKeepAlive(true) + if err != nil { + err = fmt.Errorf("unable to set keepalive: %s", err) + return + } + if svc.kaTime > 0 { + err = conn.SetKeepAlivePeriod(svc.kaTime) + if err != nil { + err = fmt.Errorf("unable to set keepalive period: %s", err) + return + } + } + } context := CreateContext(conn) remoteAddr := conn.RemoteAddr().String() @@ -104,6 +144,7 @@ func (this *Service) handleConn(conn net.Conn) { exitFunc := func() { conn.Close() + wg.Done() if err != nil { logError("conn handler err:", err) } @@ -129,19 +170,19 @@ func (this *Service) handleConn(conn net.Conn) { err = Err(err) return } - for _, mw := range this.preMw { + for _, mw := range svc.preMw { err = mw(context) if err != nil { err = Err(err) return } } - err = this.Route(context) + err = svc.Route(context) if err != nil { err = Err(err) return } - for _, mw := range this.postMw { + for _, mw := range svc.postMw { err = mw(context) if err != nil { err = Err(err) @@ -151,8 +192,8 @@ func (this *Service) handleConn(conn net.Conn) { return } -func (this *Service) Route(context *Context) error { - handler, ok := this.handlers[context.reqRPC.Method] +func (svc *Service) Route(context *Context) error { + handler, ok := svc.handlers[context.reqRPC.Method] if ok { return Err(handler(context)) } @@ -200,14 +241,14 @@ func (context *Context) ReadBin(writer io.Writer) error { func (context *Context) BindMethod() error { var err error - err = json.Unmarshal(context.reqPacket.rcpPayload, context.reqRPC) + err = encoder.Unmarshal(context.reqPacket.rcpPayload, context.reqRPC) return Err(err) } func (context *Context) BindParams(params any) error { var err error context.reqRPC.Params = params - err = json.Unmarshal(context.reqPacket.rcpPayload, context.reqRPC) + err = encoder.Unmarshal(context.reqPacket.rcpPayload, context.reqRPC) if err != nil { return Err(err) } diff --git a/tools.go b/tools.go index d7e1da6..e067353 100644 --- a/tools.go +++ b/tools.go @@ -18,7 +18,7 @@ func ReadBytes(reader io.Reader, size int64) ([]byte, error) { func CopyBytes(reader io.Reader, writer io.Writer, dataSize int64) (int64, error) { var err error - var bSize int64 = 1024 * 4 + var bSize int64 = 1024 * 16 var total int64 = 0 var remains int64 = dataSize buffer := make([]byte, bSize) diff --git a/xauth.go b/xauth.go index 28f399e..91c0b55 100644 --- a/xauth.go +++ b/xauth.go @@ -17,9 +17,9 @@ func init() { } type Auth struct { - Ident []byte `json:"ident,omitempty"` - Salt []byte `json:"salt,omitempty"` - Hash []byte `json:"hash,omitempty"` + Ident []byte `msgpack:"ident" json:"ident"` + Salt []byte `msgpack:"salt" json:"salt"` + Hash []byte `msgpack:"hash" json:"hash"` } func NewAuth() *Auth {