16 Commits

27 changed files with 1762 additions and 1353 deletions

209
README.md
View File

@@ -1,24 +1,24 @@
# dsrpc, Data RPC
DSRPC is easy and simple RPC framework over TCP socket.
### Purpose
A very easy and open RPC framework with data streaming.
### You can
- Use post and pre-execution middleware
- Use own post and pre-execution middleware
- Hash-based authentication in middleware
- Test call remote function without service organization
- Test remote function without network
Socket encryption is not used at this time since framefork
is oriented to transfer large amounts of data
is oriented to transfer large amounts of data.
Style of the framework is similar to that of GIN framework.
Style of the framework is similar of GIN framework.
## Example
## Exec method example
### Server
@@ -135,7 +135,7 @@ package api
const HelloMethod string = "hello"
type HelloParams struct {
Message string `msgpack:"message" json:"message"`
Message string `json:"message"`
}
func NewHelloParams() *HelloParams {
@@ -143,7 +143,7 @@ func NewHelloParams() *HelloParams {
}
type HelloResult struct {
Message string `msgpack:"message" json:"message"`
Message string `json:"message"`
}
func NewHelloResult() *HelloResult {
@@ -151,3 +151,196 @@ func NewHelloResult() *HelloResult {
}
```
### Authentication and authorization
#### Client side
```
func clientHello() error {
var err error
params := NewHelloParams()
params.Message = "hello server!"
result := NewHelloResult()
auth := dsrpc.CreateAuth([]byte("login"), []byte("password"))
err = dsrpc.Exec("127.0.0.1:8081", HelloMethod, params, result, auth)
if err != nil {
log.Println("method err:", err)
return err
}
//...
}
```
#### Server side
```
func authMiddleware(context *dsrpc.Context) error {
var err error
reqIdent := context.AuthIdent()
reqSalt := context.AuthSalt()
reqHash := context.AuthHash()
if reqIdent != "login" {
err = errors.New("auth ident or pass mismatch")
context.SendError(err)
return err
}
ident := reqIdent
pass := []byte("password")
ok := dsrpc.CheckHash(ident, pass, reqSalt, reqHash)
log.Println("auth is ok:", ok)
if !ok {
err = errors.New("auth ident or pass mismatch")
context.SendError(err)
return err
}
return err
}
func sampleServ(quiet bool) error {
var err error
if quiet {
dsrpc.SetAccessWriter(io.Discard)
dsrpc.SetMessageWriter(io.Discard)
}
serv := NewService()
serv.PreMiddleware(authMiddleware)
serv.PreMiddleware(dsrpc.LogRequest)
serv.Handler(HelloMethod, helloHandler)
serv.Handler(SaveMethod, saveHandler)
serv.Handler(LoadMethod, loadHandler)
serv.PostMiddleware(dsrpc.LogResponse)
serv.PostMiddleware(dsrpc.LogAccess)
err = serv.Listen(":8081")
if err != nil {
return err
}
return err
}
```
### Put method
#### Client side sample
```
var binSize int64 = 16
rand.Seed(time.Now().UnixNano())
binBytes := make([]byte, binSize)
rand.Read(binBytes)
reader := bytes.NewReader(binBytes)
err = dsrpc.Put("127.0.0.1:8081", SaveMethod, reader, binSize, params, result, auth)
```
#### Server side
```
func saveHandler(context *dsrpc.Context) error {
var err error
params := NewSaveParams()
err = context.BindParams(params)
if err != nil {
return err
}
bufferBytes := make([]byte, 0, 1024)
binWriter := bytes.NewBuffer(bufferBytes)
err = context.ReadBin(binWriter)
if err != nil {
context.SendError(err)
return err
}
result := NewSaveResult()
result.Message = "saved successfully!"
err = context.SendResult(result, 0)
if err != nil {
return err
}
return err
}
```
### Get method
#### Client side
```
params := NewLoadParams()
params.Message = "load data!"
result := NewHelloResult()
auth := CreateAuth([]byte("qwert"), []byte("12345"))
binBytes := make([]byte, 0)
writer := bytes.NewBuffer(binBytes)
err = dsrpc.Get("127.0.0.1:8081", LoadMethod, writer, params, result, auth)
if err != nil {
return err
}
//...
```
#### Server side
```
func getHandler(context *dsrpc.Context) error {
var err error
params := NewSaveParams()
err = context.BindParams(params)
if err != nil {
return err
}
var binSize int64 = 1024
rand.Seed(time.Now().UnixNano())
binBytes := make([]byte, binSize)
rand.Read(binBytes)
binReader := bytes.NewReader(binBytes)
result := NewSaveResult()
result.Message = "load successfully!"
err = context.SendResult(result, binSize)
if err != nil {
return err
}
binWriter := context.BinWriter()
_, err = dsrpc.CopyBytes(binReader, binWriter, binSize)
if err != nil {
return err
}
return err
}
```

526
client.go
View File

@@ -1,282 +1,348 @@
/*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/
package dsrpc
import (
"encoding/json"
"errors"
"io"
"net"
"sync"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"sync"
encoder "encoding/json"
)
func Put(ctx context.Context, address string, method string, reader io.Reader, binSize int64, param, result any, auth *Auth) error {
var err error
func Put(address string, method string, reader io.Reader, size int64, param, result any, auth *Auth) error {
var err error
addr, err := net.ResolveTCPAddr("tcp", address)
if err != nil {
err = fmt.Errorf("unable to resolve adddress: %s", err)
return err
}
conn, err := net.DialTCP("tcp", nil, addr)
if err != nil {
return err
}
defer conn.Close()
conn, err := net.Dial("tcp", address)
if err != nil {
return err
}
defer conn.Close()
return ConnPut(conn, method, reader, size, param, result, auth)
return ConnPut(ctx, conn, method, reader, binSize, param, result, auth)
}
func PutTLS(ctx context.Context, tlsConfig *tls.Config, address string, method string, reader io.Reader, binSize int64, param, result any, auth *Auth) error {
var err error
func ConnPut(conn net.Conn, method string, reader io.Reader, size int64, param, result any, auth *Auth) error {
var err error
context := CreateContext(conn)
context.reqRPC.Method = method
context.reqRPC.Params = param
context.reqRPC.Auth = auth
context.resRPC.Result = result
conn, err := tls.Dial("tcp", address, tlsConfig)
if err != nil {
return err
}
defer conn.Close()
context.binReader = reader
context.binWriter = conn
context.reqHeader.binSize = size
if context.reqRPC.Params == nil {
context.reqRPC.Params = NewEmpty()
}
err = context.CreateRequest()
if err != nil {
return err
}
err = context.WriteRequest()
if err != nil {
return err
}
var wg sync.WaitGroup
errChan := make(chan error, 1)
wg.Add(1)
go context.ReadResponseAsync(&wg, errChan)
wg.Add(1)
go context.UploadBinAsync(&wg)
wg.Wait()
err = <- errChan
if err != nil {
return err
}
err = context.BindResponse()
if err != nil {
return err
}
return err
return ConnPut(ctx, conn, method, reader, binSize, param, result, auth)
}
func Get(address string, method string, writer io.Writer, param, result any, auth *Auth) error {
var err error
func ConnPut(ctx context.Context, conn net.Conn, method string, reader io.Reader, binSize int64, param, result any, auth *Auth) error {
var err error
content := CreateContent(conn)
conn, err := net.Dial("tcp", address)
if err != nil {
return err
}
defer conn.Close()
content.reqBlock.Method = method
if param != nil {
content.reqBlock.Params = param
}
if auth != nil {
content.reqBlock.Auth = auth
}
if result != nil {
content.resBlock.Result = result
}
return ConnGet(conn, method, writer, param, result, auth)
content.binReader = reader
content.binWriter = conn
content.reqHeader.binSize = binSize
err = content.createRequest()
if err != nil {
return err
}
err = content.writeRequest()
if err != nil {
return err
}
var wg sync.WaitGroup
errChan := make(chan error, 1)
wg.Add(1)
go content.readResponseAsync(&wg, errChan)
wg.Add(1)
go content.uploadBinAsync(ctx, &wg)
wg.Wait()
err = <-errChan
if err != nil {
return err
}
err = content.bindResponse()
if err != nil {
return err
}
return err
}
func ConnGet(conn net.Conn, method string, writer io.Writer, param, result any, auth *Auth) error {
var err error
func Get(ctx context.Context, address string, method string, writer io.Writer, param, result any, auth *Auth) error {
var err error
context := CreateContext(conn)
context.reqRPC.Method = method
context.reqRPC.Params = param
context.reqRPC.Auth = auth
context.resRPC.Result = result
addr, err := net.ResolveTCPAddr("tcp", address)
if err != nil {
err = fmt.Errorf("unable to resolve adddress: %s", err)
return err
}
conn, err := net.DialTCP("tcp", nil, addr)
if err != nil {
return err
}
defer conn.Close()
context.binReader = conn
context.binWriter = writer
if context.reqRPC.Params == nil {
context.reqRPC.Params = NewEmpty()
}
err = context.CreateRequest()
if err != nil {
return err
}
err = context.WriteRequest()
if err != nil {
return err
}
err = context.ReadResponse()
if err != nil {
return err
}
err = context.DownloadBin()
if err != nil {
return err
}
err = context.BindResponse()
if err != nil {
return err
}
return err
return ConnGet(ctx, conn, method, writer, param, result, auth)
}
func Exec(address, method string, param any, result any, auth *Auth) error {
var err error
func GetTLS(ctx context.Context, tlsConfig *tls.Config, address string, method string, writer io.Writer, param, result any, auth *Auth) error {
var err error
conn, err := net.Dial("tcp", address)
if err != nil {
return err
}
defer conn.Close()
conn, err := tls.Dial("tcp", address, tlsConfig)
if err != nil {
return err
}
defer conn.Close()
err = ConnExec(conn, method, param, result, auth)
if err != nil {
return err
}
return err
return ConnGet(ctx, conn, method, writer, param, result, auth)
}
func ConnGet(ctx context.Context, conn net.Conn, method string, writer io.Writer, param, result any, auth *Auth) error {
var err error
func ConnExec(conn net.Conn, method string, param any, result any, auth *Auth) error {
var err error
content := CreateContent(conn)
content.reqBlock.Method = method
if param != nil {
content.reqBlock.Params = param
}
if auth != nil {
content.reqBlock.Auth = auth
}
if result != nil {
content.resBlock.Result = result
}
context := CreateContext(conn)
context.reqRPC.Method = method
context.reqRPC.Params = param
context.reqRPC.Auth = auth
context.resRPC.Result = result
content.binReader = conn
content.binWriter = writer
if context.reqRPC.Params == nil {
context.reqRPC.Params = NewEmpty()
}
err = context.CreateRequest()
if err != nil {
return err
}
err = context.WriteRequest()
if err != nil {
return err
}
err = context.ReadResponse()
if err != nil {
return err
}
err = context.BindResponse()
if err != nil {
return err
}
return err
err = content.createRequest()
if err != nil {
return err
}
err = content.writeRequest()
if err != nil {
return err
}
err = content.readResponse()
if err != nil {
return err
}
err = content.downloadBin(ctx)
if err != nil {
return err
}
err = content.bindResponse()
if err != nil {
return err
}
return err
}
func Exec(ctx context.Context, address, method string, param any, result any, auth *Auth) error {
var err error
func (context *Context) CreateRequest() error {
var err error
addr, err := net.ResolveTCPAddr("tcp", address)
if err != nil {
err = fmt.Errorf("unable to resolve adddress: %s", err)
return err
}
conn, err := net.DialTCP("tcp", nil, addr)
if err != nil {
return err
}
defer conn.Close()
context.reqPacket.rcpPayload, err = context.reqRPC.Pack()
if err != nil {
return err
}
rpcSize := int64(len(context.reqPacket.rcpPayload))
context.reqHeader.rpcSize = rpcSize
context.reqPacket.header, err = context.reqHeader.Pack()
if err != nil {
return err
}
return err
err = ConnExec(ctx, conn, method, param, result, auth)
if err != nil {
return err
}
return err
}
func (context *Context) WriteRequest() error {
var err error
_, err = context.sockWriter.Write(context.reqPacket.header)
if err != nil {
return err
}
_, err = context.sockWriter.Write(context.reqPacket.rcpPayload)
if err != nil {
return err
}
return err
func ExecTLS(ctx context.Context, tlsConfig *tls.Config, address, method string, param any, result any, auth *Auth) error {
var err error
conn, err := tls.Dial("tcp", address, tlsConfig)
if err != nil {
return err
}
defer conn.Close()
err = ConnExec(ctx, conn, method, param, result, auth)
if err != nil {
return err
}
return err
}
func (context *Context) UploadBin() error {
var err error
_, err = CopyBytes(context.binReader, context.binWriter, context.reqHeader.binSize)
return err
func ConnExec(ctx context.Context, conn net.Conn, method string, param any, result any, auth *Auth) error {
var err error
content := CreateContent(conn)
content.reqBlock.Method = method
if param != nil {
content.reqBlock.Params = param
}
if result != nil {
content.resBlock.Result = result
}
if auth != nil {
content.reqBlock.Auth = auth
}
err = content.createRequest()
if err != nil {
return err
}
err = content.writeRequest()
if err != nil {
return err
}
err = content.readResponse()
if err != nil {
return err
}
err = content.bindResponse()
if err != nil {
return err
}
return err
}
func (context *Context) ReadResponse() error {
var err error
func (content *Content) createRequest() error {
var err error
context.resPacket.header, err = ReadBytes(context.sockReader, headerSize)
if err != nil {
return err
}
context.resHeader, err = UnpackHeader(context.resPacket.header)
if err != nil {
return err
}
rpcSize := context.resHeader.rpcSize
context.resPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize)
if err != nil {
return err
}
return err
content.reqPacket.rcpPayload, err = content.reqBlock.Pack()
if err != nil {
return err
}
rpcSize := int64(len(content.reqPacket.rcpPayload))
content.reqHeader.rpcSize = rpcSize
content.reqPacket.header, err = content.reqHeader.Pack()
if err != nil {
return err
}
return err
}
func (context *Context) UploadBinAsync(wg *sync.WaitGroup) {
exitFunc := func() {
wg.Done()
}
defer exitFunc()
_, _ = CopyBytes(context.binReader, context.binWriter, context.reqHeader.binSize)
return
func (content *Content) writeRequest() error {
var err error
_, err = content.sockWriter.Write(content.reqPacket.header)
if err != nil {
return err
}
_, err = content.sockWriter.Write(content.reqPacket.rcpPayload)
if err != nil {
return err
}
return err
}
func (context *Context) ReadResponseAsync(wg *sync.WaitGroup, errChan chan error) {
var err error
exitFunc := func() {
errChan <- err
wg.Done()
}
defer exitFunc()
context.resPacket.header, err = ReadBytes(context.sockReader, headerSize)
if err != nil {
return
}
context.resHeader, err = UnpackHeader(context.resPacket.header)
if err != nil {
return
}
rpcSize := context.resHeader.rpcSize
context.resPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize)
if err != nil {
return
}
return
func (content *Content) uploadBin(ctx context.Context) error {
var err error
_, err = CopyBytes(ctx, content.binReader, content.binWriter, content.reqHeader.binSize)
return err
}
func (context *Context) DownloadBin() error {
var err error
_, err = CopyBytes(context.binReader, context.binWriter, context.resHeader.binSize)
return err
func (content *Content) readResponse() error {
var err error
content.resPacket.header, err = ReadBytes(content.sockReader, headerSize)
if err != nil {
return err
}
content.resHeader, err = UnpackHeader(content.resPacket.header)
if err != nil {
return err
}
rpcSize := content.resHeader.rpcSize
content.resPacket.rcpPayload, err = ReadBytes(content.sockReader, rpcSize)
if err != nil {
return err
}
return err
}
func (context *Context) BindResponse() error {
var err error
err = json.Unmarshal(context.resPacket.rcpPayload, context.resRPC)
if err != nil {
return err
}
if len(context.resRPC.Error) > 0 {
return errors.New(context.resRPC.Error)
}
return err
func (content *Content) uploadBinAsync(ctx context.Context, wg *sync.WaitGroup) {
exitFunc := func() {
wg.Done()
}
defer exitFunc()
_, _ = CopyBytes(ctx, content.binReader, content.binWriter, content.reqHeader.binSize)
return
}
func (content *Content) readResponseAsync(wg *sync.WaitGroup, errChan chan error) {
var err error
exitFunc := func() {
errChan <- err
wg.Done()
}
defer exitFunc()
content.resPacket.header, err = ReadBytes(content.sockReader, headerSize)
if err != nil {
err = err
return
}
content.resHeader, err = UnpackHeader(content.resPacket.header)
if err != nil {
err = err
return
}
rpcSize := content.resHeader.rpcSize
content.resPacket.rcpPayload, err = ReadBytes(content.sockReader, rpcSize)
if err != nil {
err = err
return
}
return
}
func (content *Content) downloadBin(ctx context.Context) error {
var err error
_, err = CopyBytes(ctx, content.binReader, content.binWriter, content.resHeader.binSize)
return err
}
func (content *Content) bindResponse() error {
var err error
err = encoder.Unmarshal(content.resPacket.rcpPayload, content.resBlock)
if err != nil {
return err
}
if len(content.resBlock.Error) > 0 {
err = errors.New(content.resBlock.Error)
return err
}
return err
}

View File

@@ -1,3 +0,0 @@
package dsrpc
type any = interface{}

145
content.go Normal file
View File

@@ -0,0 +1,145 @@
/*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*/
package dsrpc
import (
"io"
"net"
"time"
)
type Content struct {
start time.Time
remoteHost string
sockReader io.Reader
sockWriter io.Writer
reqPacket *Packet
reqHeader *Header
reqBlock *Request
resPacket *Packet
resHeader *Header
resBlock *Response
binReader io.Reader
binWriter io.Writer
}
func CreateContent(conn net.Conn) *Content {
context := &Content{
start: time.Now(),
sockReader: conn,
sockWriter: conn,
reqPacket: NewEmptyPacket(),
reqHeader: NewEmptyHeader(),
reqBlock: NewEmptyRequest(),
resPacket: NewEmptyPacket(),
resHeader: NewEmptyHeader(),
resBlock: NewEmptyResponse(),
}
return context
}
func (context *Content) Request() *Request {
return context.reqBlock
}
func (context *Content) RemoteHost() string {
return context.remoteHost
}
func (context *Content) Start() time.Time {
return context.start
}
func (context *Content) Method() string {
var method string
if context.reqBlock != nil {
method = context.reqBlock.Method
}
return method
}
func (context *Content) ReqRpcSize() int64 {
var size int64
if context.reqHeader != nil {
size = context.reqHeader.rpcSize
}
return size
}
func (context *Content) ReqBinSize() int64 {
var size int64
if context.reqHeader != nil {
size = context.reqHeader.binSize
}
return size
}
func (context *Content) ResBinSize() int64 {
var size int64
if context.resHeader != nil {
size = context.resHeader.binSize
}
return size
}
func (context *Content) ResRpcSize() int64 {
var size int64
if context.resHeader != nil {
size = context.resHeader.rpcSize
}
return size
}
func (context *Content) ReqSize() int64 {
var size int64
if context.reqHeader != nil {
size += context.reqHeader.binSize
size += context.reqHeader.rpcSize
}
return size
}
func (context *Content) ResSize() int64 {
var size int64
if context.resHeader != nil {
size += context.resHeader.binSize
size += context.resHeader.rpcSize
}
return size
}
func (context *Content) SetAuthIdent(ident []byte) {
context.reqBlock.Auth.Ident = ident
}
func (context *Content) SetAuthSalt(salt []byte) {
context.reqBlock.Auth.Salt = salt
}
func (context *Content) SetAuthHash(hash []byte) {
context.reqBlock.Auth.Hash = hash
}
func (context *Content) AuthIdent() []byte {
return context.reqBlock.Auth.Ident
}
func (context *Content) AuthSalt() []byte {
return context.reqBlock.Auth.Salt
}
func (context *Content) AuthHash() []byte {
return context.reqBlock.Auth.Hash
}
func (context *Content) Auth() *Auth {
return context.reqBlock.Auth
}

View File

@@ -1,90 +0,0 @@
/*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/
package dsrpc
import (
"io"
"net"
"time"
)
type Context struct {
start time.Time
remoteHost string
sockReader io.Reader
sockWriter io.Writer
reqHeader *Header
reqRPC *Request
reqPacket *Packet
resPacket *Packet
resHeader *Header
resRPC *Response
binReader io.Reader
binWriter io.Writer
}
func NewContext() *Context {
context := &Context{}
context.start = time.Now()
return context
}
func CreateContext(conn net.Conn) *Context {
context := &Context{}
context.start = time.Now()
context.sockReader = conn
context.sockWriter = conn
context.reqPacket = NewPacket()
context.resPacket = NewPacket()
context.reqHeader = NewHeader()
context.reqRPC = NewRequest()
context.resHeader = NewHeader()
context.resRPC = NewResponse()
context.resRPC = NewResponse()
return context
}
func (context *Context) Request() *Request {
return context.reqRPC
}
func (context *Context) SetAuthIdent(ident []byte) {
context.reqRPC.Auth.Ident = ident
}
func (context *Context) SetAuthSalt(salt []byte) {
context.reqRPC.Auth.Salt = salt
}
func (context *Context) SetAuthHash(hash []byte) {
context.reqRPC.Auth.Hash = hash
}
func (context *Context) AuthIdent() []byte {
return context.reqRPC.Auth.Ident
}
func (context *Context) AuthSalt() []byte {
return context.reqRPC.Auth.Salt
}
func (context *Context) AuthHash() []byte {
return context.reqRPC.Auth.Hash
}
func (context *Context) Auth() *Auth {
return context.reqRPC.Auth
}

View File

@@ -1,13 +0,0 @@
/*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/
package dsrpc
type Empty struct {}
func NewEmpty() *Empty {
return &Empty{}
}

View File

@@ -12,14 +12,7 @@ type HelloParams struct {
Message string `msgpack:"message" json:"message"`
}
func NewHelloParams() *HelloParams {
return &HelloParams{}
}
type HelloResult struct {
Message string `msgpack:"message" json:"message"`
}
func NewHelloResult() *HelloResult {
return &HelloResult{}
}

View File

@@ -7,32 +7,40 @@
package main
import (
"fmt"
"github.com/kindsoldier/dsrpc"
"netsrv/api"
"context"
"fmt"
"time"
"github.com/kindsoldier/dsrpc"
"netsrv/api"
)
func main() {
err := exec()
if err != nil {
fmt.Println("exec err:", err)
}
err := exec()
if err != nil {
fmt.Println("exec err:", err)
}
}
func exec() error {
var err error
var err error
params := api.NewHelloParams()
params.Message = "hello, server!"
params := api.HelloParams{
Message: "hello, server!",
}
result := api.NewHelloResult()
result := api.HelloResult{}
err = dsrpc.Exec("127.0.0.1:8081", api.HelloMethod, params, result, nil)
if err != nil {
return err
}
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(5*time.Second))
defer cancel()
fmt.Println("result:", result.Message)
return err
err = dsrpc.Exec(ctx, "127.0.0.1:8081", api.HelloMethod, &params, &result, nil)
if err != nil {
return err
}
fmt.Println("result:", result.Message)
return err
}

View File

@@ -1,7 +1,5 @@
module netsrv
go 1.17
go 1.19
require github.com/kindsoldier/dsrpc v0.0.1
replace github.com/kindsoldier/dsrpc => ../
require github.com/kindsoldier/dsrpc v1.2.1

View File

@@ -1,10 +1,6 @@
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/kindsoldier/dsrpc v1.2.1 h1:sw1a3MAD83Do1Fu+Dh+AHArrwVMgZ/KTLUWWTkQ6vj8=
github.com/kindsoldier/dsrpc v1.2.1/go.mod h1:zYb5yYfE/18BYK+iCUNcpkZ4uArwUNNhwYkUK8xDHQk=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

View File

@@ -8,8 +8,9 @@ package main
import (
"log"
"github.com/kindsoldier/dsrpc"
"netsrv/api"
"github.com/kindsoldier/dsrpc"
)
func main() {
@@ -46,20 +47,19 @@ func NewController() *Controller {
return &Controller{}
}
func (cont *Controller) HelloHandler(context *dsrpc.Context) error {
func (cont *Controller) HelloHandler(content *dsrpc.Content) error {
var err error
params := api.NewHelloParams()
err = context.BindParams(params)
params := api.HelloParams{}
err = content.BindParams(&params)
if err != nil {
return err
}
log.Println("hello message:", params.Message)
result := api.NewHelloResult()
result.Message = "hello!"
err = context.SendResult(result, 0)
result := api.HelloResult{
Message: "hello, client!",
}
err = content.SendResult(&result, 0)
if err != nil {
return err
}

View File

@@ -1,376 +1,370 @@
/*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/
package dsrpc
import (
"bytes"
"encoding/json"
"errors"
"io"
"math/rand"
"testing"
"time"
"bytes"
"context"
"encoding/json"
"errors"
"io"
"math/rand"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/require"
)
func TestLocalExec(t *testing.T) {
var err error
params := NewHelloParams()
params.Message = "hello server!"
result := NewHelloResult()
auth := CreateAuth([]byte("qwert"), []byte("12345"))
err = LocalExec(HelloMethod, params, result, auth, helloHandler)
require.NoError(t, err)
resultJSON, _ := json.Marshal(result)
logDebug("method result:", string(resultJSON))
}
func TestLocalSave(t *testing.T) {
var err error
params := NewSaveParams()
params.Message = "save data!"
result := NewHelloResult()
auth := CreateAuth([]byte("qwert"), []byte("12345"))
var binSize int64 = 16
rand.Seed(time.Now().UnixNano())
binBytes := make([]byte, binSize)
rand.Read(binBytes)
reader := bytes.NewReader(binBytes)
err = LocalPut(SaveMethod, reader, binSize, params, result, auth, saveHandler)
require.NoError(t, err)
resultJSON, _ := json.Marshal(result)
logDebug("method result:", string(resultJSON))
}
func TestLocalLoad(t *testing.T) {
var err error
params := NewLoadParams()
params.Message = "load data!"
result := NewHelloResult()
auth := CreateAuth([]byte("qwert"), []byte("12345"))
binBytes := make([]byte, 0)
writer := bytes.NewBuffer(binBytes)
err = LocalGet(LoadMethod, writer, params, result, auth, loadHandler)
require.NoError(t, err)
resultJSON, _ := json.Marshal(result)
logDebug("method result:", string(resultJSON))
logDebug("bin size:", len(writer.Bytes()))
}
func TestNetExec(t *testing.T) {
go testServ(false)
time.Sleep(10 * time.Millisecond)
err := clientHello()
require.NoError(t, err)
}
func TestNetSave(t *testing.T) {
go testServ(false)
time.Sleep(10 * time.Millisecond)
err := clientSave()
require.NoError(t, err)
}
func TestNetLoad(t *testing.T) {
go testServ(false)
time.Sleep(10 * time.Millisecond)
err := clientLoad()
require.NoError(t, err)
}
func BenchmarkNetPut(b *testing.B) {
go testServ(true)
time.Sleep(10 * time.Millisecond)
clientSave()
pBench := func(pb *testing.PB) {
for pb.Next() {
clientSave()
}
}
b.SetParallelism(10)
b.RunParallel(pBench)
}
func clientHello() error {
var err error
params := NewHelloParams()
params.Message = "hello server!"
result := NewHelloResult()
auth := CreateAuth([]byte("qwert"), []byte("12345"))
var binSize int64 = 16
rand.Seed(time.Now().UnixNano())
binBytes := make([]byte, binSize)
rand.Read(binBytes)
err = Exec("127.0.0.1:8081", HelloMethod, params, result, auth)
if err != nil {
logError("method err:", err)
return err
}
resultJSON, _ := json.Marshal(result)
logDebug("method result:", string(resultJSON))
return err
}
func clientSave() error {
var err error
params := NewSaveParams()
params.Message = "save data!"
result := NewHelloResult()
auth := CreateAuth([]byte("qwert"), []byte("12345"))
var binSize int64 = 16
rand.Seed(time.Now().UnixNano())
binBytes := make([]byte, binSize)
rand.Read(binBytes)
reader := bytes.NewReader(binBytes)
err = Put("127.0.0.1:8081", SaveMethod, reader, binSize, params, result, auth)
if err != nil {
logError("method err:", err)
return err
}
resultJSON, _ := json.Marshal(result)
logDebug("method result:", string(resultJSON))
return err
}
func clientLoad() error {
var err error
params := NewLoadParams()
params.Message = "load data!"
result := NewHelloResult()
auth := CreateAuth([]byte("qwert"), []byte("12345"))
binBytes := make([]byte, 0)
writer := bytes.NewBuffer(binBytes)
err = Get("127.0.0.1:8081", LoadMethod, writer, params, result, auth)
if err != nil {
logError("method err:", err)
return err
}
resultJSON, _ := json.Marshal(result)
logDebug("method result:", string(resultJSON))
logDebug("bin size:", len(writer.Bytes()))
return err
}
var testServRun bool = false
func testServ(quiet bool) error {
var err error
if testServRun {
return err
}
testServRun = true
if quiet {
SetAccessWriter(io.Discard)
SetMessageWriter(io.Discard)
}
serv := NewService()
serv.Handler(HelloMethod, helloHandler)
serv.Handler(SaveMethod, saveHandler)
serv.Handler(LoadMethod, loadHandler)
serv.PreMiddleware(LogRequest)
serv.PreMiddleware(auth)
serv.PostMiddleware(LogResponse)
serv.PostMiddleware(LogAccess)
err = serv.Listen(":8081")
if err != nil {
return err
}
return err
}
func auth(context *Context) error {
var err error
reqIdent := context.AuthIdent()
reqSalt := context.AuthSalt()
reqHash := context.AuthHash()
ident := reqIdent
pass := []byte("12345")
auth := context.Auth()
logDebug("auth ", string(auth.JSON()))
ok := CheckHash(ident, pass, reqSalt, reqHash)
logDebug("auth ok:", ok)
if !ok {
err = errors.New("auth ident or pass missmatch")
context.SendError(err)
return err
}
return err
}
func helloHandler(context *Context) error {
var err error
params := NewHelloParams()
err = context.BindParams(params)
if err != nil {
return err
}
err = context.ReadBin(io.Discard)
if err != nil {
context.SendError(err)
return err
}
result := NewHelloResult()
result.Message = "hello, client!"
err = context.SendResult(result, 0)
if err != nil {
return err
}
return err
}
func saveHandler(context *Context) error {
var err error
params := NewSaveParams()
err = context.BindParams(params)
if err != nil {
return err
}
bufferBytes := make([]byte, 0, 1024)
binWriter := bytes.NewBuffer(bufferBytes)
err = context.ReadBin(binWriter)
if err != nil {
context.SendError(err)
return err
}
result := NewSaveResult()
result.Message = "saved successfully!"
err = context.SendResult(result, 0)
if err != nil {
return err
}
return err
}
func loadHandler(context *Context) error {
var err error
params := NewSaveParams()
err = context.BindParams(params)
if err != nil {
return err
}
err = context.ReadBin(io.Discard)
if err != nil {
context.SendError(err)
return err
}
var binSize int64 = 1024
rand.Seed(time.Now().UnixNano())
binBytes := make([]byte, binSize)
rand.Read(binBytes)
binReader := bytes.NewReader(binBytes)
result := NewSaveResult()
result.Message = "load successfully!"
err = context.SendResult(result, binSize)
if err != nil {
return err
}
binWriter := context.BinWriter()
_, err = CopyBytes(binReader, binWriter, binSize)
if err != nil {
return err
}
return err
}
const HelloMethod string = "hello"
type HelloParams struct {
Message string `json:"message" json:"message"`
}
func NewHelloParams() *HelloParams {
return &HelloParams{}
Message string `json:"message" msgpack:"message"`
}
type HelloResult struct {
Message string `json:"message" json:"message"`
Message string `json:"message" msgpack:"message"`
}
func NewHelloResult() *HelloResult {
return &HelloResult{}
}
const SaveMethod string = "save"
type SaveParams HelloParams
type SaveResult HelloResult
func NewSaveParams() *SaveParams {
return &SaveParams{}
}
func NewSaveResult() *SaveResult {
return &SaveResult{}
}
const LoadMethod string = "load"
type LoadParams HelloParams
type LoadResult HelloResult
func NewLoadParams() *LoadParams {
return &LoadParams{}
func TestLocalExec(t *testing.T) {
var err error
params := HelloParams{}
params.Message = "hello server!"
result := HelloResult{}
auth := CreateAuth([]byte("qwert"), []byte("12345"))
err = LocalExec(HelloMethod, &params, &result, auth, helloHandler)
require.NoError(t, err)
resultJson, _ := json.Marshal(result)
logDebug("method result:", string(resultJson))
}
func NewLoadResult() *LoadResult {
return &LoadResult{}
func TestLocalSave(t *testing.T) {
var err error
params := SaveParams{}
params.Message = "save data!"
result := SaveResult{}
auth := CreateAuth([]byte("qwert"), []byte("12345"))
var binSize int64 = 16
rand.Seed(time.Now().UnixNano())
binBytes := make([]byte, binSize)
rand.Read(binBytes)
reader := bytes.NewReader(binBytes)
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(5*time.Second))
defer cancel()
err = LocalPut(ctx, SaveMethod, reader, binSize, &params, &result, auth, saveHandler)
require.NoError(t, err)
resultJson, _ := json.Marshal(result)
logDebug("method result:", string(resultJson))
}
func TestLocalLoad(t *testing.T) {
var err error
params := LoadParams{}
params.Message = "load data!"
result := LoadResult{}
auth := CreateAuth([]byte("qwert"), []byte("12345"))
binBytes := make([]byte, 0)
writer := bytes.NewBuffer(binBytes)
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(5*time.Second))
defer cancel()
err = LocalGet(ctx, LoadMethod, writer, &params, &result, auth, loadHandler)
require.NoError(t, err)
resultJson, _ := json.Marshal(result)
logDebug("method result:", string(resultJson))
logDebug("bin size:", len(writer.Bytes()))
}
func TestNetExec(t *testing.T) {
go testServ(false)
time.Sleep(10 * time.Millisecond)
err := clientHello()
require.NoError(t, err)
}
func TestNetSave(t *testing.T) {
go testServ(false)
time.Sleep(10 * time.Millisecond)
err := clientSave()
require.NoError(t, err)
}
func TestNetLoad(t *testing.T) {
go testServ(false)
time.Sleep(10 * time.Millisecond)
err := clientLoad()
require.NoError(t, err)
}
func BenchmarkNetPut(b *testing.B) {
go testServ(true)
time.Sleep(10 * time.Millisecond)
clientSave()
pBench := func(pb *testing.PB) {
for pb.Next() {
clientSave()
}
}
b.SetParallelism(2000)
b.RunParallel(pBench)
}
func clientHello() error {
var err error
params := HelloParams{}
params.Message = "hello server!"
result := HelloResult{}
auth := CreateAuth([]byte("qwert"), []byte("12345"))
var binSize int64 = 16
rand.Seed(time.Now().UnixNano())
binBytes := make([]byte, binSize)
rand.Read(binBytes)
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(5*time.Second))
defer cancel()
err = Exec(ctx, "127.0.0.1:8081", HelloMethod, &params, &result, auth)
if err != nil {
logError("method err:", err)
return err
}
resultJson, _ := json.Marshal(result)
logDebug("method result:", string(resultJson))
return err
}
func clientSave() error {
var err error
params := SaveParams{}
params.Message = "save data!"
result := SaveResult{}
auth := CreateAuth([]byte("qwert"), []byte("12345"))
var binSize int64 = 16
rand.Seed(time.Now().UnixNano())
binBytes := make([]byte, binSize)
rand.Read(binBytes)
reader := bytes.NewReader(binBytes)
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(5*time.Second))
defer cancel()
err = Put(ctx, "127.0.0.1:8081", SaveMethod, reader, binSize, &params, &result, auth)
if err != nil {
logError("method err:", err)
return err
}
resultJson, _ := json.Marshal(result)
logDebug("method result:", string(resultJson))
return err
}
func clientLoad() error {
var err error
params := LoadParams{}
params.Message = "load data!"
result := LoadResult{}
auth := CreateAuth([]byte("qwert"), []byte("12345"))
binBytes := make([]byte, 0)
writer := bytes.NewBuffer(binBytes)
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(5*time.Second))
defer cancel()
err = Get(ctx, "127.0.0.1:8081", LoadMethod, writer, &params, &result, auth)
if err != nil {
logError("method err:", err)
return err
}
resultJson, _ := json.Marshal(result)
logDebug("method result:", string(resultJson))
logDebug("bin size:", len(writer.Bytes()))
return err
}
var testServRun bool = false
func testServ(quiet bool) error {
var err error
if testServRun {
return err
}
testServRun = true
if quiet {
SetAccessWriter(io.Discard)
SetMessageWriter(io.Discard)
}
serv := NewService()
serv.Handler(HelloMethod, helloHandler)
serv.Handler(SaveMethod, saveHandler)
serv.Handler(LoadMethod, loadHandler)
serv.PreMiddleware(LogRequest)
serv.PreMiddleware(auth)
serv.PostMiddleware(LogResponse)
serv.PostMiddleware(LogAccess)
err = serv.Listen(":8081")
if err != nil {
return err
}
return err
}
func auth(content *Content) error {
var err error
reqIdent := content.AuthIdent()
reqSalt := content.AuthSalt()
reqHash := content.AuthHash()
ident := reqIdent
pass := []byte("12345")
auth := content.Auth()
logDebug("auth ", string(auth.Json()))
ok := CheckHash(ident, pass, reqSalt, reqHash)
logDebug("auth ok:", ok)
if !ok {
err = errors.New("auth ident or pass missmatch")
content.SendError(err)
return err
}
return err
}
func helloHandler(content *Content) error {
var err error
params := HelloParams{}
err = content.BindParams(&params)
if err != nil {
return err
}
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(5*time.Second))
defer cancel()
err = content.ReadBin(ctx, io.Discard)
if err != nil {
content.SendError(err)
return err
}
result := HelloResult{}
result.Message = "hello, client!"
err = content.SendResult(result, 0)
if err != nil {
return err
}
return err
}
func saveHandler(content *Content) error {
var err error
params := SaveParams{}
err = content.BindParams(&params)
if err != nil {
return err
}
bufferBytes := make([]byte, 0, 1024)
binWriter := bytes.NewBuffer(bufferBytes)
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(5*time.Second))
defer cancel()
err = content.ReadBin(ctx, binWriter)
if err != nil {
content.SendError(err)
return err
}
result := SaveResult{}
result.Message = "saved successfully!"
err = content.SendResult(result, 0)
if err != nil {
return err
}
return err
}
func loadHandler(content *Content) error {
var err error
params := SaveParams{}
err = content.BindParams(&params)
if err != nil {
return err
}
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(5*time.Second))
defer cancel()
err = content.ReadBin(ctx, io.Discard)
if err != nil {
content.SendError(err)
return err
}
var binSize int64 = 1024
rand.Seed(time.Now().UnixNano())
binBytes := make([]byte, binSize)
rand.Read(binBytes)
binReader := bytes.NewReader(binBytes)
result := SaveResult{}
result.Message = "load successfully!"
err = content.SendResult(result, binSize)
if err != nil {
return err
}
binWriter := content.BinWriter()
_, err = CopyBytes(ctx, binReader, binWriter, binSize)
if err != nil {
return err
}
return err
}

View File

@@ -1,27 +1,26 @@
/*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/
package dsrpc
type FAddr struct {
network string
address string
network string
address string
}
func NewFAddr() *FAddr {
var addr FAddr
addr.network = "tcp"
addr.address = "127.0.0.1:5000"
return &addr
addr := FAddr{
network: "tcp",
address: "127.0.0.1:5000",
}
return &addr
}
func (addr *FAddr) Network() string {
return addr.network
return addr.network
}
func (addr *FAddr) String() string {
return addr.address
return addr.address
}

View File

@@ -7,51 +7,52 @@
package dsrpc
import (
"net"
"testing"
"github.com/stretchr/testify/require"
"net"
"testing"
"github.com/stretchr/testify/require"
)
func TestFConn0(t *testing.T) {
var cConn, sConn net.Conn
sConn, cConn = NewFConn()
var cConn, sConn net.Conn
sConn, cConn = NewFConn()
cData := []byte("qwerty")
count := 10
cData := []byte("qwerty")
count := 10
for i := 0; i < count; i++ {
wc, err := cConn.Write(cData)
if err != nil {
t.Error(err)
}
require.Equal(t, wc, len(cData))
for i := 0; i < count; i++ {
wc, err := cConn.Write(cData)
if err != nil {
t.Error(err)
}
require.Equal(t, wc, len(cData))
sData := make([]byte, len(cData))
rc, err := sConn.Read(sData)
require.NoError(t, err)
require.Equal(t, rc, len(cData))
require.Equal(t, cData, sData)
}
sData := make([]byte, len(cData))
rc, err := sConn.Read(sData)
require.NoError(t, err)
require.Equal(t, rc, len(cData))
require.Equal(t, cData, sData)
}
}
func TestFConn1(t *testing.T) {
var cConn, sConn net.Conn
cConn, sConn = NewFConn()
var cConn, sConn net.Conn
cConn, sConn = NewFConn()
cData := []byte("qwerty")
count := 10
cData := []byte("qwerty")
count := 10
for i := 0; i < count; i++ {
wc, err := cConn.Write(cData)
if err != nil {
t.Error(err)
}
require.Equal(t, wc, len(cData))
for i := 0; i < count; i++ {
wc, err := cConn.Write(cData)
if err != nil {
t.Error(err)
}
require.Equal(t, wc, len(cData))
sData := make([]byte, len(cData))
rc, err := sConn.Read(sData)
require.NoError(t, err)
require.Equal(t, rc, len(cData))
require.Equal(t, cData, sData)
}
sData := make([]byte, len(cData))
rc, err := sConn.Read(sData)
require.NoError(t, err)
require.Equal(t, rc, len(cData))
require.Equal(t, cData, sData)
}
}

View File

@@ -7,62 +7,62 @@
package dsrpc
import (
"bytes"
"io"
"net"
"time"
"bytes"
"io"
"net"
"time"
)
type FConn struct {
reader io.Reader
writer io.Writer
reader io.Reader
writer io.Writer
}
func NewFConn() (*FConn, *FConn){
c2sBuffer := bytes.NewBuffer(make([]byte, 0))
s2cBuffer := bytes.NewBuffer(make([]byte, 0))
func NewFConn() (*FConn, *FConn) {
c2sBuffer := bytes.NewBuffer(make([]byte, 0))
s2cBuffer := bytes.NewBuffer(make([]byte, 0))
var client FConn
client.writer = c2sBuffer
client.reader = s2cBuffer
var client FConn
client.writer = c2sBuffer
client.reader = s2cBuffer
var server FConn
server.writer = s2cBuffer
server.reader = c2sBuffer
var server FConn
server.writer = s2cBuffer
server.reader = c2sBuffer
return &client, &server
return &client, &server
}
func (conn FConn) SetDeadline(t time.Time) error {
var err error
return err
var err error
return err
}
func (conn FConn) SetReadDeadline(t time.Time) error {
var err error
return err
func (conn FConn) SetReadDeadline(t time.Time) error {
var err error
return err
}
func (conn FConn) SetWriteDeadline(t time.Time) error {
var err error
return err
var err error
return err
}
func (conn FConn) LocalAddr() net.Addr {
return NewFAddr()
return NewFAddr()
}
func (conn FConn) RemoteAddr() net.Addr {
return NewFAddr()
return NewFAddr()
}
func (conn FConn) Write(data []byte) (int, error) {
return conn.writer.Write(data)
return conn.writer.Write(data)
}
func (conn FConn) Read(data []byte) (int, error) {
return conn.reader.Read(data)
return conn.reader.Read(data)
}
func (conn FConn) Close() error {
var err error
return err
var err error
return err
}

8
go.mod
View File

@@ -1,11 +1,11 @@
module github.com/kindsoldier/dsrpc
go 1.17
go 1.19
require github.com/stretchr/testify v1.7.1
require github.com/stretchr/testify v1.8.2
require (
github.com/davecgh/go-spew v1.1.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

12
go.sum
View File

@@ -1,11 +1,17 @@
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

123
header.go
View File

@@ -7,92 +7,95 @@
package dsrpc
import (
"errors"
"encoding/binary"
"encoding/json"
"bytes"
"bytes"
"encoding/binary"
"encoding/json"
"errors"
)
const headerSize int64 = 16 * 2
const sizeOfInt64 int = 8
const magicCodeA int64 = 0xEE00ABBA
const magicCodeB int64 = 0xEE44ABBA
const (
headerSize int64 = 16 * 2
sizeOfInt64 int = 8
magicCodeA int64 = 0xEE00ABBA
magicCodeB int64 = 0xEE44ABBA
)
type Header struct {
magicCodeA int64 `json:"magicCodeA"`
rpcSize int64 `json:"rpcSize"`
binSize int64 `json:"binSize"`
magicCodeB int64 `json:"magicCodeB"`
magicCodeA int64 `json:"magicCodeA"`
rpcSize int64 `json:"rpcSize"`
binSize int64 `json:"binSize"`
magicCodeB int64 `json:"magicCodeB"`
}
func NewHeader() *Header {
return &Header{
magicCodeA: magicCodeA,
magicCodeB: magicCodeB,
}
func NewEmptyHeader() *Header {
return &Header{
magicCodeA: magicCodeA,
magicCodeB: magicCodeB,
}
}
func (this *Header) JSON() []byte {
jBytes, _ := json.Marshal(this)
return jBytes
func (hdr *Header) ToJson() []byte {
jBytes, _ := json.Marshal(hdr)
return jBytes
}
func (hdr *Header) Pack() ([]byte, error) {
var err error
headerBytes := make([]byte, 0, headerSize)
headerBuffer := bytes.NewBuffer(headerBytes)
func (this *Header) Pack() ([]byte, error) {
var err error
headerBytes := make([]byte, 0, headerSize)
headerBuffer := bytes.NewBuffer(headerBytes)
magicCodeABytes := EncoderI64(hdr.magicCodeA)
headerBuffer.Write(magicCodeABytes)
magicCodeABytes := encoderI64(this.magicCodeA)
headerBuffer.Write(magicCodeABytes)
rpcSizeBytes := EncoderI64(hdr.rpcSize)
headerBuffer.Write(rpcSizeBytes)
rpcSizeBytes := encoderI64(this.rpcSize)
headerBuffer.Write(rpcSizeBytes)
binSizeBytes := EncoderI64(hdr.binSize)
headerBuffer.Write(binSizeBytes)
binSizeBytes := encoderI64(this.binSize)
headerBuffer.Write(binSizeBytes)
magicCodeBBytes := EncoderI64(hdr.magicCodeB)
headerBuffer.Write(magicCodeBBytes)
magicCodeBBytes := encoderI64(this.magicCodeB)
headerBuffer.Write(magicCodeBBytes)
return headerBuffer.Bytes(), err
return headerBuffer.Bytes(), err
}
func UnpackHeader(headerBytes []byte) (*Header, error) {
var err error
header := NewHeader()
headerReader := bytes.NewReader(headerBytes)
var err error
magicCodeABytes := make([]byte, sizeOfInt64)
headerReader.Read(magicCodeABytes)
header.magicCodeA = decoderI64(magicCodeABytes)
headerReader := bytes.NewReader(headerBytes)
rpcSizeBytes := make([]byte, sizeOfInt64)
headerReader.Read(rpcSizeBytes)
header.rpcSize = decoderI64(rpcSizeBytes)
magicCodeABytes := make([]byte, sizeOfInt64)
headerReader.Read(magicCodeABytes)
binSizeBytes := make([]byte, sizeOfInt64)
headerReader.Read(binSizeBytes)
header.binSize = decoderI64(binSizeBytes)
rpcSizeBytes := make([]byte, sizeOfInt64)
headerReader.Read(rpcSizeBytes)
magicCodeBBytes := make([]byte, sizeOfInt64)
headerReader.Read(magicCodeBBytes)
header.magicCodeB = decoderI64(magicCodeBBytes)
binSizeBytes := make([]byte, sizeOfInt64)
headerReader.Read(binSizeBytes)
if header.magicCodeA != magicCodeA || header.magicCodeB != magicCodeB {
return header, errors.New("wrong protocol magic code")
}
magicCodeBBytes := make([]byte, sizeOfInt64)
headerReader.Read(magicCodeBBytes)
return header, err
header := &Header{
magicCodeA: DecoderI64(magicCodeABytes),
rpcSize: DecoderI64(rpcSizeBytes),
binSize: DecoderI64(binSizeBytes),
magicCodeB: DecoderI64(magicCodeBBytes),
}
if header.magicCodeA != magicCodeA || header.magicCodeB != magicCodeB {
err = errors.New("Wrong protocol magic code")
return header, err
}
return header, err
}
func encoderI64(i int64) []byte {
buffer := make([]byte, sizeOfInt64)
binary.BigEndian.PutUint64(buffer, uint64(i))
return buffer
func EncoderI64(i int64) []byte {
buffer := make([]byte, sizeOfInt64)
binary.BigEndian.PutUint64(buffer, uint64(i))
return buffer
}
func decoderI64(b []byte) int64 {
return int64(binary.BigEndian.Uint64(b))
func DecoderI64(b []byte) int64 {
return int64(binary.BigEndian.Uint64(b))
}

View File

@@ -7,39 +7,39 @@
package dsrpc
import (
"fmt"
"io"
"os"
"time"
"fmt"
"io"
"os"
"time"
)
var messageWriter io.Writer = os.Stdout
var accessWriter io.Writer = os.Stdout
func logDebug(messages ...any) {
stamp := time.Now().Format(time.RFC3339Nano)
fmt.Fprintln(messageWriter, stamp, "debug", messages)
stamp := time.Now().Format(time.RFC3339)
fmt.Fprintln(messageWriter, stamp, "debug", messages)
}
func logInfo(messages ...any) {
stamp := time.Now().Format(time.RFC3339Nano)
fmt.Fprintln(messageWriter, stamp, "info", messages)
stamp := time.Now().Format(time.RFC3339)
fmt.Fprintln(messageWriter, stamp, "info", messages)
}
func logError(messages ...any) {
stamp := time.Now().Format(time.RFC3339Nano)
fmt.Fprintln(messageWriter, stamp, "error", messages)
stamp := time.Now().Format(time.RFC3339)
fmt.Fprintln(messageWriter, stamp, "error", messages)
}
func logAccess(messages ...any) {
stamp := time.Now().Format(time.RFC3339Nano)
fmt.Fprintln(accessWriter, stamp, "access", messages)
stamp := time.Now().Format(time.RFC3339)
fmt.Fprintln(accessWriter, stamp, "access", messages)
}
func SetAccessWriter(writer io.Writer) {
accessWriter = writer
accessWriter = writer
}
func SetMessageWriter(writer io.Writer) {
messageWriter = writer
messageWriter = writer
}

View File

@@ -1,31 +1,29 @@
/*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/
package dsrpc
import (
"time"
"time"
)
func LogRequest(context *Context) error {
var err error
logDebug("request:", string(context.reqRPC.JSON()))
return err
func LogRequest(content *Content) error {
var err error
logDebug("request:", string(content.reqBlock.ToJson()))
return err
}
func LogResponse(context *Context) error {
var err error
logDebug("response:", string(context.resRPC.JSON()))
return err
func LogResponse(content *Content) error {
var err error
logDebug("response:", string(content.resBlock.ToJson()))
return err
}
func LogAccess(context *Context) error {
var err error
execTime := time.Now().Sub(context.start)
logAccess(context.remoteHost, context.reqRPC.Method, execTime)
return err
func LogAccess(content *Content) error {
var err error
execTime := time.Now().Sub(content.start)
login := string(content.AuthIdent())
logAccess(content.remoteHost, login, content.reqBlock.Method, execTime)
return err
}

View File

@@ -1,25 +1,18 @@
/*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/
package dsrpc
import (
"encoding/json"
)
type Packet struct {
header []byte
rcpPayload []byte
header []byte
rcpPayload []byte
}
func NewPacket() *Packet {
return &Packet{}
}
func (this *Packet) JSON() []byte {
jBytes, _ := json.Marshal(this)
return jBytes
func NewEmptyPacket() *Packet {
packet := &Packet{
header: make([]byte, 0),
rcpPayload: make([]byte, 0),
}
return packet
}

View File

@@ -7,27 +7,36 @@
package dsrpc
import (
"encoding/json"
"encoding/json"
encoder "encoding/json"
)
type EmptyParams struct{}
func NewEmptyParams() *EmptyParams {
return &EmptyParams{}
}
type Request struct {
Method string `json:"method" msgpack:"method"`
Params any `json:"params,omitempty" msgpack:"params,omitempty"`
Auth *Auth `json:"auth,omitempty" msgpack:"auth,omitempty"`
Method string `json:"method" msgpack:"method"`
Params any `json:"params,omitempty" msgpack:"params"`
Auth *Auth `json:"auth,omitempty" msgpack:"auth"`
}
func NewRequest() *Request {
req := &Request{}
req.Auth = &Auth{}
return req
func NewEmptyRequest() *Request {
req := &Request{}
req.Auth = &Auth{}
req.Params = NewEmptyParams()
return req
}
func (this *Request) Pack() ([]byte, error) {
rBytes, err := json.Marshal(this)
return rBytes, err
func (req *Request) Pack() ([]byte, error) {
rBytes, err := encoder.Marshal(req)
return rBytes, err
}
func (this *Request) JSON() []byte {
jBytes, _ := json.Marshal(this)
return jBytes
func (req *Request) ToJson() []byte {
jBytes, _ := json.Marshal(req)
return jBytes
}

View File

@@ -1,31 +1,38 @@
/*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/
package dsrpc
import (
"encoding/json"
"encoding/json"
encoder "encoding/json"
)
type EmptyResult struct{}
func NewEmptyResult() *EmptyResult {
return &EmptyResult{}
}
type Response struct {
Error string `json:"error,omitempty" msgpack:"error,omitempty"`
Result any `json:"result,omitemty" msgpack:"result,omitemty"`
Error string `json:"error" msgpack:"error"`
Result any `json:"result" msgpack:"result"`
}
func NewResponse() *Response {
return &Response{}
func NewEmptyResponse() *Response {
return &Response{
Result: NewEmptyResult(),
}
}
func (this *Response) JSON() []byte {
jBytes, _ := json.Marshal(this)
return jBytes
func (resp *Response) ToJson() []byte {
jBytes, _ := json.Marshal(resp)
return jBytes
}
func (this *Response) Pack() ([]byte, error) {
rBytes, err := json.Marshal(this)
return rBytes, err
func (resp *Response) Pack() ([]byte, error) {
rBytes, err := encoder.Marshal(resp)
return rBytes, err
}

479
server.go
View File

@@ -1,263 +1,348 @@
/*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/
package dsrpc
import (
"context"
"encoding/json"
"errors"
"io"
"net"
"sync"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"sync"
"time"
encoder "encoding/json"
)
type HandlerFunc = func(*Context) error
type HandlerFunc = func(*Content) error
type Service struct {
handlers map[string]HandlerFunc
ctx context.Context
cancel context.CancelFunc
wg *sync.WaitGroup
preMw []HandlerFunc
postMw []HandlerFunc
handlers map[string]HandlerFunc
ctx context.Context
cancel context.CancelFunc
wg *sync.WaitGroup
preMw []HandlerFunc
postMw []HandlerFunc
keepalive bool
kaTime time.Duration
kaMtx sync.Mutex
listener net.Listener
tcpListener *net.TCPListener
}
func NewService() *Service {
rdrpc := &Service{}
rdrpc.handlers = make(map[string]HandlerFunc)
ctx, cancel := context.WithCancel(context.Background())
rdrpc.ctx = ctx
rdrpc.cancel = cancel
var wg sync.WaitGroup
rdrpc.wg = &wg
rdrpc.preMw = make([]HandlerFunc, 0)
rdrpc.postMw = make([]HandlerFunc, 0)
rdrpc := &Service{}
rdrpc.handlers = make(map[string]HandlerFunc)
ctx, cancel := context.WithCancel(context.Background())
rdrpc.ctx = ctx
rdrpc.cancel = cancel
var wg sync.WaitGroup
rdrpc.wg = &wg
rdrpc.preMw = make([]HandlerFunc, 0)
rdrpc.postMw = make([]HandlerFunc, 0)
return rdrpc
return rdrpc
}
func (this *Service) PreMiddleware(mw HandlerFunc) {
this.preMw = append(this.preMw, mw)
func (svc *Service) PreMiddleware(mw HandlerFunc) {
svc.preMw = append(svc.preMw, mw)
}
func (this *Service) PostMiddleware(mw HandlerFunc) {
this.postMw = append(this.postMw, mw)
func (svc *Service) PostMiddleware(mw HandlerFunc) {
svc.postMw = append(svc.postMw, mw)
}
func (this *Service) Handler(method string, handler HandlerFunc) {
this.handlers[method] = handler
func (svc *Service) Handle(method string, handler HandlerFunc) {
svc.handlers[method] = handler
}
func (this *Service) Listen(address string) error {
var err error
logInfo("server listen:", address)
listener, err := net.Listen("tcp", address)
if err != nil {
return err
}
this.wg.Add(1)
for {
select {
case <- this.ctx.Done():
this.wg.Done()
func (svc *Service) SetKeepAlive(flag bool) {
svc.kaMtx.Lock()
defer svc.kaMtx.Unlock()
svc.keepalive = true
}
func (svc *Service) SetKeepAlivePeriod(interval time.Duration) {
svc.kaMtx.Lock()
defer svc.kaMtx.Unlock()
svc.kaTime = interval
}
func (svc *Service) Listen(address string) error {
var err error
logInfo("server listen:", address)
addr, err := net.ResolveTCPAddr("tcp", address)
if err != nil {
err = fmt.Errorf("unable to resolve adddress: %s", err)
return err
default:
}
conn, err := listener.Accept()
svc.tcpListener, err = net.ListenTCP("tcp", addr)
if err != nil {
logError("conn accept err:", err)
err = fmt.Errorf("unable to start listener: %s", err)
return err
}
go this.handleConn(conn)
}
}
func notFound(context *Context) error {
execErr := errors.New("method not found")
err := context.SendError(execErr)
return err
}
func (this *Service) Stop() error {
var err error
this.cancel()
this.wg.Wait()
return err
}
func (this *Service) handleConn(conn net.Conn) {
var err error
context := CreateContext(conn)
remoteAddr := conn.RemoteAddr().String()
remoteHost, _, _ := net.SplitHostPort(remoteAddr)
context.remoteHost = remoteHost
context.binReader = conn
context.binWriter = io.Discard
exitFunc := func() {
conn.Close()
if err != nil {
logError("conn handler err:", err)
}
}
defer exitFunc()
recovFunc := func () {
panicMsg := recover()
if panicMsg != nil {
logError("handler panic message:", panicMsg)
for {
conn, err := svc.tcpListener.AcceptTCP()
if err != nil {
logError("conn accept err:", err)
}
select {
case <-svc.ctx.Done():
return err
default:
}
svc.wg.Add(1)
go svc.handleTCPConn(conn, svc.wg)
}
}
defer recovFunc()
return err
}
err = context.ReadRequest()
if err != nil {
return
}
func (svc *Service) ListenTLS(address string, tlsConfig *tls.Config) error {
var err error
logInfo("server listen:", address)
err = context.BindMethod()
if err != nil {
return
}
for _, mw := range this.preMw {
err = mw(context)
svc.listener, err = tls.Listen("tcp", address, tlsConfig)
if err != nil {
return
err = fmt.Errorf("unable to start listener: %s", err)
return err
}
}
err = this.Route(context)
if err != nil {
return
}
for _, mw := range this.postMw {
err = mw(context)
for {
conn, err := svc.listener.Accept()
if err != nil {
logError("conn accept err:", err)
}
select {
case <-svc.ctx.Done():
logMessage("accept loop done")
return err
default:
}
svc.wg.Add(1)
go svc.handleConn(conn, svc.wg)
}
return err
}
func notFound(content *Content) error {
execErr := errors.New("method not found")
err := content.SendError(execErr)
return err
}
func (svc *Service) Stop() error {
var err error
// Disable new connection
logInfo("cancel rpc accept loop")
if svc.listener != nil {
svc.listener.Close()
}
if svc.tcpListener != nil {
svc.tcpListener.Close()
}
svc.cancel()
// Wait handlers
logInfo("wait rpc handlers")
svc.wg.Wait()
return err
}
func (svc *Service) handleTCPConn(conn *net.TCPConn, wg *sync.WaitGroup) {
var err error
if svc.keepalive {
err = conn.SetKeepAlive(true)
if err != nil {
err = fmt.Errorf("unable to set keepalive: %s", err)
return
}
if svc.kaTime > 0 {
err = conn.SetKeepAlivePeriod(svc.kaTime)
if err != nil {
err = fmt.Errorf("unable to set keepalive period: %s", err)
return
}
}
}
svc.handleConn(conn, wg)
}
func (svc *Service) handleConn(conn net.Conn, wg *sync.WaitGroup) {
var err error
content := CreateContent(conn)
remoteAddr := conn.RemoteAddr().String()
remoteHost, _, _ := net.SplitHostPort(remoteAddr)
content.remoteHost = remoteHost
content.binReader = conn
content.binWriter = io.Discard
exitFunc := func() {
conn.Close()
wg.Done()
if err != nil {
logError("conn handler err:", err)
}
}
defer exitFunc()
recovFunc := func() {
panicMsg := recover()
if panicMsg != nil {
logError("handler panic message:", panicMsg)
}
}
defer recovFunc()
err = content.ReadRequest()
if err != nil {
return
err = err
return
}
}
return
err = content.BindMethod()
if err != nil {
err = err
return
}
for _, mw := range svc.preMw {
err = mw(content)
if err != nil {
err = err
return
}
}
err = svc.Route(content)
if err != nil {
err = err
return
}
for _, mw := range svc.postMw {
err = mw(content)
if err != nil {
err = err
return
}
}
return
}
func (this *Service) Route(context *Context) error {
handler, ok := this.handlers[context.reqRPC.Method]
if ok {
return handler(context)
}
return notFound(context)
func (svc *Service) Route(content *Content) error {
handler, ok := svc.handlers[content.reqBlock.Method]
if ok {
return handler(content)
}
return notFound(content)
}
func (context *Context) ReadRequest() error {
var err error
func (content *Content) ReadRequest() error {
var err error
context.reqPacket.header, err = ReadBytes(context.sockReader, headerSize)
if err != nil {
return err
}
context.reqHeader, err = UnpackHeader(context.reqPacket.header)
if err != nil {
return err
}
content.reqPacket.header, err = ReadBytes(content.sockReader, headerSize)
if err != nil {
return err
}
content.reqHeader, err = UnpackHeader(content.reqPacket.header)
if err != nil {
return err
}
rpcSize := context.reqHeader.rpcSize
context.reqPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize)
if err != nil {
rpcSize := content.reqHeader.rpcSize
content.reqPacket.rcpPayload, err = ReadBytes(content.sockReader, rpcSize)
if err != nil {
return err
}
return err
}
return err
}
func (context *Context) BinWriter() io.Writer {
return context.sockWriter
func (content *Content) BinWriter() io.Writer {
return content.sockWriter
}
func (context *Context) BinReader() io.Reader {
return context.sockReader
func (content *Content) BinReader() io.Reader {
return content.sockReader
}
func (context *Context) BinSize() int64 {
return context.reqHeader.binSize
func (content *Content) BinSize() int64 {
return content.reqHeader.binSize
}
func (context *Context) ReadBin(writer io.Writer) error {
var err error
_, err = CopyBytes(context.sockReader, writer, context.reqHeader.binSize)
return err
func (content *Content) ReadBin(ctx context.Context, writer io.Writer) error {
var err error
_, err = CopyBytes(ctx, content.sockReader, writer, content.reqHeader.binSize)
return err
}
func (context *Context) BindMethod() error {
var err error
err = json.Unmarshal(context.reqPacket.rcpPayload, context.reqRPC)
return err
func (content *Content) BindMethod() error {
var err error
err = encoder.Unmarshal(content.reqPacket.rcpPayload, content.reqBlock)
return err
}
func (context *Context) BindParams(params any) error {
var err error
context.reqRPC.Params = params
err = json.Unmarshal(context.reqPacket.rcpPayload, context.reqRPC)
if err != nil {
func (content *Content) BindParams(params any) error {
var err error
content.reqBlock.Params = params
err = encoder.Unmarshal(content.reqPacket.rcpPayload, content.reqBlock)
if err != nil {
return err
}
return err
}
return err
}
func (context *Context) SendResult(result any, binSize int64) error {
var err error
context.resRPC.Result = result
func (content *Content) SendResult(result any, binSize int64) error {
var err error
content.resBlock.Result = result
context.resPacket.rcpPayload, err = context.resRPC.Pack()
if err != nil {
return err
}
context.resHeader.rpcSize = int64(len(context.resPacket.rcpPayload))
context.resHeader.binSize = binSize
content.resPacket.rcpPayload, err = content.resBlock.Pack()
if err != nil {
return err
}
content.resHeader.rpcSize = int64(len(content.resPacket.rcpPayload))
content.resHeader.binSize = binSize
context.resPacket.header, err = context.resHeader.Pack()
if err != nil {
content.resPacket.header, err = content.resHeader.Pack()
if err != nil {
return err
}
_, err = content.sockWriter.Write(content.resPacket.header)
if err != nil {
return err
}
_, err = content.sockWriter.Write(content.resPacket.rcpPayload)
if err != nil {
return err
}
return err
}
_, err = context.sockWriter.Write(context.resPacket.header)
if err != nil {
return err
}
_, err = context.sockWriter.Write(context.resPacket.rcpPayload)
if err != nil {
return err
}
return err
}
func (content *Content) SendError(execErr error) error {
var err error
func (context *Context) SendError(execErr error) error {
var err error
content.resBlock.Error = execErr.Error()
content.resBlock.Result = NewEmptyResult()
context.resRPC.Error = execErr.Error()
context.resRPC.Result = NewEmpty()
context.resPacket.rcpPayload, err = context.resRPC.Pack()
if err != nil {
content.resPacket.rcpPayload, err = content.resBlock.Pack()
if err != nil {
return err
}
content.resHeader.rpcSize = int64(len(content.resPacket.rcpPayload))
content.resPacket.header, err = content.resHeader.Pack()
if err != nil {
return err
}
_, err = content.sockWriter.Write(content.resPacket.header)
if err != nil {
return err
}
_, err = content.sockWriter.Write(content.resPacket.rcpPayload)
if err != nil {
return err
}
return err
}
context.resHeader.rpcSize = int64(len(context.resPacket.rcpPayload))
context.resPacket.header, err = context.resHeader.Pack()
if err != nil {
return err
}
_, err = context.sockWriter.Write(context.resPacket.header)
if err != nil {
return err
}
_, err = context.sockWriter.Write(context.resPacket.rcpPayload)
if err != nil {
return err
}
return err
}

View File

@@ -5,50 +5,60 @@
package dsrpc
import (
"errors"
"io"
"fmt"
"context"
"errors"
"fmt"
"io"
)
func ReadBytes(reader io.Reader, size int64) ([]byte, error) {
buffer := make([]byte, size)
read, err := io.ReadFull(reader, buffer)
return buffer[0:read], err
buffer := make([]byte, size)
read, err := io.ReadFull(reader, buffer)
return buffer[0:read], err
}
func CopyBytes(reader io.Reader, writer io.Writer, dataSize int64) (int64, error) {
var err error
var bSize int64 = 1024 * 4
var total int64 = 0
var remains int64 = dataSize
buffer := make([]byte, bSize)
func CopyBytes(ctx context.Context, reader io.Reader, writer io.Writer, dataSize int64) (int64, error) {
var err error
var bSize int64 = 1024 * 16
var total int64 = 0
var remains int64 = dataSize
buffer := make([]byte, bSize)
for {
if reader == nil {
return total, errors.New("reader is nil")
}
if writer == nil {
return total, errors.New("writer is nil")
}
if remains == 0 {
return total, err
}
if remains < bSize {
bSize = remains
}
received, err := reader.Read(buffer[0:bSize])
if err != nil {
return total, fmt.Errorf("read error: %v", err)
}
recorded, err := writer.Write(buffer[0:received])
if err != nil {
return total, fmt.Errorf("write error: %v", err)
}
if recorded != received {
return total, errors.New("size mismatch")
}
total += int64(recorded)
remains -= int64(recorded)
}
return total, err
for {
select {
case <-ctx.Done():
return total, errors.New("break by context")
default:
}
if reader == nil {
return total, errors.New("reader is nil")
}
if writer == nil {
return total, errors.New("writer is nil")
}
if remains == 0 {
return total, err
}
if remains < bSize {
bSize = remains
}
received, err := reader.Read(buffer[0:bSize])
if err != nil {
err = fmt.Errorf("read error: %v", err)
return total, err
}
recorded, err := writer.Write(buffer[0:received])
if err != nil {
err = fmt.Errorf("write error: %v", err)
return total, err
}
if recorded != received {
err = errors.New("size mismatch")
return total, err
}
total += int64(recorded)
remains -= int64(recorded)
}
return total, err
}

View File

@@ -4,161 +4,172 @@
package dsrpc
import (
"io"
"net"
"context"
"io"
"net"
)
func LocalExec(method string, param any, result any, auth *Auth, handler HandlerFunc) error {
var err error
func LocalExec(method string, param, result any, auth *Auth, handler HandlerFunc) error {
var err error
cliConn, srvConn := NewFConn()
cliConn, srvConn := NewFConn()
context := CreateContext(cliConn)
context.reqRPC.Method = method
context.reqRPC.Params = param
context.reqRPC.Auth = auth
context.resRPC.Result = result
content := CreateContent(cliConn)
content.reqBlock.Method = method
if context.reqRPC.Params == nil {
context.reqRPC.Params = NewEmpty()
}
err = context.CreateRequest()
if err != nil {
return err
}
err = context.WriteRequest()
if err != nil {
return err
}
err = LocalService(srvConn, handler)
if err != nil {
return err
}
err = context.ReadResponse()
if err != nil {
return err
}
err = context.BindResponse()
if err != nil {
return err
}
if param != nil {
content.reqBlock.Params = param
}
if auth != nil {
content.reqBlock.Auth = auth
}
if result != nil {
content.resBlock.Result = result
}
return err
err = content.createRequest()
if err != nil {
return err
}
err = content.writeRequest()
if err != nil {
return err
}
err = LocalService(srvConn, handler)
if err != nil {
return err
}
err = content.readResponse()
if err != nil {
return err
}
err = content.bindResponse()
if err != nil {
return err
}
return err
}
func LocalPut(method string, reader io.Reader, size int64, param, result any, auth *Auth, handler HandlerFunc) error {
func LocalPut(ctx context.Context, method string, reader io.Reader, size int64, param, result any, auth *Auth, handler HandlerFunc) error {
var err error
var err error
cliConn, srvConn := NewFConn()
cliConn, srvConn := NewFConn()
context := CreateContext(cliConn)
context.reqRPC.Method = method
context.reqRPC.Params = param
context.reqRPC.Auth = auth
context.resRPC.Result = result
content := CreateContent(cliConn)
content.reqBlock.Method = method
context.binReader = reader
context.binWriter = cliConn
if param != nil {
content.reqBlock.Params = param
}
if auth != nil {
content.reqBlock.Auth = auth
}
if result != nil {
content.resBlock.Result = result
}
context.reqHeader.binSize = size
content.binReader = reader
content.binWriter = cliConn
if context.reqRPC.Params == nil {
context.reqRPC.Params = NewEmpty()
}
err = context.CreateRequest()
if err != nil {
return err
}
err = context.WriteRequest()
if err != nil {
return err
}
err = context.UploadBin()
if err != nil {
return err
}
err = LocalService(srvConn, handler)
if err != nil {
return err
}
err = context.ReadResponse()
if err != nil {
return err
}
err = context.BindResponse()
if err != nil {
return err
}
return err
content.reqHeader.binSize = size
err = content.createRequest()
if err != nil {
return err
}
err = content.writeRequest()
if err != nil {
return err
}
err = content.uploadBin(ctx)
if err != nil {
return err
}
err = LocalService(srvConn, handler)
if err != nil {
return err
}
err = content.readResponse()
if err != nil {
return err
}
err = content.bindResponse()
if err != nil {
return err
}
return err
}
func LocalGet(ctx context.Context, method string, writer io.Writer, param, result any, auth *Auth, handler HandlerFunc) error {
var err error
func LocalGet(method string, writer io.Writer, param, result any, auth *Auth, handler HandlerFunc) error {
var err error
cliConn, srvConn := NewFConn()
cliConn, srvConn := NewFConn()
content := CreateContent(cliConn)
content.reqBlock.Method = method
context := CreateContext(cliConn)
context.reqRPC.Method = method
context.reqRPC.Params = param
context.reqRPC.Auth = auth
context.resRPC.Result = result
if param != nil {
content.reqBlock.Params = param
}
if auth != nil {
content.reqBlock.Auth = auth
}
if result != nil {
content.resBlock.Result = result
}
context.binReader = cliConn
context.binWriter = writer
content.binReader = cliConn
content.binWriter = writer
if context.reqRPC.Params == nil {
context.reqRPC.Params = NewEmpty()
}
err = context.CreateRequest()
if err != nil {
return err
}
err = context.WriteRequest()
if err != nil {
return err
}
err = content.createRequest()
if err != nil {
return err
}
err = content.writeRequest()
if err != nil {
return err
}
err = LocalService(srvConn, handler)
if err != nil {
return err
}
err = context.ReadResponse()
if err != nil {
return err
}
err = context.DownloadBin()
if err != nil {
return err
}
err = context.BindResponse()
if err != nil {
return err
}
return err
err = LocalService(srvConn, handler)
if err != nil {
return err
}
err = content.readResponse()
if err != nil {
return err
}
err = content.downloadBin(ctx)
if err != nil {
return err
}
err = content.bindResponse()
if err != nil {
return err
}
return err
}
func LocalService(conn net.Conn, handler HandlerFunc) error {
var err error
context := CreateContext(conn)
var err error
content := CreateContent(conn)
remoteAddr := conn.RemoteAddr().String()
remoteHost, _, _ := net.SplitHostPort(remoteAddr)
context.remoteHost = remoteHost
remoteAddr := conn.RemoteAddr().String()
remoteHost, _, _ := net.SplitHostPort(remoteAddr)
content.remoteHost = remoteHost
context.binReader = conn
context.binWriter = io.Discard
content.binReader = conn
content.binWriter = io.Discard
err = context.ReadRequest()
if err != nil {
return err
}
err = context.BindMethod()
if err != nil {
return err
}
return handler(context)
err = content.ReadRequest()
if err != nil {
return err
}
err = content.BindMethod()
if err != nil {
return err
}
return handler(content)
}

View File

@@ -5,60 +5,60 @@
package dsrpc
import (
"encoding/json"
"bytes"
"math/rand"
"time"
"crypto/sha256"
"bytes"
"crypto/sha256"
"encoding/json"
"math/rand"
"time"
)
func init() {
rand.Seed(time.Now().UnixNano())
rand.Seed(time.Now().UnixNano())
}
type Auth struct {
Ident []byte `json:"ident,omitempty"`
Salt []byte `json:"salt,omitempty"`
Hash []byte `json:"hash,omitempty"`
Ident []byte `msgpack:"ident" json:"ident"`
Salt []byte `msgpack:"salt" json:"salt"`
Hash []byte `msgpack:"hash" json:"hash"`
}
func NewAuth() *Auth {
return &Auth{}
return &Auth{}
}
func (this *Auth) JSON() []byte {
jBytes, _ := json.Marshal(this)
return jBytes
func (this *Auth) Json() []byte {
jBytes, _ := json.Marshal(this)
return jBytes
}
func CreateAuth(ident, pass []byte) *Auth {
salt := CreateSalt()
hash := CreateHash(ident, pass, salt)
auth := &Auth{}
auth.Ident = ident
auth.Salt = salt
auth.Hash = hash
return auth
salt := CreateSalt()
hash := CreateHash(ident, pass, salt)
auth := &Auth{}
auth.Ident = ident
auth.Salt = salt
auth.Hash = hash
return auth
}
func CreateSalt() []byte {
const saltSize = 16
randBytes := make([]byte, saltSize)
rand.Read(randBytes)
return randBytes
const saltSize = 16
randBytes := make([]byte, saltSize)
rand.Read(randBytes)
return randBytes
}
func CreateHash(ident, pass, salt []byte) []byte {
vec := make([]byte, 0, len(ident) + len(salt) + len(pass))
vec = append(vec, ident...)
vec = append(vec, salt...)
vec = append(vec, pass...)
hasher := sha256.New()
hash := hasher.Sum(vec)
return hash
vec := make([]byte, 0, len(ident)+len(salt)+len(pass))
vec = append(vec, ident...)
vec = append(vec, salt...)
vec = append(vec, pass...)
hasher := sha256.New()
hash := hasher.Sum(vec)
return hash
}
func CheckHash(ident, pass, reqSalt, reqHash []byte) bool {
localHash := CreateHash(ident, pass, reqSalt)
return bytes.Equal(reqHash, localHash)
localHash := CreateHash(ident, pass, reqSalt)
return bytes.Equal(reqHash, localHash)
}