diff --git a/handler.go b/handler.go index ba5ae32..ecb0d5e 100644 --- a/handler.go +++ b/handler.go @@ -191,7 +191,6 @@ func (s *handler) handleReader(ctx context.Context, r io.Reader, w io.Writer, rp cb(w) } - var req request // We read the entire request upfront in a buffer to be able to tell if the // client sent more than maxRequestSize and report it back as an explicit error, // instead of just silently truncating it and reporting a more vague parsing @@ -205,11 +204,11 @@ func (s *handler) handleReader(ctx context.Context, r io.Reader, w io.Writer, rp if err != nil { // ReadFrom will discard EOF so any error here is unexpected and should // be reported. - rpcError(wf, &req, rpcParseError, xerrors.Errorf("reading request: %w", err)) + rpcError(wf, nil, rpcParseError, xerrors.Errorf("reading request: %w", err)) return } if reqSize > s.maxRequestSize { - rpcError(wf, &req, rpcParseError, + rpcError(wf, nil, rpcParseError, // rpcParseError is the closest we have from the standard errors defined // in [jsonrpc spec](https://www.jsonrpc.org/specification#error_object) // to report the maximum limit. @@ -218,17 +217,56 @@ func (s *handler) handleReader(ctx context.Context, r io.Reader, w io.Writer, rp return } - if err := json.NewDecoder(bufferedRequest).Decode(&req); err != nil { - rpcError(wf, &req, rpcParseError, xerrors.Errorf("unmarshaling request: %w", err)) - return - } + // Trim spaces to avoid issues with batch request detection. + bufferedRequest = bytes.NewBuffer(bytes.TrimSpace(bufferedRequest.Bytes())) + reqSize = int64(bufferedRequest.Len()) - if req.ID, err = normalizeID(req.ID); err != nil { - rpcError(wf, &req, rpcParseError, xerrors.Errorf("failed to parse ID: %w", err)) + if reqSize == 0 { + rpcError(wf, nil, rpcInvalidRequest, xerrors.New("Invalid request")) return } - s.handle(ctx, req, wf, rpcError, func(bool) {}, nil) + if bufferedRequest.Bytes()[0] == '[' && bufferedRequest.Bytes()[reqSize-1] == ']' { + var reqs []request + + if err := json.NewDecoder(bufferedRequest).Decode(&reqs); err != nil { + rpcError(wf, nil, rpcParseError, xerrors.New("Parse error")) + return + } + + if len(reqs) == 0 { + rpcError(wf, nil, rpcInvalidRequest, xerrors.New("Invalid request")) + return + } + + w.Write([]byte("[")) + for idx, req := range reqs { + if req.ID, err = normalizeID(req.ID); err != nil { + rpcError(wf, &req, rpcParseError, xerrors.Errorf("failed to parse ID: %w", err)) + return + } + + s.handle(ctx, req, wf, rpcError, func(bool) {}, nil) + + if idx != len(reqs)-1 { + w.Write([]byte(",")) + } + } + w.Write([]byte("]")) + } else { + var req request + if err := json.NewDecoder(bufferedRequest).Decode(&req); err != nil { + rpcError(wf, &req, rpcParseError, xerrors.New("Parse error")) + return + } + + if req.ID, err = normalizeID(req.ID); err != nil { + rpcError(wf, &req, rpcParseError, xerrors.Errorf("failed to parse ID: %w", err)) + return + } + + s.handle(ctx, req, wf, rpcError, func(bool) {}, nil) + } } func doCall(methodName string, f reflect.Value, params []reflect.Value) (out []reflect.Value, err error) { diff --git a/rpc_test.go b/rpc_test.go index 807aa19..cadeea3 100644 --- a/rpc_test.go +++ b/rpc_test.go @@ -90,7 +90,22 @@ func TestRawRequests(t *testing.T) { testServ := httptest.NewServer(rpcServer) defer testServ.Close() - tc := func(req, resp string, n int32) func(t *testing.T) { + removeSpaces := func(jsonStr string) (string, error) { + var jsonObj interface{} + err := json.Unmarshal([]byte(jsonStr), &jsonObj) + if err != nil { + return "", err + } + + compactJSONBytes, err := json.Marshal(jsonObj) + if err != nil { + return "", err + } + + return string(compactJSONBytes), nil + } + + tc := func(req, resp string, n int32, statusCode int) func(t *testing.T) { return func(t *testing.T) { rpcHandler.n = 0 @@ -100,16 +115,29 @@ func TestRawRequests(t *testing.T) { b, err := ioutil.ReadAll(res.Body) require.NoError(t, err) - assert.Equal(t, resp, strings.TrimSpace(string(b))) + expectedResp, err := removeSpaces(resp) + require.NoError(t, err) + + responseBody, err := removeSpaces(string(b)) + require.NoError(t, err) + + assert.Equal(t, expectedResp, responseBody) require.Equal(t, n, rpcHandler.n) + require.Equal(t, statusCode, res.StatusCode) } } - t.Run("inc", tc(`{"jsonrpc": "2.0", "method": "SimpleServerHandler.Inc", "params": [], "id": 1}`, `{"jsonrpc":"2.0","id":1}`, 1)) - t.Run("inc-null", tc(`{"jsonrpc": "2.0", "method": "SimpleServerHandler.Inc", "params": null, "id": 1}`, `{"jsonrpc":"2.0","id":1}`, 1)) - t.Run("inc-noparam", tc(`{"jsonrpc": "2.0", "method": "SimpleServerHandler.Inc", "id": 2}`, `{"jsonrpc":"2.0","id":2}`, 1)) - t.Run("add", tc(`{"jsonrpc": "2.0", "method": "SimpleServerHandler.Add", "params": [10], "id": 4}`, `{"jsonrpc":"2.0","id":4}`, 10)) - + t.Run("inc", tc(`{"jsonrpc": "2.0", "method": "SimpleServerHandler.Inc", "params": [], "id": 1}`, `{"jsonrpc":"2.0","id":1}`, 1, 200)) + t.Run("inc-null", tc(`{"jsonrpc": "2.0", "method": "SimpleServerHandler.Inc", "params": null, "id": 1}`, `{"jsonrpc":"2.0","id":1}`, 1, 200)) + t.Run("inc-noparam", tc(`{"jsonrpc": "2.0", "method": "SimpleServerHandler.Inc", "id": 2}`, `{"jsonrpc":"2.0","id":2}`, 1, 200)) + t.Run("add", tc(`{"jsonrpc": "2.0", "method": "SimpleServerHandler.Add", "params": [10], "id": 4}`, `{"jsonrpc":"2.0","id":4}`, 10, 200)) + // Batch requests + t.Run("add", tc(`[{"jsonrpc": "2.0", "method": "SimpleServerHandler.Add", "params": [123], "id": 5}`, `{"jsonrpc":"2.0","id":null,"error":{"code":-32700,"message":"Parse error"}}`, 0, 500)) + t.Run("add", tc(`[{"jsonrpc": "2.0", "method": "SimpleServerHandler.Add", "params": [123], "id": 6}]`, `[{"jsonrpc":"2.0","id":6}]`, 123, 200)) + t.Run("add", tc(`[{"jsonrpc": "2.0", "method": "SimpleServerHandler.Add", "params": [123], "id": 7},{"jsonrpc": "2.0", "method": "SimpleServerHandler.Add", "params": [-122], "id": 8}]`, `[{"jsonrpc":"2.0","id":7},{"jsonrpc":"2.0","id":8}]`, 1, 200)) + t.Run("add", tc(`[{"jsonrpc": "2.0", "method": "SimpleServerHandler.Add", "params": [123], "id": 9},{"jsonrpc": "2.0", "params": [-122], "id": 10}]`, `[{"jsonrpc":"2.0","id":9},{"error":{"code":-32601,"message":"method '' not found"},"id":10,"jsonrpc":"2.0"}]`, 123, 200)) + t.Run("add", tc(` [{"jsonrpc": "2.0", "method": "SimpleServerHandler.Add", "params": [-1], "id": 11}] `, `[{"jsonrpc":"2.0","id":11}]`, -1, 200)) + t.Run("add", tc(``, `{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"Invalid request"}}`, 0, 400)) } func TestReconnection(t *testing.T) { diff --git a/server.go b/server.go index fa13acd..4e4fa4f 100644 --- a/server.go +++ b/server.go @@ -15,6 +15,7 @@ import ( const ( rpcParseError = -32700 + rpcInvalidRequest = -32600 rpcMethodNotFound = -32601 rpcInvalidParams = -32602 ) @@ -107,13 +108,17 @@ func rpcError(wf func(func(io.Writer)), req *request, code ErrorCode, err error) log.Errorf("RPC Error: %s", err) wf(func(w io.Writer) { if hw, ok := w.(http.ResponseWriter); ok { - hw.WriteHeader(500) + if code == rpcInvalidRequest { + hw.WriteHeader(400) + } else { + hw.WriteHeader(500) + } } log.Warnf("rpc error: %s", err) - if req.ID == nil { // notification - return + if req == nil { + req = &request{} } resp := response{