Little refactoring

This commit is contained in:
2023-04-01 00:18:44 +02:00
parent 8dc753cd95
commit 8b3e722ea5
16 changed files with 354 additions and 562 deletions

276
client.go
View File

@@ -14,80 +14,71 @@ import (
encoder "github.com/vmihailenco/msgpack/v5" encoder "github.com/vmihailenco/msgpack/v5"
) )
func Put(address string, method string, reader io.Reader, size int64, param, result any, auth *Auth) error { func Put(address string, method string, reader io.Reader, binSize int64, param, result any, auth *Auth) error {
var err error var err error
addr, err := net.ResolveTCPAddr("tcp", address) addr, err := net.ResolveTCPAddr("tcp", address)
if err != nil { if err != nil {
err = fmt.Errorf("unable to resolve adddress: %s", err) err = fmt.Errorf("unable to resolve adddress: %s", err)
return Err(err) return err
} }
conn, err := net.DialTCP("tcp", nil, addr) conn, err := net.DialTCP("tcp", nil, addr)
if err != nil { if err != nil {
return Err(err) return err
} }
defer conn.Close() defer conn.Close()
//err = conn.SetKeepAlive(true) return ConnPut(conn, method, reader, binSize, param, result, auth)
//if err != nil {
// err = fmt.Errorf("unable to set keepalive: %s", err)
// return Err(err)
//}
//err = conn.SetKeepAlivePeriod(10 * time.Second)
//if err != nil {
// err = fmt.Errorf("unable to set keepalive period: %s", err)
// return Err(err)
//}
return ConnPut(conn, method, reader, size, param, result, auth)
} }
func ConnPut(conn net.Conn, method string, reader io.Reader, size int64, param, result any, auth *Auth) error { func ConnPut(conn net.Conn, method string, reader io.Reader, binSize int64, param, result any, auth *Auth) error {
var err error var err error
context := CreateContext(conn) content := CreateContent(conn)
context.reqRPC.Method = method
context.reqRPC.Params = param
context.reqRPC.Auth = auth
context.resRPC.Result = result
context.binReader = reader content.reqBlock.Method = method
context.binWriter = conn if param != nil {
content.reqBlock.Params = param
context.reqHeader.binSize = size }
if auth != nil {
if context.reqRPC.Params == nil { content.reqBlock.Auth = auth
context.reqRPC.Params = NewEmpty() }
if result != nil {
content.resBlock.Result = result
} }
err = context.CreateRequest() content.binReader = reader
content.binWriter = conn
content.reqHeader.binSize = binSize
err = content.CreateRequest()
if err != nil { if err != nil {
return Err(err) return err
} }
err = context.WriteRequest() err = content.WriteRequest()
if err != nil { if err != nil {
return Err(err) return err
} }
var wg sync.WaitGroup var wg sync.WaitGroup
errChan := make(chan error, 1) errChan := make(chan error, 1)
wg.Add(1) wg.Add(1)
go context.ReadResponseAsync(&wg, errChan) go content.ReadResponseAsync(&wg, errChan)
wg.Add(1) wg.Add(1)
go context.UploadBinAsync(&wg) go content.UploadBinAsync(&wg)
wg.Wait() wg.Wait()
err = <-errChan err = <-errChan
if err != nil { if err != nil {
return Err(err) return err
} }
err = context.BindResponse() err = content.BindResponse()
if err != nil { if err != nil {
return Err(err) return err
} }
return Err(err) return err
} }
func Get(address string, method string, writer io.Writer, param, result any, auth *Auth) error { func Get(address string, method string, writer io.Writer, param, result any, auth *Auth) error {
@@ -96,65 +87,56 @@ func Get(address string, method string, writer io.Writer, param, result any, aut
addr, err := net.ResolveTCPAddr("tcp", address) addr, err := net.ResolveTCPAddr("tcp", address)
if err != nil { if err != nil {
err = fmt.Errorf("unable to resolve adddress: %s", err) err = fmt.Errorf("unable to resolve adddress: %s", err)
return Err(err) return err
} }
conn, err := net.DialTCP("tcp", nil, addr) conn, err := net.DialTCP("tcp", nil, addr)
if err != nil { if err != nil {
return Err(err) return err
} }
defer conn.Close() defer conn.Close()
//err = conn.SetKeepAlive(true)
//if err != nil {
// err = fmt.Errorf("unable to set keepalive: %s", err)
// return Err(err)
//}
//err = conn.SetKeepAlivePeriod(10 * time.Second)
//if err != nil {
// err = fmt.Errorf("unable to set keepalive period: %s", err)
// return Err(err)
//}
return ConnGet(conn, method, writer, param, result, auth) return ConnGet(conn, method, writer, param, result, auth)
} }
func ConnGet(conn net.Conn, method string, writer io.Writer, param, result any, auth *Auth) error { func ConnGet(conn net.Conn, method string, writer io.Writer, param, result any, auth *Auth) error {
var err error var err error
context := CreateContext(conn) content := CreateContent(conn)
context.reqRPC.Method = method content.reqBlock.Method = method
context.reqRPC.Params = param if param != nil {
context.reqRPC.Auth = auth content.reqBlock.Params = param
context.resRPC.Result = result }
if auth != nil {
content.reqBlock.Auth = auth
}
if result != nil {
content.resBlock.Result = result
}
context.binReader = conn content.binReader = conn
context.binWriter = writer content.binWriter = writer
if context.reqRPC.Params == nil { err = content.CreateRequest()
context.reqRPC.Params = NewEmpty()
}
err = context.CreateRequest()
if err != nil { if err != nil {
return Err(err) return err
} }
err = context.WriteRequest() err = content.WriteRequest()
if err != nil { if err != nil {
return Err(err) return err
} }
err = context.ReadResponse() err = content.ReadResponse()
if err != nil { if err != nil {
return Err(err) return err
} }
err = context.DownloadBin() err = content.DownloadBin()
if err != nil { if err != nil {
return Err(err) return err
} }
err = context.BindResponse() err = content.BindResponse()
if err != nil { if err != nil {
return Err(err) return err
} }
return Err(err) return err
} }
func Exec(address, method string, param any, result any, auth *Auth) error { func Exec(address, method string, param any, result any, auth *Auth) error {
@@ -163,170 +145,162 @@ func Exec(address, method string, param any, result any, auth *Auth) error {
addr, err := net.ResolveTCPAddr("tcp", address) addr, err := net.ResolveTCPAddr("tcp", address)
if err != nil { if err != nil {
err = fmt.Errorf("unable to resolve adddress: %s", err) err = fmt.Errorf("unable to resolve adddress: %s", err)
return Err(err) return err
} }
conn, err := net.DialTCP("tcp", nil, addr) conn, err := net.DialTCP("tcp", nil, addr)
if err != nil { if err != nil {
return Err(err) return err
} }
defer conn.Close() defer conn.Close()
//err = conn.SetKeepAlive(true)
//if err != nil {
// err = fmt.Errorf("unable to set keepalive: %s", err)
// return Err(err)
//}
//err = conn.SetKeepAlivePeriod(10 * time.Second)
//if err != nil {
// err = fmt.Errorf("unable to set keepalive period: %s", err)
// return Err(err)
//}
err = ConnExec(conn, method, param, result, auth) err = ConnExec(conn, method, param, result, auth)
if err != nil { if err != nil {
return Err(err) return err
} }
return Err(err) return err
} }
func ConnExec(conn net.Conn, method string, param any, result any, auth *Auth) error { func ConnExec(conn net.Conn, method string, param any, result any, auth *Auth) error {
var err error var err error
context := CreateContext(conn) content := CreateContent(conn)
context.reqRPC.Method = method content.reqBlock.Method = method
context.reqRPC.Params = param
context.reqRPC.Auth = auth
context.resRPC.Result = result
if context.reqRPC.Params == nil { if param != nil {
context.reqRPC.Params = NewEmpty() content.reqBlock.Params = param
}
if result != nil {
content.resBlock.Result = result
}
if auth != nil {
content.reqBlock.Auth = auth
} }
err = context.CreateRequest() err = content.CreateRequest()
if err != nil { if err != nil {
return Err(err) return err
} }
err = context.WriteRequest() err = content.WriteRequest()
if err != nil { if err != nil {
return Err(err) return err
} }
err = context.ReadResponse() err = content.ReadResponse()
if err != nil { if err != nil {
return Err(err) return err
} }
err = context.BindResponse() err = content.BindResponse()
if err != nil { if err != nil {
return Err(err) return err
} }
return Err(err) return err
} }
func (context *Context) CreateRequest() error { func (content *Content) CreateRequest() error {
var err error var err error
context.reqPacket.rcpPayload, err = context.reqRPC.Pack() content.reqPacket.rcpPayload, err = content.reqBlock.Pack()
if err != nil { if err != nil {
return Err(err) return err
} }
rpcSize := int64(len(context.reqPacket.rcpPayload)) rpcSize := int64(len(content.reqPacket.rcpPayload))
context.reqHeader.rpcSize = rpcSize content.reqHeader.rpcSize = rpcSize
context.reqPacket.header, err = context.reqHeader.Pack() content.reqPacket.header, err = content.reqHeader.Pack()
if err != nil { if err != nil {
return Err(err) return err
} }
return Err(err) return err
} }
func (context *Context) WriteRequest() error { func (content *Content) WriteRequest() error {
var err error var err error
_, err = context.sockWriter.Write(context.reqPacket.header) _, err = content.sockWriter.Write(content.reqPacket.header)
if err != nil { if err != nil {
return Err(err) return err
} }
_, err = context.sockWriter.Write(context.reqPacket.rcpPayload) _, err = content.sockWriter.Write(content.reqPacket.rcpPayload)
if err != nil { if err != nil {
return Err(err) return err
} }
return Err(err) return err
} }
func (context *Context) UploadBin() error { func (content *Content) UploadBin() error {
var err error var err error
_, err = CopyBytes(context.binReader, context.binWriter, context.reqHeader.binSize) _, err = CopyBytes(content.binReader, content.binWriter, content.reqHeader.binSize)
return Err(err) return err
} }
func (context *Context) ReadResponse() error { func (content *Content) ReadResponse() error {
var err error var err error
context.resPacket.header, err = ReadBytes(context.sockReader, headerSize) content.resPacket.header, err = ReadBytes(content.sockReader, headerSize)
if err != nil { if err != nil {
return Err(err) return err
} }
context.resHeader, err = UnpackHeader(context.resPacket.header) content.resHeader, err = UnpackHeader(content.resPacket.header)
if err != nil { if err != nil {
return Err(err) return err
} }
rpcSize := context.resHeader.rpcSize rpcSize := content.resHeader.rpcSize
context.resPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize) content.resPacket.rcpPayload, err = ReadBytes(content.sockReader, rpcSize)
if err != nil { if err != nil {
return Err(err) return err
} }
return Err(err) return err
} }
func (context *Context) UploadBinAsync(wg *sync.WaitGroup) { func (content *Content) UploadBinAsync(wg *sync.WaitGroup) {
exitFunc := func() { exitFunc := func() {
wg.Done() wg.Done()
} }
defer exitFunc() defer exitFunc()
_, _ = CopyBytes(context.binReader, context.binWriter, context.reqHeader.binSize) _, _ = CopyBytes(content.binReader, content.binWriter, content.reqHeader.binSize)
return return
} }
func (context *Context) ReadResponseAsync(wg *sync.WaitGroup, errChan chan error) { func (content *Content) ReadResponseAsync(wg *sync.WaitGroup, errChan chan error) {
var err error var err error
exitFunc := func() { exitFunc := func() {
errChan <- err errChan <- err
wg.Done() wg.Done()
} }
defer exitFunc() defer exitFunc()
context.resPacket.header, err = ReadBytes(context.sockReader, headerSize) content.resPacket.header, err = ReadBytes(content.sockReader, headerSize)
if err != nil { if err != nil {
err = Err(err) err = err
return return
} }
context.resHeader, err = UnpackHeader(context.resPacket.header) content.resHeader, err = UnpackHeader(content.resPacket.header)
if err != nil { if err != nil {
err = Err(err) err = err
return return
} }
rpcSize := context.resHeader.rpcSize rpcSize := content.resHeader.rpcSize
context.resPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize) content.resPacket.rcpPayload, err = ReadBytes(content.sockReader, rpcSize)
if err != nil { if err != nil {
err = Err(err) err = err
return return
} }
return return
} }
func (context *Context) DownloadBin() error { func (content *Content) DownloadBin() error {
var err error var err error
_, err = CopyBytes(context.binReader, context.binWriter, context.resHeader.binSize) _, err = CopyBytes(content.binReader, content.binWriter, content.resHeader.binSize)
return Err(err) return err
} }
func (context *Context) BindResponse() error { func (content *Content) BindResponse() error {
var err error var err error
err = encoder.Unmarshal(context.resPacket.rcpPayload, context.resRPC) err = encoder.Unmarshal(content.resPacket.rcpPayload, content.resBlock)
if err != nil { if err != nil {
return Err(err) return err
} }
if len(context.resRPC.Error) > 0 { if len(content.resBlock.Error) > 0 {
err = errors.New(context.resRPC.Error) err = errors.New(content.resBlock.Error)
return Err(err) return err
} }
return Err(err) return err
} }

