2 Commits

18 changed files with 592 additions and 179 deletions

214
README.md
View File

@@ -4,21 +4,20 @@ DSRPC is easy and simple RPC framework over TCP socket.
### Purpose ### Purpose
A very easy and open RPC framework with data streaming. A very easy and open RPC framework with data streaming.
### You can
### You can - Use own post and pre-execution middleware
- Use post and pre-execution middleware
- Hash-based authentication in middleware - Hash-based authentication in middleware
- Test call remote function without service organization - Test remote function without network
Socket encryption is not used at this time since framefork Socket encryption is not used at this time since framefork
is oriented to transfer large amounts of data is oriented to transfer large amounts of data.
Style of the framework is similar to that of GIN framework. Style of the framework is similar of GIN framework.
## Example ## Exec method example
### Server ### Server
@@ -135,7 +134,7 @@ package api
const HelloMethod string = "hello" const HelloMethod string = "hello"
type HelloParams struct { type HelloParams struct {
Message string `msgpack:"message" json:"message"` Message string `json:"message"`
} }
func NewHelloParams() *HelloParams { func NewHelloParams() *HelloParams {
@@ -143,7 +142,7 @@ func NewHelloParams() *HelloParams {
} }
type HelloResult struct { type HelloResult struct {
Message string `msgpack:"message" json:"message"` Message string `json:"message"`
} }
func NewHelloResult() *HelloResult { func NewHelloResult() *HelloResult {
@@ -151,3 +150,196 @@ func NewHelloResult() *HelloResult {
} }
``` ```
### Authentication and authorization
#### Client side
```
func clientHello() error {
var err error
params := NewHelloParams()
params.Message = "hello server!"
result := NewHelloResult()
auth := dsrpc.CreateAuth([]byte("login"), []byte("password"))
err = dsrpc.Exec("127.0.0.1:8081", HelloMethod, params, result, auth)
if err != nil {
log.Println("method err:", err)
return err
}
//...
}
```
#### Server side
```
func authMiddleware(context *dsrpc.Context) error {
var err error
reqIdent := context.AuthIdent()
reqSalt := context.AuthSalt()
reqHash := context.AuthHash()
if reqIdent != "login" {
err = errors.New("auth ident or pass mismatch")
context.SendError(err)
return err
}
ident := reqIdent
pass := []byte("password")
ok := dsrpc.CheckHash(ident, pass, reqSalt, reqHash)
log.Println("auth is ok:", ok)
if !ok {
err = errors.New("auth ident or pass mismatch")
context.SendError(err)
return err
}
return err
}
func sampleServ(quiet bool) error {
var err error
if quiet {
dsrpc.SetAccessWriter(io.Discard)
dsrpc.SetMessageWriter(io.Discard)
}
serv := NewService()
serv.PreMiddleware(authMiddleware)
serv.PreMiddleware(dsrpc.LogRequest)
serv.Handler(HelloMethod, helloHandler)
serv.Handler(SaveMethod, saveHandler)
serv.Handler(LoadMethod, loadHandler)
serv.PostMiddleware(dsrpc.LogResponse)
serv.PostMiddleware(dsrpc.LogAccess)
err = serv.Listen(":8081")
if err != nil {
return err
}
return err
}
```
### Put method
#### Client side sample
```
var binSize int64 = 16
rand.Seed(time.Now().UnixNano())
binBytes := make([]byte, binSize)
rand.Read(binBytes)
reader := bytes.NewReader(binBytes)
err = dsrpc.Put("127.0.0.1:8081", SaveMethod, reader, binSize, params, result, auth)
```
#### Server side
```
func saveHandler(context *dsrpc.Context) error {
var err error
params := NewSaveParams()
err = context.BindParams(params)
if err != nil {
return err
}
bufferBytes := make([]byte, 0, 1024)
binWriter := bytes.NewBuffer(bufferBytes)
err = context.ReadBin(binWriter)
if err != nil {
context.SendError(err)
return err
}
result := NewSaveResult()
result.Message = "saved successfully!"
err = context.SendResult(result, 0)
if err != nil {
return err
}
return err
}
```
### Get method
#### Client side
```
params := NewLoadParams()
params.Message = "load data!"
result := NewHelloResult()
auth := CreateAuth([]byte("qwert"), []byte("12345"))
binBytes := make([]byte, 0)
writer := bytes.NewBuffer(binBytes)
err = dsrpc.Get("127.0.0.1:8081", LoadMethod, writer, params, result, auth)
if err != nil {
return err
}
//...
```
#### Server side
```
func getHandler(context *dsrpc.Context) error {
var err error
params := NewSaveParams()
err = context.BindParams(params)
if err != nil {
return err
}
var binSize int64 = 1024
rand.Seed(time.Now().UnixNano())
binBytes := make([]byte, binSize)
rand.Read(binBytes)
binReader := bytes.NewReader(binBytes)
result := NewSaveResult()
result.Message = "load successfully!"
err = context.SendResult(result, binSize)
if err != nil {
return err
}
binWriter := context.BinWriter()
_, err = dsrpc.CopyBytes(binReader, binWriter, binSize)
if err != nil {
return err
}
return err
}
```

140
client.go
View File

@@ -1,29 +1,46 @@
/* /*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org> * Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/ */
package dsrpc package dsrpc
import ( import (
"encoding/json"
"errors" "errors"
"fmt"
"io" "io"
"net" "net"
"sync" "sync"
encoder "github.com/vmihailenco/msgpack/v5"
) )
func Put(address string, method string, reader io.Reader, size int64, param, result any, auth *Auth) error { func Put(address string, method string, reader io.Reader, size int64, param, result any, auth *Auth) error {
var err error var err error
conn, err := net.Dial("tcp", address) addr, err := net.ResolveTCPAddr("tcp", address)
if err != nil { if err != nil {
return err err = fmt.Errorf("unable to resolve adddress: %s", err)
return Err(err)
}
conn, err := net.DialTCP("tcp", nil, addr)
if err != nil {
return Err(err)
} }
defer conn.Close() defer conn.Close()
//err = conn.SetKeepAlive(true)
//if err != nil {
// err = fmt.Errorf("unable to set keepalive: %s", err)
// return Err(err)
//}
//err = conn.SetKeepAlivePeriod(10 * time.Second)
//if err != nil {
// err = fmt.Errorf("unable to set keepalive period: %s", err)
// return Err(err)
//}
return ConnPut(conn, method, reader, size, param, result, auth) return ConnPut(conn, method, reader, size, param, result, auth)
} }
@@ -47,11 +64,11 @@ func ConnPut(conn net.Conn, method string, reader io.Reader, size int64, param,
err = context.CreateRequest() err = context.CreateRequest()
if err != nil { if err != nil {
return err return Err(err)
} }
err = context.WriteRequest() err = context.WriteRequest()
if err != nil { if err != nil {
return err return Err(err)
} }
var wg sync.WaitGroup var wg sync.WaitGroup
@@ -66,24 +83,41 @@ func ConnPut(conn net.Conn, method string, reader io.Reader, size int64, param,
wg.Wait() wg.Wait()
err = <- errChan err = <- errChan
if err != nil { if err != nil {
return err return Err(err)
} }
err = context.BindResponse() err = context.BindResponse()
if err != nil { if err != nil {
return err return Err(err)
} }
return err return Err(err)
} }
func Get(address string, method string, writer io.Writer, param, result any, auth *Auth) error { func Get(address string, method string, writer io.Writer, param, result any, auth *Auth) error {
var err error var err error
conn, err := net.Dial("tcp", address) addr, err := net.ResolveTCPAddr("tcp", address)
if err != nil { if err != nil {
return err err = fmt.Errorf("unable to resolve adddress: %s", err)
return Err(err)
}
conn, err := net.DialTCP("tcp", nil, addr)
if err != nil {
return Err(err)
} }
defer conn.Close() defer conn.Close()
//err = conn.SetKeepAlive(true)
//if err != nil {
// err = fmt.Errorf("unable to set keepalive: %s", err)
// return Err(err)
//}
//err = conn.SetKeepAlivePeriod(10 * time.Second)
//if err != nil {
// err = fmt.Errorf("unable to set keepalive period: %s", err)
// return Err(err)
//}
return ConnGet(conn, method, writer, param, result, auth) return ConnGet(conn, method, writer, param, result, auth)
} }
@@ -104,41 +138,57 @@ func ConnGet(conn net.Conn, method string, writer io.Writer, param, result any,
} }
err = context.CreateRequest() err = context.CreateRequest()
if err != nil { if err != nil {
return err return Err(err)
} }
err = context.WriteRequest() err = context.WriteRequest()
if err != nil { if err != nil {
return err return Err(err)
} }
err = context.ReadResponse() err = context.ReadResponse()
if err != nil { if err != nil {
return err return Err(err)
} }
err = context.DownloadBin() err = context.DownloadBin()
if err != nil { if err != nil {
return err return Err(err)
} }
err = context.BindResponse() err = context.BindResponse()
if err != nil { if err != nil {
return err return Err(err)
} }
return err return Err(err)
} }
func Exec(address, method string, param any, result any, auth *Auth) error { func Exec(address, method string, param any, result any, auth *Auth) error {
var err error var err error
conn, err := net.Dial("tcp", address) addr, err := net.ResolveTCPAddr("tcp", address)
if err != nil { if err != nil {
return err err = fmt.Errorf("unable to resolve adddress: %s", err)
return Err(err)
}
conn, err := net.DialTCP("tcp", nil, addr)
if err != nil {
return Err(err)
} }
defer conn.Close() defer conn.Close()
//err = conn.SetKeepAlive(true)
//if err != nil {
// err = fmt.Errorf("unable to set keepalive: %s", err)
// return Err(err)
//}
//err = conn.SetKeepAlivePeriod(10 * time.Second)
//if err != nil {
// err = fmt.Errorf("unable to set keepalive period: %s", err)
// return Err(err)
//}
err = ConnExec(conn, method, param, result, auth) err = ConnExec(conn, method, param, result, auth)
if err != nil { if err != nil {
return err return Err(err)
} }
return err return Err(err)
} }
@@ -157,21 +207,21 @@ func ConnExec(conn net.Conn, method string, param any, result any, auth *Auth) e
err = context.CreateRequest() err = context.CreateRequest()
if err != nil { if err != nil {
return err return Err(err)
} }
err = context.WriteRequest() err = context.WriteRequest()
if err != nil { if err != nil {
return err return Err(err)
} }
err = context.ReadResponse() err = context.ReadResponse()
if err != nil { if err != nil {
return err return Err(err)
} }
err = context.BindResponse() err = context.BindResponse()
if err != nil { if err != nil {
return err return Err(err)
} }
return err return Err(err)
} }
@@ -180,35 +230,35 @@ func (context *Context) CreateRequest() error {
context.reqPacket.rcpPayload, err = context.reqRPC.Pack() context.reqPacket.rcpPayload, err = context.reqRPC.Pack()
if err != nil { if err != nil {
return err return Err(err)
} }
rpcSize := int64(len(context.reqPacket.rcpPayload)) rpcSize := int64(len(context.reqPacket.rcpPayload))
context.reqHeader.rpcSize = rpcSize context.reqHeader.rpcSize = rpcSize
context.reqPacket.header, err = context.reqHeader.Pack() context.reqPacket.header, err = context.reqHeader.Pack()
if err != nil { if err != nil {
return err return Err(err)
} }
return err return Err(err)
} }
func (context *Context) WriteRequest() error { func (context *Context) WriteRequest() error {
var err error var err error
_, err = context.sockWriter.Write(context.reqPacket.header) _, err = context.sockWriter.Write(context.reqPacket.header)
if err != nil { if err != nil {
return err return Err(err)
} }
_, err = context.sockWriter.Write(context.reqPacket.rcpPayload) _, err = context.sockWriter.Write(context.reqPacket.rcpPayload)
if err != nil { if err != nil {
return err return Err(err)
} }
return err return Err(err)
} }
func (context *Context) UploadBin() error { func (context *Context) UploadBin() error {
var err error var err error
_, err = CopyBytes(context.binReader, context.binWriter, context.reqHeader.binSize) _, err = CopyBytes(context.binReader, context.binWriter, context.reqHeader.binSize)
return err return Err(err)
} }
func (context *Context) ReadResponse() error { func (context *Context) ReadResponse() error {
@@ -216,18 +266,18 @@ func (context *Context) ReadResponse() error {
context.resPacket.header, err = ReadBytes(context.sockReader, headerSize) context.resPacket.header, err = ReadBytes(context.sockReader, headerSize)
if err != nil { if err != nil {
return err return Err(err)
} }
context.resHeader, err = UnpackHeader(context.resPacket.header) context.resHeader, err = UnpackHeader(context.resPacket.header)
if err != nil { if err != nil {
return err return Err(err)
} }
rpcSize := context.resHeader.rpcSize rpcSize := context.resHeader.rpcSize
context.resPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize) context.resPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize)
if err != nil { if err != nil {
return err return Err(err)
} }
return err return Err(err)
} }
func (context *Context) UploadBinAsync(wg *sync.WaitGroup) { func (context *Context) UploadBinAsync(wg *sync.WaitGroup) {
@@ -248,15 +298,18 @@ func (context *Context) ReadResponseAsync(wg *sync.WaitGroup, errChan chan error
defer exitFunc() defer exitFunc()
context.resPacket.header, err = ReadBytes(context.sockReader, headerSize) context.resPacket.header, err = ReadBytes(context.sockReader, headerSize)
if err != nil { if err != nil {
err = Err(err)
return return
} }
context.resHeader, err = UnpackHeader(context.resPacket.header) context.resHeader, err = UnpackHeader(context.resPacket.header)
if err != nil { if err != nil {
err = Err(err)
return return
} }
rpcSize := context.resHeader.rpcSize rpcSize := context.resHeader.rpcSize
context.resPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize) context.resPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize)
if err != nil { if err != nil {
err = Err(err)
return return
} }
return return
@@ -265,18 +318,19 @@ func (context *Context) ReadResponseAsync(wg *sync.WaitGroup, errChan chan error
func (context *Context) DownloadBin() error { func (context *Context) DownloadBin() error {
var err error var err error
_, err = CopyBytes(context.binReader, context.binWriter, context.resHeader.binSize) _, err = CopyBytes(context.binReader, context.binWriter, context.resHeader.binSize)
return err return Err(err)
} }
func (context *Context) BindResponse() error { func (context *Context) BindResponse() error {
var err error var err error
err = json.Unmarshal(context.resPacket.rcpPayload, context.resRPC) err = encoder.Unmarshal(context.resPacket.rcpPayload, context.resRPC)
if err != nil { if err != nil {
return err return Err(err)
} }
if len(context.resRPC.Error) > 0 { if len(context.resRPC.Error) > 0 {
return errors.New(context.resRPC.Error) err = errors.New(context.resRPC.Error)
return Err(err)
} }
return err return Err(err)
} }

View File

@@ -1,3 +1,7 @@
/*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*/
package dsrpc package dsrpc
type any = interface{} type any = interface{}

View File

@@ -1,7 +1,5 @@
/* /*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org> * Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/ */
package dsrpc package dsrpc
@@ -60,6 +58,74 @@ func (context *Context) Request() *Request {
return context.reqRPC return context.reqRPC
} }
func (context *Context) RemoteHost() string {
return context.remoteHost
}
func (context *Context) Start() time.Time {
return context.start
}
func (context *Context) Method() string {
var method string
if context.reqRPC != nil {
method = context.reqRPC.Method
}
return method
}
func (context *Context) ReqRpcSize() int64 {
var size int64
if context.reqHeader != nil {
size = context.reqHeader.rpcSize
}
return size
}
func (context *Context) ReqBinSize() int64 {
var size int64
if context.reqHeader != nil {
size = context.reqHeader.binSize
}
return size
}
func (context *Context) ResBinSize() int64 {
var size int64
if context.resHeader != nil {
size = context.resHeader.binSize
}
return size
}
func (context *Context) ResRpcSize() int64 {
var size int64
if context.resHeader != nil {
size = context.resHeader.rpcSize
}
return size
}
func (context *Context) ReqSize() int64 {
var size int64
if context.reqHeader != nil {
size += context.reqHeader.binSize
size += context.reqHeader.rpcSize
}
return size
}
func (context *Context) ResSize() int64 {
var size int64
if context.resHeader != nil {
size += context.resHeader.binSize
size += context.resHeader.rpcSize
}
return size
}
func (context *Context) SetAuthIdent(ident []byte) { func (context *Context) SetAuthIdent(ident []byte) {
context.reqRPC.Auth.Ident = ident context.reqRPC.Auth.Ident = ident

43
error.go Normal file
View File

@@ -0,0 +1,43 @@
/*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*/
package dsrpc
import (
"fmt"
"runtime"
"io"
)
var develMode bool = false
var debugMode bool = false
func SetDevelMode(mode bool) {
develMode = mode
}
func SetDebugMode(mode bool) {
debugMode = mode
}
func Err(err error) error {
switch err {
case io.EOF:
return err
}
if err != nil {
switch {
case develMode == true:
pc, filename, line, _ := runtime.Caller(1)
funcName := runtime.FuncForPC(pc).Name()
err = fmt.Errorf(" %s:%d:%s:%s", filename, line, funcName, err.Error())
case debugMode == true:
pc, _, line, _ := runtime.Caller(1)
funcName := runtime.FuncForPC(pc).Name()
err = fmt.Errorf(" %s:%d:%s ", funcName, line, err.Error())
default:
}
}
return err
}

View File

@@ -1,7 +1,5 @@
/* /*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org> * Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/ */
package dsrpc package dsrpc
@@ -335,7 +333,7 @@ func loadHandler(context *Context) error {
const HelloMethod string = "hello" const HelloMethod string = "hello"
type HelloParams struct { type HelloParams struct {
Message string `json:"message" json:"message"` Message string `json:"message" msgpack:"message"`
} }
func NewHelloParams() *HelloParams { func NewHelloParams() *HelloParams {
@@ -343,7 +341,7 @@ func NewHelloParams() *HelloParams {
} }
type HelloResult struct { type HelloResult struct {
Message string `json:"message" json:"message"` Message string `json:"message" msgpack:"message"`
} }
func NewHelloResult() *HelloResult { func NewHelloResult() *HelloResult {

View File

@@ -1,7 +1,5 @@
/* /*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org> * Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/ */
package dsrpc package dsrpc

10
go.mod
View File

@@ -2,10 +2,14 @@ module github.com/kindsoldier/dsrpc
go 1.17 go 1.17
require github.com/stretchr/testify v1.7.1 require (
github.com/stretchr/testify v1.8.0
github.com/vmihailenco/msgpack/v5 v5.3.5
)
require ( require (
github.com/davecgh/go-spew v1.1.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
) )

15
go.sum
View File

@@ -1,11 +1,20 @@
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU=
github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc=
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -33,30 +33,30 @@ func NewHeader() *Header {
} }
} }
func (this *Header) JSON() []byte { func (hdr *Header) JSON() []byte {
jBytes, _ := json.Marshal(this) jBytes, _ := json.Marshal(hdr)
return jBytes return jBytes
} }
func (this *Header) Pack() ([]byte, error) { func (hdr *Header) Pack() ([]byte, error) {
var err error var err error
headerBytes := make([]byte, 0, headerSize) headerBytes := make([]byte, 0, headerSize)
headerBuffer := bytes.NewBuffer(headerBytes) headerBuffer := bytes.NewBuffer(headerBytes)
magicCodeABytes := encoderI64(this.magicCodeA) magicCodeABytes := encoderI64(hdr.magicCodeA)
headerBuffer.Write(magicCodeABytes) headerBuffer.Write(magicCodeABytes)
rpcSizeBytes := encoderI64(this.rpcSize) rpcSizeBytes := encoderI64(hdr.rpcSize)
headerBuffer.Write(rpcSizeBytes) headerBuffer.Write(rpcSizeBytes)
binSizeBytes := encoderI64(this.binSize) binSizeBytes := encoderI64(hdr.binSize)
headerBuffer.Write(binSizeBytes) headerBuffer.Write(binSizeBytes)
magicCodeBBytes := encoderI64(this.magicCodeB) magicCodeBBytes := encoderI64(hdr.magicCodeB)
headerBuffer.Write(magicCodeBBytes) headerBuffer.Write(magicCodeBBytes)
return headerBuffer.Bytes(), err return headerBuffer.Bytes(), Err(err)
} }
func UnpackHeader(headerBytes []byte) (*Header, error) { func UnpackHeader(headerBytes []byte) (*Header, error) {
@@ -81,10 +81,10 @@ func UnpackHeader(headerBytes []byte) (*Header, error) {
header.magicCodeB = decoderI64(magicCodeBBytes) header.magicCodeB = decoderI64(magicCodeBBytes)
if header.magicCodeA != magicCodeA || header.magicCodeB != magicCodeB { if header.magicCodeA != magicCodeA || header.magicCodeB != magicCodeB {
return header, errors.New("wrong protocol magic code") err = errors.New("wrong protocol magic code")
return header, Err(err)
} }
return header, Err(err)
return header, err
} }
func encoderI64(i int64) []byte { func encoderI64(i int64) []byte {

View File

@@ -17,22 +17,22 @@ var messageWriter io.Writer = os.Stdout
var accessWriter io.Writer = os.Stdout var accessWriter io.Writer = os.Stdout
func logDebug(messages ...any) { func logDebug(messages ...any) {
stamp := time.Now().Format(time.RFC3339Nano) stamp := time.Now().Format(time.RFC3339)
fmt.Fprintln(messageWriter, stamp, "debug", messages) fmt.Fprintln(messageWriter, stamp, "debug", messages)
} }
func logInfo(messages ...any) { func logInfo(messages ...any) {
stamp := time.Now().Format(time.RFC3339Nano) stamp := time.Now().Format(time.RFC3339)
fmt.Fprintln(messageWriter, stamp, "info", messages) fmt.Fprintln(messageWriter, stamp, "info", messages)
} }
func logError(messages ...any) { func logError(messages ...any) {
stamp := time.Now().Format(time.RFC3339Nano) stamp := time.Now().Format(time.RFC3339)
fmt.Fprintln(messageWriter, stamp, "error", messages) fmt.Fprintln(messageWriter, stamp, "error", messages)
} }
func logAccess(messages ...any) { func logAccess(messages ...any) {
stamp := time.Now().Format(time.RFC3339Nano) stamp := time.Now().Format(time.RFC3339)
fmt.Fprintln(accessWriter, stamp, "access", messages) fmt.Fprintln(accessWriter, stamp, "access", messages)
} }

View File

@@ -1,10 +1,7 @@
/* /*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org> * Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/ */
package dsrpc package dsrpc
import ( import (
@@ -14,18 +11,19 @@ import (
func LogRequest(context *Context) error { func LogRequest(context *Context) error {
var err error var err error
logDebug("request:", string(context.reqRPC.JSON())) logDebug("request:", string(context.reqRPC.JSON()))
return err return Err(err)
} }
func LogResponse(context *Context) error { func LogResponse(context *Context) error {
var err error var err error
logDebug("response:", string(context.resRPC.JSON())) logDebug("response:", string(context.resRPC.JSON()))
return err return Err(err)
} }
func LogAccess(context *Context) error { func LogAccess(context *Context) error {
var err error var err error
execTime := time.Now().Sub(context.start) execTime := time.Now().Sub(context.start)
logAccess(context.remoteHost, context.reqRPC.Method, execTime) login := string(context.AuthIdent())
return err logAccess(context.remoteHost, login, context.reqRPC.Method, execTime)
return Err(err)
} }

View File

@@ -1,7 +1,5 @@
/* /*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org> * Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/ */
package dsrpc package dsrpc
@@ -19,7 +17,7 @@ func NewPacket() *Packet {
return &Packet{} return &Packet{}
} }
func (this *Packet) JSON() []byte { func (pkt *Packet) JSON() []byte {
jBytes, _ := json.Marshal(this) jBytes, _ := json.Marshal(pkt)
return jBytes return jBytes
} }

View File

@@ -8,12 +8,13 @@ package dsrpc
import ( import (
"encoding/json" "encoding/json"
encoder "github.com/vmihailenco/msgpack/v5"
) )
type Request struct { type Request struct {
Method string `json:"method" msgpack:"method"` Method string `json:"method" msgpack:"method"`
Params any `json:"params,omitempty" msgpack:"params,omitempty"` Params any `json:"params,omitempty" msgpack:"params"`
Auth *Auth `json:"auth,omitempty" msgpack:"auth,omitempty"` Auth *Auth `json:"auth,omitempty" msgpack:"auth"`
} }
func NewRequest() *Request { func NewRequest() *Request {
@@ -22,12 +23,12 @@ func NewRequest() *Request {
return req return req
} }
func (this *Request) Pack() ([]byte, error) { func (req *Request) Pack() ([]byte, error) {
rBytes, err := json.Marshal(this) rBytes, err := encoder.Marshal(req)
return rBytes, err return rBytes, Err(err)
} }
func (this *Request) JSON() []byte { func (req *Request) JSON() []byte {
jBytes, _ := json.Marshal(this) jBytes, _ := json.Marshal(req)
return jBytes return jBytes
} }

View File

@@ -1,31 +1,30 @@
/* /*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org> * Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/ */
package dsrpc package dsrpc
import ( import (
"encoding/json" "encoding/json"
encoder "github.com/vmihailenco/msgpack/v5"
) )
type Response struct { type Response struct {
Error string `json:"error,omitempty" msgpack:"error,omitempty"` Error string `json:"error" msgpack:"error"`
Result any `json:"result,omitemty" msgpack:"result,omitemty"` Result any `json:"result" msgpack:"result"`
} }
func NewResponse() *Response { func NewResponse() *Response {
return &Response{} return &Response{}
} }
func (this *Response) JSON() []byte { func (resp *Response) JSON() []byte {
jBytes, _ := json.Marshal(this) jBytes, _ := json.Marshal(resp)
return jBytes return jBytes
} }
func (this *Response) Pack() ([]byte, error) { func (resp *Response) Pack() ([]byte, error) {
rBytes, err := json.Marshal(this) rBytes, err := encoder.Marshal(resp)
return rBytes, err return rBytes, Err(err)
} }

164
server.go
View File

@@ -1,29 +1,33 @@
/* /*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org> * Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/ */
package dsrpc package dsrpc
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"fmt"
"io" "io"
"net" "net"
"sync" "sync"
"time"
encoder "github.com/vmihailenco/msgpack/v5"
) )
type HandlerFunc = func(*Context) error type HandlerFunc = func(*Context) error
type Service struct { type Service struct {
handlers map[string]HandlerFunc handlers map[string]HandlerFunc
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
wg *sync.WaitGroup wg *sync.WaitGroup
preMw []HandlerFunc preMw []HandlerFunc
postMw []HandlerFunc postMw []HandlerFunc
keepalive bool
kaTime time.Duration
kaMtx sync.Mutex
} }
func NewService() *Service { func NewService() *Service {
@@ -40,41 +44,59 @@ func NewService() *Service {
return rdrpc return rdrpc
} }
func (this *Service) PreMiddleware(mw HandlerFunc) { func (svc *Service) PreMiddleware(mw HandlerFunc) {
this.preMw = append(this.preMw, mw) svc.preMw = append(svc.preMw, mw)
} }
func (this *Service) PostMiddleware(mw HandlerFunc) { func (svc *Service) PostMiddleware(mw HandlerFunc) {
this.postMw = append(this.postMw, mw) svc.postMw = append(svc.postMw, mw)
} }
func (svc *Service) Handler(method string, handler HandlerFunc) {
func (this *Service) Handler(method string, handler HandlerFunc) { svc.handlers[method] = handler
this.handlers[method] = handler
} }
func (this *Service) Listen(address string) error { func (svc *Service) SetKeepAlive(flag bool) {
svc.kaMtx.Lock()
defer svc.kaMtx.Unlock()
svc.keepalive = true
}
func (svc *Service) SetKeepAlivePeriod(interval time.Duration) {
svc.kaMtx.Lock()
defer svc.kaMtx.Unlock()
svc.kaTime = interval
}
func (svc *Service) Listen(address string) error {
var err error var err error
logInfo("server listen:", address) logInfo("server listen:", address)
listener, err := net.Listen("tcp", address)
addr, err := net.ResolveTCPAddr("tcp", address)
if err != nil { if err != nil {
err = fmt.Errorf("unable to resolve adddress: %s", err)
return err return err
} }
this.wg.Add(1) listener, err := net.ListenTCP("tcp", addr)
if err != nil {
err = fmt.Errorf("unable to start listener: %s", err)
return err
}
for { for {
select { conn, err := listener.AcceptTCP()
case <- this.ctx.Done():
this.wg.Done()
return err
default:
}
conn, err := listener.Accept()
if err != nil { if err != nil {
logError("conn accept err:", err) logError("conn accept err:", err)
} }
select {
go this.handleConn(conn) case <-svc.ctx.Done():
return err
default:
}
svc.wg.Add(1)
go svc.handleConn(conn, svc.wg)
} }
return err
} }
func notFound(context *Context) error { func notFound(context *Context) error {
@@ -83,16 +105,34 @@ func notFound(context *Context) error {
return err return err
} }
func (this *Service) Stop() error { func (svc *Service) Stop() error {
var err error var err error
this.cancel() // Disable new connection
this.wg.Wait() logInfo("cancel rpc accept loop")
svc.cancel()
// Wait handlers
logInfo("wait rpc handlers")
svc.wg.Wait()
return err return err
} }
func (this *Service) handleConn(conn net.Conn) { func (svc *Service) handleConn(conn *net.TCPConn, wg *sync.WaitGroup) {
var err error var err error
if svc.keepalive {
err = conn.SetKeepAlive(true)
if err != nil {
err = fmt.Errorf("unable to set keepalive: %s", err)
return
}
if svc.kaTime > 0 {
err = conn.SetKeepAlivePeriod(svc.kaTime)
if err != nil {
err = fmt.Errorf("unable to set keepalive period: %s", err)
return
}
}
}
context := CreateContext(conn) context := CreateContext(conn)
remoteAddr := conn.RemoteAddr().String() remoteAddr := conn.RemoteAddr().String()
@@ -104,6 +144,7 @@ func (this *Service) handleConn(conn net.Conn) {
exitFunc := func() { exitFunc := func() {
conn.Close() conn.Close()
wg.Done()
if err != nil { if err != nil {
logError("conn handler err:", err) logError("conn handler err:", err)
} }
@@ -120,38 +161,43 @@ func (this *Service) handleConn(conn net.Conn) {
err = context.ReadRequest() err = context.ReadRequest()
if err != nil { if err != nil {
err = Err(err)
return return
} }
err = context.BindMethod() err = context.BindMethod()
if err != nil { if err != nil {
err = Err(err)
return return
} }
for _, mw := range this.preMw { for _, mw := range svc.preMw {
err = mw(context) err = mw(context)
if err != nil { if err != nil {
err = Err(err)
return return
} }
} }
err = this.Route(context) err = svc.Route(context)
if err != nil { if err != nil {
err = Err(err)
return return
} }
for _, mw := range this.postMw { for _, mw := range svc.postMw {
err = mw(context) err = mw(context)
if err != nil { if err != nil {
err = Err(err)
return return
} }
} }
return return
} }
func (this *Service) Route(context *Context) error { func (svc *Service) Route(context *Context) error {
handler, ok := this.handlers[context.reqRPC.Method] handler, ok := svc.handlers[context.reqRPC.Method]
if ok { if ok {
return handler(context) return Err(handler(context))
} }
return notFound(context) return Err(notFound(context))
} }
func (context *Context) ReadRequest() error { func (context *Context) ReadRequest() error {
@@ -159,19 +205,19 @@ func (context *Context) ReadRequest() error {
context.reqPacket.header, err = ReadBytes(context.sockReader, headerSize) context.reqPacket.header, err = ReadBytes(context.sockReader, headerSize)
if err != nil { if err != nil {
return err return Err(err)
} }
context.reqHeader, err = UnpackHeader(context.reqPacket.header) context.reqHeader, err = UnpackHeader(context.reqPacket.header)
if err != nil { if err != nil {
return err return Err(err)
} }
rpcSize := context.reqHeader.rpcSize rpcSize := context.reqHeader.rpcSize
context.reqPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize) context.reqPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize)
if err != nil { if err != nil {
return err return Err(err)
} }
return err return Err(err)
} }
func (context *Context) BinWriter() io.Writer { func (context *Context) BinWriter() io.Writer {
@@ -189,24 +235,24 @@ func (context *Context) BinSize() int64 {
func (context *Context) ReadBin(writer io.Writer) error { func (context *Context) ReadBin(writer io.Writer) error {
var err error var err error
_, err = CopyBytes(context.sockReader, writer, context.reqHeader.binSize) _, err = CopyBytes(context.sockReader, writer, context.reqHeader.binSize)
return err return Err(err)
} }
func (context *Context) BindMethod() error { func (context *Context) BindMethod() error {
var err error var err error
err = json.Unmarshal(context.reqPacket.rcpPayload, context.reqRPC) err = encoder.Unmarshal(context.reqPacket.rcpPayload, context.reqRPC)
return err return Err(err)
} }
func (context *Context) BindParams(params any) error { func (context *Context) BindParams(params any) error {
var err error var err error
context.reqRPC.Params = params context.reqRPC.Params = params
err = json.Unmarshal(context.reqPacket.rcpPayload, context.reqRPC) err = encoder.Unmarshal(context.reqPacket.rcpPayload, context.reqRPC)
if err != nil { if err != nil {
return err return Err(err)
} }
return err return Err(err)
} }
func (context *Context) SendResult(result any, binSize int64) error { func (context *Context) SendResult(result any, binSize int64) error {
@@ -215,24 +261,24 @@ func (context *Context) SendResult(result any, binSize int64) error {
context.resPacket.rcpPayload, err = context.resRPC.Pack() context.resPacket.rcpPayload, err = context.resRPC.Pack()
if err != nil { if err != nil {
return err return Err(err)
} }
context.resHeader.rpcSize = int64(len(context.resPacket.rcpPayload)) context.resHeader.rpcSize = int64(len(context.resPacket.rcpPayload))
context.resHeader.binSize = binSize context.resHeader.binSize = binSize
context.resPacket.header, err = context.resHeader.Pack() context.resPacket.header, err = context.resHeader.Pack()
if err != nil { if err != nil {
return err return Err(err)
} }
_, err = context.sockWriter.Write(context.resPacket.header) _, err = context.sockWriter.Write(context.resPacket.header)
if err != nil { if err != nil {
return err return Err(err)
} }
_, err = context.sockWriter.Write(context.resPacket.rcpPayload) _, err = context.sockWriter.Write(context.resPacket.rcpPayload)
if err != nil { if err != nil {
return err return Err(err)
} }
return err return Err(err)
} }
@@ -244,20 +290,20 @@ func (context *Context) SendError(execErr error) error {
context.resPacket.rcpPayload, err = context.resRPC.Pack() context.resPacket.rcpPayload, err = context.resRPC.Pack()
if err != nil { if err != nil {
return err return Err(err)
} }
context.resHeader.rpcSize = int64(len(context.resPacket.rcpPayload)) context.resHeader.rpcSize = int64(len(context.resPacket.rcpPayload))
context.resPacket.header, err = context.resHeader.Pack() context.resPacket.header, err = context.resHeader.Pack()
if err != nil { if err != nil {
return err return Err(err)
} }
_, err = context.sockWriter.Write(context.resPacket.header) _, err = context.sockWriter.Write(context.resPacket.header)
if err != nil { if err != nil {
return err return Err(err)
} }
_, err = context.sockWriter.Write(context.resPacket.rcpPayload) _, err = context.sockWriter.Write(context.resPacket.rcpPayload)
if err != nil { if err != nil {
return err return Err(err)
} }
return err return Err(err)
} }

View File

@@ -13,12 +13,12 @@ import (
func ReadBytes(reader io.Reader, size int64) ([]byte, error) { func ReadBytes(reader io.Reader, size int64) ([]byte, error) {
buffer := make([]byte, size) buffer := make([]byte, size)
read, err := io.ReadFull(reader, buffer) read, err := io.ReadFull(reader, buffer)
return buffer[0:read], err return buffer[0:read], Err(err)
} }
func CopyBytes(reader io.Reader, writer io.Writer, dataSize int64) (int64, error) { func CopyBytes(reader io.Reader, writer io.Writer, dataSize int64) (int64, error) {
var err error var err error
var bSize int64 = 1024 * 4 var bSize int64 = 1024 * 16
var total int64 = 0 var total int64 = 0
var remains int64 = dataSize var remains int64 = dataSize
buffer := make([]byte, bSize) buffer := make([]byte, bSize)
@@ -38,17 +38,20 @@ func CopyBytes(reader io.Reader, writer io.Writer, dataSize int64) (int64, error
} }
received, err := reader.Read(buffer[0:bSize]) received, err := reader.Read(buffer[0:bSize])
if err != nil { if err != nil {
return total, fmt.Errorf("read error: %v", err) err = fmt.Errorf("read error: %v", err)
return total, Err(err)
} }
recorded, err := writer.Write(buffer[0:received]) recorded, err := writer.Write(buffer[0:received])
if err != nil { if err != nil {
return total, fmt.Errorf("write error: %v", err) err = fmt.Errorf("write error: %v", err)
return total, Err(err)
} }
if recorded != received { if recorded != received {
return total, errors.New("size mismatch") err = errors.New("size mismatch")
return total, Err(err)
} }
total += int64(recorded) total += int64(recorded)
remains -= int64(recorded) remains -= int64(recorded)
} }
return total, err return total, Err(err)
} }

View File

@@ -17,9 +17,9 @@ func init() {
} }
type Auth struct { type Auth struct {
Ident []byte `json:"ident,omitempty"` Ident []byte `msgpack:"ident" json:"ident"`
Salt []byte `json:"salt,omitempty"` Salt []byte `msgpack:"salt" json:"salt"`
Hash []byte `json:"hash,omitempty"` Hash []byte `msgpack:"hash" json:"hash"`
} }
func NewAuth() *Auth { func NewAuth() *Auth {