2 Commits

18 changed files with 592 additions and 179 deletions

208
README.md
View File

@@ -6,19 +6,18 @@ DSRPC is easy and simple RPC framework over TCP socket.
A very easy and open RPC framework with data streaming.
### You can
- Use post and pre-execution middleware
- Use own post and pre-execution middleware
- Hash-based authentication in middleware
- Test call remote function without service organization
- Test remote function without network
Socket encryption is not used at this time since framefork
is oriented to transfer large amounts of data
is oriented to transfer large amounts of data.
Style of the framework is similar to that of GIN framework.
Style of the framework is similar of GIN framework.
## Example
## Exec method example
### Server
@@ -135,7 +134,7 @@ package api
const HelloMethod string = "hello"
type HelloParams struct {
Message string `msgpack:"message" json:"message"`
Message string `json:"message"`
}
func NewHelloParams() *HelloParams {
@@ -143,7 +142,7 @@ func NewHelloParams() *HelloParams {
}
type HelloResult struct {
Message string `msgpack:"message" json:"message"`
Message string `json:"message"`
}
func NewHelloResult() *HelloResult {
@@ -151,3 +150,196 @@ func NewHelloResult() *HelloResult {
}
```
### Authentication and authorization
#### Client side
```
func clientHello() error {
var err error
params := NewHelloParams()
params.Message = "hello server!"
result := NewHelloResult()
auth := dsrpc.CreateAuth([]byte("login"), []byte("password"))
err = dsrpc.Exec("127.0.0.1:8081", HelloMethod, params, result, auth)
if err != nil {
log.Println("method err:", err)
return err
}
//...
}
```
#### Server side
```
func authMiddleware(context *dsrpc.Context) error {
var err error
reqIdent := context.AuthIdent()
reqSalt := context.AuthSalt()
reqHash := context.AuthHash()
if reqIdent != "login" {
err = errors.New("auth ident or pass mismatch")
context.SendError(err)
return err
}
ident := reqIdent
pass := []byte("password")
ok := dsrpc.CheckHash(ident, pass, reqSalt, reqHash)
log.Println("auth is ok:", ok)
if !ok {
err = errors.New("auth ident or pass mismatch")
context.SendError(err)
return err
}
return err
}
func sampleServ(quiet bool) error {
var err error
if quiet {
dsrpc.SetAccessWriter(io.Discard)
dsrpc.SetMessageWriter(io.Discard)
}
serv := NewService()
serv.PreMiddleware(authMiddleware)
serv.PreMiddleware(dsrpc.LogRequest)
serv.Handler(HelloMethod, helloHandler)
serv.Handler(SaveMethod, saveHandler)
serv.Handler(LoadMethod, loadHandler)
serv.PostMiddleware(dsrpc.LogResponse)
serv.PostMiddleware(dsrpc.LogAccess)
err = serv.Listen(":8081")
if err != nil {
return err
}
return err
}
```
### Put method
#### Client side sample
```
var binSize int64 = 16
rand.Seed(time.Now().UnixNano())
binBytes := make([]byte, binSize)
rand.Read(binBytes)
reader := bytes.NewReader(binBytes)
err = dsrpc.Put("127.0.0.1:8081", SaveMethod, reader, binSize, params, result, auth)
```
#### Server side
```
func saveHandler(context *dsrpc.Context) error {
var err error
params := NewSaveParams()
err = context.BindParams(params)
if err != nil {
return err
}
bufferBytes := make([]byte, 0, 1024)
binWriter := bytes.NewBuffer(bufferBytes)
err = context.ReadBin(binWriter)
if err != nil {
context.SendError(err)
return err
}
result := NewSaveResult()
result.Message = "saved successfully!"
err = context.SendResult(result, 0)
if err != nil {
return err
}
return err
}
```
### Get method
#### Client side
```
params := NewLoadParams()
params.Message = "load data!"
result := NewHelloResult()
auth := CreateAuth([]byte("qwert"), []byte("12345"))
binBytes := make([]byte, 0)
writer := bytes.NewBuffer(binBytes)
err = dsrpc.Get("127.0.0.1:8081", LoadMethod, writer, params, result, auth)
if err != nil {
return err
}
//...
```
#### Server side
```
func getHandler(context *dsrpc.Context) error {
var err error
params := NewSaveParams()
err = context.BindParams(params)
if err != nil {
return err
}
var binSize int64 = 1024
rand.Seed(time.Now().UnixNano())
binBytes := make([]byte, binSize)
rand.Read(binBytes)
binReader := bytes.NewReader(binBytes)
result := NewSaveResult()
result.Message = "load successfully!"
err = context.SendResult(result, binSize)
if err != nil {
return err
}
binWriter := context.BinWriter()
_, err = dsrpc.CopyBytes(binReader, binWriter, binSize)
if err != nil {
return err
}
return err
}
```

140
client.go
View File

@@ -1,29 +1,46 @@
/*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/
package dsrpc
import (
"encoding/json"
"errors"
"fmt"
"io"
"net"
"sync"
encoder "github.com/vmihailenco/msgpack/v5"
)
func Put(address string, method string, reader io.Reader, size int64, param, result any, auth *Auth) error {
var err error
conn, err := net.Dial("tcp", address)
addr, err := net.ResolveTCPAddr("tcp", address)
if err != nil {
return err
err = fmt.Errorf("unable to resolve adddress: %s", err)
return Err(err)
}
conn, err := net.DialTCP("tcp", nil, addr)
if err != nil {
return Err(err)
}
defer conn.Close()
//err = conn.SetKeepAlive(true)
//if err != nil {
// err = fmt.Errorf("unable to set keepalive: %s", err)
// return Err(err)
//}
//err = conn.SetKeepAlivePeriod(10 * time.Second)
//if err != nil {
// err = fmt.Errorf("unable to set keepalive period: %s", err)
// return Err(err)
//}
return ConnPut(conn, method, reader, size, param, result, auth)
}
@@ -47,11 +64,11 @@ func ConnPut(conn net.Conn, method string, reader io.Reader, size int64, param,
err = context.CreateRequest()
if err != nil {
return err
return Err(err)
}
err = context.WriteRequest()
if err != nil {
return err
return Err(err)
}
var wg sync.WaitGroup
@@ -66,24 +83,41 @@ func ConnPut(conn net.Conn, method string, reader io.Reader, size int64, param,
wg.Wait()
err = <- errChan
if err != nil {
return err
return Err(err)
}
err = context.BindResponse()
if err != nil {
return err
return Err(err)
}
return err
return Err(err)
}
func Get(address string, method string, writer io.Writer, param, result any, auth *Auth) error {
var err error
conn, err := net.Dial("tcp", address)
addr, err := net.ResolveTCPAddr("tcp", address)
if err != nil {
return err
err = fmt.Errorf("unable to resolve adddress: %s", err)
return Err(err)
}
conn, err := net.DialTCP("tcp", nil, addr)
if err != nil {
return Err(err)
}
defer conn.Close()
//err = conn.SetKeepAlive(true)
//if err != nil {
// err = fmt.Errorf("unable to set keepalive: %s", err)
// return Err(err)
//}
//err = conn.SetKeepAlivePeriod(10 * time.Second)
//if err != nil {
// err = fmt.Errorf("unable to set keepalive period: %s", err)
// return Err(err)
//}
return ConnGet(conn, method, writer, param, result, auth)
}
@@ -104,41 +138,57 @@ func ConnGet(conn net.Conn, method string, writer io.Writer, param, result any,
}
err = context.CreateRequest()
if err != nil {
return err
return Err(err)
}
err = context.WriteRequest()
if err != nil {
return err
return Err(err)
}
err = context.ReadResponse()
if err != nil {
return err
return Err(err)
}
err = context.DownloadBin()
if err != nil {
return err
return Err(err)
}
err = context.BindResponse()
if err != nil {
return err
return Err(err)
}
return err
return Err(err)
}
func Exec(address, method string, param any, result any, auth *Auth) error {
var err error
conn, err := net.Dial("tcp", address)
addr, err := net.ResolveTCPAddr("tcp", address)
if err != nil {
return err
err = fmt.Errorf("unable to resolve adddress: %s", err)
return Err(err)
}
conn, err := net.DialTCP("tcp", nil, addr)
if err != nil {
return Err(err)
}
defer conn.Close()
//err = conn.SetKeepAlive(true)
//if err != nil {
// err = fmt.Errorf("unable to set keepalive: %s", err)
// return Err(err)
//}
//err = conn.SetKeepAlivePeriod(10 * time.Second)
//if err != nil {
// err = fmt.Errorf("unable to set keepalive period: %s", err)
// return Err(err)
//}
err = ConnExec(conn, method, param, result, auth)
if err != nil {
return err
return Err(err)
}
return err
return Err(err)
}
@@ -157,21 +207,21 @@ func ConnExec(conn net.Conn, method string, param any, result any, auth *Auth) e
err = context.CreateRequest()
if err != nil {
return err
return Err(err)
}
err = context.WriteRequest()
if err != nil {
return err
return Err(err)
}
err = context.ReadResponse()
if err != nil {
return err
return Err(err)
}
err = context.BindResponse()
if err != nil {
return err
return Err(err)
}
return err
return Err(err)
}
@@ -180,35 +230,35 @@ func (context *Context) CreateRequest() error {
context.reqPacket.rcpPayload, err = context.reqRPC.Pack()
if err != nil {
return err
return Err(err)
}
rpcSize := int64(len(context.reqPacket.rcpPayload))
context.reqHeader.rpcSize = rpcSize
context.reqPacket.header, err = context.reqHeader.Pack()
if err != nil {
return err
return Err(err)
}
return err
return Err(err)
}
func (context *Context) WriteRequest() error {
var err error
_, err = context.sockWriter.Write(context.reqPacket.header)
if err != nil {
return err
return Err(err)
}
_, err = context.sockWriter.Write(context.reqPacket.rcpPayload)
if err != nil {
return err
return Err(err)
}
return err
return Err(err)
}
func (context *Context) UploadBin() error {
var err error
_, err = CopyBytes(context.binReader, context.binWriter, context.reqHeader.binSize)
return err
return Err(err)
}
func (context *Context) ReadResponse() error {
@@ -216,18 +266,18 @@ func (context *Context) ReadResponse() error {
context.resPacket.header, err = ReadBytes(context.sockReader, headerSize)
if err != nil {
return err
return Err(err)
}
context.resHeader, err = UnpackHeader(context.resPacket.header)
if err != nil {
return err
return Err(err)
}
rpcSize := context.resHeader.rpcSize
context.resPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize)
if err != nil {
return err
return Err(err)
}
return err
return Err(err)
}
func (context *Context) UploadBinAsync(wg *sync.WaitGroup) {
@@ -248,15 +298,18 @@ func (context *Context) ReadResponseAsync(wg *sync.WaitGroup, errChan chan error
defer exitFunc()
context.resPacket.header, err = ReadBytes(context.sockReader, headerSize)
if err != nil {
err = Err(err)
return
}
context.resHeader, err = UnpackHeader(context.resPacket.header)
if err != nil {
err = Err(err)
return
}
rpcSize := context.resHeader.rpcSize
context.resPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize)
if err != nil {
err = Err(err)
return
}
return
@@ -265,18 +318,19 @@ func (context *Context) ReadResponseAsync(wg *sync.WaitGroup, errChan chan error
func (context *Context) DownloadBin() error {
var err error
_, err = CopyBytes(context.binReader, context.binWriter, context.resHeader.binSize)
return err
return Err(err)
}
func (context *Context) BindResponse() error {
var err error
err = json.Unmarshal(context.resPacket.rcpPayload, context.resRPC)
err = encoder.Unmarshal(context.resPacket.rcpPayload, context.resRPC)
if err != nil {
return err
return Err(err)
}
if len(context.resRPC.Error) > 0 {
return errors.New(context.resRPC.Error)
err = errors.New(context.resRPC.Error)
return Err(err)
}
return err
return Err(err)
}

View File

@@ -1,3 +1,7 @@
/*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*/
package dsrpc
type any = interface{}

View File

@@ -1,7 +1,5 @@
/*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/
package dsrpc
@@ -60,6 +58,74 @@ func (context *Context) Request() *Request {
return context.reqRPC
}
func (context *Context) RemoteHost() string {
return context.remoteHost
}
func (context *Context) Start() time.Time {
return context.start
}
func (context *Context) Method() string {
var method string
if context.reqRPC != nil {
method = context.reqRPC.Method
}
return method
}
func (context *Context) ReqRpcSize() int64 {
var size int64
if context.reqHeader != nil {
size = context.reqHeader.rpcSize
}
return size
}
func (context *Context) ReqBinSize() int64 {
var size int64
if context.reqHeader != nil {
size = context.reqHeader.binSize
}
return size
}
func (context *Context) ResBinSize() int64 {
var size int64
if context.resHeader != nil {
size = context.resHeader.binSize
}
return size
}
func (context *Context) ResRpcSize() int64 {
var size int64
if context.resHeader != nil {
size = context.resHeader.rpcSize
}
return size
}
func (context *Context) ReqSize() int64 {
var size int64
if context.reqHeader != nil {
size += context.reqHeader.binSize
size += context.reqHeader.rpcSize
}
return size
}
func (context *Context) ResSize() int64 {
var size int64
if context.resHeader != nil {
size += context.resHeader.binSize
size += context.resHeader.rpcSize
}
return size
}
func (context *Context) SetAuthIdent(ident []byte) {
context.reqRPC.Auth.Ident = ident

43
error.go Normal file
View File

@@ -0,0 +1,43 @@
/*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*/
package dsrpc
import (
"fmt"
"runtime"
"io"
)
var develMode bool = false
var debugMode bool = false
func SetDevelMode(mode bool) {
develMode = mode
}
func SetDebugMode(mode bool) {
debugMode = mode
}
func Err(err error) error {
switch err {
case io.EOF:
return err
}
if err != nil {
switch {
case develMode == true:
pc, filename, line, _ := runtime.Caller(1)
funcName := runtime.FuncForPC(pc).Name()
err = fmt.Errorf(" %s:%d:%s:%s", filename, line, funcName, err.Error())
case debugMode == true:
pc, _, line, _ := runtime.Caller(1)
funcName := runtime.FuncForPC(pc).Name()
err = fmt.Errorf(" %s:%d:%s ", funcName, line, err.Error())
default:
}
}
return err
}

View File

@@ -1,7 +1,5 @@
/*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/
package dsrpc
@@ -335,7 +333,7 @@ func loadHandler(context *Context) error {
const HelloMethod string = "hello"
type HelloParams struct {
Message string `json:"message" json:"message"`
Message string `json:"message" msgpack:"message"`
}
func NewHelloParams() *HelloParams {
@@ -343,7 +341,7 @@ func NewHelloParams() *HelloParams {
}
type HelloResult struct {
Message string `json:"message" json:"message"`
Message string `json:"message" msgpack:"message"`
}
func NewHelloResult() *HelloResult {

View File

@@ -1,7 +1,5 @@
/*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/
package dsrpc

10
go.mod
View File

@@ -2,10 +2,14 @@ module github.com/kindsoldier/dsrpc
go 1.17
require github.com/stretchr/testify v1.7.1
require (
github.com/stretchr/testify v1.8.0
github.com/vmihailenco/msgpack/v5 v5.3.5
)
require (
github.com/davecgh/go-spew v1.1.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

15
go.sum
View File

@@ -1,11 +1,20 @@
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU=
github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc=
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -33,30 +33,30 @@ func NewHeader() *Header {
}
}
func (this *Header) JSON() []byte {
jBytes, _ := json.Marshal(this)
func (hdr *Header) JSON() []byte {
jBytes, _ := json.Marshal(hdr)
return jBytes
}
func (this *Header) Pack() ([]byte, error) {
func (hdr *Header) Pack() ([]byte, error) {
var err error
headerBytes := make([]byte, 0, headerSize)
headerBuffer := bytes.NewBuffer(headerBytes)
magicCodeABytes := encoderI64(this.magicCodeA)
magicCodeABytes := encoderI64(hdr.magicCodeA)
headerBuffer.Write(magicCodeABytes)
rpcSizeBytes := encoderI64(this.rpcSize)
rpcSizeBytes := encoderI64(hdr.rpcSize)
headerBuffer.Write(rpcSizeBytes)
binSizeBytes := encoderI64(this.binSize)
binSizeBytes := encoderI64(hdr.binSize)
headerBuffer.Write(binSizeBytes)
magicCodeBBytes := encoderI64(this.magicCodeB)
magicCodeBBytes := encoderI64(hdr.magicCodeB)
headerBuffer.Write(magicCodeBBytes)
return headerBuffer.Bytes(), err
return headerBuffer.Bytes(), Err(err)
}
func UnpackHeader(headerBytes []byte) (*Header, error) {
@@ -81,10 +81,10 @@ func UnpackHeader(headerBytes []byte) (*Header, error) {
header.magicCodeB = decoderI64(magicCodeBBytes)
if header.magicCodeA != magicCodeA || header.magicCodeB != magicCodeB {
return header, errors.New("wrong protocol magic code")
err = errors.New("wrong protocol magic code")
return header, Err(err)
}
return header, err
return header, Err(err)
}
func encoderI64(i int64) []byte {

View File

@@ -17,22 +17,22 @@ var messageWriter io.Writer = os.Stdout
var accessWriter io.Writer = os.Stdout
func logDebug(messages ...any) {
stamp := time.Now().Format(time.RFC3339Nano)
stamp := time.Now().Format(time.RFC3339)
fmt.Fprintln(messageWriter, stamp, "debug", messages)
}
func logInfo(messages ...any) {
stamp := time.Now().Format(time.RFC3339Nano)
stamp := time.Now().Format(time.RFC3339)
fmt.Fprintln(messageWriter, stamp, "info", messages)
}
func logError(messages ...any) {
stamp := time.Now().Format(time.RFC3339Nano)
stamp := time.Now().Format(time.RFC3339)
fmt.Fprintln(messageWriter, stamp, "error", messages)
}
func logAccess(messages ...any) {
stamp := time.Now().Format(time.RFC3339Nano)
stamp := time.Now().Format(time.RFC3339)
fmt.Fprintln(accessWriter, stamp, "access", messages)
}

View File

@@ -1,10 +1,7 @@
/*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/
package dsrpc
import (
@@ -14,18 +11,19 @@ import (
func LogRequest(context *Context) error {
var err error
logDebug("request:", string(context.reqRPC.JSON()))
return err
return Err(err)
}
func LogResponse(context *Context) error {
var err error
logDebug("response:", string(context.resRPC.JSON()))
return err
return Err(err)
}
func LogAccess(context *Context) error {
var err error
execTime := time.Now().Sub(context.start)
logAccess(context.remoteHost, context.reqRPC.Method, execTime)
return err
login := string(context.AuthIdent())
logAccess(context.remoteHost, login, context.reqRPC.Method, execTime)
return Err(err)
}

View File

@@ -1,7 +1,5 @@
/*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/
package dsrpc
@@ -19,7 +17,7 @@ func NewPacket() *Packet {
return &Packet{}
}
func (this *Packet) JSON() []byte {
jBytes, _ := json.Marshal(this)
func (pkt *Packet) JSON() []byte {
jBytes, _ := json.Marshal(pkt)
return jBytes
}

View File

@@ -8,12 +8,13 @@ package dsrpc
import (
"encoding/json"
encoder "github.com/vmihailenco/msgpack/v5"
)
type Request struct {
Method string `json:"method" msgpack:"method"`
Params any `json:"params,omitempty" msgpack:"params,omitempty"`
Auth *Auth `json:"auth,omitempty" msgpack:"auth,omitempty"`
Method string `json:"method" msgpack:"method"`
Params any `json:"params,omitempty" msgpack:"params"`
Auth *Auth `json:"auth,omitempty" msgpack:"auth"`
}
func NewRequest() *Request {
@@ -22,12 +23,12 @@ func NewRequest() *Request {
return req
}
func (this *Request) Pack() ([]byte, error) {
rBytes, err := json.Marshal(this)
return rBytes, err
func (req *Request) Pack() ([]byte, error) {
rBytes, err := encoder.Marshal(req)
return rBytes, Err(err)
}
func (this *Request) JSON() []byte {
jBytes, _ := json.Marshal(this)
func (req *Request) JSON() []byte {
jBytes, _ := json.Marshal(req)
return jBytes
}

View File

@@ -1,31 +1,30 @@
/*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/
package dsrpc
import (
"encoding/json"
encoder "github.com/vmihailenco/msgpack/v5"
)
type Response struct {
Error string `json:"error,omitempty" msgpack:"error,omitempty"`
Result any `json:"result,omitemty" msgpack:"result,omitemty"`
Error string `json:"error" msgpack:"error"`
Result any `json:"result" msgpack:"result"`
}
func NewResponse() *Response {
return &Response{}
}
func (this *Response) JSON() []byte {
jBytes, _ := json.Marshal(this)
func (resp *Response) JSON() []byte {
jBytes, _ := json.Marshal(resp)
return jBytes
}
func (this *Response) Pack() ([]byte, error) {
rBytes, err := json.Marshal(this)
return rBytes, err
func (resp *Response) Pack() ([]byte, error) {
rBytes, err := encoder.Marshal(resp)
return rBytes, Err(err)
}

164
server.go
View File

@@ -1,29 +1,33 @@
/*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/
package dsrpc
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"sync"
"time"
encoder "github.com/vmihailenco/msgpack/v5"
)
type HandlerFunc = func(*Context) error
type Service struct {
handlers map[string]HandlerFunc
ctx context.Context
cancel context.CancelFunc
wg *sync.WaitGroup
preMw []HandlerFunc
postMw []HandlerFunc
handlers map[string]HandlerFunc
ctx context.Context
cancel context.CancelFunc
wg *sync.WaitGroup
preMw []HandlerFunc
postMw []HandlerFunc
keepalive bool
kaTime time.Duration
kaMtx sync.Mutex
}
func NewService() *Service {
@@ -40,41 +44,59 @@ func NewService() *Service {
return rdrpc
}
func (this *Service) PreMiddleware(mw HandlerFunc) {
this.preMw = append(this.preMw, mw)
func (svc *Service) PreMiddleware(mw HandlerFunc) {
svc.preMw = append(svc.preMw, mw)
}
func (this *Service) PostMiddleware(mw HandlerFunc) {
this.postMw = append(this.postMw, mw)
func (svc *Service) PostMiddleware(mw HandlerFunc) {
svc.postMw = append(svc.postMw, mw)
}
func (this *Service) Handler(method string, handler HandlerFunc) {
this.handlers[method] = handler
func (svc *Service) Handler(method string, handler HandlerFunc) {
svc.handlers[method] = handler
}
func (this *Service) Listen(address string) error {
func (svc *Service) SetKeepAlive(flag bool) {
svc.kaMtx.Lock()
defer svc.kaMtx.Unlock()
svc.keepalive = true
}
func (svc *Service) SetKeepAlivePeriod(interval time.Duration) {
svc.kaMtx.Lock()
defer svc.kaMtx.Unlock()
svc.kaTime = interval
}
func (svc *Service) Listen(address string) error {
var err error
logInfo("server listen:", address)
listener, err := net.Listen("tcp", address)
addr, err := net.ResolveTCPAddr("tcp", address)
if err != nil {
err = fmt.Errorf("unable to resolve adddress: %s", err)
return err
}
this.wg.Add(1)
listener, err := net.ListenTCP("tcp", addr)
if err != nil {
err = fmt.Errorf("unable to start listener: %s", err)
return err
}
for {
select {
case <- this.ctx.Done():
this.wg.Done()
return err
default:
}
conn, err := listener.Accept()
conn, err := listener.AcceptTCP()
if err != nil {
logError("conn accept err:", err)
}
go this.handleConn(conn)
select {
case <-svc.ctx.Done():
return err
default:
}
svc.wg.Add(1)
go svc.handleConn(conn, svc.wg)
}
return err
}
func notFound(context *Context) error {
@@ -83,16 +105,34 @@ func notFound(context *Context) error {
return err
}
func (this *Service) Stop() error {
func (svc *Service) Stop() error {
var err error
this.cancel()
this.wg.Wait()
// Disable new connection
logInfo("cancel rpc accept loop")
svc.cancel()
// Wait handlers
logInfo("wait rpc handlers")
svc.wg.Wait()
return err
}
func (this *Service) handleConn(conn net.Conn) {
func (svc *Service) handleConn(conn *net.TCPConn, wg *sync.WaitGroup) {
var err error
if svc.keepalive {
err = conn.SetKeepAlive(true)
if err != nil {
err = fmt.Errorf("unable to set keepalive: %s", err)
return
}
if svc.kaTime > 0 {
err = conn.SetKeepAlivePeriod(svc.kaTime)
if err != nil {
err = fmt.Errorf("unable to set keepalive period: %s", err)
return
}
}
}
context := CreateContext(conn)
remoteAddr := conn.RemoteAddr().String()
@@ -104,6 +144,7 @@ func (this *Service) handleConn(conn net.Conn) {
exitFunc := func() {
conn.Close()
wg.Done()
if err != nil {
logError("conn handler err:", err)
}
@@ -120,38 +161,43 @@ func (this *Service) handleConn(conn net.Conn) {
err = context.ReadRequest()
if err != nil {
err = Err(err)
return
}
err = context.BindMethod()
if err != nil {
err = Err(err)
return
}
for _, mw := range this.preMw {
for _, mw := range svc.preMw {
err = mw(context)
if err != nil {
err = Err(err)
return
}
}
err = this.Route(context)
err = svc.Route(context)
if err != nil {
err = Err(err)
return
}
for _, mw := range this.postMw {
for _, mw := range svc.postMw {
err = mw(context)
if err != nil {
err = Err(err)
return
}
}
return
}
func (this *Service) Route(context *Context) error {
handler, ok := this.handlers[context.reqRPC.Method]
func (svc *Service) Route(context *Context) error {
handler, ok := svc.handlers[context.reqRPC.Method]
if ok {
return handler(context)
return Err(handler(context))
}
return notFound(context)
return Err(notFound(context))
}
func (context *Context) ReadRequest() error {
@@ -159,19 +205,19 @@ func (context *Context) ReadRequest() error {
context.reqPacket.header, err = ReadBytes(context.sockReader, headerSize)
if err != nil {
return err
return Err(err)
}
context.reqHeader, err = UnpackHeader(context.reqPacket.header)
if err != nil {
return err
return Err(err)
}
rpcSize := context.reqHeader.rpcSize
context.reqPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize)
if err != nil {
return err
return Err(err)
}
return err
return Err(err)
}
func (context *Context) BinWriter() io.Writer {
@@ -189,24 +235,24 @@ func (context *Context) BinSize() int64 {
func (context *Context) ReadBin(writer io.Writer) error {
var err error
_, err = CopyBytes(context.sockReader, writer, context.reqHeader.binSize)
return err
return Err(err)
}
func (context *Context) BindMethod() error {
var err error
err = json.Unmarshal(context.reqPacket.rcpPayload, context.reqRPC)
return err
err = encoder.Unmarshal(context.reqPacket.rcpPayload, context.reqRPC)
return Err(err)
}
func (context *Context) BindParams(params any) error {
var err error
context.reqRPC.Params = params
err = json.Unmarshal(context.reqPacket.rcpPayload, context.reqRPC)
err = encoder.Unmarshal(context.reqPacket.rcpPayload, context.reqRPC)
if err != nil {
return err
return Err(err)
}
return err
return Err(err)
}
func (context *Context) SendResult(result any, binSize int64) error {
@@ -215,24 +261,24 @@ func (context *Context) SendResult(result any, binSize int64) error {
context.resPacket.rcpPayload, err = context.resRPC.Pack()
if err != nil {
return err
return Err(err)
}
context.resHeader.rpcSize = int64(len(context.resPacket.rcpPayload))
context.resHeader.binSize = binSize
context.resPacket.header, err = context.resHeader.Pack()
if err != nil {
return err
return Err(err)
}
_, err = context.sockWriter.Write(context.resPacket.header)
if err != nil {
return err
return Err(err)
}
_, err = context.sockWriter.Write(context.resPacket.rcpPayload)
if err != nil {
return err
return Err(err)
}
return err
return Err(err)
}
@@ -244,20 +290,20 @@ func (context *Context) SendError(execErr error) error {
context.resPacket.rcpPayload, err = context.resRPC.Pack()
if err != nil {
return err
return Err(err)
}
context.resHeader.rpcSize = int64(len(context.resPacket.rcpPayload))
context.resPacket.header, err = context.resHeader.Pack()
if err != nil {
return err
return Err(err)
}
_, err = context.sockWriter.Write(context.resPacket.header)
if err != nil {
return err
return Err(err)
}
_, err = context.sockWriter.Write(context.resPacket.rcpPayload)
if err != nil {
return err
return Err(err)
}
return err
return Err(err)
}

View File

@@ -13,12 +13,12 @@ import (
func ReadBytes(reader io.Reader, size int64) ([]byte, error) {
buffer := make([]byte, size)
read, err := io.ReadFull(reader, buffer)
return buffer[0:read], err
return buffer[0:read], Err(err)
}
func CopyBytes(reader io.Reader, writer io.Writer, dataSize int64) (int64, error) {
var err error
var bSize int64 = 1024 * 4
var bSize int64 = 1024 * 16
var total int64 = 0
var remains int64 = dataSize
buffer := make([]byte, bSize)
@@ -38,17 +38,20 @@ func CopyBytes(reader io.Reader, writer io.Writer, dataSize int64) (int64, error
}
received, err := reader.Read(buffer[0:bSize])
if err != nil {
return total, fmt.Errorf("read error: %v", err)
err = fmt.Errorf("read error: %v", err)
return total, Err(err)
}
recorded, err := writer.Write(buffer[0:received])
if err != nil {
return total, fmt.Errorf("write error: %v", err)
err = fmt.Errorf("write error: %v", err)
return total, Err(err)
}
if recorded != received {
return total, errors.New("size mismatch")
err = errors.New("size mismatch")
return total, Err(err)
}
total += int64(recorded)
remains -= int64(recorded)
}
return total, err
return total, Err(err)
}

View File

@@ -17,9 +17,9 @@ func init() {
}
type Auth struct {
Ident []byte `json:"ident,omitempty"`
Salt []byte `json:"salt,omitempty"`
Hash []byte `json:"hash,omitempty"`
Ident []byte `msgpack:"ident" json:"ident"`
Salt []byte `msgpack:"salt" json:"salt"`
Hash []byte `msgpack:"hash" json:"hash"`
}
func NewAuth() *Auth {