19 Commits

Author SHA1 Message Date
498abc18b4 delete timestamp from logger 2023-05-18 13:44:39 +02:00
5ea1af7fb1 disable default log timestamp 2023-05-18 13:34:54 +02:00
bc578cf25c fixed mistakes 2023-05-18 12:56:37 +02:00
cb6d26a7cc added listener close() 2023-05-18 12:53:37 +02:00
f814bb7d24 go fmt 2023-05-18 11:12:35 +02:00
aed740c933 added TLS transport 2023-05-18 11:12:16 +02:00
4716c25cf6 rename Handler to Handle 2023-05-17 22:35:35 +02:00
9bc3aef167 updated samples 2023-05-17 13:06:13 +02:00
0e5321d51a encoder changed to json 2023-05-17 13:03:52 +02:00
6f9835f399 context added 2023-05-17 12:59:01 +02:00
ef8db1d198 context added 2023-05-17 12:52:41 +02:00
c20faaecb3 Little refactoring 2023-04-01 01:03:38 +02:00
b5b293329e Little refactoring 2023-04-01 00:33:59 +02:00
8b3e722ea5 Little refactoring 2023-04-01 00:18:44 +02:00
8dc753cd95 update to go 1.19 2023-02-11 07:55:46 +02:00
59d850da3c go fmt 2023-02-11 07:52:40 +02:00
5366bd7ff7 Update README.md 2022-12-07 10:35:06 +02:00
f76cbc32aa changed: server stopping, dial/listen to tcp dial/listen, json rpc encoding to msgpack, info getters for logging ; added keepalive setter to server, etc 2022-07-26 14:26:40 +02:00
cddeee1412 added put/get and auth examples to README, added Err(err) debug wrapper 2022-07-08 10:58:42 +02:00
27 changed files with 1786 additions and 1371 deletions

211
README.md
View File

