|
| 1 | +// Copyright 2025-Present Datadog, Inc. https://www.datadoghq.com/ |
| 2 | +// SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +use dd_trace::Config; |
| 5 | +use dd_trace_sampling::DatadogSampler; |
| 6 | +use opentelemetry::trace::TraceContextExt; |
| 7 | +use opentelemetry_sdk::{trace::ShouldSample, Resource}; |
| 8 | +use std::{ |
| 9 | + collections::HashMap, |
| 10 | + sync::{Arc, RwLock}, |
| 11 | +}; |
| 12 | + |
| 13 | +use crate::{ |
| 14 | + span_processor::{RegisterTracePropagationResult, SamplingDecision}, |
| 15 | + TraceRegistry, |
| 16 | +}; |
| 17 | + |
| 18 | +#[derive(Debug, Clone)] |
| 19 | +pub struct Sampler { |
| 20 | + sampler: DatadogSampler, |
| 21 | + trace_registry: Arc<TraceRegistry>, |
| 22 | +} |
| 23 | + |
| 24 | +impl Sampler { |
| 25 | + pub fn new( |
| 26 | + cfg: &Config, |
| 27 | + resource: Arc<RwLock<Resource>>, |
| 28 | + trace_registry: Arc<TraceRegistry>, |
| 29 | + ) -> Self { |
| 30 | + let rules = cfg |
| 31 | + .trace_sampling_rules() |
| 32 | + .iter() |
| 33 | + .map(|r| { |
| 34 | + dd_trace_sampling::SamplingRule::new( |
| 35 | + r.sample_rate, |
| 36 | + r.service.clone(), |
| 37 | + r.name.clone(), |
| 38 | + r.resource.clone(), |
| 39 | + Some(r.tags.clone()), |
| 40 | + Some(r.provenance.clone()), |
| 41 | + ) |
| 42 | + }) |
| 43 | + .collect::<Vec<_>>(); |
| 44 | + let sampler = |
| 45 | + dd_trace_sampling::DatadogSampler::new(rules, cfg.trace_rate_limit(), resource); |
| 46 | + Self { |
| 47 | + sampler, |
| 48 | + trace_registry, |
| 49 | + } |
| 50 | + } |
| 51 | +} |
| 52 | + |
| 53 | +impl ShouldSample for Sampler { |
| 54 | + fn should_sample( |
| 55 | + &self, |
| 56 | + parent_context: Option<&opentelemetry::Context>, |
| 57 | + trace_id: opentelemetry::trace::TraceId, |
| 58 | + name: &str, |
| 59 | + span_kind: &opentelemetry::trace::SpanKind, |
| 60 | + attributes: &[opentelemetry::KeyValue], |
| 61 | + _links: &[opentelemetry::trace::Link], |
| 62 | + ) -> opentelemetry::trace::SamplingResult { |
| 63 | + let result = self.sampler.sample( |
| 64 | + parent_context |
| 65 | + .filter(|c| c.has_active_span()) |
| 66 | + .map(|c| c.span().span_context().is_sampled()), |
| 67 | + trace_id, |
| 68 | + name, |
| 69 | + span_kind, |
| 70 | + attributes, |
| 71 | + ); |
| 72 | + if let Some(trace_root_info) = &result.trace_root_info { |
| 73 | + match self.trace_registry.register_trace_propagation_data( |
| 74 | + trace_id.to_bytes(), |
| 75 | + SamplingDecision { |
| 76 | + decision: trace_root_info.sampling_priority(result.is_sampled).value(), |
| 77 | + // TODO: unify these types with decision maker with the one in the span |
| 78 | + // processor |
| 79 | + decision_maker: trace_root_info.mechanism.value() as i8, |
| 80 | + }, |
| 81 | + None, |
| 82 | + // TODO(paullgdc): This is here so the injector adds the t.dm tag to |
| 83 | + // tracecontext. The injector should probably inject it from |
| 84 | + // the trace propagation data instead of tags. |
| 85 | + Some(HashMap::from_iter([( |
| 86 | + "_dd.p.dm".to_string(), |
| 87 | + format!("{}", -(trace_root_info.mechanism.value() as i32)), |
| 88 | + )])), |
| 89 | + ) { |
| 90 | + RegisterTracePropagationResult::Existing(sampling_decision) => { |
| 91 | + return opentelemetry::trace::SamplingResult { |
| 92 | + decision: if sampling_decision.decision > 0 { |
| 93 | + opentelemetry::trace::SamplingDecision::RecordAndSample |
| 94 | + } else { |
| 95 | + opentelemetry::trace::SamplingDecision::RecordOnly |
| 96 | + }, |
| 97 | + attributes: Vec::new(), |
| 98 | + trace_state: parent_context |
| 99 | + .map(|c| c.span().span_context().trace_state().clone()) |
| 100 | + .unwrap_or_default(), |
| 101 | + } |
| 102 | + } |
| 103 | + RegisterTracePropagationResult::New => {} |
| 104 | + } |
| 105 | + } |
| 106 | + |
| 107 | + opentelemetry::trace::SamplingResult { |
| 108 | + decision: result.to_otel_decision(), |
| 109 | + attributes: result.to_dd_sampling_tags(), |
| 110 | + trace_state: parent_context |
| 111 | + .map(|c| c.span().span_context().trace_state().clone()) |
| 112 | + .unwrap_or_default(), |
| 113 | + } |
| 114 | + } |
| 115 | +} |
| 116 | + |
| 117 | +#[cfg(test)] |
| 118 | +mod tests { |
| 119 | + use super::*; |
| 120 | + use dd_trace::configuration::SamplingRuleConfig; |
| 121 | + use opentelemetry::{ |
| 122 | + trace::{SamplingDecision, SpanContext, SpanKind, TraceId, TraceState}, |
| 123 | + Context, SpanId, TraceFlags, |
| 124 | + }; |
| 125 | + use opentelemetry_sdk::trace::ShouldSample; |
| 126 | + use std::env; |
| 127 | + |
| 128 | + #[test] |
| 129 | + fn test_create_sampler_with_sampling_rules() { |
| 130 | + // Build a fresh config to pick up the env var |
| 131 | + let mut config = Config::builder(); |
| 132 | + config.set_trace_sampling_rules(vec![SamplingRuleConfig { |
| 133 | + sample_rate: 0.5, |
| 134 | + service: Some("test-service".to_string()), |
| 135 | + name: None, |
| 136 | + resource: None, |
| 137 | + tags: HashMap::new(), |
| 138 | + provenance: "customer".to_string(), |
| 139 | + }]); |
| 140 | + let config = config.build(); |
| 141 | + |
| 142 | + let test_resource = Arc::new(RwLock::new(Resource::builder().build())); |
| 143 | + let sampler = Sampler::new(&config, test_resource, Arc::new(TraceRegistry::new())); |
| 144 | + |
| 145 | + let trace_id_bytes = [1; 16]; |
| 146 | + let trace_id = TraceId::from_bytes(trace_id_bytes); |
| 147 | + |
| 148 | + // Basic assertion: Check if the attributes added by the sampler are not empty, |
| 149 | + // implying some sampling logic (like adding priority tags) ran. |
| 150 | + assert!( |
| 151 | + !sampler |
| 152 | + .should_sample(None, trace_id, "test", &SpanKind::Client, &[], &[]) |
| 153 | + .attributes |
| 154 | + .is_empty(), |
| 155 | + "Sampler should add attributes even if decision is complex" |
| 156 | + ); |
| 157 | + |
| 158 | + // Clean up environment |
| 159 | + env::remove_var("DD_TRACE_SAMPLING_RULES"); |
| 160 | + } |
| 161 | + |
| 162 | + #[test] |
| 163 | + fn test_create_default_sampler() { |
| 164 | + // Create a default config (no rules, no specific rate limit) |
| 165 | + let config = Config::builder().build(); |
| 166 | + |
| 167 | + let test_resource = Arc::new(RwLock::new(Resource::builder_empty().build())); |
| 168 | + let sampler = Sampler::new(&config, test_resource, Arc::new(TraceRegistry::new())); |
| 169 | + |
| 170 | + let trace_id_bytes = [2; 16]; |
| 171 | + let trace_id = TraceId::from_bytes(trace_id_bytes); |
| 172 | + |
| 173 | + // Verify the default sampler behavior |
| 174 | + let result = sampler.should_sample(None, trace_id, "test", &SpanKind::Client, &[], &[]); |
| 175 | + assert_eq!( |
| 176 | + result.decision, |
| 177 | + SamplingDecision::RecordAndSample, |
| 178 | + "Default sampler should record and sample by default" |
| 179 | + ); |
| 180 | + } |
| 181 | + |
| 182 | + #[test] |
| 183 | + fn test_trace_state_propagation() { |
| 184 | + let config = Config::builder().build(); |
| 185 | + |
| 186 | + let test_resource = Arc::new(RwLock::new(Resource::builder_empty().build())); |
| 187 | + let sampler = Sampler::new(&config, test_resource, Arc::new(TraceRegistry::new())); |
| 188 | + |
| 189 | + let trace_id = TraceId::from_bytes([2; 16]); |
| 190 | + let span_id = SpanId::from_bytes([3; 8]); |
| 191 | + |
| 192 | + for is_sampled in [true, false] { |
| 193 | + let trace_state = TraceState::from_key_value([("test_key", "test_value")]).unwrap(); |
| 194 | + let span_context = SpanContext::new( |
| 195 | + trace_id, |
| 196 | + span_id, |
| 197 | + is_sampled |
| 198 | + .then_some(TraceFlags::SAMPLED) |
| 199 | + .unwrap_or_default(), |
| 200 | + true, |
| 201 | + trace_state.clone(), |
| 202 | + ); |
| 203 | + |
| 204 | + // Verify the sampler with a parent context |
| 205 | + let result = sampler.should_sample( |
| 206 | + Some(&Context::new().with_remote_span_context(span_context)), |
| 207 | + trace_id, |
| 208 | + "test", |
| 209 | + &SpanKind::Client, |
| 210 | + &[], |
| 211 | + &[], |
| 212 | + ); |
| 213 | + assert_eq!( |
| 214 | + result.decision, |
| 215 | + if is_sampled { |
| 216 | + SamplingDecision::RecordAndSample |
| 217 | + } else { |
| 218 | + SamplingDecision::RecordOnly |
| 219 | + }, |
| 220 | + "Sampler should respect parent context sampling decision" |
| 221 | + ); |
| 222 | + assert_eq!( |
| 223 | + result.trace_state.header(), |
| 224 | + "test_key=test_value", |
| 225 | + "Sampler should propagate trace state from parent context" |
| 226 | + ); |
| 227 | + } |
| 228 | + } |
| 229 | +} |
0 commit comments