Skip to content

Commit

Permalink
Update test TestConnectionAttributes
Browse files Browse the repository at this point in the history
  • Loading branch information
oblitorum committed Oct 31, 2023
1 parent c422dd1 commit 4ec6710
Showing 1 changed file with 40 additions and 24 deletions.
64 changes: 40 additions & 24 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"os"
"reflect"
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -3377,11 +3378,31 @@ func TestConnectionAttributes(t *testing.T) {
t.Skipf("MySQL server not running on %s", netAddr)
}

attr1 := "attr1"
value1 := "value1"
attr2 := "foo"
value2 := "boo"
dsn += fmt.Sprintf("&connectionAttributes=%s:%s,%s:%s", attr1, value1, attr2, value2)
defaultAttrs := []string{
connAttrClientName,
connAttrOS,
connAttrPlatform,
connAttrPid,
connAttrServerHost,
}
host, _, _ := net.SplitHostPort(addr)
defaultAttrValues := []string{
connAttrClientNameValue,
connAttrOSValue,
connAttrPlatformValue,
strconv.Itoa(os.Getpid()),
host,
}

customAttrs := []string{"attr1", "attr2"}
customAttrValues := []string{"foo", "bar"}

customAttrStrs := make([]string, len(customAttrs))
for i := range customAttrs {
customAttrStrs[i] = fmt.Sprintf("%s:%s", customAttrs[i], customAttrValues[i])
}

dsn += fmt.Sprintf("&connectionAttributes=%s", strings.Join(customAttrStrs, ","))

var db *sql.DB
if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation {
Expand All @@ -3394,27 +3415,22 @@ func TestConnectionAttributes(t *testing.T) {

dbt := &DBTest{t, db}

var attrValue string
queryString := "SELECT ATTR_VALUE FROM performance_schema.session_account_connect_attrs WHERE PROCESSLIST_ID = CONNECTION_ID() and ATTR_NAME = ?"
rows := dbt.mustQuery(queryString, connAttrClientName)
if rows.Next() {
rows.Scan(&attrValue)
if attrValue != connAttrClientNameValue {
dbt.Errorf("expected %q, got %q", connAttrClientNameValue, attrValue)
}
} else {
dbt.Errorf("no data")
queryString := "SELECT ATTR_NAME, ATTR_VALUE FROM performance_schema.session_account_connect_attrs WHERE PROCESSLIST_ID = CONNECTION_ID()"
rows := dbt.mustQuery(queryString)
defer rows.Close()

rowsMap := make(map[string]string)
for rows.Next() {
var attrName, attrValue string
rows.Scan(&attrName, &attrValue)
rowsMap[attrName] = attrValue
}
rows.Close()

rows = dbt.mustQuery(queryString, attr2)
if rows.Next() {
rows.Scan(&attrValue)
if attrValue != value2 {
dbt.Errorf("expected %q, got %q", value2, attrValue)
connAttrs := append(append([]string{}, defaultAttrs...), customAttrs...)
expectedAttrValues := append(append([]string{}, defaultAttrValues...), customAttrValues...)
for i := range connAttrs {
if gotValue := rowsMap[connAttrs[i]]; gotValue != expectedAttrValues[i] {
dbt.Errorf("expected %s, got %s", expectedAttrValues[i], gotValue)
}
} else {
dbt.Errorf("no data")
}
rows.Close()
}

0 comments on commit 4ec6710

Please sign in to comment.