diff --git a/client.go b/client.go index b0906d6..05ffd06 100644 --- a/client.go +++ b/client.go @@ -5,6 +5,7 @@ package dsrpc import ( + "context" "errors" "fmt" "io" @@ -14,7 +15,7 @@ import ( encoder "github.com/vmihailenco/msgpack/v5" ) -func Put(address string, method string, reader io.Reader, binSize int64, param, result any, auth *Auth) error { +func Put(ctx context.Context, address string, method string, reader io.Reader, binSize int64, param, result any, auth *Auth) error { var err error addr, err := net.ResolveTCPAddr("tcp", address) @@ -28,10 +29,10 @@ func Put(address string, method string, reader io.Reader, binSize int64, param, } defer conn.Close() - return ConnPut(conn, method, reader, binSize, param, result, auth) + return ConnPut(ctx, conn, method, reader, binSize, param, result, auth) } -func ConnPut(conn net.Conn, method string, reader io.Reader, binSize int64, param, result any, auth *Auth) error { +func ConnPut(ctx context.Context, conn net.Conn, method string, reader io.Reader, binSize int64, param, result any, auth *Auth) error { var err error content := CreateContent(conn) @@ -51,11 +52,11 @@ func ConnPut(conn net.Conn, method string, reader io.Reader, binSize int64, para content.reqHeader.binSize = binSize - err = content.CreateRequest() + err = content.createRequest() if err != nil { return err } - err = content.WriteRequest() + err = content.writeRequest() if err != nil { return err } @@ -64,24 +65,24 @@ func ConnPut(conn net.Conn, method string, reader io.Reader, binSize int64, para errChan := make(chan error, 1) wg.Add(1) - go content.ReadResponseAsync(&wg, errChan) + go content.readResponseAsync(&wg, errChan) wg.Add(1) - go content.UploadBinAsync(&wg) + go content.uploadBinAsync(ctx, &wg) wg.Wait() err = <-errChan if err != nil { return err } - err = content.BindResponse() + err = content.bindResponse() if err != nil { return err } return err } -func Get(address string, method string, writer io.Writer, param, result any, auth *Auth) error { +func Get(ctx context.Context, address string, method string, writer io.Writer, param, result any, auth *Auth) error { var err error addr, err := net.ResolveTCPAddr("tcp", address) @@ -95,10 +96,10 @@ func Get(address string, method string, writer io.Writer, param, result any, aut } defer conn.Close() - return ConnGet(conn, method, writer, param, result, auth) + return ConnGet(ctx, conn, method, writer, param, result, auth) } -func ConnGet(conn net.Conn, method string, writer io.Writer, param, result any, auth *Auth) error { +func ConnGet(ctx context.Context, conn net.Conn, method string, writer io.Writer, param, result any, auth *Auth) error { var err error content := CreateContent(conn) @@ -116,30 +117,30 @@ func ConnGet(conn net.Conn, method string, writer io.Writer, param, result any, content.binReader = conn content.binWriter = writer - err = content.CreateRequest() + err = content.createRequest() if err != nil { return err } - err = content.WriteRequest() + err = content.writeRequest() if err != nil { return err } - err = content.ReadResponse() + err = content.readResponse() if err != nil { return err } - err = content.DownloadBin() + err = content.downloadBin(ctx) if err != nil { return err } - err = content.BindResponse() + err = content.bindResponse() if err != nil { return err } return err } -func Exec(address, method string, param any, result any, auth *Auth) error { +func Exec(ctx context.Context, address, method string, param any, result any, auth *Auth) error { var err error addr, err := net.ResolveTCPAddr("tcp", address) @@ -153,14 +154,14 @@ func Exec(address, method string, param any, result any, auth *Auth) error { } defer conn.Close() - err = ConnExec(conn, method, param, result, auth) + err = ConnExec(ctx, conn, method, param, result, auth) if err != nil { return err } return err } -func ConnExec(conn net.Conn, method string, param any, result any, auth *Auth) error { +func ConnExec(ctx context.Context, conn net.Conn, method string, param any, result any, auth *Auth) error { var err error content := CreateContent(conn) @@ -176,26 +177,26 @@ func ConnExec(conn net.Conn, method string, param any, result any, auth *Auth) e content.reqBlock.Auth = auth } - err = content.CreateRequest() + err = content.createRequest() if err != nil { return err } - err = content.WriteRequest() + err = content.writeRequest() if err != nil { return err } - err = content.ReadResponse() + err = content.readResponse() if err != nil { return err } - err = content.BindResponse() + err = content.bindResponse() if err != nil { return err } return err } -func (content *Content) CreateRequest() error { +func (content *Content) createRequest() error { var err error content.reqPacket.rcpPayload, err = content.reqBlock.Pack() @@ -212,7 +213,7 @@ func (content *Content) CreateRequest() error { return err } -func (content *Content) WriteRequest() error { +func (content *Content) writeRequest() error { var err error _, err = content.sockWriter.Write(content.reqPacket.header) if err != nil { @@ -225,13 +226,13 @@ func (content *Content) WriteRequest() error { return err } -func (content *Content) UploadBin() error { +func (content *Content) uploadBin(ctx context.Context) error { var err error - _, err = CopyBytes(content.binReader, content.binWriter, content.reqHeader.binSize) + _, err = CopyBytes(ctx, content.binReader, content.binWriter, content.reqHeader.binSize) return err } -func (content *Content) ReadResponse() error { +func (content *Content) readResponse() error { var err error content.resPacket.header, err = ReadBytes(content.sockReader, headerSize) @@ -250,16 +251,16 @@ func (content *Content) ReadResponse() error { return err } -func (content *Content) UploadBinAsync(wg *sync.WaitGroup) { +func (content *Content) uploadBinAsync(ctx context.Context, wg *sync.WaitGroup) { exitFunc := func() { wg.Done() } defer exitFunc() - _, _ = CopyBytes(content.binReader, content.binWriter, content.reqHeader.binSize) + _, _ = CopyBytes(ctx, content.binReader, content.binWriter, content.reqHeader.binSize) return } -func (content *Content) ReadResponseAsync(wg *sync.WaitGroup, errChan chan error) { +func (content *Content) readResponseAsync(wg *sync.WaitGroup, errChan chan error) { var err error exitFunc := func() { errChan <- err @@ -285,13 +286,13 @@ func (content *Content) ReadResponseAsync(wg *sync.WaitGroup, errChan chan error return } -func (content *Content) DownloadBin() error { +func (content *Content) downloadBin(ctx context.Context) error { var err error - _, err = CopyBytes(content.binReader, content.binWriter, content.resHeader.binSize) + _, err = CopyBytes(ctx, content.binReader, content.binWriter, content.resHeader.binSize) return err } -func (content *Content) BindResponse() error { +func (content *Content) bindResponse() error { var err error err = encoder.Unmarshal(content.resPacket.rcpPayload, content.resBlock) diff --git a/example/go.mod b/example/go.mod index bb8242a..47e7cc6 100644 --- a/example/go.mod +++ b/example/go.mod @@ -2,7 +2,7 @@ module netsrv go 1.19 -require github.com/kindsoldier/dsrpc v1.1.2 +require github.com/kindsoldier/dsrpc v1.1.4 require ( github.com/vmihailenco/msgpack/v5 v5.3.5 // indirect diff --git a/example/go.sum b/example/go.sum index 8b1ee29..0ddc5d4 100644 --- a/example/go.sum +++ b/example/go.sum @@ -2,6 +2,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/kindsoldier/dsrpc v1.1.2 h1:bFTIGpRSMq5OK1a3dHQxLPMxf6R+Ik15slkqNN0QrNE= github.com/kindsoldier/dsrpc v1.1.2/go.mod h1:KG8x2ZPid/hPJdhkUHtt1mDulWPVhj9fh/1XL3Z2xT8= +github.com/kindsoldier/dsrpc v1.1.4 h1:F6e1K5C7C92jKGOrH4lF/XraLe5E2glsQTeVP9avYBE= +github.com/kindsoldier/dsrpc v1.1.4/go.mod h1:KG8x2ZPid/hPJdhkUHtt1mDulWPVhj9fh/1XL3Z2xT8= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/exec_test.go b/exec_test.go index 25d5226..b9e43be 100644 --- a/exec_test.go +++ b/exec_test.go @@ -6,6 +6,7 @@ package dsrpc import ( "bytes" + "context" "encoding/json" "errors" "io" @@ -66,7 +67,10 @@ func TestLocalSave(t *testing.T) { reader := bytes.NewReader(binBytes) - err = LocalPut(SaveMethod, reader, binSize, ¶ms, &result, auth, saveHandler) + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(5*time.Second)) + defer cancel() + + err = LocalPut(ctx, SaveMethod, reader, binSize, ¶ms, &result, auth, saveHandler) require.NoError(t, err) resultJson, _ := json.Marshal(result) @@ -84,7 +88,10 @@ func TestLocalLoad(t *testing.T) { binBytes := make([]byte, 0) writer := bytes.NewBuffer(binBytes) - err = LocalGet(LoadMethod, writer, ¶ms, &result, auth, loadHandler) + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(5*time.Second)) + defer cancel() + + err = LocalGet(ctx, LoadMethod, writer, ¶ms, &result, auth, loadHandler) require.NoError(t, err) resultJson, _ := json.Marshal(result) @@ -141,7 +148,10 @@ func clientHello() error { binBytes := make([]byte, binSize) rand.Read(binBytes) - err = Exec("127.0.0.1:8081", HelloMethod, ¶ms, &result, auth) + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(5*time.Second)) + defer cancel() + + err = Exec(ctx, "127.0.0.1:8081", HelloMethod, ¶ms, &result, auth) if err != nil { logError("method err:", err) return err @@ -166,7 +176,10 @@ func clientSave() error { reader := bytes.NewReader(binBytes) - err = Put("127.0.0.1:8081", SaveMethod, reader, binSize, ¶ms, &result, auth) + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(5*time.Second)) + defer cancel() + + err = Put(ctx, "127.0.0.1:8081", SaveMethod, reader, binSize, ¶ms, &result, auth) if err != nil { logError("method err:", err) return err @@ -187,7 +200,10 @@ func clientLoad() error { binBytes := make([]byte, 0) writer := bytes.NewBuffer(binBytes) - err = Get("127.0.0.1:8081", LoadMethod, writer, ¶ms, &result, auth) + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(5*time.Second)) + defer cancel() + + err = Get(ctx, "127.0.0.1:8081", LoadMethod, writer, ¶ms, &result, auth) if err != nil { logError("method err:", err) return err @@ -261,7 +277,10 @@ func helloHandler(content *Content) error { return err } - err = content.ReadBin(io.Discard) + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(5*time.Second)) + defer cancel() + + err = content.ReadBin(ctx, io.Discard) if err != nil { content.SendError(err) return err @@ -289,7 +308,10 @@ func saveHandler(content *Content) error { bufferBytes := make([]byte, 0, 1024) binWriter := bytes.NewBuffer(bufferBytes) - err = content.ReadBin(binWriter) + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(5*time.Second)) + defer cancel() + + err = content.ReadBin(ctx, binWriter) if err != nil { content.SendError(err) return err @@ -314,7 +336,10 @@ func loadHandler(content *Content) error { return err } - err = content.ReadBin(io.Discard) + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(5*time.Second)) + defer cancel() + + err = content.ReadBin(ctx, io.Discard) if err != nil { content.SendError(err) return err @@ -334,8 +359,9 @@ func loadHandler(content *Content) error { if err != nil { return err } + binWriter := content.BinWriter() - _, err = CopyBytes(binReader, binWriter, binSize) + _, err = CopyBytes(ctx, binReader, binWriter, binSize) if err != nil { return err } diff --git a/header.go b/header.go index 61493d9..d722390 100644 --- a/header.go +++ b/header.go @@ -44,16 +44,16 @@ func (hdr *Header) Pack() ([]byte, error) { headerBytes := make([]byte, 0, headerSize) headerBuffer := bytes.NewBuffer(headerBytes) - magicCodeABytes := encoderI64(hdr.magicCodeA) + magicCodeABytes := EncoderI64(hdr.magicCodeA) headerBuffer.Write(magicCodeABytes) - rpcSizeBytes := encoderI64(hdr.rpcSize) + rpcSizeBytes := EncoderI64(hdr.rpcSize) headerBuffer.Write(rpcSizeBytes) - binSizeBytes := encoderI64(hdr.binSize) + binSizeBytes := EncoderI64(hdr.binSize) headerBuffer.Write(binSizeBytes) - magicCodeBBytes := encoderI64(hdr.magicCodeB) + magicCodeBBytes := EncoderI64(hdr.magicCodeB) headerBuffer.Write(magicCodeBBytes) return headerBuffer.Bytes(), err @@ -77,10 +77,10 @@ func UnpackHeader(headerBytes []byte) (*Header, error) { headerReader.Read(magicCodeBBytes) header := &Header{ - magicCodeA: decoderI64(magicCodeABytes), - rpcSize: decoderI64(rpcSizeBytes), - binSize: decoderI64(binSizeBytes), - magicCodeB: decoderI64(magicCodeBBytes), + magicCodeA: DecoderI64(magicCodeABytes), + rpcSize: DecoderI64(rpcSizeBytes), + binSize: DecoderI64(binSizeBytes), + magicCodeB: DecoderI64(magicCodeBBytes), } if header.magicCodeA != magicCodeA || header.magicCodeB != magicCodeB { @@ -90,12 +90,12 @@ func UnpackHeader(headerBytes []byte) (*Header, error) { return header, err } -func encoderI64(i int64) []byte { +func EncoderI64(i int64) []byte { buffer := make([]byte, sizeOfInt64) binary.BigEndian.PutUint64(buffer, uint64(i)) return buffer } -func decoderI64(b []byte) int64 { +func DecoderI64(b []byte) int64 { return int64(binary.BigEndian.Uint64(b)) } diff --git a/server.go b/server.go index cdfb847..3350ed5 100644 --- a/server.go +++ b/server.go @@ -232,9 +232,9 @@ func (content *Content) BinSize() int64 { return content.reqHeader.binSize } -func (content *Content) ReadBin(writer io.Writer) error { +func (content *Content) ReadBin(ctx context.Context, writer io.Writer) error { var err error - _, err = CopyBytes(content.sockReader, writer, content.reqHeader.binSize) + _, err = CopyBytes(ctx, content.sockReader, writer, content.reqHeader.binSize) return err } diff --git a/tools.go b/tools.go index 10721c4..259b106 100644 --- a/tools.go +++ b/tools.go @@ -5,6 +5,7 @@ package dsrpc import ( + "context" "errors" "fmt" "io" @@ -16,7 +17,7 @@ func ReadBytes(reader io.Reader, size int64) ([]byte, error) { return buffer[0:read], err } -func CopyBytes(reader io.Reader, writer io.Writer, dataSize int64) (int64, error) { +func CopyBytes(ctx context.Context, reader io.Reader, writer io.Writer, dataSize int64) (int64, error) { var err error var bSize int64 = 1024 * 16 var total int64 = 0 @@ -24,6 +25,12 @@ func CopyBytes(reader io.Reader, writer io.Writer, dataSize int64) (int64, error buffer := make([]byte, bSize) for { + select { + case <-ctx.Done(): + return total, errors.New("break by context") + default: + } + if reader == nil { return total, errors.New("reader is nil") } diff --git a/validate.go b/validate.go index 55bc762..9b60619 100644 --- a/validate.go +++ b/validate.go @@ -5,6 +5,7 @@ package dsrpc import ( + "context" "io" "net" ) @@ -27,11 +28,11 @@ func LocalExec(method string, param, result any, auth *Auth, handler HandlerFunc content.resBlock.Result = result } - err = content.CreateRequest() + err = content.createRequest() if err != nil { return err } - err = content.WriteRequest() + err = content.writeRequest() if err != nil { return err } @@ -39,11 +40,11 @@ func LocalExec(method string, param, result any, auth *Auth, handler HandlerFunc if err != nil { return err } - err = content.ReadResponse() + err = content.readResponse() if err != nil { return err } - err = content.BindResponse() + err = content.bindResponse() if err != nil { return err } @@ -51,7 +52,7 @@ func LocalExec(method string, param, result any, auth *Auth, handler HandlerFunc return err } -func LocalPut(method string, reader io.Reader, size int64, param, result any, auth *Auth, handler HandlerFunc) error { +func LocalPut(ctx context.Context, method string, reader io.Reader, size int64, param, result any, auth *Auth, handler HandlerFunc) error { var err error @@ -75,15 +76,15 @@ func LocalPut(method string, reader io.Reader, size int64, param, result any, au content.reqHeader.binSize = size - err = content.CreateRequest() + err = content.createRequest() if err != nil { return err } - err = content.WriteRequest() + err = content.writeRequest() if err != nil { return err } - err = content.UploadBin() + err = content.uploadBin(ctx) if err != nil { return err } @@ -91,18 +92,18 @@ func LocalPut(method string, reader io.Reader, size int64, param, result any, au if err != nil { return err } - err = content.ReadResponse() + err = content.readResponse() if err != nil { return err } - err = content.BindResponse() + err = content.bindResponse() if err != nil { return err } return err } -func LocalGet(method string, writer io.Writer, param, result any, auth *Auth, handler HandlerFunc) error { +func LocalGet(ctx context.Context, method string, writer io.Writer, param, result any, auth *Auth, handler HandlerFunc) error { var err error cliConn, srvConn := NewFConn() @@ -123,11 +124,11 @@ func LocalGet(method string, writer io.Writer, param, result any, auth *Auth, ha content.binReader = cliConn content.binWriter = writer - err = content.CreateRequest() + err = content.createRequest() if err != nil { return err } - err = content.WriteRequest() + err = content.writeRequest() if err != nil { return err } @@ -136,15 +137,15 @@ func LocalGet(method string, writer io.Writer, param, result any, auth *Auth, ha if err != nil { return err } - err = content.ReadResponse() + err = content.readResponse() if err != nil { return err } - err = content.DownloadBin() + err = content.downloadBin(ctx) if err != nil { return err } - err = content.BindResponse() + err = content.bindResponse() if err != nil { return err }