Little refactoring

This commit is contained in:
2023-04-01 00:18:44 +02:00
parent 8dc753cd95
commit 8b3e722ea5
16 changed files with 354 additions and 562 deletions

276
client.go
View File

@@ -14,80 +14,71 @@ import (
encoder "github.com/vmihailenco/msgpack/v5"
)
func Put(address string, method string, reader io.Reader, size int64, param, result any, auth *Auth) error {
func Put(address string, method string, reader io.Reader, binSize int64, param, result any, auth *Auth) error {
var err error
addr, err := net.ResolveTCPAddr("tcp", address)
if err != nil {
err = fmt.Errorf("unable to resolve adddress: %s", err)
return Err(err)
return err
}
conn, err := net.DialTCP("tcp", nil, addr)
if err != nil {
return Err(err)
return err
}
defer conn.Close()
//err = conn.SetKeepAlive(true)
//if err != nil {
// err = fmt.Errorf("unable to set keepalive: %s", err)
// return Err(err)
//}
//err = conn.SetKeepAlivePeriod(10 * time.Second)
//if err != nil {
// err = fmt.Errorf("unable to set keepalive period: %s", err)
// return Err(err)
//}
return ConnPut(conn, method, reader, size, param, result, auth)
return ConnPut(conn, method, reader, binSize, param, result, auth)
}
func ConnPut(conn net.Conn, method string, reader io.Reader, size int64, param, result any, auth *Auth) error {
func ConnPut(conn net.Conn, method string, reader io.Reader, binSize int64, param, result any, auth *Auth) error {
var err error
context := CreateContext(conn)
context.reqRPC.Method = method
context.reqRPC.Params = param
context.reqRPC.Auth = auth
context.resRPC.Result = result
content := CreateContent(conn)
context.binReader = reader
context.binWriter = conn
context.reqHeader.binSize = size
if context.reqRPC.Params == nil {
context.reqRPC.Params = NewEmpty()
content.reqBlock.Method = method
if param != nil {
content.reqBlock.Params = param
}
if auth != nil {
content.reqBlock.Auth = auth
}
if result != nil {
content.resBlock.Result = result
}
err = context.CreateRequest()
content.binReader = reader
content.binWriter = conn
content.reqHeader.binSize = binSize
err = content.CreateRequest()
if err != nil {
return Err(err)
return err
}
err = context.WriteRequest()
err = content.WriteRequest()
if err != nil {
return Err(err)
return err
}
var wg sync.WaitGroup
errChan := make(chan error, 1)
wg.Add(1)
go context.ReadResponseAsync(&wg, errChan)
go content.ReadResponseAsync(&wg, errChan)
wg.Add(1)
go context.UploadBinAsync(&wg)
go content.UploadBinAsync(&wg)
wg.Wait()
err = <-errChan
if err != nil {
return Err(err)
return err
}
err = context.BindResponse()
err = content.BindResponse()
if err != nil {
return Err(err)
return err
}
return Err(err)
return err
}
func Get(address string, method string, writer io.Writer, param, result any, auth *Auth) error {
@@ -96,65 +87,56 @@ func Get(address string, method string, writer io.Writer, param, result any, aut
addr, err := net.ResolveTCPAddr("tcp", address)
if err != nil {
err = fmt.Errorf("unable to resolve adddress: %s", err)
return Err(err)
return err
}
conn, err := net.DialTCP("tcp", nil, addr)
if err != nil {
return Err(err)
return err
}
defer conn.Close()
//err = conn.SetKeepAlive(true)
//if err != nil {
// err = fmt.Errorf("unable to set keepalive: %s", err)
// return Err(err)
//}
//err = conn.SetKeepAlivePeriod(10 * time.Second)
//if err != nil {
// err = fmt.Errorf("unable to set keepalive period: %s", err)
// return Err(err)
//}
return ConnGet(conn, method, writer, param, result, auth)
}
func ConnGet(conn net.Conn, method string, writer io.Writer, param, result any, auth *Auth) error {
var err error
context := CreateContext(conn)
context.reqRPC.Method = method
context.reqRPC.Params = param
context.reqRPC.Auth = auth
context.resRPC.Result = result
content := CreateContent(conn)
content.reqBlock.Method = method
if param != nil {
content.reqBlock.Params = param
}
if auth != nil {
content.reqBlock.Auth = auth
}
if result != nil {
content.resBlock.Result = result
}
context.binReader = conn
context.binWriter = writer
content.binReader = conn
content.binWriter = writer
if context.reqRPC.Params == nil {
context.reqRPC.Params = NewEmpty()
}
err = context.CreateRequest()
err = content.CreateRequest()
if err != nil {
return Err(err)
return err
}
err = context.WriteRequest()
err = content.WriteRequest()
if err != nil {
return Err(err)
return err
}
err = context.ReadResponse()
err = content.ReadResponse()
if err != nil {
return Err(err)
return err
}
err = context.DownloadBin()
err = content.DownloadBin()
if err != nil {
return Err(err)
return err
}
err = context.BindResponse()
err = content.BindResponse()
if err != nil {
return Err(err)
return err
}
return Err(err)
return err
}
func Exec(address, method string, param any, result any, auth *Auth) error {
@@ -163,170 +145,162 @@ func Exec(address, method string, param any, result any, auth *Auth) error {
addr, err := net.ResolveTCPAddr("tcp", address)
if err != nil {
err = fmt.Errorf("unable to resolve adddress: %s", err)
return Err(err)
return err
}
conn, err := net.DialTCP("tcp", nil, addr)
if err != nil {
return Err(err)
return err
}
defer conn.Close()
//err = conn.SetKeepAlive(true)
//if err != nil {
// err = fmt.Errorf("unable to set keepalive: %s", err)
// return Err(err)
//}
//err = conn.SetKeepAlivePeriod(10 * time.Second)
//if err != nil {
// err = fmt.Errorf("unable to set keepalive period: %s", err)
// return Err(err)
//}
err = ConnExec(conn, method, param, result, auth)
if err != nil {
return Err(err)
return err
}
return Err(err)
return err
}
func ConnExec(conn net.Conn, method string, param any, result any, auth *Auth) error {
var err error
context := CreateContext(conn)
context.reqRPC.Method = method
context.reqRPC.Params = param
context.reqRPC.Auth = auth
context.resRPC.Result = result
content := CreateContent(conn)
content.reqBlock.Method = method
if context.reqRPC.Params == nil {
context.reqRPC.Params = NewEmpty()
if param != nil {
content.reqBlock.Params = param
}
if result != nil {
content.resBlock.Result = result
}
if auth != nil {
content.reqBlock.Auth = auth
}
err = context.CreateRequest()
err = content.CreateRequest()
if err != nil {
return Err(err)
return err
}
err = context.WriteRequest()
err = content.WriteRequest()
if err != nil {
return Err(err)
return err
}
err = context.ReadResponse()
err = content.ReadResponse()
if err != nil {
return Err(err)
return err
}
err = context.BindResponse()
err = content.BindResponse()
if err != nil {
return Err(err)
return err
}
return Err(err)
return err
}
func (context *Context) CreateRequest() error {
func (content *Content) CreateRequest() error {
var err error
context.reqPacket.rcpPayload, err = context.reqRPC.Pack()
content.reqPacket.rcpPayload, err = content.reqBlock.Pack()
if err != nil {
return Err(err)
return err
}
rpcSize := int64(len(context.reqPacket.rcpPayload))
context.reqHeader.rpcSize = rpcSize
rpcSize := int64(len(content.reqPacket.rcpPayload))
content.reqHeader.rpcSize = rpcSize
context.reqPacket.header, err = context.reqHeader.Pack()
content.reqPacket.header, err = content.reqHeader.Pack()
if err != nil {
return Err(err)
return err
}
return Err(err)
return err
}
func (context *Context) WriteRequest() error {
func (content *Content) WriteRequest() error {
var err error
_, err = context.sockWriter.Write(context.reqPacket.header)
_, err = content.sockWriter.Write(content.reqPacket.header)
if err != nil {
return Err(err)
return err
}
_, err = context.sockWriter.Write(context.reqPacket.rcpPayload)
_, err = content.sockWriter.Write(content.reqPacket.rcpPayload)
if err != nil {
return Err(err)
return err
}
return Err(err)
return err
}
func (context *Context) UploadBin() error {
func (content *Content) UploadBin() error {
var err error
_, err = CopyBytes(context.binReader, context.binWriter, context.reqHeader.binSize)
return Err(err)
_, err = CopyBytes(content.binReader, content.binWriter, content.reqHeader.binSize)
return err
}
func (context *Context) ReadResponse() error {
func (content *Content) ReadResponse() error {
var err error
context.resPacket.header, err = ReadBytes(context.sockReader, headerSize)
content.resPacket.header, err = ReadBytes(content.sockReader, headerSize)
if err != nil {
return Err(err)
return err
}
context.resHeader, err = UnpackHeader(context.resPacket.header)
content.resHeader, err = UnpackHeader(content.resPacket.header)
if err != nil {
return Err(err)
return err
}
rpcSize := context.resHeader.rpcSize
context.resPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize)
rpcSize := content.resHeader.rpcSize
content.resPacket.rcpPayload, err = ReadBytes(content.sockReader, rpcSize)
if err != nil {
return Err(err)
return err
}
return Err(err)
return err
}
func (context *Context) UploadBinAsync(wg *sync.WaitGroup) {
func (content *Content) UploadBinAsync(wg *sync.WaitGroup) {
exitFunc := func() {
wg.Done()
}
defer exitFunc()
_, _ = CopyBytes(context.binReader, context.binWriter, context.reqHeader.binSize)
_, _ = CopyBytes(content.binReader, content.binWriter, content.reqHeader.binSize)
return
}
func (context *Context) ReadResponseAsync(wg *sync.WaitGroup, errChan chan error) {
func (content *Content) ReadResponseAsync(wg *sync.WaitGroup, errChan chan error) {
var err error
exitFunc := func() {
errChan <- err
wg.Done()
}
defer exitFunc()
context.resPacket.header, err = ReadBytes(context.sockReader, headerSize)
content.resPacket.header, err = ReadBytes(content.sockReader, headerSize)
if err != nil {
err = Err(err)
err = err
return
}
context.resHeader, err = UnpackHeader(context.resPacket.header)
content.resHeader, err = UnpackHeader(content.resPacket.header)
if err != nil {
err = Err(err)
err = err
return
}
rpcSize := context.resHeader.rpcSize
context.resPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize)
rpcSize := content.resHeader.rpcSize
content.resPacket.rcpPayload, err = ReadBytes(content.sockReader, rpcSize)
if err != nil {
err = Err(err)
err = err
return
}
return
}
func (context *Context) DownloadBin() error {
func (content *Content) DownloadBin() error {
var err error
_, err = CopyBytes(context.binReader, context.binWriter, context.resHeader.binSize)
return Err(err)
_, err = CopyBytes(content.binReader, content.binWriter, content.resHeader.binSize)
return err
}
func (context *Context) BindResponse() error {
func (content *Content) BindResponse() error {
var err error
err = encoder.Unmarshal(context.resPacket.rcpPayload, context.resRPC)
err = encoder.Unmarshal(content.resPacket.rcpPayload, content.resBlock)
if err != nil {
return Err(err)
return err
}
if len(context.resRPC.Error) > 0 {
err = errors.New(context.resRPC.Error)
return Err(err)
if len(content.resBlock.Error) > 0 {
err = errors.New(content.resBlock.Error)
return err
}
return Err(err)
return err
}

View File

@@ -1,5 +0,0 @@
/*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*/
package dsrpc

View File

@@ -1,152 +0,0 @@
/*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*/
package dsrpc
import (
"io"
"net"
"time"
)
type Context struct {
start time.Time
remoteHost string
sockReader io.Reader
sockWriter io.Writer
reqHeader *Header
reqRPC *Request
reqPacket *Packet
resPacket *Packet
resHeader *Header
resRPC *Response
binReader io.Reader
binWriter io.Writer
}
func NewContext() *Context {
context := &Context{}
context.start = time.Now()
return context
}
func CreateContext(conn net.Conn) *Context {
context := &Context{}
context.start = time.Now()
context.sockReader = conn
context.sockWriter = conn
context.reqPacket = NewPacket()
context.resPacket = NewPacket()
context.reqHeader = NewHeader()
context.reqRPC = NewRequest()
context.resHeader = NewHeader()
context.resRPC = NewResponse()
context.resRPC = NewResponse()
return context
}
func (context *Context) Request() *Request {
return context.reqRPC
}
func (context *Context) RemoteHost() string {
return context.remoteHost
}
func (context *Context) Start() time.Time {
return context.start
}
func (context *Context) Method() string {
var method string
if context.reqRPC != nil {
method = context.reqRPC.Method
}
return method
}
func (context *Context) ReqRpcSize() int64 {
var size int64
if context.reqHeader != nil {
size = context.reqHeader.rpcSize
}
return size
}
func (context *Context) ReqBinSize() int64 {
var size int64
if context.reqHeader != nil {
size = context.reqHeader.binSize
}
return size
}
func (context *Context) ResBinSize() int64 {
var size int64
if context.resHeader != nil {
size = context.resHeader.binSize
}
return size
}
func (context *Context) ResRpcSize() int64 {
var size int64
if context.resHeader != nil {
size = context.resHeader.rpcSize
}
return size
}
func (context *Context) ReqSize() int64 {
var size int64
if context.reqHeader != nil {
size += context.reqHeader.binSize
size += context.reqHeader.rpcSize
}
return size
}
func (context *Context) ResSize() int64 {
var size int64
if context.resHeader != nil {
size += context.resHeader.binSize
size += context.resHeader.rpcSize
}
return size
}
func (context *Context) SetAuthIdent(ident []byte) {
context.reqRPC.Auth.Ident = ident
}
func (context *Context) SetAuthSalt(salt []byte) {
context.reqRPC.Auth.Salt = salt
}
func (context *Context) SetAuthHash(hash []byte) {
context.reqRPC.Auth.Hash = hash
}
func (context *Context) AuthIdent() []byte {
return context.reqRPC.Auth.Ident
}
func (context *Context) AuthSalt() []byte {
return context.reqRPC.Auth.Salt
}
func (context *Context) AuthHash() []byte {
return context.reqRPC.Auth.Hash
}
func (context *Context) Auth() *Auth {
return context.reqRPC.Auth
}

View File

@@ -1,13 +0,0 @@
/*
*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*
*/
package dsrpc
type Empty struct{}
func NewEmpty() *Empty {
return &Empty{}
}

View File

@@ -1,42 +0,0 @@
/*
* Copyright 2022 Oleg Borodin <borodin@unix7.org>
*/
package dsrpc
import (
"fmt"
"io"
"runtime"
)
var develMode bool = false
var debugMode bool = false
func SetDevelMode(mode bool) {
develMode = mode
}
func SetDebugMode(mode bool) {
debugMode = mode
}
func Err(err error) error {
switch err {
case io.EOF:
return err
}
if err != nil {
switch {
case develMode == true:
pc, filename, line, _ := runtime.Caller(1)
funcName := runtime.FuncForPC(pc).Name()
err = fmt.Errorf(" %s:%d:%s:%s", filename, line, funcName, err.Error())
case debugMode == true:
pc, _, line, _ := runtime.Caller(1)
funcName := runtime.FuncForPC(pc).Name()
err = fmt.Errorf(" %s:%d:%s ", funcName, line, err.Error())
default:
}
}
return err
}

View File

@@ -209,58 +209,58 @@ func testServ(quiet bool) error {
return err
}
func auth(context *Context) error {
func auth(content *Content) error {
var err error
reqIdent := context.AuthIdent()
reqSalt := context.AuthSalt()
reqHash := context.AuthHash()
reqIdent := content.AuthIdent()
reqSalt := content.AuthSalt()
reqHash := content.AuthHash()
ident := reqIdent
pass := []byte("12345")
auth := context.Auth()
auth := content.Auth()
logDebug("auth ", string(auth.JSON()))
ok := CheckHash(ident, pass, reqSalt, reqHash)
logDebug("auth ok:", ok)
if !ok {
err = errors.New("auth ident or pass missmatch")
context.SendError(err)
content.SendError(err)
return err
}
return err
}
func helloHandler(context *Context) error {
func helloHandler(content *Content) error {
var err error
params := NewHelloParams()
err = context.BindParams(params)
err = content.BindParams(params)
if err != nil {
return err
}
err = context.ReadBin(io.Discard)
err = content.ReadBin(io.Discard)
if err != nil {
context.SendError(err)
content.SendError(err)
return err
}
result := NewHelloResult()
result.Message = "hello, client!"
err = context.SendResult(result, 0)
err = content.SendResult(result, 0)
if err != nil {
return err
}
return err
}
func saveHandler(context *Context) error {
func saveHandler(content *Content) error {
var err error
params := NewSaveParams()
err = context.BindParams(params)
err = content.BindParams(params)
if err != nil {
return err
}
@@ -268,34 +268,34 @@ func saveHandler(context *Context) error {
bufferBytes := make([]byte, 0, 1024)
binWriter := bytes.NewBuffer(bufferBytes)
err = context.ReadBin(binWriter)
err = content.ReadBin(binWriter)
if err != nil {
context.SendError(err)
content.SendError(err)
return err
}
result := NewSaveResult()
result.Message = "saved successfully!"
err = context.SendResult(result, 0)
err = content.SendResult(result, 0)
if err != nil {
return err
}
return err
}
func loadHandler(context *Context) error {
func loadHandler(content *Content) error {
var err error
params := NewSaveParams()
err = context.BindParams(params)
err = content.BindParams(params)
if err != nil {
return err
}
err = context.ReadBin(io.Discard)
err = content.ReadBin(io.Discard)
if err != nil {
context.SendError(err)
content.SendError(err)
return err
}
@@ -309,11 +309,11 @@ func loadHandler(context *Context) error {
result := NewSaveResult()
result.Message = "load successfully!"
err = context.SendResult(result, binSize)
err = content.SendResult(result, binSize)
if err != nil {
return err
}
binWriter := context.BinWriter()
binWriter := content.BinWriter()
_, err = CopyBytes(binReader, binWriter, binSize)
if err != nil {
return err

View File

@@ -10,9 +10,10 @@ type FAddr struct {
}
func NewFAddr() *FAddr {
var addr FAddr
addr.network = "tcp"
addr.address = "127.0.0.1:5000"
addr := FAddr{
network: "tcp",
address: "127.0.0.1:5000",
}
return &addr
}

View File

@@ -7,9 +7,10 @@
package dsrpc
import (
"github.com/stretchr/testify/require"
"net"
"testing"
"github.com/stretchr/testify/require"
)
func TestFConn0(t *testing.T) {

View File

@@ -13,10 +13,12 @@ import (
"errors"
)
const headerSize int64 = 16 * 2
const sizeOfInt64 int = 8
const magicCodeA int64 = 0xEE00ABBA
const magicCodeB int64 = 0xEE44ABBA
const (
headerSize int64 = 16 * 2
sizeOfInt64 int = 8
magicCodeA int64 = 0xEE00ABBA
magicCodeB int64 = 0xEE44ABBA
)
type Header struct {
magicCodeA int64 `json:"magicCodeA"`
@@ -25,14 +27,14 @@ type Header struct {
magicCodeB int64 `json:"magicCodeB"`
}
func NewHeader() *Header {
func NewEmptyHeader() *Header {
return &Header{
magicCodeA: magicCodeA,
magicCodeB: magicCodeB,
}
}
func (hdr *Header) JSON() []byte {
func (hdr *Header) ToJson() []byte {
jBytes, _ := json.Marshal(hdr)
return jBytes
}
@@ -54,35 +56,38 @@ func (hdr *Header) Pack() ([]byte, error) {
magicCodeBBytes := encoderI64(hdr.magicCodeB)
headerBuffer.Write(magicCodeBBytes)
return headerBuffer.Bytes(), Err(err)
return headerBuffer.Bytes(), err
}
func UnpackHeader(headerBytes []byte) (*Header, error) {
var err error
header := NewHeader()
headerReader := bytes.NewReader(headerBytes)
magicCodeABytes := make([]byte, sizeOfInt64)
headerReader.Read(magicCodeABytes)
header.magicCodeA = decoderI64(magicCodeABytes)
rpcSizeBytes := make([]byte, sizeOfInt64)
headerReader.Read(rpcSizeBytes)
header.rpcSize = decoderI64(rpcSizeBytes)
binSizeBytes := make([]byte, sizeOfInt64)
headerReader.Read(binSizeBytes)
header.binSize = decoderI64(binSizeBytes)
magicCodeBBytes := make([]byte, sizeOfInt64)
headerReader.Read(magicCodeBBytes)
header.magicCodeB = decoderI64(magicCodeBBytes)
header := &Header{
magicCodeA: decoderI64(magicCodeABytes),
rpcSize: decoderI64(rpcSizeBytes),
binSize: decoderI64(binSizeBytes),
magicCodeB: decoderI64(magicCodeBBytes),
}
if header.magicCodeA != magicCodeA || header.magicCodeB != magicCodeB {
err = errors.New("wrong protocol magic code")
return header, Err(err)
err = errors.New("Wrong protocol magic code")
return header, err
}
return header, Err(err)
return header, err
}
func encoderI64(i int64) []byte {

View File

@@ -8,22 +8,22 @@ import (
"time"
)
func LogRequest(context *Context) error {
func LogRequest(content *Content) error {
var err error
logDebug("request:", string(context.reqRPC.JSON()))
return Err(err)
logDebug("request:", string(content.reqBlock.ToJson()))
return err
}
func LogResponse(context *Context) error {
func LogResponse(content *Content) error {
var err error
logDebug("response:", string(context.resRPC.JSON()))
return Err(err)
logDebug("response:", string(content.resBlock.ToJson()))
return err
}
func LogAccess(context *Context) error {
func LogAccess(content *Content) error {
var err error
execTime := time.Now().Sub(context.start)
login := string(context.AuthIdent())
logAccess(context.remoteHost, login, context.reqRPC.Method, execTime)
return Err(err)
execTime := time.Now().Sub(content.start)
login := string(content.AuthIdent())
logAccess(content.remoteHost, login, content.reqBlock.Method, execTime)
return err
}

View File

@@ -4,20 +4,15 @@
package dsrpc
import (
"encoding/json"
)
type Packet struct {
header []byte
rcpPayload []byte
}
func NewPacket() *Packet {
return &Packet{}
}
func (pkt *Packet) JSON() []byte {
jBytes, _ := json.Marshal(pkt)
return jBytes
func NewEmptyPacket() *Packet {
packet := &Packet{
header: make([]byte, 0),
rcpPayload: make([]byte, 0),
}
return packet
}

View File

@@ -11,24 +11,31 @@ import (
encoder "github.com/vmihailenco/msgpack/v5"
)
type EmptyParams struct{}
func NewEmptyParams() *EmptyParams {
return &EmptyParams{}
}
type Request struct {
Method string `json:"method" msgpack:"method"`
Params any `json:"params,omitempty" msgpack:"params"`
Auth *Auth `json:"auth,omitempty" msgpack:"auth"`
}
func NewRequest() *Request {
func NewEmptyRequest() *Request {
req := &Request{}
req.Auth = &Auth{}
req.Params = NewEmptyParams()
return req
}
func (req *Request) Pack() ([]byte, error) {
rBytes, err := encoder.Marshal(req)
return rBytes, Err(err)
return rBytes, err
}
func (req *Request) JSON() []byte {
func (req *Request) ToJson() []byte {
jBytes, _ := json.Marshal(req)
return jBytes
}

View File

@@ -6,24 +6,33 @@ package dsrpc
import (
"encoding/json"
encoder "github.com/vmihailenco/msgpack/v5"
)
type EmptyResult struct{}
func NewEmptyResult() *EmptyResult {
return &EmptyResult{}
}
type Response struct {
Error string `json:"error" msgpack:"error"`
Result any `json:"result" msgpack:"result"`
}
func NewResponse() *Response {
return &Response{}
func NewEmptyResponse() *Response {
return &Response{
Result: NewEmptyResult(),
}
}
func (resp *Response) JSON() []byte {
func (resp *Response) ToJson() []byte {
jBytes, _ := json.Marshal(resp)
return jBytes
}
func (resp *Response) Pack() ([]byte, error) {
rBytes, err := encoder.Marshal(resp)
return rBytes, Err(err)
return rBytes, err
}

146
server.go
View File

@@ -16,7 +16,7 @@ import (
encoder "github.com/vmihailenco/msgpack/v5"
)
type HandlerFunc = func(*Context) error
type HandlerFunc = func(*Content) error
type Service struct {
handlers map[string]HandlerFunc
@@ -99,9 +99,9 @@ func (svc *Service) Listen(address string) error {
return err
}
func notFound(context *Context) error {
func notFound(content *Content) error {
execErr := errors.New("method not found")
err := context.SendError(execErr)
err := content.SendError(execErr)
return err
}
@@ -133,14 +133,14 @@ func (svc *Service) handleConn(conn *net.TCPConn, wg *sync.WaitGroup) {
}
}
}
context := CreateContext(conn)
content := CreateContent(conn)
remoteAddr := conn.RemoteAddr().String()
remoteHost, _, _ := net.SplitHostPort(remoteAddr)
context.remoteHost = remoteHost
content.remoteHost = remoteHost
context.binReader = conn
context.binWriter = io.Discard
content.binReader = conn
content.binWriter = io.Discard
exitFunc := func() {
conn.Close()
@@ -159,149 +159,149 @@ func (svc *Service) handleConn(conn *net.TCPConn, wg *sync.WaitGroup) {
}
defer recovFunc()
err = context.ReadRequest()
err = content.ReadRequest()
if err != nil {
err = Err(err)
err = err
return
}
err = context.BindMethod()
err = content.BindMethod()
if err != nil {
err = Err(err)
err = err
return
}
for _, mw := range svc.preMw {
err = mw(context)
err = mw(content)
if err != nil {
err = Err(err)
err = err
return
}
}
err = svc.Route(context)
err = svc.Route(content)
if err != nil {
err = Err(err)
err = err
return
}
for _, mw := range svc.postMw {
err = mw(context)
err = mw(content)
if err != nil {
err = Err(err)
err = err
return
}
}
return
}
func (svc *Service) Route(context *Context) error {
handler, ok := svc.handlers[context.reqRPC.Method]
func (svc *Service) Route(content *Content) error {
handler, ok := svc.handlers[content.reqBlock.Method]
if ok {
return Err(handler(context))
return handler(content)
}
return Err(notFound(context))
return notFound(content)
}
func (context *Context) ReadRequest() error {
func (content *Content) ReadRequest() error {
var err error
context.reqPacket.header, err = ReadBytes(context.sockReader, headerSize)
content.reqPacket.header, err = ReadBytes(content.sockReader, headerSize)
if err != nil {
return Err(err)
return err
}
context.reqHeader, err = UnpackHeader(context.reqPacket.header)
content.reqHeader, err = UnpackHeader(content.reqPacket.header)
if err != nil {
return Err(err)
return err
}
rpcSize := context.reqHeader.rpcSize
context.reqPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize)
rpcSize := content.reqHeader.rpcSize
content.reqPacket.rcpPayload, err = ReadBytes(content.sockReader, rpcSize)
if err != nil {
return Err(err)
return err
}
return Err(err)
return err
}
func (context *Context) BinWriter() io.Writer {
return context.sockWriter
func (content *Content) BinWriter() io.Writer {
return content.sockWriter
}
func (context *Context) BinReader() io.Reader {
return context.sockReader
func (content *Content) BinReader() io.Reader {
return content.sockReader
}
func (context *Context) BinSize() int64 {
return context.reqHeader.binSize
func (content *Content) BinSize() int64 {
return content.reqHeader.binSize
}
func (context *Context) ReadBin(writer io.Writer) error {
func (content *Content) ReadBin(writer io.Writer) error {
var err error
_, err = CopyBytes(context.sockReader, writer, context.reqHeader.binSize)
return Err(err)
_, err = CopyBytes(content.sockReader, writer, content.reqHeader.binSize)
return err
}
func (context *Context) BindMethod() error {
func (content *Content) BindMethod() error {
var err error
err = encoder.Unmarshal(context.reqPacket.rcpPayload, context.reqRPC)
return Err(err)
err = encoder.Unmarshal(content.reqPacket.rcpPayload, content.reqBlock)
return err
}
func (context *Context) BindParams(params any) error {
func (content *Content) BindParams(params any) error {
var err error
context.reqRPC.Params = params
err = encoder.Unmarshal(context.reqPacket.rcpPayload, context.reqRPC)
content.reqBlock.Params = params
err = encoder.Unmarshal(content.reqPacket.rcpPayload, content.reqBlock)
if err != nil {
return Err(err)
return err
}
return Err(err)
return err
}
func (context *Context) SendResult(result any, binSize int64) error {
func (content *Content) SendResult(result any, binSize int64) error {
var err error
context.resRPC.Result = result
content.resBlock.Result = result
context.resPacket.rcpPayload, err = context.resRPC.Pack()
content.resPacket.rcpPayload, err = content.resBlock.Pack()
if err != nil {
return Err(err)
return err
}
context.resHeader.rpcSize = int64(len(context.resPacket.rcpPayload))
context.resHeader.binSize = binSize
content.resHeader.rpcSize = int64(len(content.resPacket.rcpPayload))
content.resHeader.binSize = binSize
context.resPacket.header, err = context.resHeader.Pack()
content.resPacket.header, err = content.resHeader.Pack()
if err != nil {
return Err(err)
return err
}
_, err = context.sockWriter.Write(context.resPacket.header)
_, err = content.sockWriter.Write(content.resPacket.header)
if err != nil {
return Err(err)
return err
}
_, err = context.sockWriter.Write(context.resPacket.rcpPayload)
_, err = content.sockWriter.Write(content.resPacket.rcpPayload)
if err != nil {
return Err(err)
return err
}
return Err(err)
return err
}
func (context *Context) SendError(execErr error) error {
func (content *Content) SendError(execErr error) error {
var err error
context.resRPC.Error = execErr.Error()
context.resRPC.Result = NewEmpty()
content.resBlock.Error = execErr.Error()
content.resBlock.Result = NewEmptyResult()
context.resPacket.rcpPayload, err = context.resRPC.Pack()
content.resPacket.rcpPayload, err = content.resBlock.Pack()
if err != nil {
return Err(err)
return err
}
context.resHeader.rpcSize = int64(len(context.resPacket.rcpPayload))
context.resPacket.header, err = context.resHeader.Pack()
content.resHeader.rpcSize = int64(len(content.resPacket.rcpPayload))
content.resPacket.header, err = content.resHeader.Pack()
if err != nil {
return Err(err)
return err
}
_, err = context.sockWriter.Write(context.resPacket.header)
_, err = content.sockWriter.Write(content.resPacket.header)
if err != nil {
return Err(err)
return err
}
_, err = context.sockWriter.Write(context.resPacket.rcpPayload)
_, err = content.sockWriter.Write(content.resPacket.rcpPayload)
if err != nil {
return Err(err)
return err
}
return Err(err)
return err
}

View File

@@ -13,7 +13,7 @@ import (
func ReadBytes(reader io.Reader, size int64) ([]byte, error) {
buffer := make([]byte, size)
read, err := io.ReadFull(reader, buffer)
return buffer[0:read], Err(err)
return buffer[0:read], err
}
func CopyBytes(reader io.Reader, writer io.Writer, dataSize int64) (int64, error) {
@@ -39,19 +39,19 @@ func CopyBytes(reader io.Reader, writer io.Writer, dataSize int64) (int64, error
received, err := reader.Read(buffer[0:bSize])
if err != nil {
err = fmt.Errorf("read error: %v", err)
return total, Err(err)
return total, err
}
recorded, err := writer.Write(buffer[0:received])
if err != nil {
err = fmt.Errorf("write error: %v", err)
return total, Err(err)
return total, err
}
if recorded != received {
err = errors.New("size mismatch")
return total, Err(err)
return total, err
}
total += int64(recorded)
remains -= int64(recorded)
}
return total, Err(err)
return total, err
}

View File

@@ -9,25 +9,29 @@ import (
"net"
)
func LocalExec(method string, param any, result any, auth *Auth, handler HandlerFunc) error {
func LocalExec(method string, param, result any, auth *Auth, handler HandlerFunc) error {
var err error
cliConn, srvConn := NewFConn()
context := CreateContext(cliConn)
context.reqRPC.Method = method
context.reqRPC.Params = param
context.reqRPC.Auth = auth
context.resRPC.Result = result
content := CreateContent(cliConn)
content.reqBlock.Method = method
if context.reqRPC.Params == nil {
context.reqRPC.Params = NewEmpty()
if param != nil {
content.reqBlock.Params = param
}
err = context.CreateRequest()
if auth != nil {
content.reqBlock.Auth = auth
}
if result != nil {
content.resBlock.Result = result
}
err = content.CreateRequest()
if err != nil {
return err
}
err = context.WriteRequest()
err = content.WriteRequest()
if err != nil {
return err
}
@@ -35,11 +39,11 @@ func LocalExec(method string, param any, result any, auth *Auth, handler Handler
if err != nil {
return err
}
err = context.ReadResponse()
err = content.ReadResponse()
if err != nil {
return err
}
err = context.BindResponse()
err = content.BindResponse()
if err != nil {
return err
}
@@ -53,29 +57,33 @@ func LocalPut(method string, reader io.Reader, size int64, param, result any, au
cliConn, srvConn := NewFConn()
context := CreateContext(cliConn)
context.reqRPC.Method = method
context.reqRPC.Params = param
context.reqRPC.Auth = auth
context.resRPC.Result = result
content := CreateContent(cliConn)
content.reqBlock.Method = method
context.binReader = reader
context.binWriter = cliConn
context.reqHeader.binSize = size
if context.reqRPC.Params == nil {
context.reqRPC.Params = NewEmpty()
if param != nil {
content.reqBlock.Params = param
}
err = context.CreateRequest()
if auth != nil {
content.reqBlock.Auth = auth
}
if result != nil {
content.resBlock.Result = result
}
content.binReader = reader
content.binWriter = cliConn
content.reqHeader.binSize = size
err = content.CreateRequest()
if err != nil {
return err
}
err = context.WriteRequest()
err = content.WriteRequest()
if err != nil {
return err
}
err = context.UploadBin()
err = content.UploadBin()
if err != nil {
return err
}
@@ -83,11 +91,11 @@ func LocalPut(method string, reader io.Reader, size int64, param, result any, au
if err != nil {
return err
}
err = context.ReadResponse()
err = content.ReadResponse()
if err != nil {
return err
}
err = context.BindResponse()
err = content.BindResponse()
if err != nil {
return err
}
@@ -99,23 +107,27 @@ func LocalGet(method string, writer io.Writer, param, result any, auth *Auth, ha
cliConn, srvConn := NewFConn()
context := CreateContext(cliConn)
context.reqRPC.Method = method
context.reqRPC.Params = param
context.reqRPC.Auth = auth
context.resRPC.Result = result
content := CreateContent(cliConn)
content.reqBlock.Method = method
context.binReader = cliConn
context.binWriter = writer
if context.reqRPC.Params == nil {
context.reqRPC.Params = NewEmpty()
if param != nil {
content.reqBlock.Params = param
}
err = context.CreateRequest()
if auth != nil {
content.reqBlock.Auth = auth
}
if result != nil {
content.resBlock.Result = result
}
content.binReader = cliConn
content.binWriter = writer
err = content.CreateRequest()
if err != nil {
return err
}
err = context.WriteRequest()
err = content.WriteRequest()
if err != nil {
return err
}
@@ -124,15 +136,15 @@ func LocalGet(method string, writer io.Writer, param, result any, auth *Auth, ha
if err != nil {
return err
}
err = context.ReadResponse()
err = content.ReadResponse()
if err != nil {
return err
}
err = context.DownloadBin()
err = content.DownloadBin()
if err != nil {
return err
}
err = context.BindResponse()
err = content.BindResponse()
if err != nil {
return err
}
@@ -141,22 +153,22 @@ func LocalGet(method string, writer io.Writer, param, result any, auth *Auth, ha
func LocalService(conn net.Conn, handler HandlerFunc) error {
var err error
context := CreateContext(conn)
content := CreateContent(conn)
remoteAddr := conn.RemoteAddr().String()
remoteHost, _, _ := net.SplitHostPort(remoteAddr)
context.remoteHost = remoteHost
content.remoteHost = remoteHost
context.binReader = conn
context.binWriter = io.Discard
content.binReader = conn
content.binWriter = io.Discard
err = context.ReadRequest()
err = content.ReadRequest()
if err != nil {
return err
}
err = context.BindMethod()
err = content.BindMethod()
if err != nil {
return err
}
return handler(context)
return handler(content)
}