View File

@@ -1,5 +0,0 @@
/*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*/
package dsrpc

View File

@@ -1,152 +0,0 @@
/*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*/
package dsrpc
import (
"io"
"net"
"time"
)
type Context struct {
start time.Time
remoteHost string
sockReader io.Reader
sockWriter io.Writer
reqHeader *Header
reqRPC *Request
reqPacket *Packet
resPacket *Packet
resHeader *Header
resRPC *Response
binReader io.Reader
binWriter io.Writer
}
func NewContext() *Context {
context := &Context{}
context.start = time.Now()
return context
}
func CreateContext(conn net.Conn) *Context {
context := &Context{}
context.start = time.Now()
context.sockReader = conn
context.sockWriter = conn
context.reqPacket = NewPacket()
context.resPacket = NewPacket()
context.reqHeader = NewHeader()
context.reqRPC = NewRequest()
context.resHeader = NewHeader()
context.resRPC = NewResponse()
context.resRPC = NewResponse()
return context
}
func (context *Context) Request() *Request {
return context.reqRPC
}
func (context *Context) RemoteHost() string {
return context.remoteHost
}
func (context *Context) Start() time.Time {
return context.start
}
func (context *Context) Method() string {
var method string
if context.reqRPC != nil {
method = context.reqRPC.Method
}
return method
}
func (context *Context) ReqRpcSize() int64 {
var size int64
if context.reqHeader != nil {
size = context.reqHeader.rpcSize
}
return size
}
func (context *Context) ReqBinSize() int64 {
var size int64
if context.reqHeader != nil {
size = context.reqHeader.binSize
}
return size
}
func (context *Context) ResBinSize() int64 {
var size int64
if context.resHeader != nil {
size = context.resHeader.binSize
}
return size
}
func (context *Context) ResRpcSize() int64 {
var size int64
if context.resHeader != nil {
size = context.resHeader.rpcSize
}
return size
}
func (context *Context) ReqSize() int64 {
var size int64
if context.reqHeader != nil {
size += context.reqHeader.binSize
size += context.reqHeader.rpcSize
}
return size
}
func (context *Context) ResSize() int64 {
var size int64
if context.resHeader != nil {
size += context.resHeader.binSize
size += context.resHeader.rpcSize
}
return size
}
func (context *Context) SetAuthIdent(ident []byte) {
context.reqRPC.Auth.Ident = ident
}
func (context *Context) SetAuthSalt(salt []byte) {
context.reqRPC.Auth.Salt = salt
}
func (context *Context) SetAuthHash(hash []byte) {
context.reqRPC.Auth.Hash = hash
}
func (context *Context) AuthIdent() []byte {
return context.reqRPC.Auth.Ident
}
func (context *Context) AuthSalt() []byte {
return context.reqRPC.Auth.Salt
}
func (context *Context) AuthHash() []byte {
return context.reqRPC.Auth.Hash
}
func (context *Context) Auth() *Auth {
return context.reqRPC.Auth
}

