diff --git a/README.md b/README.md index e254b6b0..2fb1c6ff 100644 --- a/README.md +++ b/README.md @@ -68,6 +68,9 @@ Other supported formats are listed below. * `multisubnetfailover` * `true` (Default) Client attempt to connect to all IPs simultaneously. * `false` Client attempts to connect to IPs in serial. +* `sendStringParametersAsUnicode` + * `true` (Default) Go default string types sent as `nvarchar`. + * `false` Go default string types sent as `varchar`. ### Connection parameters for namedpipe package * `pipe` - If set, no Browser query is made and named pipe used will be `\\\pipe\` @@ -371,7 +374,7 @@ To pass specific types to the query parameters, say `varchar` or `date` types, you must convert the types to the type before passing in. The following types are supported: -* string -> nvarchar +* string -> nvarchar(by default, will be varchar if `sendStringParametersAsUnicode` is set to true) * mssql.VarChar -> varchar * time.Time -> datetimeoffset or datetime (TDS version dependent) * mssql.DateTime1 -> datetime diff --git a/msdsn/conn_str.go b/msdsn/conn_str.go index 460dd4e3..b3bf2d17 100644 --- a/msdsn/conn_str.go +++ b/msdsn/conn_str.go @@ -53,32 +53,33 @@ const ( ) const ( - Database = "database" - Encrypt = "encrypt" - Password = "password" - ChangePassword = "change password" - UserID = "user id" - Port = "port" - TrustServerCertificate = "trustservercertificate" - Certificate = "certificate" - TLSMin = "tlsmin" - PacketSize = "packet size" - LogParam = "log" - ConnectionTimeout = "connection timeout" - HostNameInCertificate = "hostnameincertificate" - KeepAlive = "keepalive" - ServerSpn = "serverspn" - WorkstationID = "workstation id" - AppName = "app name" - ApplicationIntent = "applicationintent" - FailoverPartner = "failoverpartner" - FailOverPort = "failoverport" - DisableRetry = "disableretry" - Server = "server" - Protocol = "protocol" - DialTimeout = "dial timeout" - Pipe = "pipe" - MultiSubnetFailover = "multisubnetfailover" + Database = "database" + Encrypt = "encrypt" + Password = "password" + ChangePassword = "change password" + UserID = "user id" + Port = "port" + TrustServerCertificate = "trustservercertificate" + Certificate = "certificate" + TLSMin = "tlsmin" + PacketSize = "packet size" + LogParam = "log" + ConnectionTimeout = "connection timeout" + HostNameInCertificate = "hostnameincertificate" + KeepAlive = "keepalive" + ServerSpn = "serverspn" + WorkstationID = "workstation id" + AppName = "app name" + ApplicationIntent = "applicationintent" + FailoverPartner = "failoverpartner" + FailOverPort = "failoverport" + DisableRetry = "disableretry" + Server = "server" + Protocol = "protocol" + DialTimeout = "dial timeout" + Pipe = "pipe" + MultiSubnetFailover = "multisubnetfailover" + SendStringParametersAsUnicode = "sendstringparametersasunicode" ) type Config struct { @@ -131,6 +132,9 @@ type Config struct { ColumnEncryption bool // Attempt to connect to all IPs in parallel when MultiSubnetFailover is true MultiSubnetFailover bool + + // Sets a boolean value that indicates if sending string parameters to the server in UNICODE format is enabled. + SendStringParametersAsUnicode bool } func readDERFile(filename string) ([]byte, error) { @@ -504,6 +508,19 @@ func Parse(dsn string) (Config, error) { // Defaulting to true to prevent breaking change although other client libraries default to false p.MultiSubnetFailover = true } + + sendStringParametersAsUnicode, ok := params[SendStringParametersAsUnicode] + if ok { + p.SendStringParametersAsUnicode, err = strconv.ParseBool(sendStringParametersAsUnicode) + if err != nil { + return p, fmt.Errorf("invalid %s '%s': %s", SendStringParametersAsUnicode, + sendStringParametersAsUnicode, err.Error()) + } + } else { + // defaulting to true for backward compatibility + p.SendStringParametersAsUnicode = true + } + return p, nil } diff --git a/msdsn/conn_str_test.go b/msdsn/conn_str_test.go index f1bf03eb..32aee78a 100644 --- a/msdsn/conn_str_test.go +++ b/msdsn/conn_str_test.go @@ -193,6 +193,15 @@ func TestValidConnectionString(t *testing.T) { {"sqlserver://somehost?encrypt=true&tlsmin=1.1&columnencryption=1", func(p Config) bool { return p.Host == "somehost" && p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS11 && p.ColumnEncryption }}, + {"sqlserver://somehost", func(p Config) bool { + return p.Host == "somehost" && p.SendStringParametersAsUnicode + }}, + {"sqlserver://somehost?sendStringParametersAsUnicode=true", func(p Config) bool { + return p.Host == "somehost" && p.SendStringParametersAsUnicode + }}, + {"sqlserver://somehost?sendStringParametersAsUnicode=false", func(p Config) bool { + return p.Host == "somehost" && !p.SendStringParametersAsUnicode + }}, } for _, ts := range connStrings { p, err := Parse(ts.connStr) diff --git a/mssql.go b/mssql.go index 8870410d..0b544501 100644 --- a/mssql.go +++ b/mssql.go @@ -563,8 +563,8 @@ func (s *Stmt) sendQuery(ctx context.Context, args []namedValue) (err error) { if err != nil { return } - params[0] = makeStrParam(s.query) - params[1] = makeStrParam(strings.Join(decls, ",")) + params[0] = makeStrParam(s.query, true) + params[1] = makeStrParam(strings.Join(decls, ","), true) } if err = sendRpc(conn.sess.buf, headers, proc, 0, params, reset); err != nil { if conn.sess.logFlags&logErrors != 0 { @@ -968,9 +968,19 @@ func (r *Rows) ColumnTypeNullable(index int) (nullable, ok bool) { return } -func makeStrParam(val string) (res param) { - res.ti.TypeId = typeNVarChar - res.buffer = str2ucs2(val) +func getSendStringParametersAsUnicode(s *Stmt) bool { + return s == nil || s.c == nil || s.c.connector == nil || s.c.connector.params.SendStringParametersAsUnicode +} + +func makeStrParam(val string, sendStringParametersAsUnicode bool) (res param) { + if sendStringParametersAsUnicode { + res.ti.TypeId = typeNVarChar + res.buffer = str2ucs2(val) + res.ti.Size = len(res.buffer) + return + } + res.ti.TypeId = typeBigVarChar + res.buffer = []byte(val) res.ti.Size = len(res.buffer) return } @@ -1046,7 +1056,7 @@ func (s *Stmt) makeParam(val driver.Value) (res param, err error) { res.ti.Size = len(val) res.buffer = val case string: - res = makeStrParam(val) + res = makeStrParam(val, getSendStringParametersAsUnicode(s)) case sql.NullString: // only null values should be getting here res.ti.TypeId = typeNVarChar diff --git a/queries_test.go b/queries_test.go index 3d150d21..efd980b4 100644 --- a/queries_test.go +++ b/queries_test.go @@ -321,6 +321,61 @@ func TestSelectNewTypes(t *testing.T) { } } +func TestSelectWithVarchar(t *testing.T) { + conn, logger := openWithVarcharDSN(t) + defer conn.Close() + defer logger.StopLogging() + + t.Run("scan into string", func(t *testing.T) { + type testStruct struct { + sql string + args []interface{} + val string + } + + longstr := strings.Repeat("x", 10000) + + values := []testStruct{ + {"'abc'", []interface{}{}, "abc"}, + {"N'abc'", []interface{}{}, "abc"}, + {"cast(N'abc' as nvarchar(max))", []interface{}{}, "abc"}, + {"cast('abc' as text)", []interface{}{}, "abc"}, + {"cast(N'abc' as ntext)", []interface{}{}, "abc"}, + {"cast('abc' as char(3))", []interface{}{}, "abc"}, + {"cast('abc' as varchar(3))", []interface{}{}, "abc"}, + {fmt.Sprintf("cast(N'%s' as nvarchar(max))", longstr), []interface{}{}, longstr}, + {"cast(cast('abc' as varchar(3)) as sql_variant)", []interface{}{}, "abc"}, + {"cast(cast('abc' as char(3)) as sql_variant)", []interface{}{}, "abc"}, + {"cast(N'abc' as sql_variant)", []interface{}{}, "abc"}, + {"@p1", []interface{}{"abc"}, "abc"}, + {"@p1", []interface{}{longstr}, longstr}, + } + + for _, test := range values { + t.Run(test.sql, func(t *testing.T) { + stmt, err := conn.Prepare("select " + test.sql) + if err != nil { + t.Error("Prepare failed:", test.sql, err.Error()) + return + } + defer stmt.Close() + + row := stmt.QueryRow(test.args...) + var retval string + err = row.Scan(&retval) + if err != nil { + t.Error("Scan failed:", test.sql, err.Error()) + return + } + if retval != test.val { + t.Errorf("Values don't match '%s' '%s' for test: %s", retval, test.val, test.sql) + return + } + }) + } + }) +} + func TestTrans(t *testing.T) { conn, logger := open(t) defer conn.Close() diff --git a/tds_go110_test.go b/tds_go110_test.go index 76ecfc66..1c07d295 100644 --- a/tds_go110_test.go +++ b/tds_go110_test.go @@ -24,3 +24,22 @@ func getTestConnector(t testing.TB) (*Connector, *testLogger) { } return connector, &tl } + +func openWithVarcharDSN(t testing.TB) (*sql.DB, *testLogger) { + connector, logger := getTestConnectorWithVarcharDSN(t) + conn := sql.OpenDB(connector) + return conn, logger +} + +func getTestConnectorWithVarcharDSN(t testing.TB) (*Connector, *testLogger) { + tl := testLogger{t: t} + SetLogger(&tl) + s := testConnParams(t) + s.SendStringParametersAsUnicode = true + connector, err := NewConnector(s.URL().String()) + if err != nil { + t.Error("Open connection failed:", err.Error()) + return nil, &tl + } + return connector, &tl +}