1 Commits

10 changed files with 333 additions and 91 deletions

214
README.md
View File

@@ -4,21 +4,20 @@ DSRPC is easy and simple RPC framework over TCP socket.
### Purpose ### Purpose
A very easy and open RPC framework with data streaming. A very easy and open RPC framework with data streaming.
### You can
### You can - Use own post and pre-execution middleware
- Use post and pre-execution middleware
- Hash-based authentication in middleware - Hash-based authentication in middleware
- Test call remote function without service organization - Test remote function without network
Socket encryption is not used at this time since framefork Socket encryption is not used at this time since framefork
is oriented to transfer large amounts of data is oriented to transfer large amounts of data.
Style of the framework is similar to that of GIN framework. Style of the framework is similar of GIN framework.
## Example ## Exec method example
### Server ### Server
@@ -135,7 +134,7 @@ package api
const HelloMethod string = "hello" const HelloMethod string = "hello"
type HelloParams struct { type HelloParams struct {
Message string `msgpack:"message" json:"message"` Message string `json:"message"`
} }
func NewHelloParams() *HelloParams { func NewHelloParams() *HelloParams {
@@ -143,7 +142,7 @@ func NewHelloParams() *HelloParams {
} }
type HelloResult struct { type HelloResult struct {
Message string `msgpack:"message" json:"message"` Message string `json:"message"`
} }
func NewHelloResult() *HelloResult { func NewHelloResult() *HelloResult {
@@ -151,3 +150,196 @@ func NewHelloResult() *HelloResult {
} }
``` ```
### Authentication and authorization
#### Client side
```
func clientHello() error {
var err error
params := NewHelloParams()
params.Message = "hello server!"
result := NewHelloResult()
auth := dsrpc.CreateAuth([]byte("login"), []byte("password"))
err = dsrpc.Exec("127.0.0.1:8081", HelloMethod, params, result, auth)
if err != nil {
log.Println("method err:", err)
return err
}
//...
}
```
#### Server side
```
func authMiddleware(context *dsrpc.Context) error {
var err error
reqIdent := context.AuthIdent()
reqSalt := context.AuthSalt()
reqHash := context.AuthHash()
if reqIdent != "login" {
err = errors.New("auth ident or pass mismatch")
context.SendError(err)
return err
}
ident := reqIdent
pass := []byte("password")
ok := dsrpc.CheckHash(ident, pass, reqSalt, reqHash)
log.Println("auth is ok:", ok)
if !ok {
err = errors.New("auth ident or pass mismatch")
context.SendError(err)
return err
}
return err
}
func sampleServ(quiet bool) error {
var err error
if quiet {
dsrpc.SetAccessWriter(io.Discard)
dsrpc.SetMessageWriter(io.Discard)
}
serv := NewService()
serv.PreMiddleware(authMiddleware)
serv.PreMiddleware(dsrpc.LogRequest)
serv.Handler(HelloMethod, helloHandler)
serv.Handler(SaveMethod, saveHandler)
serv.Handler(LoadMethod, loadHandler)
serv.PostMiddleware(dsrpc.LogResponse)
serv.PostMiddleware(dsrpc.LogAccess)
err = serv.Listen(":8081")
if err != nil {
return err
}
return err
}
```
### Put method
#### Client side sample
```
var binSize int64 = 16
rand.Seed(time.Now().UnixNano())
binBytes := make([]byte, binSize)
rand.Read(binBytes)
reader := bytes.NewReader(binBytes)
err = dsrpc.Put("127.0.0.1:8081", SaveMethod, reader, binSize, params, result, auth)
```
#### Server side
```
func saveHandler(context *dsrpc.Context) error {
var err error
params := NewSaveParams()
err = context.BindParams(params)
if err != nil {
return err
}
bufferBytes := make([]byte, 0, 1024)
binWriter := bytes.NewBuffer(bufferBytes)
err = context.ReadBin(binWriter)
if err != nil {
context.SendError(err)
return err
}
result := NewSaveResult()
result.Message = "saved successfully!"
err = context.SendResult(result, 0)
if err != nil {
return err
}
return err
}
```
### Get method
#### Client side
```
params := NewLoadParams()
params.Message = "load data!"
result := NewHelloResult()
auth := CreateAuth([]byte("qwert"), []byte("12345"))
binBytes := make([]byte, 0)
writer := bytes.NewBuffer(binBytes)
err = dsrpc.Get("127.0.0.1:8081", LoadMethod, writer, params, result, auth)
if err != nil {
return err
}
//...
```
#### Server side
```
func getHandler(context *dsrpc.Context) error {
var err error
params := NewSaveParams()
err = context.BindParams(params)
if err != nil {
return err
}
var binSize int64 = 1024
rand.Seed(time.Now().UnixNano())
binBytes := make([]byte, binSize)
rand.Read(binBytes)
binReader := bytes.NewReader(binBytes)
result := NewSaveResult()
result.Message = "load successfully!"
err = context.SendResult(result, binSize)
if err != nil {
return err
}
binWriter := context.BinWriter()
_, err = dsrpc.CopyBytes(binReader, binWriter, binSize)
if err != nil {
return err
}
return err
}
```

View File

@@ -20,7 +20,7 @@ func Put(address string, method string, reader io.Reader, size int64, param, res
conn, err := net.Dial("tcp", address) conn, err := net.Dial("tcp", address)
if err != nil { if err != nil {
return err return Err(err)
} }
defer conn.Close() defer conn.Close()
@@ -47,11 +47,11 @@ func ConnPut(conn net.Conn, method string, reader io.Reader, size int64, param,
err = context.CreateRequest() err = context.CreateRequest()
if err != nil { if err != nil {
return err return Err(err)
} }
err = context.WriteRequest() err = context.WriteRequest()
if err != nil { if err != nil {
return err return Err(err)
} }
var wg sync.WaitGroup var wg sync.WaitGroup
@@ -66,13 +66,13 @@ func ConnPut(conn net.Conn, method string, reader io.Reader, size int64, param,
wg.Wait() wg.Wait()
err = <- errChan err = <- errChan
if err != nil { if err != nil {
return err return Err(err)
} }
err = context.BindResponse() err = context.BindResponse()
if err != nil { if err != nil {
return err return Err(err)
} }
return err return Err(err)
} }
func Get(address string, method string, writer io.Writer, param, result any, auth *Auth) error { func Get(address string, method string, writer io.Writer, param, result any, auth *Auth) error {
@@ -80,7 +80,7 @@ func Get(address string, method string, writer io.Writer, param, result any, aut
conn, err := net.Dial("tcp", address) conn, err := net.Dial("tcp", address)
if err != nil { if err != nil {
return err return Err(err)
} }
defer conn.Close() defer conn.Close()
@@ -104,41 +104,40 @@ func ConnGet(conn net.Conn, method string, writer io.Writer, param, result any,
} }
err = context.CreateRequest() err = context.CreateRequest()
if err != nil { if err != nil {
return err return Err(err)
} }
err = context.WriteRequest() err = context.WriteRequest()
if err != nil { if err != nil {
return err return Err(err)
} }
err = context.ReadResponse() err = context.ReadResponse()
if err != nil { if err != nil {
return err return Err(err)
} }
err = context.DownloadBin() err = context.DownloadBin()
if err != nil { if err != nil {
return err return Err(err)
} }
err = context.BindResponse() err = context.BindResponse()
if err != nil { if err != nil {
return err return Err(err)
} }
return err return Err(err)
} }
func Exec(address, method string, param any, result any, auth *Auth) error { func Exec(address, method string, param any, result any, auth *Auth) error {
var err error var err error
conn, err := net.Dial("tcp", address) conn, err := net.Dial("tcp", address)
if err != nil { if err != nil {
return err return Err(err)
} }
defer conn.Close() defer conn.Close()
err = ConnExec(conn, method, param, result, auth) err = ConnExec(conn, method, param, result, auth)
if err != nil { if err != nil {
return err return Err(err)
} }
return err return Err(err)
} }
@@ -157,21 +156,21 @@ func ConnExec(conn net.Conn, method string, param any, result any, auth *Auth) e
err = context.CreateRequest() err = context.CreateRequest()
if err != nil { if err != nil {
return err return Err(err)
} }
err = context.WriteRequest() err = context.WriteRequest()
if err != nil { if err != nil {
return err return Err(err)
} }
err = context.ReadResponse() err = context.ReadResponse()
if err != nil { if err != nil {
return err return Err(err)
} }
err = context.BindResponse() err = context.BindResponse()
if err != nil { if err != nil {
return err return Err(err)
} }
return err return Err(err)
} }
@@ -180,35 +179,35 @@ func (context *Context) CreateRequest() error {
context.reqPacket.rcpPayload, err = context.reqRPC.Pack() context.reqPacket.rcpPayload, err = context.reqRPC.Pack()
if err != nil { if err != nil {
return err return Err(err)
} }
rpcSize := int64(len(context.reqPacket.rcpPayload)) rpcSize := int64(len(context.reqPacket.rcpPayload))
context.reqHeader.rpcSize = rpcSize context.reqHeader.rpcSize = rpcSize
context.reqPacket.header, err = context.reqHeader.Pack() context.reqPacket.header, err = context.reqHeader.Pack()
if err != nil { if err != nil {
return err return Err(err)
} }
return err return Err(err)
} }
func (context *Context) WriteRequest() error { func (context *Context) WriteRequest() error {
var err error var err error
_, err = context.sockWriter.Write(context.reqPacket.header) _, err = context.sockWriter.Write(context.reqPacket.header)
if err != nil { if err != nil {
return err return Err(err)
} }
_, err = context.sockWriter.Write(context.reqPacket.rcpPayload) _, err = context.sockWriter.Write(context.reqPacket.rcpPayload)
if err != nil { if err != nil {
return err return Err(err)
} }
return err return Err(err)
} }
func (context *Context) UploadBin() error { func (context *Context) UploadBin() error {
var err error var err error
_, err = CopyBytes(context.binReader, context.binWriter, context.reqHeader.binSize) _, err = CopyBytes(context.binReader, context.binWriter, context.reqHeader.binSize)
return err return Err(err)
} }
func (context *Context) ReadResponse() error { func (context *Context) ReadResponse() error {
@@ -216,18 +215,18 @@ func (context *Context) ReadResponse() error {
context.resPacket.header, err = ReadBytes(context.sockReader, headerSize) context.resPacket.header, err = ReadBytes(context.sockReader, headerSize)
if err != nil { if err != nil {
return err return Err(err)
} }
context.resHeader, err = UnpackHeader(context.resPacket.header) context.resHeader, err = UnpackHeader(context.resPacket.header)
if err != nil { if err != nil {
return err return Err(err)
} }
rpcSize := context.resHeader.rpcSize rpcSize := context.resHeader.rpcSize
context.resPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize) context.resPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize)
if err != nil { if err != nil {
return err return Err(err)
} }
return err return Err(err)
} }
func (context *Context) UploadBinAsync(wg *sync.WaitGroup) { func (context *Context) UploadBinAsync(wg *sync.WaitGroup) {
@@ -248,15 +247,18 @@ func (context *Context) ReadResponseAsync(wg *sync.WaitGroup, errChan chan error
defer exitFunc() defer exitFunc()
context.resPacket.header, err = ReadBytes(context.sockReader, headerSize) context.resPacket.header, err = ReadBytes(context.sockReader, headerSize)
if err != nil { if err != nil {
err = Err(err)
return return
} }
context.resHeader, err = UnpackHeader(context.resPacket.header) context.resHeader, err = UnpackHeader(context.resPacket.header)
if err != nil { if err != nil {
err = Err(err)
return return
} }
rpcSize := context.resHeader.rpcSize rpcSize := context.resHeader.rpcSize
context.resPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize) context.resPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize)
if err != nil { if err != nil {
err = Err(err)
return return
} }
return return
@@ -265,7 +267,7 @@ func (context *Context) ReadResponseAsync(wg *sync.WaitGroup, errChan chan error
func (context *Context) DownloadBin() error { func (context *Context) DownloadBin() error {
var err error var err error
_, err = CopyBytes(context.binReader, context.binWriter, context.resHeader.binSize) _, err = CopyBytes(context.binReader, context.binWriter, context.resHeader.binSize)
return err return Err(err)
} }
func (context *Context) BindResponse() error { func (context *Context) BindResponse() error {
@@ -273,10 +275,11 @@ func (context *Context) BindResponse() error {
err = json.Unmarshal(context.resPacket.rcpPayload, context.resRPC) err = json.Unmarshal(context.resPacket.rcpPayload, context.resRPC)
if err != nil { if err != nil {
return err return Err(err)
} }
if len(context.resRPC.Error) > 0 { if len(context.resRPC.Error) > 0 {
return errors.New(context.resRPC.Error) err = errors.New(context.resRPC.Error)
return Err(err)
} }
return err return Err(err)
} }

39
error.go Normal file
View File

@@ -0,0 +1,39 @@
package dsrpc
import (
"fmt"
"runtime"
"io"
)
var develMode bool = false
var debugMode bool = false
func SetDevelMode(mode bool) {
develMode = mode
}
func SetDebugMode(mode bool) {
debugMode = mode
}
func Err(err error) error {
switch err {
case io.EOF:
return err
}
if err != nil {
switch {
case develMode == true:
pc, filename, line, _ := runtime.Caller(1)
funcName := runtime.FuncForPC(pc).Name()
err = fmt.Errorf(" %s:%d:%s:%s", filename, line, funcName, err.Error())
case debugMode == true:
pc, _, line, _ := runtime.Caller(1)
funcName := runtime.FuncForPC(pc).Name()
err = fmt.Errorf(" %s:%d:%s ", funcName, line, err.Error())
default:
}
}
return err
}

View File

@@ -56,7 +56,7 @@ func (this *Header) Pack() ([]byte, error) {
magicCodeBBytes := encoderI64(this.magicCodeB) magicCodeBBytes := encoderI64(this.magicCodeB)
headerBuffer.Write(magicCodeBBytes) headerBuffer.Write(magicCodeBBytes)
return headerBuffer.Bytes(), err return headerBuffer.Bytes(), Err(err)
} }
func UnpackHeader(headerBytes []byte) (*Header, error) { func UnpackHeader(headerBytes []byte) (*Header, error) {
@@ -81,10 +81,10 @@ func UnpackHeader(headerBytes []byte) (*Header, error) {
header.magicCodeB = decoderI64(magicCodeBBytes) header.magicCodeB = decoderI64(magicCodeBBytes)
if header.magicCodeA != magicCodeA || header.magicCodeB != magicCodeB { if header.magicCodeA != magicCodeA || header.magicCodeB != magicCodeB {
return header, errors.New("wrong protocol magic code") err = errors.New("wrong protocol magic code")
return header, Err(err)
} }
return header, Err(err)
return header, err
} }
func encoderI64(i int64) []byte { func encoderI64(i int64) []byte {

View File

@@ -17,23 +17,23 @@ var messageWriter io.Writer = os.Stdout
var accessWriter io.Writer = os.Stdout var accessWriter io.Writer = os.Stdout
func logDebug(messages ...any) { func logDebug(messages ...any) {
stamp := time.Now().Format(time.RFC3339Nano) stamp := time.Now().Format(time.RFC3339)
fmt.Fprintln(messageWriter, stamp, "debug", messages) fmt.Fprintln(messageWriter, stamp, "dsrpc debug", messages)
} }
func logInfo(messages ...any) { func logInfo(messages ...any) {
stamp := time.Now().Format(time.RFC3339Nano) stamp := time.Now().Format(time.RFC3339)
fmt.Fprintln(messageWriter, stamp, "info", messages) fmt.Fprintln(messageWriter, stamp, "dsrpc info", messages)
} }
func logError(messages ...any) { func logError(messages ...any) {
stamp := time.Now().Format(time.RFC3339Nano) stamp := time.Now().Format(time.RFC3339)
fmt.Fprintln(messageWriter, stamp, "error", messages) fmt.Fprintln(messageWriter, stamp, "dsrpc error", messages)
} }
func logAccess(messages ...any) { func logAccess(messages ...any) {
stamp := time.Now().Format(time.RFC3339Nano) stamp := time.Now().Format(time.RFC3339)
fmt.Fprintln(accessWriter, stamp, "access", messages) fmt.Fprintln(accessWriter, stamp, "dsrpc access", messages)
} }
func SetAccessWriter(writer io.Writer) { func SetAccessWriter(writer io.Writer) {

View File

@@ -14,18 +14,18 @@ import (
func LogRequest(context *Context) error { func LogRequest(context *Context) error {
var err error var err error
logDebug("request:", string(context.reqRPC.JSON())) logDebug("request:", string(context.reqRPC.JSON()))
return err return Err(err)
} }
func LogResponse(context *Context) error { func LogResponse(context *Context) error {
var err error var err error
logDebug("response:", string(context.resRPC.JSON())) logDebug("response:", string(context.resRPC.JSON()))
return err return Err(err)
} }
func LogAccess(context *Context) error { func LogAccess(context *Context) error {
var err error var err error
execTime := time.Now().Sub(context.start) execTime := time.Now().Sub(context.start)
logAccess(context.remoteHost, context.reqRPC.Method, execTime) logAccess(context.remoteHost, context.reqRPC.Method, execTime)
return err return Err(err)
} }

View File

@@ -24,7 +24,7 @@ func NewRequest() *Request {
func (this *Request) Pack() ([]byte, error) { func (this *Request) Pack() ([]byte, error) {
rBytes, err := json.Marshal(this) rBytes, err := json.Marshal(this)
return rBytes, err return rBytes, Err(err)
} }
func (this *Request) JSON() []byte { func (this *Request) JSON() []byte {

View File

@@ -27,5 +27,5 @@ func (this *Response) JSON() []byte {
func (this *Response) Pack() ([]byte, error) { func (this *Response) Pack() ([]byte, error) {
rBytes, err := json.Marshal(this) rBytes, err := json.Marshal(this)
return rBytes, err return rBytes, Err(err)
} }

View File

@@ -56,6 +56,7 @@ func (this *Service) Handler(method string, handler HandlerFunc) {
func (this *Service) Listen(address string) error { func (this *Service) Listen(address string) error {
var err error var err error
logInfo("server listen:", address) logInfo("server listen:", address)
listener, err := net.Listen("tcp", address) listener, err := net.Listen("tcp", address)
if err != nil { if err != nil {
return err return err
@@ -72,7 +73,6 @@ func (this *Service) Listen(address string) error {
if err != nil { if err != nil {
logError("conn accept err:", err) logError("conn accept err:", err)
} }
go this.handleConn(conn) go this.handleConn(conn)
} }
} }
@@ -120,26 +120,31 @@ func (this *Service) handleConn(conn net.Conn) {
err = context.ReadRequest() err = context.ReadRequest()
if err != nil { if err != nil {
err = Err(err)
return return
} }
err = context.BindMethod() err = context.BindMethod()
if err != nil { if err != nil {
err = Err(err)
return return
} }
for _, mw := range this.preMw { for _, mw := range this.preMw {
err = mw(context) err = mw(context)
if err != nil { if err != nil {
err = Err(err)
return return
} }
} }
err = this.Route(context) err = this.Route(context)
if err != nil { if err != nil {
err = Err(err)
return return
} }
for _, mw := range this.postMw { for _, mw := range this.postMw {
err = mw(context) err = mw(context)
if err != nil { if err != nil {
err = Err(err)
return return
} }
} }
@@ -149,9 +154,9 @@ func (this *Service) handleConn(conn net.Conn) {
func (this *Service) Route(context *Context) error { func (this *Service) Route(context *Context) error {
handler, ok := this.handlers[context.reqRPC.Method] handler, ok := this.handlers[context.reqRPC.Method]
if ok { if ok {
return handler(context) return Err(handler(context))
} }
return notFound(context) return Err(notFound(context))
} }
func (context *Context) ReadRequest() error { func (context *Context) ReadRequest() error {
@@ -159,19 +164,19 @@ func (context *Context) ReadRequest() error {
context.reqPacket.header, err = ReadBytes(context.sockReader, headerSize) context.reqPacket.header, err = ReadBytes(context.sockReader, headerSize)
if err != nil { if err != nil {
return err return Err(err)
} }
context.reqHeader, err = UnpackHeader(context.reqPacket.header) context.reqHeader, err = UnpackHeader(context.reqPacket.header)
if err != nil { if err != nil {
return err return Err(err)
} }
rpcSize := context.reqHeader.rpcSize rpcSize := context.reqHeader.rpcSize
context.reqPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize) context.reqPacket.rcpPayload, err = ReadBytes(context.sockReader, rpcSize)
if err != nil { if err != nil {
return err return Err(err)
} }
return err return Err(err)
} }
func (context *Context) BinWriter() io.Writer { func (context *Context) BinWriter() io.Writer {
@@ -189,14 +194,14 @@ func (context *Context) BinSize() int64 {
func (context *Context) ReadBin(writer io.Writer) error { func (context *Context) ReadBin(writer io.Writer) error {
var err error var err error
_, err = CopyBytes(context.sockReader, writer, context.reqHeader.binSize) _, err = CopyBytes(context.sockReader, writer, context.reqHeader.binSize)
return err return Err(err)
} }
func (context *Context) BindMethod() error { func (context *Context) BindMethod() error {
var err error var err error
err = json.Unmarshal(context.reqPacket.rcpPayload, context.reqRPC) err = json.Unmarshal(context.reqPacket.rcpPayload, context.reqRPC)
return err return Err(err)
} }
func (context *Context) BindParams(params any) error { func (context *Context) BindParams(params any) error {
@@ -204,9 +209,9 @@ func (context *Context) BindParams(params any) error {
context.reqRPC.Params = params context.reqRPC.Params = params
err = json.Unmarshal(context.reqPacket.rcpPayload, context.reqRPC) err = json.Unmarshal(context.reqPacket.rcpPayload, context.reqRPC)
if err != nil { if err != nil {
return err return Err(err)
} }
return err return Err(err)
} }
func (context *Context) SendResult(result any, binSize int64) error { func (context *Context) SendResult(result any, binSize int64) error {
@@ -215,24 +220,24 @@ func (context *Context) SendResult(result any, binSize int64) error {
context.resPacket.rcpPayload, err = context.resRPC.Pack() context.resPacket.rcpPayload, err = context.resRPC.Pack()
if err != nil { if err != nil {
return err return Err(err)
} }
context.resHeader.rpcSize = int64(len(context.resPacket.rcpPayload)) context.resHeader.rpcSize = int64(len(context.resPacket.rcpPayload))
context.resHeader.binSize = binSize context.resHeader.binSize = binSize
context.resPacket.header, err = context.resHeader.Pack() context.resPacket.header, err = context.resHeader.Pack()
if err != nil { if err != nil {
return err return Err(err)
} }
_, err = context.sockWriter.Write(context.resPacket.header) _, err = context.sockWriter.Write(context.resPacket.header)
if err != nil { if err != nil {
return err return Err(err)
} }
_, err = context.sockWriter.Write(context.resPacket.rcpPayload) _, err = context.sockWriter.Write(context.resPacket.rcpPayload)
if err != nil { if err != nil {
return err return Err(err)
} }
return err return Err(err)
} }
@@ -244,20 +249,20 @@ func (context *Context) SendError(execErr error) error {
context.resPacket.rcpPayload, err = context.resRPC.Pack() context.resPacket.rcpPayload, err = context.resRPC.Pack()
if err != nil { if err != nil {
return err return Err(err)
} }
context.resHeader.rpcSize = int64(len(context.resPacket.rcpPayload)) context.resHeader.rpcSize = int64(len(context.resPacket.rcpPayload))
context.resPacket.header, err = context.resHeader.Pack() context.resPacket.header, err = context.resHeader.Pack()
if err != nil { if err != nil {
return err return Err(err)
} }
_, err = context.sockWriter.Write(context.resPacket.header) _, err = context.sockWriter.Write(context.resPacket.header)
if err != nil { if err != nil {
return err return Err(err)
} }
_, err = context.sockWriter.Write(context.resPacket.rcpPayload) _, err = context.sockWriter.Write(context.resPacket.rcpPayload)
if err != nil { if err != nil {
return err return Err(err)
} }
return err return Err(err)
} }

View File

@@ -13,7 +13,7 @@ import (
func ReadBytes(reader io.Reader, size int64) ([]byte, error) { func ReadBytes(reader io.Reader, size int64) ([]byte, error) {
buffer := make([]byte, size) buffer := make([]byte, size)
read, err := io.ReadFull(reader, buffer) read, err := io.ReadFull(reader, buffer)
return buffer[0:read], err return buffer[0:read], Err(err)
} }
func CopyBytes(reader io.Reader, writer io.Writer, dataSize int64) (int64, error) { func CopyBytes(reader io.Reader, writer io.Writer, dataSize int64) (int64, error) {
@@ -38,17 +38,20 @@ func CopyBytes(reader io.Reader, writer io.Writer, dataSize int64) (int64, error
} }
received, err := reader.Read(buffer[0:bSize]) received, err := reader.Read(buffer[0:bSize])
if err != nil { if err != nil {
return total, fmt.Errorf("read error: %v", err) err = fmt.Errorf("read error: %v", err)
return total, Err(err)
} }
recorded, err := writer.Write(buffer[0:received]) recorded, err := writer.Write(buffer[0:received])
if err != nil { if err != nil {
return total, fmt.Errorf("write error: %v", err) err = fmt.Errorf("write error: %v", err)
return total, Err(err)
} }
if recorded != received { if recorded != received {
return total, errors.New("size mismatch") err = errors.New("size mismatch")
return total, Err(err)
} }
total += int64(recorded) total += int64(recorded)
remains -= int64(recorded) remains -= int64(recorded)
} }
return total, err return total, Err(err)
} }