Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: feat(generic): add binarypb generics #1484

Open
wants to merge 7 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .licenserc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,6 @@ header:
- pkg/generic/httppb_test/idl/echo.pb.go
- pkg/utils/json.go
- pkg/protocol/bthrift/test/kitex_gen/**
- pkg/generic/binarypb_test/kitex_gen/**

comment: on-failure
182 changes: 182 additions & 0 deletions pkg/generic/binarypb_codec.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
/*
* Copyright 2024 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package generic

import (
"context"
"encoding/binary"
"fmt"

"github.com/bytedance/gopkg/lang/dirtmake"

"github.com/cloudwego/kitex/pkg/remote"
"github.com/cloudwego/kitex/pkg/remote/codec"
"github.com/cloudwego/kitex/pkg/remote/codec/perrors"
"github.com/cloudwego/kitex/pkg/serviceinfo"
)

var _ remote.PayloadCodec = &binaryProtobufCodec{}

type binaryPbReqType = []byte

type binaryProtobufCodec struct {
protobufCodec remote.PayloadCodec
}

func (c *binaryProtobufCodec) Marshal(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error {
data := msg.Data()
if data == nil {
return perrors.NewProtocolErrorWithMsg("invalid marshal data in rawProtobufBinaryCodec: nil")
}
if msg.MessageType() == remote.Exception {
if err := c.protobufCodec.Marshal(ctx, msg, out); err != nil {
return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("rawProtobufBinaryCodec Marshal exception failed, err: %s", err.Error()))
}
return nil
}
var transBuff []byte
var ok bool
if msg.RPCRole() == remote.Server {
gResult := data.(*Result)
transBinary := gResult.Success
if transBuff, ok = transBinary.(binaryPbReqType); !ok {
return perrors.NewProtocolErrorWithMsg("invalid marshal result in rawProtobufBinaryCodec: must be []byte")
}
} else {
gArg := data.(*Args)
transBinary := gArg.Request
if transBuff, ok = transBinary.(binaryPbReqType); !ok {
return perrors.NewProtocolErrorWithMsg("invalid marshal request in rawProtobufBinaryCodec: must be []byte")
}
if err := PbSetSeqID(msg.RPCInfo().Invocation().SeqID(), transBuff); err != nil {
return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("rawProtobufBinaryCodec set seqID failed, err: %s", err.Error()))
}
}
out.WriteBinary(transBuff)
return nil
}

func (c *binaryProtobufCodec) Unmarshal(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error {
magicAndMsgType, err := codec.PeekUint32(in)
if err != nil {
return err
}
msgType := magicAndMsgType & codec.FrontMask
if msgType == uint32(remote.Exception) {
return c.protobufCodec.Unmarshal(ctx, msg, in)
}
payloadLen := msg.PayloadLen()
transBuff := dirtmake.Bytes(payloadLen, payloadLen)
_, err = in.ReadBinary(transBuff)
if err != nil {
return err
}
if err := pbReadBinaryMethod(transBuff, msg); err != nil {
return err
}

if err = codec.NewDataIfNeeded(serviceinfo.GenericMethod, msg); err != nil {
return err
}
data := msg.Data()
if msg.RPCRole() == remote.Server {
gArg := data.(*Args)
gArg.Method = msg.RPCInfo().Invocation().MethodName()
gArg.Request = transBuff
} else {
gResult := data.(*Result)
gResult.Success = transBuff
}
return nil
}

func (c *binaryProtobufCodec) Name() string {
return "RawProtobufBinary"
}

// SetSeqID is used to reset seqID for kitex protobufs payload.
// For client side, you don't need this function, Kitex will gen seqID and set it into transport protocol to ignore
// inconsistent seqID between protobuf payload and transport protocol, reset the seqID to that generated by kitex for
// client side by default.
// But for server side(binary generic server), you need to return the same seqID with upstream, it is suggested to keep
// the upstream seqID(use PbGetSeqID) then use PbSetSeqID to reset the seqID of transBuff.
func PbSetSeqID(seqID int32, transBuff []byte) error {
seqID4Bytes, err := pbGetSeqID4Bytes(transBuff)
if err != nil {
return err
}
binary.BigEndian.PutUint32(seqID4Bytes, uint32(seqID))
return nil
}

// GetSeqID from protobuf buffered binary.
func PbGetSeqID(transBuff []byte) (int32, error) {
seqID4Bytes, err := pbGetSeqID4Bytes(transBuff)
if err != nil {
return 0, err
}
seqID := binary.BigEndian.Uint32(seqID4Bytes)
return int32(seqID), nil
}

// seqID has 4 bytes
func pbGetSeqID4Bytes(transBuff []byte) ([]byte, error) {
idx := 4
ret, e := codec.Bytes2Uint32(transBuff[:idx])
if e != nil {
return nil, e
}
first4Bytes := int32(ret)
if first4Bytes > 0 {
return nil, perrors.NewProtocolErrorWithMsg("missing version in Protobuf Message")
}
version := int64(first4Bytes) & codec.MagicMask
if version != codec.ProtobufV1Magic {
return nil, perrors.NewProtocolErrorWithType(perrors.BadVersion, "bad version in Protobuf Message")
}
idx += 4
ret, e = codec.Bytes2Uint32(transBuff[4:idx])
if e != nil {
return nil, e
}
methodNameLen := int32(ret)
if methodNameLen < 0 {
return nil, perrors.InvalidDataLength
}
idx += int(methodNameLen)
if len(transBuff) < idx+4 {
return nil, perrors.NewProtocolErrorWithMsg("invalid trans buffer")
}
return transBuff[idx : idx+4], nil
}

func pbReadBinaryMethod(buff []byte, msg remote.Message) error {
bufLen := len(buff)
if bufLen < codec.Size32*2 {
return perrors.NewProtocolErrorWithMsg(
fmt.Sprintf("invalid trans buffer in binaryProtobufCodec Unmarshal, size=%d less than 8 bytes", bufLen))
}
methodLen := int(binary.BigEndian.Uint32(buff[4:8]))
if bufLen < codec.Size32*2+methodLen || methodLen <= 0 {
return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("method len[%d] invalid in binaryProtobufCodec Unmarshal", methodLen))
}
method := string(buff[8:(8 + methodLen)])
if err := codec.SetOrCheckMethodName(method, msg); err != nil {
return perrors.NewProtocolError(err)
}
return nil
}
Loading