diff --git a/client.go b/client.go index 05ae16a..b0906d6 100644 --- a/client.go +++ b/client.go @@ -14,80 +14,71 @@ import ( encoder "github.com/vmihailenco/msgpack/v5" ) -func Put(address string, method string, reader io.Reader, size int64, param, result any, auth *Auth) error { +func Put(address string, method string, reader io.Reader, binSize int64, param, result any, auth *Auth) error { var err error addr, err := net.ResolveTCPAddr("tcp", address) if err != nil { err = fmt.Errorf("unable to resolve adddress: %s", err) - return Err(err) + return err } conn, err := net.DialTCP("tcp", nil, addr) if err != nil { - return Err(err) + return err } defer conn.Close() - //err = conn.SetKeepAlive(true) - //if err != nil { - // err = fmt.Errorf("unable to set keepalive: %s", err) - // return Err(err) - //} - - //err = conn.SetKeepAlivePeriod(10 * time.Second) - //if err != nil { - // err = fmt.Errorf("unable to set keepalive period: %s", err) - // return Err(err) - //} - - return ConnPut(conn, method, reader, size, param, result, auth) + return ConnPut(conn, method, reader, binSize, param, result, auth) } -func ConnPut(conn net.Conn, method string, reader io.Reader, size int64, param, result any, auth *Auth) error { +func ConnPut(conn net.Conn, method string, reader io.Reader, binSize int64, param, result any, auth *Auth) error { var err error - context := CreateContext(conn) - context.reqRPC.Method = method - context.reqRPC.Params = param - context.reqRPC.Auth = auth - context.resRPC.Result = result + content := CreateContent(conn) - context.binReader = reader - context.binWriter = conn - - context.reqHeader.binSize = size - - if context.reqRPC.Params == nil { - context.reqRPC.Params = NewEmpty() + content.reqBlock.Method = method + if param != nil { + content.reqBlock.Params = param + } + if auth != nil { + content.reqBlock.Auth = auth + } + if result != nil { + content.resBlock.Result = result } - err = context.CreateRequest() + content.binReader = reader + content.binWriter = conn + + content.reqHeader.binSize = binSize + + err = content.CreateRequest() if err != nil { - return Err(err) + return err } - err = context.WriteRequest() + err = content.WriteRequest() if err != nil { - return Err(err) + return err } var wg sync.WaitGroup errChan := make(chan error, 1) wg.Add(1) - go context.ReadResponseAsync(&wg, errChan) + go content.ReadResponseAsync(&wg, errChan) wg.Add(1) - go context.UploadBinAsync(&wg) + go content.UploadBinAsync(&wg) wg.Wait() err = <-errChan if err != nil { - return Err(err) + return err } - err = context.BindResponse() + err = content.BindResponse() if err != nil { - return Err(err) + return err } - return Err(err) + return err } func Get(address string, method string, writer io.Writer, param, result any, auth *Auth) error { @@ -96,65 +87,56 @@ func Get(address string, method string, writer io.Writer, param, result any, aut addr, err := net.ResolveTCPAddr("tcp", address) if err != nil { err = fmt.Errorf("unable to resolve adddress: %s", err) - return Err(err) + return err } conn, err := net.DialTCP("tcp", nil, addr) if err != nil { - return Err(err) + return err } defer conn.Close() - //err = conn.SetKeepAlive(true) - //if err != nil { - // err = fmt.Errorf("unable to set keepalive: %s", err) - // return Err(err) - //} - - //err = conn.SetKeepAlivePeriod(10 * time.Second) - //if err != nil { - // err = fmt.Errorf("unable to set keepalive period: %s", err) - // return Err(err) - //} - return ConnGet(conn, method, writer, param, result, auth) } func ConnGet(conn net.Conn, method string, writer io.Writer, param, result any, auth *Auth) error { var err error - context := CreateContext(conn) - context.reqRPC.Method = method - context.reqRPC.Params = param - context.reqRPC.Auth = auth - context.resRPC.Result = result + content := CreateContent(conn) + content.reqBlock.Method = method + if param != nil { + content.reqBlock.Params = param + } + if auth != nil { + content.reqBlock.Auth = auth + } + if result != nil { + content.resBlock.Result = result + } - context.binReader = conn - context.binWriter = writer + content.binReader = conn + content.binWriter = writer - if context.reqRPC.Params == nil { - context.reqRPC.Params = NewEmpty() - } - err = context.CreateRequest() + err = content.CreateRequest() if err != nil { - return Err(err) + return err } - err = context.WriteRequest() + err = content.WriteRequest() if err != nil { - return Err(err) + return err } - err = context.ReadResponse() + err = content.ReadResponse() if err != nil { - return Err(err) + return err } - err = context.DownloadBin() + err = content.DownloadBin() if err != nil { - return Err(err) + return err } - err = context.BindResponse() + err = content.BindResponse() if err != nil { - return Err(err) + return err } - return Err(err) + return err } func Exec(address, method string, param any, result any, auth *Auth) error { @@ -163,170 +145,162 @@ func Exec(address, method string, param any, result any, auth *Auth) error { addr, err := net.ResolveTCPAddr("tcp", address) if err != nil { err = fmt.Errorf("unable to resolve adddress: %s", err) - return Err(err) + return err } conn, err := net.DialTCP("tcp", nil, addr) if err != nil { - return Err(err) + return err } defer conn.Close() - //err = conn.SetKeepAlive(true) - //if err != nil { - // err = fmt.Errorf("unable to set keepalive: %s", err) - // return Err(err) - //} - //err = conn.SetKeepAlivePeriod(10 * time.Second) - //if err != nil { - // err = fmt.Errorf("unable to set keepalive period: %s", err) - // return Err(err) - //} - err = ConnExec(conn, method, param, result, auth) if err != nil { - return Err(err) + return err } - return Err(err) + return err } func ConnExec(conn net.Conn, method string, param any, result any, auth *Auth) error { var err error - context := CreateContext(conn) - context.reqRPC.Method = method - context.reqRPC.Params = param - context.reqRPC.Auth = auth - context.resRPC.Result = result + content := CreateContent(conn) + content.reqBlock.Method = method - if context.reqRPC.Params == nil { - context.reqRPC.Params = NewEmpty() + if param != nil { + content.reqBlock.Params = param + } + if result != nil { + content.resBlock.Result = result + } + if auth != nil { + content.reqBlock.Auth = auth } - err = context.CreateRequest() + err = content.CreateRequest() if err != nil { - return Err(err) + return err } - err = context.WriteRequest() + err = content.WriteRequest() if err != nil { - return Err(err) + return err } - err = context.ReadResponse() + err = content.ReadResponse() if err != nil { - return Err(err) + return err } - err = context.BindResponse() + err = content.BindResponse() if err != nil { - return Err(err) + return err } - return Err(err) + return err } -func (context *Context) CreateRequest() error { +func (content *Content) CreateRequest() error { var err error - context.reqPacket.rcpPayload, err = context.reqRPC.Pack() + content.reqPacket.rcpPayload, err = content.reqBlock.Pack() if err != nil { - return Err(err) + return err } - rpcSize := int64(len(context.reqPacket.rcpPayload)) - context.reqHeader.rpcSize = rpcSize + rpcSize := int64(len(content.reqPacket.rcpPayload)) + content.reqHeader.rpcSize = rpcSize - context.reqPacket.header, err = context.reqHeader.Pack() + content.reqPacket.header, err = content.reqHeader.Pack() if err != nil { - return Err(err) + return err } - return Err(err) + return err } -func (context *Context) WriteRequest() error { +func (content *Content) WriteRequest() error { var err error - _, err = context.sockWriter.Write(context.reqPacket.header) + _, err = content.sockWriter.Write(content.reqPacket.header) if err != nil { - return Err(err) + return err } - _, err = context.sockWriter.Write(context.reqPacket.rcpPayload) + _, err = content.sockWriter.Write(content.reqPacket.rcpPayload) if err != nil { - return Err(err) + return err } - return Err(err) + return err } -func (context *Context) UploadBin() error { +func (content *Content) UploadBin() error { var err error - _, err = CopyBytes(context.binReader, context.binWriter, context.reqHeader.binSize) - return Err(err) + _, err = CopyBytes(content.binReader, content.binWriter, content.reqHeader.binSize) + return err } -func (context *Context) ReadResponse() error { +func (content *Content) ReadResponse() error { var err error - context.resPacket.header, err = ReadBytes(context.sockReader, headerSize) + content.resPacket.header, err = ReadBytes(content.sockReader, headerSize) if err != nil { - return Err(err) + return err } - context.resHeader, err = UnpackHeader(context.resPacket.header) + content.resHeader, err = UnpackHeader(content.resPacket.header) if err != nil { - return Err(err) + return err } - rpcSize := context.resHeader.rpcSize - context.resPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize) + rpcSize := content.resHeader.rpcSize + content.resPacket.rcpPayload, err = ReadBytes(content.sockReader, rpcSize) if err != nil { - return Err(err) + return err } - return Err(err) + return err } -func (context *Context) UploadBinAsync(wg *sync.WaitGroup) { +func (content *Content) UploadBinAsync(wg *sync.WaitGroup) { exitFunc := func() { wg.Done() } defer exitFunc() - _, _ = CopyBytes(context.binReader, context.binWriter, context.reqHeader.binSize) + _, _ = CopyBytes(content.binReader, content.binWriter, content.reqHeader.binSize) return } -func (context *Context) ReadResponseAsync(wg *sync.WaitGroup, errChan chan error) { +func (content *Content) ReadResponseAsync(wg *sync.WaitGroup, errChan chan error) { var err error exitFunc := func() { errChan <- err wg.Done() } defer exitFunc() - context.resPacket.header, err = ReadBytes(context.sockReader, headerSize) + content.resPacket.header, err = ReadBytes(content.sockReader, headerSize) if err != nil { - err = Err(err) + err = err return } - context.resHeader, err = UnpackHeader(context.resPacket.header) + content.resHeader, err = UnpackHeader(content.resPacket.header) if err != nil { - err = Err(err) + err = err return } - rpcSize := context.resHeader.rpcSize - context.resPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize) + rpcSize := content.resHeader.rpcSize + content.resPacket.rcpPayload, err = ReadBytes(content.sockReader, rpcSize) if err != nil { - err = Err(err) + err = err return } return } -func (context *Context) DownloadBin() error { +func (content *Content) DownloadBin() error { var err error - _, err = CopyBytes(context.binReader, context.binWriter, context.resHeader.binSize) - return Err(err) + _, err = CopyBytes(content.binReader, content.binWriter, content.resHeader.binSize) + return err } -func (context *Context) BindResponse() error { +func (content *Content) BindResponse() error { var err error - err = encoder.Unmarshal(context.resPacket.rcpPayload, context.resRPC) + err = encoder.Unmarshal(content.resPacket.rcpPayload, content.resBlock) if err != nil { - return Err(err) + return err } - if len(context.resRPC.Error) > 0 { - err = errors.New(context.resRPC.Error) - return Err(err) + if len(content.resBlock.Error) > 0 { + err = errors.New(content.resBlock.Error) + return err } - return Err(err) + return err } diff --git a/compat.go b/compat.go deleted file mode 100644 index b9ea6ba..0000000 --- a/compat.go +++ /dev/null @@ -1,5 +0,0 @@ -/* - * Copyright 2022 Oleg Borodin - */ - -package dsrpc diff --git a/context.go b/context.go deleted file mode 100644 index d78ea30..0000000 --- a/context.go +++ /dev/null @@ -1,152 +0,0 @@ -/* - * Copyright 2022 Oleg Borodin - */ - -package dsrpc - -import ( - "io" - "net" - "time" -) - -type Context struct { - start time.Time - remoteHost string - sockReader io.Reader - sockWriter io.Writer - - reqHeader *Header - reqRPC *Request - reqPacket *Packet - - resPacket *Packet - resHeader *Header - resRPC *Response - - binReader io.Reader - binWriter io.Writer -} - -func NewContext() *Context { - context := &Context{} - context.start = time.Now() - return context -} - -func CreateContext(conn net.Conn) *Context { - context := &Context{} - context.start = time.Now() - context.sockReader = conn - context.sockWriter = conn - - context.reqPacket = NewPacket() - context.resPacket = NewPacket() - - context.reqHeader = NewHeader() - context.reqRPC = NewRequest() - - context.resHeader = NewHeader() - context.resRPC = NewResponse() - context.resRPC = NewResponse() - - return context -} - -func (context *Context) Request() *Request { - return context.reqRPC -} - -func (context *Context) RemoteHost() string { - return context.remoteHost -} - -func (context *Context) Start() time.Time { - return context.start -} - -func (context *Context) Method() string { - var method string - if context.reqRPC != nil { - method = context.reqRPC.Method - } - return method -} - -func (context *Context) ReqRpcSize() int64 { - var size int64 - if context.reqHeader != nil { - size = context.reqHeader.rpcSize - } - return size -} - -func (context *Context) ReqBinSize() int64 { - var size int64 - if context.reqHeader != nil { - size = context.reqHeader.binSize - } - return size -} - -func (context *Context) ResBinSize() int64 { - var size int64 - if context.resHeader != nil { - size = context.resHeader.binSize - } - return size -} - -func (context *Context) ResRpcSize() int64 { - var size int64 - if context.resHeader != nil { - size = context.resHeader.rpcSize - } - return size -} - -func (context *Context) ReqSize() int64 { - var size int64 - if context.reqHeader != nil { - size += context.reqHeader.binSize - size += context.reqHeader.rpcSize - } - return size -} - -func (context *Context) ResSize() int64 { - var size int64 - if context.resHeader != nil { - size += context.resHeader.binSize - size += context.resHeader.rpcSize - } - return size -} - -func (context *Context) SetAuthIdent(ident []byte) { - context.reqRPC.Auth.Ident = ident -} - -func (context *Context) SetAuthSalt(salt []byte) { - context.reqRPC.Auth.Salt = salt -} - -func (context *Context) SetAuthHash(hash []byte) { - context.reqRPC.Auth.Hash = hash -} - -func (context *Context) AuthIdent() []byte { - return context.reqRPC.Auth.Ident -} - -func (context *Context) AuthSalt() []byte { - return context.reqRPC.Auth.Salt -} - -func (context *Context) AuthHash() []byte { - return context.reqRPC.Auth.Hash -} - -func (context *Context) Auth() *Auth { - return context.reqRPC.Auth -} diff --git a/empty.go b/empty.go deleted file mode 100644 index 4ab4e05..0000000 --- a/empty.go +++ /dev/null @@ -1,13 +0,0 @@ -/* - * - * Copyright 2022 Oleg Borodin - * - */ - -package dsrpc - -type Empty struct{} - -func NewEmpty() *Empty { - return &Empty{} -} diff --git a/error.go b/error.go deleted file mode 100644 index a25a0ea..0000000 --- a/error.go +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright 2022 Oleg Borodin - */ - -package dsrpc - -import ( - "fmt" - "io" - "runtime" -) - -var develMode bool = false -var debugMode bool = false - -func SetDevelMode(mode bool) { - develMode = mode -} -func SetDebugMode(mode bool) { - debugMode = mode -} - -func Err(err error) error { - switch err { - case io.EOF: - return err - } - if err != nil { - switch { - case develMode == true: - pc, filename, line, _ := runtime.Caller(1) - funcName := runtime.FuncForPC(pc).Name() - err = fmt.Errorf(" %s:%d:%s:%s", filename, line, funcName, err.Error()) - case debugMode == true: - pc, _, line, _ := runtime.Caller(1) - funcName := runtime.FuncForPC(pc).Name() - err = fmt.Errorf(" %s:%d:%s ", funcName, line, err.Error()) - default: - } - } - return err -} diff --git a/exec_test.go b/exec_test.go index 8d22e58..ca9262d 100644 --- a/exec_test.go +++ b/exec_test.go @@ -209,58 +209,58 @@ func testServ(quiet bool) error { return err } -func auth(context *Context) error { +func auth(content *Content) error { var err error - reqIdent := context.AuthIdent() - reqSalt := context.AuthSalt() - reqHash := context.AuthHash() + reqIdent := content.AuthIdent() + reqSalt := content.AuthSalt() + reqHash := content.AuthHash() ident := reqIdent pass := []byte("12345") - auth := context.Auth() + auth := content.Auth() logDebug("auth ", string(auth.JSON())) ok := CheckHash(ident, pass, reqSalt, reqHash) logDebug("auth ok:", ok) if !ok { err = errors.New("auth ident or pass missmatch") - context.SendError(err) + content.SendError(err) return err } return err } -func helloHandler(context *Context) error { +func helloHandler(content *Content) error { var err error params := NewHelloParams() - err = context.BindParams(params) + err = content.BindParams(params) if err != nil { return err } - err = context.ReadBin(io.Discard) + err = content.ReadBin(io.Discard) if err != nil { - context.SendError(err) + content.SendError(err) return err } result := NewHelloResult() result.Message = "hello, client!" - err = context.SendResult(result, 0) + err = content.SendResult(result, 0) if err != nil { return err } return err } -func saveHandler(context *Context) error { +func saveHandler(content *Content) error { var err error params := NewSaveParams() - err = context.BindParams(params) + err = content.BindParams(params) if err != nil { return err } @@ -268,34 +268,34 @@ func saveHandler(context *Context) error { bufferBytes := make([]byte, 0, 1024) binWriter := bytes.NewBuffer(bufferBytes) - err = context.ReadBin(binWriter) + err = content.ReadBin(binWriter) if err != nil { - context.SendError(err) + content.SendError(err) return err } result := NewSaveResult() result.Message = "saved successfully!" - err = context.SendResult(result, 0) + err = content.SendResult(result, 0) if err != nil { return err } return err } -func loadHandler(context *Context) error { +func loadHandler(content *Content) error { var err error params := NewSaveParams() - err = context.BindParams(params) + err = content.BindParams(params) if err != nil { return err } - err = context.ReadBin(io.Discard) + err = content.ReadBin(io.Discard) if err != nil { - context.SendError(err) + content.SendError(err) return err } @@ -309,11 +309,11 @@ func loadHandler(context *Context) error { result := NewSaveResult() result.Message = "load successfully!" - err = context.SendResult(result, binSize) + err = content.SendResult(result, binSize) if err != nil { return err } - binWriter := context.BinWriter() + binWriter := content.BinWriter() _, err = CopyBytes(binReader, binWriter, binSize) if err != nil { return err diff --git a/faddr.go b/faddr.go index 952e1b1..a03c652 100644 --- a/faddr.go +++ b/faddr.go @@ -10,9 +10,10 @@ type FAddr struct { } func NewFAddr() *FAddr { - var addr FAddr - addr.network = "tcp" - addr.address = "127.0.0.1:5000" + addr := FAddr{ + network: "tcp", + address: "127.0.0.1:5000", + } return &addr } diff --git a/faddr_test.go b/faddr_test.go index a94d0cd..48c7b1b 100644 --- a/faddr_test.go +++ b/faddr_test.go @@ -7,9 +7,10 @@ package dsrpc import ( - "github.com/stretchr/testify/require" "net" "testing" + + "github.com/stretchr/testify/require" ) func TestFConn0(t *testing.T) { diff --git a/header.go b/header.go index cee49bf..61493d9 100644 --- a/header.go +++ b/header.go @@ -13,10 +13,12 @@ import ( "errors" ) -const headerSize int64 = 16 * 2 -const sizeOfInt64 int = 8 -const magicCodeA int64 = 0xEE00ABBA -const magicCodeB int64 = 0xEE44ABBA +const ( + headerSize int64 = 16 * 2 + sizeOfInt64 int = 8 + magicCodeA int64 = 0xEE00ABBA + magicCodeB int64 = 0xEE44ABBA +) type Header struct { magicCodeA int64 `json:"magicCodeA"` @@ -25,14 +27,14 @@ type Header struct { magicCodeB int64 `json:"magicCodeB"` } -func NewHeader() *Header { +func NewEmptyHeader() *Header { return &Header{ magicCodeA: magicCodeA, magicCodeB: magicCodeB, } } -func (hdr *Header) JSON() []byte { +func (hdr *Header) ToJson() []byte { jBytes, _ := json.Marshal(hdr) return jBytes } @@ -54,35 +56,38 @@ func (hdr *Header) Pack() ([]byte, error) { magicCodeBBytes := encoderI64(hdr.magicCodeB) headerBuffer.Write(magicCodeBBytes) - return headerBuffer.Bytes(), Err(err) + return headerBuffer.Bytes(), err } func UnpackHeader(headerBytes []byte) (*Header, error) { var err error - header := NewHeader() + headerReader := bytes.NewReader(headerBytes) magicCodeABytes := make([]byte, sizeOfInt64) headerReader.Read(magicCodeABytes) - header.magicCodeA = decoderI64(magicCodeABytes) rpcSizeBytes := make([]byte, sizeOfInt64) headerReader.Read(rpcSizeBytes) - header.rpcSize = decoderI64(rpcSizeBytes) binSizeBytes := make([]byte, sizeOfInt64) headerReader.Read(binSizeBytes) - header.binSize = decoderI64(binSizeBytes) magicCodeBBytes := make([]byte, sizeOfInt64) headerReader.Read(magicCodeBBytes) - header.magicCodeB = decoderI64(magicCodeBBytes) + + header := &Header{ + magicCodeA: decoderI64(magicCodeABytes), + rpcSize: decoderI64(rpcSizeBytes), + binSize: decoderI64(binSizeBytes), + magicCodeB: decoderI64(magicCodeBBytes), + } if header.magicCodeA != magicCodeA || header.magicCodeB != magicCodeB { - err = errors.New("wrong protocol magic code") - return header, Err(err) + err = errors.New("Wrong protocol magic code") + return header, err } - return header, Err(err) + return header, err } func encoderI64(i int64) []byte { diff --git a/midware.go b/midware.go index 56e9fef..8ffe553 100644 --- a/midware.go +++ b/midware.go @@ -8,22 +8,22 @@ import ( "time" ) -func LogRequest(context *Context) error { +func LogRequest(content *Content) error { var err error - logDebug("request:", string(context.reqRPC.JSON())) - return Err(err) + logDebug("request:", string(content.reqBlock.ToJson())) + return err } -func LogResponse(context *Context) error { +func LogResponse(content *Content) error { var err error - logDebug("response:", string(context.resRPC.JSON())) - return Err(err) + logDebug("response:", string(content.resBlock.ToJson())) + return err } -func LogAccess(context *Context) error { +func LogAccess(content *Content) error { var err error - execTime := time.Now().Sub(context.start) - login := string(context.AuthIdent()) - logAccess(context.remoteHost, login, context.reqRPC.Method, execTime) - return Err(err) + execTime := time.Now().Sub(content.start) + login := string(content.AuthIdent()) + logAccess(content.remoteHost, login, content.reqBlock.Method, execTime) + return err } diff --git a/packet.go b/packet.go index 4631a8e..a0c46f9 100644 --- a/packet.go +++ b/packet.go @@ -4,20 +4,15 @@ package dsrpc -import ( - "encoding/json" -) - type Packet struct { header []byte rcpPayload []byte } -func NewPacket() *Packet { - return &Packet{} -} - -func (pkt *Packet) JSON() []byte { - jBytes, _ := json.Marshal(pkt) - return jBytes +func NewEmptyPacket() *Packet { + packet := &Packet{ + header: make([]byte, 0), + rcpPayload: make([]byte, 0), + } + return packet } diff --git a/request.go b/request.go index 35b7798..f50e4ed 100644 --- a/request.go +++ b/request.go @@ -11,24 +11,31 @@ import ( encoder "github.com/vmihailenco/msgpack/v5" ) +type EmptyParams struct{} + +func NewEmptyParams() *EmptyParams { + return &EmptyParams{} +} + type Request struct { Method string `json:"method" msgpack:"method"` Params any `json:"params,omitempty" msgpack:"params"` Auth *Auth `json:"auth,omitempty" msgpack:"auth"` } -func NewRequest() *Request { +func NewEmptyRequest() *Request { req := &Request{} req.Auth = &Auth{} + req.Params = NewEmptyParams() return req } func (req *Request) Pack() ([]byte, error) { rBytes, err := encoder.Marshal(req) - return rBytes, Err(err) + return rBytes, err } -func (req *Request) JSON() []byte { +func (req *Request) ToJson() []byte { jBytes, _ := json.Marshal(req) return jBytes } diff --git a/response.go b/response.go index db5fa96..cc05001 100644 --- a/response.go +++ b/response.go @@ -6,24 +6,33 @@ package dsrpc import ( "encoding/json" + encoder "github.com/vmihailenco/msgpack/v5" ) +type EmptyResult struct{} + +func NewEmptyResult() *EmptyResult { + return &EmptyResult{} +} + type Response struct { Error string `json:"error" msgpack:"error"` Result any `json:"result" msgpack:"result"` } -func NewResponse() *Response { - return &Response{} +func NewEmptyResponse() *Response { + return &Response{ + Result: NewEmptyResult(), + } } -func (resp *Response) JSON() []byte { +func (resp *Response) ToJson() []byte { jBytes, _ := json.Marshal(resp) return jBytes } func (resp *Response) Pack() ([]byte, error) { rBytes, err := encoder.Marshal(resp) - return rBytes, Err(err) + return rBytes, err } diff --git a/server.go b/server.go index cb9adba..cdfb847 100644 --- a/server.go +++ b/server.go @@ -16,7 +16,7 @@ import ( encoder "github.com/vmihailenco/msgpack/v5" ) -type HandlerFunc = func(*Context) error +type HandlerFunc = func(*Content) error type Service struct { handlers map[string]HandlerFunc @@ -99,9 +99,9 @@ func (svc *Service) Listen(address string) error { return err } -func notFound(context *Context) error { +func notFound(content *Content) error { execErr := errors.New("method not found") - err := context.SendError(execErr) + err := content.SendError(execErr) return err } @@ -133,14 +133,14 @@ func (svc *Service) handleConn(conn *net.TCPConn, wg *sync.WaitGroup) { } } } - context := CreateContext(conn) + content := CreateContent(conn) remoteAddr := conn.RemoteAddr().String() remoteHost, _, _ := net.SplitHostPort(remoteAddr) - context.remoteHost = remoteHost + content.remoteHost = remoteHost - context.binReader = conn - context.binWriter = io.Discard + content.binReader = conn + content.binWriter = io.Discard exitFunc := func() { conn.Close() @@ -159,149 +159,149 @@ func (svc *Service) handleConn(conn *net.TCPConn, wg *sync.WaitGroup) { } defer recovFunc() - err = context.ReadRequest() + err = content.ReadRequest() if err != nil { - err = Err(err) + err = err return } - err = context.BindMethod() + err = content.BindMethod() if err != nil { - err = Err(err) + err = err return } for _, mw := range svc.preMw { - err = mw(context) + err = mw(content) if err != nil { - err = Err(err) + err = err return } } - err = svc.Route(context) + err = svc.Route(content) if err != nil { - err = Err(err) + err = err return } for _, mw := range svc.postMw { - err = mw(context) + err = mw(content) if err != nil { - err = Err(err) + err = err return } } return } -func (svc *Service) Route(context *Context) error { - handler, ok := svc.handlers[context.reqRPC.Method] +func (svc *Service) Route(content *Content) error { + handler, ok := svc.handlers[content.reqBlock.Method] if ok { - return Err(handler(context)) + return handler(content) } - return Err(notFound(context)) + return notFound(content) } -func (context *Context) ReadRequest() error { +func (content *Content) ReadRequest() error { var err error - context.reqPacket.header, err = ReadBytes(context.sockReader, headerSize) + content.reqPacket.header, err = ReadBytes(content.sockReader, headerSize) if err != nil { - return Err(err) + return err } - context.reqHeader, err = UnpackHeader(context.reqPacket.header) + content.reqHeader, err = UnpackHeader(content.reqPacket.header) if err != nil { - return Err(err) + return err } - rpcSize := context.reqHeader.rpcSize - context.reqPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize) + rpcSize := content.reqHeader.rpcSize + content.reqPacket.rcpPayload, err = ReadBytes(content.sockReader, rpcSize) if err != nil { - return Err(err) + return err } - return Err(err) + return err } -func (context *Context) BinWriter() io.Writer { - return context.sockWriter +func (content *Content) BinWriter() io.Writer { + return content.sockWriter } -func (context *Context) BinReader() io.Reader { - return context.sockReader +func (content *Content) BinReader() io.Reader { + return content.sockReader } -func (context *Context) BinSize() int64 { - return context.reqHeader.binSize +func (content *Content) BinSize() int64 { + return content.reqHeader.binSize } -func (context *Context) ReadBin(writer io.Writer) error { +func (content *Content) ReadBin(writer io.Writer) error { var err error - _, err = CopyBytes(context.sockReader, writer, context.reqHeader.binSize) - return Err(err) + _, err = CopyBytes(content.sockReader, writer, content.reqHeader.binSize) + return err } -func (context *Context) BindMethod() error { +func (content *Content) BindMethod() error { var err error - err = encoder.Unmarshal(context.reqPacket.rcpPayload, context.reqRPC) - return Err(err) + err = encoder.Unmarshal(content.reqPacket.rcpPayload, content.reqBlock) + return err } -func (context *Context) BindParams(params any) error { +func (content *Content) BindParams(params any) error { var err error - context.reqRPC.Params = params - err = encoder.Unmarshal(context.reqPacket.rcpPayload, context.reqRPC) + content.reqBlock.Params = params + err = encoder.Unmarshal(content.reqPacket.rcpPayload, content.reqBlock) if err != nil { - return Err(err) + return err } - return Err(err) + return err } -func (context *Context) SendResult(result any, binSize int64) error { +func (content *Content) SendResult(result any, binSize int64) error { var err error - context.resRPC.Result = result + content.resBlock.Result = result - context.resPacket.rcpPayload, err = context.resRPC.Pack() + content.resPacket.rcpPayload, err = content.resBlock.Pack() if err != nil { - return Err(err) + return err } - context.resHeader.rpcSize = int64(len(context.resPacket.rcpPayload)) - context.resHeader.binSize = binSize + content.resHeader.rpcSize = int64(len(content.resPacket.rcpPayload)) + content.resHeader.binSize = binSize - context.resPacket.header, err = context.resHeader.Pack() + content.resPacket.header, err = content.resHeader.Pack() if err != nil { - return Err(err) + return err } - _, err = context.sockWriter.Write(context.resPacket.header) + _, err = content.sockWriter.Write(content.resPacket.header) if err != nil { - return Err(err) + return err } - _, err = context.sockWriter.Write(context.resPacket.rcpPayload) + _, err = content.sockWriter.Write(content.resPacket.rcpPayload) if err != nil { - return Err(err) + return err } - return Err(err) + return err } -func (context *Context) SendError(execErr error) error { +func (content *Content) SendError(execErr error) error { var err error - context.resRPC.Error = execErr.Error() - context.resRPC.Result = NewEmpty() + content.resBlock.Error = execErr.Error() + content.resBlock.Result = NewEmptyResult() - context.resPacket.rcpPayload, err = context.resRPC.Pack() + content.resPacket.rcpPayload, err = content.resBlock.Pack() if err != nil { - return Err(err) + return err } - context.resHeader.rpcSize = int64(len(context.resPacket.rcpPayload)) - context.resPacket.header, err = context.resHeader.Pack() + content.resHeader.rpcSize = int64(len(content.resPacket.rcpPayload)) + content.resPacket.header, err = content.resHeader.Pack() if err != nil { - return Err(err) + return err } - _, err = context.sockWriter.Write(context.resPacket.header) + _, err = content.sockWriter.Write(content.resPacket.header) if err != nil { - return Err(err) + return err } - _, err = context.sockWriter.Write(context.resPacket.rcpPayload) + _, err = content.sockWriter.Write(content.resPacket.rcpPayload) if err != nil { - return Err(err) + return err } - return Err(err) + return err } diff --git a/tools.go b/tools.go index ccc7bde..10721c4 100644 --- a/tools.go +++ b/tools.go @@ -13,7 +13,7 @@ import ( func ReadBytes(reader io.Reader, size int64) ([]byte, error) { buffer := make([]byte, size) read, err := io.ReadFull(reader, buffer) - return buffer[0:read], Err(err) + return buffer[0:read], err } func CopyBytes(reader io.Reader, writer io.Writer, dataSize int64) (int64, error) { @@ -39,19 +39,19 @@ func CopyBytes(reader io.Reader, writer io.Writer, dataSize int64) (int64, error received, err := reader.Read(buffer[0:bSize]) if err != nil { err = fmt.Errorf("read error: %v", err) - return total, Err(err) + return total, err } recorded, err := writer.Write(buffer[0:received]) if err != nil { err = fmt.Errorf("write error: %v", err) - return total, Err(err) + return total, err } if recorded != received { err = errors.New("size mismatch") - return total, Err(err) + return total, err } total += int64(recorded) remains -= int64(recorded) } - return total, Err(err) + return total, err } diff --git a/validate.go b/validate.go index f98f1c2..55bc762 100644 --- a/validate.go +++ b/validate.go @@ -9,25 +9,29 @@ import ( "net" ) -func LocalExec(method string, param any, result any, auth *Auth, handler HandlerFunc) error { +func LocalExec(method string, param, result any, auth *Auth, handler HandlerFunc) error { var err error cliConn, srvConn := NewFConn() - context := CreateContext(cliConn) - context.reqRPC.Method = method - context.reqRPC.Params = param - context.reqRPC.Auth = auth - context.resRPC.Result = result + content := CreateContent(cliConn) + content.reqBlock.Method = method - if context.reqRPC.Params == nil { - context.reqRPC.Params = NewEmpty() + if param != nil { + content.reqBlock.Params = param } - err = context.CreateRequest() + if auth != nil { + content.reqBlock.Auth = auth + } + if result != nil { + content.resBlock.Result = result + } + + err = content.CreateRequest() if err != nil { return err } - err = context.WriteRequest() + err = content.WriteRequest() if err != nil { return err } @@ -35,11 +39,11 @@ func LocalExec(method string, param any, result any, auth *Auth, handler Handler if err != nil { return err } - err = context.ReadResponse() + err = content.ReadResponse() if err != nil { return err } - err = context.BindResponse() + err = content.BindResponse() if err != nil { return err } @@ -53,29 +57,33 @@ func LocalPut(method string, reader io.Reader, size int64, param, result any, au cliConn, srvConn := NewFConn() - context := CreateContext(cliConn) - context.reqRPC.Method = method - context.reqRPC.Params = param - context.reqRPC.Auth = auth - context.resRPC.Result = result + content := CreateContent(cliConn) + content.reqBlock.Method = method - context.binReader = reader - context.binWriter = cliConn - - context.reqHeader.binSize = size - - if context.reqRPC.Params == nil { - context.reqRPC.Params = NewEmpty() + if param != nil { + content.reqBlock.Params = param } - err = context.CreateRequest() + if auth != nil { + content.reqBlock.Auth = auth + } + if result != nil { + content.resBlock.Result = result + } + + content.binReader = reader + content.binWriter = cliConn + + content.reqHeader.binSize = size + + err = content.CreateRequest() if err != nil { return err } - err = context.WriteRequest() + err = content.WriteRequest() if err != nil { return err } - err = context.UploadBin() + err = content.UploadBin() if err != nil { return err } @@ -83,11 +91,11 @@ func LocalPut(method string, reader io.Reader, size int64, param, result any, au if err != nil { return err } - err = context.ReadResponse() + err = content.ReadResponse() if err != nil { return err } - err = context.BindResponse() + err = content.BindResponse() if err != nil { return err } @@ -99,23 +107,27 @@ func LocalGet(method string, writer io.Writer, param, result any, auth *Auth, ha cliConn, srvConn := NewFConn() - context := CreateContext(cliConn) - context.reqRPC.Method = method - context.reqRPC.Params = param - context.reqRPC.Auth = auth - context.resRPC.Result = result + content := CreateContent(cliConn) + content.reqBlock.Method = method - context.binReader = cliConn - context.binWriter = writer - - if context.reqRPC.Params == nil { - context.reqRPC.Params = NewEmpty() + if param != nil { + content.reqBlock.Params = param } - err = context.CreateRequest() + if auth != nil { + content.reqBlock.Auth = auth + } + if result != nil { + content.resBlock.Result = result + } + + content.binReader = cliConn + content.binWriter = writer + + err = content.CreateRequest() if err != nil { return err } - err = context.WriteRequest() + err = content.WriteRequest() if err != nil { return err } @@ -124,15 +136,15 @@ func LocalGet(method string, writer io.Writer, param, result any, auth *Auth, ha if err != nil { return err } - err = context.ReadResponse() + err = content.ReadResponse() if err != nil { return err } - err = context.DownloadBin() + err = content.DownloadBin() if err != nil { return err } - err = context.BindResponse() + err = content.BindResponse() if err != nil { return err } @@ -141,22 +153,22 @@ func LocalGet(method string, writer io.Writer, param, result any, auth *Auth, ha func LocalService(conn net.Conn, handler HandlerFunc) error { var err error - context := CreateContext(conn) + content := CreateContent(conn) remoteAddr := conn.RemoteAddr().String() remoteHost, _, _ := net.SplitHostPort(remoteAddr) - context.remoteHost = remoteHost + content.remoteHost = remoteHost - context.binReader = conn - context.binWriter = io.Discard + content.binReader = conn + content.binWriter = io.Discard - err = context.ReadRequest() + err = content.ReadRequest() if err != nil { return err } - err = context.BindMethod() + err = content.BindMethod() if err != nil { return err } - return handler(context) + return handler(content) }