Skip to content

Commit

Permalink
Allow to change (or disable) the default driver name for registration (
Browse files Browse the repository at this point in the history
…#1499)

A link variable now allows to change or disable the name of the driver
that is automatically registered with database/sql:

Change driver name:
    go build "-ldflags=-X github.com/go-sql-driver/mysql.driverName=custom"

Disable driver registration (set driverName to empty string):
    go build "-ldflags=-X github.com/go-sql-driver/mysql.driverName="

In the same way, a variable overridable at link time is also provided to
override the driver name used in the test suite. This allows to run our
test suite on another driver.

    go test "-ldflags=-X github.com/go-sql-driver/mysql.driverNameTest=custom"

driverName is propagated to driverNameTest unless driverNameTest is
explicitely defined.
  • Loading branch information
dolmen authored Oct 24, 2023
1 parent 1e6b8d7 commit 62c29ce
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 14 deletions.
8 changes: 4 additions & 4 deletions benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func (tb *TB) checkStmt(stmt *sql.Stmt, err error) *sql.Stmt {

func initDB(b *testing.B, queries ...string) *sql.DB {
tb := (*TB)(b)
db := tb.checkDB(sql.Open("mysql", dsn))
db := tb.checkDB(sql.Open(driverNameTest, dsn))
for _, query := range queries {
if _, err := db.Exec(query); err != nil {
b.Fatalf("error on %q: %v", query, err)
Expand Down Expand Up @@ -105,7 +105,7 @@ func BenchmarkExec(b *testing.B) {
tb := (*TB)(b)
b.StopTimer()
b.ReportAllocs()
db := tb.checkDB(sql.Open("mysql", dsn))
db := tb.checkDB(sql.Open(driverNameTest, dsn))
db.SetMaxIdleConns(concurrencyLevel)
defer db.Close()

Expand Down Expand Up @@ -151,7 +151,7 @@ func BenchmarkRoundtripTxt(b *testing.B) {
sampleString := string(sample)
b.ReportAllocs()
tb := (*TB)(b)
db := tb.checkDB(sql.Open("mysql", dsn))
db := tb.checkDB(sql.Open(driverNameTest, dsn))
defer db.Close()
b.StartTimer()
var result string
Expand Down Expand Up @@ -184,7 +184,7 @@ func BenchmarkRoundtripBin(b *testing.B) {
sample, min, max := initRoundtripBenchmarks()
b.ReportAllocs()
tb := (*TB)(b)
db := tb.checkDB(sql.Open("mysql", dsn))
db := tb.checkDB(sql.Open(driverNameTest, dsn))
defer db.Close()
stmt := tb.checkStmt(db.Prepare("SELECT ?"))
defer stmt.Close()
Expand Down
8 changes: 7 additions & 1 deletion driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,14 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
return c.Connect(context.Background())
}

// This variable can be replaced with -ldflags like below:
// go build "-ldflags=-X github.com/go-sql-driver/mysql.driverName=custom"
var driverName = "mysql"

func init() {
sql.Register("mysql", &MySQLDriver{})
if driverName != "" {
sql.Register(driverName, &MySQLDriver{})
}
}

// NewConnector returns new driver.Connector.
Expand Down
28 changes: 19 additions & 9 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,16 @@ import (
"time"
)

// This variable can be replaced with -ldflags like below:
// go test "-ldflags=-X github.com/go-sql-driver/mysql.driverNameTest=custom"
var driverNameTest string

func init() {
if driverNameTest == "" {
driverNameTest = driverName
}
}

// Ensure that all the driver interfaces are implemented
var (
_ driver.Rows = &binaryRows{}
Expand Down Expand Up @@ -111,7 +121,7 @@ func runTestsWithMultiStatement(t *testing.T, dsn string, tests ...func(dbt *DBT
dsn += "&multiStatements=true"
var db *sql.DB
if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation {
db, err = sql.Open("mysql", dsn)
db, err = sql.Open(driverNameTest, dsn)
if err != nil {
t.Fatalf("error connecting: %s", err.Error())
}
Expand All @@ -130,7 +140,7 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
t.Skipf("MySQL server not running on %s", netAddr)
}

db, err := sql.Open("mysql", dsn)
db, err := sql.Open(driverNameTest, dsn)
if err != nil {
t.Fatalf("error connecting: %s", err.Error())
}
Expand All @@ -141,7 +151,7 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
dsn2 := dsn + "&interpolateParams=true"
var db2 *sql.DB
if _, err := ParseDSN(dsn2); err != errInvalidDSNUnsafeCollation {
db2, err = sql.Open("mysql", dsn2)
db2, err = sql.Open(driverNameTest, dsn2)
if err != nil {
t.Fatalf("error connecting: %s", err.Error())
}
Expand Down Expand Up @@ -1917,7 +1927,7 @@ func testDialError(t *testing.T, dialErr error, expectErr error) {
return nil, dialErr
})

db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname))
db, err := sql.Open(driverNameTest, fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname))
if err != nil {
t.Fatalf("error connecting: %s", err.Error())
}
Expand Down Expand Up @@ -1956,7 +1966,7 @@ func TestCustomDial(t *testing.T) {
return d.DialContext(ctx, prot, addr)
})

db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname))
db, err := sql.Open(driverNameTest, fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname))
if err != nil {
t.Fatalf("error connecting: %s", err.Error())
}
Expand Down Expand Up @@ -2054,7 +2064,7 @@ func TestUnixSocketAuthFail(t *testing.T) {
}
t.Logf("socket: %s", socket)
badDSN := fmt.Sprintf("%s:%s@unix(%s)/%s?timeout=30s", user, badPass, socket, dbname)
db, err := sql.Open("mysql", badDSN)
db, err := sql.Open(driverNameTest, badDSN)
if err != nil {
t.Fatalf("error connecting: %s", err.Error())
}
Expand Down Expand Up @@ -2243,7 +2253,7 @@ func TestEmptyPassword(t *testing.T) {
}

dsn := fmt.Sprintf("%s:%s@%s/%s?timeout=30s", user, "", netAddr, dbname)
db, err := sql.Open("mysql", dsn)
db, err := sql.Open(driverNameTest, dsn)
if err == nil {
defer db.Close()
err = db.Ping()
Expand Down Expand Up @@ -3210,7 +3220,7 @@ func TestConnectorObeysDialTimeouts(t *testing.T) {
return d.DialContext(ctx, prot, addr)
})

db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@dialctxtest(%s)/%s?timeout=30s", user, pass, addr, dbname))
db, err := sql.Open(driverNameTest, fmt.Sprintf("%s:%s@dialctxtest(%s)/%s?timeout=30s", user, pass, addr, dbname))
if err != nil {
t.Fatalf("error connecting: %s", err.Error())
}
Expand Down Expand Up @@ -3375,7 +3385,7 @@ func TestConnectionAttributes(t *testing.T) {

var db *sql.DB
if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation {
db, err = sql.Open("mysql", dsn)
db, err = sql.Open(driverNameTest, dsn)
if err != nil {
t.Fatalf("error connecting: %s", err.Error())
}
Expand Down

0 comments on commit 62c29ce

Please sign in to comment.