From 62c29ce0b1b8f84567de97ca0d32cebd53f05aa9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Mengu=C3=A9?= Date: Tue, 24 Oct 2023 10:05:53 +0200 Subject: [PATCH] Allow to change (or disable) the default driver name for registration (#1499) A link variable now allows to change or disable the name of the driver that is automatically registered with database/sql: Change driver name: go build "-ldflags=-X github.com/go-sql-driver/mysql.driverName=custom" Disable driver registration (set driverName to empty string): go build "-ldflags=-X github.com/go-sql-driver/mysql.driverName=" In the same way, a variable overridable at link time is also provided to override the driver name used in the test suite. This allows to run our test suite on another driver. go test "-ldflags=-X github.com/go-sql-driver/mysql.driverNameTest=custom" driverName is propagated to driverNameTest unless driverNameTest is explicitely defined. --- benchmark_test.go | 8 ++++---- driver.go | 8 +++++++- driver_test.go | 28 +++++++++++++++++++--------- 3 files changed, 30 insertions(+), 14 deletions(-) diff --git a/benchmark_test.go b/benchmark_test.go index fc70df60d..a4ecc0a63 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -48,7 +48,7 @@ func (tb *TB) checkStmt(stmt *sql.Stmt, err error) *sql.Stmt { func initDB(b *testing.B, queries ...string) *sql.DB { tb := (*TB)(b) - db := tb.checkDB(sql.Open("mysql", dsn)) + db := tb.checkDB(sql.Open(driverNameTest, dsn)) for _, query := range queries { if _, err := db.Exec(query); err != nil { b.Fatalf("error on %q: %v", query, err) @@ -105,7 +105,7 @@ func BenchmarkExec(b *testing.B) { tb := (*TB)(b) b.StopTimer() b.ReportAllocs() - db := tb.checkDB(sql.Open("mysql", dsn)) + db := tb.checkDB(sql.Open(driverNameTest, dsn)) db.SetMaxIdleConns(concurrencyLevel) defer db.Close() @@ -151,7 +151,7 @@ func BenchmarkRoundtripTxt(b *testing.B) { sampleString := string(sample) b.ReportAllocs() tb := (*TB)(b) - db := tb.checkDB(sql.Open("mysql", dsn)) + db := tb.checkDB(sql.Open(driverNameTest, dsn)) defer db.Close() b.StartTimer() var result string @@ -184,7 +184,7 @@ func BenchmarkRoundtripBin(b *testing.B) { sample, min, max := initRoundtripBenchmarks() b.ReportAllocs() tb := (*TB)(b) - db := tb.checkDB(sql.Open("mysql", dsn)) + db := tb.checkDB(sql.Open(driverNameTest, dsn)) defer db.Close() stmt := tb.checkStmt(db.Prepare("SELECT ?")) defer stmt.Close() diff --git a/driver.go b/driver.go index 0ed8fa1c5..45528b920 100644 --- a/driver.go +++ b/driver.go @@ -90,8 +90,14 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { return c.Connect(context.Background()) } +// This variable can be replaced with -ldflags like below: +// go build "-ldflags=-X github.com/go-sql-driver/mysql.driverName=custom" +var driverName = "mysql" + func init() { - sql.Register("mysql", &MySQLDriver{}) + if driverName != "" { + sql.Register(driverName, &MySQLDriver{}) + } } // NewConnector returns new driver.Connector. diff --git a/driver_test.go b/driver_test.go index f46d38df6..13e07e753 100644 --- a/driver_test.go +++ b/driver_test.go @@ -31,6 +31,16 @@ import ( "time" ) +// This variable can be replaced with -ldflags like below: +// go test "-ldflags=-X github.com/go-sql-driver/mysql.driverNameTest=custom" +var driverNameTest string + +func init() { + if driverNameTest == "" { + driverNameTest = driverName + } +} + // Ensure that all the driver interfaces are implemented var ( _ driver.Rows = &binaryRows{} @@ -111,7 +121,7 @@ func runTestsWithMultiStatement(t *testing.T, dsn string, tests ...func(dbt *DBT dsn += "&multiStatements=true" var db *sql.DB if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation { - db, err = sql.Open("mysql", dsn) + db, err = sql.Open(driverNameTest, dsn) if err != nil { t.Fatalf("error connecting: %s", err.Error()) } @@ -130,7 +140,7 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { t.Skipf("MySQL server not running on %s", netAddr) } - db, err := sql.Open("mysql", dsn) + db, err := sql.Open(driverNameTest, dsn) if err != nil { t.Fatalf("error connecting: %s", err.Error()) } @@ -141,7 +151,7 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { dsn2 := dsn + "&interpolateParams=true" var db2 *sql.DB if _, err := ParseDSN(dsn2); err != errInvalidDSNUnsafeCollation { - db2, err = sql.Open("mysql", dsn2) + db2, err = sql.Open(driverNameTest, dsn2) if err != nil { t.Fatalf("error connecting: %s", err.Error()) } @@ -1917,7 +1927,7 @@ func testDialError(t *testing.T, dialErr error, expectErr error) { return nil, dialErr }) - db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname)) + db, err := sql.Open(driverNameTest, fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname)) if err != nil { t.Fatalf("error connecting: %s", err.Error()) } @@ -1956,7 +1966,7 @@ func TestCustomDial(t *testing.T) { return d.DialContext(ctx, prot, addr) }) - db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname)) + db, err := sql.Open(driverNameTest, fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname)) if err != nil { t.Fatalf("error connecting: %s", err.Error()) } @@ -2054,7 +2064,7 @@ func TestUnixSocketAuthFail(t *testing.T) { } t.Logf("socket: %s", socket) badDSN := fmt.Sprintf("%s:%s@unix(%s)/%s?timeout=30s", user, badPass, socket, dbname) - db, err := sql.Open("mysql", badDSN) + db, err := sql.Open(driverNameTest, badDSN) if err != nil { t.Fatalf("error connecting: %s", err.Error()) } @@ -2243,7 +2253,7 @@ func TestEmptyPassword(t *testing.T) { } dsn := fmt.Sprintf("%s:%s@%s/%s?timeout=30s", user, "", netAddr, dbname) - db, err := sql.Open("mysql", dsn) + db, err := sql.Open(driverNameTest, dsn) if err == nil { defer db.Close() err = db.Ping() @@ -3210,7 +3220,7 @@ func TestConnectorObeysDialTimeouts(t *testing.T) { return d.DialContext(ctx, prot, addr) }) - db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@dialctxtest(%s)/%s?timeout=30s", user, pass, addr, dbname)) + db, err := sql.Open(driverNameTest, fmt.Sprintf("%s:%s@dialctxtest(%s)/%s?timeout=30s", user, pass, addr, dbname)) if err != nil { t.Fatalf("error connecting: %s", err.Error()) } @@ -3375,7 +3385,7 @@ func TestConnectionAttributes(t *testing.T) { var db *sql.DB if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation { - db, err = sql.Open("mysql", dsn) + db, err = sql.Open(driverNameTest, dsn) if err != nil { t.Fatalf("error connecting: %s", err.Error()) }