From 903a912c95ba1a33b62100b3c8a4486aa6b37b9c Mon Sep 17 00:00:00 2001 From: Cloud Healthcare Team Date: Tue, 16 Jul 2024 15:57:08 -0400 Subject: [PATCH] Allow MLLP adapter to receive multiple messages on the same connection PiperOrigin-RevId: 652946889 --- cloudbuild.yaml | 3 +- mllp_adapter/mllp/BUILD.bazel | 3 + mllp_adapter/mllp/mllp.go | 80 +++++++++---------- mllp_adapter/mllp/mllp_test.go | 30 ++++++- mllp_adapter/mllpreceiver/mllpreceiver.go | 3 +- .../mllpreceiver/mllpreceiver_test.go | 38 ++++++--- mllp_adapter/mllpsender/mllpsender_test.go | 2 + shared/testingutil/testingutil.go | 1 + 8 files changed, 106 insertions(+), 54 deletions(-) diff --git a/cloudbuild.yaml b/cloudbuild.yaml index f3c3376..2ad5f6d 100644 --- a/cloudbuild.yaml +++ b/cloudbuild.yaml @@ -85,11 +85,12 @@ steps: waitFor: ['Create pubsub topic/subscription', 'Create cluster'] - id: 'Get endpoint' - name: 'gcr.io/cloud-builders/kubectl' + name: 'gcr.io/google.com/cloudsdktool/cloud-sdk' entrypoint: 'bash' args: - '-c' - | + gcloud container clusters get-credentials $(cat _cluster-name) --location us-central1-b --project $PROJECT_ID # grab the public IP of the load balancer get_ip() { kubectl get service mllp-adapter-presubmit-service -o=jsonpath='{.status.loadBalancer.ingress[0].ip}' diff --git a/mllp_adapter/mllp/BUILD.bazel b/mllp_adapter/mllp/BUILD.bazel index ef2bb83..f67a5d5 100644 --- a/mllp_adapter/mllp/BUILD.bazel +++ b/mllp_adapter/mllp/BUILD.bazel @@ -20,6 +20,9 @@ go_library( name = "go_default_library", srcs = ["mllp.go"], importpath = "github.com/GoogleCloudPlatform/mllp/mllp_adapter/mllp", + deps = [ + "@com_github_golang_glog//:go_default_library", + ], ) go_test( diff --git a/mllp_adapter/mllp/mllp.go b/mllp_adapter/mllp/mllp.go index 47145e6..89ef95e 100644 --- a/mllp_adapter/mllp/mllp.go +++ b/mllp_adapter/mllp/mllp.go @@ -21,6 +21,8 @@ import ( "bufio" "fmt" "io" + + log "github.com/golang/glog" ) const ( @@ -29,66 +31,62 @@ const ( cr = '\x0d' ) -func encapsulate(in []byte) []byte { - out := make([]byte, len(in)+3) - out[0] = startBlock - for i, b := range in { - out[i+1] = b - } - out[len(out)-2] = endBlock - out[len(out)-1] = cr - return out -} - // WriteMsg wraps an HL7 message in the start block, end block, and carriage return bytes // required for MLLP transmission and then writes the wrapped message to writer. func WriteMsg(writer io.Writer, msg []byte) error { - if _, err := writer.Write(encapsulate(msg)); err != nil { + if _, err := writer.Write([]byte{startBlock}); err != nil { + return fmt.Errorf("writing message: %v", err) + } + if _, err := writer.Write(msg); err != nil { + return fmt.Errorf("writing message: %v", err) + } + if _, err := writer.Write([]byte{endBlock, cr}); err != nil { return fmt.Errorf("writing message: %v", err) } return nil } -func checkByte(msg []byte, pos int, expected byte) error { - if msg[pos] != expected { - return fmt.Errorf("invalid message %v, expected %v at position %v but got %v", - msg, expected, pos, msg[pos]) - } - return nil +// MessageReader consumes MLLP messages from a stream. +type MessageReader struct { + r *bufio.Reader } -func decapsulate(msg []byte) ([]byte, error) { - if len(msg) < 3 { - return nil, fmt.Errorf("short message, length %v", len(msg)) - } - if err := checkByte(msg, 0, startBlock); err != nil { - return nil, err - } - if err := checkByte(msg, len(msg)-2, endBlock); err != nil { +// NewMessageReader to unwrap MLLP messages the provided stream. +func NewMessageReader(r io.Reader) *MessageReader { + return &MessageReader{r: bufio.NewReader(r)} +} + +// Next message in the reader. Unwraps the inner message by removing the start +// block, end block, and carriage return bytes. +func (mr *MessageReader) Next() ([]byte, error) { + data, err := mr.r.ReadBytes(startBlock) + if err != nil { return nil, err } - if err := checkByte(msg, len(msg)-1, cr); err != nil { - return nil, err + if len(data) > 1 { + log.Infof("dropped %d bytes before start of message", len(data)-1) } - return msg[1 : len(msg)-2], nil -} - -// ReadMsg reads a message from reader and removes the start block, end block, and carriage return bytes. -func ReadMsg(reader io.Reader) ([]byte, error) { - r := bufio.NewReader(reader) // Read everything up to the endBlock byte. - rawMsg, err := r.ReadBytes(endBlock) + rawMsg, err := mr.r.ReadBytes(endBlock) if err != nil { return nil, err } // Read one more byte for the carriage return. - lastByte, err := r.ReadByte() + lastByte, err := mr.r.ReadByte() if err != nil { - return nil, fmt.Errorf("reading last byte: %v", err) + return nil, err } - msg, err := decapsulate(append(rawMsg, lastByte)) - if err != nil { - return nil, fmt.Errorf("decapsulating message: %v", err) + if lastByte != cr { + if err := mr.r.UnreadByte(); err != nil { + return nil, err + } + return nil, fmt.Errorf("message ends with %c, want %c", lastByte, cr) } - return msg, nil + return rawMsg[:len(rawMsg)-1], nil +} + +// ReadMsg from reader and removes the start block, end block, and carriage return bytes. +// The reader must return a single message, any trailing bytes may be consumed. +func ReadMsg(r io.Reader) ([]byte, error) { + return NewMessageReader(r).Next() } diff --git a/mllp_adapter/mllp/mllp_test.go b/mllp_adapter/mllp/mllp_test.go index 3153ff2..e72465e 100644 --- a/mllp_adapter/mllp/mllp_test.go +++ b/mllp_adapter/mllp/mllp_test.go @@ -52,7 +52,7 @@ func TestOK(t *testing.T) { t.Errorf("Writing message %v: expected %v but got %v", tc.raw, tc.mllp, enc.Bytes()) } - dec, err := ReadMsg(bytes.NewBuffer(tc.mllp)) + dec, err := ReadMsg(bytes.NewReader(tc.mllp)) if err != nil { t.Errorf("Unexpected error reading message %v: %v", tc.mllp, err) } @@ -63,6 +63,32 @@ func TestOK(t *testing.T) { } } +func TestRead_MultipleMessages(t *testing.T) { + msg1 := []byte("msg1") + msg2 := []byte("msg2") + data := bytes.Join([][]byte{ + []byte{startBlock}, msg1, []byte{endBlock, cr}, + []byte{startBlock}, msg2, []byte{endBlock, cr}, + }, nil) + reader := NewMessageReader(bytes.NewReader(data)) + + innerMsg, err := reader.Next() + if err != nil { + t.Errorf("Unexpected error reading first message in %s: %v", data, err) + } + if !bytes.Equal(innerMsg, msg1) { + t.Errorf("Reading first message: got %s, want %s", innerMsg, msg1) + } + + innerMsg, err = reader.Next() + if err != nil { + t.Errorf("Unexpected error reading second message in %s: %v", data, err) + } + if !bytes.Equal(innerMsg, msg2) { + t.Errorf("Reading second message: got %s, want %s", innerMsg, msg2) + } +} + func TestError(t *testing.T) { testCases := []struct { name string @@ -77,7 +103,7 @@ func TestError(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - if _, err := ReadMsg(bytes.NewBuffer(tc.msg)); err == nil { + if _, err := ReadMsg(bytes.NewReader(tc.msg)); err == nil { t.Errorf("Expected error for message %v", tc.msg) } }) diff --git a/mllp_adapter/mllpreceiver/mllpreceiver.go b/mllp_adapter/mllpreceiver/mllpreceiver.go index 6a0b832..3d23b50 100644 --- a/mllp_adapter/mllpreceiver/mllpreceiver.go +++ b/mllp_adapter/mllpreceiver/mllpreceiver.go @@ -110,8 +110,9 @@ func (m *MLLPReceiver) handleConnection(conn *net.TCPConn) { } }() + reader := mllp.NewMessageReader(conn) for { - msg, err := mllp.ReadMsg(conn) + msg, err := reader.Next() if err != nil { if err != io.EOF { log.Errorf("MLLP Receiver: failed to read message: %v", err) diff --git a/mllp_adapter/mllpreceiver/mllpreceiver_test.go b/mllp_adapter/mllpreceiver/mllpreceiver_test.go index 8825ce4..58ba44d 100644 --- a/mllp_adapter/mllpreceiver/mllpreceiver_test.go +++ b/mllp_adapter/mllpreceiver/mllpreceiver_test.go @@ -15,6 +15,7 @@ package mllpreceiver import ( + "bytes" "net" "reflect" "strconv" @@ -97,6 +98,14 @@ func TestValidMessages(t *testing.T) { }, [][]byte{cannedMsg, cannedMsg, cannedMsg}, }, + testCase{ + "3 encapsulated messages, single connection", + []connection{{ + bytes.Join([][]byte{wrappedMsg, wrappedMsg, wrappedMsg}, nil), + [][]byte{cannedAck, cannedAck, cannedAck}, + }}, + [][]byte{cannedMsg, cannedMsg, cannedMsg}, + }, testCase{ "2 encapsulated messages, 1 unencapsulated message (ignored and not sent), sent over separate connections", []connection{ @@ -116,14 +125,14 @@ func TestValidMessages(t *testing.T) { [][]byte{cannedMsg, cannedMsg}, }, testCase{ - "encapsulated message, unencapsulated message, encapsulated message, sent over a single connection, unencapsulated message and everything after is ignored", + "encapsulated message, unencapsulated message, encapsulated message, sent over a single connection", []connection{ connection{ - append(append(wrappedMsg, cannedMsg...), wrappedMsg...), + bytes.Join([][]byte{wrappedMsg, cannedMsg, wrappedMsg}, nil), [][]byte{cannedAck}, }, }, - [][]byte{cannedMsg}, + [][]byte{cannedMsg, cannedMsg}, }, testCase{ "garbage (ignored)", @@ -142,11 +151,17 @@ func TestValidMessages(t *testing.T) { s, r := setUp(t) for _, c := range tc.connections { conn := dial(t, r.port) - conn.Write(c.input) - for _, expectedAck := range c.expectedAcks { - ack := receiveAck(t, conn) - if !reflect.DeepEqual(ack, expectedAck) { - t.Errorf("Expected ack %v but got %v", expectedAck, ack) + if _, err := conn.Write(c.input); err != nil { + t.Fatalf("Failed to write message: %v", err) + } + reader := mllp.NewMessageReader(conn) + for _, wantResp := range c.expectedAcks { + gotResp, err := reader.Next() + if err != nil { + t.Fatalf("Next() failed: %v", err) + } + if !bytes.Equal(gotResp, wantResp) { + t.Errorf("Next() got message %v, want %v", gotResp, wantResp) } } conn.Close() @@ -186,7 +201,12 @@ func Test3SimultanousConnections(t *testing.T) { func TestMessageStats(t *testing.T) { s, r := setUp(t) c := dial(t, r.port) - mllp.WriteMsg(c, cannedMsg) + if err := mllp.WriteMsg(c, cannedMsg); err != nil { + t.Errorf("Failed to write message: %v", err) + } + if _, err := mllp.ReadMsg(c); err != nil { + t.Fatalf("Failed to read ack: %v", err) + } if err := c.Close(); err != nil { t.Fatalf("Failure closing connection: %v", err) } diff --git a/mllp_adapter/mllpsender/mllpsender_test.go b/mllp_adapter/mllpsender/mllpsender_test.go index 31be845..8473013 100644 --- a/mllp_adapter/mllpsender/mllpsender_test.go +++ b/mllp_adapter/mllpsender/mllpsender_test.go @@ -82,6 +82,7 @@ func TestSendError(t *testing.T) { listener, sender, metrics := setUp() go func() { conn := accept(t, listener) + mllp.ReadMsg(conn) conn.Close() }() if _, err := sender.Send(cannedMsg); err == nil { @@ -94,6 +95,7 @@ func TestRecoverAfterSendError(t *testing.T) { listener, sender, metrics := setUp() go func() { conn := accept(t, listener) + mllp.ReadMsg(conn) conn.Close() conn = accept(t, listener) mllp.ReadMsg(conn) diff --git a/shared/testingutil/testingutil.go b/shared/testingutil/testingutil.go index bdf4eb0..bd9978d 100644 --- a/shared/testingutil/testingutil.go +++ b/shared/testingutil/testingutil.go @@ -22,6 +22,7 @@ import ( // CheckMetrics checks whether metrics match expected. func CheckMetrics(t *testing.T, metrics *FakeMonitoringClient, expected map[string]int64) { + t.Helper() for m, v := range expected { if metrics.CounterValue(m) != v { t.Errorf("Metric %v expected %v, got %v", m, v, metrics.CounterValue(m))