View File

@@ -1,13 +0,0 @@
/*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/
package dsrpc
type Empty struct{}
func NewEmpty() *Empty {
return &Empty{}
}

View File

@@ -1,42 +0,0 @@
/*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*/
package dsrpc
import (
"fmt"
"io"
"runtime"
)
var develMode bool = false
var debugMode bool = false
func SetDevelMode(mode bool) {
develMode = mode
}
func SetDebugMode(mode bool) {
debugMode = mode
}
func Err(err error) error {
switch err {
case io.EOF:
return err
}
if err != nil {
switch {
case develMode == true:
pc, filename, line, _ := runtime.Caller(1)
funcName := runtime.FuncForPC(pc).Name()
err = fmt.Errorf(" %s:%d:%s:%s", filename, line, funcName, err.Error())
case debugMode == true:
pc, _, line, _ := runtime.Caller(1)
funcName := runtime.FuncForPC(pc).Name()
err = fmt.Errorf(" %s:%d:%s ", funcName, line, err.Error())
default:
}
}
return err
}

View File

@@ -209,58 +209,58 @@ func testServ(quiet bool) error {
return err return err
} }
func auth(context *Context) error { func auth(content *Content) error {
var err error var err error
reqIdent := context.AuthIdent() reqIdent := content.AuthIdent()
reqSalt := context.AuthSalt() reqSalt := content.AuthSalt()
reqHash := context.AuthHash() reqHash := content.AuthHash()
ident := reqIdent ident := reqIdent
pass := []byte("12345") pass := []byte("12345")
auth := context.Auth() auth := content.Auth()
logDebug("auth ", string(auth.JSON())) logDebug("auth ", string(auth.JSON()))
ok := CheckHash(ident, pass, reqSalt, reqHash) ok := CheckHash(ident, pass, reqSalt, reqHash)
logDebug("auth ok:", ok) logDebug("auth ok:", ok)
if !ok { if !ok {
err = errors.New("auth ident or pass missmatch") err = errors.New("auth ident or pass missmatch")
context.SendError(err) content.SendError(err)
return err return err
} }
return err return err
} }
func helloHandler(context *Context) error { func helloHandler(content *Content) error {
var err error var err error
params := NewHelloParams() params := NewHelloParams()
err = context.BindParams(params) err = content.BindParams(params)
if err != nil { if err != nil {
return err return err
} }
err = context.ReadBin(io.Discard) err = content.ReadBin(io.Discard)
if err != nil { if err != nil {
context.SendError(err) content.SendError(err)
return err return err
} }
result := NewHelloResult() result := NewHelloResult()
result.Message = "hello, client!" result.Message = "hello, client!"
err = context.SendResult(result, 0) err = content.SendResult(result, 0)
if err != nil { if err != nil {
return err return err
} }
return err return err
} }
func saveHandler(context *Context) error { func saveHandler(content *Content) error {
var err error var err error
params := NewSaveParams() params := NewSaveParams()
err = context.BindParams(params) err = content.BindParams(params)
if err != nil { if err != nil {
return err return err
} }
@@ -268,34 +268,34 @@ func saveHandler(context *Context) error {
bufferBytes := make([]byte, 0, 1024) bufferBytes := make([]byte, 0, 1024)
binWriter := bytes.NewBuffer(bufferBytes) binWriter := bytes.NewBuffer(bufferBytes)
err = context.ReadBin(binWriter) err = content.ReadBin(binWriter)
if err != nil { if err != nil {
context.SendError(err) content.SendError(err)
return err return err
} }
result := NewSaveResult() result := NewSaveResult()
result.Message = "saved successfully!" result.Message = "saved successfully!"
err = context.SendResult(result, 0) err = content.SendResult(result, 0)
if err != nil { if err != nil {
return err return err
} }
return err return err
} }
func loadHandler(context *Context) error { func loadHandler(content *Content) error {
var err error var err error
params := NewSaveParams() params := NewSaveParams()
err = context.BindParams(params) err = content.BindParams(params)
if err != nil { if err != nil {
return err return err
} }
err = context.ReadBin(io.Discard) err = content.ReadBin(io.Discard)
if err != nil { if err != nil {
context.SendError(err) content.SendError(err)
return err return err
} }
@@ -309,11 +309,11 @@ func loadHandler(context *Context) error {
result := NewSaveResult() result := NewSaveResult()
result.Message = "load successfully!" result.Message = "load successfully!"
err = context.SendResult(result, binSize) err = content.SendResult(result, binSize)
if err != nil { if err != nil {
return err return err
} }
binWriter := context.BinWriter() binWriter := content.BinWriter()
_, err = CopyBytes(binReader, binWriter, binSize) _, err = CopyBytes(binReader, binWriter, binSize)
if err != nil { if err != nil {
return err return err

View File

@@ -10,9 +10,10 @@ type FAddr struct {
} }
func NewFAddr() *FAddr { func NewFAddr() *FAddr {
var addr FAddr addr := FAddr{
addr.network = "tcp" network: "tcp",
addr.address = "127.0.0.1:5000" address: "127.0.0.1:5000",
}
return &addr return &addr
} }

View File

@@ -7,9 +7,10 @@
package dsrpc package dsrpc
import ( import (
"github.com/stretchr/testify/require"
"net" "net"
"testing" "testing"
"github.com/stretchr/testify/require"
) )
func TestFConn0(t *testing.T) { func TestFConn0(t *testing.T) {

View File

@@ -13,10 +13,12 @@ import (
"errors" "errors"
) )
const headerSize int64 = 16 * 2 const (
const sizeOfInt64 int = 8 headerSize int64 = 16 * 2
const magicCodeA int64 = 0xEE00ABBA sizeOfInt64 int = 8
const magicCodeB int64 = 0xEE44ABBA magicCodeA int64 = 0xEE00ABBA
magicCodeB int64 = 0xEE44ABBA
)
type Header struct { type Header struct {
magicCodeA int64 `json:"magicCodeA"` magicCodeA int64 `json:"magicCodeA"`
@@ -25,14 +27,14 @@ type Header struct {
magicCodeB int64 `json:"magicCodeB"` magicCodeB int64 `json:"magicCodeB"`
} }
func NewHeader() *Header { func NewEmptyHeader() *Header {
return &Header{ return &Header{
magicCodeA: magicCodeA, magicCodeA: magicCodeA,
magicCodeB: magicCodeB, magicCodeB: magicCodeB,
} }
} }
func (hdr *Header) JSON() []byte { func (hdr *Header) ToJson() []byte {
jBytes, _ := json.Marshal(hdr) jBytes, _ := json.Marshal(hdr)
return jBytes return jBytes
} }
@@ -54,35 +56,38 @@ func (hdr *Header) Pack() ([]byte, error) {
magicCodeBBytes := encoderI64(hdr.magicCodeB) magicCodeBBytes := encoderI64(hdr.magicCodeB)
headerBuffer.Write(magicCodeBBytes) headerBuffer.Write(magicCodeBBytes)
return headerBuffer.Bytes(), Err(err) return headerBuffer.Bytes(), err
} }
func UnpackHeader(headerBytes []byte) (*Header, error) { func UnpackHeader(headerBytes []byte) (*Header, error) {
var err error var err error
header := NewHeader()
headerReader := bytes.NewReader(headerBytes) headerReader := bytes.NewReader(headerBytes)
magicCodeABytes := make([]byte, sizeOfInt64) magicCodeABytes := make([]byte, sizeOfInt64)
headerReader.Read(magicCodeABytes) headerReader.Read(magicCodeABytes)
header.magicCodeA = decoderI64(magicCodeABytes)
rpcSizeBytes := make([]byte, sizeOfInt64) rpcSizeBytes := make([]byte, sizeOfInt64)
headerReader.Read(rpcSizeBytes) headerReader.Read(rpcSizeBytes)
header.rpcSize = decoderI64(rpcSizeBytes)
binSizeBytes := make([]byte, sizeOfInt64) binSizeBytes := make([]byte, sizeOfInt64)
headerReader.Read(binSizeBytes) headerReader.Read(binSizeBytes)
header.binSize = decoderI64(binSizeBytes)
magicCodeBBytes := make([]byte, sizeOfInt64) magicCodeBBytes := make([]byte, sizeOfInt64)
headerReader.Read(magicCodeBBytes) headerReader.Read(magicCodeBBytes)
header.magicCodeB = decoderI64(magicCodeBBytes)
header := &Header{
magicCodeA: decoderI64(magicCodeABytes),
rpcSize: decoderI64(rpcSizeBytes),
binSize: decoderI64(binSizeBytes),
magicCodeB: decoderI64(magicCodeBBytes),
}
if header.magicCodeA != magicCodeA || header.magicCodeB != magicCodeB { if header.magicCodeA != magicCodeA || header.magicCodeB != magicCodeB {
err = errors.New("wrong protocol magic code") err = errors.New("Wrong protocol magic code")
return header, Err(err) return header, err
} }
return header, Err(err) return header, err
} }
func encoderI64(i int64) []byte { func encoderI64(i int64) []byte {

View File

@@ -8,22 +8,22 @@ import (
"time" "time"
) )
func LogRequest(context *Context) error { func LogRequest(content *Content) error {
var err error var err error
logDebug("request:", string(context.reqRPC.JSON())) logDebug("request:", string(content.reqBlock.ToJson()))
return Err(err) return err
} }
func LogResponse(context *Context) error { func LogResponse(content *Content) error {
var err error var err error
logDebug("response:", string(context.resRPC.JSON())) logDebug("response:", string(content.resBlock.ToJson()))
return Err(err) return err
} }
func LogAccess(context *Context) error { func LogAccess(content *Content) error {
var err error var err error
execTime := time.Now().Sub(context.start) execTime := time.Now().Sub(content.start)
login := string(context.AuthIdent()) login := string(content.AuthIdent())
logAccess(context.remoteHost, login, context.reqRPC.Method, execTime) logAccess(content.remoteHost, login, content.reqBlock.Method, execTime)
return Err(err) return err
} }

View File

@@ -4,20 +4,15 @@
package dsrpc package dsrpc
import (
"encoding/json"
)
type Packet struct { type Packet struct {
header []byte header []byte
rcpPayload []byte rcpPayload []byte
} }
func NewPacket() *Packet { func NewEmptyPacket() *Packet {
return &Packet{} packet := &Packet{
} header: make([]byte, 0),
rcpPayload: make([]byte, 0),
func (pkt *Packet) JSON() []byte { }
jBytes, _ := json.Marshal(pkt) return packet
return jBytes
} }

View File

@@ -11,24 +11,31 @@ import (
encoder "github.com/vmihailenco/msgpack/v5" encoder "github.com/vmihailenco/msgpack/v5"
) )
type EmptyParams struct{}
func NewEmptyParams() *EmptyParams {
return &EmptyParams{}
}
type Request struct { type Request struct {
Method string `json:"method" msgpack:"method"` Method string `json:"method" msgpack:"method"`
Params any `json:"params,omitempty" msgpack:"params"` Params any `json:"params,omitempty" msgpack:"params"`
Auth *Auth `json:"auth,omitempty" msgpack:"auth"` Auth *Auth `json:"auth,omitempty" msgpack:"auth"`
} }
func NewRequest() *Request { func NewEmptyRequest() *Request {
req := &Request{} req := &Request{}
req.Auth = &Auth{} req.Auth = &Auth{}
req.Params = NewEmptyParams()
return req return req
} }
func (req *Request) Pack() ([]byte, error) { func (req *Request) Pack() ([]byte, error) {
rBytes, err := encoder.Marshal(req) rBytes, err := encoder.Marshal(req)
return rBytes, Err(err) return rBytes, err
} }
func (req *Request) JSON() []byte { func (req *Request) ToJson() []byte {
jBytes, _ := json.Marshal(req) jBytes, _ := json.Marshal(req)
return jBytes return jBytes
} }

View File

@@ -6,24 +6,33 @@ package dsrpc
import ( import (
"encoding/json" "encoding/json"
encoder "github.com/vmihailenco/msgpack/v5" encoder "github.com/vmihailenco/msgpack/v5"
) )
type EmptyResult struct{}
func NewEmptyResult() *EmptyResult {
return &EmptyResult{}
}
type Response struct { type Response struct {
Error string `json:"error" msgpack:"error"` Error string `json:"error" msgpack:"error"`
Result any `json:"result" msgpack:"result"` Result any `json:"result" msgpack:"result"`
} }
func NewResponse() *Response { func NewEmptyResponse() *Response {
return &Response{} return &Response{
Result: NewEmptyResult(),
}
} }
func (resp *Response) JSON() []byte { func (resp *Response) ToJson() []byte {
jBytes, _ := json.Marshal(resp) jBytes, _ := json.Marshal(resp)
return jBytes return jBytes
} }
func (resp *Response) Pack() ([]byte, error) { func (resp *Response) Pack() ([]byte, error) {
rBytes, err := encoder.Marshal(resp) rBytes, err := encoder.Marshal(resp)
return rBytes, Err(err) return rBytes, err
} }

146
server.go
View File

@@ -16,7 +16,7 @@ import (
encoder "github.com/vmihailenco/msgpack/v5" encoder "github.com/vmihailenco/msgpack/v5"
) )
type HandlerFunc = func(*Context) error type HandlerFunc = func(*Content) error
type Service struct { type Service struct {
handlers map[string]HandlerFunc handlers map[string]HandlerFunc
@@ -99,9 +99,9 @@ func (svc *Service) Listen(address string) error {
return err return err
} }
func notFound(context *Context) error { func notFound(content *Content) error {
execErr := errors.New("method not found") execErr := errors.New("method not found")
err := context.SendError(execErr) err := content.SendError(execErr)
return err return err
} }
@@ -133,14 +133,14 @@ func (svc *Service) handleConn(conn *net.TCPConn, wg *sync.WaitGroup) {
} }
} }
} }
context := CreateContext(conn) content := CreateContent(conn)
remoteAddr := conn.RemoteAddr().String() remoteAddr := conn.RemoteAddr().String()
remoteHost, _, _ := net.SplitHostPort(remoteAddr) remoteHost, _, _ := net.SplitHostPort(remoteAddr)
context.remoteHost = remoteHost content.remoteHost = remoteHost
context.binReader = conn content.binReader = conn
context.binWriter = io.Discard content.binWriter = io.Discard
exitFunc := func() { exitFunc := func() {
conn.Close() conn.Close()
@@ -159,149 +159,149 @@ func (svc *Service) handleConn(conn *net.TCPConn, wg *sync.WaitGroup) {
} }
defer recovFunc() defer recovFunc()
err = context.ReadRequest() err = content.ReadRequest()
if err != nil { if err != nil {
err = Err(err) err = err
return return
} }
err = context.BindMethod() err = content.BindMethod()
if err != nil { if err != nil {
err = Err(err) err = err
return return
} }
for _, mw := range svc.preMw { for _, mw := range svc.preMw {
err = mw(context) err = mw(content)
if err != nil { if err != nil {
err = Err(err) err = err
return return
} }
} }
err = svc.Route(context) err = svc.Route(content)
if err != nil { if err != nil {
err = Err(err) err = err
return return
} }
for _, mw := range svc.postMw { for _, mw := range svc.postMw {
err = mw(context) err = mw(content)
if err != nil { if err != nil {
err = Err(err) err = err
return return
} }
} }
return return
} }
func (svc *Service) Route(context *Context) error { func (svc *Service) Route(content *Content) error {
handler, ok := svc.handlers[context.reqRPC.Method] handler, ok := svc.handlers[content.reqBlock.Method]
if ok { if ok {
return Err(handler(context)) return handler(content)
} }
return Err(notFound(context)) return notFound(content)
} }
func (context *Context) ReadRequest() error { func (content *Content) ReadRequest() error {
var err error var err error
context.reqPacket.header, err = ReadBytes(context.sockReader, headerSize) content.reqPacket.header, err = ReadBytes(content.sockReader, headerSize)
if err != nil { if err != nil {
return Err(err) return err
} }
context.reqHeader, err = UnpackHeader(context.reqPacket.header) content.reqHeader, err = UnpackHeader(content.reqPacket.header)
if err != nil { if err != nil {
return Err(err) return err
} }
rpcSize := context.reqHeader.rpcSize rpcSize := content.reqHeader.rpcSize
context.reqPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize) content.reqPacket.rcpPayload, err = ReadBytes(content.sockReader, rpcSize)
if err != nil { if err != nil {
return Err(err) return err
} }
return Err(err) return err
} }
func (context *Context) BinWriter() io.Writer { func (content *Content) BinWriter() io.Writer {
return context.sockWriter return content.sockWriter
} }
func (context *Context) BinReader() io.Reader { func (content *Content) BinReader() io.Reader {
return context.sockReader return content.sockReader
} }
func (context *Context) BinSize() int64 { func (content *Content) BinSize() int64 {
return context.reqHeader.binSize return content.reqHeader.binSize
} }
func (context *Context) ReadBin(writer io.Writer) error { func (content *Content) ReadBin(writer io.Writer) error {
var err error var err error
_, err = CopyBytes(context.sockReader, writer, context.reqHeader.binSize) _, err = CopyBytes(content.sockReader, writer, content.reqHeader.binSize)
return Err(err) return err
} }
func (context *Context) BindMethod() error { func (content *Content) BindMethod() error {
var err error var err error
err = encoder.Unmarshal(context.reqPacket.rcpPayload, context.reqRPC) err = encoder.Unmarshal(content.reqPacket.rcpPayload, content.reqBlock)
return Err(err) return err
} }
func (context *Context) BindParams(params any) error { func (content *Content) BindParams(params any) error {
var err error var err error
context.reqRPC.Params = params content.reqBlock.Params = params
err = encoder.Unmarshal(context.reqPacket.rcpPayload, context.reqRPC) err = encoder.Unmarshal(content.reqPacket.rcpPayload, content.reqBlock)
if err != nil { if err != nil {
return Err(err) return err
} }
return Err(err) return err
} }
func (context *Context) SendResult(result any, binSize int64) error { func (content *Content) SendResult(result any, binSize int64) error {
var err error var err error
context.resRPC.Result = result content.resBlock.Result = result
context.resPacket.rcpPayload, err = context.resRPC.Pack() content.resPacket.rcpPayload, err = content.resBlock.Pack()
if err != nil { if err != nil {
return Err(err) return err
} }
context.resHeader.rpcSize = int64(len(context.resPacket.rcpPayload)) content.resHeader.rpcSize = int64(len(content.resPacket.rcpPayload))
context.resHeader.binSize = binSize content.resHeader.binSize = binSize
context.resPacket.header, err = context.resHeader.Pack() content.resPacket.header, err = content.resHeader.Pack()
if err != nil { if err != nil {
return Err(err) return err
} }
_, err = context.sockWriter.Write(context.resPacket.header) _, err = content.sockWriter.Write(content.resPacket.header)
if err != nil { if err != nil {
return Err(err) return err
} }
_, err = context.sockWriter.Write(context.resPacket.rcpPayload) _, err = content.sockWriter.Write(content.resPacket.rcpPayload)
if err != nil { if err != nil {
return Err(err) return err
} }
return Err(err) return err
} }
func (context *Context) SendError(execErr error) error { func (content *Content) SendError(execErr error) error {
var err error var err error
context.resRPC.Error = execErr.Error() content.resBlock.Error = execErr.Error()
context.resRPC.Result = NewEmpty() content.resBlock.Result = NewEmptyResult()
context.resPacket.rcpPayload, err = context.resRPC.Pack() content.resPacket.rcpPayload, err = content.resBlock.Pack()
if err != nil { if err != nil {
return Err(err) return err
} }
context.resHeader.rpcSize = int64(len(context.resPacket.rcpPayload)) content.resHeader.rpcSize = int64(len(content.resPacket.rcpPayload))
context.resPacket.header, err = context.resHeader.Pack() content.resPacket.header, err = content.resHeader.Pack()
if err != nil { if err != nil {
return Err(err) return err
} }
_, err = context.sockWriter.Write(context.resPacket.header) _, err = content.sockWriter.Write(content.resPacket.header)
if err != nil { if err != nil {
return Err(err) return err
} }
_, err = context.sockWriter.Write(context.resPacket.rcpPayload) _, err = content.sockWriter.Write(content.resPacket.rcpPayload)
if err != nil { if err != nil {
return Err(err) return err
} }
return Err(err) return err
} }

