From cddeee14128504708b4dc4050b94f04f75f29423 Mon Sep 17 00:00:00 2001 From: Oleg Borodin Date: Fri, 8 Jul 2022 10:58:42 +0200 Subject: [PATCH] added put/get and auth examples to README, added Err(err) debug wrapper --- README.md | 214 +++++++++++++++++++++++++++++++++++++++++++++++++--- client.go | 77 ++++++++++--------- error.go | 39 ++++++++++ header.go | 8 +- logger.go | 16 ++-- midware.go | 6 +- request.go | 2 +- response.go | 2 +- server.go | 47 ++++++------ tools.go | 13 ++-- 10 files changed, 333 insertions(+), 91 deletions(-) create mode 100644 error.go diff --git a/README.md b/README.md index c9196b7..17c571f 100644 --- a/README.md +++ b/README.md @@ -4,21 +4,20 @@ DSRPC is easy and simple RPC framework over TCP socket. ### Purpose -A very easy and open RPC framework with data streaming. +A very easy and open RPC framework with data streaming. +### You can -### You can - -- Use post and pre-execution middleware +- Use own post and pre-execution middleware - Hash-based authentication in middleware -- Test call remote function without service organization +- Test remote function without network -Socket encryption is not used at this time since framefork -is oriented to transfer large amounts of data +Socket encryption is not used at this time since framefork +is oriented to transfer large amounts of data. -Style of the framework is similar to that of GIN framework. +Style of the framework is similar of GIN framework. -## Example +## Exec method example ### Server @@ -135,7 +134,7 @@ package api const HelloMethod string = "hello" type HelloParams struct { - Message string `msgpack:"message" json:"message"` + Message string `json:"message"` } func NewHelloParams() *HelloParams { @@ -143,7 +142,7 @@ func NewHelloParams() *HelloParams { } type HelloResult struct { - Message string `msgpack:"message" json:"message"` + Message string `json:"message"` } func NewHelloResult() *HelloResult { @@ -151,3 +150,196 @@ func NewHelloResult() *HelloResult { } ``` + +### Authentication and authorization + +#### Client side + +``` + +func clientHello() error { + var err error + + params := NewHelloParams() + params.Message = "hello server!" + result := NewHelloResult() + + auth := dsrpc.CreateAuth([]byte("login"), []byte("password")) + + err = dsrpc.Exec("127.0.0.1:8081", HelloMethod, params, result, auth) + if err != nil { + log.Println("method err:", err) + return err + } + + //... +} + + +``` + +#### Server side + +``` + +func authMiddleware(context *dsrpc.Context) error { + var err error + reqIdent := context.AuthIdent() + reqSalt := context.AuthSalt() + reqHash := context.AuthHash() + + if reqIdent != "login" { + err = errors.New("auth ident or pass mismatch") + context.SendError(err) + return err + } + + ident := reqIdent + pass := []byte("password") + + ok := dsrpc.CheckHash(ident, pass, reqSalt, reqHash) + log.Println("auth is ok:", ok) + if !ok { + err = errors.New("auth ident or pass mismatch") + context.SendError(err) + return err + } + return err +} + +func sampleServ(quiet bool) error { + var err error + + if quiet { + dsrpc.SetAccessWriter(io.Discard) + dsrpc.SetMessageWriter(io.Discard) + } + serv := NewService() + + serv.PreMiddleware(authMiddleware) + serv.PreMiddleware(dsrpc.LogRequest) + + serv.Handler(HelloMethod, helloHandler) + serv.Handler(SaveMethod, saveHandler) + serv.Handler(LoadMethod, loadHandler) + + serv.PostMiddleware(dsrpc.LogResponse) + serv.PostMiddleware(dsrpc.LogAccess) + + err = serv.Listen(":8081") + if err != nil { + return err + } + return err +} + +``` + +### Put method + +#### Client side sample + +``` + var binSize int64 = 16 + rand.Seed(time.Now().UnixNano()) + binBytes := make([]byte, binSize) + rand.Read(binBytes) + + reader := bytes.NewReader(binBytes) + + err = dsrpc.Put("127.0.0.1:8081", SaveMethod, reader, binSize, params, result, auth) + +``` +#### Server side + +``` +func saveHandler(context *dsrpc.Context) error { + var err error + params := NewSaveParams() + + err = context.BindParams(params) + if err != nil { + return err + } + + bufferBytes := make([]byte, 0, 1024) + binWriter := bytes.NewBuffer(bufferBytes) + + err = context.ReadBin(binWriter) + if err != nil { + context.SendError(err) + return err + } + + result := NewSaveResult() + result.Message = "saved successfully!" + + err = context.SendResult(result, 0) + if err != nil { + return err + } + return err +} + +``` + +### Get method + +#### Client side + +``` + params := NewLoadParams() + params.Message = "load data!" + result := NewHelloResult() + auth := CreateAuth([]byte("qwert"), []byte("12345")) + + binBytes := make([]byte, 0) + writer := bytes.NewBuffer(binBytes) + + err = dsrpc.Get("127.0.0.1:8081", LoadMethod, writer, params, result, auth) + if err != nil { + return err + } + + //... + +``` + +#### Server side + +``` + +func getHandler(context *dsrpc.Context) error { + var err error + params := NewSaveParams() + + err = context.BindParams(params) + if err != nil { + return err + } + + var binSize int64 = 1024 + + rand.Seed(time.Now().UnixNano()) + binBytes := make([]byte, binSize) + rand.Read(binBytes) + + binReader := bytes.NewReader(binBytes) + + result := NewSaveResult() + result.Message = "load successfully!" + + err = context.SendResult(result, binSize) + if err != nil { + return err + } + binWriter := context.BinWriter() + _, err = dsrpc.CopyBytes(binReader, binWriter, binSize) + if err != nil { + return err + } + + return err +} + +``` diff --git a/client.go b/client.go index 74d0806..6b72051 100644 --- a/client.go +++ b/client.go @@ -20,7 +20,7 @@ func Put(address string, method string, reader io.Reader, size int64, param, res conn, err := net.Dial("tcp", address) if err != nil { - return err + return Err(err) } defer conn.Close() @@ -47,11 +47,11 @@ func ConnPut(conn net.Conn, method string, reader io.Reader, size int64, param, err = context.CreateRequest() if err != nil { - return err + return Err(err) } err = context.WriteRequest() if err != nil { - return err + return Err(err) } var wg sync.WaitGroup @@ -66,13 +66,13 @@ func ConnPut(conn net.Conn, method string, reader io.Reader, size int64, param, wg.Wait() err = <- errChan if err != nil { - return err + return Err(err) } err = context.BindResponse() if err != nil { - return err + return Err(err) } - return err + return Err(err) } func Get(address string, method string, writer io.Writer, param, result any, auth *Auth) error { @@ -80,7 +80,7 @@ func Get(address string, method string, writer io.Writer, param, result any, aut conn, err := net.Dial("tcp", address) if err != nil { - return err + return Err(err) } defer conn.Close() @@ -104,41 +104,40 @@ func ConnGet(conn net.Conn, method string, writer io.Writer, param, result any, } err = context.CreateRequest() if err != nil { - return err + return Err(err) } err = context.WriteRequest() if err != nil { - return err + return Err(err) } err = context.ReadResponse() if err != nil { - return err + return Err(err) } err = context.DownloadBin() if err != nil { - return err + return Err(err) } err = context.BindResponse() if err != nil { - return err + return Err(err) } - return err + return Err(err) } func Exec(address, method string, param any, result any, auth *Auth) error { var err error - conn, err := net.Dial("tcp", address) if err != nil { - return err + return Err(err) } defer conn.Close() err = ConnExec(conn, method, param, result, auth) if err != nil { - return err + return Err(err) } - return err + return Err(err) } @@ -157,21 +156,21 @@ func ConnExec(conn net.Conn, method string, param any, result any, auth *Auth) e err = context.CreateRequest() if err != nil { - return err + return Err(err) } err = context.WriteRequest() if err != nil { - return err + return Err(err) } err = context.ReadResponse() if err != nil { - return err + return Err(err) } err = context.BindResponse() if err != nil { - return err + return Err(err) } - return err + return Err(err) } @@ -180,35 +179,35 @@ func (context *Context) CreateRequest() error { context.reqPacket.rcpPayload, err = context.reqRPC.Pack() if err != nil { - return err + return Err(err) } rpcSize := int64(len(context.reqPacket.rcpPayload)) context.reqHeader.rpcSize = rpcSize context.reqPacket.header, err = context.reqHeader.Pack() if err != nil { - return err + return Err(err) } - return err + return Err(err) } func (context *Context) WriteRequest() error { var err error _, err = context.sockWriter.Write(context.reqPacket.header) if err != nil { - return err + return Err(err) } _, err = context.sockWriter.Write(context.reqPacket.rcpPayload) if err != nil { - return err + return Err(err) } - return err + return Err(err) } func (context *Context) UploadBin() error { var err error _, err = CopyBytes(context.binReader, context.binWriter, context.reqHeader.binSize) - return err + return Err(err) } func (context *Context) ReadResponse() error { @@ -216,18 +215,18 @@ func (context *Context) ReadResponse() error { context.resPacket.header, err = ReadBytes(context.sockReader, headerSize) if err != nil { - return err + return Err(err) } context.resHeader, err = UnpackHeader(context.resPacket.header) if err != nil { - return err + return Err(err) } rpcSize := context.resHeader.rpcSize context.resPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize) if err != nil { - return err + return Err(err) } - return err + return Err(err) } func (context *Context) UploadBinAsync(wg *sync.WaitGroup) { @@ -248,15 +247,18 @@ func (context *Context) ReadResponseAsync(wg *sync.WaitGroup, errChan chan error defer exitFunc() context.resPacket.header, err = ReadBytes(context.sockReader, headerSize) if err != nil { + err = Err(err) return } context.resHeader, err = UnpackHeader(context.resPacket.header) if err != nil { + err = Err(err) return } rpcSize := context.resHeader.rpcSize context.resPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize) if err != nil { + err = Err(err) return } return @@ -265,7 +267,7 @@ func (context *Context) ReadResponseAsync(wg *sync.WaitGroup, errChan chan error func (context *Context) DownloadBin() error { var err error _, err = CopyBytes(context.binReader, context.binWriter, context.resHeader.binSize) - return err + return Err(err) } func (context *Context) BindResponse() error { @@ -273,10 +275,11 @@ func (context *Context) BindResponse() error { err = json.Unmarshal(context.resPacket.rcpPayload, context.resRPC) if err != nil { - return err + return Err(err) } if len(context.resRPC.Error) > 0 { - return errors.New(context.resRPC.Error) + err = errors.New(context.resRPC.Error) + return Err(err) } - return err + return Err(err) } diff --git a/error.go b/error.go new file mode 100644 index 0000000..fe473e8 --- /dev/null +++ b/error.go @@ -0,0 +1,39 @@ +package dsrpc + +import ( + "fmt" + "runtime" + "io" +) + +var develMode bool = false +var debugMode bool = false + + +func SetDevelMode(mode bool) { + develMode = mode +} +func SetDebugMode(mode bool) { + debugMode = mode +} + +func Err(err error) error { + switch err { + case io.EOF: + return err + } + if err != nil { + switch { + case develMode == true: + pc, filename, line, _ := runtime.Caller(1) + funcName := runtime.FuncForPC(pc).Name() + err = fmt.Errorf(" %s:%d:%s:%s", filename, line, funcName, err.Error()) + case debugMode == true: + pc, _, line, _ := runtime.Caller(1) + funcName := runtime.FuncForPC(pc).Name() + err = fmt.Errorf(" %s:%d:%s ", funcName, line, err.Error()) + default: + } + } + return err +} diff --git a/header.go b/header.go index 1c98286..33f22c4 100644 --- a/header.go +++ b/header.go @@ -56,7 +56,7 @@ func (this *Header) Pack() ([]byte, error) { magicCodeBBytes := encoderI64(this.magicCodeB) headerBuffer.Write(magicCodeBBytes) - return headerBuffer.Bytes(), err + return headerBuffer.Bytes(), Err(err) } func UnpackHeader(headerBytes []byte) (*Header, error) { @@ -81,10 +81,10 @@ func UnpackHeader(headerBytes []byte) (*Header, error) { header.magicCodeB = decoderI64(magicCodeBBytes) if header.magicCodeA != magicCodeA || header.magicCodeB != magicCodeB { - return header, errors.New("wrong protocol magic code") + err = errors.New("wrong protocol magic code") + return header, Err(err) } - - return header, err + return header, Err(err) } func encoderI64(i int64) []byte { diff --git a/logger.go b/logger.go index ec59bfb..d8aab96 100644 --- a/logger.go +++ b/logger.go @@ -17,23 +17,23 @@ var messageWriter io.Writer = os.Stdout var accessWriter io.Writer = os.Stdout func logDebug(messages ...any) { - stamp := time.Now().Format(time.RFC3339Nano) - fmt.Fprintln(messageWriter, stamp, "debug", messages) + stamp := time.Now().Format(time.RFC3339) + fmt.Fprintln(messageWriter, stamp, "dsrpc debug", messages) } func logInfo(messages ...any) { - stamp := time.Now().Format(time.RFC3339Nano) - fmt.Fprintln(messageWriter, stamp, "info", messages) + stamp := time.Now().Format(time.RFC3339) + fmt.Fprintln(messageWriter, stamp, "dsrpc info", messages) } func logError(messages ...any) { - stamp := time.Now().Format(time.RFC3339Nano) - fmt.Fprintln(messageWriter, stamp, "error", messages) + stamp := time.Now().Format(time.RFC3339) + fmt.Fprintln(messageWriter, stamp, "dsrpc error", messages) } func logAccess(messages ...any) { - stamp := time.Now().Format(time.RFC3339Nano) - fmt.Fprintln(accessWriter, stamp, "access", messages) + stamp := time.Now().Format(time.RFC3339) + fmt.Fprintln(accessWriter, stamp, "dsrpc access", messages) } func SetAccessWriter(writer io.Writer) { diff --git a/midware.go b/midware.go index fda341a..75f871b 100644 --- a/midware.go +++ b/midware.go @@ -14,18 +14,18 @@ import ( func LogRequest(context *Context) error { var err error logDebug("request:", string(context.reqRPC.JSON())) - return err + return Err(err) } func LogResponse(context *Context) error { var err error logDebug("response:", string(context.resRPC.JSON())) - return err + return Err(err) } func LogAccess(context *Context) error { var err error execTime := time.Now().Sub(context.start) logAccess(context.remoteHost, context.reqRPC.Method, execTime) - return err + return Err(err) } diff --git a/request.go b/request.go index 8090608..18cf183 100644 --- a/request.go +++ b/request.go @@ -24,7 +24,7 @@ func NewRequest() *Request { func (this *Request) Pack() ([]byte, error) { rBytes, err := json.Marshal(this) - return rBytes, err + return rBytes, Err(err) } func (this *Request) JSON() []byte { diff --git a/response.go b/response.go index 084c892..a03bb03 100644 --- a/response.go +++ b/response.go @@ -27,5 +27,5 @@ func (this *Response) JSON() []byte { func (this *Response) Pack() ([]byte, error) { rBytes, err := json.Marshal(this) - return rBytes, err + return rBytes, Err(err) } diff --git a/server.go b/server.go index 4f4add6..791a19b 100644 --- a/server.go +++ b/server.go @@ -56,6 +56,7 @@ func (this *Service) Handler(method string, handler HandlerFunc) { func (this *Service) Listen(address string) error { var err error logInfo("server listen:", address) + listener, err := net.Listen("tcp", address) if err != nil { return err @@ -72,7 +73,6 @@ func (this *Service) Listen(address string) error { if err != nil { logError("conn accept err:", err) } - go this.handleConn(conn) } } @@ -120,26 +120,31 @@ func (this *Service) handleConn(conn net.Conn) { err = context.ReadRequest() if err != nil { + err = Err(err) return } err = context.BindMethod() if err != nil { + err = Err(err) return } for _, mw := range this.preMw { err = mw(context) if err != nil { + err = Err(err) return } } err = this.Route(context) if err != nil { + err = Err(err) return } for _, mw := range this.postMw { err = mw(context) if err != nil { + err = Err(err) return } } @@ -149,9 +154,9 @@ func (this *Service) handleConn(conn net.Conn) { func (this *Service) Route(context *Context) error { handler, ok := this.handlers[context.reqRPC.Method] if ok { - return handler(context) + return Err(handler(context)) } - return notFound(context) + return Err(notFound(context)) } func (context *Context) ReadRequest() error { @@ -159,19 +164,19 @@ func (context *Context) ReadRequest() error { context.reqPacket.header, err = ReadBytes(context.sockReader, headerSize) if err != nil { - return err + return Err(err) } context.reqHeader, err = UnpackHeader(context.reqPacket.header) if err != nil { - return err + return Err(err) } rpcSize := context.reqHeader.rpcSize context.reqPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize) if err != nil { - return err + return Err(err) } - return err + return Err(err) } func (context *Context) BinWriter() io.Writer { @@ -189,14 +194,14 @@ func (context *Context) BinSize() int64 { func (context *Context) ReadBin(writer io.Writer) error { var err error _, err = CopyBytes(context.sockReader, writer, context.reqHeader.binSize) - return err + return Err(err) } func (context *Context) BindMethod() error { var err error err = json.Unmarshal(context.reqPacket.rcpPayload, context.reqRPC) - return err + return Err(err) } func (context *Context) BindParams(params any) error { @@ -204,9 +209,9 @@ func (context *Context) BindParams(params any) error { context.reqRPC.Params = params err = json.Unmarshal(context.reqPacket.rcpPayload, context.reqRPC) if err != nil { - return err + return Err(err) } - return err + return Err(err) } func (context *Context) SendResult(result any, binSize int64) error { @@ -215,24 +220,24 @@ func (context *Context) SendResult(result any, binSize int64) error { context.resPacket.rcpPayload, err = context.resRPC.Pack() if err != nil { - return err + return Err(err) } context.resHeader.rpcSize = int64(len(context.resPacket.rcpPayload)) context.resHeader.binSize = binSize context.resPacket.header, err = context.resHeader.Pack() if err != nil { - return err + return Err(err) } _, err = context.sockWriter.Write(context.resPacket.header) if err != nil { - return err + return Err(err) } _, err = context.sockWriter.Write(context.resPacket.rcpPayload) if err != nil { - return err + return Err(err) } - return err + return Err(err) } @@ -244,20 +249,20 @@ func (context *Context) SendError(execErr error) error { context.resPacket.rcpPayload, err = context.resRPC.Pack() if err != nil { - return err + return Err(err) } context.resHeader.rpcSize = int64(len(context.resPacket.rcpPayload)) context.resPacket.header, err = context.resHeader.Pack() if err != nil { - return err + return Err(err) } _, err = context.sockWriter.Write(context.resPacket.header) if err != nil { - return err + return Err(err) } _, err = context.sockWriter.Write(context.resPacket.rcpPayload) if err != nil { - return err + return Err(err) } - return err + return Err(err) } diff --git a/tools.go b/tools.go index 554983b..d7e1da6 100644 --- a/tools.go +++ b/tools.go @@ -13,7 +13,7 @@ import ( func ReadBytes(reader io.Reader, size int64) ([]byte, error) { buffer := make([]byte, size) read, err := io.ReadFull(reader, buffer) - return buffer[0:read], err + return buffer[0:read], Err(err) } func CopyBytes(reader io.Reader, writer io.Writer, dataSize int64) (int64, error) { @@ -38,17 +38,20 @@ func CopyBytes(reader io.Reader, writer io.Writer, dataSize int64) (int64, error } received, err := reader.Read(buffer[0:bSize]) if err != nil { - return total, fmt.Errorf("read error: %v", err) + err = fmt.Errorf("read error: %v", err) + return total, Err(err) } recorded, err := writer.Write(buffer[0:received]) if err != nil { - return total, fmt.Errorf("write error: %v", err) + err = fmt.Errorf("write error: %v", err) + return total, Err(err) } if recorded != received { - return total, errors.New("size mismatch") + err = errors.New("size mismatch") + return total, Err(err) } total += int64(recorded) remains -= int64(recorded) } - return total, err + return total, Err(err) }