Skip to content

Commit

Permalink
Allow MLLP adapter to receive multiple messages on the same connection
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 652946889
  • Loading branch information
Cloud Healthcare Team authored and nikklassen committed Jul 19, 2024
1 parent 8b49249 commit 903a912
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 54 deletions.
3 changes: 2 additions & 1 deletion cloudbuild.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,12 @@ steps:
waitFor: ['Create pubsub topic/subscription', 'Create cluster']

- id: 'Get endpoint'
name: 'gcr.io/cloud-builders/kubectl'
name: 'gcr.io/google.com/cloudsdktool/cloud-sdk'
entrypoint: 'bash'
args:
- '-c'
- |
gcloud container clusters get-credentials $(cat _cluster-name) --location us-central1-b --project $PROJECT_ID
# grab the public IP of the load balancer
get_ip() {
kubectl get service mllp-adapter-presubmit-service -o=jsonpath='{.status.loadBalancer.ingress[0].ip}'
Expand Down
3 changes: 3 additions & 0 deletions mllp_adapter/mllp/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ go_library(
name = "go_default_library",
srcs = ["mllp.go"],
importpath = "github.com/GoogleCloudPlatform/mllp/mllp_adapter/mllp",
deps = [
"@com_github_golang_glog//:go_default_library",
],
)

go_test(
Expand Down
80 changes: 39 additions & 41 deletions mllp_adapter/mllp/mllp.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
"bufio"
"fmt"
"io"

log "github.com/golang/glog"
)

const (
Expand All @@ -29,66 +31,62 @@ const (
cr = '\x0d'
)

func encapsulate(in []byte) []byte {
out := make([]byte, len(in)+3)
out[0] = startBlock
for i, b := range in {
out[i+1] = b
}
out[len(out)-2] = endBlock
out[len(out)-1] = cr
return out
}

// WriteMsg wraps an HL7 message in the start block, end block, and carriage return bytes
// required for MLLP transmission and then writes the wrapped message to writer.
func WriteMsg(writer io.Writer, msg []byte) error {
if _, err := writer.Write(encapsulate(msg)); err != nil {
if _, err := writer.Write([]byte{startBlock}); err != nil {
return fmt.Errorf("writing message: %v", err)
}
if _, err := writer.Write(msg); err != nil {
return fmt.Errorf("writing message: %v", err)
}
if _, err := writer.Write([]byte{endBlock, cr}); err != nil {
return fmt.Errorf("writing message: %v", err)
}
return nil
}

func checkByte(msg []byte, pos int, expected byte) error {
if msg[pos] != expected {
return fmt.Errorf("invalid message %v, expected %v at position %v but got %v",
msg, expected, pos, msg[pos])
}
return nil
// MessageReader consumes MLLP messages from a stream.
type MessageReader struct {
r *bufio.Reader
}

func decapsulate(msg []byte) ([]byte, error) {
if len(msg) < 3 {
return nil, fmt.Errorf("short message, length %v", len(msg))
}
if err := checkByte(msg, 0, startBlock); err != nil {
return nil, err
}
if err := checkByte(msg, len(msg)-2, endBlock); err != nil {
// NewMessageReader to unwrap MLLP messages the provided stream.
func NewMessageReader(r io.Reader) *MessageReader {
return &MessageReader{r: bufio.NewReader(r)}
}

// Next message in the reader. Unwraps the inner message by removing the start
// block, end block, and carriage return bytes.
func (mr *MessageReader) Next() ([]byte, error) {
data, err := mr.r.ReadBytes(startBlock)
if err != nil {
return nil, err
}
if err := checkByte(msg, len(msg)-1, cr); err != nil {
return nil, err
if len(data) > 1 {
log.Infof("dropped %d bytes before start of message", len(data)-1)
}
return msg[1 : len(msg)-2], nil
}

// ReadMsg reads a message from reader and removes the start block, end block, and carriage return bytes.
func ReadMsg(reader io.Reader) ([]byte, error) {
r := bufio.NewReader(reader)
// Read everything up to the endBlock byte.
rawMsg, err := r.ReadBytes(endBlock)
rawMsg, err := mr.r.ReadBytes(endBlock)
if err != nil {
return nil, err
}
// Read one more byte for the carriage return.
lastByte, err := r.ReadByte()
lastByte, err := mr.r.ReadByte()
if err != nil {
return nil, fmt.Errorf("reading last byte: %v", err)
return nil, err
}
msg, err := decapsulate(append(rawMsg, lastByte))
if err != nil {
return nil, fmt.Errorf("decapsulating message: %v", err)
if lastByte != cr {
if err := mr.r.UnreadByte(); err != nil {
return nil, err
}
return nil, fmt.Errorf("message ends with %c, want %c", lastByte, cr)
}
return msg, nil
return rawMsg[:len(rawMsg)-1], nil
}

// ReadMsg from reader and removes the start block, end block, and carriage return bytes.
// The reader must return a single message, any trailing bytes may be consumed.
func ReadMsg(r io.Reader) ([]byte, error) {
return NewMessageReader(r).Next()
}
30 changes: 28 additions & 2 deletions mllp_adapter/mllp/mllp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func TestOK(t *testing.T) {
t.Errorf("Writing message %v: expected %v but got %v", tc.raw, tc.mllp, enc.Bytes())
}

dec, err := ReadMsg(bytes.NewBuffer(tc.mllp))
dec, err := ReadMsg(bytes.NewReader(tc.mllp))
if err != nil {
t.Errorf("Unexpected error reading message %v: %v", tc.mllp, err)
}
Expand All @@ -63,6 +63,32 @@ func TestOK(t *testing.T) {
}
}

func TestRead_MultipleMessages(t *testing.T) {
msg1 := []byte("msg1")
msg2 := []byte("msg2")
data := bytes.Join([][]byte{
[]byte{startBlock}, msg1, []byte{endBlock, cr},
[]byte{startBlock}, msg2, []byte{endBlock, cr},
}, nil)
reader := NewMessageReader(bytes.NewReader(data))

innerMsg, err := reader.Next()
if err != nil {
t.Errorf("Unexpected error reading first message in %s: %v", data, err)
}
if !bytes.Equal(innerMsg, msg1) {
t.Errorf("Reading first message: got %s, want %s", innerMsg, msg1)
}

innerMsg, err = reader.Next()
if err != nil {
t.Errorf("Unexpected error reading second message in %s: %v", data, err)
}
if !bytes.Equal(innerMsg, msg2) {
t.Errorf("Reading second message: got %s, want %s", innerMsg, msg2)
}
}

func TestError(t *testing.T) {
testCases := []struct {
name string
Expand All @@ -77,7 +103,7 @@ func TestError(t *testing.T) {

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if _, err := ReadMsg(bytes.NewBuffer(tc.msg)); err == nil {
if _, err := ReadMsg(bytes.NewReader(tc.msg)); err == nil {
t.Errorf("Expected error for message %v", tc.msg)
}
})
Expand Down
3 changes: 2 additions & 1 deletion mllp_adapter/mllpreceiver/mllpreceiver.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,9 @@ func (m *MLLPReceiver) handleConnection(conn *net.TCPConn) {
}
}()

reader := mllp.NewMessageReader(conn)
for {
msg, err := mllp.ReadMsg(conn)
msg, err := reader.Next()
if err != nil {
if err != io.EOF {
log.Errorf("MLLP Receiver: failed to read message: %v", err)
Expand Down
38 changes: 29 additions & 9 deletions mllp_adapter/mllpreceiver/mllpreceiver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package mllpreceiver

import (
"bytes"
"net"
"reflect"
"strconv"
Expand Down Expand Up @@ -97,6 +98,14 @@ func TestValidMessages(t *testing.T) {
},
[][]byte{cannedMsg, cannedMsg, cannedMsg},
},
testCase{
"3 encapsulated messages, single connection",
[]connection{{
bytes.Join([][]byte{wrappedMsg, wrappedMsg, wrappedMsg}, nil),
[][]byte{cannedAck, cannedAck, cannedAck},
}},
[][]byte{cannedMsg, cannedMsg, cannedMsg},
},
testCase{
"2 encapsulated messages, 1 unencapsulated message (ignored and not sent), sent over separate connections",
[]connection{
Expand All @@ -116,14 +125,14 @@ func TestValidMessages(t *testing.T) {
[][]byte{cannedMsg, cannedMsg},
},
testCase{
"encapsulated message, unencapsulated message, encapsulated message, sent over a single connection, unencapsulated message and everything after is ignored",
"encapsulated message, unencapsulated message, encapsulated message, sent over a single connection",
[]connection{
connection{
append(append(wrappedMsg, cannedMsg...), wrappedMsg...),
bytes.Join([][]byte{wrappedMsg, cannedMsg, wrappedMsg}, nil),
[][]byte{cannedAck},
},
},
[][]byte{cannedMsg},
[][]byte{cannedMsg, cannedMsg},
},
testCase{
"garbage (ignored)",
Expand All @@ -142,11 +151,17 @@ func TestValidMessages(t *testing.T) {
s, r := setUp(t)
for _, c := range tc.connections {
conn := dial(t, r.port)
conn.Write(c.input)
for _, expectedAck := range c.expectedAcks {
ack := receiveAck(t, conn)
if !reflect.DeepEqual(ack, expectedAck) {
t.Errorf("Expected ack %v but got %v", expectedAck, ack)
if _, err := conn.Write(c.input); err != nil {
t.Fatalf("Failed to write message: %v", err)
}
reader := mllp.NewMessageReader(conn)
for _, wantResp := range c.expectedAcks {
gotResp, err := reader.Next()
if err != nil {
t.Fatalf("Next() failed: %v", err)
}
if !bytes.Equal(gotResp, wantResp) {
t.Errorf("Next() got message %v, want %v", gotResp, wantResp)
}
}
conn.Close()
Expand Down Expand Up @@ -186,7 +201,12 @@ func Test3SimultanousConnections(t *testing.T) {
func TestMessageStats(t *testing.T) {
s, r := setUp(t)
c := dial(t, r.port)
mllp.WriteMsg(c, cannedMsg)
if err := mllp.WriteMsg(c, cannedMsg); err != nil {
t.Errorf("Failed to write message: %v", err)
}
if _, err := mllp.ReadMsg(c); err != nil {
t.Fatalf("Failed to read ack: %v", err)
}
if err := c.Close(); err != nil {
t.Fatalf("Failure closing connection: %v", err)
}
Expand Down
2 changes: 2 additions & 0 deletions mllp_adapter/mllpsender/mllpsender_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ func TestSendError(t *testing.T) {
listener, sender, metrics := setUp()
go func() {
conn := accept(t, listener)
mllp.ReadMsg(conn)
conn.Close()
}()
if _, err := sender.Send(cannedMsg); err == nil {
Expand All @@ -94,6 +95,7 @@ func TestRecoverAfterSendError(t *testing.T) {
listener, sender, metrics := setUp()
go func() {
conn := accept(t, listener)
mllp.ReadMsg(conn)
conn.Close()
conn = accept(t, listener)
mllp.ReadMsg(conn)
Expand Down
1 change: 1 addition & 0 deletions shared/testingutil/testingutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (

// CheckMetrics checks whether metrics match expected.
func CheckMetrics(t *testing.T, metrics *FakeMonitoringClient, expected map[string]int64) {
t.Helper()
for m, v := range expected {
if metrics.CounterValue(m) != v {
t.Errorf("Metric %v expected %v, got %v", m, v, metrics.CounterValue(m))
Expand Down

0 comments on commit 903a912

Please sign in to comment.