diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-28 13:14:23 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-28 13:14:23 +0000 |
commit | 73df946d56c74384511a194dd01dbe099584fd1a (patch) | |
tree | fd0bcea490dd81327ddfbb31e215439672c9a068 /src/net/rpc | |
parent | Initial commit. (diff) | |
download | golang-1.16-73df946d56c74384511a194dd01dbe099584fd1a.tar.xz golang-1.16-73df946d56c74384511a194dd01dbe099584fd1a.zip |
Adding upstream version 1.16.10.upstream/1.16.10upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to '')
-rw-r--r-- | src/net/rpc/client.go | 323 | ||||
-rw-r--r-- | src/net/rpc/client_test.go | 87 | ||||
-rw-r--r-- | src/net/rpc/debug.go | 90 | ||||
-rw-r--r-- | src/net/rpc/jsonrpc/all_test.go | 353 | ||||
-rw-r--r-- | src/net/rpc/jsonrpc/client.go | 124 | ||||
-rw-r--r-- | src/net/rpc/jsonrpc/server.go | 134 | ||||
-rw-r--r-- | src/net/rpc/server.go | 720 | ||||
-rw-r--r-- | src/net/rpc/server_test.go | 839 |
8 files changed, 2670 insertions, 0 deletions
diff --git a/src/net/rpc/client.go b/src/net/rpc/client.go new file mode 100644 index 0000000..60bb2cc --- /dev/null +++ b/src/net/rpc/client.go @@ -0,0 +1,323 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package rpc + +import ( + "bufio" + "encoding/gob" + "errors" + "io" + "log" + "net" + "net/http" + "sync" +) + +// ServerError represents an error that has been returned from +// the remote side of the RPC connection. +type ServerError string + +func (e ServerError) Error() string { + return string(e) +} + +var ErrShutdown = errors.New("connection is shut down") + +// Call represents an active RPC. +type Call struct { + ServiceMethod string // The name of the service and method to call. + Args interface{} // The argument to the function (*struct). + Reply interface{} // The reply from the function (*struct). + Error error // After completion, the error status. + Done chan *Call // Receives *Call when Go is complete. +} + +// Client represents an RPC Client. +// There may be multiple outstanding Calls associated +// with a single Client, and a Client may be used by +// multiple goroutines simultaneously. +type Client struct { + codec ClientCodec + + reqMutex sync.Mutex // protects following + request Request + + mutex sync.Mutex // protects following + seq uint64 + pending map[uint64]*Call + closing bool // user has called Close + shutdown bool // server has told us to stop +} + +// A ClientCodec implements writing of RPC requests and +// reading of RPC responses for the client side of an RPC session. +// The client calls WriteRequest to write a request to the connection +// and calls ReadResponseHeader and ReadResponseBody in pairs +// to read responses. The client calls Close when finished with the +// connection. ReadResponseBody may be called with a nil +// argument to force the body of the response to be read and then +// discarded. +// See NewClient's comment for information about concurrent access. +type ClientCodec interface { + WriteRequest(*Request, interface{}) error + ReadResponseHeader(*Response) error + ReadResponseBody(interface{}) error + + Close() error +} + +func (client *Client) send(call *Call) { + client.reqMutex.Lock() + defer client.reqMutex.Unlock() + + // Register this call. + client.mutex.Lock() + if client.shutdown || client.closing { + client.mutex.Unlock() + call.Error = ErrShutdown + call.done() + return + } + seq := client.seq + client.seq++ + client.pending[seq] = call + client.mutex.Unlock() + + // Encode and send the request. + client.request.Seq = seq + client.request.ServiceMethod = call.ServiceMethod + err := client.codec.WriteRequest(&client.request, call.Args) + if err != nil { + client.mutex.Lock() + call = client.pending[seq] + delete(client.pending, seq) + client.mutex.Unlock() + if call != nil { + call.Error = err + call.done() + } + } +} + +func (client *Client) input() { + var err error + var response Response + for err == nil { + response = Response{} + err = client.codec.ReadResponseHeader(&response) + if err != nil { + break + } + seq := response.Seq + client.mutex.Lock() + call := client.pending[seq] + delete(client.pending, seq) + client.mutex.Unlock() + + switch { + case call == nil: + // We've got no pending call. That usually means that + // WriteRequest partially failed, and call was already + // removed; response is a server telling us about an + // error reading request body. We should still attempt + // to read error body, but there's no one to give it to. + err = client.codec.ReadResponseBody(nil) + if err != nil { + err = errors.New("reading error body: " + err.Error()) + } + case response.Error != "": + // We've got an error response. Give this to the request; + // any subsequent requests will get the ReadResponseBody + // error if there is one. + call.Error = ServerError(response.Error) + err = client.codec.ReadResponseBody(nil) + if err != nil { + err = errors.New("reading error body: " + err.Error()) + } + call.done() + default: + err = client.codec.ReadResponseBody(call.Reply) + if err != nil { + call.Error = errors.New("reading body " + err.Error()) + } + call.done() + } + } + // Terminate pending calls. + client.reqMutex.Lock() + client.mutex.Lock() + client.shutdown = true + closing := client.closing + if err == io.EOF { + if closing { + err = ErrShutdown + } else { + err = io.ErrUnexpectedEOF + } + } + for _, call := range client.pending { + call.Error = err + call.done() + } + client.mutex.Unlock() + client.reqMutex.Unlock() + if debugLog && err != io.EOF && !closing { + log.Println("rpc: client protocol error:", err) + } +} + +func (call *Call) done() { + select { + case call.Done <- call: + // ok + default: + // We don't want to block here. It is the caller's responsibility to make + // sure the channel has enough buffer space. See comment in Go(). + if debugLog { + log.Println("rpc: discarding Call reply due to insufficient Done chan capacity") + } + } +} + +// NewClient returns a new Client to handle requests to the +// set of services at the other end of the connection. +// It adds a buffer to the write side of the connection so +// the header and payload are sent as a unit. +// +// The read and write halves of the connection are serialized independently, +// so no interlocking is required. However each half may be accessed +// concurrently so the implementation of conn should protect against +// concurrent reads or concurrent writes. +func NewClient(conn io.ReadWriteCloser) *Client { + encBuf := bufio.NewWriter(conn) + client := &gobClientCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(encBuf), encBuf} + return NewClientWithCodec(client) +} + +// NewClientWithCodec is like NewClient but uses the specified +// codec to encode requests and decode responses. +func NewClientWithCodec(codec ClientCodec) *Client { + client := &Client{ + codec: codec, + pending: make(map[uint64]*Call), + } + go client.input() + return client +} + +type gobClientCodec struct { + rwc io.ReadWriteCloser + dec *gob.Decoder + enc *gob.Encoder + encBuf *bufio.Writer +} + +func (c *gobClientCodec) WriteRequest(r *Request, body interface{}) (err error) { + if err = c.enc.Encode(r); err != nil { + return + } + if err = c.enc.Encode(body); err != nil { + return + } + return c.encBuf.Flush() +} + +func (c *gobClientCodec) ReadResponseHeader(r *Response) error { + return c.dec.Decode(r) +} + +func (c *gobClientCodec) ReadResponseBody(body interface{}) error { + return c.dec.Decode(body) +} + +func (c *gobClientCodec) Close() error { + return c.rwc.Close() +} + +// DialHTTP connects to an HTTP RPC server at the specified network address +// listening on the default HTTP RPC path. +func DialHTTP(network, address string) (*Client, error) { + return DialHTTPPath(network, address, DefaultRPCPath) +} + +// DialHTTPPath connects to an HTTP RPC server +// at the specified network address and path. +func DialHTTPPath(network, address, path string) (*Client, error) { + conn, err := net.Dial(network, address) + if err != nil { + return nil, err + } + io.WriteString(conn, "CONNECT "+path+" HTTP/1.0\n\n") + + // Require successful HTTP response + // before switching to RPC protocol. + resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"}) + if err == nil && resp.Status == connected { + return NewClient(conn), nil + } + if err == nil { + err = errors.New("unexpected HTTP response: " + resp.Status) + } + conn.Close() + return nil, &net.OpError{ + Op: "dial-http", + Net: network + " " + address, + Addr: nil, + Err: err, + } +} + +// Dial connects to an RPC server at the specified network address. +func Dial(network, address string) (*Client, error) { + conn, err := net.Dial(network, address) + if err != nil { + return nil, err + } + return NewClient(conn), nil +} + +// Close calls the underlying codec's Close method. If the connection is already +// shutting down, ErrShutdown is returned. +func (client *Client) Close() error { + client.mutex.Lock() + if client.closing { + client.mutex.Unlock() + return ErrShutdown + } + client.closing = true + client.mutex.Unlock() + return client.codec.Close() +} + +// Go invokes the function asynchronously. It returns the Call structure representing +// the invocation. The done channel will signal when the call is complete by returning +// the same Call object. If done is nil, Go will allocate a new channel. +// If non-nil, done must be buffered or Go will deliberately crash. +func (client *Client) Go(serviceMethod string, args interface{}, reply interface{}, done chan *Call) *Call { + call := new(Call) + call.ServiceMethod = serviceMethod + call.Args = args + call.Reply = reply + if done == nil { + done = make(chan *Call, 10) // buffered. + } else { + // If caller passes done != nil, it must arrange that + // done has enough buffer for the number of simultaneous + // RPCs that will be using that channel. If the channel + // is totally unbuffered, it's best not to run at all. + if cap(done) == 0 { + log.Panic("rpc: done channel is unbuffered") + } + } + call.Done = done + client.send(call) + return call +} + +// Call invokes the named function, waits for it to complete, and returns its error status. +func (client *Client) Call(serviceMethod string, args interface{}, reply interface{}) error { + call := <-client.Go(serviceMethod, args, reply, make(chan *Call, 1)).Done + return call.Error +} diff --git a/src/net/rpc/client_test.go b/src/net/rpc/client_test.go new file mode 100644 index 0000000..03225e3 --- /dev/null +++ b/src/net/rpc/client_test.go @@ -0,0 +1,87 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package rpc + +import ( + "errors" + "fmt" + "net" + "strings" + "testing" +) + +type shutdownCodec struct { + responded chan int + closed bool +} + +func (c *shutdownCodec) WriteRequest(*Request, interface{}) error { return nil } +func (c *shutdownCodec) ReadResponseBody(interface{}) error { return nil } +func (c *shutdownCodec) ReadResponseHeader(*Response) error { + c.responded <- 1 + return errors.New("shutdownCodec ReadResponseHeader") +} +func (c *shutdownCodec) Close() error { + c.closed = true + return nil +} + +func TestCloseCodec(t *testing.T) { + codec := &shutdownCodec{responded: make(chan int)} + client := NewClientWithCodec(codec) + <-codec.responded + client.Close() + if !codec.closed { + t.Error("client.Close did not close codec") + } +} + +// Test that errors in gob shut down the connection. Issue 7689. + +type R struct { + msg []byte // Not exported, so R does not work with gob. +} + +type S struct{} + +func (s *S) Recv(nul *struct{}, reply *R) error { + *reply = R{[]byte("foo")} + return nil +} + +func TestGobError(t *testing.T) { + defer func() { + err := recover() + if err == nil { + t.Fatal("no error") + } + if !strings.Contains(err.(error).Error(), "reading body EOF") { + t.Fatal("expected `reading body EOF', got", err) + } + }() + Register(new(S)) + + listen, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + panic(err) + } + go Accept(listen) + + client, err := Dial("tcp", listen.Addr().String()) + if err != nil { + panic(err) + } + + var reply Reply + err = client.Call("S.Recv", &struct{}{}, &reply) + if err != nil { + panic(err) + } + + fmt.Printf("%#v\n", reply) + client.Close() + + listen.Close() +} diff --git a/src/net/rpc/debug.go b/src/net/rpc/debug.go new file mode 100644 index 0000000..a1d799f --- /dev/null +++ b/src/net/rpc/debug.go @@ -0,0 +1,90 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package rpc + +/* + Some HTML presented at http://machine:port/debug/rpc + Lists services, their methods, and some statistics, still rudimentary. +*/ + +import ( + "fmt" + "html/template" + "net/http" + "sort" +) + +const debugText = `<html> + <body> + <title>Services</title> + {{range .}} + <hr> + Service {{.Name}} + <hr> + <table> + <th align=center>Method</th><th align=center>Calls</th> + {{range .Method}} + <tr> + <td align=left font=fixed>{{.Name}}({{.Type.ArgType}}, {{.Type.ReplyType}}) error</td> + <td align=center>{{.Type.NumCalls}}</td> + </tr> + {{end}} + </table> + {{end}} + </body> + </html>` + +var debug = template.Must(template.New("RPC debug").Parse(debugText)) + +// If set, print log statements for internal and I/O errors. +var debugLog = false + +type debugMethod struct { + Type *methodType + Name string +} + +type methodArray []debugMethod + +type debugService struct { + Service *service + Name string + Method methodArray +} + +type serviceArray []debugService + +func (s serviceArray) Len() int { return len(s) } +func (s serviceArray) Less(i, j int) bool { return s[i].Name < s[j].Name } +func (s serviceArray) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +func (m methodArray) Len() int { return len(m) } +func (m methodArray) Less(i, j int) bool { return m[i].Name < m[j].Name } +func (m methodArray) Swap(i, j int) { m[i], m[j] = m[j], m[i] } + +type debugHTTP struct { + *Server +} + +// Runs at /debug/rpc +func (server debugHTTP) ServeHTTP(w http.ResponseWriter, req *http.Request) { + // Build a sorted version of the data. + var services serviceArray + server.serviceMap.Range(func(snamei, svci interface{}) bool { + svc := svci.(*service) + ds := debugService{svc, snamei.(string), make(methodArray, 0, len(svc.method))} + for mname, method := range svc.method { + ds.Method = append(ds.Method, debugMethod{method, mname}) + } + sort.Sort(ds.Method) + services = append(services, ds) + return true + }) + sort.Sort(services) + err := debug.Execute(w, services) + if err != nil { + fmt.Fprintln(w, "rpc: error executing template:", err.Error()) + } +} diff --git a/src/net/rpc/jsonrpc/all_test.go b/src/net/rpc/jsonrpc/all_test.go new file mode 100644 index 0000000..667f839 --- /dev/null +++ b/src/net/rpc/jsonrpc/all_test.go @@ -0,0 +1,353 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package jsonrpc + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/rpc" + "reflect" + "strings" + "testing" +) + +type Args struct { + A, B int +} + +type Reply struct { + C int +} + +type Arith int + +type ArithAddResp struct { + Id interface{} `json:"id"` + Result Reply `json:"result"` + Error interface{} `json:"error"` +} + +func (t *Arith) Add(args *Args, reply *Reply) error { + reply.C = args.A + args.B + return nil +} + +func (t *Arith) Mul(args *Args, reply *Reply) error { + reply.C = args.A * args.B + return nil +} + +func (t *Arith) Div(args *Args, reply *Reply) error { + if args.B == 0 { + return errors.New("divide by zero") + } + reply.C = args.A / args.B + return nil +} + +func (t *Arith) Error(args *Args, reply *Reply) error { + panic("ERROR") +} + +type BuiltinTypes struct{} + +func (BuiltinTypes) Map(i int, reply *map[int]int) error { + (*reply)[i] = i + return nil +} + +func (BuiltinTypes) Slice(i int, reply *[]int) error { + *reply = append(*reply, i) + return nil +} + +func (BuiltinTypes) Array(i int, reply *[1]int) error { + (*reply)[0] = i + return nil +} + +func init() { + rpc.Register(new(Arith)) + rpc.Register(BuiltinTypes{}) +} + +func TestServerNoParams(t *testing.T) { + cli, srv := net.Pipe() + defer cli.Close() + go ServeConn(srv) + dec := json.NewDecoder(cli) + + fmt.Fprintf(cli, `{"method": "Arith.Add", "id": "123"}`) + var resp ArithAddResp + if err := dec.Decode(&resp); err != nil { + t.Fatalf("Decode after no params: %s", err) + } + if resp.Error == nil { + t.Fatalf("Expected error, got nil") + } +} + +func TestServerEmptyMessage(t *testing.T) { + cli, srv := net.Pipe() + defer cli.Close() + go ServeConn(srv) + dec := json.NewDecoder(cli) + + fmt.Fprintf(cli, "{}") + var resp ArithAddResp + if err := dec.Decode(&resp); err != nil { + t.Fatalf("Decode after empty: %s", err) + } + if resp.Error == nil { + t.Fatalf("Expected error, got nil") + } +} + +func TestServer(t *testing.T) { + cli, srv := net.Pipe() + defer cli.Close() + go ServeConn(srv) + dec := json.NewDecoder(cli) + + // Send hand-coded requests to server, parse responses. + for i := 0; i < 10; i++ { + fmt.Fprintf(cli, `{"method": "Arith.Add", "id": "\u%04d", "params": [{"A": %d, "B": %d}]}`, i, i, i+1) + var resp ArithAddResp + err := dec.Decode(&resp) + if err != nil { + t.Fatalf("Decode: %s", err) + } + if resp.Error != nil { + t.Fatalf("resp.Error: %s", resp.Error) + } + if resp.Id.(string) != string(rune(i)) { + t.Fatalf("resp: bad id %q want %q", resp.Id.(string), string(rune(i))) + } + if resp.Result.C != 2*i+1 { + t.Fatalf("resp: bad result: %d+%d=%d", i, i+1, resp.Result.C) + } + } +} + +func TestClient(t *testing.T) { + // Assume server is okay (TestServer is above). + // Test client against server. + cli, srv := net.Pipe() + go ServeConn(srv) + + client := NewClient(cli) + defer client.Close() + + // Synchronous calls + args := &Args{7, 8} + reply := new(Reply) + err := client.Call("Arith.Add", args, reply) + if err != nil { + t.Errorf("Add: expected no error but got string %q", err.Error()) + } + if reply.C != args.A+args.B { + t.Errorf("Add: got %d expected %d", reply.C, args.A+args.B) + } + + args = &Args{7, 8} + reply = new(Reply) + err = client.Call("Arith.Mul", args, reply) + if err != nil { + t.Errorf("Mul: expected no error but got string %q", err.Error()) + } + if reply.C != args.A*args.B { + t.Errorf("Mul: got %d expected %d", reply.C, args.A*args.B) + } + + // Out of order. + args = &Args{7, 8} + mulReply := new(Reply) + mulCall := client.Go("Arith.Mul", args, mulReply, nil) + addReply := new(Reply) + addCall := client.Go("Arith.Add", args, addReply, nil) + + addCall = <-addCall.Done + if addCall.Error != nil { + t.Errorf("Add: expected no error but got string %q", addCall.Error.Error()) + } + if addReply.C != args.A+args.B { + t.Errorf("Add: got %d expected %d", addReply.C, args.A+args.B) + } + + mulCall = <-mulCall.Done + if mulCall.Error != nil { + t.Errorf("Mul: expected no error but got string %q", mulCall.Error.Error()) + } + if mulReply.C != args.A*args.B { + t.Errorf("Mul: got %d expected %d", mulReply.C, args.A*args.B) + } + + // Error test + args = &Args{7, 0} + reply = new(Reply) + err = client.Call("Arith.Div", args, reply) + // expect an error: zero divide + if err == nil { + t.Error("Div: expected error") + } else if err.Error() != "divide by zero" { + t.Error("Div: expected divide by zero error; got", err) + } +} + +func TestBuiltinTypes(t *testing.T) { + cli, srv := net.Pipe() + go ServeConn(srv) + + client := NewClient(cli) + defer client.Close() + + // Map + arg := 7 + replyMap := map[int]int{} + err := client.Call("BuiltinTypes.Map", arg, &replyMap) + if err != nil { + t.Errorf("Map: expected no error but got string %q", err.Error()) + } + if replyMap[arg] != arg { + t.Errorf("Map: expected %d got %d", arg, replyMap[arg]) + } + + // Slice + replySlice := []int{} + err = client.Call("BuiltinTypes.Slice", arg, &replySlice) + if err != nil { + t.Errorf("Slice: expected no error but got string %q", err.Error()) + } + if e := []int{arg}; !reflect.DeepEqual(replySlice, e) { + t.Errorf("Slice: expected %v got %v", e, replySlice) + } + + // Array + replyArray := [1]int{} + err = client.Call("BuiltinTypes.Array", arg, &replyArray) + if err != nil { + t.Errorf("Array: expected no error but got string %q", err.Error()) + } + if e := [1]int{arg}; !reflect.DeepEqual(replyArray, e) { + t.Errorf("Array: expected %v got %v", e, replyArray) + } +} + +func TestMalformedInput(t *testing.T) { + cli, srv := net.Pipe() + go cli.Write([]byte(`{id:1}`)) // invalid json + ServeConn(srv) // must return, not loop +} + +func TestMalformedOutput(t *testing.T) { + cli, srv := net.Pipe() + go srv.Write([]byte(`{"id":0,"result":null,"error":null}`)) + go io.ReadAll(srv) + + client := NewClient(cli) + defer client.Close() + + args := &Args{7, 8} + reply := new(Reply) + err := client.Call("Arith.Add", args, reply) + if err == nil { + t.Error("expected error") + } +} + +func TestServerErrorHasNullResult(t *testing.T) { + var out bytes.Buffer + sc := NewServerCodec(struct { + io.Reader + io.Writer + io.Closer + }{ + Reader: strings.NewReader(`{"method": "Arith.Add", "id": "123", "params": []}`), + Writer: &out, + Closer: io.NopCloser(nil), + }) + r := new(rpc.Request) + if err := sc.ReadRequestHeader(r); err != nil { + t.Fatal(err) + } + const valueText = "the value we don't want to see" + const errorText = "some error" + err := sc.WriteResponse(&rpc.Response{ + ServiceMethod: "Method", + Seq: 1, + Error: errorText, + }, valueText) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(out.String(), errorText) { + t.Fatalf("Response didn't contain expected error %q: %s", errorText, &out) + } + if strings.Contains(out.String(), valueText) { + t.Errorf("Response contains both an error and value: %s", &out) + } +} + +func TestUnexpectedError(t *testing.T) { + cli, srv := myPipe() + go cli.PipeWriter.CloseWithError(errors.New("unexpected error!")) // reader will get this error + ServeConn(srv) // must return, not loop +} + +// Copied from package net. +func myPipe() (*pipe, *pipe) { + r1, w1 := io.Pipe() + r2, w2 := io.Pipe() + + return &pipe{r1, w2}, &pipe{r2, w1} +} + +type pipe struct { + *io.PipeReader + *io.PipeWriter +} + +type pipeAddr int + +func (pipeAddr) Network() string { + return "pipe" +} + +func (pipeAddr) String() string { + return "pipe" +} + +func (p *pipe) Close() error { + err := p.PipeReader.Close() + err1 := p.PipeWriter.Close() + if err == nil { + err = err1 + } + return err +} + +func (p *pipe) LocalAddr() net.Addr { + return pipeAddr(0) +} + +func (p *pipe) RemoteAddr() net.Addr { + return pipeAddr(0) +} + +func (p *pipe) SetTimeout(nsec int64) error { + return errors.New("net.Pipe does not support timeouts") +} + +func (p *pipe) SetReadTimeout(nsec int64) error { + return errors.New("net.Pipe does not support timeouts") +} + +func (p *pipe) SetWriteTimeout(nsec int64) error { + return errors.New("net.Pipe does not support timeouts") +} diff --git a/src/net/rpc/jsonrpc/client.go b/src/net/rpc/jsonrpc/client.go new file mode 100644 index 0000000..e6359be --- /dev/null +++ b/src/net/rpc/jsonrpc/client.go @@ -0,0 +1,124 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package jsonrpc implements a JSON-RPC 1.0 ClientCodec and ServerCodec +// for the rpc package. +// For JSON-RPC 2.0 support, see https://godoc.org/?q=json-rpc+2.0 +package jsonrpc + +import ( + "encoding/json" + "fmt" + "io" + "net" + "net/rpc" + "sync" +) + +type clientCodec struct { + dec *json.Decoder // for reading JSON values + enc *json.Encoder // for writing JSON values + c io.Closer + + // temporary work space + req clientRequest + resp clientResponse + + // JSON-RPC responses include the request id but not the request method. + // Package rpc expects both. + // We save the request method in pending when sending a request + // and then look it up by request ID when filling out the rpc Response. + mutex sync.Mutex // protects pending + pending map[uint64]string // map request id to method name +} + +// NewClientCodec returns a new rpc.ClientCodec using JSON-RPC on conn. +func NewClientCodec(conn io.ReadWriteCloser) rpc.ClientCodec { + return &clientCodec{ + dec: json.NewDecoder(conn), + enc: json.NewEncoder(conn), + c: conn, + pending: make(map[uint64]string), + } +} + +type clientRequest struct { + Method string `json:"method"` + Params [1]interface{} `json:"params"` + Id uint64 `json:"id"` +} + +func (c *clientCodec) WriteRequest(r *rpc.Request, param interface{}) error { + c.mutex.Lock() + c.pending[r.Seq] = r.ServiceMethod + c.mutex.Unlock() + c.req.Method = r.ServiceMethod + c.req.Params[0] = param + c.req.Id = r.Seq + return c.enc.Encode(&c.req) +} + +type clientResponse struct { + Id uint64 `json:"id"` + Result *json.RawMessage `json:"result"` + Error interface{} `json:"error"` +} + +func (r *clientResponse) reset() { + r.Id = 0 + r.Result = nil + r.Error = nil +} + +func (c *clientCodec) ReadResponseHeader(r *rpc.Response) error { + c.resp.reset() + if err := c.dec.Decode(&c.resp); err != nil { + return err + } + + c.mutex.Lock() + r.ServiceMethod = c.pending[c.resp.Id] + delete(c.pending, c.resp.Id) + c.mutex.Unlock() + + r.Error = "" + r.Seq = c.resp.Id + if c.resp.Error != nil || c.resp.Result == nil { + x, ok := c.resp.Error.(string) + if !ok { + return fmt.Errorf("invalid error %v", c.resp.Error) + } + if x == "" { + x = "unspecified error" + } + r.Error = x + } + return nil +} + +func (c *clientCodec) ReadResponseBody(x interface{}) error { + if x == nil { + return nil + } + return json.Unmarshal(*c.resp.Result, x) +} + +func (c *clientCodec) Close() error { + return c.c.Close() +} + +// NewClient returns a new rpc.Client to handle requests to the +// set of services at the other end of the connection. +func NewClient(conn io.ReadWriteCloser) *rpc.Client { + return rpc.NewClientWithCodec(NewClientCodec(conn)) +} + +// Dial connects to a JSON-RPC server at the specified network address. +func Dial(network, address string) (*rpc.Client, error) { + conn, err := net.Dial(network, address) + if err != nil { + return nil, err + } + return NewClient(conn), err +} diff --git a/src/net/rpc/jsonrpc/server.go b/src/net/rpc/jsonrpc/server.go new file mode 100644 index 0000000..40e4e6f --- /dev/null +++ b/src/net/rpc/jsonrpc/server.go @@ -0,0 +1,134 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package jsonrpc + +import ( + "encoding/json" + "errors" + "io" + "net/rpc" + "sync" +) + +var errMissingParams = errors.New("jsonrpc: request body missing params") + +type serverCodec struct { + dec *json.Decoder // for reading JSON values + enc *json.Encoder // for writing JSON values + c io.Closer + + // temporary work space + req serverRequest + + // JSON-RPC clients can use arbitrary json values as request IDs. + // Package rpc expects uint64 request IDs. + // We assign uint64 sequence numbers to incoming requests + // but save the original request ID in the pending map. + // When rpc responds, we use the sequence number in + // the response to find the original request ID. + mutex sync.Mutex // protects seq, pending + seq uint64 + pending map[uint64]*json.RawMessage +} + +// NewServerCodec returns a new rpc.ServerCodec using JSON-RPC on conn. +func NewServerCodec(conn io.ReadWriteCloser) rpc.ServerCodec { + return &serverCodec{ + dec: json.NewDecoder(conn), + enc: json.NewEncoder(conn), + c: conn, + pending: make(map[uint64]*json.RawMessage), + } +} + +type serverRequest struct { + Method string `json:"method"` + Params *json.RawMessage `json:"params"` + Id *json.RawMessage `json:"id"` +} + +func (r *serverRequest) reset() { + r.Method = "" + r.Params = nil + r.Id = nil +} + +type serverResponse struct { + Id *json.RawMessage `json:"id"` + Result interface{} `json:"result"` + Error interface{} `json:"error"` +} + +func (c *serverCodec) ReadRequestHeader(r *rpc.Request) error { + c.req.reset() + if err := c.dec.Decode(&c.req); err != nil { + return err + } + r.ServiceMethod = c.req.Method + + // JSON request id can be any JSON value; + // RPC package expects uint64. Translate to + // internal uint64 and save JSON on the side. + c.mutex.Lock() + c.seq++ + c.pending[c.seq] = c.req.Id + c.req.Id = nil + r.Seq = c.seq + c.mutex.Unlock() + + return nil +} + +func (c *serverCodec) ReadRequestBody(x interface{}) error { + if x == nil { + return nil + } + if c.req.Params == nil { + return errMissingParams + } + // JSON params is array value. + // RPC params is struct. + // Unmarshal into array containing struct for now. + // Should think about making RPC more general. + var params [1]interface{} + params[0] = x + return json.Unmarshal(*c.req.Params, ¶ms) +} + +var null = json.RawMessage([]byte("null")) + +func (c *serverCodec) WriteResponse(r *rpc.Response, x interface{}) error { + c.mutex.Lock() + b, ok := c.pending[r.Seq] + if !ok { + c.mutex.Unlock() + return errors.New("invalid sequence number in response") + } + delete(c.pending, r.Seq) + c.mutex.Unlock() + + if b == nil { + // Invalid request so no id. Use JSON null. + b = &null + } + resp := serverResponse{Id: b} + if r.Error == "" { + resp.Result = x + } else { + resp.Error = r.Error + } + return c.enc.Encode(resp) +} + +func (c *serverCodec) Close() error { + return c.c.Close() +} + +// ServeConn runs the JSON-RPC server on a single connection. +// ServeConn blocks, serving the connection until the client hangs up. +// The caller typically invokes ServeConn in a go statement. +func ServeConn(conn io.ReadWriteCloser) { + rpc.ServeCodec(NewServerCodec(conn)) +} diff --git a/src/net/rpc/server.go b/src/net/rpc/server.go new file mode 100644 index 0000000..9cb9282 --- /dev/null +++ b/src/net/rpc/server.go @@ -0,0 +1,720 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* + Package rpc provides access to the exported methods of an object across a + network or other I/O connection. A server registers an object, making it visible + as a service with the name of the type of the object. After registration, exported + methods of the object will be accessible remotely. A server may register multiple + objects (services) of different types but it is an error to register multiple + objects of the same type. + + Only methods that satisfy these criteria will be made available for remote access; + other methods will be ignored: + + - the method's type is exported. + - the method is exported. + - the method has two arguments, both exported (or builtin) types. + - the method's second argument is a pointer. + - the method has return type error. + + In effect, the method must look schematically like + + func (t *T) MethodName(argType T1, replyType *T2) error + + where T1 and T2 can be marshaled by encoding/gob. + These requirements apply even if a different codec is used. + (In the future, these requirements may soften for custom codecs.) + + The method's first argument represents the arguments provided by the caller; the + second argument represents the result parameters to be returned to the caller. + The method's return value, if non-nil, is passed back as a string that the client + sees as if created by errors.New. If an error is returned, the reply parameter + will not be sent back to the client. + + The server may handle requests on a single connection by calling ServeConn. More + typically it will create a network listener and call Accept or, for an HTTP + listener, HandleHTTP and http.Serve. + + A client wishing to use the service establishes a connection and then invokes + NewClient on the connection. The convenience function Dial (DialHTTP) performs + both steps for a raw network connection (an HTTP connection). The resulting + Client object has two methods, Call and Go, that specify the service and method to + call, a pointer containing the arguments, and a pointer to receive the result + parameters. + + The Call method waits for the remote call to complete while the Go method + launches the call asynchronously and signals completion using the Call + structure's Done channel. + + Unless an explicit codec is set up, package encoding/gob is used to + transport the data. + + Here is a simple example. A server wishes to export an object of type Arith: + + package server + + import "errors" + + type Args struct { + A, B int + } + + type Quotient struct { + Quo, Rem int + } + + type Arith int + + func (t *Arith) Multiply(args *Args, reply *int) error { + *reply = args.A * args.B + return nil + } + + func (t *Arith) Divide(args *Args, quo *Quotient) error { + if args.B == 0 { + return errors.New("divide by zero") + } + quo.Quo = args.A / args.B + quo.Rem = args.A % args.B + return nil + } + + The server calls (for HTTP service): + + arith := new(Arith) + rpc.Register(arith) + rpc.HandleHTTP() + l, e := net.Listen("tcp", ":1234") + if e != nil { + log.Fatal("listen error:", e) + } + go http.Serve(l, nil) + + At this point, clients can see a service "Arith" with methods "Arith.Multiply" and + "Arith.Divide". To invoke one, a client first dials the server: + + client, err := rpc.DialHTTP("tcp", serverAddress + ":1234") + if err != nil { + log.Fatal("dialing:", err) + } + + Then it can make a remote call: + + // Synchronous call + args := &server.Args{7,8} + var reply int + err = client.Call("Arith.Multiply", args, &reply) + if err != nil { + log.Fatal("arith error:", err) + } + fmt.Printf("Arith: %d*%d=%d", args.A, args.B, reply) + + or + + // Asynchronous call + quotient := new(Quotient) + divCall := client.Go("Arith.Divide", args, quotient, nil) + replyCall := <-divCall.Done // will be equal to divCall + // check errors, print, etc. + + A server implementation will often provide a simple, type-safe wrapper for the + client. + + The net/rpc package is frozen and is not accepting new features. +*/ +package rpc + +import ( + "bufio" + "encoding/gob" + "errors" + "go/token" + "io" + "log" + "net" + "net/http" + "reflect" + "strings" + "sync" +) + +const ( + // Defaults used by HandleHTTP + DefaultRPCPath = "/_goRPC_" + DefaultDebugPath = "/debug/rpc" +) + +// Precompute the reflect type for error. Can't use error directly +// because Typeof takes an empty interface value. This is annoying. +var typeOfError = reflect.TypeOf((*error)(nil)).Elem() + +type methodType struct { + sync.Mutex // protects counters + method reflect.Method + ArgType reflect.Type + ReplyType reflect.Type + numCalls uint +} + +type service struct { + name string // name of service + rcvr reflect.Value // receiver of methods for the service + typ reflect.Type // type of the receiver + method map[string]*methodType // registered methods +} + +// Request is a header written before every RPC call. It is used internally +// but documented here as an aid to debugging, such as when analyzing +// network traffic. +type Request struct { + ServiceMethod string // format: "Service.Method" + Seq uint64 // sequence number chosen by client + next *Request // for free list in Server +} + +// Response is a header written before every RPC return. It is used internally +// but documented here as an aid to debugging, such as when analyzing +// network traffic. +type Response struct { + ServiceMethod string // echoes that of the Request + Seq uint64 // echoes that of the request + Error string // error, if any. + next *Response // for free list in Server +} + +// Server represents an RPC Server. +type Server struct { + serviceMap sync.Map // map[string]*service + reqLock sync.Mutex // protects freeReq + freeReq *Request + respLock sync.Mutex // protects freeResp + freeResp *Response +} + +// NewServer returns a new Server. +func NewServer() *Server { + return &Server{} +} + +// DefaultServer is the default instance of *Server. +var DefaultServer = NewServer() + +// Is this type exported or a builtin? +func isExportedOrBuiltinType(t reflect.Type) bool { + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + // PkgPath will be non-empty even for an exported type, + // so we need to check the type name as well. + return token.IsExported(t.Name()) || t.PkgPath() == "" +} + +// Register publishes in the server the set of methods of the +// receiver value that satisfy the following conditions: +// - exported method of exported type +// - two arguments, both of exported type +// - the second argument is a pointer +// - one return value, of type error +// It returns an error if the receiver is not an exported type or has +// no suitable methods. It also logs the error using package log. +// The client accesses each method using a string of the form "Type.Method", +// where Type is the receiver's concrete type. +func (server *Server) Register(rcvr interface{}) error { + return server.register(rcvr, "", false) +} + +// RegisterName is like Register but uses the provided name for the type +// instead of the receiver's concrete type. +func (server *Server) RegisterName(name string, rcvr interface{}) error { + return server.register(rcvr, name, true) +} + +func (server *Server) register(rcvr interface{}, name string, useName bool) error { + s := new(service) + s.typ = reflect.TypeOf(rcvr) + s.rcvr = reflect.ValueOf(rcvr) + sname := reflect.Indirect(s.rcvr).Type().Name() + if useName { + sname = name + } + if sname == "" { + s := "rpc.Register: no service name for type " + s.typ.String() + log.Print(s) + return errors.New(s) + } + if !token.IsExported(sname) && !useName { + s := "rpc.Register: type " + sname + " is not exported" + log.Print(s) + return errors.New(s) + } + s.name = sname + + // Install the methods + s.method = suitableMethods(s.typ, true) + + if len(s.method) == 0 { + str := "" + + // To help the user, see if a pointer receiver would work. + method := suitableMethods(reflect.PtrTo(s.typ), false) + if len(method) != 0 { + str = "rpc.Register: type " + sname + " has no exported methods of suitable type (hint: pass a pointer to value of that type)" + } else { + str = "rpc.Register: type " + sname + " has no exported methods of suitable type" + } + log.Print(str) + return errors.New(str) + } + + if _, dup := server.serviceMap.LoadOrStore(sname, s); dup { + return errors.New("rpc: service already defined: " + sname) + } + return nil +} + +// suitableMethods returns suitable Rpc methods of typ, it will report +// error using log if reportErr is true. +func suitableMethods(typ reflect.Type, reportErr bool) map[string]*methodType { + methods := make(map[string]*methodType) + for m := 0; m < typ.NumMethod(); m++ { + method := typ.Method(m) + mtype := method.Type + mname := method.Name + // Method must be exported. + if method.PkgPath != "" { + continue + } + // Method needs three ins: receiver, *args, *reply. + if mtype.NumIn() != 3 { + if reportErr { + log.Printf("rpc.Register: method %q has %d input parameters; needs exactly three\n", mname, mtype.NumIn()) + } + continue + } + // First arg need not be a pointer. + argType := mtype.In(1) + if !isExportedOrBuiltinType(argType) { + if reportErr { + log.Printf("rpc.Register: argument type of method %q is not exported: %q\n", mname, argType) + } + continue + } + // Second arg must be a pointer. + replyType := mtype.In(2) + if replyType.Kind() != reflect.Ptr { + if reportErr { + log.Printf("rpc.Register: reply type of method %q is not a pointer: %q\n", mname, replyType) + } + continue + } + // Reply type must be exported. + if !isExportedOrBuiltinType(replyType) { + if reportErr { + log.Printf("rpc.Register: reply type of method %q is not exported: %q\n", mname, replyType) + } + continue + } + // Method needs one out. + if mtype.NumOut() != 1 { + if reportErr { + log.Printf("rpc.Register: method %q has %d output parameters; needs exactly one\n", mname, mtype.NumOut()) + } + continue + } + // The return type of the method must be error. + if returnType := mtype.Out(0); returnType != typeOfError { + if reportErr { + log.Printf("rpc.Register: return type of method %q is %q, must be error\n", mname, returnType) + } + continue + } + methods[mname] = &methodType{method: method, ArgType: argType, ReplyType: replyType} + } + return methods +} + +// A value sent as a placeholder for the server's response value when the server +// receives an invalid request. It is never decoded by the client since the Response +// contains an error when it is used. +var invalidRequest = struct{}{} + +func (server *Server) sendResponse(sending *sync.Mutex, req *Request, reply interface{}, codec ServerCodec, errmsg string) { + resp := server.getResponse() + // Encode the response header + resp.ServiceMethod = req.ServiceMethod + if errmsg != "" { + resp.Error = errmsg + reply = invalidRequest + } + resp.Seq = req.Seq + sending.Lock() + err := codec.WriteResponse(resp, reply) + if debugLog && err != nil { + log.Println("rpc: writing response:", err) + } + sending.Unlock() + server.freeResponse(resp) +} + +func (m *methodType) NumCalls() (n uint) { + m.Lock() + n = m.numCalls + m.Unlock() + return n +} + +func (s *service) call(server *Server, sending *sync.Mutex, wg *sync.WaitGroup, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) { + if wg != nil { + defer wg.Done() + } + mtype.Lock() + mtype.numCalls++ + mtype.Unlock() + function := mtype.method.Func + // Invoke the method, providing a new value for the reply. + returnValues := function.Call([]reflect.Value{s.rcvr, argv, replyv}) + // The return value for the method is an error. + errInter := returnValues[0].Interface() + errmsg := "" + if errInter != nil { + errmsg = errInter.(error).Error() + } + server.sendResponse(sending, req, replyv.Interface(), codec, errmsg) + server.freeRequest(req) +} + +type gobServerCodec struct { + rwc io.ReadWriteCloser + dec *gob.Decoder + enc *gob.Encoder + encBuf *bufio.Writer + closed bool +} + +func (c *gobServerCodec) ReadRequestHeader(r *Request) error { + return c.dec.Decode(r) +} + +func (c *gobServerCodec) ReadRequestBody(body interface{}) error { + return c.dec.Decode(body) +} + +func (c *gobServerCodec) WriteResponse(r *Response, body interface{}) (err error) { + if err = c.enc.Encode(r); err != nil { + if c.encBuf.Flush() == nil { + // Gob couldn't encode the header. Should not happen, so if it does, + // shut down the connection to signal that the connection is broken. + log.Println("rpc: gob error encoding response:", err) + c.Close() + } + return + } + if err = c.enc.Encode(body); err != nil { + if c.encBuf.Flush() == nil { + // Was a gob problem encoding the body but the header has been written. + // Shut down the connection to signal that the connection is broken. + log.Println("rpc: gob error encoding body:", err) + c.Close() + } + return + } + return c.encBuf.Flush() +} + +func (c *gobServerCodec) Close() error { + if c.closed { + // Only call c.rwc.Close once; otherwise the semantics are undefined. + return nil + } + c.closed = true + return c.rwc.Close() +} + +// ServeConn runs the server on a single connection. +// ServeConn blocks, serving the connection until the client hangs up. +// The caller typically invokes ServeConn in a go statement. +// ServeConn uses the gob wire format (see package gob) on the +// connection. To use an alternate codec, use ServeCodec. +// See NewClient's comment for information about concurrent access. +func (server *Server) ServeConn(conn io.ReadWriteCloser) { + buf := bufio.NewWriter(conn) + srv := &gobServerCodec{ + rwc: conn, + dec: gob.NewDecoder(conn), + enc: gob.NewEncoder(buf), + encBuf: buf, + } + server.ServeCodec(srv) +} + +// ServeCodec is like ServeConn but uses the specified codec to +// decode requests and encode responses. +func (server *Server) ServeCodec(codec ServerCodec) { + sending := new(sync.Mutex) + wg := new(sync.WaitGroup) + for { + service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec) + if err != nil { + if debugLog && err != io.EOF { + log.Println("rpc:", err) + } + if !keepReading { + break + } + // send a response if we actually managed to read a header. + if req != nil { + server.sendResponse(sending, req, invalidRequest, codec, err.Error()) + server.freeRequest(req) + } + continue + } + wg.Add(1) + go service.call(server, sending, wg, mtype, req, argv, replyv, codec) + } + // We've seen that there are no more requests. + // Wait for responses to be sent before closing codec. + wg.Wait() + codec.Close() +} + +// ServeRequest is like ServeCodec but synchronously serves a single request. +// It does not close the codec upon completion. +func (server *Server) ServeRequest(codec ServerCodec) error { + sending := new(sync.Mutex) + service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec) + if err != nil { + if !keepReading { + return err + } + // send a response if we actually managed to read a header. + if req != nil { + server.sendResponse(sending, req, invalidRequest, codec, err.Error()) + server.freeRequest(req) + } + return err + } + service.call(server, sending, nil, mtype, req, argv, replyv, codec) + return nil +} + +func (server *Server) getRequest() *Request { + server.reqLock.Lock() + req := server.freeReq + if req == nil { + req = new(Request) + } else { + server.freeReq = req.next + *req = Request{} + } + server.reqLock.Unlock() + return req +} + +func (server *Server) freeRequest(req *Request) { + server.reqLock.Lock() + req.next = server.freeReq + server.freeReq = req + server.reqLock.Unlock() +} + +func (server *Server) getResponse() *Response { + server.respLock.Lock() + resp := server.freeResp + if resp == nil { + resp = new(Response) + } else { + server.freeResp = resp.next + *resp = Response{} + } + server.respLock.Unlock() + return resp +} + +func (server *Server) freeResponse(resp *Response) { + server.respLock.Lock() + resp.next = server.freeResp + server.freeResp = resp + server.respLock.Unlock() +} + +func (server *Server) readRequest(codec ServerCodec) (service *service, mtype *methodType, req *Request, argv, replyv reflect.Value, keepReading bool, err error) { + service, mtype, req, keepReading, err = server.readRequestHeader(codec) + if err != nil { + if !keepReading { + return + } + // discard body + codec.ReadRequestBody(nil) + return + } + + // Decode the argument value. + argIsValue := false // if true, need to indirect before calling. + if mtype.ArgType.Kind() == reflect.Ptr { + argv = reflect.New(mtype.ArgType.Elem()) + } else { + argv = reflect.New(mtype.ArgType) + argIsValue = true + } + // argv guaranteed to be a pointer now. + if err = codec.ReadRequestBody(argv.Interface()); err != nil { + return + } + if argIsValue { + argv = argv.Elem() + } + + replyv = reflect.New(mtype.ReplyType.Elem()) + + switch mtype.ReplyType.Elem().Kind() { + case reflect.Map: + replyv.Elem().Set(reflect.MakeMap(mtype.ReplyType.Elem())) + case reflect.Slice: + replyv.Elem().Set(reflect.MakeSlice(mtype.ReplyType.Elem(), 0, 0)) + } + return +} + +func (server *Server) readRequestHeader(codec ServerCodec) (svc *service, mtype *methodType, req *Request, keepReading bool, err error) { + // Grab the request header. + req = server.getRequest() + err = codec.ReadRequestHeader(req) + if err != nil { + req = nil + if err == io.EOF || err == io.ErrUnexpectedEOF { + return + } + err = errors.New("rpc: server cannot decode request: " + err.Error()) + return + } + + // We read the header successfully. If we see an error now, + // we can still recover and move on to the next request. + keepReading = true + + dot := strings.LastIndex(req.ServiceMethod, ".") + if dot < 0 { + err = errors.New("rpc: service/method request ill-formed: " + req.ServiceMethod) + return + } + serviceName := req.ServiceMethod[:dot] + methodName := req.ServiceMethod[dot+1:] + + // Look up the request. + svci, ok := server.serviceMap.Load(serviceName) + if !ok { + err = errors.New("rpc: can't find service " + req.ServiceMethod) + return + } + svc = svci.(*service) + mtype = svc.method[methodName] + if mtype == nil { + err = errors.New("rpc: can't find method " + req.ServiceMethod) + } + return +} + +// Accept accepts connections on the listener and serves requests +// for each incoming connection. Accept blocks until the listener +// returns a non-nil error. The caller typically invokes Accept in a +// go statement. +func (server *Server) Accept(lis net.Listener) { + for { + conn, err := lis.Accept() + if err != nil { + log.Print("rpc.Serve: accept:", err.Error()) + return + } + go server.ServeConn(conn) + } +} + +// Register publishes the receiver's methods in the DefaultServer. +func Register(rcvr interface{}) error { return DefaultServer.Register(rcvr) } + +// RegisterName is like Register but uses the provided name for the type +// instead of the receiver's concrete type. +func RegisterName(name string, rcvr interface{}) error { + return DefaultServer.RegisterName(name, rcvr) +} + +// A ServerCodec implements reading of RPC requests and writing of +// RPC responses for the server side of an RPC session. +// The server calls ReadRequestHeader and ReadRequestBody in pairs +// to read requests from the connection, and it calls WriteResponse to +// write a response back. The server calls Close when finished with the +// connection. ReadRequestBody may be called with a nil +// argument to force the body of the request to be read and discarded. +// See NewClient's comment for information about concurrent access. +type ServerCodec interface { + ReadRequestHeader(*Request) error + ReadRequestBody(interface{}) error + WriteResponse(*Response, interface{}) error + + // Close can be called multiple times and must be idempotent. + Close() error +} + +// ServeConn runs the DefaultServer on a single connection. +// ServeConn blocks, serving the connection until the client hangs up. +// The caller typically invokes ServeConn in a go statement. +// ServeConn uses the gob wire format (see package gob) on the +// connection. To use an alternate codec, use ServeCodec. +// See NewClient's comment for information about concurrent access. +func ServeConn(conn io.ReadWriteCloser) { + DefaultServer.ServeConn(conn) +} + +// ServeCodec is like ServeConn but uses the specified codec to +// decode requests and encode responses. +func ServeCodec(codec ServerCodec) { + DefaultServer.ServeCodec(codec) +} + +// ServeRequest is like ServeCodec but synchronously serves a single request. +// It does not close the codec upon completion. +func ServeRequest(codec ServerCodec) error { + return DefaultServer.ServeRequest(codec) +} + +// Accept accepts connections on the listener and serves requests +// to DefaultServer for each incoming connection. +// Accept blocks; the caller typically invokes it in a go statement. +func Accept(lis net.Listener) { DefaultServer.Accept(lis) } + +// Can connect to RPC service using HTTP CONNECT to rpcPath. +var connected = "200 Connected to Go RPC" + +// ServeHTTP implements an http.Handler that answers RPC requests. +func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if req.Method != "CONNECT" { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusMethodNotAllowed) + io.WriteString(w, "405 must CONNECT\n") + return + } + conn, _, err := w.(http.Hijacker).Hijack() + if err != nil { + log.Print("rpc hijacking ", req.RemoteAddr, ": ", err.Error()) + return + } + io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n") + server.ServeConn(conn) +} + +// HandleHTTP registers an HTTP handler for RPC messages on rpcPath, +// and a debugging handler on debugPath. +// It is still necessary to invoke http.Serve(), typically in a go statement. +func (server *Server) HandleHTTP(rpcPath, debugPath string) { + http.Handle(rpcPath, server) + http.Handle(debugPath, debugHTTP{server}) +} + +// HandleHTTP registers an HTTP handler for RPC messages to DefaultServer +// on DefaultRPCPath and a debugging handler on DefaultDebugPath. +// It is still necessary to invoke http.Serve(), typically in a go statement. +func HandleHTTP() { + DefaultServer.HandleHTTP(DefaultRPCPath, DefaultDebugPath) +} diff --git a/src/net/rpc/server_test.go b/src/net/rpc/server_test.go new file mode 100644 index 0000000..e5d7fe0 --- /dev/null +++ b/src/net/rpc/server_test.go @@ -0,0 +1,839 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package rpc + +import ( + "errors" + "fmt" + "io" + "log" + "net" + "net/http/httptest" + "reflect" + "runtime" + "strings" + "sync" + "sync/atomic" + "testing" + "time" +) + +var ( + newServer *Server + serverAddr, newServerAddr string + httpServerAddr string + once, newOnce, httpOnce sync.Once +) + +const ( + newHttpPath = "/foo" +) + +type Args struct { + A, B int +} + +type Reply struct { + C int +} + +type Arith int + +// Some of Arith's methods have value args, some have pointer args. That's deliberate. + +func (t *Arith) Add(args Args, reply *Reply) error { + reply.C = args.A + args.B + return nil +} + +func (t *Arith) Mul(args *Args, reply *Reply) error { + reply.C = args.A * args.B + return nil +} + +func (t *Arith) Div(args Args, reply *Reply) error { + if args.B == 0 { + return errors.New("divide by zero") + } + reply.C = args.A / args.B + return nil +} + +func (t *Arith) String(args *Args, reply *string) error { + *reply = fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B) + return nil +} + +func (t *Arith) Scan(args string, reply *Reply) (err error) { + _, err = fmt.Sscan(args, &reply.C) + return +} + +func (t *Arith) Error(args *Args, reply *Reply) error { + panic("ERROR") +} + +func (t *Arith) SleepMilli(args *Args, reply *Reply) error { + time.Sleep(time.Duration(args.A) * time.Millisecond) + return nil +} + +type hidden int + +func (t *hidden) Exported(args Args, reply *Reply) error { + reply.C = args.A + args.B + return nil +} + +type Embed struct { + hidden +} + +type BuiltinTypes struct{} + +func (BuiltinTypes) Map(args *Args, reply *map[int]int) error { + (*reply)[args.A] = args.B + return nil +} + +func (BuiltinTypes) Slice(args *Args, reply *[]int) error { + *reply = append(*reply, args.A, args.B) + return nil +} + +func (BuiltinTypes) Array(args *Args, reply *[2]int) error { + (*reply)[0] = args.A + (*reply)[1] = args.B + return nil +} + +func listenTCP() (net.Listener, string) { + l, e := net.Listen("tcp", "127.0.0.1:0") // any available address + if e != nil { + log.Fatalf("net.Listen tcp :0: %v", e) + } + return l, l.Addr().String() +} + +func startServer() { + Register(new(Arith)) + Register(new(Embed)) + RegisterName("net.rpc.Arith", new(Arith)) + Register(BuiltinTypes{}) + + var l net.Listener + l, serverAddr = listenTCP() + log.Println("Test RPC server listening on", serverAddr) + go Accept(l) + + HandleHTTP() + httpOnce.Do(startHttpServer) +} + +func startNewServer() { + newServer = NewServer() + newServer.Register(new(Arith)) + newServer.Register(new(Embed)) + newServer.RegisterName("net.rpc.Arith", new(Arith)) + newServer.RegisterName("newServer.Arith", new(Arith)) + + var l net.Listener + l, newServerAddr = listenTCP() + log.Println("NewServer test RPC server listening on", newServerAddr) + go newServer.Accept(l) + + newServer.HandleHTTP(newHttpPath, "/bar") + httpOnce.Do(startHttpServer) +} + +func startHttpServer() { + server := httptest.NewServer(nil) + httpServerAddr = server.Listener.Addr().String() + log.Println("Test HTTP RPC server listening on", httpServerAddr) +} + +func TestRPC(t *testing.T) { + once.Do(startServer) + testRPC(t, serverAddr) + newOnce.Do(startNewServer) + testRPC(t, newServerAddr) + testNewServerRPC(t, newServerAddr) +} + +func testRPC(t *testing.T, addr string) { + client, err := Dial("tcp", addr) + if err != nil { + t.Fatal("dialing", err) + } + defer client.Close() + + // Synchronous calls + args := &Args{7, 8} + reply := new(Reply) + err = client.Call("Arith.Add", args, reply) + if err != nil { + t.Errorf("Add: expected no error but got string %q", err.Error()) + } + if reply.C != args.A+args.B { + t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) + } + + // Methods exported from unexported embedded structs + args = &Args{7, 0} + reply = new(Reply) + err = client.Call("Embed.Exported", args, reply) + if err != nil { + t.Errorf("Add: expected no error but got string %q", err.Error()) + } + if reply.C != args.A+args.B { + t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) + } + + // Nonexistent method + args = &Args{7, 0} + reply = new(Reply) + err = client.Call("Arith.BadOperation", args, reply) + // expect an error + if err == nil { + t.Error("BadOperation: expected error") + } else if !strings.HasPrefix(err.Error(), "rpc: can't find method ") { + t.Errorf("BadOperation: expected can't find method error; got %q", err) + } + + // Unknown service + args = &Args{7, 8} + reply = new(Reply) + err = client.Call("Arith.Unknown", args, reply) + if err == nil { + t.Error("expected error calling unknown service") + } else if !strings.Contains(err.Error(), "method") { + t.Error("expected error about method; got", err) + } + + // Out of order. + args = &Args{7, 8} + mulReply := new(Reply) + mulCall := client.Go("Arith.Mul", args, mulReply, nil) + addReply := new(Reply) + addCall := client.Go("Arith.Add", args, addReply, nil) + + addCall = <-addCall.Done + if addCall.Error != nil { + t.Errorf("Add: expected no error but got string %q", addCall.Error.Error()) + } + if addReply.C != args.A+args.B { + t.Errorf("Add: expected %d got %d", addReply.C, args.A+args.B) + } + + mulCall = <-mulCall.Done + if mulCall.Error != nil { + t.Errorf("Mul: expected no error but got string %q", mulCall.Error.Error()) + } + if mulReply.C != args.A*args.B { + t.Errorf("Mul: expected %d got %d", mulReply.C, args.A*args.B) + } + + // Error test + args = &Args{7, 0} + reply = new(Reply) + err = client.Call("Arith.Div", args, reply) + // expect an error: zero divide + if err == nil { + t.Error("Div: expected error") + } else if err.Error() != "divide by zero" { + t.Error("Div: expected divide by zero error; got", err) + } + + // Bad type. + reply = new(Reply) + err = client.Call("Arith.Add", reply, reply) // args, reply would be the correct thing to use + if err == nil { + t.Error("expected error calling Arith.Add with wrong arg type") + } else if !strings.Contains(err.Error(), "type") { + t.Error("expected error about type; got", err) + } + + // Non-struct argument + const Val = 12345 + str := fmt.Sprint(Val) + reply = new(Reply) + err = client.Call("Arith.Scan", &str, reply) + if err != nil { + t.Errorf("Scan: expected no error but got string %q", err.Error()) + } else if reply.C != Val { + t.Errorf("Scan: expected %d got %d", Val, reply.C) + } + + // Non-struct reply + args = &Args{27, 35} + str = "" + err = client.Call("Arith.String", args, &str) + if err != nil { + t.Errorf("String: expected no error but got string %q", err.Error()) + } + expect := fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B) + if str != expect { + t.Errorf("String: expected %s got %s", expect, str) + } + + args = &Args{7, 8} + reply = new(Reply) + err = client.Call("Arith.Mul", args, reply) + if err != nil { + t.Errorf("Mul: expected no error but got string %q", err.Error()) + } + if reply.C != args.A*args.B { + t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B) + } + + // ServiceName contain "." character + args = &Args{7, 8} + reply = new(Reply) + err = client.Call("net.rpc.Arith.Add", args, reply) + if err != nil { + t.Errorf("Add: expected no error but got string %q", err.Error()) + } + if reply.C != args.A+args.B { + t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) + } +} + +func testNewServerRPC(t *testing.T, addr string) { + client, err := Dial("tcp", addr) + if err != nil { + t.Fatal("dialing", err) + } + defer client.Close() + + // Synchronous calls + args := &Args{7, 8} + reply := new(Reply) + err = client.Call("newServer.Arith.Add", args, reply) + if err != nil { + t.Errorf("Add: expected no error but got string %q", err.Error()) + } + if reply.C != args.A+args.B { + t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) + } +} + +func TestHTTP(t *testing.T) { + once.Do(startServer) + testHTTPRPC(t, "") + newOnce.Do(startNewServer) + testHTTPRPC(t, newHttpPath) +} + +func testHTTPRPC(t *testing.T, path string) { + var client *Client + var err error + if path == "" { + client, err = DialHTTP("tcp", httpServerAddr) + } else { + client, err = DialHTTPPath("tcp", httpServerAddr, path) + } + if err != nil { + t.Fatal("dialing", err) + } + defer client.Close() + + // Synchronous calls + args := &Args{7, 8} + reply := new(Reply) + err = client.Call("Arith.Add", args, reply) + if err != nil { + t.Errorf("Add: expected no error but got string %q", err.Error()) + } + if reply.C != args.A+args.B { + t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) + } +} + +func TestBuiltinTypes(t *testing.T) { + once.Do(startServer) + + client, err := DialHTTP("tcp", httpServerAddr) + if err != nil { + t.Fatal("dialing", err) + } + defer client.Close() + + // Map + args := &Args{7, 8} + replyMap := map[int]int{} + err = client.Call("BuiltinTypes.Map", args, &replyMap) + if err != nil { + t.Errorf("Map: expected no error but got string %q", err.Error()) + } + if replyMap[args.A] != args.B { + t.Errorf("Map: expected %d got %d", args.B, replyMap[args.A]) + } + + // Slice + args = &Args{7, 8} + replySlice := []int{} + err = client.Call("BuiltinTypes.Slice", args, &replySlice) + if err != nil { + t.Errorf("Slice: expected no error but got string %q", err.Error()) + } + if e := []int{args.A, args.B}; !reflect.DeepEqual(replySlice, e) { + t.Errorf("Slice: expected %v got %v", e, replySlice) + } + + // Array + args = &Args{7, 8} + replyArray := [2]int{} + err = client.Call("BuiltinTypes.Array", args, &replyArray) + if err != nil { + t.Errorf("Array: expected no error but got string %q", err.Error()) + } + if e := [2]int{args.A, args.B}; !reflect.DeepEqual(replyArray, e) { + t.Errorf("Array: expected %v got %v", e, replyArray) + } +} + +// CodecEmulator provides a client-like api and a ServerCodec interface. +// Can be used to test ServeRequest. +type CodecEmulator struct { + server *Server + serviceMethod string + args *Args + reply *Reply + err error +} + +func (codec *CodecEmulator) Call(serviceMethod string, args *Args, reply *Reply) error { + codec.serviceMethod = serviceMethod + codec.args = args + codec.reply = reply + codec.err = nil + var serverError error + if codec.server == nil { + serverError = ServeRequest(codec) + } else { + serverError = codec.server.ServeRequest(codec) + } + if codec.err == nil && serverError != nil { + codec.err = serverError + } + return codec.err +} + +func (codec *CodecEmulator) ReadRequestHeader(req *Request) error { + req.ServiceMethod = codec.serviceMethod + req.Seq = 0 + return nil +} + +func (codec *CodecEmulator) ReadRequestBody(argv interface{}) error { + if codec.args == nil { + return io.ErrUnexpectedEOF + } + *(argv.(*Args)) = *codec.args + return nil +} + +func (codec *CodecEmulator) WriteResponse(resp *Response, reply interface{}) error { + if resp.Error != "" { + codec.err = errors.New(resp.Error) + } else { + *codec.reply = *(reply.(*Reply)) + } + return nil +} + +func (codec *CodecEmulator) Close() error { + return nil +} + +func TestServeRequest(t *testing.T) { + once.Do(startServer) + testServeRequest(t, nil) + newOnce.Do(startNewServer) + testServeRequest(t, newServer) +} + +func testServeRequest(t *testing.T, server *Server) { + client := CodecEmulator{server: server} + defer client.Close() + + args := &Args{7, 8} + reply := new(Reply) + err := client.Call("Arith.Add", args, reply) + if err != nil { + t.Errorf("Add: expected no error but got string %q", err.Error()) + } + if reply.C != args.A+args.B { + t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) + } + + err = client.Call("Arith.Add", nil, reply) + if err == nil { + t.Errorf("expected error calling Arith.Add with nil arg") + } +} + +type ReplyNotPointer int +type ArgNotPublic int +type ReplyNotPublic int +type NeedsPtrType int +type local struct{} + +func (t *ReplyNotPointer) ReplyNotPointer(args *Args, reply Reply) error { + return nil +} + +func (t *ArgNotPublic) ArgNotPublic(args *local, reply *Reply) error { + return nil +} + +func (t *ReplyNotPublic) ReplyNotPublic(args *Args, reply *local) error { + return nil +} + +func (t *NeedsPtrType) NeedsPtrType(args *Args, reply *Reply) error { + return nil +} + +// Check that registration handles lots of bad methods and a type with no suitable methods. +func TestRegistrationError(t *testing.T) { + err := Register(new(ReplyNotPointer)) + if err == nil { + t.Error("expected error registering ReplyNotPointer") + } + err = Register(new(ArgNotPublic)) + if err == nil { + t.Error("expected error registering ArgNotPublic") + } + err = Register(new(ReplyNotPublic)) + if err == nil { + t.Error("expected error registering ReplyNotPublic") + } + err = Register(NeedsPtrType(0)) + if err == nil { + t.Error("expected error registering NeedsPtrType") + } else if !strings.Contains(err.Error(), "pointer") { + t.Error("expected hint when registering NeedsPtrType") + } +} + +type WriteFailCodec int + +func (WriteFailCodec) WriteRequest(*Request, interface{}) error { + // the panic caused by this error used to not unlock a lock. + return errors.New("fail") +} + +func (WriteFailCodec) ReadResponseHeader(*Response) error { + select {} +} + +func (WriteFailCodec) ReadResponseBody(interface{}) error { + select {} +} + +func (WriteFailCodec) Close() error { + return nil +} + +func TestSendDeadlock(t *testing.T) { + client := NewClientWithCodec(WriteFailCodec(0)) + defer client.Close() + + done := make(chan bool) + go func() { + testSendDeadlock(client) + testSendDeadlock(client) + done <- true + }() + select { + case <-done: + return + case <-time.After(5 * time.Second): + t.Fatal("deadlock") + } +} + +func testSendDeadlock(client *Client) { + defer func() { + recover() + }() + args := &Args{7, 8} + reply := new(Reply) + client.Call("Arith.Add", args, reply) +} + +func dialDirect() (*Client, error) { + return Dial("tcp", serverAddr) +} + +func dialHTTP() (*Client, error) { + return DialHTTP("tcp", httpServerAddr) +} + +func countMallocs(dial func() (*Client, error), t *testing.T) float64 { + once.Do(startServer) + client, err := dial() + if err != nil { + t.Fatal("error dialing", err) + } + defer client.Close() + + args := &Args{7, 8} + reply := new(Reply) + return testing.AllocsPerRun(100, func() { + err := client.Call("Arith.Add", args, reply) + if err != nil { + t.Errorf("Add: expected no error but got string %q", err.Error()) + } + if reply.C != args.A+args.B { + t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) + } + }) +} + +func TestCountMallocs(t *testing.T) { + if testing.Short() { + t.Skip("skipping malloc count in short mode") + } + if runtime.GOMAXPROCS(0) > 1 { + t.Skip("skipping; GOMAXPROCS>1") + } + fmt.Printf("mallocs per rpc round trip: %v\n", countMallocs(dialDirect, t)) +} + +func TestCountMallocsOverHTTP(t *testing.T) { + if testing.Short() { + t.Skip("skipping malloc count in short mode") + } + if runtime.GOMAXPROCS(0) > 1 { + t.Skip("skipping; GOMAXPROCS>1") + } + fmt.Printf("mallocs per HTTP rpc round trip: %v\n", countMallocs(dialHTTP, t)) +} + +type writeCrasher struct { + done chan bool +} + +func (writeCrasher) Close() error { + return nil +} + +func (w *writeCrasher) Read(p []byte) (int, error) { + <-w.done + return 0, io.EOF +} + +func (writeCrasher) Write(p []byte) (int, error) { + return 0, errors.New("fake write failure") +} + +func TestClientWriteError(t *testing.T) { + w := &writeCrasher{done: make(chan bool)} + c := NewClient(w) + defer c.Close() + + res := false + err := c.Call("foo", 1, &res) + if err == nil { + t.Fatal("expected error") + } + if err.Error() != "fake write failure" { + t.Error("unexpected value of error:", err) + } + w.done <- true +} + +func TestTCPClose(t *testing.T) { + once.Do(startServer) + + client, err := dialHTTP() + if err != nil { + t.Fatalf("dialing: %v", err) + } + defer client.Close() + + args := Args{17, 8} + var reply Reply + err = client.Call("Arith.Mul", args, &reply) + if err != nil { + t.Fatal("arith error:", err) + } + t.Logf("Arith: %d*%d=%d\n", args.A, args.B, reply) + if reply.C != args.A*args.B { + t.Errorf("Add: expected %d got %d", reply.C, args.A*args.B) + } +} + +func TestErrorAfterClientClose(t *testing.T) { + once.Do(startServer) + + client, err := dialHTTP() + if err != nil { + t.Fatalf("dialing: %v", err) + } + err = client.Close() + if err != nil { + t.Fatal("close error:", err) + } + err = client.Call("Arith.Add", &Args{7, 9}, new(Reply)) + if err != ErrShutdown { + t.Errorf("Forever: expected ErrShutdown got %v", err) + } +} + +// Tests the fix to issue 11221. Without the fix, this loops forever or crashes. +func TestAcceptExitAfterListenerClose(t *testing.T) { + newServer := NewServer() + newServer.Register(new(Arith)) + newServer.RegisterName("net.rpc.Arith", new(Arith)) + newServer.RegisterName("newServer.Arith", new(Arith)) + + var l net.Listener + l, _ = listenTCP() + l.Close() + newServer.Accept(l) +} + +func TestShutdown(t *testing.T) { + var l net.Listener + l, _ = listenTCP() + ch := make(chan net.Conn, 1) + go func() { + defer l.Close() + c, err := l.Accept() + if err != nil { + t.Error(err) + } + ch <- c + }() + c, err := net.Dial("tcp", l.Addr().String()) + if err != nil { + t.Fatal(err) + } + c1 := <-ch + if c1 == nil { + t.Fatal(err) + } + + newServer := NewServer() + newServer.Register(new(Arith)) + go newServer.ServeConn(c1) + + args := &Args{7, 8} + reply := new(Reply) + client := NewClient(c) + err = client.Call("Arith.Add", args, reply) + if err != nil { + t.Fatal(err) + } + + // On an unloaded system 10ms is usually enough to fail 100% of the time + // with a broken server. On a loaded system, a broken server might incorrectly + // be reported as passing, but we're OK with that kind of flakiness. + // If the code is correct, this test will never fail, regardless of timeout. + args.A = 10 // 10 ms + done := make(chan *Call, 1) + call := client.Go("Arith.SleepMilli", args, reply, done) + c.(*net.TCPConn).CloseWrite() + <-done + if call.Error != nil { + t.Fatal(err) + } +} + +func benchmarkEndToEnd(dial func() (*Client, error), b *testing.B) { + once.Do(startServer) + client, err := dial() + if err != nil { + b.Fatal("error dialing:", err) + } + defer client.Close() + + // Synchronous calls + args := &Args{7, 8} + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + reply := new(Reply) + for pb.Next() { + err := client.Call("Arith.Add", args, reply) + if err != nil { + b.Fatalf("rpc error: Add: expected no error but got string %q", err.Error()) + } + if reply.C != args.A+args.B { + b.Fatalf("rpc error: Add: expected %d got %d", reply.C, args.A+args.B) + } + } + }) +} + +func benchmarkEndToEndAsync(dial func() (*Client, error), b *testing.B) { + if b.N == 0 { + return + } + const MaxConcurrentCalls = 100 + once.Do(startServer) + client, err := dial() + if err != nil { + b.Fatal("error dialing:", err) + } + defer client.Close() + + // Asynchronous calls + args := &Args{7, 8} + procs := 4 * runtime.GOMAXPROCS(-1) + send := int32(b.N) + recv := int32(b.N) + var wg sync.WaitGroup + wg.Add(procs) + gate := make(chan bool, MaxConcurrentCalls) + res := make(chan *Call, MaxConcurrentCalls) + b.ResetTimer() + + for p := 0; p < procs; p++ { + go func() { + for atomic.AddInt32(&send, -1) >= 0 { + gate <- true + reply := new(Reply) + client.Go("Arith.Add", args, reply, res) + } + }() + go func() { + for call := range res { + A := call.Args.(*Args).A + B := call.Args.(*Args).B + C := call.Reply.(*Reply).C + if A+B != C { + b.Errorf("incorrect reply: Add: expected %d got %d", A+B, C) + return + } + <-gate + if atomic.AddInt32(&recv, -1) == 0 { + close(res) + } + } + wg.Done() + }() + } + wg.Wait() +} + +func BenchmarkEndToEnd(b *testing.B) { + benchmarkEndToEnd(dialDirect, b) +} + +func BenchmarkEndToEndHTTP(b *testing.B) { + benchmarkEndToEnd(dialHTTP, b) +} + +func BenchmarkEndToEndAsync(b *testing.B) { + benchmarkEndToEndAsync(dialDirect, b) +} + +func BenchmarkEndToEndAsyncHTTP(b *testing.B) { + benchmarkEndToEndAsync(dialHTTP, b) +} |