View File

@@ -13,7 +13,7 @@ import (
func ReadBytes(reader io.Reader, size int64) ([]byte, error) { func ReadBytes(reader io.Reader, size int64) ([]byte, error) {
buffer := make([]byte, size) buffer := make([]byte, size)
read, err := io.ReadFull(reader, buffer) read, err := io.ReadFull(reader, buffer)
return buffer[0:read], Err(err) return buffer[0:read], err
} }
func CopyBytes(reader io.Reader, writer io.Writer, dataSize int64) (int64, error) { func CopyBytes(reader io.Reader, writer io.Writer, dataSize int64) (int64, error) {
@@ -39,19 +39,19 @@ func CopyBytes(reader io.Reader, writer io.Writer, dataSize int64) (int64, error
received, err := reader.Read(buffer[0:bSize]) received, err := reader.Read(buffer[0:bSize])
if err != nil { if err != nil {
err = fmt.Errorf("read error: %v", err) err = fmt.Errorf("read error: %v", err)
return total, Err(err) return total, err
} }
recorded, err := writer.Write(buffer[0:received]) recorded, err := writer.Write(buffer[0:received])
if err != nil { if err != nil {
err = fmt.Errorf("write error: %v", err) err = fmt.Errorf("write error: %v", err)
return total, Err(err) return total, err
} }
if recorded != received { if recorded != received {
err = errors.New("size mismatch") err = errors.New("size mismatch")
return total, Err(err) return total, err
} }
total += int64(recorded) total += int64(recorded)
remains -= int64(recorded) remains -= int64(recorded)
} }
return total, Err(err) return total, err
} }

