From 59d850da3cf62cd7fab3c370802acb3d74ab51d6 Mon Sep 17 00:00:00 2001 From: Oleg Borodin Date: Sat, 11 Feb 2023 07:52:40 +0200 Subject: [PATCH] go fmt --- client.go | 566 +++++++++++++++++++++++++------------------------- context.go | 166 ++++++++------- empty.go | 4 +- error.go | 47 ++--- exec_test.go | 541 ++++++++++++++++++++++++----------------------- faddr.go | 16 +- faddr_test.go | 82 ++++---- fconn.go | 58 +++--- header.go | 120 ++++++----- logger.go | 28 +-- midware.go | 24 +-- packet.go | 12 +- request.go | 24 +-- response.go | 19 +- server.go | 472 +++++++++++++++++++++-------------------- tools.go | 84 ++++---- validate.go | 284 +++++++++++++------------ xauth.go | 64 +++--- 18 files changed, 1293 insertions(+), 1318 deletions(-) diff --git a/client.go b/client.go index 1a0277c..05ae16a 100644 --- a/client.go +++ b/client.go @@ -5,332 +5,328 @@ package dsrpc import ( - "errors" - "fmt" - "io" - "net" - "sync" + "errors" + "fmt" + "io" + "net" + "sync" - encoder "github.com/vmihailenco/msgpack/v5" + encoder "github.com/vmihailenco/msgpack/v5" ) - func Put(address string, method string, reader io.Reader, size int64, param, result any, auth *Auth) error { - var err error - - addr, err := net.ResolveTCPAddr("tcp", address) - if err != nil { - err = fmt.Errorf("unable to resolve adddress: %s", err) - return Err(err) - } - conn, err := net.DialTCP("tcp", nil, addr) - if err != nil { - return Err(err) - } - defer conn.Close() - - //err = conn.SetKeepAlive(true) - //if err != nil { - // err = fmt.Errorf("unable to set keepalive: %s", err) - // return Err(err) - //} - - //err = conn.SetKeepAlivePeriod(10 * time.Second) - //if err != nil { - // err = fmt.Errorf("unable to set keepalive period: %s", err) - // return Err(err) - //} - - return ConnPut(conn, method, reader, size, param, result, auth) + var err error + + addr, err := net.ResolveTCPAddr("tcp", address) + if err != nil { + err = fmt.Errorf("unable to resolve adddress: %s", err) + return Err(err) + } + conn, err := net.DialTCP("tcp", nil, addr) + if err != nil { + return Err(err) + } + defer conn.Close() + + //err = conn.SetKeepAlive(true) + //if err != nil { + // err = fmt.Errorf("unable to set keepalive: %s", err) + // return Err(err) + //} + + //err = conn.SetKeepAlivePeriod(10 * time.Second) + //if err != nil { + // err = fmt.Errorf("unable to set keepalive period: %s", err) + // return Err(err) + //} + + return ConnPut(conn, method, reader, size, param, result, auth) } - func ConnPut(conn net.Conn, method string, reader io.Reader, size int64, param, result any, auth *Auth) error { - var err error - context := CreateContext(conn) - context.reqRPC.Method = method - context.reqRPC.Params = param - context.reqRPC.Auth = auth - context.resRPC.Result = result - - context.binReader = reader - context.binWriter = conn - - context.reqHeader.binSize = size - - if context.reqRPC.Params == nil { - context.reqRPC.Params = NewEmpty() - } - - err = context.CreateRequest() - if err != nil { - return Err(err) - } - err = context.WriteRequest() - if err != nil { - return Err(err) - } - - var wg sync.WaitGroup - errChan := make(chan error, 1) - - wg.Add(1) - go context.ReadResponseAsync(&wg, errChan) - - wg.Add(1) - go context.UploadBinAsync(&wg) - - wg.Wait() - err = <- errChan - if err != nil { - return Err(err) - } - err = context.BindResponse() - if err != nil { - return Err(err) - } - return Err(err) + var err error + context := CreateContext(conn) + context.reqRPC.Method = method + context.reqRPC.Params = param + context.reqRPC.Auth = auth + context.resRPC.Result = result + + context.binReader = reader + context.binWriter = conn + + context.reqHeader.binSize = size + + if context.reqRPC.Params == nil { + context.reqRPC.Params = NewEmpty() + } + + err = context.CreateRequest() + if err != nil { + return Err(err) + } + err = context.WriteRequest() + if err != nil { + return Err(err) + } + + var wg sync.WaitGroup + errChan := make(chan error, 1) + + wg.Add(1) + go context.ReadResponseAsync(&wg, errChan) + + wg.Add(1) + go context.UploadBinAsync(&wg) + + wg.Wait() + err = <-errChan + if err != nil { + return Err(err) + } + err = context.BindResponse() + if err != nil { + return Err(err) + } + return Err(err) } func Get(address string, method string, writer io.Writer, param, result any, auth *Auth) error { - var err error - - addr, err := net.ResolveTCPAddr("tcp", address) - if err != nil { - err = fmt.Errorf("unable to resolve adddress: %s", err) - return Err(err) - } - conn, err := net.DialTCP("tcp", nil, addr) - if err != nil { - return Err(err) - } - defer conn.Close() - - //err = conn.SetKeepAlive(true) - //if err != nil { - // err = fmt.Errorf("unable to set keepalive: %s", err) - // return Err(err) - //} - - //err = conn.SetKeepAlivePeriod(10 * time.Second) - //if err != nil { - // err = fmt.Errorf("unable to set keepalive period: %s", err) - // return Err(err) - //} - - return ConnGet(conn, method, writer, param, result, auth) + var err error + + addr, err := net.ResolveTCPAddr("tcp", address) + if err != nil { + err = fmt.Errorf("unable to resolve adddress: %s", err) + return Err(err) + } + conn, err := net.DialTCP("tcp", nil, addr) + if err != nil { + return Err(err) + } + defer conn.Close() + + //err = conn.SetKeepAlive(true) + //if err != nil { + // err = fmt.Errorf("unable to set keepalive: %s", err) + // return Err(err) + //} + + //err = conn.SetKeepAlivePeriod(10 * time.Second) + //if err != nil { + // err = fmt.Errorf("unable to set keepalive period: %s", err) + // return Err(err) + //} + + return ConnGet(conn, method, writer, param, result, auth) } func ConnGet(conn net.Conn, method string, writer io.Writer, param, result any, auth *Auth) error { - var err error - - context := CreateContext(conn) - context.reqRPC.Method = method - context.reqRPC.Params = param - context.reqRPC.Auth = auth - context.resRPC.Result = result - - context.binReader = conn - context.binWriter = writer - - if context.reqRPC.Params == nil { - context.reqRPC.Params = NewEmpty() - } - err = context.CreateRequest() - if err != nil { - return Err(err) - } - err = context.WriteRequest() - if err != nil { - return Err(err) - } - err = context.ReadResponse() - if err != nil { - return Err(err) - } - err = context.DownloadBin() - if err != nil { - return Err(err) - } - err = context.BindResponse() - if err != nil { - return Err(err) - } - return Err(err) + var err error + + context := CreateContext(conn) + context.reqRPC.Method = method + context.reqRPC.Params = param + context.reqRPC.Auth = auth + context.resRPC.Result = result + + context.binReader = conn + context.binWriter = writer + + if context.reqRPC.Params == nil { + context.reqRPC.Params = NewEmpty() + } + err = context.CreateRequest() + if err != nil { + return Err(err) + } + err = context.WriteRequest() + if err != nil { + return Err(err) + } + err = context.ReadResponse() + if err != nil { + return Err(err) + } + err = context.DownloadBin() + if err != nil { + return Err(err) + } + err = context.BindResponse() + if err != nil { + return Err(err) + } + return Err(err) } func Exec(address, method string, param any, result any, auth *Auth) error { - var err error - - addr, err := net.ResolveTCPAddr("tcp", address) - if err != nil { - err = fmt.Errorf("unable to resolve adddress: %s", err) - return Err(err) - } - conn, err := net.DialTCP("tcp", nil, addr) - if err != nil { - return Err(err) - } - defer conn.Close() - - //err = conn.SetKeepAlive(true) - //if err != nil { - // err = fmt.Errorf("unable to set keepalive: %s", err) - // return Err(err) - //} - //err = conn.SetKeepAlivePeriod(10 * time.Second) - //if err != nil { - // err = fmt.Errorf("unable to set keepalive period: %s", err) - // return Err(err) - //} - - err = ConnExec(conn, method, param, result, auth) - if err != nil { - return Err(err) - } - return Err(err) + var err error + + addr, err := net.ResolveTCPAddr("tcp", address) + if err != nil { + err = fmt.Errorf("unable to resolve adddress: %s", err) + return Err(err) + } + conn, err := net.DialTCP("tcp", nil, addr) + if err != nil { + return Err(err) + } + defer conn.Close() + + //err = conn.SetKeepAlive(true) + //if err != nil { + // err = fmt.Errorf("unable to set keepalive: %s", err) + // return Err(err) + //} + //err = conn.SetKeepAlivePeriod(10 * time.Second) + //if err != nil { + // err = fmt.Errorf("unable to set keepalive period: %s", err) + // return Err(err) + //} + + err = ConnExec(conn, method, param, result, auth) + if err != nil { + return Err(err) + } + return Err(err) } - func ConnExec(conn net.Conn, method string, param any, result any, auth *Auth) error { - var err error - - context := CreateContext(conn) - context.reqRPC.Method = method - context.reqRPC.Params = param - context.reqRPC.Auth = auth - context.resRPC.Result = result - - if context.reqRPC.Params == nil { - context.reqRPC.Params = NewEmpty() - } - - err = context.CreateRequest() - if err != nil { - return Err(err) - } - err = context.WriteRequest() - if err != nil { - return Err(err) - } - err = context.ReadResponse() - if err != nil { - return Err(err) - } - err = context.BindResponse() - if err != nil { - return Err(err) - } - return Err(err) + var err error + + context := CreateContext(conn) + context.reqRPC.Method = method + context.reqRPC.Params = param + context.reqRPC.Auth = auth + context.resRPC.Result = result + + if context.reqRPC.Params == nil { + context.reqRPC.Params = NewEmpty() + } + + err = context.CreateRequest() + if err != nil { + return Err(err) + } + err = context.WriteRequest() + if err != nil { + return Err(err) + } + err = context.ReadResponse() + if err != nil { + return Err(err) + } + err = context.BindResponse() + if err != nil { + return Err(err) + } + return Err(err) } - func (context *Context) CreateRequest() error { - var err error - - context.reqPacket.rcpPayload, err = context.reqRPC.Pack() - if err != nil { - return Err(err) - } - rpcSize := int64(len(context.reqPacket.rcpPayload)) - context.reqHeader.rpcSize = rpcSize - - context.reqPacket.header, err = context.reqHeader.Pack() - if err != nil { - return Err(err) - } - return Err(err) + var err error + + context.reqPacket.rcpPayload, err = context.reqRPC.Pack() + if err != nil { + return Err(err) + } + rpcSize := int64(len(context.reqPacket.rcpPayload)) + context.reqHeader.rpcSize = rpcSize + + context.reqPacket.header, err = context.reqHeader.Pack() + if err != nil { + return Err(err) + } + return Err(err) } func (context *Context) WriteRequest() error { - var err error - _, err = context.sockWriter.Write(context.reqPacket.header) - if err != nil { - return Err(err) - } - _, err = context.sockWriter.Write(context.reqPacket.rcpPayload) - if err != nil { - return Err(err) - } - return Err(err) + var err error + _, err = context.sockWriter.Write(context.reqPacket.header) + if err != nil { + return Err(err) + } + _, err = context.sockWriter.Write(context.reqPacket.rcpPayload) + if err != nil { + return Err(err) + } + return Err(err) } func (context *Context) UploadBin() error { - var err error - _, err = CopyBytes(context.binReader, context.binWriter, context.reqHeader.binSize) - return Err(err) + var err error + _, err = CopyBytes(context.binReader, context.binWriter, context.reqHeader.binSize) + return Err(err) } func (context *Context) ReadResponse() error { - var err error - - context.resPacket.header, err = ReadBytes(context.sockReader, headerSize) - if err != nil { - return Err(err) - } - context.resHeader, err = UnpackHeader(context.resPacket.header) - if err != nil { - return Err(err) - } - rpcSize := context.resHeader.rpcSize - context.resPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize) - if err != nil { - return Err(err) - } - return Err(err) + var err error + + context.resPacket.header, err = ReadBytes(context.sockReader, headerSize) + if err != nil { + return Err(err) + } + context.resHeader, err = UnpackHeader(context.resPacket.header) + if err != nil { + return Err(err) + } + rpcSize := context.resHeader.rpcSize + context.resPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize) + if err != nil { + return Err(err) + } + return Err(err) } func (context *Context) UploadBinAsync(wg *sync.WaitGroup) { - exitFunc := func() { - wg.Done() - } - defer exitFunc() - _, _ = CopyBytes(context.binReader, context.binWriter, context.reqHeader.binSize) - return + exitFunc := func() { + wg.Done() + } + defer exitFunc() + _, _ = CopyBytes(context.binReader, context.binWriter, context.reqHeader.binSize) + return } func (context *Context) ReadResponseAsync(wg *sync.WaitGroup, errChan chan error) { - var err error - exitFunc := func() { - errChan <- err - wg.Done() - } - defer exitFunc() - context.resPacket.header, err = ReadBytes(context.sockReader, headerSize) - if err != nil { - err = Err(err) - return - } - context.resHeader, err = UnpackHeader(context.resPacket.header) - if err != nil { - err = Err(err) - return - } - rpcSize := context.resHeader.rpcSize - context.resPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize) - if err != nil { - err = Err(err) - return - } - return + var err error + exitFunc := func() { + errChan <- err + wg.Done() + } + defer exitFunc() + context.resPacket.header, err = ReadBytes(context.sockReader, headerSize) + if err != nil { + err = Err(err) + return + } + context.resHeader, err = UnpackHeader(context.resPacket.header) + if err != nil { + err = Err(err) + return + } + rpcSize := context.resHeader.rpcSize + context.resPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize) + if err != nil { + err = Err(err) + return + } + return } func (context *Context) DownloadBin() error { - var err error - _, err = CopyBytes(context.binReader, context.binWriter, context.resHeader.binSize) - return Err(err) + var err error + _, err = CopyBytes(context.binReader, context.binWriter, context.resHeader.binSize) + return Err(err) } func (context *Context) BindResponse() error { - var err error - - err = encoder.Unmarshal(context.resPacket.rcpPayload, context.resRPC) - if err != nil { - return Err(err) - } - if len(context.resRPC.Error) > 0 { - err = errors.New(context.resRPC.Error) - return Err(err) - } - return Err(err) + var err error + + err = encoder.Unmarshal(context.resPacket.rcpPayload, context.resRPC) + if err != nil { + return Err(err) + } + if len(context.resRPC.Error) > 0 { + err = errors.New(context.resRPC.Error) + return Err(err) + } + return Err(err) } diff --git a/context.go b/context.go index c38b4d6..d78ea30 100644 --- a/context.go +++ b/context.go @@ -5,152 +5,148 @@ package dsrpc import ( - "io" - "net" - "time" + "io" + "net" + "time" ) type Context struct { - start time.Time - remoteHost string - sockReader io.Reader - sockWriter io.Writer + start time.Time + remoteHost string + sockReader io.Reader + sockWriter io.Writer - reqHeader *Header - reqRPC *Request - reqPacket *Packet + reqHeader *Header + reqRPC *Request + reqPacket *Packet - resPacket *Packet - resHeader *Header - resRPC *Response + resPacket *Packet + resHeader *Header + resRPC *Response - binReader io.Reader - binWriter io.Writer + binReader io.Reader + binWriter io.Writer } - func NewContext() *Context { - context := &Context{} - context.start = time.Now() - return context + context := &Context{} + context.start = time.Now() + return context } func CreateContext(conn net.Conn) *Context { - context := &Context{} - context.start = time.Now() - context.sockReader = conn - context.sockWriter = conn + context := &Context{} + context.start = time.Now() + context.sockReader = conn + context.sockWriter = conn - context.reqPacket = NewPacket() - context.resPacket = NewPacket() + context.reqPacket = NewPacket() + context.resPacket = NewPacket() - context.reqHeader = NewHeader() - context.reqRPC = NewRequest() + context.reqHeader = NewHeader() + context.reqRPC = NewRequest() - context.resHeader = NewHeader() - context.resRPC = NewResponse() - context.resRPC = NewResponse() + context.resHeader = NewHeader() + context.resRPC = NewResponse() + context.resRPC = NewResponse() - return context + return context } -func (context *Context) Request() *Request { - return context.reqRPC +func (context *Context) Request() *Request { + return context.reqRPC } func (context *Context) RemoteHost() string { - return context.remoteHost + return context.remoteHost } func (context *Context) Start() time.Time { - return context.start + return context.start } func (context *Context) Method() string { - var method string - if context.reqRPC != nil { - method = context.reqRPC.Method - } - return method + var method string + if context.reqRPC != nil { + method = context.reqRPC.Method + } + return method } func (context *Context) ReqRpcSize() int64 { - var size int64 - if context.reqHeader != nil { - size = context.reqHeader.rpcSize - } - return size + var size int64 + if context.reqHeader != nil { + size = context.reqHeader.rpcSize + } + return size } - func (context *Context) ReqBinSize() int64 { - var size int64 - if context.reqHeader != nil { - size = context.reqHeader.binSize - } - return size + var size int64 + if context.reqHeader != nil { + size = context.reqHeader.binSize + } + return size } func (context *Context) ResBinSize() int64 { - var size int64 - if context.resHeader != nil { - size = context.resHeader.binSize - } - return size + var size int64 + if context.resHeader != nil { + size = context.resHeader.binSize + } + return size } func (context *Context) ResRpcSize() int64 { - var size int64 - if context.resHeader != nil { - size = context.resHeader.rpcSize - } - return size + var size int64 + if context.resHeader != nil { + size = context.resHeader.rpcSize + } + return size } func (context *Context) ReqSize() int64 { - var size int64 - if context.reqHeader != nil { - size += context.reqHeader.binSize - size += context.reqHeader.rpcSize - } - return size + var size int64 + if context.reqHeader != nil { + size += context.reqHeader.binSize + size += context.reqHeader.rpcSize + } + return size } func (context *Context) ResSize() int64 { - var size int64 - if context.resHeader != nil { - size += context.resHeader.binSize - size += context.resHeader.rpcSize - } - return size + var size int64 + if context.resHeader != nil { + size += context.resHeader.binSize + size += context.resHeader.rpcSize + } + return size } - - -func (context *Context) SetAuthIdent(ident []byte) { - context.reqRPC.Auth.Ident = ident +func (context *Context) SetAuthIdent(ident []byte) { + context.reqRPC.Auth.Ident = ident } -func (context *Context) SetAuthSalt(salt []byte) { - context.reqRPC.Auth.Salt = salt +func (context *Context) SetAuthSalt(salt []byte) { + context.reqRPC.Auth.Salt = salt } -func (context *Context) SetAuthHash(hash []byte) { - context.reqRPC.Auth.Hash = hash +func (context *Context) SetAuthHash(hash []byte) { + context.reqRPC.Auth.Hash = hash } func (context *Context) AuthIdent() []byte { - return context.reqRPC.Auth.Ident + return context.reqRPC.Auth.Ident } func (context *Context) AuthSalt() []byte { - return context.reqRPC.Auth.Salt + return context.reqRPC.Auth.Salt } func (context *Context) AuthHash() []byte { - return context.reqRPC.Auth.Hash + return context.reqRPC.Auth.Hash } func (context *Context) Auth() *Auth { - return context.reqRPC.Auth + return context.reqRPC.Auth } diff --git a/empty.go b/empty.go index 853f1b7..4ab4e05 100644 --- a/empty.go +++ b/empty.go @@ -6,8 +6,8 @@ package dsrpc -type Empty struct {} +type Empty struct{} func NewEmpty() *Empty { - return &Empty{} + return &Empty{} } diff --git a/error.go b/error.go index 100ebb1..a25a0ea 100644 --- a/error.go +++ b/error.go @@ -5,39 +5,38 @@ package dsrpc import ( - "fmt" - "runtime" - "io" + "fmt" + "io" + "runtime" ) var develMode bool = false var debugMode bool = false - func SetDevelMode(mode bool) { - develMode = mode + develMode = mode } func SetDebugMode(mode bool) { - debugMode = mode + debugMode = mode } func Err(err error) error { - switch err { - case io.EOF: - return err - } - if err != nil { - switch { - case develMode == true: - pc, filename, line, _ := runtime.Caller(1) - funcName := runtime.FuncForPC(pc).Name() - err = fmt.Errorf(" %s:%d:%s:%s", filename, line, funcName, err.Error()) - case debugMode == true: - pc, _, line, _ := runtime.Caller(1) - funcName := runtime.FuncForPC(pc).Name() - err = fmt.Errorf(" %s:%d:%s ", funcName, line, err.Error()) - default: - } - } - return err + switch err { + case io.EOF: + return err + } + if err != nil { + switch { + case develMode == true: + pc, filename, line, _ := runtime.Caller(1) + funcName := runtime.FuncForPC(pc).Name() + err = fmt.Errorf(" %s:%d:%s:%s", filename, line, funcName, err.Error()) + case debugMode == true: + pc, _, line, _ := runtime.Caller(1) + funcName := runtime.FuncForPC(pc).Name() + err = fmt.Errorf(" %s:%d:%s ", funcName, line, err.Error()) + default: + } + } + return err } diff --git a/exec_test.go b/exec_test.go index 7649166..8d22e58 100644 --- a/exec_test.go +++ b/exec_test.go @@ -5,370 +5,361 @@ package dsrpc import ( - "bytes" - "encoding/json" - "errors" - "io" - "math/rand" - "testing" - "time" - - "github.com/stretchr/testify/require" + "bytes" + "encoding/json" + "errors" + "io" + "math/rand" + "testing" + "time" + + "github.com/stretchr/testify/require" ) func TestLocalExec(t *testing.T) { - var err error - params := NewHelloParams() - params.Message = "hello server!" - result := NewHelloResult() + var err error + params := NewHelloParams() + params.Message = "hello server!" + result := NewHelloResult() - auth := CreateAuth([]byte("qwert"), []byte("12345")) + auth := CreateAuth([]byte("qwert"), []byte("12345")) - err = LocalExec(HelloMethod, params, result, auth, helloHandler) - require.NoError(t, err) - resultJSON, _ := json.Marshal(result) - logDebug("method result:", string(resultJSON)) + err = LocalExec(HelloMethod, params, result, auth, helloHandler) + require.NoError(t, err) + resultJSON, _ := json.Marshal(result) + logDebug("method result:", string(resultJSON)) } - func TestLocalSave(t *testing.T) { - var err error + var err error - params := NewSaveParams() - params.Message = "save data!" - result := NewHelloResult() - auth := CreateAuth([]byte("qwert"), []byte("12345")) + params := NewSaveParams() + params.Message = "save data!" + result := NewHelloResult() + auth := CreateAuth([]byte("qwert"), []byte("12345")) - var binSize int64 = 16 - rand.Seed(time.Now().UnixNano()) - binBytes := make([]byte, binSize) - rand.Read(binBytes) + var binSize int64 = 16 + rand.Seed(time.Now().UnixNano()) + binBytes := make([]byte, binSize) + rand.Read(binBytes) - reader := bytes.NewReader(binBytes) + reader := bytes.NewReader(binBytes) - err = LocalPut(SaveMethod, reader, binSize, params, result, auth, saveHandler) - require.NoError(t, err) + err = LocalPut(SaveMethod, reader, binSize, params, result, auth, saveHandler) + require.NoError(t, err) - resultJSON, _ := json.Marshal(result) - logDebug("method result:", string(resultJSON)) + resultJSON, _ := json.Marshal(result) + logDebug("method result:", string(resultJSON)) } - func TestLocalLoad(t *testing.T) { - var err error + var err error - params := NewLoadParams() - params.Message = "load data!" - result := NewHelloResult() - auth := CreateAuth([]byte("qwert"), []byte("12345")) + params := NewLoadParams() + params.Message = "load data!" + result := NewHelloResult() + auth := CreateAuth([]byte("qwert"), []byte("12345")) - binBytes := make([]byte, 0) - writer := bytes.NewBuffer(binBytes) + binBytes := make([]byte, 0) + writer := bytes.NewBuffer(binBytes) - err = LocalGet(LoadMethod, writer, params, result, auth, loadHandler) - require.NoError(t, err) + err = LocalGet(LoadMethod, writer, params, result, auth, loadHandler) + require.NoError(t, err) - resultJSON, _ := json.Marshal(result) - logDebug("method result:", string(resultJSON)) - logDebug("bin size:", len(writer.Bytes())) + resultJSON, _ := json.Marshal(result) + logDebug("method result:", string(resultJSON)) + logDebug("bin size:", len(writer.Bytes())) } - func TestNetExec(t *testing.T) { - go testServ(false) - time.Sleep(10 * time.Millisecond) - err := clientHello() + go testServ(false) + time.Sleep(10 * time.Millisecond) + err := clientHello() - require.NoError(t, err) + require.NoError(t, err) } func TestNetSave(t *testing.T) { - go testServ(false) - time.Sleep(10 * time.Millisecond) - err := clientSave() - require.NoError(t, err) + go testServ(false) + time.Sleep(10 * time.Millisecond) + err := clientSave() + require.NoError(t, err) } func TestNetLoad(t *testing.T) { - go testServ(false) - time.Sleep(10 * time.Millisecond) - err := clientLoad() - require.NoError(t, err) + go testServ(false) + time.Sleep(10 * time.Millisecond) + err := clientLoad() + require.NoError(t, err) } func BenchmarkNetPut(b *testing.B) { - go testServ(true) - time.Sleep(10 * time.Millisecond) - clientSave() - - pBench := func(pb *testing.PB) { - for pb.Next() { - clientSave() - } - } - b.SetParallelism(10) - b.RunParallel(pBench) + go testServ(true) + time.Sleep(10 * time.Millisecond) + clientSave() + + pBench := func(pb *testing.PB) { + for pb.Next() { + clientSave() + } + } + b.SetParallelism(10) + b.RunParallel(pBench) } func clientHello() error { - var err error - - params := NewHelloParams() - params.Message = "hello server!" - result := NewHelloResult() - auth := CreateAuth([]byte("qwert"), []byte("12345")) - - var binSize int64 = 16 - rand.Seed(time.Now().UnixNano()) - binBytes := make([]byte, binSize) - rand.Read(binBytes) - - err = Exec("127.0.0.1:8081", HelloMethod, params, result, auth) - if err != nil { - logError("method err:", err) - return err - } - resultJSON, _ := json.Marshal(result) - logDebug("method result:", string(resultJSON)) - return err + var err error + + params := NewHelloParams() + params.Message = "hello server!" + result := NewHelloResult() + auth := CreateAuth([]byte("qwert"), []byte("12345")) + + var binSize int64 = 16 + rand.Seed(time.Now().UnixNano()) + binBytes := make([]byte, binSize) + rand.Read(binBytes) + + err = Exec("127.0.0.1:8081", HelloMethod, params, result, auth) + if err != nil { + logError("method err:", err) + return err + } + resultJSON, _ := json.Marshal(result) + logDebug("method result:", string(resultJSON)) + return err } - func clientSave() error { - var err error - - params := NewSaveParams() - params.Message = "save data!" - result := NewHelloResult() - auth := CreateAuth([]byte("qwert"), []byte("12345")) - - var binSize int64 = 16 - rand.Seed(time.Now().UnixNano()) - binBytes := make([]byte, binSize) - rand.Read(binBytes) - - reader := bytes.NewReader(binBytes) - - err = Put("127.0.0.1:8081", SaveMethod, reader, binSize, params, result, auth) - if err != nil { - logError("method err:", err) - return err - } - resultJSON, _ := json.Marshal(result) - logDebug("method result:", string(resultJSON)) - return err + var err error + + params := NewSaveParams() + params.Message = "save data!" + result := NewHelloResult() + auth := CreateAuth([]byte("qwert"), []byte("12345")) + + var binSize int64 = 16 + rand.Seed(time.Now().UnixNano()) + binBytes := make([]byte, binSize) + rand.Read(binBytes) + + reader := bytes.NewReader(binBytes) + + err = Put("127.0.0.1:8081", SaveMethod, reader, binSize, params, result, auth) + if err != nil { + logError("method err:", err) + return err + } + resultJSON, _ := json.Marshal(result) + logDebug("method result:", string(resultJSON)) + return err } - func clientLoad() error { - var err error - - params := NewLoadParams() - params.Message = "load data!" - result := NewHelloResult() - auth := CreateAuth([]byte("qwert"), []byte("12345")) - - - binBytes := make([]byte, 0) - writer := bytes.NewBuffer(binBytes) - - err = Get("127.0.0.1:8081", LoadMethod, writer, params, result, auth) - if err != nil { - logError("method err:", err) - return err - } - resultJSON, _ := json.Marshal(result) - logDebug("method result:", string(resultJSON)) - logDebug("bin size:", len(writer.Bytes())) - return err + var err error + + params := NewLoadParams() + params.Message = "load data!" + result := NewHelloResult() + auth := CreateAuth([]byte("qwert"), []byte("12345")) + + binBytes := make([]byte, 0) + writer := bytes.NewBuffer(binBytes) + + err = Get("127.0.0.1:8081", LoadMethod, writer, params, result, auth) + if err != nil { + logError("method err:", err) + return err + } + resultJSON, _ := json.Marshal(result) + logDebug("method result:", string(resultJSON)) + logDebug("bin size:", len(writer.Bytes())) + return err } - var testServRun bool = false func testServ(quiet bool) error { - var err error - - if testServRun { - return err - } - testServRun = true - - if quiet { - SetAccessWriter(io.Discard) - SetMessageWriter(io.Discard) - } - serv := NewService() - serv.Handler(HelloMethod, helloHandler) - serv.Handler(SaveMethod, saveHandler) - serv.Handler(LoadMethod, loadHandler) - - serv.PreMiddleware(LogRequest) - serv.PreMiddleware(auth) - - serv.PostMiddleware(LogResponse) - serv.PostMiddleware(LogAccess) - - err = serv.Listen(":8081") - if err != nil { - return err - } - return err + var err error + + if testServRun { + return err + } + testServRun = true + + if quiet { + SetAccessWriter(io.Discard) + SetMessageWriter(io.Discard) + } + serv := NewService() + serv.Handler(HelloMethod, helloHandler) + serv.Handler(SaveMethod, saveHandler) + serv.Handler(LoadMethod, loadHandler) + + serv.PreMiddleware(LogRequest) + serv.PreMiddleware(auth) + + serv.PostMiddleware(LogResponse) + serv.PostMiddleware(LogAccess) + + err = serv.Listen(":8081") + if err != nil { + return err + } + return err } func auth(context *Context) error { - var err error - reqIdent := context.AuthIdent() - reqSalt := context.AuthSalt() - reqHash := context.AuthHash() - - ident := reqIdent - pass := []byte("12345") - - auth := context.Auth() - logDebug("auth ", string(auth.JSON())) - - ok := CheckHash(ident, pass, reqSalt, reqHash) - logDebug("auth ok:", ok) - if !ok { - err = errors.New("auth ident or pass missmatch") - context.SendError(err) - return err - } - return err + var err error + reqIdent := context.AuthIdent() + reqSalt := context.AuthSalt() + reqHash := context.AuthHash() + + ident := reqIdent + pass := []byte("12345") + + auth := context.Auth() + logDebug("auth ", string(auth.JSON())) + + ok := CheckHash(ident, pass, reqSalt, reqHash) + logDebug("auth ok:", ok) + if !ok { + err = errors.New("auth ident or pass missmatch") + context.SendError(err) + return err + } + return err } func helloHandler(context *Context) error { - var err error - params := NewHelloParams() - - err = context.BindParams(params) - if err != nil { - return err - } - - err = context.ReadBin(io.Discard) - if err != nil { - context.SendError(err) - return err - } - - result := NewHelloResult() - result.Message = "hello, client!" - - err = context.SendResult(result, 0) - if err != nil { - return err - } - return err + var err error + params := NewHelloParams() + + err = context.BindParams(params) + if err != nil { + return err + } + + err = context.ReadBin(io.Discard) + if err != nil { + context.SendError(err) + return err + } + + result := NewHelloResult() + result.Message = "hello, client!" + + err = context.SendResult(result, 0) + if err != nil { + return err + } + return err } func saveHandler(context *Context) error { - var err error - params := NewSaveParams() - - err = context.BindParams(params) - if err != nil { - return err - } - - bufferBytes := make([]byte, 0, 1024) - binWriter := bytes.NewBuffer(bufferBytes) - - err = context.ReadBin(binWriter) - if err != nil { - context.SendError(err) - return err - } - - result := NewSaveResult() - result.Message = "saved successfully!" - - err = context.SendResult(result, 0) - if err != nil { - return err - } - return err + var err error + params := NewSaveParams() + + err = context.BindParams(params) + if err != nil { + return err + } + + bufferBytes := make([]byte, 0, 1024) + binWriter := bytes.NewBuffer(bufferBytes) + + err = context.ReadBin(binWriter) + if err != nil { + context.SendError(err) + return err + } + + result := NewSaveResult() + result.Message = "saved successfully!" + + err = context.SendResult(result, 0) + if err != nil { + return err + } + return err } func loadHandler(context *Context) error { - var err error - params := NewSaveParams() - - err = context.BindParams(params) - if err != nil { - return err - } - - err = context.ReadBin(io.Discard) - if err != nil { - context.SendError(err) - return err - } - - var binSize int64 = 1024 - rand.Seed(time.Now().UnixNano()) - binBytes := make([]byte, binSize) - rand.Read(binBytes) - - binReader := bytes.NewReader(binBytes) - - result := NewSaveResult() - result.Message = "load successfully!" - - err = context.SendResult(result, binSize) - if err != nil { - return err - } - binWriter := context.BinWriter() - _, err = CopyBytes(binReader, binWriter, binSize) - if err != nil { - return err - } - - return err + var err error + params := NewSaveParams() + + err = context.BindParams(params) + if err != nil { + return err + } + + err = context.ReadBin(io.Discard) + if err != nil { + context.SendError(err) + return err + } + + var binSize int64 = 1024 + rand.Seed(time.Now().UnixNano()) + binBytes := make([]byte, binSize) + rand.Read(binBytes) + + binReader := bytes.NewReader(binBytes) + + result := NewSaveResult() + result.Message = "load successfully!" + + err = context.SendResult(result, binSize) + if err != nil { + return err + } + binWriter := context.BinWriter() + _, err = CopyBytes(binReader, binWriter, binSize) + if err != nil { + return err + } + + return err } - const HelloMethod string = "hello" type HelloParams struct { - Message string `json:"message" msgpack:"message"` + Message string `json:"message" msgpack:"message"` } func NewHelloParams() *HelloParams { - return &HelloParams{} + return &HelloParams{} } type HelloResult struct { - Message string `json:"message" msgpack:"message"` + Message string `json:"message" msgpack:"message"` } func NewHelloResult() *HelloResult { - return &HelloResult{} + return &HelloResult{} } - const SaveMethod string = "save" + type SaveParams HelloParams type SaveResult HelloResult func NewSaveParams() *SaveParams { - return &SaveParams{} + return &SaveParams{} } func NewSaveResult() *SaveResult { - return &SaveResult{} + return &SaveResult{} } - - const LoadMethod string = "load" + type LoadParams HelloParams type LoadResult HelloResult func NewLoadParams() *LoadParams { - return &LoadParams{} + return &LoadParams{} } func NewLoadResult() *LoadResult { - return &LoadResult{} + return &LoadResult{} } diff --git a/faddr.go b/faddr.go index 075df38..952e1b1 100644 --- a/faddr.go +++ b/faddr.go @@ -5,21 +5,21 @@ package dsrpc type FAddr struct { - network string - address string + network string + address string } func NewFAddr() *FAddr { - var addr FAddr - addr.network = "tcp" - addr.address = "127.0.0.1:5000" - return &addr + var addr FAddr + addr.network = "tcp" + addr.address = "127.0.0.1:5000" + return &addr } func (addr *FAddr) Network() string { - return addr.network + return addr.network } func (addr *FAddr) String() string { - return addr.address + return addr.address } diff --git a/faddr_test.go b/faddr_test.go index 75414d8..a94d0cd 100644 --- a/faddr_test.go +++ b/faddr_test.go @@ -7,51 +7,51 @@ package dsrpc import ( - "net" - "testing" - "github.com/stretchr/testify/require" + "github.com/stretchr/testify/require" + "net" + "testing" ) func TestFConn0(t *testing.T) { - var cConn, sConn net.Conn - sConn, cConn = NewFConn() - - cData := []byte("qwerty") - count := 10 - - for i := 0; i < count; i++ { - wc, err := cConn.Write(cData) - if err != nil { - t.Error(err) - } - require.Equal(t, wc, len(cData)) - - sData := make([]byte, len(cData)) - rc, err := sConn.Read(sData) - require.NoError(t, err) - require.Equal(t, rc, len(cData)) - require.Equal(t, cData, sData) - } + var cConn, sConn net.Conn + sConn, cConn = NewFConn() + + cData := []byte("qwerty") + count := 10 + + for i := 0; i < count; i++ { + wc, err := cConn.Write(cData) + if err != nil { + t.Error(err) + } + require.Equal(t, wc, len(cData)) + + sData := make([]byte, len(cData)) + rc, err := sConn.Read(sData) + require.NoError(t, err) + require.Equal(t, rc, len(cData)) + require.Equal(t, cData, sData) + } } func TestFConn1(t *testing.T) { - var cConn, sConn net.Conn - cConn, sConn = NewFConn() - - cData := []byte("qwerty") - count := 10 - - for i := 0; i < count; i++ { - wc, err := cConn.Write(cData) - if err != nil { - t.Error(err) - } - require.Equal(t, wc, len(cData)) - - sData := make([]byte, len(cData)) - rc, err := sConn.Read(sData) - require.NoError(t, err) - require.Equal(t, rc, len(cData)) - require.Equal(t, cData, sData) - } + var cConn, sConn net.Conn + cConn, sConn = NewFConn() + + cData := []byte("qwerty") + count := 10 + + for i := 0; i < count; i++ { + wc, err := cConn.Write(cData) + if err != nil { + t.Error(err) + } + require.Equal(t, wc, len(cData)) + + sData := make([]byte, len(cData)) + rc, err := sConn.Read(sData) + require.NoError(t, err) + require.Equal(t, rc, len(cData)) + require.Equal(t, cData, sData) + } } diff --git a/fconn.go b/fconn.go index 1c042d9..1398e3b 100644 --- a/fconn.go +++ b/fconn.go @@ -7,62 +7,62 @@ package dsrpc import ( - "bytes" - "io" - "net" - "time" + "bytes" + "io" + "net" + "time" ) type FConn struct { - reader io.Reader - writer io.Writer + reader io.Reader + writer io.Writer } -func NewFConn() (*FConn, *FConn){ - c2sBuffer := bytes.NewBuffer(make([]byte, 0)) - s2cBuffer := bytes.NewBuffer(make([]byte, 0)) +func NewFConn() (*FConn, *FConn) { + c2sBuffer := bytes.NewBuffer(make([]byte, 0)) + s2cBuffer := bytes.NewBuffer(make([]byte, 0)) - var client FConn - client.writer = c2sBuffer - client.reader = s2cBuffer + var client FConn + client.writer = c2sBuffer + client.reader = s2cBuffer - var server FConn - server.writer = s2cBuffer - server.reader = c2sBuffer + var server FConn + server.writer = s2cBuffer + server.reader = c2sBuffer - return &client, &server + return &client, &server } func (conn FConn) SetDeadline(t time.Time) error { - var err error - return err + var err error + return err } -func (conn FConn) SetReadDeadline(t time.Time) error { - var err error - return err +func (conn FConn) SetReadDeadline(t time.Time) error { + var err error + return err } func (conn FConn) SetWriteDeadline(t time.Time) error { - var err error - return err + var err error + return err } func (conn FConn) LocalAddr() net.Addr { - return NewFAddr() + return NewFAddr() } func (conn FConn) RemoteAddr() net.Addr { - return NewFAddr() + return NewFAddr() } func (conn FConn) Write(data []byte) (int, error) { - return conn.writer.Write(data) + return conn.writer.Write(data) } func (conn FConn) Read(data []byte) (int, error) { - return conn.reader.Read(data) + return conn.reader.Read(data) } func (conn FConn) Close() error { - var err error - return err + var err error + return err } diff --git a/header.go b/header.go index c4104cd..cee49bf 100644 --- a/header.go +++ b/header.go @@ -7,92 +7,90 @@ package dsrpc import ( - "errors" - "encoding/binary" - "encoding/json" - "bytes" + "bytes" + "encoding/binary" + "encoding/json" + "errors" ) -const headerSize int64 = 16 * 2 -const sizeOfInt64 int = 8 -const magicCodeA int64 = 0xEE00ABBA -const magicCodeB int64 = 0xEE44ABBA +const headerSize int64 = 16 * 2 +const sizeOfInt64 int = 8 +const magicCodeA int64 = 0xEE00ABBA +const magicCodeB int64 = 0xEE44ABBA type Header struct { - magicCodeA int64 `json:"magicCodeA"` - rpcSize int64 `json:"rpcSize"` - binSize int64 `json:"binSize"` - magicCodeB int64 `json:"magicCodeB"` + magicCodeA int64 `json:"magicCodeA"` + rpcSize int64 `json:"rpcSize"` + binSize int64 `json:"binSize"` + magicCodeB int64 `json:"magicCodeB"` } - func NewHeader() *Header { - return &Header{ - magicCodeA: magicCodeA, - magicCodeB: magicCodeB, - } + return &Header{ + magicCodeA: magicCodeA, + magicCodeB: magicCodeB, + } } func (hdr *Header) JSON() []byte { - jBytes, _ := json.Marshal(hdr) - return jBytes + jBytes, _ := json.Marshal(hdr) + return jBytes } - func (hdr *Header) Pack() ([]byte, error) { - var err error - headerBytes := make([]byte, 0, headerSize) - headerBuffer := bytes.NewBuffer(headerBytes) + var err error + headerBytes := make([]byte, 0, headerSize) + headerBuffer := bytes.NewBuffer(headerBytes) - magicCodeABytes := encoderI64(hdr.magicCodeA) - headerBuffer.Write(magicCodeABytes) + magicCodeABytes := encoderI64(hdr.magicCodeA) + headerBuffer.Write(magicCodeABytes) - rpcSizeBytes := encoderI64(hdr.rpcSize) - headerBuffer.Write(rpcSizeBytes) + rpcSizeBytes := encoderI64(hdr.rpcSize) + headerBuffer.Write(rpcSizeBytes) - binSizeBytes := encoderI64(hdr.binSize) - headerBuffer.Write(binSizeBytes) + binSizeBytes := encoderI64(hdr.binSize) + headerBuffer.Write(binSizeBytes) - magicCodeBBytes := encoderI64(hdr.magicCodeB) - headerBuffer.Write(magicCodeBBytes) + magicCodeBBytes := encoderI64(hdr.magicCodeB) + headerBuffer.Write(magicCodeBBytes) - return headerBuffer.Bytes(), Err(err) + return headerBuffer.Bytes(), Err(err) } func UnpackHeader(headerBytes []byte) (*Header, error) { - var err error - header := NewHeader() - headerReader := bytes.NewReader(headerBytes) - - magicCodeABytes := make([]byte, sizeOfInt64) - headerReader.Read(magicCodeABytes) - header.magicCodeA = decoderI64(magicCodeABytes) - - rpcSizeBytes := make([]byte, sizeOfInt64) - headerReader.Read(rpcSizeBytes) - header.rpcSize = decoderI64(rpcSizeBytes) - - binSizeBytes := make([]byte, sizeOfInt64) - headerReader.Read(binSizeBytes) - header.binSize = decoderI64(binSizeBytes) - - magicCodeBBytes := make([]byte, sizeOfInt64) - headerReader.Read(magicCodeBBytes) - header.magicCodeB = decoderI64(magicCodeBBytes) - - if header.magicCodeA != magicCodeA || header.magicCodeB != magicCodeB { - err = errors.New("wrong protocol magic code") - return header, Err(err) - } - return header, Err(err) + var err error + header := NewHeader() + headerReader := bytes.NewReader(headerBytes) + + magicCodeABytes := make([]byte, sizeOfInt64) + headerReader.Read(magicCodeABytes) + header.magicCodeA = decoderI64(magicCodeABytes) + + rpcSizeBytes := make([]byte, sizeOfInt64) + headerReader.Read(rpcSizeBytes) + header.rpcSize = decoderI64(rpcSizeBytes) + + binSizeBytes := make([]byte, sizeOfInt64) + headerReader.Read(binSizeBytes) + header.binSize = decoderI64(binSizeBytes) + + magicCodeBBytes := make([]byte, sizeOfInt64) + headerReader.Read(magicCodeBBytes) + header.magicCodeB = decoderI64(magicCodeBBytes) + + if header.magicCodeA != magicCodeA || header.magicCodeB != magicCodeB { + err = errors.New("wrong protocol magic code") + return header, Err(err) + } + return header, Err(err) } func encoderI64(i int64) []byte { - buffer := make([]byte, sizeOfInt64) - binary.BigEndian.PutUint64(buffer, uint64(i)) - return buffer + buffer := make([]byte, sizeOfInt64) + binary.BigEndian.PutUint64(buffer, uint64(i)) + return buffer } func decoderI64(b []byte) int64 { - return int64(binary.BigEndian.Uint64(b)) + return int64(binary.BigEndian.Uint64(b)) } diff --git a/logger.go b/logger.go index d0632e6..a296a08 100644 --- a/logger.go +++ b/logger.go @@ -7,39 +7,39 @@ package dsrpc import ( - "fmt" - "io" - "os" - "time" + "fmt" + "io" + "os" + "time" ) var messageWriter io.Writer = os.Stdout var accessWriter io.Writer = os.Stdout func logDebug(messages ...any) { - stamp := time.Now().Format(time.RFC3339) - fmt.Fprintln(messageWriter, stamp, "debug", messages) + stamp := time.Now().Format(time.RFC3339) + fmt.Fprintln(messageWriter, stamp, "debug", messages) } func logInfo(messages ...any) { - stamp := time.Now().Format(time.RFC3339) - fmt.Fprintln(messageWriter, stamp, "info", messages) + stamp := time.Now().Format(time.RFC3339) + fmt.Fprintln(messageWriter, stamp, "info", messages) } func logError(messages ...any) { - stamp := time.Now().Format(time.RFC3339) - fmt.Fprintln(messageWriter, stamp, "error", messages) + stamp := time.Now().Format(time.RFC3339) + fmt.Fprintln(messageWriter, stamp, "error", messages) } func logAccess(messages ...any) { - stamp := time.Now().Format(time.RFC3339) - fmt.Fprintln(accessWriter, stamp, "access", messages) + stamp := time.Now().Format(time.RFC3339) + fmt.Fprintln(accessWriter, stamp, "access", messages) } func SetAccessWriter(writer io.Writer) { - accessWriter = writer + accessWriter = writer } func SetMessageWriter(writer io.Writer) { - messageWriter = writer + messageWriter = writer } diff --git a/midware.go b/midware.go index 838381e..56e9fef 100644 --- a/midware.go +++ b/midware.go @@ -5,25 +5,25 @@ package dsrpc import ( - "time" + "time" ) func LogRequest(context *Context) error { - var err error - logDebug("request:", string(context.reqRPC.JSON())) - return Err(err) + var err error + logDebug("request:", string(context.reqRPC.JSON())) + return Err(err) } func LogResponse(context *Context) error { - var err error - logDebug("response:", string(context.resRPC.JSON())) - return Err(err) + var err error + logDebug("response:", string(context.resRPC.JSON())) + return Err(err) } func LogAccess(context *Context) error { - var err error - execTime := time.Now().Sub(context.start) - login := string(context.AuthIdent()) - logAccess(context.remoteHost, login, context.reqRPC.Method, execTime) - return Err(err) + var err error + execTime := time.Now().Sub(context.start) + login := string(context.AuthIdent()) + logAccess(context.remoteHost, login, context.reqRPC.Method, execTime) + return Err(err) } diff --git a/packet.go b/packet.go index 9d88b55..4631a8e 100644 --- a/packet.go +++ b/packet.go @@ -5,19 +5,19 @@ package dsrpc import ( - "encoding/json" + "encoding/json" ) type Packet struct { - header []byte - rcpPayload []byte + header []byte + rcpPayload []byte } func NewPacket() *Packet { - return &Packet{} + return &Packet{} } func (pkt *Packet) JSON() []byte { - jBytes, _ := json.Marshal(pkt) - return jBytes + jBytes, _ := json.Marshal(pkt) + return jBytes } diff --git a/request.go b/request.go index 2143813..35b7798 100644 --- a/request.go +++ b/request.go @@ -7,28 +7,28 @@ package dsrpc import ( - "encoding/json" - encoder "github.com/vmihailenco/msgpack/v5" + "encoding/json" + encoder "github.com/vmihailenco/msgpack/v5" ) type Request struct { - Method string `json:"method" msgpack:"method"` - Params any `json:"params,omitempty" msgpack:"params"` - Auth *Auth `json:"auth,omitempty" msgpack:"auth"` + Method string `json:"method" msgpack:"method"` + Params any `json:"params,omitempty" msgpack:"params"` + Auth *Auth `json:"auth,omitempty" msgpack:"auth"` } func NewRequest() *Request { - req := &Request{} - req.Auth = &Auth{} - return req + req := &Request{} + req.Auth = &Auth{} + return req } func (req *Request) Pack() ([]byte, error) { - rBytes, err := encoder.Marshal(req) - return rBytes, Err(err) + rBytes, err := encoder.Marshal(req) + return rBytes, Err(err) } func (req *Request) JSON() []byte { - jBytes, _ := json.Marshal(req) - return jBytes + jBytes, _ := json.Marshal(req) + return jBytes } diff --git a/response.go b/response.go index a010500..db5fa96 100644 --- a/response.go +++ b/response.go @@ -5,26 +5,25 @@ package dsrpc import ( - "encoding/json" - encoder "github.com/vmihailenco/msgpack/v5" + "encoding/json" + encoder "github.com/vmihailenco/msgpack/v5" ) - type Response struct { - Error string `json:"error" msgpack:"error"` - Result any `json:"result" msgpack:"result"` + Error string `json:"error" msgpack:"error"` + Result any `json:"result" msgpack:"result"` } func NewResponse() *Response { - return &Response{} + return &Response{} } func (resp *Response) JSON() []byte { - jBytes, _ := json.Marshal(resp) - return jBytes + jBytes, _ := json.Marshal(resp) + return jBytes } func (resp *Response) Pack() ([]byte, error) { - rBytes, err := encoder.Marshal(resp) - return rBytes, Err(err) + rBytes, err := encoder.Marshal(resp) + return rBytes, Err(err) } diff --git a/server.go b/server.go index dfe7f79..cb9adba 100644 --- a/server.go +++ b/server.go @@ -5,305 +5,303 @@ package dsrpc import ( - "context" - "errors" - "fmt" - "io" - "net" - "sync" - "time" - - encoder "github.com/vmihailenco/msgpack/v5" + "context" + "errors" + "fmt" + "io" + "net" + "sync" + "time" + + encoder "github.com/vmihailenco/msgpack/v5" ) -type HandlerFunc = func(*Context) error +type HandlerFunc = func(*Context) error type Service struct { - handlers map[string]HandlerFunc - ctx context.Context - cancel context.CancelFunc - wg *sync.WaitGroup - preMw []HandlerFunc - postMw []HandlerFunc - keepalive bool - kaTime time.Duration - kaMtx sync.Mutex + handlers map[string]HandlerFunc + ctx context.Context + cancel context.CancelFunc + wg *sync.WaitGroup + preMw []HandlerFunc + postMw []HandlerFunc + keepalive bool + kaTime time.Duration + kaMtx sync.Mutex } func NewService() *Service { - rdrpc := &Service{} - rdrpc.handlers = make(map[string]HandlerFunc) - ctx, cancel := context.WithCancel(context.Background()) - rdrpc.ctx = ctx - rdrpc.cancel = cancel - var wg sync.WaitGroup - rdrpc.wg = &wg - rdrpc.preMw = make([]HandlerFunc, 0) - rdrpc.postMw = make([]HandlerFunc, 0) - - return rdrpc + rdrpc := &Service{} + rdrpc.handlers = make(map[string]HandlerFunc) + ctx, cancel := context.WithCancel(context.Background()) + rdrpc.ctx = ctx + rdrpc.cancel = cancel + var wg sync.WaitGroup + rdrpc.wg = &wg + rdrpc.preMw = make([]HandlerFunc, 0) + rdrpc.postMw = make([]HandlerFunc, 0) + + return rdrpc } func (svc *Service) PreMiddleware(mw HandlerFunc) { - svc.preMw = append(svc.preMw, mw) + svc.preMw = append(svc.preMw, mw) } func (svc *Service) PostMiddleware(mw HandlerFunc) { - svc.postMw = append(svc.postMw, mw) + svc.postMw = append(svc.postMw, mw) } func (svc *Service) Handler(method string, handler HandlerFunc) { - svc.handlers[method] = handler + svc.handlers[method] = handler } func (svc *Service) SetKeepAlive(flag bool) { - svc.kaMtx.Lock() - defer svc.kaMtx.Unlock() - svc.keepalive = true + svc.kaMtx.Lock() + defer svc.kaMtx.Unlock() + svc.keepalive = true } func (svc *Service) SetKeepAlivePeriod(interval time.Duration) { - svc.kaMtx.Lock() - defer svc.kaMtx.Unlock() - svc.kaTime = interval + svc.kaMtx.Lock() + defer svc.kaMtx.Unlock() + svc.kaTime = interval } func (svc *Service) Listen(address string) error { - var err error - logInfo("server listen:", address) - - addr, err := net.ResolveTCPAddr("tcp", address) - if err != nil { - err = fmt.Errorf("unable to resolve adddress: %s", err) - return err - } - listener, err := net.ListenTCP("tcp", addr) - if err != nil { - err = fmt.Errorf("unable to start listener: %s", err) - return err - } - - for { - conn, err := listener.AcceptTCP() - if err != nil { - logError("conn accept err:", err) - } - select { - case <-svc.ctx.Done(): - return err - default: - } - svc.wg.Add(1) - go svc.handleConn(conn, svc.wg) - } - return err + var err error + logInfo("server listen:", address) + + addr, err := net.ResolveTCPAddr("tcp", address) + if err != nil { + err = fmt.Errorf("unable to resolve adddress: %s", err) + return err + } + listener, err := net.ListenTCP("tcp", addr) + if err != nil { + err = fmt.Errorf("unable to start listener: %s", err) + return err + } + + for { + conn, err := listener.AcceptTCP() + if err != nil { + logError("conn accept err:", err) + } + select { + case <-svc.ctx.Done(): + return err + default: + } + svc.wg.Add(1) + go svc.handleConn(conn, svc.wg) + } + return err } func notFound(context *Context) error { - execErr := errors.New("method not found") - err := context.SendError(execErr) - return err + execErr := errors.New("method not found") + err := context.SendError(execErr) + return err } func (svc *Service) Stop() error { - var err error - // Disable new connection - logInfo("cancel rpc accept loop") - svc.cancel() - // Wait handlers - logInfo("wait rpc handlers") - svc.wg.Wait() - return err + var err error + // Disable new connection + logInfo("cancel rpc accept loop") + svc.cancel() + // Wait handlers + logInfo("wait rpc handlers") + svc.wg.Wait() + return err } func (svc *Service) handleConn(conn *net.TCPConn, wg *sync.WaitGroup) { - var err error - - if svc.keepalive { - err = conn.SetKeepAlive(true) - if err != nil { - err = fmt.Errorf("unable to set keepalive: %s", err) - return - } - if svc.kaTime > 0 { - err = conn.SetKeepAlivePeriod(svc.kaTime) - if err != nil { - err = fmt.Errorf("unable to set keepalive period: %s", err) - return - } - } - } - context := CreateContext(conn) - - remoteAddr := conn.RemoteAddr().String() - remoteHost, _, _ := net.SplitHostPort(remoteAddr) - context.remoteHost = remoteHost - - context.binReader = conn - context.binWriter = io.Discard - - exitFunc := func() { - conn.Close() - wg.Done() - if err != nil { - logError("conn handler err:", err) - } - } - defer exitFunc() - - recovFunc := func () { - panicMsg := recover() - if panicMsg != nil { - logError("handler panic message:", panicMsg) - } - } - defer recovFunc() - - err = context.ReadRequest() - if err != nil { - err = Err(err) - return - } - - err = context.BindMethod() - if err != nil { - err = Err(err) - return - } - for _, mw := range svc.preMw { - err = mw(context) - if err != nil { - err = Err(err) - return - } - } - err = svc.Route(context) - if err != nil { - err = Err(err) - return - } - for _, mw := range svc.postMw { - err = mw(context) - if err != nil { - err = Err(err) - return - } - } - return + var err error + + if svc.keepalive { + err = conn.SetKeepAlive(true) + if err != nil { + err = fmt.Errorf("unable to set keepalive: %s", err) + return + } + if svc.kaTime > 0 { + err = conn.SetKeepAlivePeriod(svc.kaTime) + if err != nil { + err = fmt.Errorf("unable to set keepalive period: %s", err) + return + } + } + } + context := CreateContext(conn) + + remoteAddr := conn.RemoteAddr().String() + remoteHost, _, _ := net.SplitHostPort(remoteAddr) + context.remoteHost = remoteHost + + context.binReader = conn + context.binWriter = io.Discard + + exitFunc := func() { + conn.Close() + wg.Done() + if err != nil { + logError("conn handler err:", err) + } + } + defer exitFunc() + + recovFunc := func() { + panicMsg := recover() + if panicMsg != nil { + logError("handler panic message:", panicMsg) + } + } + defer recovFunc() + + err = context.ReadRequest() + if err != nil { + err = Err(err) + return + } + + err = context.BindMethod() + if err != nil { + err = Err(err) + return + } + for _, mw := range svc.preMw { + err = mw(context) + if err != nil { + err = Err(err) + return + } + } + err = svc.Route(context) + if err != nil { + err = Err(err) + return + } + for _, mw := range svc.postMw { + err = mw(context) + if err != nil { + err = Err(err) + return + } + } + return } func (svc *Service) Route(context *Context) error { - handler, ok := svc.handlers[context.reqRPC.Method] - if ok { - return Err(handler(context)) - } - return Err(notFound(context)) + handler, ok := svc.handlers[context.reqRPC.Method] + if ok { + return Err(handler(context)) + } + return Err(notFound(context)) } func (context *Context) ReadRequest() error { - var err error - - context.reqPacket.header, err = ReadBytes(context.sockReader, headerSize) - if err != nil { - return Err(err) - } - context.reqHeader, err = UnpackHeader(context.reqPacket.header) - if err != nil { - return Err(err) - } - - rpcSize := context.reqHeader.rpcSize - context.reqPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize) - if err != nil { - return Err(err) - } - return Err(err) + var err error + + context.reqPacket.header, err = ReadBytes(context.sockReader, headerSize) + if err != nil { + return Err(err) + } + context.reqHeader, err = UnpackHeader(context.reqPacket.header) + if err != nil { + return Err(err) + } + + rpcSize := context.reqHeader.rpcSize + context.reqPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize) + if err != nil { + return Err(err) + } + return Err(err) } func (context *Context) BinWriter() io.Writer { - return context.sockWriter + return context.sockWriter } func (context *Context) BinReader() io.Reader { - return context.sockReader + return context.sockReader } func (context *Context) BinSize() int64 { - return context.reqHeader.binSize + return context.reqHeader.binSize } func (context *Context) ReadBin(writer io.Writer) error { - var err error - _, err = CopyBytes(context.sockReader, writer, context.reqHeader.binSize) - return Err(err) + var err error + _, err = CopyBytes(context.sockReader, writer, context.reqHeader.binSize) + return Err(err) } - func (context *Context) BindMethod() error { - var err error - err = encoder.Unmarshal(context.reqPacket.rcpPayload, context.reqRPC) - return Err(err) + var err error + err = encoder.Unmarshal(context.reqPacket.rcpPayload, context.reqRPC) + return Err(err) } func (context *Context) BindParams(params any) error { - var err error - context.reqRPC.Params = params - err = encoder.Unmarshal(context.reqPacket.rcpPayload, context.reqRPC) - if err != nil { - return Err(err) - } - return Err(err) + var err error + context.reqRPC.Params = params + err = encoder.Unmarshal(context.reqPacket.rcpPayload, context.reqRPC) + if err != nil { + return Err(err) + } + return Err(err) } func (context *Context) SendResult(result any, binSize int64) error { - var err error - context.resRPC.Result = result - - context.resPacket.rcpPayload, err = context.resRPC.Pack() - if err != nil { - return Err(err) - } - context.resHeader.rpcSize = int64(len(context.resPacket.rcpPayload)) - context.resHeader.binSize = binSize - - context.resPacket.header, err = context.resHeader.Pack() - if err != nil { - return Err(err) - } - _, err = context.sockWriter.Write(context.resPacket.header) - if err != nil { - return Err(err) - } - _, err = context.sockWriter.Write(context.resPacket.rcpPayload) - if err != nil { - return Err(err) - } - return Err(err) + var err error + context.resRPC.Result = result + + context.resPacket.rcpPayload, err = context.resRPC.Pack() + if err != nil { + return Err(err) + } + context.resHeader.rpcSize = int64(len(context.resPacket.rcpPayload)) + context.resHeader.binSize = binSize + + context.resPacket.header, err = context.resHeader.Pack() + if err != nil { + return Err(err) + } + _, err = context.sockWriter.Write(context.resPacket.header) + if err != nil { + return Err(err) + } + _, err = context.sockWriter.Write(context.resPacket.rcpPayload) + if err != nil { + return Err(err) + } + return Err(err) } - func (context *Context) SendError(execErr error) error { - var err error - - context.resRPC.Error = execErr.Error() - context.resRPC.Result = NewEmpty() - - context.resPacket.rcpPayload, err = context.resRPC.Pack() - if err != nil { - return Err(err) - } - context.resHeader.rpcSize = int64(len(context.resPacket.rcpPayload)) - context.resPacket.header, err = context.resHeader.Pack() - if err != nil { - return Err(err) - } - _, err = context.sockWriter.Write(context.resPacket.header) - if err != nil { - return Err(err) - } - _, err = context.sockWriter.Write(context.resPacket.rcpPayload) - if err != nil { - return Err(err) - } - return Err(err) + var err error + + context.resRPC.Error = execErr.Error() + context.resRPC.Result = NewEmpty() + + context.resPacket.rcpPayload, err = context.resRPC.Pack() + if err != nil { + return Err(err) + } + context.resHeader.rpcSize = int64(len(context.resPacket.rcpPayload)) + context.resPacket.header, err = context.resHeader.Pack() + if err != nil { + return Err(err) + } + _, err = context.sockWriter.Write(context.resPacket.header) + if err != nil { + return Err(err) + } + _, err = context.sockWriter.Write(context.resPacket.rcpPayload) + if err != nil { + return Err(err) + } + return Err(err) } diff --git a/tools.go b/tools.go index e067353..ccc7bde 100644 --- a/tools.go +++ b/tools.go @@ -5,53 +5,53 @@ package dsrpc import ( - "errors" - "io" - "fmt" + "errors" + "fmt" + "io" ) func ReadBytes(reader io.Reader, size int64) ([]byte, error) { - buffer := make([]byte, size) - read, err := io.ReadFull(reader, buffer) - return buffer[0:read], Err(err) + buffer := make([]byte, size) + read, err := io.ReadFull(reader, buffer) + return buffer[0:read], Err(err) } func CopyBytes(reader io.Reader, writer io.Writer, dataSize int64) (int64, error) { - var err error - var bSize int64 = 1024 * 16 - var total int64 = 0 - var remains int64 = dataSize - buffer := make([]byte, bSize) + var err error + var bSize int64 = 1024 * 16 + var total int64 = 0 + var remains int64 = dataSize + buffer := make([]byte, bSize) - for { - if reader == nil { - return total, errors.New("reader is nil") - } - if writer == nil { - return total, errors.New("writer is nil") - } - if remains == 0 { - return total, err - } - if remains < bSize { - bSize = remains - } - received, err := reader.Read(buffer[0:bSize]) - if err != nil { - err = fmt.Errorf("read error: %v", err) - return total, Err(err) - } - recorded, err := writer.Write(buffer[0:received]) - if err != nil { - err = fmt.Errorf("write error: %v", err) - return total, Err(err) - } - if recorded != received { - err = errors.New("size mismatch") - return total, Err(err) - } - total += int64(recorded) - remains -= int64(recorded) - } - return total, Err(err) + for { + if reader == nil { + return total, errors.New("reader is nil") + } + if writer == nil { + return total, errors.New("writer is nil") + } + if remains == 0 { + return total, err + } + if remains < bSize { + bSize = remains + } + received, err := reader.Read(buffer[0:bSize]) + if err != nil { + err = fmt.Errorf("read error: %v", err) + return total, Err(err) + } + recorded, err := writer.Write(buffer[0:received]) + if err != nil { + err = fmt.Errorf("write error: %v", err) + return total, Err(err) + } + if recorded != received { + err = errors.New("size mismatch") + return total, Err(err) + } + total += int64(recorded) + remains -= int64(recorded) + } + return total, Err(err) } diff --git a/validate.go b/validate.go index 9584373..f98f1c2 100644 --- a/validate.go +++ b/validate.go @@ -4,161 +4,159 @@ package dsrpc - import ( - "io" - "net" + "io" + "net" ) func LocalExec(method string, param any, result any, auth *Auth, handler HandlerFunc) error { - var err error - - cliConn, srvConn := NewFConn() - - context := CreateContext(cliConn) - context.reqRPC.Method = method - context.reqRPC.Params = param - context.reqRPC.Auth = auth - context.resRPC.Result = result - - if context.reqRPC.Params == nil { - context.reqRPC.Params = NewEmpty() - } - err = context.CreateRequest() - if err != nil { - return err - } - err = context.WriteRequest() - if err != nil { - return err - } - err = LocalService(srvConn, handler) - if err != nil { - return err - } - err = context.ReadResponse() - if err != nil { - return err - } - err = context.BindResponse() - if err != nil { - return err - } - - return err + var err error + + cliConn, srvConn := NewFConn() + + context := CreateContext(cliConn) + context.reqRPC.Method = method + context.reqRPC.Params = param + context.reqRPC.Auth = auth + context.resRPC.Result = result + + if context.reqRPC.Params == nil { + context.reqRPC.Params = NewEmpty() + } + err = context.CreateRequest() + if err != nil { + return err + } + err = context.WriteRequest() + if err != nil { + return err + } + err = LocalService(srvConn, handler) + if err != nil { + return err + } + err = context.ReadResponse() + if err != nil { + return err + } + err = context.BindResponse() + if err != nil { + return err + } + + return err } func LocalPut(method string, reader io.Reader, size int64, param, result any, auth *Auth, handler HandlerFunc) error { - var err error - - cliConn, srvConn := NewFConn() - - context := CreateContext(cliConn) - context.reqRPC.Method = method - context.reqRPC.Params = param - context.reqRPC.Auth = auth - context.resRPC.Result = result - - context.binReader = reader - context.binWriter = cliConn - - context.reqHeader.binSize = size - - if context.reqRPC.Params == nil { - context.reqRPC.Params = NewEmpty() - } - err = context.CreateRequest() - if err != nil { - return err - } - err = context.WriteRequest() - if err != nil { - return err - } - err = context.UploadBin() - if err != nil { - return err - } - err = LocalService(srvConn, handler) - if err != nil { - return err - } - err = context.ReadResponse() - if err != nil { - return err - } - err = context.BindResponse() - if err != nil { - return err - } - return err + var err error + + cliConn, srvConn := NewFConn() + + context := CreateContext(cliConn) + context.reqRPC.Method = method + context.reqRPC.Params = param + context.reqRPC.Auth = auth + context.resRPC.Result = result + + context.binReader = reader + context.binWriter = cliConn + + context.reqHeader.binSize = size + + if context.reqRPC.Params == nil { + context.reqRPC.Params = NewEmpty() + } + err = context.CreateRequest() + if err != nil { + return err + } + err = context.WriteRequest() + if err != nil { + return err + } + err = context.UploadBin() + if err != nil { + return err + } + err = LocalService(srvConn, handler) + if err != nil { + return err + } + err = context.ReadResponse() + if err != nil { + return err + } + err = context.BindResponse() + if err != nil { + return err + } + return err } - func LocalGet(method string, writer io.Writer, param, result any, auth *Auth, handler HandlerFunc) error { - var err error - - cliConn, srvConn := NewFConn() - - context := CreateContext(cliConn) - context.reqRPC.Method = method - context.reqRPC.Params = param - context.reqRPC.Auth = auth - context.resRPC.Result = result - - context.binReader = cliConn - context.binWriter = writer - - if context.reqRPC.Params == nil { - context.reqRPC.Params = NewEmpty() - } - err = context.CreateRequest() - if err != nil { - return err - } - err = context.WriteRequest() - if err != nil { - return err - } - - err = LocalService(srvConn, handler) - if err != nil { - return err - } - err = context.ReadResponse() - if err != nil { - return err - } - err = context.DownloadBin() - if err != nil { - return err - } - err = context.BindResponse() - if err != nil { - return err - } - return err + var err error + + cliConn, srvConn := NewFConn() + + context := CreateContext(cliConn) + context.reqRPC.Method = method + context.reqRPC.Params = param + context.reqRPC.Auth = auth + context.resRPC.Result = result + + context.binReader = cliConn + context.binWriter = writer + + if context.reqRPC.Params == nil { + context.reqRPC.Params = NewEmpty() + } + err = context.CreateRequest() + if err != nil { + return err + } + err = context.WriteRequest() + if err != nil { + return err + } + + err = LocalService(srvConn, handler) + if err != nil { + return err + } + err = context.ReadResponse() + if err != nil { + return err + } + err = context.DownloadBin() + if err != nil { + return err + } + err = context.BindResponse() + if err != nil { + return err + } + return err } func LocalService(conn net.Conn, handler HandlerFunc) error { - var err error - context := CreateContext(conn) - - remoteAddr := conn.RemoteAddr().String() - remoteHost, _, _ := net.SplitHostPort(remoteAddr) - context.remoteHost = remoteHost - - context.binReader = conn - context.binWriter = io.Discard - - err = context.ReadRequest() - if err != nil { - return err - } - err = context.BindMethod() - if err != nil { - return err - } - return handler(context) + var err error + context := CreateContext(conn) + + remoteAddr := conn.RemoteAddr().String() + remoteHost, _, _ := net.SplitHostPort(remoteAddr) + context.remoteHost = remoteHost + + context.binReader = conn + context.binWriter = io.Discard + + err = context.ReadRequest() + if err != nil { + return err + } + err = context.BindMethod() + if err != nil { + return err + } + return handler(context) } diff --git a/xauth.go b/xauth.go index 91c0b55..81f7690 100644 --- a/xauth.go +++ b/xauth.go @@ -5,60 +5,60 @@ package dsrpc import ( - "encoding/json" - "bytes" - "math/rand" - "time" - "crypto/sha256" + "bytes" + "crypto/sha256" + "encoding/json" + "math/rand" + "time" ) func init() { - rand.Seed(time.Now().UnixNano()) + rand.Seed(time.Now().UnixNano()) } type Auth struct { - Ident []byte `msgpack:"ident" json:"ident"` - Salt []byte `msgpack:"salt" json:"salt"` - Hash []byte `msgpack:"hash" json:"hash"` + Ident []byte `msgpack:"ident" json:"ident"` + Salt []byte `msgpack:"salt" json:"salt"` + Hash []byte `msgpack:"hash" json:"hash"` } func NewAuth() *Auth { - return &Auth{} + return &Auth{} } func (this *Auth) JSON() []byte { - jBytes, _ := json.Marshal(this) - return jBytes + jBytes, _ := json.Marshal(this) + return jBytes } func CreateAuth(ident, pass []byte) *Auth { - salt := CreateSalt() - hash := CreateHash(ident, pass, salt) - auth := &Auth{} - auth.Ident = ident - auth.Salt = salt - auth.Hash = hash - return auth + salt := CreateSalt() + hash := CreateHash(ident, pass, salt) + auth := &Auth{} + auth.Ident = ident + auth.Salt = salt + auth.Hash = hash + return auth } func CreateSalt() []byte { - const saltSize = 16 - randBytes := make([]byte, saltSize) - rand.Read(randBytes) - return randBytes + const saltSize = 16 + randBytes := make([]byte, saltSize) + rand.Read(randBytes) + return randBytes } func CreateHash(ident, pass, salt []byte) []byte { - vec := make([]byte, 0, len(ident) + len(salt) + len(pass)) - vec = append(vec, ident...) - vec = append(vec, salt...) - vec = append(vec, pass...) - hasher := sha256.New() - hash := hasher.Sum(vec) - return hash + vec := make([]byte, 0, len(ident)+len(salt)+len(pass)) + vec = append(vec, ident...) + vec = append(vec, salt...) + vec = append(vec, pass...) + hasher := sha256.New() + hash := hasher.Sum(vec) + return hash } func CheckHash(ident, pass, reqSalt, reqHash []byte) bool { - localHash := CreateHash(ident, pass, reqSalt) - return bytes.Equal(reqHash, localHash) + localHash := CreateHash(ident, pass, reqSalt) + return bytes.Equal(reqHash, localHash) }