go fmt
This commit is contained in:
508
client.go
508
client.go
@@ -5,332 +5,328 @@
|
||||
package dsrpc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
encoder "github.com/vmihailenco/msgpack/v5"
|
||||
encoder "github.com/vmihailenco/msgpack/v5"
|
||||
)
|
||||
|
||||
|
||||
func Put(address string, method string, reader io.Reader, size int64, param, result any, auth *Auth) error {
|
||||
var err error
|
||||
var err error
|
||||
|
||||
addr, err := net.ResolveTCPAddr("tcp", address)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to resolve adddress: %s", err)
|
||||
return Err(err)
|
||||
}
|
||||
conn, err := net.DialTCP("tcp", nil, addr)
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
addr, err := net.ResolveTCPAddr("tcp", address)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to resolve adddress: %s", err)
|
||||
return Err(err)
|
||||
}
|
||||
conn, err := net.DialTCP("tcp", nil, addr)
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
//err = conn.SetKeepAlive(true)
|
||||
//if err != nil {
|
||||
// err = fmt.Errorf("unable to set keepalive: %s", err)
|
||||
// return Err(err)
|
||||
//}
|
||||
//err = conn.SetKeepAlive(true)
|
||||
//if err != nil {
|
||||
// err = fmt.Errorf("unable to set keepalive: %s", err)
|
||||
// return Err(err)
|
||||
//}
|
||||
|
||||
//err = conn.SetKeepAlivePeriod(10 * time.Second)
|
||||
//if err != nil {
|
||||
// err = fmt.Errorf("unable to set keepalive period: %s", err)
|
||||
// return Err(err)
|
||||
//}
|
||||
//err = conn.SetKeepAlivePeriod(10 * time.Second)
|
||||
//if err != nil {
|
||||
// err = fmt.Errorf("unable to set keepalive period: %s", err)
|
||||
// return Err(err)
|
||||
//}
|
||||
|
||||
return ConnPut(conn, method, reader, size, param, result, auth)
|
||||
return ConnPut(conn, method, reader, size, param, result, auth)
|
||||
}
|
||||
|
||||
|
||||
func ConnPut(conn net.Conn, method string, reader io.Reader, size int64, param, result any, auth *Auth) error {
|
||||
var err error
|
||||
context := CreateContext(conn)
|
||||
context.reqRPC.Method = method
|
||||
context.reqRPC.Params = param
|
||||
context.reqRPC.Auth = auth
|
||||
context.resRPC.Result = result
|
||||
var err error
|
||||
context := CreateContext(conn)
|
||||
context.reqRPC.Method = method
|
||||
context.reqRPC.Params = param
|
||||
context.reqRPC.Auth = auth
|
||||
context.resRPC.Result = result
|
||||
|
||||
context.binReader = reader
|
||||
context.binWriter = conn
|
||||
context.binReader = reader
|
||||
context.binWriter = conn
|
||||
|
||||
context.reqHeader.binSize = size
|
||||
context.reqHeader.binSize = size
|
||||
|
||||
if context.reqRPC.Params == nil {
|
||||
context.reqRPC.Params = NewEmpty()
|
||||
}
|
||||
if context.reqRPC.Params == nil {
|
||||
context.reqRPC.Params = NewEmpty()
|
||||
}
|
||||
|
||||
err = context.CreateRequest()
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
err = context.WriteRequest()
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
err = context.CreateRequest()
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
err = context.WriteRequest()
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errChan := make(chan error, 1)
|
||||
var wg sync.WaitGroup
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
wg.Add(1)
|
||||
go context.ReadResponseAsync(&wg, errChan)
|
||||
wg.Add(1)
|
||||
go context.ReadResponseAsync(&wg, errChan)
|
||||
|
||||
wg.Add(1)
|
||||
go context.UploadBinAsync(&wg)
|
||||
wg.Add(1)
|
||||
go context.UploadBinAsync(&wg)
|
||||
|
||||
wg.Wait()
|
||||
err = <- errChan
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
err = context.BindResponse()
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
return Err(err)
|
||||
wg.Wait()
|
||||
err = <-errChan
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
err = context.BindResponse()
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
return Err(err)
|
||||
}
|
||||
|
||||
func Get(address string, method string, writer io.Writer, param, result any, auth *Auth) error {
|
||||
var err error
|
||||
var err error
|
||||
|
||||
addr, err := net.ResolveTCPAddr("tcp", address)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to resolve adddress: %s", err)
|
||||
return Err(err)
|
||||
}
|
||||
conn, err := net.DialTCP("tcp", nil, addr)
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
addr, err := net.ResolveTCPAddr("tcp", address)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to resolve adddress: %s", err)
|
||||
return Err(err)
|
||||
}
|
||||
conn, err := net.DialTCP("tcp", nil, addr)
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
//err = conn.SetKeepAlive(true)
|
||||
//if err != nil {
|
||||
// err = fmt.Errorf("unable to set keepalive: %s", err)
|
||||
// return Err(err)
|
||||
//}
|
||||
//err = conn.SetKeepAlive(true)
|
||||
//if err != nil {
|
||||
// err = fmt.Errorf("unable to set keepalive: %s", err)
|
||||
// return Err(err)
|
||||
//}
|
||||
|
||||
//err = conn.SetKeepAlivePeriod(10 * time.Second)
|
||||
//if err != nil {
|
||||
// err = fmt.Errorf("unable to set keepalive period: %s", err)
|
||||
// return Err(err)
|
||||
//}
|
||||
//err = conn.SetKeepAlivePeriod(10 * time.Second)
|
||||
//if err != nil {
|
||||
// err = fmt.Errorf("unable to set keepalive period: %s", err)
|
||||
// return Err(err)
|
||||
//}
|
||||
|
||||
return ConnGet(conn, method, writer, param, result, auth)
|
||||
return ConnGet(conn, method, writer, param, result, auth)
|
||||
}
|
||||
|
||||
func ConnGet(conn net.Conn, method string, writer io.Writer, param, result any, auth *Auth) error {
|
||||
var err error
|
||||
var err error
|
||||
|
||||
context := CreateContext(conn)
|
||||
context.reqRPC.Method = method
|
||||
context.reqRPC.Params = param
|
||||
context.reqRPC.Auth = auth
|
||||
context.resRPC.Result = result
|
||||
context := CreateContext(conn)
|
||||
context.reqRPC.Method = method
|
||||
context.reqRPC.Params = param
|
||||
context.reqRPC.Auth = auth
|
||||
context.resRPC.Result = result
|
||||
|
||||
context.binReader = conn
|
||||
context.binWriter = writer
|
||||
context.binReader = conn
|
||||
context.binWriter = writer
|
||||
|
||||
if context.reqRPC.Params == nil {
|
||||
context.reqRPC.Params = NewEmpty()
|
||||
}
|
||||
err = context.CreateRequest()
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
err = context.WriteRequest()
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
err = context.ReadResponse()
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
err = context.DownloadBin()
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
err = context.BindResponse()
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
return Err(err)
|
||||
if context.reqRPC.Params == nil {
|
||||
context.reqRPC.Params = NewEmpty()
|
||||
}
|
||||
err = context.CreateRequest()
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
err = context.WriteRequest()
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
err = context.ReadResponse()
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
err = context.DownloadBin()
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
err = context.BindResponse()
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
return Err(err)
|
||||
}
|
||||
|
||||
func Exec(address, method string, param any, result any, auth *Auth) error {
|
||||
var err error
|
||||
var err error
|
||||
|
||||
addr, err := net.ResolveTCPAddr("tcp", address)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to resolve adddress: %s", err)
|
||||
return Err(err)
|
||||
}
|
||||
conn, err := net.DialTCP("tcp", nil, addr)
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
addr, err := net.ResolveTCPAddr("tcp", address)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to resolve adddress: %s", err)
|
||||
return Err(err)
|
||||
}
|
||||
conn, err := net.DialTCP("tcp", nil, addr)
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
//err = conn.SetKeepAlive(true)
|
||||
//if err != nil {
|
||||
// err = fmt.Errorf("unable to set keepalive: %s", err)
|
||||
// return Err(err)
|
||||
//}
|
||||
//err = conn.SetKeepAlivePeriod(10 * time.Second)
|
||||
//if err != nil {
|
||||
// err = fmt.Errorf("unable to set keepalive period: %s", err)
|
||||
// return Err(err)
|
||||
//}
|
||||
//err = conn.SetKeepAlive(true)
|
||||
//if err != nil {
|
||||
// err = fmt.Errorf("unable to set keepalive: %s", err)
|
||||
// return Err(err)
|
||||
//}
|
||||
//err = conn.SetKeepAlivePeriod(10 * time.Second)
|
||||
//if err != nil {
|
||||
// err = fmt.Errorf("unable to set keepalive period: %s", err)
|
||||
// return Err(err)
|
||||
//}
|
||||
|
||||
err = ConnExec(conn, method, param, result, auth)
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
return Err(err)
|
||||
err = ConnExec(conn, method, param, result, auth)
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
return Err(err)
|
||||
}
|
||||
|
||||
|
||||
func ConnExec(conn net.Conn, method string, param any, result any, auth *Auth) error {
|
||||
var err error
|
||||
var err error
|
||||
|
||||
context := CreateContext(conn)
|
||||
context.reqRPC.Method = method
|
||||
context.reqRPC.Params = param
|
||||
context.reqRPC.Auth = auth
|
||||
context.resRPC.Result = result
|
||||
context := CreateContext(conn)
|
||||
context.reqRPC.Method = method
|
||||
context.reqRPC.Params = param
|
||||
context.reqRPC.Auth = auth
|
||||
context.resRPC.Result = result
|
||||
|
||||
if context.reqRPC.Params == nil {
|
||||
context.reqRPC.Params = NewEmpty()
|
||||
}
|
||||
if context.reqRPC.Params == nil {
|
||||
context.reqRPC.Params = NewEmpty()
|
||||
}
|
||||
|
||||
err = context.CreateRequest()
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
err = context.WriteRequest()
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
err = context.ReadResponse()
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
err = context.BindResponse()
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
return Err(err)
|
||||
err = context.CreateRequest()
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
err = context.WriteRequest()
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
err = context.ReadResponse()
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
err = context.BindResponse()
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
return Err(err)
|
||||
}
|
||||
|
||||
|
||||
func (context *Context) CreateRequest() error {
|
||||
var err error
|
||||
var err error
|
||||
|
||||
context.reqPacket.rcpPayload, err = context.reqRPC.Pack()
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
rpcSize := int64(len(context.reqPacket.rcpPayload))
|
||||
context.reqHeader.rpcSize = rpcSize
|
||||
context.reqPacket.rcpPayload, err = context.reqRPC.Pack()
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
rpcSize := int64(len(context.reqPacket.rcpPayload))
|
||||
context.reqHeader.rpcSize = rpcSize
|
||||
|
||||
context.reqPacket.header, err = context.reqHeader.Pack()
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
return Err(err)
|
||||
context.reqPacket.header, err = context.reqHeader.Pack()
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
return Err(err)
|
||||
}
|
||||
|
||||
func (context *Context) WriteRequest() error {
|
||||
var err error
|
||||
_, err = context.sockWriter.Write(context.reqPacket.header)
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
_, err = context.sockWriter.Write(context.reqPacket.rcpPayload)
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
return Err(err)
|
||||
var err error
|
||||
_, err = context.sockWriter.Write(context.reqPacket.header)
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
_, err = context.sockWriter.Write(context.reqPacket.rcpPayload)
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
return Err(err)
|
||||
}
|
||||
|
||||
func (context *Context) UploadBin() error {
|
||||
var err error
|
||||
_, err = CopyBytes(context.binReader, context.binWriter, context.reqHeader.binSize)
|
||||
return Err(err)
|
||||
var err error
|
||||
_, err = CopyBytes(context.binReader, context.binWriter, context.reqHeader.binSize)
|
||||
return Err(err)
|
||||
}
|
||||
|
||||
func (context *Context) ReadResponse() error {
|
||||
var err error
|
||||
var err error
|
||||
|
||||
context.resPacket.header, err = ReadBytes(context.sockReader, headerSize)
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
context.resHeader, err = UnpackHeader(context.resPacket.header)
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
rpcSize := context.resHeader.rpcSize
|
||||
context.resPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize)
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
return Err(err)
|
||||
context.resPacket.header, err = ReadBytes(context.sockReader, headerSize)
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
context.resHeader, err = UnpackHeader(context.resPacket.header)
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
rpcSize := context.resHeader.rpcSize
|
||||
context.resPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize)
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
return Err(err)
|
||||
}
|
||||
|
||||
func (context *Context) UploadBinAsync(wg *sync.WaitGroup) {
|
||||
exitFunc := func() {
|
||||
wg.Done()
|
||||
}
|
||||
defer exitFunc()
|
||||
_, _ = CopyBytes(context.binReader, context.binWriter, context.reqHeader.binSize)
|
||||
return
|
||||
exitFunc := func() {
|
||||
wg.Done()
|
||||
}
|
||||
defer exitFunc()
|
||||
_, _ = CopyBytes(context.binReader, context.binWriter, context.reqHeader.binSize)
|
||||
return
|
||||
}
|
||||
|
||||
func (context *Context) ReadResponseAsync(wg *sync.WaitGroup, errChan chan error) {
|
||||
var err error
|
||||
exitFunc := func() {
|
||||
errChan <- err
|
||||
wg.Done()
|
||||
}
|
||||
defer exitFunc()
|
||||
context.resPacket.header, err = ReadBytes(context.sockReader, headerSize)
|
||||
if err != nil {
|
||||
err = Err(err)
|
||||
return
|
||||
}
|
||||
context.resHeader, err = UnpackHeader(context.resPacket.header)
|
||||
if err != nil {
|
||||
err = Err(err)
|
||||
return
|
||||
}
|
||||
rpcSize := context.resHeader.rpcSize
|
||||
context.resPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize)
|
||||
if err != nil {
|
||||
err = Err(err)
|
||||
return
|
||||
}
|
||||
return
|
||||
var err error
|
||||
exitFunc := func() {
|
||||
errChan <- err
|
||||
wg.Done()
|
||||
}
|
||||
defer exitFunc()
|
||||
context.resPacket.header, err = ReadBytes(context.sockReader, headerSize)
|
||||
if err != nil {
|
||||
err = Err(err)
|
||||
return
|
||||
}
|
||||
context.resHeader, err = UnpackHeader(context.resPacket.header)
|
||||
if err != nil {
|
||||
err = Err(err)
|
||||
return
|
||||
}
|
||||
rpcSize := context.resHeader.rpcSize
|
||||
context.resPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize)
|
||||
if err != nil {
|
||||
err = Err(err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (context *Context) DownloadBin() error {
|
||||
var err error
|
||||
_, err = CopyBytes(context.binReader, context.binWriter, context.resHeader.binSize)
|
||||
return Err(err)
|
||||
var err error
|
||||
_, err = CopyBytes(context.binReader, context.binWriter, context.resHeader.binSize)
|
||||
return Err(err)
|
||||
}
|
||||
|
||||
func (context *Context) BindResponse() error {
|
||||
var err error
|
||||
var err error
|
||||
|
||||
err = encoder.Unmarshal(context.resPacket.rcpPayload, context.resRPC)
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
if len(context.resRPC.Error) > 0 {
|
||||
err = errors.New(context.resRPC.Error)
|
||||
return Err(err)
|
||||
}
|
||||
return Err(err)
|
||||
err = encoder.Unmarshal(context.resPacket.rcpPayload, context.resRPC)
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
if len(context.resRPC.Error) > 0 {
|
||||
err = errors.New(context.resRPC.Error)
|
||||
return Err(err)
|
||||
}
|
||||
return Err(err)
|
||||
}
|
||||
|
||||
166
context.go
166
context.go
@@ -5,152 +5,148 @@
|
||||
package dsrpc
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Context struct {
|
||||
start time.Time
|
||||
remoteHost string
|
||||
sockReader io.Reader
|
||||
sockWriter io.Writer
|
||||
start time.Time
|
||||
remoteHost string
|
||||
sockReader io.Reader
|
||||
sockWriter io.Writer
|
||||
|
||||
reqHeader *Header
|
||||
reqRPC *Request
|
||||
reqPacket *Packet
|
||||
reqHeader *Header
|
||||
reqRPC *Request
|
||||
reqPacket *Packet
|
||||
|
||||
resPacket *Packet
|
||||
resHeader *Header
|
||||
resRPC *Response
|
||||
resPacket *Packet
|
||||
resHeader *Header
|
||||
resRPC *Response
|
||||
|
||||
binReader io.Reader
|
||||
binWriter io.Writer
|
||||
binReader io.Reader
|
||||
binWriter io.Writer
|
||||
}
|
||||
|
||||
|
||||
func NewContext() *Context {
|
||||
context := &Context{}
|
||||
context.start = time.Now()
|
||||
return context
|
||||
context := &Context{}
|
||||
context.start = time.Now()
|
||||
return context
|
||||
}
|
||||
|
||||
func CreateContext(conn net.Conn) *Context {
|
||||
context := &Context{}
|
||||
context.start = time.Now()
|
||||
context.sockReader = conn
|
||||
context.sockWriter = conn
|
||||
context := &Context{}
|
||||
context.start = time.Now()
|
||||
context.sockReader = conn
|
||||
context.sockWriter = conn
|
||||
|
||||
context.reqPacket = NewPacket()
|
||||
context.resPacket = NewPacket()
|
||||
context.reqPacket = NewPacket()
|
||||
context.resPacket = NewPacket()
|
||||
|
||||
context.reqHeader = NewHeader()
|
||||
context.reqRPC = NewRequest()
|
||||
context.reqHeader = NewHeader()
|
||||
context.reqRPC = NewRequest()
|
||||
|
||||
context.resHeader = NewHeader()
|
||||
context.resRPC = NewResponse()
|
||||
context.resRPC = NewResponse()
|
||||
context.resHeader = NewHeader()
|
||||
context.resRPC = NewResponse()
|
||||
context.resRPC = NewResponse()
|
||||
|
||||
return context
|
||||
return context
|
||||
}
|
||||
|
||||
func (context *Context) Request() *Request {
|
||||
return context.reqRPC
|
||||
func (context *Context) Request() *Request {
|
||||
return context.reqRPC
|
||||
}
|
||||
|
||||
func (context *Context) RemoteHost() string {
|
||||
return context.remoteHost
|
||||
return context.remoteHost
|
||||
}
|
||||
|
||||
func (context *Context) Start() time.Time {
|
||||
return context.start
|
||||
return context.start
|
||||
}
|
||||
|
||||
func (context *Context) Method() string {
|
||||
var method string
|
||||
if context.reqRPC != nil {
|
||||
method = context.reqRPC.Method
|
||||
}
|
||||
return method
|
||||
var method string
|
||||
if context.reqRPC != nil {
|
||||
method = context.reqRPC.Method
|
||||
}
|
||||
return method
|
||||
}
|
||||
|
||||
func (context *Context) ReqRpcSize() int64 {
|
||||
var size int64
|
||||
if context.reqHeader != nil {
|
||||
size = context.reqHeader.rpcSize
|
||||
}
|
||||
return size
|
||||
var size int64
|
||||
if context.reqHeader != nil {
|
||||
size = context.reqHeader.rpcSize
|
||||
}
|
||||
return size
|
||||
}
|
||||
|
||||
|
||||
func (context *Context) ReqBinSize() int64 {
|
||||
var size int64
|
||||
if context.reqHeader != nil {
|
||||
size = context.reqHeader.binSize
|
||||
}
|
||||
return size
|
||||
var size int64
|
||||
if context.reqHeader != nil {
|
||||
size = context.reqHeader.binSize
|
||||
}
|
||||
return size
|
||||
}
|
||||
|
||||
func (context *Context) ResBinSize() int64 {
|
||||
var size int64
|
||||
if context.resHeader != nil {
|
||||
size = context.resHeader.binSize
|
||||
}
|
||||
return size
|
||||
var size int64
|
||||
if context.resHeader != nil {
|
||||
size = context.resHeader.binSize
|
||||
}
|
||||
return size
|
||||
}
|
||||
|
||||
func (context *Context) ResRpcSize() int64 {
|
||||
var size int64
|
||||
if context.resHeader != nil {
|
||||
size = context.resHeader.rpcSize
|
||||
}
|
||||
return size
|
||||
var size int64
|
||||
if context.resHeader != nil {
|
||||
size = context.resHeader.rpcSize
|
||||
}
|
||||
return size
|
||||
}
|
||||
|
||||
func (context *Context) ReqSize() int64 {
|
||||
var size int64
|
||||
if context.reqHeader != nil {
|
||||
size += context.reqHeader.binSize
|
||||
size += context.reqHeader.rpcSize
|
||||
}
|
||||
return size
|
||||
var size int64
|
||||
if context.reqHeader != nil {
|
||||
size += context.reqHeader.binSize
|
||||
size += context.reqHeader.rpcSize
|
||||
}
|
||||
return size
|
||||
}
|
||||
|
||||
func (context *Context) ResSize() int64 {
|
||||
var size int64
|
||||
if context.resHeader != nil {
|
||||
size += context.resHeader.binSize
|
||||
size += context.resHeader.rpcSize
|
||||
}
|
||||
return size
|
||||
var size int64
|
||||
if context.resHeader != nil {
|
||||
size += context.resHeader.binSize
|
||||
size += context.resHeader.rpcSize
|
||||
}
|
||||
return size
|
||||
}
|
||||
|
||||
|
||||
|
||||
func (context *Context) SetAuthIdent(ident []byte) {
|
||||
context.reqRPC.Auth.Ident = ident
|
||||
func (context *Context) SetAuthIdent(ident []byte) {
|
||||
context.reqRPC.Auth.Ident = ident
|
||||
}
|
||||
|
||||
func (context *Context) SetAuthSalt(salt []byte) {
|
||||
context.reqRPC.Auth.Salt = salt
|
||||
func (context *Context) SetAuthSalt(salt []byte) {
|
||||
context.reqRPC.Auth.Salt = salt
|
||||
}
|
||||
|
||||
func (context *Context) SetAuthHash(hash []byte) {
|
||||
context.reqRPC.Auth.Hash = hash
|
||||
func (context *Context) SetAuthHash(hash []byte) {
|
||||
context.reqRPC.Auth.Hash = hash
|
||||
}
|
||||
|
||||
func (context *Context) AuthIdent() []byte {
|
||||
return context.reqRPC.Auth.Ident
|
||||
return context.reqRPC.Auth.Ident
|
||||
}
|
||||
|
||||
func (context *Context) AuthSalt() []byte {
|
||||
return context.reqRPC.Auth.Salt
|
||||
return context.reqRPC.Auth.Salt
|
||||
}
|
||||
|
||||
func (context *Context) AuthHash() []byte {
|
||||
return context.reqRPC.Auth.Hash
|
||||
return context.reqRPC.Auth.Hash
|
||||
}
|
||||
|
||||
func (context *Context) Auth() *Auth {
|
||||
return context.reqRPC.Auth
|
||||
return context.reqRPC.Auth
|
||||
}
|
||||
|
||||
4
empty.go
4
empty.go
@@ -6,8 +6,8 @@
|
||||
|
||||
package dsrpc
|
||||
|
||||
type Empty struct {}
|
||||
type Empty struct{}
|
||||
|
||||
func NewEmpty() *Empty {
|
||||
return &Empty{}
|
||||
return &Empty{}
|
||||
}
|
||||
|
||||
47
error.go
47
error.go
@@ -5,39 +5,38 @@
|
||||
package dsrpc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"io"
|
||||
"fmt"
|
||||
"io"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
var develMode bool = false
|
||||
var debugMode bool = false
|
||||
|
||||
|
||||
func SetDevelMode(mode bool) {
|
||||
develMode = mode
|
||||
develMode = mode
|
||||
}
|
||||
func SetDebugMode(mode bool) {
|
||||
debugMode = mode
|
||||
debugMode = mode
|
||||
}
|
||||
|
||||
func Err(err error) error {
|
||||
switch err {
|
||||
case io.EOF:
|
||||
return err
|
||||
}
|
||||
if err != nil {
|
||||
switch {
|
||||
case develMode == true:
|
||||
pc, filename, line, _ := runtime.Caller(1)
|
||||
funcName := runtime.FuncForPC(pc).Name()
|
||||
err = fmt.Errorf(" %s:%d:%s:%s", filename, line, funcName, err.Error())
|
||||
case debugMode == true:
|
||||
pc, _, line, _ := runtime.Caller(1)
|
||||
funcName := runtime.FuncForPC(pc).Name()
|
||||
err = fmt.Errorf(" %s:%d:%s ", funcName, line, err.Error())
|
||||
default:
|
||||
}
|
||||
}
|
||||
return err
|
||||
switch err {
|
||||
case io.EOF:
|
||||
return err
|
||||
}
|
||||
if err != nil {
|
||||
switch {
|
||||
case develMode == true:
|
||||
pc, filename, line, _ := runtime.Caller(1)
|
||||
funcName := runtime.FuncForPC(pc).Name()
|
||||
err = fmt.Errorf(" %s:%d:%s:%s", filename, line, funcName, err.Error())
|
||||
case debugMode == true:
|
||||
pc, _, line, _ := runtime.Caller(1)
|
||||
funcName := runtime.FuncForPC(pc).Name()
|
||||
err = fmt.Errorf(" %s:%d:%s ", funcName, line, err.Error())
|
||||
default:
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
469
exec_test.go
469
exec_test.go
@@ -5,370 +5,361 @@
|
||||
package dsrpc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLocalExec(t *testing.T) {
|
||||
var err error
|
||||
params := NewHelloParams()
|
||||
params.Message = "hello server!"
|
||||
result := NewHelloResult()
|
||||
var err error
|
||||
params := NewHelloParams()
|
||||
params.Message = "hello server!"
|
||||
result := NewHelloResult()
|
||||
|
||||
auth := CreateAuth([]byte("qwert"), []byte("12345"))
|
||||
auth := CreateAuth([]byte("qwert"), []byte("12345"))
|
||||
|
||||
err = LocalExec(HelloMethod, params, result, auth, helloHandler)
|
||||
require.NoError(t, err)
|
||||
resultJSON, _ := json.Marshal(result)
|
||||
logDebug("method result:", string(resultJSON))
|
||||
err = LocalExec(HelloMethod, params, result, auth, helloHandler)
|
||||
require.NoError(t, err)
|
||||
resultJSON, _ := json.Marshal(result)
|
||||
logDebug("method result:", string(resultJSON))
|
||||
}
|
||||
|
||||
|
||||
func TestLocalSave(t *testing.T) {
|
||||
var err error
|
||||
var err error
|
||||
|
||||
params := NewSaveParams()
|
||||
params.Message = "save data!"
|
||||
result := NewHelloResult()
|
||||
auth := CreateAuth([]byte("qwert"), []byte("12345"))
|
||||
params := NewSaveParams()
|
||||
params.Message = "save data!"
|
||||
result := NewHelloResult()
|
||||
auth := CreateAuth([]byte("qwert"), []byte("12345"))
|
||||
|
||||
var binSize int64 = 16
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
binBytes := make([]byte, binSize)
|
||||
rand.Read(binBytes)
|
||||
var binSize int64 = 16
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
binBytes := make([]byte, binSize)
|
||||
rand.Read(binBytes)
|
||||
|
||||
reader := bytes.NewReader(binBytes)
|
||||
reader := bytes.NewReader(binBytes)
|
||||
|
||||
err = LocalPut(SaveMethod, reader, binSize, params, result, auth, saveHandler)
|
||||
require.NoError(t, err)
|
||||
err = LocalPut(SaveMethod, reader, binSize, params, result, auth, saveHandler)
|
||||
require.NoError(t, err)
|
||||
|
||||
resultJSON, _ := json.Marshal(result)
|
||||
logDebug("method result:", string(resultJSON))
|
||||
resultJSON, _ := json.Marshal(result)
|
||||
logDebug("method result:", string(resultJSON))
|
||||
}
|
||||
|
||||
|
||||
func TestLocalLoad(t *testing.T) {
|
||||
var err error
|
||||
var err error
|
||||
|
||||
params := NewLoadParams()
|
||||
params.Message = "load data!"
|
||||
result := NewHelloResult()
|
||||
auth := CreateAuth([]byte("qwert"), []byte("12345"))
|
||||
params := NewLoadParams()
|
||||
params.Message = "load data!"
|
||||
result := NewHelloResult()
|
||||
auth := CreateAuth([]byte("qwert"), []byte("12345"))
|
||||
|
||||
binBytes := make([]byte, 0)
|
||||
writer := bytes.NewBuffer(binBytes)
|
||||
binBytes := make([]byte, 0)
|
||||
writer := bytes.NewBuffer(binBytes)
|
||||
|
||||
err = LocalGet(LoadMethod, writer, params, result, auth, loadHandler)
|
||||
require.NoError(t, err)
|
||||
err = LocalGet(LoadMethod, writer, params, result, auth, loadHandler)
|
||||
require.NoError(t, err)
|
||||
|
||||
resultJSON, _ := json.Marshal(result)
|
||||
logDebug("method result:", string(resultJSON))
|
||||
logDebug("bin size:", len(writer.Bytes()))
|
||||
resultJSON, _ := json.Marshal(result)
|
||||
logDebug("method result:", string(resultJSON))
|
||||
logDebug("bin size:", len(writer.Bytes()))
|
||||
}
|
||||
|
||||
|
||||
func TestNetExec(t *testing.T) {
|
||||
go testServ(false)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
err := clientHello()
|
||||
go testServ(false)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
err := clientHello()
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestNetSave(t *testing.T) {
|
||||
go testServ(false)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
err := clientSave()
|
||||
require.NoError(t, err)
|
||||
go testServ(false)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
err := clientSave()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestNetLoad(t *testing.T) {
|
||||
go testServ(false)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
err := clientLoad()
|
||||
require.NoError(t, err)
|
||||
go testServ(false)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
err := clientLoad()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func BenchmarkNetPut(b *testing.B) {
|
||||
go testServ(true)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
clientSave()
|
||||
go testServ(true)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
clientSave()
|
||||
|
||||
pBench := func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
clientSave()
|
||||
}
|
||||
}
|
||||
b.SetParallelism(10)
|
||||
b.RunParallel(pBench)
|
||||
pBench := func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
clientSave()
|
||||
}
|
||||
}
|
||||
b.SetParallelism(10)
|
||||
b.RunParallel(pBench)
|
||||
}
|
||||
|
||||
func clientHello() error {
|
||||
var err error
|
||||
var err error
|
||||
|
||||
params := NewHelloParams()
|
||||
params.Message = "hello server!"
|
||||
result := NewHelloResult()
|
||||
auth := CreateAuth([]byte("qwert"), []byte("12345"))
|
||||
params := NewHelloParams()
|
||||
params.Message = "hello server!"
|
||||
result := NewHelloResult()
|
||||
auth := CreateAuth([]byte("qwert"), []byte("12345"))
|
||||
|
||||
var binSize int64 = 16
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
binBytes := make([]byte, binSize)
|
||||
rand.Read(binBytes)
|
||||
var binSize int64 = 16
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
binBytes := make([]byte, binSize)
|
||||
rand.Read(binBytes)
|
||||
|
||||
err = Exec("127.0.0.1:8081", HelloMethod, params, result, auth)
|
||||
if err != nil {
|
||||
logError("method err:", err)
|
||||
return err
|
||||
}
|
||||
resultJSON, _ := json.Marshal(result)
|
||||
logDebug("method result:", string(resultJSON))
|
||||
return err
|
||||
err = Exec("127.0.0.1:8081", HelloMethod, params, result, auth)
|
||||
if err != nil {
|
||||
logError("method err:", err)
|
||||
return err
|
||||
}
|
||||
resultJSON, _ := json.Marshal(result)
|
||||
logDebug("method result:", string(resultJSON))
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
func clientSave() error {
|
||||
var err error
|
||||
var err error
|
||||
|
||||
params := NewSaveParams()
|
||||
params.Message = "save data!"
|
||||
result := NewHelloResult()
|
||||
auth := CreateAuth([]byte("qwert"), []byte("12345"))
|
||||
params := NewSaveParams()
|
||||
params.Message = "save data!"
|
||||
result := NewHelloResult()
|
||||
auth := CreateAuth([]byte("qwert"), []byte("12345"))
|
||||
|
||||
var binSize int64 = 16
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
binBytes := make([]byte, binSize)
|
||||
rand.Read(binBytes)
|
||||
var binSize int64 = 16
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
binBytes := make([]byte, binSize)
|
||||
rand.Read(binBytes)
|
||||
|
||||
reader := bytes.NewReader(binBytes)
|
||||
reader := bytes.NewReader(binBytes)
|
||||
|
||||
err = Put("127.0.0.1:8081", SaveMethod, reader, binSize, params, result, auth)
|
||||
if err != nil {
|
||||
logError("method err:", err)
|
||||
return err
|
||||
}
|
||||
resultJSON, _ := json.Marshal(result)
|
||||
logDebug("method result:", string(resultJSON))
|
||||
return err
|
||||
err = Put("127.0.0.1:8081", SaveMethod, reader, binSize, params, result, auth)
|
||||
if err != nil {
|
||||
logError("method err:", err)
|
||||
return err
|
||||
}
|
||||
resultJSON, _ := json.Marshal(result)
|
||||
logDebug("method result:", string(resultJSON))
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
func clientLoad() error {
|
||||
var err error
|
||||
var err error
|
||||
|
||||
params := NewLoadParams()
|
||||
params.Message = "load data!"
|
||||
result := NewHelloResult()
|
||||
auth := CreateAuth([]byte("qwert"), []byte("12345"))
|
||||
params := NewLoadParams()
|
||||
params.Message = "load data!"
|
||||
result := NewHelloResult()
|
||||
auth := CreateAuth([]byte("qwert"), []byte("12345"))
|
||||
|
||||
binBytes := make([]byte, 0)
|
||||
writer := bytes.NewBuffer(binBytes)
|
||||
|
||||
binBytes := make([]byte, 0)
|
||||
writer := bytes.NewBuffer(binBytes)
|
||||
|
||||
err = Get("127.0.0.1:8081", LoadMethod, writer, params, result, auth)
|
||||
if err != nil {
|
||||
logError("method err:", err)
|
||||
return err
|
||||
}
|
||||
resultJSON, _ := json.Marshal(result)
|
||||
logDebug("method result:", string(resultJSON))
|
||||
logDebug("bin size:", len(writer.Bytes()))
|
||||
return err
|
||||
err = Get("127.0.0.1:8081", LoadMethod, writer, params, result, auth)
|
||||
if err != nil {
|
||||
logError("method err:", err)
|
||||
return err
|
||||
}
|
||||
resultJSON, _ := json.Marshal(result)
|
||||
logDebug("method result:", string(resultJSON))
|
||||
logDebug("bin size:", len(writer.Bytes()))
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
var testServRun bool = false
|
||||
|
||||
func testServ(quiet bool) error {
|
||||
var err error
|
||||
var err error
|
||||
|
||||
if testServRun {
|
||||
return err
|
||||
}
|
||||
testServRun = true
|
||||
if testServRun {
|
||||
return err
|
||||
}
|
||||
testServRun = true
|
||||
|
||||
if quiet {
|
||||
SetAccessWriter(io.Discard)
|
||||
SetMessageWriter(io.Discard)
|
||||
}
|
||||
serv := NewService()
|
||||
serv.Handler(HelloMethod, helloHandler)
|
||||
serv.Handler(SaveMethod, saveHandler)
|
||||
serv.Handler(LoadMethod, loadHandler)
|
||||
if quiet {
|
||||
SetAccessWriter(io.Discard)
|
||||
SetMessageWriter(io.Discard)
|
||||
}
|
||||
serv := NewService()
|
||||
serv.Handler(HelloMethod, helloHandler)
|
||||
serv.Handler(SaveMethod, saveHandler)
|
||||
serv.Handler(LoadMethod, loadHandler)
|
||||
|
||||
serv.PreMiddleware(LogRequest)
|
||||
serv.PreMiddleware(auth)
|
||||
serv.PreMiddleware(LogRequest)
|
||||
serv.PreMiddleware(auth)
|
||||
|
||||
serv.PostMiddleware(LogResponse)
|
||||
serv.PostMiddleware(LogAccess)
|
||||
serv.PostMiddleware(LogResponse)
|
||||
serv.PostMiddleware(LogAccess)
|
||||
|
||||
err = serv.Listen(":8081")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return err
|
||||
err = serv.Listen(":8081")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func auth(context *Context) error {
|
||||
var err error
|
||||
reqIdent := context.AuthIdent()
|
||||
reqSalt := context.AuthSalt()
|
||||
reqHash := context.AuthHash()
|
||||
var err error
|
||||
reqIdent := context.AuthIdent()
|
||||
reqSalt := context.AuthSalt()
|
||||
reqHash := context.AuthHash()
|
||||
|
||||
ident := reqIdent
|
||||
pass := []byte("12345")
|
||||
ident := reqIdent
|
||||
pass := []byte("12345")
|
||||
|
||||
auth := context.Auth()
|
||||
logDebug("auth ", string(auth.JSON()))
|
||||
auth := context.Auth()
|
||||
logDebug("auth ", string(auth.JSON()))
|
||||
|
||||
ok := CheckHash(ident, pass, reqSalt, reqHash)
|
||||
logDebug("auth ok:", ok)
|
||||
if !ok {
|
||||
err = errors.New("auth ident or pass missmatch")
|
||||
context.SendError(err)
|
||||
return err
|
||||
}
|
||||
return err
|
||||
ok := CheckHash(ident, pass, reqSalt, reqHash)
|
||||
logDebug("auth ok:", ok)
|
||||
if !ok {
|
||||
err = errors.New("auth ident or pass missmatch")
|
||||
context.SendError(err)
|
||||
return err
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func helloHandler(context *Context) error {
|
||||
var err error
|
||||
params := NewHelloParams()
|
||||
var err error
|
||||
params := NewHelloParams()
|
||||
|
||||
err = context.BindParams(params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = context.BindParams(params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = context.ReadBin(io.Discard)
|
||||
if err != nil {
|
||||
context.SendError(err)
|
||||
return err
|
||||
}
|
||||
err = context.ReadBin(io.Discard)
|
||||
if err != nil {
|
||||
context.SendError(err)
|
||||
return err
|
||||
}
|
||||
|
||||
result := NewHelloResult()
|
||||
result.Message = "hello, client!"
|
||||
result := NewHelloResult()
|
||||
result.Message = "hello, client!"
|
||||
|
||||
err = context.SendResult(result, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return err
|
||||
err = context.SendResult(result, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func saveHandler(context *Context) error {
|
||||
var err error
|
||||
params := NewSaveParams()
|
||||
var err error
|
||||
params := NewSaveParams()
|
||||
|
||||
err = context.BindParams(params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = context.BindParams(params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
bufferBytes := make([]byte, 0, 1024)
|
||||
binWriter := bytes.NewBuffer(bufferBytes)
|
||||
bufferBytes := make([]byte, 0, 1024)
|
||||
binWriter := bytes.NewBuffer(bufferBytes)
|
||||
|
||||
err = context.ReadBin(binWriter)
|
||||
if err != nil {
|
||||
context.SendError(err)
|
||||
return err
|
||||
}
|
||||
err = context.ReadBin(binWriter)
|
||||
if err != nil {
|
||||
context.SendError(err)
|
||||
return err
|
||||
}
|
||||
|
||||
result := NewSaveResult()
|
||||
result.Message = "saved successfully!"
|
||||
result := NewSaveResult()
|
||||
result.Message = "saved successfully!"
|
||||
|
||||
err = context.SendResult(result, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return err
|
||||
err = context.SendResult(result, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func loadHandler(context *Context) error {
|
||||
var err error
|
||||
params := NewSaveParams()
|
||||
var err error
|
||||
params := NewSaveParams()
|
||||
|
||||
err = context.BindParams(params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = context.BindParams(params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = context.ReadBin(io.Discard)
|
||||
if err != nil {
|
||||
context.SendError(err)
|
||||
return err
|
||||
}
|
||||
err = context.ReadBin(io.Discard)
|
||||
if err != nil {
|
||||
context.SendError(err)
|
||||
return err
|
||||
}
|
||||
|
||||
var binSize int64 = 1024
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
binBytes := make([]byte, binSize)
|
||||
rand.Read(binBytes)
|
||||
var binSize int64 = 1024
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
binBytes := make([]byte, binSize)
|
||||
rand.Read(binBytes)
|
||||
|
||||
binReader := bytes.NewReader(binBytes)
|
||||
binReader := bytes.NewReader(binBytes)
|
||||
|
||||
result := NewSaveResult()
|
||||
result.Message = "load successfully!"
|
||||
result := NewSaveResult()
|
||||
result.Message = "load successfully!"
|
||||
|
||||
err = context.SendResult(result, binSize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
binWriter := context.BinWriter()
|
||||
_, err = CopyBytes(binReader, binWriter, binSize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = context.SendResult(result, binSize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
binWriter := context.BinWriter()
|
||||
_, err = CopyBytes(binReader, binWriter, binSize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return err
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
const HelloMethod string = "hello"
|
||||
|
||||
type HelloParams struct {
|
||||
Message string `json:"message" msgpack:"message"`
|
||||
Message string `json:"message" msgpack:"message"`
|
||||
}
|
||||
|
||||
func NewHelloParams() *HelloParams {
|
||||
return &HelloParams{}
|
||||
return &HelloParams{}
|
||||
}
|
||||
|
||||
type HelloResult struct {
|
||||
Message string `json:"message" msgpack:"message"`
|
||||
Message string `json:"message" msgpack:"message"`
|
||||
}
|
||||
|
||||
func NewHelloResult() *HelloResult {
|
||||
return &HelloResult{}
|
||||
return &HelloResult{}
|
||||
}
|
||||
|
||||
|
||||
const SaveMethod string = "save"
|
||||
|
||||
type SaveParams HelloParams
|
||||
type SaveResult HelloResult
|
||||
|
||||
func NewSaveParams() *SaveParams {
|
||||
return &SaveParams{}
|
||||
return &SaveParams{}
|
||||
}
|
||||
func NewSaveResult() *SaveResult {
|
||||
return &SaveResult{}
|
||||
return &SaveResult{}
|
||||
}
|
||||
|
||||
|
||||
|
||||
const LoadMethod string = "load"
|
||||
|
||||
type LoadParams HelloParams
|
||||
type LoadResult HelloResult
|
||||
|
||||
func NewLoadParams() *LoadParams {
|
||||
return &LoadParams{}
|
||||
return &LoadParams{}
|
||||
}
|
||||
func NewLoadResult() *LoadResult {
|
||||
return &LoadResult{}
|
||||
return &LoadResult{}
|
||||
}
|
||||
|
||||
16
faddr.go
16
faddr.go
@@ -5,21 +5,21 @@
|
||||
package dsrpc
|
||||
|
||||
type FAddr struct {
|
||||
network string
|
||||
address string
|
||||
network string
|
||||
address string
|
||||
}
|
||||
|
||||
func NewFAddr() *FAddr {
|
||||
var addr FAddr
|
||||
addr.network = "tcp"
|
||||
addr.address = "127.0.0.1:5000"
|
||||
return &addr
|
||||
var addr FAddr
|
||||
addr.network = "tcp"
|
||||
addr.address = "127.0.0.1:5000"
|
||||
return &addr
|
||||
}
|
||||
|
||||
func (addr *FAddr) Network() string {
|
||||
return addr.network
|
||||
return addr.network
|
||||
}
|
||||
|
||||
func (addr *FAddr) String() string {
|
||||
return addr.address
|
||||
return addr.address
|
||||
}
|
||||
|
||||
@@ -7,51 +7,51 @@
|
||||
package dsrpc
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/require"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFConn0(t *testing.T) {
|
||||
var cConn, sConn net.Conn
|
||||
sConn, cConn = NewFConn()
|
||||
var cConn, sConn net.Conn
|
||||
sConn, cConn = NewFConn()
|
||||
|
||||
cData := []byte("qwerty")
|
||||
count := 10
|
||||
cData := []byte("qwerty")
|
||||
count := 10
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
wc, err := cConn.Write(cData)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
require.Equal(t, wc, len(cData))
|
||||
for i := 0; i < count; i++ {
|
||||
wc, err := cConn.Write(cData)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
require.Equal(t, wc, len(cData))
|
||||
|
||||
sData := make([]byte, len(cData))
|
||||
rc, err := sConn.Read(sData)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, rc, len(cData))
|
||||
require.Equal(t, cData, sData)
|
||||
}
|
||||
sData := make([]byte, len(cData))
|
||||
rc, err := sConn.Read(sData)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, rc, len(cData))
|
||||
require.Equal(t, cData, sData)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFConn1(t *testing.T) {
|
||||
var cConn, sConn net.Conn
|
||||
cConn, sConn = NewFConn()
|
||||
var cConn, sConn net.Conn
|
||||
cConn, sConn = NewFConn()
|
||||
|
||||
cData := []byte("qwerty")
|
||||
count := 10
|
||||
cData := []byte("qwerty")
|
||||
count := 10
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
wc, err := cConn.Write(cData)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
require.Equal(t, wc, len(cData))
|
||||
for i := 0; i < count; i++ {
|
||||
wc, err := cConn.Write(cData)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
require.Equal(t, wc, len(cData))
|
||||
|
||||
sData := make([]byte, len(cData))
|
||||
rc, err := sConn.Read(sData)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, rc, len(cData))
|
||||
require.Equal(t, cData, sData)
|
||||
}
|
||||
sData := make([]byte, len(cData))
|
||||
rc, err := sConn.Read(sData)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, rc, len(cData))
|
||||
require.Equal(t, cData, sData)
|
||||
}
|
||||
}
|
||||
|
||||
58
fconn.go
58
fconn.go
@@ -7,62 +7,62 @@
|
||||
package dsrpc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type FConn struct {
|
||||
reader io.Reader
|
||||
writer io.Writer
|
||||
reader io.Reader
|
||||
writer io.Writer
|
||||
}
|
||||
|
||||
func NewFConn() (*FConn, *FConn){
|
||||
c2sBuffer := bytes.NewBuffer(make([]byte, 0))
|
||||
s2cBuffer := bytes.NewBuffer(make([]byte, 0))
|
||||
func NewFConn() (*FConn, *FConn) {
|
||||
c2sBuffer := bytes.NewBuffer(make([]byte, 0))
|
||||
s2cBuffer := bytes.NewBuffer(make([]byte, 0))
|
||||
|
||||
var client FConn
|
||||
client.writer = c2sBuffer
|
||||
client.reader = s2cBuffer
|
||||
var client FConn
|
||||
client.writer = c2sBuffer
|
||||
client.reader = s2cBuffer
|
||||
|
||||
var server FConn
|
||||
server.writer = s2cBuffer
|
||||
server.reader = c2sBuffer
|
||||
var server FConn
|
||||
server.writer = s2cBuffer
|
||||
server.reader = c2sBuffer
|
||||
|
||||
return &client, &server
|
||||
return &client, &server
|
||||
}
|
||||
|
||||
func (conn FConn) SetDeadline(t time.Time) error {
|
||||
var err error
|
||||
return err
|
||||
var err error
|
||||
return err
|
||||
}
|
||||
func (conn FConn) SetReadDeadline(t time.Time) error {
|
||||
var err error
|
||||
return err
|
||||
func (conn FConn) SetReadDeadline(t time.Time) error {
|
||||
var err error
|
||||
return err
|
||||
}
|
||||
func (conn FConn) SetWriteDeadline(t time.Time) error {
|
||||
var err error
|
||||
return err
|
||||
var err error
|
||||
return err
|
||||
}
|
||||
|
||||
func (conn FConn) LocalAddr() net.Addr {
|
||||
return NewFAddr()
|
||||
return NewFAddr()
|
||||
}
|
||||
|
||||
func (conn FConn) RemoteAddr() net.Addr {
|
||||
return NewFAddr()
|
||||
return NewFAddr()
|
||||
}
|
||||
|
||||
func (conn FConn) Write(data []byte) (int, error) {
|
||||
return conn.writer.Write(data)
|
||||
return conn.writer.Write(data)
|
||||
}
|
||||
|
||||
func (conn FConn) Read(data []byte) (int, error) {
|
||||
return conn.reader.Read(data)
|
||||
return conn.reader.Read(data)
|
||||
}
|
||||
|
||||
func (conn FConn) Close() error {
|
||||
var err error
|
||||
return err
|
||||
var err error
|
||||
return err
|
||||
}
|
||||
|
||||
110
header.go
110
header.go
@@ -7,92 +7,90 @@
|
||||
package dsrpc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"bytes"
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
)
|
||||
|
||||
const headerSize int64 = 16 * 2
|
||||
const sizeOfInt64 int = 8
|
||||
const magicCodeA int64 = 0xEE00ABBA
|
||||
const magicCodeB int64 = 0xEE44ABBA
|
||||
const headerSize int64 = 16 * 2
|
||||
const sizeOfInt64 int = 8
|
||||
const magicCodeA int64 = 0xEE00ABBA
|
||||
const magicCodeB int64 = 0xEE44ABBA
|
||||
|
||||
type Header struct {
|
||||
magicCodeA int64 `json:"magicCodeA"`
|
||||
rpcSize int64 `json:"rpcSize"`
|
||||
binSize int64 `json:"binSize"`
|
||||
magicCodeB int64 `json:"magicCodeB"`
|
||||
magicCodeA int64 `json:"magicCodeA"`
|
||||
rpcSize int64 `json:"rpcSize"`
|
||||
binSize int64 `json:"binSize"`
|
||||
magicCodeB int64 `json:"magicCodeB"`
|
||||
}
|
||||
|
||||
|
||||
func NewHeader() *Header {
|
||||
return &Header{
|
||||
magicCodeA: magicCodeA,
|
||||
magicCodeB: magicCodeB,
|
||||
}
|
||||
return &Header{
|
||||
magicCodeA: magicCodeA,
|
||||
magicCodeB: magicCodeB,
|
||||
}
|
||||
}
|
||||
|
||||
func (hdr *Header) JSON() []byte {
|
||||
jBytes, _ := json.Marshal(hdr)
|
||||
return jBytes
|
||||
jBytes, _ := json.Marshal(hdr)
|
||||
return jBytes
|
||||
}
|
||||
|
||||
|
||||
func (hdr *Header) Pack() ([]byte, error) {
|
||||
var err error
|
||||
headerBytes := make([]byte, 0, headerSize)
|
||||
headerBuffer := bytes.NewBuffer(headerBytes)
|
||||
var err error
|
||||
headerBytes := make([]byte, 0, headerSize)
|
||||
headerBuffer := bytes.NewBuffer(headerBytes)
|
||||
|
||||
magicCodeABytes := encoderI64(hdr.magicCodeA)
|
||||
headerBuffer.Write(magicCodeABytes)
|
||||
magicCodeABytes := encoderI64(hdr.magicCodeA)
|
||||
headerBuffer.Write(magicCodeABytes)
|
||||
|
||||
rpcSizeBytes := encoderI64(hdr.rpcSize)
|
||||
headerBuffer.Write(rpcSizeBytes)
|
||||
rpcSizeBytes := encoderI64(hdr.rpcSize)
|
||||
headerBuffer.Write(rpcSizeBytes)
|
||||
|
||||
binSizeBytes := encoderI64(hdr.binSize)
|
||||
headerBuffer.Write(binSizeBytes)
|
||||
binSizeBytes := encoderI64(hdr.binSize)
|
||||
headerBuffer.Write(binSizeBytes)
|
||||
|
||||
magicCodeBBytes := encoderI64(hdr.magicCodeB)
|
||||
headerBuffer.Write(magicCodeBBytes)
|
||||
magicCodeBBytes := encoderI64(hdr.magicCodeB)
|
||||
headerBuffer.Write(magicCodeBBytes)
|
||||
|
||||
return headerBuffer.Bytes(), Err(err)
|
||||
return headerBuffer.Bytes(), Err(err)
|
||||
}
|
||||
|
||||
func UnpackHeader(headerBytes []byte) (*Header, error) {
|
||||
var err error
|
||||
header := NewHeader()
|
||||
headerReader := bytes.NewReader(headerBytes)
|
||||
var err error
|
||||
header := NewHeader()
|
||||
headerReader := bytes.NewReader(headerBytes)
|
||||
|
||||
magicCodeABytes := make([]byte, sizeOfInt64)
|
||||
headerReader.Read(magicCodeABytes)
|
||||
header.magicCodeA = decoderI64(magicCodeABytes)
|
||||
magicCodeABytes := make([]byte, sizeOfInt64)
|
||||
headerReader.Read(magicCodeABytes)
|
||||
header.magicCodeA = decoderI64(magicCodeABytes)
|
||||
|
||||
rpcSizeBytes := make([]byte, sizeOfInt64)
|
||||
headerReader.Read(rpcSizeBytes)
|
||||
header.rpcSize = decoderI64(rpcSizeBytes)
|
||||
rpcSizeBytes := make([]byte, sizeOfInt64)
|
||||
headerReader.Read(rpcSizeBytes)
|
||||
header.rpcSize = decoderI64(rpcSizeBytes)
|
||||
|
||||
binSizeBytes := make([]byte, sizeOfInt64)
|
||||
headerReader.Read(binSizeBytes)
|
||||
header.binSize = decoderI64(binSizeBytes)
|
||||
binSizeBytes := make([]byte, sizeOfInt64)
|
||||
headerReader.Read(binSizeBytes)
|
||||
header.binSize = decoderI64(binSizeBytes)
|
||||
|
||||
magicCodeBBytes := make([]byte, sizeOfInt64)
|
||||
headerReader.Read(magicCodeBBytes)
|
||||
header.magicCodeB = decoderI64(magicCodeBBytes)
|
||||
magicCodeBBytes := make([]byte, sizeOfInt64)
|
||||
headerReader.Read(magicCodeBBytes)
|
||||
header.magicCodeB = decoderI64(magicCodeBBytes)
|
||||
|
||||
if header.magicCodeA != magicCodeA || header.magicCodeB != magicCodeB {
|
||||
err = errors.New("wrong protocol magic code")
|
||||
return header, Err(err)
|
||||
}
|
||||
return header, Err(err)
|
||||
if header.magicCodeA != magicCodeA || header.magicCodeB != magicCodeB {
|
||||
err = errors.New("wrong protocol magic code")
|
||||
return header, Err(err)
|
||||
}
|
||||
return header, Err(err)
|
||||
}
|
||||
|
||||
func encoderI64(i int64) []byte {
|
||||
buffer := make([]byte, sizeOfInt64)
|
||||
binary.BigEndian.PutUint64(buffer, uint64(i))
|
||||
return buffer
|
||||
buffer := make([]byte, sizeOfInt64)
|
||||
binary.BigEndian.PutUint64(buffer, uint64(i))
|
||||
return buffer
|
||||
}
|
||||
|
||||
func decoderI64(b []byte) int64 {
|
||||
return int64(binary.BigEndian.Uint64(b))
|
||||
return int64(binary.BigEndian.Uint64(b))
|
||||
}
|
||||
|
||||
28
logger.go
28
logger.go
@@ -7,39 +7,39 @@
|
||||
package dsrpc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"time"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
var messageWriter io.Writer = os.Stdout
|
||||
var accessWriter io.Writer = os.Stdout
|
||||
|
||||
func logDebug(messages ...any) {
|
||||
stamp := time.Now().Format(time.RFC3339)
|
||||
fmt.Fprintln(messageWriter, stamp, "debug", messages)
|
||||
stamp := time.Now().Format(time.RFC3339)
|
||||
fmt.Fprintln(messageWriter, stamp, "debug", messages)
|
||||
}
|
||||
|
||||
func logInfo(messages ...any) {
|
||||
stamp := time.Now().Format(time.RFC3339)
|
||||
fmt.Fprintln(messageWriter, stamp, "info", messages)
|
||||
stamp := time.Now().Format(time.RFC3339)
|
||||
fmt.Fprintln(messageWriter, stamp, "info", messages)
|
||||
}
|
||||
|
||||
func logError(messages ...any) {
|
||||
stamp := time.Now().Format(time.RFC3339)
|
||||
fmt.Fprintln(messageWriter, stamp, "error", messages)
|
||||
stamp := time.Now().Format(time.RFC3339)
|
||||
fmt.Fprintln(messageWriter, stamp, "error", messages)
|
||||
}
|
||||
|
||||
func logAccess(messages ...any) {
|
||||
stamp := time.Now().Format(time.RFC3339)
|
||||
fmt.Fprintln(accessWriter, stamp, "access", messages)
|
||||
stamp := time.Now().Format(time.RFC3339)
|
||||
fmt.Fprintln(accessWriter, stamp, "access", messages)
|
||||
}
|
||||
|
||||
func SetAccessWriter(writer io.Writer) {
|
||||
accessWriter = writer
|
||||
accessWriter = writer
|
||||
}
|
||||
|
||||
func SetMessageWriter(writer io.Writer) {
|
||||
messageWriter = writer
|
||||
messageWriter = writer
|
||||
}
|
||||
|
||||
24
midware.go
24
midware.go
@@ -5,25 +5,25 @@
|
||||
package dsrpc
|
||||
|
||||
import (
|
||||
"time"
|
||||
"time"
|
||||
)
|
||||
|
||||
func LogRequest(context *Context) error {
|
||||
var err error
|
||||
logDebug("request:", string(context.reqRPC.JSON()))
|
||||
return Err(err)
|
||||
var err error
|
||||
logDebug("request:", string(context.reqRPC.JSON()))
|
||||
return Err(err)
|
||||
}
|
||||
|
||||
func LogResponse(context *Context) error {
|
||||
var err error
|
||||
logDebug("response:", string(context.resRPC.JSON()))
|
||||
return Err(err)
|
||||
var err error
|
||||
logDebug("response:", string(context.resRPC.JSON()))
|
||||
return Err(err)
|
||||
}
|
||||
|
||||
func LogAccess(context *Context) error {
|
||||
var err error
|
||||
execTime := time.Now().Sub(context.start)
|
||||
login := string(context.AuthIdent())
|
||||
logAccess(context.remoteHost, login, context.reqRPC.Method, execTime)
|
||||
return Err(err)
|
||||
var err error
|
||||
execTime := time.Now().Sub(context.start)
|
||||
login := string(context.AuthIdent())
|
||||
logAccess(context.remoteHost, login, context.reqRPC.Method, execTime)
|
||||
return Err(err)
|
||||
}
|
||||
|
||||
12
packet.go
12
packet.go
@@ -5,19 +5,19 @@
|
||||
package dsrpc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type Packet struct {
|
||||
header []byte
|
||||
rcpPayload []byte
|
||||
header []byte
|
||||
rcpPayload []byte
|
||||
}
|
||||
|
||||
func NewPacket() *Packet {
|
||||
return &Packet{}
|
||||
return &Packet{}
|
||||
}
|
||||
|
||||
func (pkt *Packet) JSON() []byte {
|
||||
jBytes, _ := json.Marshal(pkt)
|
||||
return jBytes
|
||||
jBytes, _ := json.Marshal(pkt)
|
||||
return jBytes
|
||||
}
|
||||
|
||||
24
request.go
24
request.go
@@ -7,28 +7,28 @@
|
||||
package dsrpc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
encoder "github.com/vmihailenco/msgpack/v5"
|
||||
"encoding/json"
|
||||
encoder "github.com/vmihailenco/msgpack/v5"
|
||||
)
|
||||
|
||||
type Request struct {
|
||||
Method string `json:"method" msgpack:"method"`
|
||||
Params any `json:"params,omitempty" msgpack:"params"`
|
||||
Auth *Auth `json:"auth,omitempty" msgpack:"auth"`
|
||||
Method string `json:"method" msgpack:"method"`
|
||||
Params any `json:"params,omitempty" msgpack:"params"`
|
||||
Auth *Auth `json:"auth,omitempty" msgpack:"auth"`
|
||||
}
|
||||
|
||||
func NewRequest() *Request {
|
||||
req := &Request{}
|
||||
req.Auth = &Auth{}
|
||||
return req
|
||||
req := &Request{}
|
||||
req.Auth = &Auth{}
|
||||
return req
|
||||
}
|
||||
|
||||
func (req *Request) Pack() ([]byte, error) {
|
||||
rBytes, err := encoder.Marshal(req)
|
||||
return rBytes, Err(err)
|
||||
rBytes, err := encoder.Marshal(req)
|
||||
return rBytes, Err(err)
|
||||
}
|
||||
|
||||
func (req *Request) JSON() []byte {
|
||||
jBytes, _ := json.Marshal(req)
|
||||
return jBytes
|
||||
jBytes, _ := json.Marshal(req)
|
||||
return jBytes
|
||||
}
|
||||
|
||||
19
response.go
19
response.go
@@ -5,26 +5,25 @@
|
||||
package dsrpc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
encoder "github.com/vmihailenco/msgpack/v5"
|
||||
"encoding/json"
|
||||
encoder "github.com/vmihailenco/msgpack/v5"
|
||||
)
|
||||
|
||||
|
||||
type Response struct {
|
||||
Error string `json:"error" msgpack:"error"`
|
||||
Result any `json:"result" msgpack:"result"`
|
||||
Error string `json:"error" msgpack:"error"`
|
||||
Result any `json:"result" msgpack:"result"`
|
||||
}
|
||||
|
||||
func NewResponse() *Response {
|
||||
return &Response{}
|
||||
return &Response{}
|
||||
}
|
||||
|
||||
func (resp *Response) JSON() []byte {
|
||||
jBytes, _ := json.Marshal(resp)
|
||||
return jBytes
|
||||
jBytes, _ := json.Marshal(resp)
|
||||
return jBytes
|
||||
}
|
||||
|
||||
func (resp *Response) Pack() ([]byte, error) {
|
||||
rBytes, err := encoder.Marshal(resp)
|
||||
return rBytes, Err(err)
|
||||
rBytes, err := encoder.Marshal(resp)
|
||||
return rBytes, Err(err)
|
||||
}
|
||||
|
||||
438
server.go
438
server.go
@@ -5,305 +5,303 @@
|
||||
package dsrpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
encoder "github.com/vmihailenco/msgpack/v5"
|
||||
encoder "github.com/vmihailenco/msgpack/v5"
|
||||
)
|
||||
|
||||
type HandlerFunc = func(*Context) error
|
||||
type HandlerFunc = func(*Context) error
|
||||
|
||||
type Service struct {
|
||||
handlers map[string]HandlerFunc
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg *sync.WaitGroup
|
||||
preMw []HandlerFunc
|
||||
postMw []HandlerFunc
|
||||
keepalive bool
|
||||
kaTime time.Duration
|
||||
kaMtx sync.Mutex
|
||||
handlers map[string]HandlerFunc
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg *sync.WaitGroup
|
||||
preMw []HandlerFunc
|
||||
postMw []HandlerFunc
|
||||
keepalive bool
|
||||
kaTime time.Duration
|
||||
kaMtx sync.Mutex
|
||||
}
|
||||
|
||||
func NewService() *Service {
|
||||
rdrpc := &Service{}
|
||||
rdrpc.handlers = make(map[string]HandlerFunc)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
rdrpc.ctx = ctx
|
||||
rdrpc.cancel = cancel
|
||||
var wg sync.WaitGroup
|
||||
rdrpc.wg = &wg
|
||||
rdrpc.preMw = make([]HandlerFunc, 0)
|
||||
rdrpc.postMw = make([]HandlerFunc, 0)
|
||||
rdrpc := &Service{}
|
||||
rdrpc.handlers = make(map[string]HandlerFunc)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
rdrpc.ctx = ctx
|
||||
rdrpc.cancel = cancel
|
||||
var wg sync.WaitGroup
|
||||
rdrpc.wg = &wg
|
||||
rdrpc.preMw = make([]HandlerFunc, 0)
|
||||
rdrpc.postMw = make([]HandlerFunc, 0)
|
||||
|
||||
return rdrpc
|
||||
return rdrpc
|
||||
}
|
||||
|
||||
func (svc *Service) PreMiddleware(mw HandlerFunc) {
|
||||
svc.preMw = append(svc.preMw, mw)
|
||||
svc.preMw = append(svc.preMw, mw)
|
||||
}
|
||||
|
||||
func (svc *Service) PostMiddleware(mw HandlerFunc) {
|
||||
svc.postMw = append(svc.postMw, mw)
|
||||
svc.postMw = append(svc.postMw, mw)
|
||||
}
|
||||
|
||||
func (svc *Service) Handler(method string, handler HandlerFunc) {
|
||||
svc.handlers[method] = handler
|
||||
svc.handlers[method] = handler
|
||||
}
|
||||
|
||||
func (svc *Service) SetKeepAlive(flag bool) {
|
||||
svc.kaMtx.Lock()
|
||||
defer svc.kaMtx.Unlock()
|
||||
svc.keepalive = true
|
||||
svc.kaMtx.Lock()
|
||||
defer svc.kaMtx.Unlock()
|
||||
svc.keepalive = true
|
||||
}
|
||||
|
||||
func (svc *Service) SetKeepAlivePeriod(interval time.Duration) {
|
||||
svc.kaMtx.Lock()
|
||||
defer svc.kaMtx.Unlock()
|
||||
svc.kaTime = interval
|
||||
svc.kaMtx.Lock()
|
||||
defer svc.kaMtx.Unlock()
|
||||
svc.kaTime = interval
|
||||
}
|
||||
|
||||
func (svc *Service) Listen(address string) error {
|
||||
var err error
|
||||
logInfo("server listen:", address)
|
||||
var err error
|
||||
logInfo("server listen:", address)
|
||||
|
||||
addr, err := net.ResolveTCPAddr("tcp", address)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to resolve adddress: %s", err)
|
||||
return err
|
||||
}
|
||||
listener, err := net.ListenTCP("tcp", addr)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to start listener: %s", err)
|
||||
return err
|
||||
}
|
||||
addr, err := net.ResolveTCPAddr("tcp", address)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to resolve adddress: %s", err)
|
||||
return err
|
||||
}
|
||||
listener, err := net.ListenTCP("tcp", addr)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to start listener: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
for {
|
||||
conn, err := listener.AcceptTCP()
|
||||
if err != nil {
|
||||
logError("conn accept err:", err)
|
||||
}
|
||||
select {
|
||||
case <-svc.ctx.Done():
|
||||
return err
|
||||
default:
|
||||
}
|
||||
svc.wg.Add(1)
|
||||
go svc.handleConn(conn, svc.wg)
|
||||
}
|
||||
return err
|
||||
for {
|
||||
conn, err := listener.AcceptTCP()
|
||||
if err != nil {
|
||||
logError("conn accept err:", err)
|
||||
}
|
||||
select {
|
||||
case <-svc.ctx.Done():
|
||||
return err
|
||||
default:
|
||||
}
|
||||
svc.wg.Add(1)
|
||||
go svc.handleConn(conn, svc.wg)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func notFound(context *Context) error {
|
||||
execErr := errors.New("method not found")
|
||||
err := context.SendError(execErr)
|
||||
return err
|
||||
execErr := errors.New("method not found")
|
||||
err := context.SendError(execErr)
|
||||
return err
|
||||
}
|
||||
|
||||
func (svc *Service) Stop() error {
|
||||
var err error
|
||||
// Disable new connection
|
||||
logInfo("cancel rpc accept loop")
|
||||
svc.cancel()
|
||||
// Wait handlers
|
||||
logInfo("wait rpc handlers")
|
||||
svc.wg.Wait()
|
||||
return err
|
||||
var err error
|
||||
// Disable new connection
|
||||
logInfo("cancel rpc accept loop")
|
||||
svc.cancel()
|
||||
// Wait handlers
|
||||
logInfo("wait rpc handlers")
|
||||
svc.wg.Wait()
|
||||
return err
|
||||
}
|
||||
|
||||
func (svc *Service) handleConn(conn *net.TCPConn, wg *sync.WaitGroup) {
|
||||
var err error
|
||||
var err error
|
||||
|
||||
if svc.keepalive {
|
||||
err = conn.SetKeepAlive(true)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to set keepalive: %s", err)
|
||||
return
|
||||
}
|
||||
if svc.kaTime > 0 {
|
||||
err = conn.SetKeepAlivePeriod(svc.kaTime)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to set keepalive period: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
context := CreateContext(conn)
|
||||
if svc.keepalive {
|
||||
err = conn.SetKeepAlive(true)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to set keepalive: %s", err)
|
||||
return
|
||||
}
|
||||
if svc.kaTime > 0 {
|
||||
err = conn.SetKeepAlivePeriod(svc.kaTime)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to set keepalive period: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
context := CreateContext(conn)
|
||||
|
||||
remoteAddr := conn.RemoteAddr().String()
|
||||
remoteHost, _, _ := net.SplitHostPort(remoteAddr)
|
||||
context.remoteHost = remoteHost
|
||||
remoteAddr := conn.RemoteAddr().String()
|
||||
remoteHost, _, _ := net.SplitHostPort(remoteAddr)
|
||||
context.remoteHost = remoteHost
|
||||
|
||||
context.binReader = conn
|
||||
context.binWriter = io.Discard
|
||||
context.binReader = conn
|
||||
context.binWriter = io.Discard
|
||||
|
||||
exitFunc := func() {
|
||||
conn.Close()
|
||||
wg.Done()
|
||||
if err != nil {
|
||||
logError("conn handler err:", err)
|
||||
}
|
||||
}
|
||||
defer exitFunc()
|
||||
exitFunc := func() {
|
||||
conn.Close()
|
||||
wg.Done()
|
||||
if err != nil {
|
||||
logError("conn handler err:", err)
|
||||
}
|
||||
}
|
||||
defer exitFunc()
|
||||
|
||||
recovFunc := func () {
|
||||
panicMsg := recover()
|
||||
if panicMsg != nil {
|
||||
logError("handler panic message:", panicMsg)
|
||||
}
|
||||
}
|
||||
defer recovFunc()
|
||||
recovFunc := func() {
|
||||
panicMsg := recover()
|
||||
if panicMsg != nil {
|
||||
logError("handler panic message:", panicMsg)
|
||||
}
|
||||
}
|
||||
defer recovFunc()
|
||||
|
||||
err = context.ReadRequest()
|
||||
if err != nil {
|
||||
err = Err(err)
|
||||
return
|
||||
}
|
||||
err = context.ReadRequest()
|
||||
if err != nil {
|
||||
err = Err(err)
|
||||
return
|
||||
}
|
||||
|
||||
err = context.BindMethod()
|
||||
if err != nil {
|
||||
err = Err(err)
|
||||
return
|
||||
}
|
||||
for _, mw := range svc.preMw {
|
||||
err = mw(context)
|
||||
if err != nil {
|
||||
err = Err(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
err = svc.Route(context)
|
||||
if err != nil {
|
||||
err = Err(err)
|
||||
return
|
||||
}
|
||||
for _, mw := range svc.postMw {
|
||||
err = mw(context)
|
||||
if err != nil {
|
||||
err = Err(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
err = context.BindMethod()
|
||||
if err != nil {
|
||||
err = Err(err)
|
||||
return
|
||||
}
|
||||
for _, mw := range svc.preMw {
|
||||
err = mw(context)
|
||||
if err != nil {
|
||||
err = Err(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
err = svc.Route(context)
|
||||
if err != nil {
|
||||
err = Err(err)
|
||||
return
|
||||
}
|
||||
for _, mw := range svc.postMw {
|
||||
err = mw(context)
|
||||
if err != nil {
|
||||
err = Err(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (svc *Service) Route(context *Context) error {
|
||||
handler, ok := svc.handlers[context.reqRPC.Method]
|
||||
if ok {
|
||||
return Err(handler(context))
|
||||
}
|
||||
return Err(notFound(context))
|
||||
handler, ok := svc.handlers[context.reqRPC.Method]
|
||||
if ok {
|
||||
return Err(handler(context))
|
||||
}
|
||||
return Err(notFound(context))
|
||||
}
|
||||
|
||||
func (context *Context) ReadRequest() error {
|
||||
var err error
|
||||
var err error
|
||||
|
||||
context.reqPacket.header, err = ReadBytes(context.sockReader, headerSize)
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
context.reqHeader, err = UnpackHeader(context.reqPacket.header)
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
context.reqPacket.header, err = ReadBytes(context.sockReader, headerSize)
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
context.reqHeader, err = UnpackHeader(context.reqPacket.header)
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
|
||||
rpcSize := context.reqHeader.rpcSize
|
||||
context.reqPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize)
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
return Err(err)
|
||||
rpcSize := context.reqHeader.rpcSize
|
||||
context.reqPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize)
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
return Err(err)
|
||||
}
|
||||
|
||||
func (context *Context) BinWriter() io.Writer {
|
||||
return context.sockWriter
|
||||
return context.sockWriter
|
||||
}
|
||||
|
||||
func (context *Context) BinReader() io.Reader {
|
||||
return context.sockReader
|
||||
return context.sockReader
|
||||
}
|
||||
|
||||
func (context *Context) BinSize() int64 {
|
||||
return context.reqHeader.binSize
|
||||
return context.reqHeader.binSize
|
||||
}
|
||||
|
||||
func (context *Context) ReadBin(writer io.Writer) error {
|
||||
var err error
|
||||
_, err = CopyBytes(context.sockReader, writer, context.reqHeader.binSize)
|
||||
return Err(err)
|
||||
var err error
|
||||
_, err = CopyBytes(context.sockReader, writer, context.reqHeader.binSize)
|
||||
return Err(err)
|
||||
}
|
||||
|
||||
|
||||
func (context *Context) BindMethod() error {
|
||||
var err error
|
||||
err = encoder.Unmarshal(context.reqPacket.rcpPayload, context.reqRPC)
|
||||
return Err(err)
|
||||
var err error
|
||||
err = encoder.Unmarshal(context.reqPacket.rcpPayload, context.reqRPC)
|
||||
return Err(err)
|
||||
}
|
||||
|
||||
func (context *Context) BindParams(params any) error {
|
||||
var err error
|
||||
context.reqRPC.Params = params
|
||||
err = encoder.Unmarshal(context.reqPacket.rcpPayload, context.reqRPC)
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
return Err(err)
|
||||
var err error
|
||||
context.reqRPC.Params = params
|
||||
err = encoder.Unmarshal(context.reqPacket.rcpPayload, context.reqRPC)
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
return Err(err)
|
||||
}
|
||||
|
||||
func (context *Context) SendResult(result any, binSize int64) error {
|
||||
var err error
|
||||
context.resRPC.Result = result
|
||||
var err error
|
||||
context.resRPC.Result = result
|
||||
|
||||
context.resPacket.rcpPayload, err = context.resRPC.Pack()
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
context.resHeader.rpcSize = int64(len(context.resPacket.rcpPayload))
|
||||
context.resHeader.binSize = binSize
|
||||
context.resPacket.rcpPayload, err = context.resRPC.Pack()
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
context.resHeader.rpcSize = int64(len(context.resPacket.rcpPayload))
|
||||
context.resHeader.binSize = binSize
|
||||
|
||||
context.resPacket.header, err = context.resHeader.Pack()
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
_, err = context.sockWriter.Write(context.resPacket.header)
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
_, err = context.sockWriter.Write(context.resPacket.rcpPayload)
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
return Err(err)
|
||||
context.resPacket.header, err = context.resHeader.Pack()
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
_, err = context.sockWriter.Write(context.resPacket.header)
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
_, err = context.sockWriter.Write(context.resPacket.rcpPayload)
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
return Err(err)
|
||||
}
|
||||
|
||||
|
||||
func (context *Context) SendError(execErr error) error {
|
||||
var err error
|
||||
var err error
|
||||
|
||||
context.resRPC.Error = execErr.Error()
|
||||
context.resRPC.Result = NewEmpty()
|
||||
context.resRPC.Error = execErr.Error()
|
||||
context.resRPC.Result = NewEmpty()
|
||||
|
||||
context.resPacket.rcpPayload, err = context.resRPC.Pack()
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
context.resHeader.rpcSize = int64(len(context.resPacket.rcpPayload))
|
||||
context.resPacket.header, err = context.resHeader.Pack()
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
_, err = context.sockWriter.Write(context.resPacket.header)
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
_, err = context.sockWriter.Write(context.resPacket.rcpPayload)
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
return Err(err)
|
||||
context.resPacket.rcpPayload, err = context.resRPC.Pack()
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
context.resHeader.rpcSize = int64(len(context.resPacket.rcpPayload))
|
||||
context.resPacket.header, err = context.resHeader.Pack()
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
_, err = context.sockWriter.Write(context.resPacket.header)
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
_, err = context.sockWriter.Write(context.resPacket.rcpPayload)
|
||||
if err != nil {
|
||||
return Err(err)
|
||||
}
|
||||
return Err(err)
|
||||
}
|
||||
|
||||
84
tools.go
84
tools.go
@@ -5,53 +5,53 @@
|
||||
package dsrpc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"fmt"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
func ReadBytes(reader io.Reader, size int64) ([]byte, error) {
|
||||
buffer := make([]byte, size)
|
||||
read, err := io.ReadFull(reader, buffer)
|
||||
return buffer[0:read], Err(err)
|
||||
buffer := make([]byte, size)
|
||||
read, err := io.ReadFull(reader, buffer)
|
||||
return buffer[0:read], Err(err)
|
||||
}
|
||||
|
||||
func CopyBytes(reader io.Reader, writer io.Writer, dataSize int64) (int64, error) {
|
||||
var err error
|
||||
var bSize int64 = 1024 * 16
|
||||
var total int64 = 0
|
||||
var remains int64 = dataSize
|
||||
buffer := make([]byte, bSize)
|
||||
var err error
|
||||
var bSize int64 = 1024 * 16
|
||||
var total int64 = 0
|
||||
var remains int64 = dataSize
|
||||
buffer := make([]byte, bSize)
|
||||
|
||||
for {
|
||||
if reader == nil {
|
||||
return total, errors.New("reader is nil")
|
||||
}
|
||||
if writer == nil {
|
||||
return total, errors.New("writer is nil")
|
||||
}
|
||||
if remains == 0 {
|
||||
return total, err
|
||||
}
|
||||
if remains < bSize {
|
||||
bSize = remains
|
||||
}
|
||||
received, err := reader.Read(buffer[0:bSize])
|
||||
if err != nil {
|
||||
err = fmt.Errorf("read error: %v", err)
|
||||
return total, Err(err)
|
||||
}
|
||||
recorded, err := writer.Write(buffer[0:received])
|
||||
if err != nil {
|
||||
err = fmt.Errorf("write error: %v", err)
|
||||
return total, Err(err)
|
||||
}
|
||||
if recorded != received {
|
||||
err = errors.New("size mismatch")
|
||||
return total, Err(err)
|
||||
}
|
||||
total += int64(recorded)
|
||||
remains -= int64(recorded)
|
||||
}
|
||||
return total, Err(err)
|
||||
for {
|
||||
if reader == nil {
|
||||
return total, errors.New("reader is nil")
|
||||
}
|
||||
if writer == nil {
|
||||
return total, errors.New("writer is nil")
|
||||
}
|
||||
if remains == 0 {
|
||||
return total, err
|
||||
}
|
||||
if remains < bSize {
|
||||
bSize = remains
|
||||
}
|
||||
received, err := reader.Read(buffer[0:bSize])
|
||||
if err != nil {
|
||||
err = fmt.Errorf("read error: %v", err)
|
||||
return total, Err(err)
|
||||
}
|
||||
recorded, err := writer.Write(buffer[0:received])
|
||||
if err != nil {
|
||||
err = fmt.Errorf("write error: %v", err)
|
||||
return total, Err(err)
|
||||
}
|
||||
if recorded != received {
|
||||
err = errors.New("size mismatch")
|
||||
return total, Err(err)
|
||||
}
|
||||
total += int64(recorded)
|
||||
remains -= int64(recorded)
|
||||
}
|
||||
return total, Err(err)
|
||||
}
|
||||
|
||||
250
validate.go
250
validate.go
@@ -4,161 +4,159 @@
|
||||
|
||||
package dsrpc
|
||||
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
func LocalExec(method string, param any, result any, auth *Auth, handler HandlerFunc) error {
|
||||
var err error
|
||||
var err error
|
||||
|
||||
cliConn, srvConn := NewFConn()
|
||||
cliConn, srvConn := NewFConn()
|
||||
|
||||
context := CreateContext(cliConn)
|
||||
context.reqRPC.Method = method
|
||||
context.reqRPC.Params = param
|
||||
context.reqRPC.Auth = auth
|
||||
context.resRPC.Result = result
|
||||
context := CreateContext(cliConn)
|
||||
context.reqRPC.Method = method
|
||||
context.reqRPC.Params = param
|
||||
context.reqRPC.Auth = auth
|
||||
context.resRPC.Result = result
|
||||
|
||||
if context.reqRPC.Params == nil {
|
||||
context.reqRPC.Params = NewEmpty()
|
||||
}
|
||||
err = context.CreateRequest()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = context.WriteRequest()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = LocalService(srvConn, handler)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = context.ReadResponse()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = context.BindResponse()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if context.reqRPC.Params == nil {
|
||||
context.reqRPC.Params = NewEmpty()
|
||||
}
|
||||
err = context.CreateRequest()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = context.WriteRequest()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = LocalService(srvConn, handler)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = context.ReadResponse()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = context.BindResponse()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return err
|
||||
return err
|
||||
}
|
||||
|
||||
func LocalPut(method string, reader io.Reader, size int64, param, result any, auth *Auth, handler HandlerFunc) error {
|
||||
|
||||
var err error
|
||||
var err error
|
||||
|
||||
cliConn, srvConn := NewFConn()
|
||||
cliConn, srvConn := NewFConn()
|
||||
|
||||
context := CreateContext(cliConn)
|
||||
context.reqRPC.Method = method
|
||||
context.reqRPC.Params = param
|
||||
context.reqRPC.Auth = auth
|
||||
context.resRPC.Result = result
|
||||
context := CreateContext(cliConn)
|
||||
context.reqRPC.Method = method
|
||||
context.reqRPC.Params = param
|
||||
context.reqRPC.Auth = auth
|
||||
context.resRPC.Result = result
|
||||
|
||||
context.binReader = reader
|
||||
context.binWriter = cliConn
|
||||
context.binReader = reader
|
||||
context.binWriter = cliConn
|
||||
|
||||
context.reqHeader.binSize = size
|
||||
context.reqHeader.binSize = size
|
||||
|
||||
if context.reqRPC.Params == nil {
|
||||
context.reqRPC.Params = NewEmpty()
|
||||
}
|
||||
err = context.CreateRequest()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = context.WriteRequest()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = context.UploadBin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = LocalService(srvConn, handler)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = context.ReadResponse()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = context.BindResponse()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return err
|
||||
if context.reqRPC.Params == nil {
|
||||
context.reqRPC.Params = NewEmpty()
|
||||
}
|
||||
err = context.CreateRequest()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = context.WriteRequest()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = context.UploadBin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = LocalService(srvConn, handler)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = context.ReadResponse()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = context.BindResponse()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
func LocalGet(method string, writer io.Writer, param, result any, auth *Auth, handler HandlerFunc) error {
|
||||
var err error
|
||||
var err error
|
||||
|
||||
cliConn, srvConn := NewFConn()
|
||||
cliConn, srvConn := NewFConn()
|
||||
|
||||
context := CreateContext(cliConn)
|
||||
context.reqRPC.Method = method
|
||||
context.reqRPC.Params = param
|
||||
context.reqRPC.Auth = auth
|
||||
context.resRPC.Result = result
|
||||
context := CreateContext(cliConn)
|
||||
context.reqRPC.Method = method
|
||||
context.reqRPC.Params = param
|
||||
context.reqRPC.Auth = auth
|
||||
context.resRPC.Result = result
|
||||
|
||||
context.binReader = cliConn
|
||||
context.binWriter = writer
|
||||
context.binReader = cliConn
|
||||
context.binWriter = writer
|
||||
|
||||
if context.reqRPC.Params == nil {
|
||||
context.reqRPC.Params = NewEmpty()
|
||||
}
|
||||
err = context.CreateRequest()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = context.WriteRequest()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if context.reqRPC.Params == nil {
|
||||
context.reqRPC.Params = NewEmpty()
|
||||
}
|
||||
err = context.CreateRequest()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = context.WriteRequest()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = LocalService(srvConn, handler)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = context.ReadResponse()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = context.DownloadBin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = context.BindResponse()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return err
|
||||
err = LocalService(srvConn, handler)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = context.ReadResponse()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = context.DownloadBin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = context.BindResponse()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func LocalService(conn net.Conn, handler HandlerFunc) error {
|
||||
var err error
|
||||
context := CreateContext(conn)
|
||||
var err error
|
||||
context := CreateContext(conn)
|
||||
|
||||
remoteAddr := conn.RemoteAddr().String()
|
||||
remoteHost, _, _ := net.SplitHostPort(remoteAddr)
|
||||
context.remoteHost = remoteHost
|
||||
remoteAddr := conn.RemoteAddr().String()
|
||||
remoteHost, _, _ := net.SplitHostPort(remoteAddr)
|
||||
context.remoteHost = remoteHost
|
||||
|
||||
context.binReader = conn
|
||||
context.binWriter = io.Discard
|
||||
context.binReader = conn
|
||||
context.binWriter = io.Discard
|
||||
|
||||
err = context.ReadRequest()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = context.BindMethod()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return handler(context)
|
||||
err = context.ReadRequest()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = context.BindMethod()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return handler(context)
|
||||
}
|
||||
|
||||
64
xauth.go
64
xauth.go
@@ -5,60 +5,60 @@
|
||||
package dsrpc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"bytes"
|
||||
"math/rand"
|
||||
"time"
|
||||
"crypto/sha256"
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
func init() {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
type Auth struct {
|
||||
Ident []byte `msgpack:"ident" json:"ident"`
|
||||
Salt []byte `msgpack:"salt" json:"salt"`
|
||||
Hash []byte `msgpack:"hash" json:"hash"`
|
||||
Ident []byte `msgpack:"ident" json:"ident"`
|
||||
Salt []byte `msgpack:"salt" json:"salt"`
|
||||
Hash []byte `msgpack:"hash" json:"hash"`
|
||||
}
|
||||
|
||||
func NewAuth() *Auth {
|
||||
return &Auth{}
|
||||
return &Auth{}
|
||||
}
|
||||
|
||||
func (this *Auth) JSON() []byte {
|
||||
jBytes, _ := json.Marshal(this)
|
||||
return jBytes
|
||||
jBytes, _ := json.Marshal(this)
|
||||
return jBytes
|
||||
}
|
||||
|
||||
func CreateAuth(ident, pass []byte) *Auth {
|
||||
salt := CreateSalt()
|
||||
hash := CreateHash(ident, pass, salt)
|
||||
auth := &Auth{}
|
||||
auth.Ident = ident
|
||||
auth.Salt = salt
|
||||
auth.Hash = hash
|
||||
return auth
|
||||
salt := CreateSalt()
|
||||
hash := CreateHash(ident, pass, salt)
|
||||
auth := &Auth{}
|
||||
auth.Ident = ident
|
||||
auth.Salt = salt
|
||||
auth.Hash = hash
|
||||
return auth
|
||||
}
|
||||
|
||||
func CreateSalt() []byte {
|
||||
const saltSize = 16
|
||||
randBytes := make([]byte, saltSize)
|
||||
rand.Read(randBytes)
|
||||
return randBytes
|
||||
const saltSize = 16
|
||||
randBytes := make([]byte, saltSize)
|
||||
rand.Read(randBytes)
|
||||
return randBytes
|
||||
}
|
||||
|
||||
func CreateHash(ident, pass, salt []byte) []byte {
|
||||
vec := make([]byte, 0, len(ident) + len(salt) + len(pass))
|
||||
vec = append(vec, ident...)
|
||||
vec = append(vec, salt...)
|
||||
vec = append(vec, pass...)
|
||||
hasher := sha256.New()
|
||||
hash := hasher.Sum(vec)
|
||||
return hash
|
||||
vec := make([]byte, 0, len(ident)+len(salt)+len(pass))
|
||||
vec = append(vec, ident...)
|
||||
vec = append(vec, salt...)
|
||||
vec = append(vec, pass...)
|
||||
hasher := sha256.New()
|
||||
hash := hasher.Sum(vec)
|
||||
return hash
|
||||
}
|
||||
|
||||
func CheckHash(ident, pass, reqSalt, reqHash []byte) bool {
|
||||
localHash := CreateHash(ident, pass, reqSalt)
|
||||
return bytes.Equal(reqHash, localHash)
|
||||
localHash := CreateHash(ident, pass, reqSalt)
|
||||
return bytes.Equal(reqHash, localHash)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user