@@ -1,24 +1,24 @@
# dsrpc, Data RPC
DSRPC is easy and simple RPC framework over TCP socket.
### Purpose
A very easy and open RPC framework with data streaming.
### You can
- Use post and pre-execution middleware
- Use own post and pre-execution middleware
- Hash-based authentication in middleware
- Test call remote function without service organization
- Test remote function without network
Socket encryption is not used at this time since framefork
is oriented to transfer large amounts of data
is oriented to transfer large amounts of data.
Style of the framework is similar to that of GIN framework.
Style of the framework is similar of GIN framework.
## Example
## Exec method example
### Server
@@ -44,7 +44,7 @@ func server() error {
serv := dsrpc.NewService()
cont := NewController()
serv.Handler(api.HelloMethod, cont.HelloHandler)
serv.Handle(api.HelloMethod, cont.HelloHandler)
serv.PreMiddleware(dsrpc.LogRequest)
serv.PostMiddleware(dsrpc.LogResponse)
@@ -135,7 +135,7 @@ package api
const HelloMethod string = "hello"
type HelloParams struct {
Message string `msgpack:"message" json:"message"`
Message string `json:"message"`
}
func NewHelloParams() *HelloParams {
@@ -143,7 +143,7 @@ func NewHelloParams() *HelloParams {
}
type HelloResult struct {
Message string `msgpack:"message" json:"message"`
Message string `json:"message"`
}
func NewHelloResult() *HelloResult {
@@ -151,3 +151,196 @@ func NewHelloResult() *HelloResult {
}
```
### Authentication and authorization
#### Client side
```
func clientHello() error {
var err error
params := NewHelloParams()
params.Message = "hello server!"
result := NewHelloResult()
auth := dsrpc.CreateAuth([]byte("login"), []byte("password"))
err = dsrpc.Exec("127.0.0.1:8081", HelloMethod, params, result, auth)
if err != nil {
log.Println("method err:", err)
return err
}
//...
}
```
#### Server side
```
func authMiddleware(context *dsrpc.Context) error {
var err error
reqIdent := context.AuthIdent()
reqSalt := context.AuthSalt()
reqHash := context.AuthHash()
if reqIdent != "login" {
err = errors.New("auth ident or pass mismatch")
context.SendError(err)
return err
}
ident := reqIdent
pass := []byte("password")
ok := dsrpc.CheckHash(ident, pass, reqSalt, reqHash)
log.Println("auth is ok:", ok)
if !ok {
err = errors.New("auth ident or pass mismatch")
context.SendError(err)
return err
}
return err
}
func sampleServ(quiet bool) error {
var err error
if quiet {
dsrpc.SetAccessWriter(io.Discard)
dsrpc.SetMessageWriter(io.Discard)
}
serv := NewService()
serv.PreMiddleware(authMiddleware)
serv.PreMiddleware(dsrpc.LogRequest)
serv.Handle(HelloMethod, helloHandler)
serv.Handle(SaveMethod, saveHandler)
serv.Handle(LoadMethod, loadHandler)
serv.PostMiddleware(dsrpc.LogResponse)
serv.PostMiddleware(dsrpc.LogAccess)
err = serv.Listen(":8081")
if err != nil {
return err
}
return err
}
```
### Put method
#### Client side sample
```
var binSize int64 = 16
rand.Seed(time.Now().UnixNano())
binBytes := make([]byte, binSize)
rand.Read(binBytes)
reader := bytes.NewReader(binBytes)
err = dsrpc.Put("127.0.0.1:8081", SaveMethod, reader, binSize, params, result, auth)
```
#### Server side
```
func saveHandler(context *dsrpc.Context) error {
var err error
params := NewSaveParams()
err = context.BindParams(params)
if err != nil {
return err
}
bufferBytes := make([]byte, 0, 1024)
binWriter := bytes.NewBuffer(bufferBytes)
err = context.ReadBin(binWriter)
if err != nil {
context.SendError(err)
return err
}
result := NewSaveResult()
result.Message = "saved successfully!"
err = context.SendResult(result, 0)
if err != nil {
return err
}
return err
}
```
### Get method
#### Client side
```
params := NewLoadParams()
params.Message = "load data!"
result := NewHelloResult()
auth := CreateAuth([]byte("qwert"), []byte("12345"))
binBytes := make([]byte, 0)
writer := bytes.NewBuffer(binBytes)
err = dsrpc.Get("127.0.0.1:8081", LoadMethod, writer, params, result, auth)
if err != nil {
return err
}
//...
```
#### Server side
```
func getHandler(context *dsrpc.Context) error {
var err error
params := NewSaveParams()
err = context.BindParams(params)
if err != nil {
return err
}
var binSize int64 = 1024
rand.Seed(time.Now().UnixNano())
binBytes := make([]byte, binSize)
rand.Read(binBytes)
binReader := bytes.NewReader(binBytes)
result := NewSaveResult()
result.Message = "load successfully!"
err = context.SendResult(result, binSize)
if err != nil {
return err
}
binWriter := context.BinWriter()
_, err = dsrpc.CopyBytes(binReader, binWriter, binSize)
if err != nil {
return err
}
return err
}
```

340
client.go
View File

@@ -1,55 +1,75 @@
/*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/
package dsrpc
import (
"encoding/json"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"sync"
encoder "encoding/json"
)
func Put(address string, method string, reader io.Reader, size int64, param, result any, auth *Auth) error {
func Put(ctx context.Context, address string, method string, reader io.Reader, binSize int64, param, result any, auth *Auth) error {
var err error
conn, err := net.Dial("tcp", address)
addr, err := net.ResolveTCPAddr("tcp", address)
if err != nil {
err = fmt.Errorf("unable to resolve adddress: %s", err)
return err
}
conn, err := net.DialTCP("tcp", nil, addr)
if err != nil {
return err
}
defer conn.Close()
return ConnPut(conn, method, reader, size, param, result, auth)
return ConnPut(ctx, conn, method, reader, binSize, param, result, auth)
}
func ConnPut(conn net.Conn, method string, reader io.Reader, size int64, param, result any, auth *Auth) error {
func PutTLS(ctx context.Context, tlsConfig *tls.Config, address string, method string, reader io.Reader, binSize int64, param, result any, auth *Auth) error {
var err error
context := CreateContext(conn)
context.reqRPC.Method = method
context.reqRPC.Params = param
context.reqRPC.Auth = auth
context.resRPC.Result = result
context.binReader = reader
context.binWriter = conn
context.reqHeader.binSize = size
if context.reqRPC.Params == nil {
context.reqRPC.Params = NewEmpty()
}
err = context.CreateRequest()
conn, err := tls.Dial("tcp", address, tlsConfig)
if err != nil {
return err
}
err = context.WriteRequest()
defer conn.Close()
return ConnPut(ctx, conn, method, reader, binSize, param, result, auth)
}
func ConnPut(ctx context.Context, conn net.Conn, method string, reader io.Reader, binSize int64, param, result any, auth *Auth) error {
var err error
content := CreateContent(conn)
content.reqBlock.Method = method
if param != nil {
content.reqBlock.Params = param
}
if auth != nil {
content.reqBlock.Auth = auth
}
if result != nil {
content.resBlock.Result = result
}
content.binReader = reader
content.binWriter = conn
content.reqHeader.binSize = binSize
err = content.createRequest()
if err != nil {
return err
}
err = content.writeRequest()
if err != nil {
return err
}
@@ -58,225 +78,271 @@ func ConnPut(conn net.Conn, method string, reader io.Reader, size int64, param,
errChan := make(chan error, 1)
wg.Add(1)
go context.ReadResponseAsync(&wg, errChan)
go content.readResponseAsync(&wg, errChan)
wg.Add(1)
go context.UploadBinAsync(&wg)
go content.uploadBinAsync(ctx, &wg)
wg.Wait()
err = <- errChan
err = <-errChan
if err != nil {
return err
}
err = context.BindResponse()
err = content.bindResponse()
if err != nil {
return err
}
return err
}
func Get(address string, method string, writer io.Writer, param, result any, auth *Auth) error {
func Get(ctx context.Context, address string, method string, writer io.Writer, param, result any, auth *Auth) error {
var err error
conn, err := net.Dial("tcp", address)
addr, err := net.ResolveTCPAddr("tcp", address)
if err != nil {
err = fmt.Errorf("unable to resolve adddress: %s", err)
return err
}
conn, err := net.DialTCP("tcp", nil, addr)
if err != nil {
return err
}
defer conn.Close()
return ConnGet(conn, method, writer, param, result, auth)
return ConnGet(ctx, conn, method, writer, param, result, auth)
}
func ConnGet(conn net.Conn, method string, writer io.Writer, param, result any, auth *Auth) error {
func GetTLS(ctx context.Context, tlsConfig *tls.Config, address string, method string, writer io.Writer, param, result any, auth *Auth) error {
var err error
context := CreateContext(conn)
context.reqRPC.Method = method
context.reqRPC.Params = param
context.reqRPC.Auth = auth
context.resRPC.Result = result
context.binReader = conn
context.binWriter = writer
if context.reqRPC.Params == nil {
context.reqRPC.Params = NewEmpty()
}
err = context.CreateRequest()
if err != nil {
return err
}
err = context.WriteRequest()
if err != nil {
return err
}
err = context.ReadResponse()
if err != nil {
return err
}
err = context.DownloadBin()
if err != nil {
return err
}
err = context.BindResponse()
if err != nil {
return err
}
return err
}
func Exec(address, method string, param any, result any, auth *Auth) error {
var err error
conn, err := net.Dial("tcp", address)
conn, err := tls.Dial("tcp", address, tlsConfig)
if err != nil {
return err
}
defer conn.Close()
err = ConnExec(conn, method, param, result, auth)
if err != nil {
return err
}
return err
return ConnGet(ctx, conn, method, writer, param, result, auth)
}
func ConnExec(conn net.Conn, method string, param any, result any, auth *Auth) error {
func ConnGet(ctx context.Context, conn net.Conn, method string, writer io.Writer, param, result any, auth *Auth) error {
var err error
context := CreateContext(conn)
context.reqRPC.Method = method
context.reqRPC.Params = param
context.reqRPC.Auth = auth
context.resRPC.Result = result
if context.reqRPC.Params == nil {
context.reqRPC.Params = NewEmpty()
content := CreateContent(conn)
content.reqBlock.Method = method
if param != nil {
content.reqBlock.Params = param
}
if auth != nil {
content.reqBlock.Auth = auth
}
if result != nil {
content.resBlock.Result = result
}
err = context.CreateRequest()
content.binReader = conn
content.binWriter = writer
err = content.createRequest()
if err != nil {
return err
}
err = context.WriteRequest()
err = content.writeRequest()
if err != nil {
return err
}
err = context.ReadResponse()
err = content.readResponse()
if err != nil {
return err
}
err = context.BindResponse()
err = content.downloadBin(ctx)
if err != nil {
return err
}
err = content.bindResponse()
if err != nil {
return err
}
return err
}
func (context *Context) CreateRequest() error {
func Exec(ctx context.Context, address, method string, param any, result any, auth *Auth) error {
var err error
context.reqPacket.rcpPayload, err = context.reqRPC.Pack()
addr, err := net.ResolveTCPAddr("tcp", address)
if err != nil {
err = fmt.Errorf("unable to resolve adddress: %s", err)
return err
}
conn, err := net.DialTCP("tcp", nil, addr)
if err != nil {
return err
}
rpcSize := int64(len(context.reqPacket.rcpPayload))
context.reqHeader.rpcSize = rpcSize
defer conn.Close()
context.reqPacket.header, err = context.reqHeader.Pack()
err = ConnExec(ctx, conn, method, param, result, auth)
if err != nil {
return err
}
return err
}
func (context *Context) WriteRequest() error {
var err error
_, err = context.sockWriter.Write(context.reqPacket.header)
if err != nil {
return err
}
_, err = context.sockWriter.Write(context.reqPacket.rcpPayload)
if err != nil {
return err
}
return err
}
func (context *Context) UploadBin() error {
var err error
_, err = CopyBytes(context.binReader, context.binWriter, context.reqHeader.binSize)
return err
}
func (context *Context) ReadResponse() error {
func ExecTLS(ctx context.Context, tlsConfig *tls.Config, address, method string, param any, result any, auth *Auth) error {
var err error
context.resPacket.header, err = ReadBytes(context.sockReader, headerSize)
conn, err := tls.Dial("tcp", address, tlsConfig)
if err != nil {
return err
}
context.resHeader, err = UnpackHeader(context.resPacket.header)
if err != nil {
return err
}
rpcSize := context.resHeader.rpcSize
context.resPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize)
defer conn.Close()
err = ConnExec(ctx, conn, method, param, result, auth)
if err != nil {
return err
}
return err
}
func (context *Context) UploadBinAsync(wg *sync.WaitGroup) {
func ConnExec(ctx context.Context, conn net.Conn, method string, param any, result any, auth *Auth) error {
var err error
content := CreateContent(conn)
content.reqBlock.Method = method
if param != nil {
content.reqBlock.Params = param
}
if result != nil {
content.resBlock.Result = result
}
if auth != nil {
content.reqBlock.Auth = auth
}
err = content.createRequest()
if err != nil {
return err
}
err = content.writeRequest()
if err != nil {
return err
}
err = content.readResponse()
if err != nil {
return err
}
err = content.bindResponse()
if err != nil {
return err
}
return err
}
func (content *Content) createRequest() error {
var err error
content.reqPacket.rcpPayload, err = content.reqBlock.Pack()
if err != nil {
return err
}
rpcSize := int64(len(content.reqPacket.rcpPayload))
content.reqHeader.rpcSize = rpcSize
content.reqPacket.header, err = content.reqHeader.Pack()
if err != nil {
return err
}
return err
}
func (content *Content) writeRequest() error {
var err error
_, err = content.sockWriter.Write(content.reqPacket.header)
if err != nil {
return err
}
_, err = content.sockWriter.Write(content.reqPacket.rcpPayload)
if err != nil {
return err
}
return err
}
func (content *Content) uploadBin(ctx context.Context) error {
var err error
_, err = CopyBytes(ctx, content.binReader, content.binWriter, content.reqHeader.binSize)
return err
}
func (content *Content) readResponse() error {
var err error
content.resPacket.header, err = ReadBytes(content.sockReader, headerSize)
if err != nil {
return err
}
content.resHeader, err = UnpackHeader(content.resPacket.header)
if err != nil {
return err
}
rpcSize := content.resHeader.rpcSize
content.resPacket.rcpPayload, err = ReadBytes(content.sockReader, rpcSize)
if err != nil {
return err
}
return err
}
func (content *Content) uploadBinAsync(ctx context.Context, wg *sync.WaitGroup) {
exitFunc := func() {
wg.Done()
}
defer exitFunc()
_, _ = CopyBytes(context.binReader, context.binWriter, context.reqHeader.binSize)
_, _ = CopyBytes(ctx, content.binReader, content.binWriter, content.reqHeader.binSize)
return
}
func (context *Context) ReadResponseAsync(wg *sync.WaitGroup, errChan chan error) {
func (content *Content) readResponseAsync(wg *sync.WaitGroup, errChan chan error) {
var err error
exitFunc := func() {
errChan <- err
wg.Done()
}
defer exitFunc()
context.resPacket.header, err = ReadBytes(context.sockReader, headerSize)
content.resPacket.header, err = ReadBytes(content.sockReader, headerSize)
if err != nil {
err = err
return
}
context.resHeader, err = UnpackHeader(context.resPacket.header)
content.resHeader, err = UnpackHeader(content.resPacket.header)
if err != nil {
err = err
return
}
rpcSize := context.resHeader.rpcSize
context.resPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize)
rpcSize := content.resHeader.rpcSize
content.resPacket.rcpPayload, err = ReadBytes(content.sockReader, rpcSize)
if err != nil {
err = err
return
}
return
}
func (context *Context) DownloadBin() error {
func (content *Content) downloadBin(ctx context.Context) error {
var err error
_, err = CopyBytes(context.binReader, context.binWriter, context.resHeader.binSize)
_, err = CopyBytes(ctx, content.binReader, content.binWriter, content.resHeader.binSize)
return err
}
func (context *Context) BindResponse() error {
func (content *Content) bindResponse() error {
var err error
err = json.Unmarshal(context.resPacket.rcpPayload, context.resRPC)
err = encoder.Unmarshal(content.resPacket.rcpPayload, content.resBlock)
if err != nil {
return err
}
if len(context.resRPC.Error) > 0 {
return errors.New(context.resRPC.Error)
if len(content.resBlock.Error) > 0 {
err = errors.New(content.resBlock.Error)
return err
}
return err
}

View File

@@ -1,3 +0,0 @@
package dsrpc
type any = interface{}

145
content.go Normal file
View File

@@ -0,0 +1,145 @@
/*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*/
package dsrpc
import (
"io"
"net"
"time"
)
type Content struct {
start time.Time
remoteHost string
sockReader io.Reader
sockWriter io.Writer
reqPacket *Packet
reqHeader *Header
reqBlock *Request
resPacket *Packet
resHeader *Header
resBlock *Response
binReader io.Reader
binWriter io.Writer
}
func CreateContent(conn net.Conn) *Content {
context := &Content{
start: time.Now(),
sockReader: conn,
sockWriter: conn,
reqPacket: NewEmptyPacket(),
reqHeader: NewEmptyHeader(),
reqBlock: NewEmptyRequest(),
resPacket: NewEmptyPacket(),
resHeader: NewEmptyHeader(),
resBlock: NewEmptyResponse(),
}
return context
}
func (context *Content) Request() *Request {
return context.reqBlock
}
func (context *Content) RemoteHost() string {
return context.remoteHost
}
func (context *Content) Start() time.Time {
return context.start
}
func (context *Content) Method() string {
var method string
if context.reqBlock != nil {
method = context.reqBlock.Method
}
return method
}
func (context *Content) ReqRpcSize() int64 {
var size int64
if context.reqHeader != nil {
size = context.reqHeader.rpcSize
}
return size
}
func (context *Content) ReqBinSize() int64 {
var size int64
if context.reqHeader != nil {
size = context.reqHeader.binSize
}
return size
}
func (context *Content) ResBinSize() int64 {
var size int64
if context.resHeader != nil {
size = context.resHeader.binSize
}
return size
}
func (context *Content) ResRpcSize() int64 {
var size int64
if context.resHeader != nil {
size = context.resHeader.rpcSize
}
return size
}
func (context *Content) ReqSize() int64 {
var size int64
if context.reqHeader != nil {
size += context.reqHeader.binSize
size += context.reqHeader.rpcSize
}
return size
}
func (context *Content) ResSize() int64 {
var size int64
if context.resHeader != nil {
size += context.resHeader.binSize
size += context.resHeader.rpcSize
}
return size
}
func (context *Content) SetAuthIdent(ident []byte) {
context.reqBlock.Auth.Ident = ident
}
func (context *Content) SetAuthSalt(salt []byte) {
context.reqBlock.Auth.Salt = salt
}
func (context *Content) SetAuthHash(hash []byte) {
context.reqBlock.Auth.Hash = hash
}
func (context *Content) AuthIdent() []byte {
return context.reqBlock.Auth.Ident
}
func (context *Content) AuthSalt() []byte {
return context.reqBlock.Auth.Salt
}
func (context *Content) AuthHash() []byte {
return context.reqBlock.Auth.Hash
}
func (context *Content) Auth() *Auth {
return context.reqBlock.Auth
}

View File

@@ -1,90 +0,0 @@
/*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/
package dsrpc
import (
"io"
"net"
"time"
)
type Context struct {
start time.Time
remoteHost string
sockReader io.Reader
sockWriter io.Writer
reqHeader *Header
reqRPC *Request
reqPacket *Packet
resPacket *Packet
resHeader *Header
resRPC *Response
binReader io.Reader
binWriter io.Writer
}
func NewContext() *Context {
context := &Context{}
context.start = time.Now()
return context
}
func CreateContext(conn net.Conn) *Context {
context := &Context{}
context.start = time.Now()
context.sockReader = conn
context.sockWriter = conn
context.reqPacket = NewPacket()
context.resPacket = NewPacket()
context.reqHeader = NewHeader()
context.reqRPC = NewRequest()
context.resHeader = NewHeader()
context.resRPC = NewResponse()
context.resRPC = NewResponse()
return context
}
func (context *Context) Request() *Request {
return context.reqRPC
}
func (context *Context) SetAuthIdent(ident []byte) {
context.reqRPC.Auth.Ident = ident
}
func (context *Context) SetAuthSalt(salt []byte) {
context.reqRPC.Auth.Salt = salt
}
func (context *Context) SetAuthHash(hash []byte) {
context.reqRPC.Auth.Hash = hash
}
func (context *Context) AuthIdent() []byte {
return context.reqRPC.Auth.Ident
}
func (context *Context) AuthSalt() []byte {
return context.reqRPC.Auth.Salt
}
func (context *Context) AuthHash() []byte {
return context.reqRPC.Auth.Hash
}
func (context *Context) Auth() *Auth {
return context.reqRPC.Auth
}

View File

@@ -1,13 +0,0 @@
/*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/
package dsrpc
type Empty struct {}
func NewEmpty() *Empty {
return &Empty{}
}

View File

@@ -12,14 +12,7 @@ type HelloParams struct {
Message string `msgpack:"message" json:"message"`
}
func NewHelloParams() *HelloParams {
return &HelloParams{}
}
type HelloResult struct {
Message string `msgpack:"message" json:"message"`
}
func NewHelloResult() *HelloResult {
return &HelloResult{}
}

View File

@@ -7,8 +7,12 @@
package main
import (
"context"
"fmt"
"time"
"github.com/kindsoldier/dsrpc"
"netsrv/api"
)
@@ -23,12 +27,16 @@ func main() {
func exec() error {
var err error
params := api.NewHelloParams()
params.Message = "hello, server!"
params := api.HelloParams{
Message: "hello, server!",
}
result := api.NewHelloResult()
result := api.HelloResult{}
err = dsrpc.Exec("127.0.0.1:8081", api.HelloMethod, params, result, nil)
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(5*time.Second))
defer cancel()
err = dsrpc.Exec(ctx, "127.0.0.1:8081", api.HelloMethod, &params, &result, nil)
if err != nil {
return err
}

View File

@@ -1,7 +1,5 @@
module netsrv
go 1.17
go 1.19
require github.com/kindsoldier/dsrpc v0.0.1
replace github.com/kindsoldier/dsrpc => ../
require github.com/kindsoldier/dsrpc v1.2.1

View File

@@ -1,10 +1,6 @@
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/kindsoldier/dsrpc v1.2.1 h1:sw1a3MAD83Do1Fu+Dh+AHArrwVMgZ/KTLUWWTkQ6vj8=
github.com/kindsoldier/dsrpc v1.2.1/go.mod h1:zYb5yYfE/18BYK+iCUNcpkZ4uArwUNNhwYkUK8xDHQk=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

View File

@@ -8,8 +8,9 @@ package main
import (
"log"
"github.com/kindsoldier/dsrpc"
"netsrv/api"
"github.com/kindsoldier/dsrpc"
)
func main() {
@@ -25,7 +26,7 @@ func server() error {
serv := dsrpc.NewService()
cont := NewController()
serv.Handler(api.HelloMethod, cont.HelloHandler)
serv.Handle(api.HelloMethod, cont.HelloHandler)
serv.PreMiddleware(dsrpc.LogRequest)
serv.PostMiddleware(dsrpc.LogResponse)
@@ -46,20 +47,19 @@ func NewController() *Controller {
return &Controller{}
}
func (cont *Controller) HelloHandler(context *dsrpc.Context) error {
func (cont *Controller) HelloHandler(content *dsrpc.Content) error {
var err error
params := api.NewHelloParams()
err = context.BindParams(params)
params := api.HelloParams{}
err = content.BindParams(&params)
if err != nil {
return err
}
log.Println("hello message:", params.Message)
result := api.NewHelloResult()
result.Message = "hello!"
err = context.SendResult(result, 0)
result := api.HelloResult{
Message: "hello, client!",
}
err = content.SendResult(&result, 0)
if err != nil {
return err
}

View File

@@ -1,13 +1,12 @@
/*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/
package dsrpc
import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
@@ -18,27 +17,47 @@ import (
"github.com/stretchr/testify/require"
)
const HelloMethod string = "hello"
type HelloParams struct {
Message string `json:"message" msgpack:"message"`
}
type HelloResult struct {
Message string `json:"message" msgpack:"message"`
}
const SaveMethod string = "save"
type SaveParams HelloParams
type SaveResult HelloResult
const LoadMethod string = "load"
type LoadParams HelloParams
type LoadResult HelloResult
func TestLocalExec(t *testing.T) {
var err error
params := NewHelloParams()
params := HelloParams{}
params.Message = "hello server!"
result := NewHelloResult()
result := HelloResult{}
auth := CreateAuth([]byte("qwert"), []byte("12345"))
err = LocalExec(HelloMethod, params, result, auth, helloHandler)
err = LocalExec(HelloMethod, &params, &result, auth, helloHandler)
require.NoError(t, err)
resultJSON, _ := json.Marshal(result)
logDebug("method result:", string(resultJSON))
}
resultJson, _ := json.Marshal(result)
logDebug("method result:", string(resultJson))
}
func TestLocalSave(t *testing.T) {
var err error
params := NewSaveParams()
params := SaveParams{}
params.Message = "save data!"
result := NewHelloResult()
result := SaveResult{}
auth := CreateAuth([]byte("qwert"), []byte("12345"))
var binSize int64 = 16
@@ -48,37 +67,43 @@ func TestLocalSave(t *testing.T) {
reader := bytes.NewReader(binBytes)
err = LocalPut(SaveMethod, reader, binSize, params, result, auth, saveHandler)
timeout := time.Duration(5 * time.Second)
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
err = LocalPut(ctx, SaveMethod, reader, binSize, &params, &result, auth, saveHandler)
require.NoError(t, err)
resultJSON, _ := json.Marshal(result)
logDebug("method result:", string(resultJSON))
resultJson, _ := json.Marshal(result)
logDebug("method result:", string(resultJson))
}
func TestLocalLoad(t *testing.T) {
var err error
params := NewLoadParams()
params := LoadParams{}
params.Message = "load data!"
result := NewHelloResult()
result := LoadResult{}
auth := CreateAuth([]byte("qwert"), []byte("12345"))
binBytes := make([]byte, 0)
writer := bytes.NewBuffer(binBytes)
err = LocalGet(LoadMethod, writer, params, result, auth, loadHandler)
timeout := time.Duration(5 * time.Second)
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
err = LocalGet(ctx, LoadMethod, writer, &params, &result, auth, loadHandler)
require.NoError(t, err)
resultJSON, _ := json.Marshal(result)
logDebug("method result:", string(resultJSON))
resultJson, _ := json.Marshal(result)
logDebug("method result:", string(resultJson))
logDebug("bin size:", len(writer.Bytes()))
}
func TestNetExec(t *testing.T) {
go testServ(false)
time.Sleep(10 * time.Millisecond)
time.Sleep(100 * time.Millisecond)
err := clientHello()
require.NoError(t, err)
@@ -86,21 +111,21 @@ func TestNetExec(t *testing.T) {
func TestNetSave(t *testing.T) {
go testServ(false)
time.Sleep(10 * time.Millisecond)
time.Sleep(100 * time.Millisecond)
err := clientSave()
require.NoError(t, err)
}
func TestNetLoad(t *testing.T) {
go testServ(false)
time.Sleep(10 * time.Millisecond)
time.Sleep(100 * time.Millisecond)
err := clientLoad()
require.NoError(t, err)
}
func BenchmarkNetPut(b *testing.B) {
go testServ(true)
time.Sleep(10 * time.Millisecond)
time.Sleep(1000 * time.Millisecond)
clientSave()
pBench := func(pb *testing.PB) {
@@ -108,16 +133,16 @@ func BenchmarkNetPut(b *testing.B) {
clientSave()
}
}
b.SetParallelism(10)
b.SetParallelism(2000)
b.RunParallel(pBench)
}
func clientHello() error {
var err error
params := NewHelloParams()
params := HelloParams{}
params.Message = "hello server!"
result := NewHelloResult()
result := HelloResult{}
auth := CreateAuth([]byte("qwert"), []byte("12345"))
var binSize int64 = 16
@@ -125,23 +150,26 @@ func clientHello() error {
binBytes := make([]byte, binSize)
rand.Read(binBytes)
err = Exec("127.0.0.1:8081", HelloMethod, params, result, auth)
timeout := time.Duration(5 * time.Second)
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
err = Exec(ctx, "127.0.0.1:18081", HelloMethod, &params, &result, auth)
if err != nil {
logError("method err:", err)
return err
}
resultJSON, _ := json.Marshal(result)
logDebug("method result:", string(resultJSON))
resultJson, _ := json.Marshal(result)
logDebug("method result:", string(resultJson))
return err
}
func clientSave() error {
var err error
params := NewSaveParams()
params := SaveParams{}
params.Message = "save data!"
result := NewHelloResult()
result := SaveResult{}
auth := CreateAuth([]byte("qwert"), []byte("12345"))
var binSize int64 = 16
@@ -151,41 +179,46 @@ func clientSave() error {
reader := bytes.NewReader(binBytes)
err = Put("127.0.0.1:8081", SaveMethod, reader, binSize, params, result, auth)
timeout := time.Duration(5 * time.Second)
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
err = Put(ctx, "127.0.0.1:18081", SaveMethod, reader, binSize, &params, &result, auth)
if err != nil {
logError("method err:", err)
return err
}
resultJSON, _ := json.Marshal(result)
logDebug("method result:", string(resultJSON))
resultJson, _ := json.Marshal(result)
logDebug("method result:", string(resultJson))
return err
}
func clientLoad() error {
var err error
params := NewLoadParams()
params := LoadParams{}
params.Message = "load data!"
result := NewHelloResult()
result := LoadResult{}
auth := CreateAuth([]byte("qwert"), []byte("12345"))
binBytes := make([]byte, 0)
writer := bytes.NewBuffer(binBytes)
err = Get("127.0.0.1:8081", LoadMethod, writer, params, result, auth)
timeout := time.Duration(5 * time.Second)
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
err = Get(ctx, "127.0.0.1:18081", LoadMethod, writer, &params, &result, auth)
if err != nil {
logError("method err:", err)
return err
}
resultJSON, _ := json.Marshal(result)
logDebug("method result:", string(resultJSON))
resultJson, _ := json.Marshal(result)
logDebug("method result:", string(resultJson))
logDebug("bin size:", len(writer.Bytes()))
return err
}
var testServRun bool = false
func testServ(quiet bool) error {
@@ -200,10 +233,11 @@ func testServ(quiet bool) error {
SetAccessWriter(io.Discard)
SetMessageWriter(io.Discard)
}
serv := NewService()
serv.Handler(HelloMethod, helloHandler)
serv.Handler(SaveMethod, saveHandler)
serv.Handler(LoadMethod, loadHandler)
serv.Handle(HelloMethod, helloHandler)
serv.Handle(SaveMethod, saveHandler)
serv.Handle(LoadMethod, loadHandler)
serv.PreMiddleware(LogRequest)
serv.PreMiddleware(auth)
@@ -211,65 +245,69 @@ func testServ(quiet bool) error {
serv.PostMiddleware(LogResponse)
serv.PostMiddleware(LogAccess)
err = serv.Listen(":8081")
err = serv.Listen(":18081")
if err != nil {
return err
}
return err
}
func auth(context *Context) error {
func auth(content *Content) error {
var err error
reqIdent := context.AuthIdent()
reqSalt := context.AuthSalt()
reqHash := context.AuthHash()
reqIdent := content.AuthIdent()
reqSalt := content.AuthSalt()
reqHash := content.AuthHash()
ident := reqIdent
pass := []byte("12345")
auth := context.Auth()
logDebug("auth ", string(auth.JSON()))
auth := content.Auth()
logDebug("auth ", string(auth.Json()))
ok := CheckHash(ident, pass, reqSalt, reqHash)
logDebug("auth ok:", ok)
if !ok {
err = errors.New("auth ident or pass missmatch")
context.SendError(err)
content.SendError(err)
return err
}
return err
}
func helloHandler(context *Context) error {
func helloHandler(content *Content) error {
var err error
params := NewHelloParams()
params := HelloParams{}
err = context.BindParams(params)
err = content.BindParams(&params)
if err != nil {
return err
}
err = context.ReadBin(io.Discard)
timeout := time.Duration(5 * time.Second)
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
err = content.ReadBin(ctx, io.Discard)
if err != nil {
context.SendError(err)
content.SendError(err)
return err
}
result := NewHelloResult()
result := HelloResult{}
result.Message = "hello, client!"
err = context.SendResult(result, 0)
err = content.SendResult(result, 0)
if err != nil {
return err
}
return err
}
func saveHandler(context *Context) error {
func saveHandler(content *Content) error {
var err error
params := NewSaveParams()
params := SaveParams{}
err = context.BindParams(params)
err = content.BindParams(&params)
if err != nil {
return err
}
@@ -277,34 +315,42 @@ func saveHandler(context *Context) error {
bufferBytes := make([]byte, 0, 1024)
binWriter := bytes.NewBuffer(bufferBytes)
err = context.ReadBin(binWriter)
timeout := time.Duration(5 * time.Second)
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
err = content.ReadBin(ctx, binWriter)
if err != nil {
context.SendError(err)
content.SendError(err)
return err
}
result := NewSaveResult()
result := SaveResult{}
result.Message = "saved successfully!"
err = context.SendResult(result, 0)
err = content.SendResult(result, 0)
if err != nil {
return err
}
return err
}
func loadHandler(context *Context) error {
func loadHandler(content *Content) error {
var err error
params := NewSaveParams()
params := SaveParams{}
err = context.BindParams(params)
err = content.BindParams(&params)
if err != nil {
return err
}
err = context.ReadBin(io.Discard)
timeout := time.Duration(5 * time.Second)
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
err = content.ReadBin(ctx, io.Discard)
if err != nil {
context.SendError(err)
content.SendError(err)
return err
}
@@ -315,62 +361,19 @@ func loadHandler(context *Context) error {
binReader := bytes.NewReader(binBytes)
result := NewSaveResult()
result := SaveResult{}
result.Message = "load successfully!"
err = context.SendResult(result, binSize)
err = content.SendResult(result, binSize)
if err != nil {
return err
}
binWriter := context.BinWriter()
_, err = CopyBytes(binReader, binWriter, binSize)
binWriter := content.BinWriter()
_, err = CopyBytes(ctx, binReader, binWriter, binSize)
if err != nil {
return err
}
return err
}
const HelloMethod string = "hello"
type HelloParams struct {
Message string `json:"message" json:"message"`
}
func NewHelloParams() *HelloParams {
return &HelloParams{}
}
type HelloResult struct {
Message string `json:"message" json:"message"`
}
func NewHelloResult() *HelloResult {
return &HelloResult{}
}
const SaveMethod string = "save"
type SaveParams HelloParams
type SaveResult HelloResult
func NewSaveParams() *SaveParams {
return &SaveParams{}
}
func NewSaveResult() *SaveResult {
return &SaveResult{}
}
const LoadMethod string = "load"
type LoadParams HelloParams
type LoadResult HelloResult
func NewLoadParams() *LoadParams {
return &LoadParams{}
}
func NewLoadResult() *LoadResult {
return &LoadResult{}
}

View File

@@ -1,7 +1,5 @@
/*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/
package dsrpc
@@ -12,9 +10,10 @@ type FAddr struct {
}
func NewFAddr() *FAddr {
var addr FAddr
addr.network = "tcp"
addr.address = "127.0.0.1:5000"
addr := FAddr{
network: "tcp",
address: "127.0.0.1:5000",
}
return &addr
}

View File

@@ -9,6 +9,7 @@ package dsrpc
import (
"net"
"testing"
"github.com/stretchr/testify/require"
)

View File

@@ -18,7 +18,7 @@ type FConn struct {
writer io.Writer
}
func NewFConn() (*FConn, *FConn){
func NewFConn() (*FConn, *FConn) {
c2sBuffer := bytes.NewBuffer(make([]byte, 0))
s2cBuffer := bytes.NewBuffer(make([]byte, 0))

8
go.mod
View File

@@ -1,11 +1,11 @@
module github.com/kindsoldier/dsrpc
go 1.17
go 1.19
require github.com/stretchr/testify v1.7.1
require github.com/stretchr/testify v1.8.2
require (
github.com/davecgh/go-spew v1.1.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

12
go.sum
View File

@@ -1,11 +1,17 @@
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -7,16 +7,18 @@
package dsrpc
import (
"errors"
"bytes"
"encoding/binary"
"encoding/json"
"bytes"
"errors"
)
const headerSize int64 = 16 * 2
const sizeOfInt64 int = 8
const magicCodeA int64 = 0xEE00ABBA
const magicCodeB int64 = 0xEE44ABBA
const (
headerSize int64 = 16 * 2
sizeOfInt64 int = 8
magicCodeA int64 = 0xEE00ABBA
magicCodeB int64 = 0xEE44ABBA
)
type Header struct {
magicCodeA int64 `json:"magicCodeA"`
@@ -25,35 +27,33 @@ type Header struct {
magicCodeB int64 `json:"magicCodeB"`
}
func NewHeader() *Header {
func NewEmptyHeader() *Header {
return &Header{
magicCodeA: magicCodeA,
magicCodeB: magicCodeB,
}
}
func (this *Header) JSON() []byte {
jBytes, _ := json.Marshal(this)
func (hdr *Header) ToJson() []byte {
jBytes, _ := json.Marshal(hdr)
return jBytes
}
func (this *Header) Pack() ([]byte, error) {
func (hdr *Header) Pack() ([]byte, error) {
var err error
headerBytes := make([]byte, 0, headerSize)
headerBuffer := bytes.NewBuffer(headerBytes)
magicCodeABytes := encoderI64(this.magicCodeA)
magicCodeABytes := EncoderI64(hdr.magicCodeA)
headerBuffer.Write(magicCodeABytes)
rpcSizeBytes := encoderI64(this.rpcSize)
rpcSizeBytes := EncoderI64(hdr.rpcSize)
headerBuffer.Write(rpcSizeBytes)
binSizeBytes := encoderI64(this.binSize)
binSizeBytes := EncoderI64(hdr.binSize)
headerBuffer.Write(binSizeBytes)
magicCodeBBytes := encoderI64(this.magicCodeB)
magicCodeBBytes := EncoderI64(hdr.magicCodeB)
headerBuffer.Write(magicCodeBBytes)
return headerBuffer.Bytes(), err
@@ -61,38 +61,41 @@ func (this *Header) Pack() ([]byte, error) {
func UnpackHeader(headerBytes []byte) (*Header, error) {
var err error
header := NewHeader()
headerReader := bytes.NewReader(headerBytes)
magicCodeABytes := make([]byte, sizeOfInt64)
headerReader.Read(magicCodeABytes)
header.magicCodeA = decoderI64(magicCodeABytes)
rpcSizeBytes := make([]byte, sizeOfInt64)
headerReader.Read(rpcSizeBytes)
header.rpcSize = decoderI64(rpcSizeBytes)
binSizeBytes := make([]byte, sizeOfInt64)
headerReader.Read(binSizeBytes)
header.binSize = decoderI64(binSizeBytes)
magicCodeBBytes := make([]byte, sizeOfInt64)
headerReader.Read(magicCodeBBytes)
header.magicCodeB = decoderI64(magicCodeBBytes)
if header.magicCodeA != magicCodeA || header.magicCodeB != magicCodeB {
return header, errors.New("wrong protocol magic code")
header := &Header{
magicCodeA: DecoderI64(magicCodeABytes),
rpcSize: DecoderI64(rpcSizeBytes),
binSize: DecoderI64(binSizeBytes),
magicCodeB: DecoderI64(magicCodeBBytes),
}
if header.magicCodeA != magicCodeA || header.magicCodeB != magicCodeB {
err = errors.New("Wrong protocol magic code")
return header, err
}
return header, err
}
func encoderI64(i int64) []byte {
func EncoderI64(i int64) []byte {
buffer := make([]byte, sizeOfInt64)
binary.BigEndian.PutUint64(buffer, uint64(i))
return buffer
}
func decoderI64(b []byte) int64 {
func DecoderI64(b []byte) int64 {
return int64(binary.BigEndian.Uint64(b))
}

View File

@@ -10,30 +10,27 @@ import (
"fmt"
"io"
"os"
"time"
)
var messageWriter io.Writer = os.Stdout
var accessWriter io.Writer = os.Stdout
var (
messageWriter io.Writer = os.Stdout
accessWriter io.Writer = os.Stdout
)
func logDebug(messages ...any) {
stamp := time.Now().Format(time.RFC3339Nano)
fmt.Fprintln(messageWriter, stamp, "debug", messages)
fmt.Fprintln(messageWriter, "debug:", messages)
}
func logInfo(messages ...any) {
stamp := time.Now().Format(time.RFC3339Nano)
fmt.Fprintln(messageWriter, stamp, "info", messages)
fmt.Fprintln(messageWriter, "info:", messages)
}
func logError(messages ...any) {
stamp := time.Now().Format(time.RFC3339Nano)
fmt.Fprintln(messageWriter, stamp, "error", messages)
fmt.Fprintln(messageWriter, "error:", messages)
}
func logAccess(messages ...any) {
stamp := time.Now().Format(time.RFC3339Nano)
fmt.Fprintln(accessWriter, stamp, "access", messages)
fmt.Fprintln(accessWriter, "access:", messages)
}
func SetAccessWriter(writer io.Writer) {

View File

@@ -1,31 +1,29 @@
/*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/
package dsrpc
import (
"time"
)
func LogRequest(context *Context) error {
func LogRequest(content *Content) error {
var err error
logDebug("request:", string(context.reqRPC.JSON()))
logDebug("request:", string(content.reqBlock.ToJson()))
return err
}
func LogResponse(context *Context) error {
func LogResponse(content *Content) error {
var err error
logDebug("response:", string(context.resRPC.JSON()))
logDebug("response:", string(content.resBlock.ToJson()))
return err
}
func LogAccess(context *Context) error {
func LogAccess(content *Content) error {
var err error
execTime := time.Now().Sub(context.start)
logAccess(context.remoteHost, context.reqRPC.Method, execTime)
execTime := time.Now().Sub(content.start)
login := string(content.AuthIdent())
logAccess(content.remoteHost, login, content.reqBlock.Method, execTime)
return err
}

View File

@@ -1,25 +1,18 @@
/*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/
package dsrpc
import (
"encoding/json"
)
type Packet struct {
header []byte
rcpPayload []byte
}
func NewPacket() *Packet {
return &Packet{}
}
func (this *Packet) JSON() []byte {
jBytes, _ := json.Marshal(this)
return jBytes
func NewEmptyPacket() *Packet {
packet := &Packet{
header: make([]byte, 0),
rcpPayload: make([]byte, 0),
}
return packet
}

View File

@@ -8,26 +8,35 @@ package dsrpc
import (
"encoding/json"
encoder "encoding/json"
)
type EmptyParams struct{}
func NewEmptyParams() *EmptyParams {
return &EmptyParams{}
}
type Request struct {
Method string `json:"method" msgpack:"method"`
Params any `json:"params,omitempty" msgpack:"params,omitempty"`
Auth *Auth `json:"auth,omitempty" msgpack:"auth,omitempty"`
Params any `json:"params,omitempty" msgpack:"params"`
Auth *Auth `json:"auth,omitempty" msgpack:"auth"`
}
func NewRequest() *Request {
func NewEmptyRequest() *Request {
req := &Request{}
req.Auth = &Auth{}
req.Params = NewEmptyParams()
return req
}
func (this *Request) Pack() ([]byte, error) {
rBytes, err := json.Marshal(this)
func (req *Request) Pack() ([]byte, error) {
rBytes, err := encoder.Marshal(req)
return rBytes, err
}
func (this *Request) JSON() []byte {
jBytes, _ := json.Marshal(this)
func (req *Request) ToJson() []byte {
jBytes, _ := json.Marshal(req)
return jBytes
}

View File

@@ -1,31 +1,38 @@
/*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/
package dsrpc
import (
"encoding/json"
encoder "encoding/json"
)
type EmptyResult struct{}
func NewEmptyResult() *EmptyResult {
return &EmptyResult{}
}
type Response struct {
Error string `json:"error,omitempty" msgpack:"error,omitempty"`
Result any `json:"result,omitemty" msgpack:"result,omitemty"`
Error string `json:"error" msgpack:"error"`
Result any `json:"result" msgpack:"result"`
}
func NewResponse() *Response {
return &Response{}
func NewEmptyResponse() *Response {
return &Response{
Result: NewEmptyResult(),
}
}
func (this *Response) JSON() []byte {
jBytes, _ := json.Marshal(this)
func (resp *Response) ToJson() []byte {
jBytes, _ := json.Marshal(resp)
return jBytes
}
func (this *Response) Pack() ([]byte, error) {
rBytes, err := json.Marshal(this)
func (resp *Response) Pack() ([]byte, error) {
rBytes, err := encoder.Marshal(resp)
return rBytes, err
}

247
server.go
View File

@@ -1,21 +1,23 @@
/*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/
package dsrpc
import (
"context"
"encoding/json"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"sync"
"time"
encoder "encoding/json"
)
type HandlerFunc = func(*Context) error
type HandlerFunc = func(*Content) error
type Service struct {
handlers map[string]HandlerFunc
@@ -24,6 +26,11 @@ type Service struct {
wg *sync.WaitGroup
preMw []HandlerFunc
postMw []HandlerFunc
keepalive bool
kaTime time.Duration
kaMtx sync.Mutex
listener net.Listener
tcpListener *net.TCPListener
}
func NewService() *Service {
@@ -40,77 +47,152 @@ func NewService() *Service {
return rdrpc
}
func (this *Service) PreMiddleware(mw HandlerFunc) {
this.preMw = append(this.preMw, mw)
func (svc *Service) PreMiddleware(mw HandlerFunc) {
svc.preMw = append(svc.preMw, mw)
}
func (this *Service) PostMiddleware(mw HandlerFunc) {
this.postMw = append(this.postMw, mw)
func (svc *Service) PostMiddleware(mw HandlerFunc) {
svc.postMw = append(svc.postMw, mw)
}
func (this *Service) Handler(method string, handler HandlerFunc) {
this.handlers[method] = handler
func (svc *Service) Handle(method string, handler HandlerFunc) {
svc.handlers[method] = handler
}
func (this *Service) Listen(address string) error {
func (svc *Service) SetKeepAlive(flag bool) {
svc.kaMtx.Lock()
defer svc.kaMtx.Unlock()
svc.keepalive = true
}
func (svc *Service) SetKeepAlivePeriod(interval time.Duration) {
svc.kaMtx.Lock()
defer svc.kaMtx.Unlock()
svc.kaTime = interval
}
func (svc *Service) Listen(address string) error {
var err error
logInfo("server listen:", address)
listener, err := net.Listen("tcp", address)
addr, err := net.ResolveTCPAddr("tcp", address)
if err != nil {
err = fmt.Errorf("unable to resolve adddress: %s", err)
return err
}
this.wg.Add(1)
svc.tcpListener, err = net.ListenTCP("tcp", addr)
if err != nil {
err = fmt.Errorf("unable to start listener: %s", err)
return err
}
for {
select {
case <- this.ctx.Done():
this.wg.Done()
return err
default:
}
conn, err := listener.Accept()
conn, err := svc.tcpListener.AcceptTCP()
if err != nil {
logError("conn accept err:", err)
}
go this.handleConn(conn)
select {
case <-svc.ctx.Done():
return err
default:
}
svc.wg.Add(1)
go svc.handleTCPConn(conn, svc.wg)
}
return err
}
func notFound(context *Context) error {
func (svc *Service) ListenTLS(address string, tlsConfig *tls.Config) error {
var err error
logInfo("server listen:", address)
svc.listener, err = tls.Listen("tcp", address, tlsConfig)
if err != nil {
err = fmt.Errorf("unable to start listener: %s", err)
return err
}
for {
conn, err := svc.listener.Accept()
if err != nil {
logError("conn accept err:", err)
}
select {
case <-svc.ctx.Done():
logInfo("accept loop done")
return err
default:
}
svc.wg.Add(1)
go svc.handleConn(conn, svc.wg)
}
return err
}
func notFound(content *Content) error {
execErr := errors.New("method not found")
err := context.SendError(execErr)
err := content.SendError(execErr)
return err
}
func (this *Service) Stop() error {
func (svc *Service) Stop() error {
var err error
this.cancel()
this.wg.Wait()
// Disable new connection
logInfo("cancel rpc accept loop")
if svc.listener != nil {
svc.listener.Close()
}
if svc.tcpListener != nil {
svc.tcpListener.Close()
}
svc.cancel()
// Wait handlers
logInfo("wait rpc handlers")
svc.wg.Wait()
return err
}
func (this *Service) handleConn(conn net.Conn) {
func (svc *Service) handleTCPConn(conn *net.TCPConn, wg *sync.WaitGroup) {
var err error
if svc.keepalive {
err = conn.SetKeepAlive(true)
if err != nil {
err = fmt.Errorf("unable to set keepalive: %s", err)
return
}
if svc.kaTime > 0 {
err = conn.SetKeepAlivePeriod(svc.kaTime)
if err != nil {
err = fmt.Errorf("unable to set keepalive period: %s", err)
return
}
}
}
svc.handleConn(conn, wg)
}
func (svc *Service) handleConn(conn net.Conn, wg *sync.WaitGroup) {
var err error
context := CreateContext(conn)
content := CreateContent(conn)
remoteAddr := conn.RemoteAddr().String()
remoteHost, _, _ := net.SplitHostPort(remoteAddr)
context.remoteHost = remoteHost
content.remoteHost = remoteHost
context.binReader = conn
context.binWriter = io.Discard
content.binReader = conn
content.binWriter = io.Discard
exitFunc := func() {
conn.Close()
wg.Done()
if err != nil {
logError("conn handler err:", err)
}
}
defer exitFunc()
recovFunc := func () {
recovFunc := func() {
panicMsg := recover()
if panicMsg != nil {
logError("handler panic message:", panicMsg)
@@ -118,144 +200,147 @@ func (this *Service) handleConn(conn net.Conn) {
}
defer recovFunc()
err = context.ReadRequest()
err = content.ReadRequest()
if err != nil {
err = err
return
}
err = context.BindMethod()
err = content.BindMethod()
if err != nil {
err = err
return
}
for _, mw := range this.preMw {
err = mw(context)
for _, mw := range svc.preMw {
err = mw(content)
if err != nil {
err = err
return
}
}
err = this.Route(context)
err = svc.Route(content)
if err != nil {
err = err
return
}
for _, mw := range this.postMw {
err = mw(context)
for _, mw := range svc.postMw {
err = mw(content)
if err != nil {
err = err
return
}
}
return
}
func (this *Service) Route(context *Context) error {
handler, ok := this.handlers[context.reqRPC.Method]
func (svc *Service) Route(content *Content) error {
handler, ok := svc.handlers[content.reqBlock.Method]
if ok {
return handler(context)
return handler(content)
}
return notFound(context)
return notFound(content)
}
func (context *Context) ReadRequest() error {
func (content *Content) ReadRequest() error {
var err error
context.reqPacket.header, err = ReadBytes(context.sockReader, headerSize)
content.reqPacket.header, err = ReadBytes(content.sockReader, headerSize)
if err != nil {
return err
}
context.reqHeader, err = UnpackHeader(context.reqPacket.header)
content.reqHeader, err = UnpackHeader(content.reqPacket.header)
if err != nil {
return err
}
rpcSize := context.reqHeader.rpcSize
context.reqPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize)
rpcSize := content.reqHeader.rpcSize
content.reqPacket.rcpPayload, err = ReadBytes(content.sockReader, rpcSize)
if err != nil {
return err
}
return err
}
func (context *Context) BinWriter() io.Writer {
return context.sockWriter
func (content *Content) BinWriter() io.Writer {
return content.sockWriter
}
func (context *Context) BinReader() io.Reader {
return context.sockReader
func (content *Content) BinReader() io.Reader {
return content.sockReader
}
func (context *Context) BinSize() int64 {
return context.reqHeader.binSize
func (content *Content) BinSize() int64 {
return content.reqHeader.binSize
}
func (context *Context) ReadBin(writer io.Writer) error {
func (content *Content) ReadBin(ctx context.Context, writer io.Writer) error {
var err error
_, err = CopyBytes(context.sockReader, writer, context.reqHeader.binSize)
_, err = CopyBytes(ctx, content.sockReader, writer, content.reqHeader.binSize)
return err
}
func (context *Context) BindMethod() error {
func (content *Content) BindMethod() error {
var err error
err = json.Unmarshal(context.reqPacket.rcpPayload, context.reqRPC)
err = encoder.Unmarshal(content.reqPacket.rcpPayload, content.reqBlock)
return err
}
func (context *Context) BindParams(params any) error {
func (content *Content) BindParams(params any) error {
var err error
context.reqRPC.Params = params
err = json.Unmarshal(context.reqPacket.rcpPayload, context.reqRPC)
content.reqBlock.Params = params
err = encoder.Unmarshal(content.reqPacket.rcpPayload, content.reqBlock)
if err != nil {
return err
}
return err
}
func (context *Context) SendResult(result any, binSize int64) error {
func (content *Content) SendResult(result any, binSize int64) error {
var err error
context.resRPC.Result = result
content.resBlock.Result = result
context.resPacket.rcpPayload, err = context.resRPC.Pack()
content.resPacket.rcpPayload, err = content.resBlock.Pack()
if err != nil {
return err
}
context.resHeader.rpcSize = int64(len(context.resPacket.rcpPayload))
context.resHeader.binSize = binSize
content.resHeader.rpcSize = int64(len(content.resPacket.rcpPayload))
content.resHeader.binSize = binSize
context.resPacket.header, err = context.resHeader.Pack()
content.resPacket.header, err = content.resHeader.Pack()
if err != nil {
return err
}
_, err = context.sockWriter.Write(context.resPacket.header)
_, err = content.sockWriter.Write(content.resPacket.header)
if err != nil {
return err
}
_, err = context.sockWriter.Write(context.resPacket.rcpPayload)
_, err = content.sockWriter.Write(content.resPacket.rcpPayload)
if err != nil {
return err
}
return err
}
func (context *Context) SendError(execErr error) error {
func (content *Content) SendError(execErr error) error {
var err error
context.resRPC.Error = execErr.Error()
context.resRPC.Result = NewEmpty()
content.resBlock.Error = execErr.Error()
content.resBlock.Result = NewEmptyResult()
context.resPacket.rcpPayload, err = context.resRPC.Pack()
content.resPacket.rcpPayload, err = content.resBlock.Pack()
if err != nil {
return err
}
context.resHeader.rpcSize = int64(len(context.resPacket.rcpPayload))
context.resPacket.header, err = context.resHeader.Pack()
content.resHeader.rpcSize = int64(len(content.resPacket.rcpPayload))
content.resPacket.header, err = content.resHeader.Pack()
if err != nil {
return err
}
_, err = context.sockWriter.Write(context.resPacket.header)
_, err = content.sockWriter.Write(content.resPacket.header)
if err != nil {
return err
}
_, err = context.sockWriter.Write(context.resPacket.rcpPayload)
_, err = content.sockWriter.Write(content.resPacket.rcpPayload)
if err != nil {
return err
}

View File

@@ -5,9 +5,10 @@
package dsrpc
import (
"context"
"errors"
"io"
"fmt"
"io"
)
func ReadBytes(reader io.Reader, size int64) ([]byte, error) {
@@ -16,14 +17,20 @@ func ReadBytes(reader io.Reader, size int64) ([]byte, error) {
return buffer[0:read], err
}
func CopyBytes(reader io.Reader, writer io.Writer, dataSize int64) (int64, error) {
func CopyBytes(ctx context.Context, reader io.Reader, writer io.Writer, dataSize int64) (int64, error) {
var err error
var bSize int64 = 1024 * 4
var bSize int64 = 1024 * 16
var total int64 = 0
var remains int64 = dataSize
buffer := make([]byte, bSize)
for {
select {
case <-ctx.Done():
return total, errors.New("break by context")
default:
}
if reader == nil {
return total, errors.New("reader is nil")
}
@@ -38,14 +45,17 @@ func CopyBytes(reader io.Reader, writer io.Writer, dataSize int64) (int64, error
}
received, err := reader.Read(buffer[0:bSize])
if err != nil {
return total, fmt.Errorf("read error: %v", err)
err = fmt.Errorf("read error: %v", err)
return total, err
}
recorded, err := writer.Write(buffer[0:received])
if err != nil {
return total, fmt.Errorf("write error: %v", err)
err = fmt.Errorf("write error: %v", err)
return total, err
}
if recorded != received {
return total, errors.New("size mismatch")
err = errors.New("size mismatch")
return total, err
}
total += int64(recorded)
remains -= int64(recorded)

View File

@@ -4,31 +4,35 @@
package dsrpc
import (
"context"
"io"
"net"
)
func LocalExec(method string, param any, result any, auth *Auth, handler HandlerFunc) error {
func LocalExec(method string, param, result any, auth *Auth, handler HandlerFunc) error {
var err error
cliConn, srvConn := NewFConn()
context := CreateContext(cliConn)
context.reqRPC.Method = method
context.reqRPC.Params = param
context.reqRPC.Auth = auth
context.resRPC.Result = result
content := CreateContent(cliConn)
content.reqBlock.Method = method
if context.reqRPC.Params == nil {
context.reqRPC.Params = NewEmpty()
if param != nil {
content.reqBlock.Params = param
}
err = context.CreateRequest()
if auth != nil {
content.reqBlock.Auth = auth
}
if result != nil {
content.resBlock.Result = result
}
err = content.createRequest()
if err != nil {
return err
}
err = context.WriteRequest()
err = content.writeRequest()
if err != nil {
return err
}
@@ -36,11 +40,11 @@ func LocalExec(method string, param any, result any, auth *Auth, handler Handler
if err != nil {
return err
}
err = context.ReadResponse()
err = content.readResponse()
if err != nil {
return err
}
err = context.BindResponse()
err = content.bindResponse()
if err != nil {
return err
}
@@ -48,35 +52,39 @@ func LocalExec(method string, param any, result any, auth *Auth, handler Handler
return err
}
func LocalPut(method string, reader io.Reader, size int64, param, result any, auth *Auth, handler HandlerFunc) error {
func LocalPut(ctx context.Context, method string, reader io.Reader, size int64, param, result any, auth *Auth, handler HandlerFunc) error {
var err error
cliConn, srvConn := NewFConn()
context := CreateContext(cliConn)
context.reqRPC.Method = method
context.reqRPC.Params = param
context.reqRPC.Auth = auth
context.resRPC.Result = result
content := CreateContent(cliConn)
content.reqBlock.Method = method
context.binReader = reader
context.binWriter = cliConn
context.reqHeader.binSize = size
if context.reqRPC.Params == nil {
context.reqRPC.Params = NewEmpty()
if param != nil {
content.reqBlock.Params = param
}
err = context.CreateRequest()
if auth != nil {
content.reqBlock.Auth = auth
}
if result != nil {
content.resBlock.Result = result
}
content.binReader = reader
content.binWriter = cliConn
content.reqHeader.binSize = size
err = content.createRequest()
if err != nil {
return err
}
err = context.WriteRequest()
err = content.writeRequest()
if err != nil {
return err
}
err = context.UploadBin()
err = content.uploadBin(ctx)
if err != nil {
return err
}
@@ -84,40 +92,43 @@ func LocalPut(method string, reader io.Reader, size int64, param, result any, au
if err != nil {
return err
}
err = context.ReadResponse()
err = content.readResponse()
if err != nil {
return err
}
err = context.BindResponse()
err = content.bindResponse()
if err != nil {
return err
}
return err
}
func LocalGet(method string, writer io.Writer, param, result any, auth *Auth, handler HandlerFunc) error {
func LocalGet(ctx context.Context, method string, writer io.Writer, param, result any, auth *Auth, handler HandlerFunc) error {
var err error
cliConn, srvConn := NewFConn()
context := CreateContext(cliConn)
context.reqRPC.Method = method
context.reqRPC.Params = param
context.reqRPC.Auth = auth
context.resRPC.Result = result
content := CreateContent(cliConn)
content.reqBlock.Method = method
context.binReader = cliConn
context.binWriter = writer
if context.reqRPC.Params == nil {
context.reqRPC.Params = NewEmpty()
if param != nil {
content.reqBlock.Params = param
}
err = context.CreateRequest()
if auth != nil {
content.reqBlock.Auth = auth
}
if result != nil {
content.resBlock.Result = result
}
content.binReader = cliConn
content.binWriter = writer
err = content.createRequest()
if err != nil {
return err
}
err = context.WriteRequest()
err = content.writeRequest()
if err != nil {
return err
}
@@ -126,15 +137,15 @@ func LocalGet(method string, writer io.Writer, param, result any, auth *Auth, ha
if err != nil {
return err
}
err = context.ReadResponse()
err = content.readResponse()
if err != nil {
return err
}
err = context.DownloadBin()
err = content.downloadBin(ctx)
if err != nil {
return err
}
err = context.BindResponse()
err = content.bindResponse()
if err != nil {
return err
}
@@ -143,22 +154,22 @@ func LocalGet(method string, writer io.Writer, param, result any, auth *Auth, ha
func LocalService(conn net.Conn, handler HandlerFunc) error {
var err error
context := CreateContext(conn)
content := CreateContent(conn)
remoteAddr := conn.RemoteAddr().String()
remoteHost, _, _ := net.SplitHostPort(remoteAddr)
context.remoteHost = remoteHost
content.remoteHost = remoteHost
context.binReader = conn
context.binWriter = io.Discard
content.binReader = conn
content.binWriter = io.Discard
err = context.ReadRequest()
err = content.ReadRequest()
if err != nil {
return err
}
err = context.BindMethod()
err = content.BindMethod()
if err != nil {
return err
}
return handler(context)
return handler(content)
}

View File

@@ -5,11 +5,11 @@
package dsrpc
import (
"encoding/json"
"bytes"
"crypto/sha256"
"encoding/json"
"math/rand"
"time"
"crypto/sha256"
)
func init() {
@@ -17,16 +17,16 @@ func init() {
}
type Auth struct {
Ident []byte `json:"ident,omitempty"`
Salt []byte `json:"salt,omitempty"`
Hash []byte `json:"hash,omitempty"`
Ident []byte `msgpack:"ident" json:"ident"`
Salt []byte `msgpack:"salt" json:"salt"`
Hash []byte `msgpack:"hash" json:"hash"`
}
func NewAuth() *Auth {
return &Auth{}
}
func (this *Auth) JSON() []byte {
func (this *Auth) Json() []byte {
jBytes, _ := json.Marshal(this)
return jBytes
}
@@ -49,7 +49,7 @@ func CreateSalt() []byte {
}
func CreateHash(ident, pass, salt []byte) []byte {
vec := make([]byte, 0, len(ident) + len(salt) + len(pass))
vec := make([]byte, 0, len(ident)+len(salt)+len(pass))
vec = append(vec, ident...)
vec = append(vec, salt...)
vec = append(vec, pass...)