Skip to content

Commit 5fadd4a

Browse files
committed
fix(contrib/database/sql): detect lib/pq and disable DBM-APM link on COPY
1 parent 66aa2dd commit 5fadd4a

File tree

3 files changed

+133
-35
lines changed

3 files changed

+133
-35
lines changed

contrib/database/sql/conn.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"context"
1010
"database/sql/driver"
1111
"math"
12+
"strings"
1213
"time"
1314

1415
"github.com/DataDog/dd-trace-go/v2/appsec/events"
@@ -285,6 +286,11 @@ func (tc *TracedConn) providedPeerService(ctx context.Context) string {
285286
// with a span ID injected into SQL comments. The returned span ID should be used when the SQL span is created
286287
// following the traced database call.
287288
func (tc *TracedConn) injectComments(ctx context.Context, query string, mode tracer.DBMPropagationMode) (cquery string, spanID uint64) {
289+
if tc.cfg.copyNotSupported && strings.EqualFold(query[:4], "COPY") {
290+
// COPY is not supported for lib/pq, so we need to disable the comment injection
291+
mode = tracer.DBMPropagationModeDisabled
292+
}
293+
288294
// The sql span only gets created after the call to the database because we need to be able to skip spans
289295
// when a driver returns driver.ErrSkip. In order to work with those constraints, a new span id is generated and
290296
// used during SQL comment injection and returned for the sql span to be used later when/if the span

contrib/database/sql/option.go

Lines changed: 83 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ type config struct {
2929
dbmPropagationMode tracer.DBMPropagationMode
3030
dbStats bool
3131
statsdClient instrumentation.StatsdClient
32+
copyNotSupported bool
3233
}
3334

3435
// checkStatsdRequired adds a statsdclient onto the config if dbstats is enabled
@@ -48,49 +49,57 @@ func (c *config) checkStatsdRequired() {
4849
}
4950

5051
func (c *config) checkDBMPropagation(driverName string, driver driver.Driver, dsn string) {
51-
if c.dbmPropagationMode == tracer.DBMPropagationModeFull {
52-
if dsn == "" {
53-
dsn = c.dsn
54-
}
55-
if dbSystem, ok := dbmFullModeUnsupported(driverName, driver, dsn); ok {
56-
instr.Logger().Warn("Using DBM_PROPAGATION_MODE in 'full' mode is not supported for %s, downgrading to 'service' mode. "+
57-
"See https://docs.datadoghq.com/database_monitoring/connect_dbm_and_apm/ for more info.",
58-
dbSystem,
59-
)
60-
c.dbmPropagationMode = tracer.DBMPropagationModeService
61-
}
52+
if c.dbmPropagationMode == tracer.DBMPropagationModeDisabled {
53+
return
54+
}
55+
if c.dbmPropagationMode == tracer.DBMPropagationModeUndefined {
56+
return
57+
}
58+
if dsn == "" {
59+
dsn = c.dsn
60+
}
61+
// this case applies to full and service modes
62+
if dbSystem, reason, ok := dbmPartiallySupported(driver, c); ok {
63+
instr.Logger().Warn("Using DBM_PROPAGATION_MODE in '%s' mode is partially supported for %s: %s. "+
64+
"See https://docs.datadoghq.com/database_monitoring/connect_dbm_and_apm/ for more info.",
65+
c.dbmPropagationMode,
66+
dbSystem,
67+
reason,
68+
)
69+
}
70+
if c.dbmPropagationMode != tracer.DBMPropagationModeFull {
71+
return
72+
}
73+
// full mode is not supported for some drivers, so we need to check for that
74+
if dbSystem, ok := dbmFullModeUnsupported(driverName, driver, dsn); ok {
75+
instr.Logger().Warn("Using DBM_PROPAGATION_MODE in 'full' mode is not supported for %s, downgrading to 'service' mode. "+
76+
"See https://docs.datadoghq.com/database_monitoring/connect_dbm_and_apm/ for more info.",
77+
dbSystem,
78+
)
79+
c.dbmPropagationMode = tracer.DBMPropagationModeService
6280
}
6381
}
6482

83+
type unsupportedDriverModule struct {
84+
prefix string
85+
pkgName string
86+
dbSystem string
87+
reason string
88+
updateConfig func(*config)
89+
}
90+
6591
func dbmFullModeUnsupported(driverName string, driver driver.Driver, dsn string) (string, bool) {
6692
const (
6793
sqlServer = "SQL Server"
6894
oracle = "Oracle"
6995
)
70-
// check if the driver package path is one of the unsupported ones.
71-
if tp := reflect.TypeOf(driver); tp != nil && (tp.Kind() == reflect.Pointer || tp.Kind() == reflect.Struct) {
72-
pkgPath := ""
73-
switch tp.Kind() {
74-
case reflect.Pointer:
75-
pkgPath = tp.Elem().PkgPath()
76-
case reflect.Struct:
77-
pkgPath = tp.PkgPath()
78-
}
79-
driverPkgs := [][3]string{
80-
{"github.com", "denisenkom/go-mssqldb", sqlServer},
81-
{"github.com", "microsoft/go-mssqldb", sqlServer},
82-
{"github.com", "sijms/go-ora", oracle},
83-
}
84-
for _, dp := range driverPkgs {
85-
prefix, pkgName, dbSystem := dp[0], dp[1], dp[2]
86-
87-
// compare without the prefix to make it work for vendoring.
88-
// also, compare only the prefix to make the comparison work when using major versions
89-
// of the libraries or subpackages.
90-
if strings.HasPrefix(strings.TrimPrefix(pkgPath, prefix+"/"), pkgName) {
91-
return dbSystem, true
92-
}
93-
}
96+
driverPkgs := []unsupportedDriverModule{
97+
{"github.com", "denisenkom/go-mssqldb", sqlServer, "", nil},
98+
{"github.com", "microsoft/go-mssqldb", sqlServer, "", nil},
99+
{"github.com", "sijms/go-ora", oracle, "", nil},
100+
}
101+
if ix := unsupportedDriver(driver, driverPkgs); ix != -1 {
102+
return driverPkgs[ix].dbSystem, true
94103
}
95104

96105
// check the DSN if provided.
@@ -123,6 +132,45 @@ func dbmFullModeUnsupported(driverName string, driver driver.Driver, dsn string)
123132
return "", false
124133
}
125134

135+
func dbmPartiallySupported(driver driver.Driver, c *config) (string, string, bool) {
136+
driverPkgs := []unsupportedDriverModule{
137+
{"github.com", "lib/pq", "PostgreSQL", "COPY doesn't support comments", func(cfg *config) {
138+
cfg.copyNotSupported = true
139+
}},
140+
}
141+
if ix := unsupportedDriver(driver, driverPkgs); ix != -1 {
142+
if driverPkgs[ix].updateConfig != nil {
143+
driverPkgs[ix].updateConfig(c)
144+
}
145+
return driverPkgs[ix].dbSystem, driverPkgs[ix].reason, true
146+
}
147+
return "", "", false
148+
}
149+
150+
func unsupportedDriver(driver driver.Driver, driverPkgs []unsupportedDriverModule) int {
151+
// check if the driver package path is one of the unsupported ones.
152+
if tp := reflect.TypeOf(driver); tp != nil && (tp.Kind() == reflect.Pointer || tp.Kind() == reflect.Struct) {
153+
pkgPath := ""
154+
switch tp.Kind() {
155+
case reflect.Pointer:
156+
pkgPath = tp.Elem().PkgPath()
157+
case reflect.Struct:
158+
pkgPath = tp.PkgPath()
159+
}
160+
for ix, dp := range driverPkgs {
161+
prefix, pkgName := dp.prefix, dp.pkgName
162+
163+
// compare without the prefix to make it work for vendoring.
164+
// also, compare only the prefix to make the comparison work when using major versions
165+
// of the libraries or subpackages.
166+
if strings.HasPrefix(strings.TrimPrefix(pkgPath, prefix+"/"), pkgName) {
167+
return ix
168+
}
169+
}
170+
}
171+
return -1
172+
}
173+
126174
// Option describes options for the database/sql integration.
127175
type Option interface {
128176
apply(*config)

contrib/database/sql/propagation_test.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,14 @@ import (
1212
"database/sql/driver"
1313
"io"
1414
"net/http"
15+
"os"
1516
"regexp"
1617
"testing"
18+
"time"
1719

1820
mssql "github.com/denisenkom/go-mssqldb"
1921
"github.com/go-sql-driver/mysql"
22+
"github.com/lib/pq"
2023
"github.com/stretchr/testify/assert"
2124
"github.com/stretchr/testify/require"
2225

@@ -256,6 +259,47 @@ func TestDBMPropagation(t *testing.T) {
256259
}
257260
}
258261

262+
func TestDBMPropagationFullOnPqCopy(t *testing.T) {
263+
if _, ok := os.LookupEnv("INTEGRATION"); !ok {
264+
t.Skip("skipping integration test")
265+
}
266+
tr := mocktracer.Start()
267+
defer tr.Stop()
268+
269+
Register("postgres", &pq.Driver{}, WithDBMPropagation(tracer.DBMPropagationModeFull))
270+
db, err := Open("postgres", "postgres://postgres:[email protected]:5432/postgres?sslmode=disable")
271+
require.NoError(t, err)
272+
273+
t.Cleanup(func() {
274+
// Using a new 10s-timeout context, as we may be running cleanup after the original context expired.
275+
_, cancel := context.WithTimeout(context.Background(), 10*time.Second)
276+
defer cancel()
277+
assert.NoError(t, db.Close())
278+
})
279+
280+
db.Exec("DROP TABLE IF EXISTS testsql")
281+
db.Exec("CREATE TABLE testsql (dn text, name text, sam_account_name text, mail text, primary_group_id text)")
282+
t.Cleanup(func() {
283+
db.Exec("DROP TABLE IF EXISTS testsql")
284+
})
285+
286+
tx, err := db.Begin()
287+
require.NoError(t, err)
288+
defer tx.Rollback()
289+
290+
s := pq.CopyInSchema("public", "testsql", "dn", "name", "sam_account_name", "mail", "primary_group_id")
291+
stmt, err := tx.Prepare(s)
292+
require.NoError(t, err)
293+
defer stmt.Close()
294+
295+
_, err = stmt.Exec("dn", "name0", "sam", nil, nil)
296+
require.NoError(t, err)
297+
298+
spans := tr.FinishedSpans()
299+
require.Len(t, spans, 6)
300+
assert.Equal(t, `COPY "public"."testsql" ("dn", "name", "sam_account_name", "mail", "primary_group_id") FROM STDIN`, spans[5].Tags()[ext.ResourceName])
301+
}
302+
259303
func TestDBMTraceContextTagging(t *testing.T) {
260304
testCases := []struct {
261305
name string

0 commit comments

Comments
 (0)