View File

@@ -9,25 +9,29 @@ import (
"net" "net"
) )
func LocalExec(method string, param any, result any, auth *Auth, handler HandlerFunc) error { func LocalExec(method string, param, result any, auth *Auth, handler HandlerFunc) error {
var err error var err error
cliConn, srvConn := NewFConn() cliConn, srvConn := NewFConn()
context := CreateContext(cliConn) content := CreateContent(cliConn)
context.reqRPC.Method = method content.reqBlock.Method = method
context.reqRPC.Params = param
context.reqRPC.Auth = auth
context.resRPC.Result = result
if context.reqRPC.Params == nil { if param != nil {
context.reqRPC.Params = NewEmpty() content.reqBlock.Params = param
} }
err = context.CreateRequest() if auth != nil {
content.reqBlock.Auth = auth
}
if result != nil {
content.resBlock.Result = result
}
err = content.CreateRequest()
if err != nil { if err != nil {
return err return err
} }
err = context.WriteRequest() err = content.WriteRequest()
if err != nil { if err != nil {
return err return err
} }
@@ -35,11 +39,11 @@ func LocalExec(method string, param any, result any, auth *Auth, handler Handler
if err != nil { if err != nil {
return err return err
} }
err = context.ReadResponse() err = content.ReadResponse()
if err != nil { if err != nil {
return err return err
} }
err = context.BindResponse() err = content.BindResponse()
if err != nil { if err != nil {
return err return err
} }
@@ -53,29 +57,33 @@ func LocalPut(method string, reader io.Reader, size int64, param, result any, au
cliConn, srvConn := NewFConn() cliConn, srvConn := NewFConn()
context := CreateContext(cliConn) content := CreateContent(cliConn)
context.reqRPC.Method = method content.reqBlock.Method = method
context.reqRPC.Params = param
context.reqRPC.Auth = auth
context.resRPC.Result = result
context.binReader = reader if param != nil {
context.binWriter = cliConn content.reqBlock.Params = param
context.reqHeader.binSize = size
if context.reqRPC.Params == nil {
context.reqRPC.Params = NewEmpty()
} }
err = context.CreateRequest() if auth != nil {
content.reqBlock.Auth = auth
}
if result != nil {
content.resBlock.Result = result
}
content.binReader = reader
content.binWriter = cliConn
content.reqHeader.binSize = size
err = content.CreateRequest()
if err != nil { if err != nil {
return err return err
} }
err = context.WriteRequest() err = content.WriteRequest()
if err != nil { if err != nil {
return err return err
} }
err = context.UploadBin() err = content.UploadBin()
if err != nil { if err != nil {
return err return err
} }
@@ -83,11 +91,11 @@ func LocalPut(method string, reader io.Reader, size int64, param, result any, au
if err != nil { if err != nil {
return err return err
} }
err = context.ReadResponse() err = content.ReadResponse()
if err != nil { if err != nil {
return err return err
} }
err = context.BindResponse() err = content.BindResponse()
if err != nil { if err != nil {
return err return err
} }
@@ -99,23 +107,27 @@ func LocalGet(method string, writer io.Writer, param, result any, auth *Auth, ha
cliConn, srvConn := NewFConn() cliConn, srvConn := NewFConn()
context := CreateContext(cliConn) content := CreateContent(cliConn)
context.reqRPC.Method = method content.reqBlock.Method = method
context.reqRPC.Params = param
context.reqRPC.Auth = auth
context.resRPC.Result = result
context.binReader = cliConn if param != nil {
context.binWriter = writer content.reqBlock.Params = param
if context.reqRPC.Params == nil {
context.reqRPC.Params = NewEmpty()
} }
err = context.CreateRequest() if auth != nil {
content.reqBlock.Auth = auth
}
if result != nil {
content.resBlock.Result = result
}
content.binReader = cliConn
content.binWriter = writer
err = content.CreateRequest()
if err != nil { if err != nil {
return err return err
} }
err = context.WriteRequest() err = content.WriteRequest()
if err != nil { if err != nil {
return err return err
} }
@@ -124,15 +136,15 @@ func LocalGet(method string, writer io.Writer, param, result any, auth *Auth, ha
if err != nil { if err != nil {
return err return err
} }
err = context.ReadResponse() err = content.ReadResponse()
if err != nil { if err != nil {
return err return err
} }
err = context.DownloadBin() err = content.DownloadBin()
if err != nil { if err != nil {
return err return err
} }
err = context.BindResponse() err = content.BindResponse()
if err != nil { if err != nil {
return err return err
} }
@@ -141,22 +153,22 @@ func LocalGet(method string, writer io.Writer, param, result any, auth *Auth, ha
func LocalService(conn net.Conn, handler HandlerFunc) error { func LocalService(conn net.Conn, handler HandlerFunc) error {
var err error var err error
context := CreateContext(conn) content := CreateContent(conn)
remoteAddr := conn.RemoteAddr().String() remoteAddr := conn.RemoteAddr().String()
remoteHost, _, _ := net.SplitHostPort(remoteAddr) remoteHost, _, _ := net.SplitHostPort(remoteAddr)
context.remoteHost = remoteHost content.remoteHost = remoteHost
context.binReader = conn content.binReader = conn
context.binWriter = io.Discard content.binWriter = io.Discard
err = context.ReadRequest() err = content.ReadRequest()
if err != nil { if err != nil {
return err return err
} }
err = context.BindMethod() err = content.BindMethod()
if err != nil { if err != nil {
return err return err
} }
return handler(context) return handler(content)
} }