6 Commits

21 changed files with 1384 additions and 1203 deletions

215
README.md
View File

@@ -1,24 +1,24 @@
# dsrpc, Data RPC
DSRPC is easy and simple RPC framework over TCP socket.
### Purpose
A very easy and open RPC framework with data streaming.
A very easy and open RPC framework with data streaming.
### You can
### You can
- Use post and pre-execution middleware
- Use own post and pre-execution middleware
- Hash-based authentication in middleware
- Test call remote function without service organization
- Test remote function without network
Socket encryption is not used at this time since framefork
is oriented to transfer large amounts of data
Socket encryption is not used at this time since framefork
is oriented to transfer large amounts of data.
Style of the framework is similar to that of GIN framework.
Style of the framework is similar of GIN framework.
## Example
## Exec method example
### Server
@@ -135,7 +135,7 @@ package api
const HelloMethod string = "hello"
type HelloParams struct {
Message string `msgpack:"message" json:"message"`
Message string `json:"message"`
}
func NewHelloParams() *HelloParams {
@@ -143,7 +143,7 @@ func NewHelloParams() *HelloParams {
}
type HelloResult struct {
Message string `msgpack:"message" json:"message"`
Message string `json:"message"`
}
func NewHelloResult() *HelloResult {
@@ -151,3 +151,196 @@ func NewHelloResult() *HelloResult {
}
```
### Authentication and authorization
#### Client side
```
func clientHello() error {
var err error
params := NewHelloParams()
params.Message = "hello server!"
result := NewHelloResult()
auth := dsrpc.CreateAuth([]byte("login"), []byte("password"))
err = dsrpc.Exec("127.0.0.1:8081", HelloMethod, params, result, auth)
if err != nil {
log.Println("method err:", err)
return err
}
//...
}
```
#### Server side
```
func authMiddleware(context *dsrpc.Context) error {
var err error
reqIdent := context.AuthIdent()
reqSalt := context.AuthSalt()
reqHash := context.AuthHash()
if reqIdent != "login" {
err = errors.New("auth ident or pass mismatch")
context.SendError(err)
return err
}
ident := reqIdent
pass := []byte("password")
ok := dsrpc.CheckHash(ident, pass, reqSalt, reqHash)
log.Println("auth is ok:", ok)
if !ok {
err = errors.New("auth ident or pass mismatch")
context.SendError(err)
return err
}
return err
}
func sampleServ(quiet bool) error {
var err error
if quiet {
dsrpc.SetAccessWriter(io.Discard)
dsrpc.SetMessageWriter(io.Discard)
}
serv := NewService()
serv.PreMiddleware(authMiddleware)
serv.PreMiddleware(dsrpc.LogRequest)
serv.Handler(HelloMethod, helloHandler)
serv.Handler(SaveMethod, saveHandler)
serv.Handler(LoadMethod, loadHandler)
serv.PostMiddleware(dsrpc.LogResponse)
serv.PostMiddleware(dsrpc.LogAccess)
err = serv.Listen(":8081")
if err != nil {
return err
}
return err
}
```
### Put method
#### Client side sample
```
var binSize int64 = 16
rand.Seed(time.Now().UnixNano())
binBytes := make([]byte, binSize)
rand.Read(binBytes)
reader := bytes.NewReader(binBytes)
err = dsrpc.Put("127.0.0.1:8081", SaveMethod, reader, binSize, params, result, auth)
```
#### Server side
```
func saveHandler(context *dsrpc.Context) error {
var err error
params := NewSaveParams()
err = context.BindParams(params)
if err != nil {
return err
}
bufferBytes := make([]byte, 0, 1024)
binWriter := bytes.NewBuffer(bufferBytes)
err = context.ReadBin(binWriter)
if err != nil {
context.SendError(err)
return err
}
result := NewSaveResult()
result.Message = "saved successfully!"
err = context.SendResult(result, 0)
if err != nil {
return err
}
return err
}
```
### Get method
#### Client side
```
params := NewLoadParams()
params.Message = "load data!"
result := NewHelloResult()
auth := CreateAuth([]byte("qwert"), []byte("12345"))
binBytes := make([]byte, 0)
writer := bytes.NewBuffer(binBytes)
err = dsrpc.Get("127.0.0.1:8081", LoadMethod, writer, params, result, auth)
if err != nil {
return err
}
//...
```
#### Server side
```
func getHandler(context *dsrpc.Context) error {
var err error
params := NewSaveParams()
err = context.BindParams(params)
if err != nil {
return err
}
var binSize int64 = 1024
rand.Seed(time.Now().UnixNano())
binBytes := make([]byte, binSize)
rand.Read(binBytes)
binReader := bytes.NewReader(binBytes)
result := NewSaveResult()
result.Message = "load successfully!"
err = context.SendResult(result, binSize)
if err != nil {
return err
}
binWriter := context.BinWriter()
_, err = dsrpc.CopyBytes(binReader, binWriter, binSize)
if err != nil {
return err
}
return err
}
```

460
client.go
View File

@@ -1,282 +1,306 @@
/*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/
package dsrpc
import (
"encoding/json"
"errors"
"io"
"net"
"sync"
"errors"
"fmt"
"io"
"net"
"sync"
encoder "github.com/vmihailenco/msgpack/v5"
)
func Put(address string, method string, reader io.Reader, binSize int64, param, result any, auth *Auth) error {
var err error
func Put(address string, method string, reader io.Reader, size int64, param, result any, auth *Auth) error {
var err error
addr, err := net.ResolveTCPAddr("tcp", address)
if err != nil {
err = fmt.Errorf("unable to resolve adddress: %s", err)
return err
}
conn, err := net.DialTCP("tcp", nil, addr)
if err != nil {
return err
}
defer conn.Close()
conn, err := net.Dial("tcp", address)
if err != nil {
return err
}
defer conn.Close()
return ConnPut(conn, method, reader, size, param, result, auth)
return ConnPut(conn, method, reader, binSize, param, result, auth)
}
func ConnPut(conn net.Conn, method string, reader io.Reader, binSize int64, param, result any, auth *Auth) error {
var err error
content := CreateContent(conn)
func ConnPut(conn net.Conn, method string, reader io.Reader, size int64, param, result any, auth *Auth) error {
var err error
context := CreateContext(conn)
context.reqRPC.Method = method
context.reqRPC.Params = param
context.reqRPC.Auth = auth
context.resRPC.Result = result
content.reqBlock.Method = method
if param != nil {
content.reqBlock.Params = param
}
if auth != nil {
content.reqBlock.Auth = auth
}
if result != nil {
content.resBlock.Result = result
}
context.binReader = reader
context.binWriter = conn
content.binReader = reader
content.binWriter = conn
context.reqHeader.binSize = size
content.reqHeader.binSize = binSize
if context.reqRPC.Params == nil {
context.reqRPC.Params = NewEmpty()
}
err = content.CreateRequest()
if err != nil {
return err
}
err = content.WriteRequest()
if err != nil {
return err
}
err = context.CreateRequest()
if err != nil {
return err
}
err = context.WriteRequest()
if err != nil {
return err
}
var wg sync.WaitGroup
errChan := make(chan error, 1)
var wg sync.WaitGroup
errChan := make(chan error, 1)
wg.Add(1)
go content.ReadResponseAsync(&wg, errChan)
wg.Add(1)
go context.ReadResponseAsync(&wg, errChan)
wg.Add(1)
go content.UploadBinAsync(&wg)
wg.Add(1)
go context.UploadBinAsync(&wg)
wg.Wait()
err = <- errChan
if err != nil {
return err
}
err = context.BindResponse()
if err != nil {
return err
}
return err
wg.Wait()
err = <-errChan
if err != nil {
return err
}
err = content.BindResponse()
if err != nil {
return err
}
return err
}
func Get(address string, method string, writer io.Writer, param, result any, auth *Auth) error {
var err error
var err error
conn, err := net.Dial("tcp", address)
if err != nil {
return err
}
defer conn.Close()
addr, err := net.ResolveTCPAddr("tcp", address)
if err != nil {
err = fmt.Errorf("unable to resolve adddress: %s", err)
return err
}
conn, err := net.DialTCP("tcp", nil, addr)
if err != nil {
return err
}
defer conn.Close()
return ConnGet(conn, method, writer, param, result, auth)
return ConnGet(conn, method, writer, param, result, auth)
}
func ConnGet(conn net.Conn, method string, writer io.Writer, param, result any, auth *Auth) error {
var err error
var err error
context := CreateContext(conn)
context.reqRPC.Method = method
context.reqRPC.Params = param
context.reqRPC.Auth = auth
context.resRPC.Result = result
content := CreateContent(conn)
content.reqBlock.Method = method
if param != nil {
content.reqBlock.Params = param
}
if auth != nil {
content.reqBlock.Auth = auth
}
if result != nil {
content.resBlock.Result = result
}
context.binReader = conn
context.binWriter = writer
content.binReader = conn
content.binWriter = writer
if context.reqRPC.Params == nil {
context.reqRPC.Params = NewEmpty()
}
err = context.CreateRequest()
if err != nil {
return err
}
err = context.WriteRequest()
if err != nil {
return err
}
err = context.ReadResponse()
if err != nil {
return err
}
err = context.DownloadBin()
if err != nil {
return err
}
err = context.BindResponse()
if err != nil {
return err
}
return err
err = content.CreateRequest()
if err != nil {
return err
}
err = content.WriteRequest()
if err != nil {
return err
}
err = content.ReadResponse()
if err != nil {
return err
}
err = content.DownloadBin()
if err != nil {
return err
}
err = content.BindResponse()
if err != nil {
return err
}
return err
}
func Exec(address, method string, param any, result any, auth *Auth) error {
var err error
var err error
conn, err := net.Dial("tcp", address)
if err != nil {
return err
}
defer conn.Close()
addr, err := net.ResolveTCPAddr("tcp", address)
if err != nil {
err = fmt.Errorf("unable to resolve adddress: %s", err)
return err
}
conn, err := net.DialTCP("tcp", nil, addr)
if err != nil {
return err
}
defer conn.Close()
err = ConnExec(conn, method, param, result, auth)
if err != nil {
return err
}
return err
err = ConnExec(conn, method, param, result, auth)
if err != nil {
return err
}
return err
}
func ConnExec(conn net.Conn, method string, param any, result any, auth *Auth) error {
var err error
var err error
context := CreateContext(conn)
context.reqRPC.Method = method
context.reqRPC.Params = param
context.reqRPC.Auth = auth
context.resRPC.Result = result
content := CreateContent(conn)
content.reqBlock.Method = method
if context.reqRPC.Params == nil {
context.reqRPC.Params = NewEmpty()
}
if param != nil {
content.reqBlock.Params = param
}
if result != nil {
content.resBlock.Result = result
}
if auth != nil {
content.reqBlock.Auth = auth
}
err = context.CreateRequest()
if err != nil {
return err
}
err = context.WriteRequest()
if err != nil {
return err
}
err = context.ReadResponse()
if err != nil {
return err
}
err = context.BindResponse()
if err != nil {
return err
}
return err
err = content.CreateRequest()
if err != nil {
return err
}
err = content.WriteRequest()
if err != nil {
return err
}
err = content.ReadResponse()
if err != nil {
return err
}
err = content.BindResponse()
if err != nil {
return err
}
return err
}
func (content *Content) CreateRequest() error {
var err error
func (context *Context) CreateRequest() error {
var err error
content.reqPacket.rcpPayload, err = content.reqBlock.Pack()
if err != nil {
return err
}
rpcSize := int64(len(content.reqPacket.rcpPayload))
content.reqHeader.rpcSize = rpcSize
context.reqPacket.rcpPayload, err = context.reqRPC.Pack()
if err != nil {
return err
}
rpcSize := int64(len(context.reqPacket.rcpPayload))
context.reqHeader.rpcSize = rpcSize
context.reqPacket.header, err = context.reqHeader.Pack()
if err != nil {
return err
}
return err
content.reqPacket.header, err = content.reqHeader.Pack()
if err != nil {
return err
}
return err
}
func (context *Context) WriteRequest() error {
var err error
_, err = context.sockWriter.Write(context.reqPacket.header)
if err != nil {
return err
}
_, err = context.sockWriter.Write(context.reqPacket.rcpPayload)
if err != nil {
return err
}
return err
func (content *Content) WriteRequest() error {
var err error
_, err = content.sockWriter.Write(content.reqPacket.header)
if err != nil {
return err
}
_, err = content.sockWriter.Write(content.reqPacket.rcpPayload)
if err != nil {
return err
}
return err
}
func (context *Context) UploadBin() error {
var err error
_, err = CopyBytes(context.binReader, context.binWriter, context.reqHeader.binSize)
return err
func (content *Content) UploadBin() error {
var err error
_, err = CopyBytes(content.binReader, content.binWriter, content.reqHeader.binSize)
return err
}
func (context *Context) ReadResponse() error {
var err error
func (content *Content) ReadResponse() error {
var err error
context.resPacket.header, err = ReadBytes(context.sockReader, headerSize)
if err != nil {
return err
}
context.resHeader, err = UnpackHeader(context.resPacket.header)
if err != nil {
return err
}
rpcSize := context.resHeader.rpcSize
context.resPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize)
if err != nil {
return err
}
return err
content.resPacket.header, err = ReadBytes(content.sockReader, headerSize)
if err != nil {
return err
}
content.resHeader, err = UnpackHeader(content.resPacket.header)
if err != nil {
return err
}
rpcSize := content.resHeader.rpcSize
content.resPacket.rcpPayload, err = ReadBytes(content.sockReader, rpcSize)
if err != nil {
return err
}
return err
}
func (context *Context) UploadBinAsync(wg *sync.WaitGroup) {
exitFunc := func() {
wg.Done()
}
defer exitFunc()
_, _ = CopyBytes(context.binReader, context.binWriter, context.reqHeader.binSize)
return
func (content *Content) UploadBinAsync(wg *sync.WaitGroup) {
exitFunc := func() {
wg.Done()
}
defer exitFunc()
_, _ = CopyBytes(content.binReader, content.binWriter, content.reqHeader.binSize)
return
}
func (context *Context) ReadResponseAsync(wg *sync.WaitGroup, errChan chan error) {
var err error
exitFunc := func() {
errChan <- err
wg.Done()
}
defer exitFunc()
context.resPacket.header, err = ReadBytes(context.sockReader, headerSize)
if err != nil {
return
}
context.resHeader, err = UnpackHeader(context.resPacket.header)
if err != nil {
return
}
rpcSize := context.resHeader.rpcSize
context.resPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize)
if err != nil {
return
}
return
func (content *Content) ReadResponseAsync(wg *sync.WaitGroup, errChan chan error) {
var err error
exitFunc := func() {
errChan <- err
wg.Done()
}
defer exitFunc()
content.resPacket.header, err = ReadBytes(content.sockReader, headerSize)
if err != nil {
err = err
return
}
content.resHeader, err = UnpackHeader(content.resPacket.header)
if err != nil {
err = err
return
}
rpcSize := content.resHeader.rpcSize
content.resPacket.rcpPayload, err = ReadBytes(content.sockReader, rpcSize)
if err != nil {
err = err
return
}
return
}
func (context *Context) DownloadBin() error {
var err error
_, err = CopyBytes(context.binReader, context.binWriter, context.resHeader.binSize)
return err
func (content *Content) DownloadBin() error {
var err error
_, err = CopyBytes(content.binReader, content.binWriter, content.resHeader.binSize)
return err
}
func (context *Context) BindResponse() error {
var err error
func (content *Content) BindResponse() error {
var err error
err = json.Unmarshal(context.resPacket.rcpPayload, context.resRPC)
if err != nil {
return err
}
if len(context.resRPC.Error) > 0 {
return errors.New(context.resRPC.Error)
}
return err
err = encoder.Unmarshal(content.resPacket.rcpPayload, content.resBlock)
if err != nil {
return err
}
if len(content.resBlock.Error) > 0 {
err = errors.New(content.resBlock.Error)
return err
}
return err
}

View File

@@ -1,3 +0,0 @@
package dsrpc
type any = interface{}

View File

@@ -1,90 +0,0 @@
/*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/
package dsrpc
import (
"io"
"net"
"time"
)
type Context struct {
start time.Time
remoteHost string
sockReader io.Reader
sockWriter io.Writer
reqHeader *Header
reqRPC *Request
reqPacket *Packet
resPacket *Packet
resHeader *Header
resRPC *Response
binReader io.Reader
binWriter io.Writer
}
func NewContext() *Context {
context := &Context{}
context.start = time.Now()
return context
}
func CreateContext(conn net.Conn) *Context {
context := &Context{}
context.start = time.Now()
context.sockReader = conn
context.sockWriter = conn
context.reqPacket = NewPacket()
context.resPacket = NewPacket()
context.reqHeader = NewHeader()
context.reqRPC = NewRequest()
context.resHeader = NewHeader()
context.resRPC = NewResponse()
context.resRPC = NewResponse()
return context
}
func (context *Context) Request() *Request {
return context.reqRPC
}
func (context *Context) SetAuthIdent(ident []byte) {
context.reqRPC.Auth.Ident = ident
}
func (context *Context) SetAuthSalt(salt []byte) {
context.reqRPC.Auth.Salt = salt
}
func (context *Context) SetAuthHash(hash []byte) {
context.reqRPC.Auth.Hash = hash
}
func (context *Context) AuthIdent() []byte {
return context.reqRPC.Auth.Ident
}
func (context *Context) AuthSalt() []byte {
return context.reqRPC.Auth.Salt
}
func (context *Context) AuthHash() []byte {
return context.reqRPC.Auth.Hash
}
func (context *Context) Auth() *Auth {
return context.reqRPC.Auth
}

View File

@@ -1,13 +0,0 @@
/*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/
package dsrpc
type Empty struct {}
func NewEmpty() *Empty {
return &Empty{}
}

View File

@@ -1,376 +1,365 @@
/*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/
package dsrpc
import (
"bytes"
"encoding/json"
"errors"
"io"
"math/rand"
"testing"
"time"
"bytes"
"encoding/json"
"errors"
"io"
"math/rand"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/require"
)
func TestLocalExec(t *testing.T) {
var err error
params := NewHelloParams()
params.Message = "hello server!"
result := NewHelloResult()
var err error
params := NewHelloParams()
params.Message = "hello server!"
result := NewHelloResult()
auth := CreateAuth([]byte("qwert"), []byte("12345"))
auth := CreateAuth([]byte("qwert"), []byte("12345"))
err = LocalExec(HelloMethod, params, result, auth, helloHandler)
require.NoError(t, err)
resultJSON, _ := json.Marshal(result)
logDebug("method result:", string(resultJSON))
err = LocalExec(HelloMethod, params, result, auth, helloHandler)
require.NoError(t, err)
resultJSON, _ := json.Marshal(result)
logDebug("method result:", string(resultJSON))
}
func TestLocalSave(t *testing.T) {
var err error
var err error
params := NewSaveParams()
params.Message = "save data!"
result := NewHelloResult()
auth := CreateAuth([]byte("qwert"), []byte("12345"))
params := NewSaveParams()
params.Message = "save data!"
result := NewHelloResult()
auth := CreateAuth([]byte("qwert"), []byte("12345"))
var binSize int64 = 16
rand.Seed(time.Now().UnixNano())
binBytes := make([]byte, binSize)
rand.Read(binBytes)
var binSize int64 = 16
rand.Seed(time.Now().UnixNano())
binBytes := make([]byte, binSize)
rand.Read(binBytes)
reader := bytes.NewReader(binBytes)
reader := bytes.NewReader(binBytes)
err = LocalPut(SaveMethod, reader, binSize, params, result, auth, saveHandler)
require.NoError(t, err)
err = LocalPut(SaveMethod, reader, binSize, params, result, auth, saveHandler)
require.NoError(t, err)
resultJSON, _ := json.Marshal(result)
logDebug("method result:", string(resultJSON))
resultJSON, _ := json.Marshal(result)
logDebug("method result:", string(resultJSON))
}
func TestLocalLoad(t *testing.T) {
var err error
var err error
params := NewLoadParams()
params.Message = "load data!"
result := NewHelloResult()
auth := CreateAuth([]byte("qwert"), []byte("12345"))
params := NewLoadParams()
params.Message = "load data!"
result := NewHelloResult()
auth := CreateAuth([]byte("qwert"), []byte("12345"))
binBytes := make([]byte, 0)
writer := bytes.NewBuffer(binBytes)
binBytes := make([]byte, 0)
writer := bytes.NewBuffer(binBytes)
err = LocalGet(LoadMethod, writer, params, result, auth, loadHandler)
require.NoError(t, err)
err = LocalGet(LoadMethod, writer, params, result, auth, loadHandler)
require.NoError(t, err)
resultJSON, _ := json.Marshal(result)
logDebug("method result:", string(resultJSON))
logDebug("bin size:", len(writer.Bytes()))
resultJSON, _ := json.Marshal(result)
logDebug("method result:", string(resultJSON))
logDebug("bin size:", len(writer.Bytes()))
}
func TestNetExec(t *testing.T) {
go testServ(false)
time.Sleep(10 * time.Millisecond)
err := clientHello()
go testServ(false)
time.Sleep(10 * time.Millisecond)
err := clientHello()
require.NoError(t, err)
require.NoError(t, err)
}
func TestNetSave(t *testing.T) {
go testServ(false)
time.Sleep(10 * time.Millisecond)
err := clientSave()
require.NoError(t, err)
go testServ(false)
time.Sleep(10 * time.Millisecond)
err := clientSave()
require.NoError(t, err)
}
func TestNetLoad(t *testing.T) {
go testServ(false)
time.Sleep(10 * time.Millisecond)
err := clientLoad()
require.NoError(t, err)
go testServ(false)
time.Sleep(10 * time.Millisecond)
err := clientLoad()
require.NoError(t, err)
}
func BenchmarkNetPut(b *testing.B) {
go testServ(true)
time.Sleep(10 * time.Millisecond)
clientSave()
go testServ(true)
time.Sleep(10 * time.Millisecond)
clientSave()
pBench := func(pb *testing.PB) {
for pb.Next() {
clientSave()
}
}
b.SetParallelism(10)
b.RunParallel(pBench)
pBench := func(pb *testing.PB) {
for pb.Next() {
clientSave()
}
}
b.SetParallelism(10)
b.RunParallel(pBench)
}
func clientHello() error {
var err error
var err error
params := NewHelloParams()
params.Message = "hello server!"
result := NewHelloResult()
auth := CreateAuth([]byte("qwert"), []byte("12345"))
params := NewHelloParams()
params.Message = "hello server!"
result := NewHelloResult()
auth := CreateAuth([]byte("qwert"), []byte("12345"))
var binSize int64 = 16
rand.Seed(time.Now().UnixNano())
binBytes := make([]byte, binSize)
rand.Read(binBytes)
var binSize int64 = 16
rand.Seed(time.Now().UnixNano())
binBytes := make([]byte, binSize)
rand.Read(binBytes)
err = Exec("127.0.0.1:8081", HelloMethod, params, result, auth)
if err != nil {
logError("method err:", err)
return err
}
resultJSON, _ := json.Marshal(result)
logDebug("method result:", string(resultJSON))
return err
err = Exec("127.0.0.1:8081", HelloMethod, params, result, auth)
if err != nil {
logError("method err:", err)
return err
}
resultJSON, _ := json.Marshal(result)
logDebug("method result:", string(resultJSON))
return err
}
func clientSave() error {
var err error
var err error
params := NewSaveParams()
params.Message = "save data!"
result := NewHelloResult()
auth := CreateAuth([]byte("qwert"), []byte("12345"))
params := NewSaveParams()
params.Message = "save data!"
result := NewHelloResult()
auth := CreateAuth([]byte("qwert"), []byte("12345"))
var binSize int64 = 16
rand.Seed(time.Now().UnixNano())
binBytes := make([]byte, binSize)
rand.Read(binBytes)
var binSize int64 = 16
rand.Seed(time.Now().UnixNano())
binBytes := make([]byte, binSize)
rand.Read(binBytes)
reader := bytes.NewReader(binBytes)
reader := bytes.NewReader(binBytes)
err = Put("127.0.0.1:8081", SaveMethod, reader, binSize, params, result, auth)
if err != nil {
logError("method err:", err)
return err
}
resultJSON, _ := json.Marshal(result)
logDebug("method result:", string(resultJSON))
return err
err = Put("127.0.0.1:8081", SaveMethod, reader, binSize, params, result, auth)
if err != nil {
logError("method err:", err)
return err
}
resultJSON, _ := json.Marshal(result)
logDebug("method result:", string(resultJSON))
return err
}
func clientLoad() error {
var err error
var err error
params := NewLoadParams()
params.Message = "load data!"
result := NewHelloResult()
auth := CreateAuth([]byte("qwert"), []byte("12345"))
params := NewLoadParams()
params.Message = "load data!"
result := NewHelloResult()
auth := CreateAuth([]byte("qwert"), []byte("12345"))
binBytes := make([]byte, 0)
writer := bytes.NewBuffer(binBytes)
binBytes := make([]byte, 0)
writer := bytes.NewBuffer(binBytes)
err = Get("127.0.0.1:8081", LoadMethod, writer, params, result, auth)
if err != nil {
logError("method err:", err)
return err
}
resultJSON, _ := json.Marshal(result)
logDebug("method result:", string(resultJSON))
logDebug("bin size:", len(writer.Bytes()))
return err
err = Get("127.0.0.1:8081", LoadMethod, writer, params, result, auth)
if err != nil {
logError("method err:", err)
return err
}
resultJSON, _ := json.Marshal(result)
logDebug("method result:", string(resultJSON))
logDebug("bin size:", len(writer.Bytes()))
return err
}
var testServRun bool = false
func testServ(quiet bool) error {
var err error
var err error
if testServRun {
return err
}
testServRun = true
if testServRun {
return err
}
testServRun = true
if quiet {
SetAccessWriter(io.Discard)
SetMessageWriter(io.Discard)
}
serv := NewService()
serv.Handler(HelloMethod, helloHandler)
serv.Handler(SaveMethod, saveHandler)
serv.Handler(LoadMethod, loadHandler)
if quiet {
SetAccessWriter(io.Discard)
SetMessageWriter(io.Discard)
}
serv := NewService()
serv.Handler(HelloMethod, helloHandler)
serv.Handler(SaveMethod, saveHandler)
serv.Handler(LoadMethod, loadHandler)
serv.PreMiddleware(LogRequest)
serv.PreMiddleware(auth)
serv.PreMiddleware(LogRequest)
serv.PreMiddleware(auth)
serv.PostMiddleware(LogResponse)
serv.PostMiddleware(LogAccess)
serv.PostMiddleware(LogResponse)
serv.PostMiddleware(LogAccess)
err = serv.Listen(":8081")
if err != nil {
return err
}
return err
err = serv.Listen(":8081")
if err != nil {
return err
}
return err
}
func auth(context *Context) error {
var err error
reqIdent := context.AuthIdent()
reqSalt := context.AuthSalt()
reqHash := context.AuthHash()
func auth(content *Content) error {
var err error
reqIdent := content.AuthIdent()
reqSalt := content.AuthSalt()
reqHash := content.AuthHash()
ident := reqIdent
pass := []byte("12345")
ident := reqIdent
pass := []byte("12345")
auth := context.Auth()
logDebug("auth ", string(auth.JSON()))
auth := content.Auth()
logDebug("auth ", string(auth.JSON()))
ok := CheckHash(ident, pass, reqSalt, reqHash)
logDebug("auth ok:", ok)
if !ok {
err = errors.New("auth ident or pass missmatch")
context.SendError(err)
return err
}
return err
ok := CheckHash(ident, pass, reqSalt, reqHash)
logDebug("auth ok:", ok)
if !ok {
err = errors.New("auth ident or pass missmatch")
content.SendError(err)
return err
}
return err
}
func helloHandler(context *Context) error {
var err error
params := NewHelloParams()
func helloHandler(content *Content) error {
var err error
params := NewHelloParams()
err = context.BindParams(params)
if err != nil {
return err
}
err = content.BindParams(params)
if err != nil {
return err
}
err = context.ReadBin(io.Discard)
if err != nil {
context.SendError(err)
return err
}
err = content.ReadBin(io.Discard)
if err != nil {
content.SendError(err)
return err
}
result := NewHelloResult()
result.Message = "hello, client!"
result := NewHelloResult()
result.Message = "hello, client!"
err = context.SendResult(result, 0)
if err != nil {
return err
}
return err
err = content.SendResult(result, 0)
if err != nil {
return err
}
return err
}
func saveHandler(context *Context) error {
var err error
params := NewSaveParams()
func saveHandler(content *Content) error {
var err error
params := NewSaveParams()
err = context.BindParams(params)
if err != nil {
return err
}
err = content.BindParams(params)
if err != nil {
return err
}
bufferBytes := make([]byte, 0, 1024)
binWriter := bytes.NewBuffer(bufferBytes)
bufferBytes := make([]byte, 0, 1024)
binWriter := bytes.NewBuffer(bufferBytes)
err = context.ReadBin(binWriter)
if err != nil {
context.SendError(err)
return err
}
err = content.ReadBin(binWriter)
if err != nil {
content.SendError(err)
return err
}
result := NewSaveResult()
result.Message = "saved successfully!"
result := NewSaveResult()
result.Message = "saved successfully!"
err = context.SendResult(result, 0)
if err != nil {
return err
}
return err
err = content.SendResult(result, 0)
if err != nil {
return err
}
return err
}
func loadHandler(context *Context) error {
var err error
params := NewSaveParams()
func loadHandler(content *Content) error {
var err error
params := NewSaveParams()
err = context.BindParams(params)
if err != nil {
return err
}
err = content.BindParams(params)
if err != nil {
return err
}
err = context.ReadBin(io.Discard)
if err != nil {
context.SendError(err)
return err
}
err = content.ReadBin(io.Discard)
if err != nil {
content.SendError(err)
return err
}
var binSize int64 = 1024
rand.Seed(time.Now().UnixNano())
binBytes := make([]byte, binSize)
rand.Read(binBytes)
var binSize int64 = 1024
rand.Seed(time.Now().UnixNano())
binBytes := make([]byte, binSize)
rand.Read(binBytes)
binReader := bytes.NewReader(binBytes)
binReader := bytes.NewReader(binBytes)
result := NewSaveResult()
result.Message = "load successfully!"
result := NewSaveResult()
result.Message = "load successfully!"
err = context.SendResult(result, binSize)
if err != nil {
return err
}
binWriter := context.BinWriter()
_, err = CopyBytes(binReader, binWriter, binSize)
if err != nil {
return err
}
err = content.SendResult(result, binSize)
if err != nil {
return err
}
binWriter := content.BinWriter()
_, err = CopyBytes(binReader, binWriter, binSize)
if err != nil {
return err
}
return err
return err
}
const HelloMethod string = "hello"
type HelloParams struct {
Message string `json:"message" json:"message"`
Message string `json:"message" msgpack:"message"`
}
func NewHelloParams() *HelloParams {
return &HelloParams{}
return &HelloParams{}
}
type HelloResult struct {
Message string `json:"message" json:"message"`
Message string `json:"message" msgpack:"message"`
}
func NewHelloResult() *HelloResult {
return &HelloResult{}
return &HelloResult{}
}
const SaveMethod string = "save"
type SaveParams HelloParams
type SaveResult HelloResult
func NewSaveParams() *SaveParams {
return &SaveParams{}
return &SaveParams{}
}
func NewSaveResult() *SaveResult {
return &SaveResult{}
return &SaveResult{}
}
const LoadMethod string = "load"
type LoadParams HelloParams
type LoadResult HelloResult
func NewLoadParams() *LoadParams {
return &LoadParams{}
return &LoadParams{}
}
func NewLoadResult() *LoadResult {
return &LoadResult{}
return &LoadResult{}
}

View File

@@ -1,27 +1,26 @@
/*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/
package dsrpc
type FAddr struct {
network string
address string
network string
address string
}
func NewFAddr() *FAddr {
var addr FAddr
addr.network = "tcp"
addr.address = "127.0.0.1:5000"
return &addr
addr := FAddr{
network: "tcp",
address: "127.0.0.1:5000",
}
return &addr
}
func (addr *FAddr) Network() string {
return addr.network
return addr.network
}
func (addr *FAddr) String() string {
return addr.address
return addr.address
}

View File

@@ -7,51 +7,52 @@
package dsrpc
import (
"net"
"testing"
"github.com/stretchr/testify/require"
"net"
"testing"
"github.com/stretchr/testify/require"
)
func TestFConn0(t *testing.T) {
var cConn, sConn net.Conn
sConn, cConn = NewFConn()
var cConn, sConn net.Conn
sConn, cConn = NewFConn()
cData := []byte("qwerty")
count := 10
cData := []byte("qwerty")
count := 10
for i := 0; i < count; i++ {
wc, err := cConn.Write(cData)
if err != nil {
t.Error(err)
}
require.Equal(t, wc, len(cData))
for i := 0; i < count; i++ {
wc, err := cConn.Write(cData)
if err != nil {
t.Error(err)
}
require.Equal(t, wc, len(cData))
sData := make([]byte, len(cData))
rc, err := sConn.Read(sData)
require.NoError(t, err)
require.Equal(t, rc, len(cData))
require.Equal(t, cData, sData)
}
sData := make([]byte, len(cData))
rc, err := sConn.Read(sData)
require.NoError(t, err)
require.Equal(t, rc, len(cData))
require.Equal(t, cData, sData)
}
}
func TestFConn1(t *testing.T) {
var cConn, sConn net.Conn
cConn, sConn = NewFConn()
var cConn, sConn net.Conn
cConn, sConn = NewFConn()
cData := []byte("qwerty")
count := 10
cData := []byte("qwerty")
count := 10
for i := 0; i < count; i++ {
wc, err := cConn.Write(cData)
if err != nil {
t.Error(err)
}
require.Equal(t, wc, len(cData))
for i := 0; i < count; i++ {
wc, err := cConn.Write(cData)
if err != nil {
t.Error(err)
}
require.Equal(t, wc, len(cData))
sData := make([]byte, len(cData))
rc, err := sConn.Read(sData)
require.NoError(t, err)
require.Equal(t, rc, len(cData))
require.Equal(t, cData, sData)
}
sData := make([]byte, len(cData))
rc, err := sConn.Read(sData)
require.NoError(t, err)
require.Equal(t, rc, len(cData))
require.Equal(t, cData, sData)
}
}

View File

@@ -7,62 +7,62 @@
package dsrpc
import (
"bytes"
"io"
"net"
"time"
"bytes"
"io"
"net"
"time"
)
type FConn struct {
reader io.Reader
writer io.Writer
reader io.Reader
writer io.Writer
}
func NewFConn() (*FConn, *FConn){
c2sBuffer := bytes.NewBuffer(make([]byte, 0))
s2cBuffer := bytes.NewBuffer(make([]byte, 0))
func NewFConn() (*FConn, *FConn) {
c2sBuffer := bytes.NewBuffer(make([]byte, 0))
s2cBuffer := bytes.NewBuffer(make([]byte, 0))
var client FConn
client.writer = c2sBuffer
client.reader = s2cBuffer
var client FConn
client.writer = c2sBuffer
client.reader = s2cBuffer
var server FConn
server.writer = s2cBuffer
server.reader = c2sBuffer
var server FConn
server.writer = s2cBuffer
server.reader = c2sBuffer
return &client, &server
return &client, &server
}
func (conn FConn) SetDeadline(t time.Time) error {
var err error
return err
var err error
return err
}
func (conn FConn) SetReadDeadline(t time.Time) error {
var err error
return err
func (conn FConn) SetReadDeadline(t time.Time) error {
var err error
return err
}
func (conn FConn) SetWriteDeadline(t time.Time) error {
var err error
return err
var err error
return err
}
func (conn FConn) LocalAddr() net.Addr {
return NewFAddr()
return NewFAddr()
}
func (conn FConn) RemoteAddr() net.Addr {
return NewFAddr()
return NewFAddr()
}
func (conn FConn) Write(data []byte) (int, error) {
return conn.writer.Write(data)
return conn.writer.Write(data)
}
func (conn FConn) Read(data []byte) (int, error) {
return conn.reader.Read(data)
return conn.reader.Read(data)
}
func (conn FConn) Close() error {
var err error
return err
var err error
return err
}

16
go.mod
View File

@@ -1,11 +1,15 @@
module github.com/kindsoldier/dsrpc
go 1.17
require github.com/stretchr/testify v1.7.1
go 1.19
require (
github.com/davecgh/go-spew v1.1.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect
github.com/stretchr/testify v1.8.1
github.com/vmihailenco/msgpack/v5 v5.3.5
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

17
go.sum
View File

@@ -1,11 +1,22 @@
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU=
github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc=
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

119
header.go
View File

@@ -7,92 +7,95 @@
package dsrpc
import (
"errors"
"encoding/binary"
"encoding/json"
"bytes"
"bytes"
"encoding/binary"
"encoding/json"
"errors"
)
const headerSize int64 = 16 * 2
const sizeOfInt64 int = 8
const magicCodeA int64 = 0xEE00ABBA
const magicCodeB int64 = 0xEE44ABBA
const (
headerSize int64 = 16 * 2
sizeOfInt64 int = 8
magicCodeA int64 = 0xEE00ABBA
magicCodeB int64 = 0xEE44ABBA
)
type Header struct {
magicCodeA int64 `json:"magicCodeA"`
rpcSize int64 `json:"rpcSize"`
binSize int64 `json:"binSize"`
magicCodeB int64 `json:"magicCodeB"`
magicCodeA int64 `json:"magicCodeA"`
rpcSize int64 `json:"rpcSize"`
binSize int64 `json:"binSize"`
magicCodeB int64 `json:"magicCodeB"`
}
func NewHeader() *Header {
return &Header{
magicCodeA: magicCodeA,
magicCodeB: magicCodeB,
}
func NewEmptyHeader() *Header {
return &Header{
magicCodeA: magicCodeA,
magicCodeB: magicCodeB,
}
}
func (this *Header) JSON() []byte {
jBytes, _ := json.Marshal(this)
return jBytes
func (hdr *Header) ToJson() []byte {
jBytes, _ := json.Marshal(hdr)
return jBytes
}
func (hdr *Header) Pack() ([]byte, error) {
var err error
headerBytes := make([]byte, 0, headerSize)
headerBuffer := bytes.NewBuffer(headerBytes)
func (this *Header) Pack() ([]byte, error) {
var err error
headerBytes := make([]byte, 0, headerSize)
headerBuffer := bytes.NewBuffer(headerBytes)
magicCodeABytes := encoderI64(hdr.magicCodeA)
headerBuffer.Write(magicCodeABytes)
magicCodeABytes := encoderI64(this.magicCodeA)
headerBuffer.Write(magicCodeABytes)
rpcSizeBytes := encoderI64(hdr.rpcSize)
headerBuffer.Write(rpcSizeBytes)
rpcSizeBytes := encoderI64(this.rpcSize)
headerBuffer.Write(rpcSizeBytes)
binSizeBytes := encoderI64(hdr.binSize)
headerBuffer.Write(binSizeBytes)
binSizeBytes := encoderI64(this.binSize)
headerBuffer.Write(binSizeBytes)
magicCodeBBytes := encoderI64(hdr.magicCodeB)
headerBuffer.Write(magicCodeBBytes)
magicCodeBBytes := encoderI64(this.magicCodeB)
headerBuffer.Write(magicCodeBBytes)
return headerBuffer.Bytes(), err
return headerBuffer.Bytes(), err
}
func UnpackHeader(headerBytes []byte) (*Header, error) {
var err error
header := NewHeader()
headerReader := bytes.NewReader(headerBytes)
var err error
magicCodeABytes := make([]byte, sizeOfInt64)
headerReader.Read(magicCodeABytes)
header.magicCodeA = decoderI64(magicCodeABytes)
headerReader := bytes.NewReader(headerBytes)
rpcSizeBytes := make([]byte, sizeOfInt64)
headerReader.Read(rpcSizeBytes)
header.rpcSize = decoderI64(rpcSizeBytes)
magicCodeABytes := make([]byte, sizeOfInt64)
headerReader.Read(magicCodeABytes)
binSizeBytes := make([]byte, sizeOfInt64)
headerReader.Read(binSizeBytes)
header.binSize = decoderI64(binSizeBytes)
rpcSizeBytes := make([]byte, sizeOfInt64)
headerReader.Read(rpcSizeBytes)
magicCodeBBytes := make([]byte, sizeOfInt64)
headerReader.Read(magicCodeBBytes)
header.magicCodeB = decoderI64(magicCodeBBytes)
binSizeBytes := make([]byte, sizeOfInt64)
headerReader.Read(binSizeBytes)
if header.magicCodeA != magicCodeA || header.magicCodeB != magicCodeB {
return header, errors.New("wrong protocol magic code")
}
magicCodeBBytes := make([]byte, sizeOfInt64)
headerReader.Read(magicCodeBBytes)
return header, err
header := &Header{
magicCodeA: decoderI64(magicCodeABytes),
rpcSize: decoderI64(rpcSizeBytes),
binSize: decoderI64(binSizeBytes),
magicCodeB: decoderI64(magicCodeBBytes),
}
if header.magicCodeA != magicCodeA || header.magicCodeB != magicCodeB {
err = errors.New("Wrong protocol magic code")
return header, err
}
return header, err
}
func encoderI64(i int64) []byte {
buffer := make([]byte, sizeOfInt64)
binary.BigEndian.PutUint64(buffer, uint64(i))
return buffer
buffer := make([]byte, sizeOfInt64)
binary.BigEndian.PutUint64(buffer, uint64(i))
return buffer
}
func decoderI64(b []byte) int64 {
return int64(binary.BigEndian.Uint64(b))
return int64(binary.BigEndian.Uint64(b))
}

View File

@@ -7,39 +7,39 @@
package dsrpc
import (
"fmt"
"io"
"os"
"time"
"fmt"
"io"
"os"
"time"
)
var messageWriter io.Writer = os.Stdout
var accessWriter io.Writer = os.Stdout
func logDebug(messages ...any) {
stamp := time.Now().Format(time.RFC3339Nano)
fmt.Fprintln(messageWriter, stamp, "debug", messages)
stamp := time.Now().Format(time.RFC3339)
fmt.Fprintln(messageWriter, stamp, "debug", messages)
}
func logInfo(messages ...any) {
stamp := time.Now().Format(time.RFC3339Nano)
fmt.Fprintln(messageWriter, stamp, "info", messages)
stamp := time.Now().Format(time.RFC3339)
fmt.Fprintln(messageWriter, stamp, "info", messages)
}
func logError(messages ...any) {
stamp := time.Now().Format(time.RFC3339Nano)
fmt.Fprintln(messageWriter, stamp, "error", messages)
stamp := time.Now().Format(time.RFC3339)
fmt.Fprintln(messageWriter, stamp, "error", messages)
}
func logAccess(messages ...any) {
stamp := time.Now().Format(time.RFC3339Nano)
fmt.Fprintln(accessWriter, stamp, "access", messages)
stamp := time.Now().Format(time.RFC3339)
fmt.Fprintln(accessWriter, stamp, "access", messages)
}
func SetAccessWriter(writer io.Writer) {
accessWriter = writer
accessWriter = writer
}
func SetMessageWriter(writer io.Writer) {
messageWriter = writer
messageWriter = writer
}

View File

@@ -1,31 +1,29 @@
/*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/
package dsrpc
import (
"time"
"time"
)
func LogRequest(context *Context) error {
var err error
logDebug("request:", string(context.reqRPC.JSON()))
return err
func LogRequest(content *Content) error {
var err error
logDebug("request:", string(content.reqBlock.ToJson()))
return err
}
func LogResponse(context *Context) error {
var err error
logDebug("response:", string(context.resRPC.JSON()))
return err
func LogResponse(content *Content) error {
var err error
logDebug("response:", string(content.resBlock.ToJson()))
return err
}
func LogAccess(context *Context) error {
var err error
execTime := time.Now().Sub(context.start)
logAccess(context.remoteHost, context.reqRPC.Method, execTime)
return err
func LogAccess(content *Content) error {
var err error
execTime := time.Now().Sub(content.start)
login := string(content.AuthIdent())
logAccess(content.remoteHost, login, content.reqBlock.Method, execTime)
return err
}

View File

@@ -1,25 +1,18 @@
/*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/
package dsrpc
import (
"encoding/json"
)
type Packet struct {
header []byte
rcpPayload []byte
header []byte
rcpPayload []byte
}
func NewPacket() *Packet {
return &Packet{}
}
func (this *Packet) JSON() []byte {
jBytes, _ := json.Marshal(this)
return jBytes
func NewEmptyPacket() *Packet {
packet := &Packet{
header: make([]byte, 0),
rcpPayload: make([]byte, 0),
}
return packet
}

View File

@@ -7,27 +7,35 @@
package dsrpc
import (
"encoding/json"
"encoding/json"
encoder "github.com/vmihailenco/msgpack/v5"
)
type EmptyParams struct{}
func NewEmptyParams() *EmptyParams {
return &EmptyParams{}
}
type Request struct {
Method string `json:"method" msgpack:"method"`
Params any `json:"params,omitempty" msgpack:"params,omitempty"`
Auth *Auth `json:"auth,omitempty" msgpack:"auth,omitempty"`
Method string `json:"method" msgpack:"method"`
Params any `json:"params,omitempty" msgpack:"params"`
Auth *Auth `json:"auth,omitempty" msgpack:"auth"`
}
func NewRequest() *Request {
req := &Request{}
req.Auth = &Auth{}
return req
func NewEmptyRequest() *Request {
req := &Request{}
req.Auth = &Auth{}
req.Params = NewEmptyParams()
return req
}
func (this *Request) Pack() ([]byte, error) {
rBytes, err := json.Marshal(this)
return rBytes, err
func (req *Request) Pack() ([]byte, error) {
rBytes, err := encoder.Marshal(req)
return rBytes, err
}
func (this *Request) JSON() []byte {
jBytes, _ := json.Marshal(this)
return jBytes
func (req *Request) ToJson() []byte {
jBytes, _ := json.Marshal(req)
return jBytes
}

View File

@@ -1,31 +1,38 @@
/*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/
package dsrpc
import (
"encoding/json"
"encoding/json"
encoder "github.com/vmihailenco/msgpack/v5"
)
type EmptyResult struct{}
func NewEmptyResult() *EmptyResult {
return &EmptyResult{}
}
type Response struct {
Error string `json:"error,omitempty" msgpack:"error,omitempty"`
Result any `json:"result,omitemty" msgpack:"result,omitemty"`
Error string `json:"error" msgpack:"error"`
Result any `json:"result" msgpack:"result"`
}
func NewResponse() *Response {
return &Response{}
func NewEmptyResponse() *Response {
return &Response{
Result: NewEmptyResult(),
}
}
func (this *Response) JSON() []byte {
jBytes, _ := json.Marshal(this)
return jBytes
func (resp *Response) ToJson() []byte {
jBytes, _ := json.Marshal(resp)
return jBytes
}
func (this *Response) Pack() ([]byte, error) {
rBytes, err := json.Marshal(this)
return rBytes, err
func (resp *Response) Pack() ([]byte, error) {
rBytes, err := encoder.Marshal(resp)
return rBytes, err
}

466
server.go
View File

@@ -1,263 +1,307 @@
/*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/
package dsrpc
import (
"context"
"encoding/json"
"errors"
"io"
"net"
"sync"
"context"
"errors"
"fmt"
"io"
"net"
"sync"
"time"
encoder "github.com/vmihailenco/msgpack/v5"
)
type HandlerFunc = func(*Context) error
type HandlerFunc = func(*Content) error
type Service struct {
handlers map[string]HandlerFunc
ctx context.Context
cancel context.CancelFunc
wg *sync.WaitGroup
preMw []HandlerFunc
postMw []HandlerFunc
handlers map[string]HandlerFunc
ctx context.Context
cancel context.CancelFunc
wg *sync.WaitGroup
preMw []HandlerFunc
postMw []HandlerFunc
keepalive bool
kaTime time.Duration
kaMtx sync.Mutex
}
func NewService() *Service {
rdrpc := &Service{}
rdrpc.handlers = make(map[string]HandlerFunc)
ctx, cancel := context.WithCancel(context.Background())
rdrpc.ctx = ctx
rdrpc.cancel = cancel
var wg sync.WaitGroup
rdrpc.wg = &wg
rdrpc.preMw = make([]HandlerFunc, 0)
rdrpc.postMw = make([]HandlerFunc, 0)
rdrpc := &Service{}
rdrpc.handlers = make(map[string]HandlerFunc)
ctx, cancel := context.WithCancel(context.Background())
rdrpc.ctx = ctx
rdrpc.cancel = cancel
var wg sync.WaitGroup
rdrpc.wg = &wg
rdrpc.preMw = make([]HandlerFunc, 0)
rdrpc.postMw = make([]HandlerFunc, 0)
return rdrpc
return rdrpc
}
func (this *Service) PreMiddleware(mw HandlerFunc) {
this.preMw = append(this.preMw, mw)
func (svc *Service) PreMiddleware(mw HandlerFunc) {
svc.preMw = append(svc.preMw, mw)
}
func (this *Service) PostMiddleware(mw HandlerFunc) {
this.postMw = append(this.postMw, mw)
func (svc *Service) PostMiddleware(mw HandlerFunc) {
svc.postMw = append(svc.postMw, mw)
}
func (this *Service) Handler(method string, handler HandlerFunc) {
this.handlers[method] = handler
func (svc *Service) Handler(method string, handler HandlerFunc) {
svc.handlers[method] = handler
}
func (this *Service) Listen(address string) error {
var err error
logInfo("server listen:", address)
listener, err := net.Listen("tcp", address)
if err != nil {
return err
}
this.wg.Add(1)
for {
select {
case <- this.ctx.Done():
this.wg.Done()
return err
default:
}
conn, err := listener.Accept()
if err != nil {
logError("conn accept err:", err)
}
go this.handleConn(conn)
}
func (svc *Service) SetKeepAlive(flag bool) {
svc.kaMtx.Lock()
defer svc.kaMtx.Unlock()
svc.keepalive = true
}
func notFound(context *Context) error {
execErr := errors.New("method not found")
err := context.SendError(execErr)
return err
func (svc *Service) SetKeepAlivePeriod(interval time.Duration) {
svc.kaMtx.Lock()
defer svc.kaMtx.Unlock()
svc.kaTime = interval
}
func (this *Service) Stop() error {
var err error
this.cancel()
this.wg.Wait()
return err
func (svc *Service) Listen(address string) error {
var err error
logInfo("server listen:", address)
addr, err := net.ResolveTCPAddr("tcp", address)
if err != nil {
err = fmt.Errorf("unable to resolve adddress: %s", err)
return err
}
listener, err := net.ListenTCP("tcp", addr)
if err != nil {
err = fmt.Errorf("unable to start listener: %s", err)
return err
}
for {
conn, err := listener.AcceptTCP()
if err != nil {
logError("conn accept err:", err)
}
select {
case <-svc.ctx.Done():
return err
default:
}
svc.wg.Add(1)
go svc.handleConn(conn, svc.wg)
}
return err
}
func (this *Service) handleConn(conn net.Conn) {
var err error
context := CreateContext(conn)
remoteAddr := conn.RemoteAddr().String()
remoteHost, _, _ := net.SplitHostPort(remoteAddr)
context.remoteHost = remoteHost
context.binReader = conn
context.binWriter = io.Discard
exitFunc := func() {
conn.Close()
if err != nil {
logError("conn handler err:", err)
}
}
defer exitFunc()
recovFunc := func () {
panicMsg := recover()
if panicMsg != nil {
logError("handler panic message:", panicMsg)
}
}
defer recovFunc()
err = context.ReadRequest()
if err != nil {
return
}
err = context.BindMethod()
if err != nil {
return
}
for _, mw := range this.preMw {
err = mw(context)
if err != nil {
return
}
}
err = this.Route(context)
if err != nil {
return
}
for _, mw := range this.postMw {
err = mw(context)
if err != nil {
return
}
}
return
func notFound(content *Content) error {
execErr := errors.New("method not found")
err := content.SendError(execErr)
return err
}
func (this *Service) Route(context *Context) error {
handler, ok := this.handlers[context.reqRPC.Method]
if ok {
return handler(context)
}
return notFound(context)
func (svc *Service) Stop() error {
var err error
// Disable new connection
logInfo("cancel rpc accept loop")
svc.cancel()
// Wait handlers
logInfo("wait rpc handlers")
svc.wg.Wait()
return err
}
func (context *Context) ReadRequest() error {
var err error
func (svc *Service) handleConn(conn *net.TCPConn, wg *sync.WaitGroup) {
var err error
context.reqPacket.header, err = ReadBytes(context.sockReader, headerSize)
if err != nil {
return err
}
context.reqHeader, err = UnpackHeader(context.reqPacket.header)
if err != nil {
return err
}
if svc.keepalive {
err = conn.SetKeepAlive(true)
if err != nil {
err = fmt.Errorf("unable to set keepalive: %s", err)
return
}
if svc.kaTime > 0 {
err = conn.SetKeepAlivePeriod(svc.kaTime)
if err != nil {
err = fmt.Errorf("unable to set keepalive period: %s", err)
return
}
}
}
content := CreateContent(conn)
rpcSize := context.reqHeader.rpcSize
context.reqPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize)
if err != nil {
return err
}
return err
remoteAddr := conn.RemoteAddr().String()
remoteHost, _, _ := net.SplitHostPort(remoteAddr)
content.remoteHost = remoteHost
content.binReader = conn
content.binWriter = io.Discard
exitFunc := func() {
conn.Close()
wg.Done()
if err != nil {
logError("conn handler err:", err)
}
}
defer exitFunc()
recovFunc := func() {
panicMsg := recover()
if panicMsg != nil {
logError("handler panic message:", panicMsg)
}
}
defer recovFunc()
err = content.ReadRequest()
if err != nil {
err = err
return
}
err = content.BindMethod()
if err != nil {
err = err
return
}
for _, mw := range svc.preMw {
err = mw(content)
if err != nil {
err = err
return
}
}
err = svc.Route(content)
if err != nil {
err = err
return
}
for _, mw := range svc.postMw {
err = mw(content)
if err != nil {
err = err
return
}
}
return
}
func (context *Context) BinWriter() io.Writer {
return context.sockWriter
func (svc *Service) Route(content *Content) error {
handler, ok := svc.handlers[content.reqBlock.Method]
if ok {
return handler(content)
}
return notFound(content)
}
func (context *Context) BinReader() io.Reader {
return context.sockReader
func (content *Content) ReadRequest() error {
var err error
content.reqPacket.header, err = ReadBytes(content.sockReader, headerSize)
if err != nil {
return err
}
content.reqHeader, err = UnpackHeader(content.reqPacket.header)
if err != nil {
return err
}
rpcSize := content.reqHeader.rpcSize
content.reqPacket.rcpPayload, err = ReadBytes(content.sockReader, rpcSize)
if err != nil {
return err
}
return err
}
func (context *Context) BinSize() int64 {
return context.reqHeader.binSize
func (content *Content) BinWriter() io.Writer {
return content.sockWriter
}
func (context *Context) ReadBin(writer io.Writer) error {
var err error
_, err = CopyBytes(context.sockReader, writer, context.reqHeader.binSize)
return err
func (content *Content) BinReader() io.Reader {
return content.sockReader
}
func (context *Context) BindMethod() error {
var err error
err = json.Unmarshal(context.reqPacket.rcpPayload, context.reqRPC)
return err
func (content *Content) BinSize() int64 {
return content.reqHeader.binSize
}
func (context *Context) BindParams(params any) error {
var err error
context.reqRPC.Params = params
err = json.Unmarshal(context.reqPacket.rcpPayload, context.reqRPC)
if err != nil {
return err
}
return err
func (content *Content) ReadBin(writer io.Writer) error {
var err error
_, err = CopyBytes(content.sockReader, writer, content.reqHeader.binSize)
return err
}
func (context *Context) SendResult(result any, binSize int64) error {
var err error
context.resRPC.Result = result
context.resPacket.rcpPayload, err = context.resRPC.Pack()
if err != nil {
return err
}
context.resHeader.rpcSize = int64(len(context.resPacket.rcpPayload))
context.resHeader.binSize = binSize
context.resPacket.header, err = context.resHeader.Pack()
if err != nil {
return err
}
_, err = context.sockWriter.Write(context.resPacket.header)
if err != nil {
return err
}
_, err = context.sockWriter.Write(context.resPacket.rcpPayload)
if err != nil {
return err
}
return err
func (content *Content) BindMethod() error {
var err error
err = encoder.Unmarshal(content.reqPacket.rcpPayload, content.reqBlock)
return err
}
func (context *Context) SendError(execErr error) error {
var err error
context.resRPC.Error = execErr.Error()
context.resRPC.Result = NewEmpty()
context.resPacket.rcpPayload, err = context.resRPC.Pack()
if err != nil {
return err
}
context.resHeader.rpcSize = int64(len(context.resPacket.rcpPayload))
context.resPacket.header, err = context.resHeader.Pack()
if err != nil {
return err
}
_, err = context.sockWriter.Write(context.resPacket.header)
if err != nil {
return err
}
_, err = context.sockWriter.Write(context.resPacket.rcpPayload)
if err != nil {
return err
}
return err
func (content *Content) BindParams(params any) error {
var err error
content.reqBlock.Params = params
err = encoder.Unmarshal(content.reqPacket.rcpPayload, content.reqBlock)
if err != nil {
return err
}
return err
}
func (content *Content) SendResult(result any, binSize int64) error {
var err error
content.resBlock.Result = result
content.resPacket.rcpPayload, err = content.resBlock.Pack()
if err != nil {
return err
}
content.resHeader.rpcSize = int64(len(content.resPacket.rcpPayload))
content.resHeader.binSize = binSize
content.resPacket.header, err = content.resHeader.Pack()
if err != nil {
return err
}
_, err = content.sockWriter.Write(content.resPacket.header)
if err != nil {
return err
}
_, err = content.sockWriter.Write(content.resPacket.rcpPayload)
if err != nil {
return err
}
return err
}
func (content *Content) SendError(execErr error) error {
var err error
content.resBlock.Error = execErr.Error()
content.resBlock.Result = NewEmptyResult()
content.resPacket.rcpPayload, err = content.resBlock.Pack()
if err != nil {
return err
}
content.resHeader.rpcSize = int64(len(content.resPacket.rcpPayload))
content.resPacket.header, err = content.resHeader.Pack()
if err != nil {
return err
}
_, err = content.sockWriter.Write(content.resPacket.header)
if err != nil {
return err
}
_, err = content.sockWriter.Write(content.resPacket.rcpPayload)
if err != nil {
return err
}
return err
}

View File

@@ -5,50 +5,53 @@
package dsrpc
import (
"errors"
"io"
"fmt"
"errors"
"fmt"
"io"
)
func ReadBytes(reader io.Reader, size int64) ([]byte, error) {
buffer := make([]byte, size)
read, err := io.ReadFull(reader, buffer)
return buffer[0:read], err
buffer := make([]byte, size)
read, err := io.ReadFull(reader, buffer)
return buffer[0:read], err
}
func CopyBytes(reader io.Reader, writer io.Writer, dataSize int64) (int64, error) {
var err error
var bSize int64 = 1024 * 4
var total int64 = 0
var remains int64 = dataSize
buffer := make([]byte, bSize)
var err error
var bSize int64 = 1024 * 16
var total int64 = 0
var remains int64 = dataSize
buffer := make([]byte, bSize)
for {
if reader == nil {
return total, errors.New("reader is nil")
}
if writer == nil {
return total, errors.New("writer is nil")
}
if remains == 0 {
return total, err
}
if remains < bSize {
bSize = remains
}
received, err := reader.Read(buffer[0:bSize])
if err != nil {
return total, fmt.Errorf("read error: %v", err)
}
recorded, err := writer.Write(buffer[0:received])
if err != nil {
return total, fmt.Errorf("write error: %v", err)
}
if recorded != received {
return total, errors.New("size mismatch")
}
total += int64(recorded)
remains -= int64(recorded)
}
return total, err
for {
if reader == nil {
return total, errors.New("reader is nil")
}
if writer == nil {
return total, errors.New("writer is nil")
}
if remains == 0 {
return total, err
}
if remains < bSize {
bSize = remains
}
received, err := reader.Read(buffer[0:bSize])
if err != nil {
err = fmt.Errorf("read error: %v", err)
return total, err
}
recorded, err := writer.Write(buffer[0:received])
if err != nil {
err = fmt.Errorf("write error: %v", err)
return total, err
}
if recorded != received {
err = errors.New("size mismatch")
return total, err
}
total += int64(recorded)
remains -= int64(recorded)
}
return total, err
}

View File

@@ -4,161 +4,171 @@
package dsrpc
import (
"io"
"net"
"io"
"net"
)
func LocalExec(method string, param any, result any, auth *Auth, handler HandlerFunc) error {
var err error
func LocalExec(method string, param, result any, auth *Auth, handler HandlerFunc) error {
var err error
cliConn, srvConn := NewFConn()
cliConn, srvConn := NewFConn()
context := CreateContext(cliConn)
context.reqRPC.Method = method
context.reqRPC.Params = param
context.reqRPC.Auth = auth
context.resRPC.Result = result
content := CreateContent(cliConn)
content.reqBlock.Method = method
if context.reqRPC.Params == nil {
context.reqRPC.Params = NewEmpty()
}
err = context.CreateRequest()
if err != nil {
return err
}
err = context.WriteRequest()
if err != nil {
return err
}
err = LocalService(srvConn, handler)
if err != nil {
return err
}
err = context.ReadResponse()
if err != nil {
return err
}
err = context.BindResponse()
if err != nil {
return err
}
if param != nil {
content.reqBlock.Params = param
}
if auth != nil {
content.reqBlock.Auth = auth
}
if result != nil {
content.resBlock.Result = result
}
return err
err = content.CreateRequest()
if err != nil {
return err
}
err = content.WriteRequest()
if err != nil {
return err
}
err = LocalService(srvConn, handler)
if err != nil {
return err
}
err = content.ReadResponse()
if err != nil {
return err
}
err = content.BindResponse()
if err != nil {
return err
}
return err
}
func LocalPut(method string, reader io.Reader, size int64, param, result any, auth *Auth, handler HandlerFunc) error {
var err error
var err error
cliConn, srvConn := NewFConn()
cliConn, srvConn := NewFConn()
context := CreateContext(cliConn)
context.reqRPC.Method = method
context.reqRPC.Params = param
context.reqRPC.Auth = auth
context.resRPC.Result = result
content := CreateContent(cliConn)
content.reqBlock.Method = method
context.binReader = reader
context.binWriter = cliConn
if param != nil {
content.reqBlock.Params = param
}
if auth != nil {
content.reqBlock.Auth = auth
}
if result != nil {
content.resBlock.Result = result
}
context.reqHeader.binSize = size
content.binReader = reader
content.binWriter = cliConn
if context.reqRPC.Params == nil {
context.reqRPC.Params = NewEmpty()
}
err = context.CreateRequest()
if err != nil {
return err
}
err = context.WriteRequest()
if err != nil {
return err
}
err = context.UploadBin()
if err != nil {
return err
}
err = LocalService(srvConn, handler)
if err != nil {
return err
}
err = context.ReadResponse()
if err != nil {
return err
}
err = context.BindResponse()
if err != nil {
return err
}
return err
content.reqHeader.binSize = size
err = content.CreateRequest()
if err != nil {
return err
}
err = content.WriteRequest()
if err != nil {
return err
}
err = content.UploadBin()
if err != nil {
return err
}
err = LocalService(srvConn, handler)
if err != nil {
return err
}
err = content.ReadResponse()
if err != nil {
return err
}
err = content.BindResponse()
if err != nil {
return err
}
return err
}
func LocalGet(method string, writer io.Writer, param, result any, auth *Auth, handler HandlerFunc) error {
var err error
var err error
cliConn, srvConn := NewFConn()
cliConn, srvConn := NewFConn()
context := CreateContext(cliConn)
context.reqRPC.Method = method
context.reqRPC.Params = param
context.reqRPC.Auth = auth
context.resRPC.Result = result
content := CreateContent(cliConn)
content.reqBlock.Method = method
context.binReader = cliConn
context.binWriter = writer
if param != nil {
content.reqBlock.Params = param
}
if auth != nil {
content.reqBlock.Auth = auth
}
if result != nil {
content.resBlock.Result = result
}
if context.reqRPC.Params == nil {
context.reqRPC.Params = NewEmpty()
}
err = context.CreateRequest()
if err != nil {
return err
}
err = context.WriteRequest()
if err != nil {
return err
}
content.binReader = cliConn
content.binWriter = writer
err = LocalService(srvConn, handler)
if err != nil {
return err
}
err = context.ReadResponse()
if err != nil {
return err
}
err = context.DownloadBin()
if err != nil {
return err
}
err = context.BindResponse()
if err != nil {
return err
}
return err
err = content.CreateRequest()
if err != nil {
return err
}
err = content.WriteRequest()
if err != nil {
return err
}
err = LocalService(srvConn, handler)
if err != nil {
return err
}
err = content.ReadResponse()
if err != nil {
return err
}
err = content.DownloadBin()
if err != nil {
return err
}
err = content.BindResponse()
if err != nil {
return err
}
return err
}
func LocalService(conn net.Conn, handler HandlerFunc) error {
var err error
context := CreateContext(conn)
var err error
content := CreateContent(conn)
remoteAddr := conn.RemoteAddr().String()
remoteHost, _, _ := net.SplitHostPort(remoteAddr)
context.remoteHost = remoteHost
remoteAddr := conn.RemoteAddr().String()
remoteHost, _, _ := net.SplitHostPort(remoteAddr)
content.remoteHost = remoteHost
context.binReader = conn
context.binWriter = io.Discard
content.binReader = conn
content.binWriter = io.Discard
err = context.ReadRequest()
if err != nil {
return err
}
err = context.BindMethod()
if err != nil {
return err
}
return handler(context)
err = content.ReadRequest()
if err != nil {
return err
}
err = content.BindMethod()
if err != nil {
return err
}
return handler(content)
}

View File

@@ -5,60 +5,60 @@
package dsrpc
import (
"encoding/json"
"bytes"
"math/rand"
"time"
"crypto/sha256"
"bytes"
"crypto/sha256"
"encoding/json"
"math/rand"
"time"
)
func init() {
rand.Seed(time.Now().UnixNano())
rand.Seed(time.Now().UnixNano())
}
type Auth struct {
Ident []byte `json:"ident,omitempty"`
Salt []byte `json:"salt,omitempty"`
Hash []byte `json:"hash,omitempty"`
Ident []byte `msgpack:"ident" json:"ident"`
Salt []byte `msgpack:"salt" json:"salt"`
Hash []byte `msgpack:"hash" json:"hash"`
}
func NewAuth() *Auth {
return &Auth{}
return &Auth{}
}
func (this *Auth) JSON() []byte {
jBytes, _ := json.Marshal(this)
return jBytes
jBytes, _ := json.Marshal(this)
return jBytes
}
func CreateAuth(ident, pass []byte) *Auth {
salt := CreateSalt()
hash := CreateHash(ident, pass, salt)
auth := &Auth{}
auth.Ident = ident
auth.Salt = salt
auth.Hash = hash
return auth
salt := CreateSalt()
hash := CreateHash(ident, pass, salt)
auth := &Auth{}
auth.Ident = ident
auth.Salt = salt
auth.Hash = hash
return auth
}
func CreateSalt() []byte {
const saltSize = 16
randBytes := make([]byte, saltSize)
rand.Read(randBytes)
return randBytes
const saltSize = 16
randBytes := make([]byte, saltSize)
rand.Read(randBytes)
return randBytes
}
func CreateHash(ident, pass, salt []byte) []byte {
vec := make([]byte, 0, len(ident) + len(salt) + len(pass))
vec = append(vec, ident...)
vec = append(vec, salt...)
vec = append(vec, pass...)
hasher := sha256.New()
hash := hasher.Sum(vec)
return hash
vec := make([]byte, 0, len(ident)+len(salt)+len(pass))
vec = append(vec, ident...)
vec = append(vec, salt...)
vec = append(vec, pass...)
hasher := sha256.New()
hash := hasher.Sum(vec)
return hash
}
func CheckHash(ident, pass, reqSalt, reqHash []byte) bool {
localHash := CreateHash(ident, pass, reqSalt)
return bytes.Equal(reqHash, localHash)
localHash := CreateHash(ident, pass, reqSalt)
return bytes.Equal(reqHash, localHash)
}