diff --git a/src/sources/splunk_hec/mod.rs b/src/sources/splunk_hec/mod.rs index d6cb758063536..837389a81f945 100644 --- a/src/sources/splunk_hec/mod.rs +++ b/src/sources/splunk_hec/mod.rs @@ -67,6 +67,7 @@ use crate::{ }; mod acknowledgements; +use std::marker::PhantomData; // Event fields unique to splunk_hec source pub const CHANNEL: &str = "splunk_channel"; @@ -154,59 +155,12 @@ fn default_socket_address() -> SocketAddr { impl SourceConfig for SplunkConfig { async fn build(&self, cx: SourceContext) -> crate::Result { let tls = MaybeTlsSettings::from_config(self.tls.as_ref(), true)?; - let shutdown = cx.shutdown.clone(); - let out = cx.out.clone(); - let source = SplunkSource::new(self, tls.http_protocol_name(), cx); - - let event_service = source.event_service(out.clone()); - let raw_service = source.raw_service(out); - let health_service = source.health_service(); - let ack_service = source.ack_service(); - let options = SplunkSource::options(); - let services = path!("services" / "collector" / ..) - .and( - event_service - .or(raw_service) - .unify() - .or(health_service) - .unify() - .or(ack_service) - .unify() - .or(options) - .unify(), - ) - .or_else(finish_err); + let source: SplunkSource = SplunkSource::new(self, tls.http_protocol_name(), cx); let listener = tls.bind(&self.address).await?; - let keepalive_settings = self.keepalive.clone(); - Ok(Box::pin(async move { - let span = Span::current(); - let make_svc = make_service_fn(move |conn: &MaybeTlsIncomingStream| { - let svc = ServiceBuilder::new() - .layer(build_http_trace_layer(span.clone())) - .option_layer(keepalive_settings.max_connection_age_secs.map(|secs| { - MaxConnectionAgeLayer::new( - Duration::from_secs(secs), - keepalive_settings.max_connection_age_jitter_factor, - conn.peer_addr(), - ) - })) - .service(warp::service(services.clone())); - futures_util::future::ok::<_, Infallible>(svc) - }); - - Server::builder(hyper::server::accept::from_stream(listener.accept_stream())) - .serve(make_svc) - .with_graceful_shutdown(shutdown.map(|_| ())) - .await - .map_err(|err| { - error!("An error occurred: {:?}.", err); - })?; - - Ok(()) - })) + Ok(source.listen(listener, self.keepalive.clone())) } fn outputs(&self, global_log_namespace: LogNamespace) -> Vec { @@ -295,21 +249,24 @@ impl SourceConfig for SplunkConfig { } } -/// Shared data for responding to requests. -struct SplunkSource { +pub struct SplunkSource { valid_credentials: Vec, protocol: &'static str, idx_ack: Option>, store_hec_token: bool, log_namespace: LogNamespace, events_received: Registered, + shutdown: vector_lib::shutdown::ShutdownSignal, + out: SourceSender, + + _extractor: PhantomData, } -impl SplunkSource { +impl SplunkSource { fn new(config: &SplunkConfig, protocol: &'static str, cx: SourceContext) -> Self { let log_namespace = cx.log_namespace(config.log_namespace); let acknowledgements = cx.do_acknowledgements(config.acknowledgements.enabled.into()); - let shutdown = cx.shutdown; + let shutdown = cx.shutdown.clone(); let valid_tokens = config .valid_tokens .iter() @@ -323,7 +280,7 @@ impl SplunkSource { )) }); - SplunkSource { + SplunkSource:: { valid_credentials: valid_tokens .map(|token| format!("Splunk {}", token.inner())) .collect(), @@ -332,9 +289,66 @@ impl SplunkSource { store_hec_token: config.store_hec_token, log_namespace, events_received: register!(EventsReceived), + _extractor: PhantomData, + shutdown: cx.shutdown, + out: cx.out, } } + fn listen( + self, + listener: vector_lib::tls::MaybeTlsListener, + keepalive_settings: KeepaliveConfig, + ) -> super::Source { + let event_service = self.event_service(self.out.clone()); + let raw_service = self.raw_service(self.out.clone()); + let health_service = self.health_service(); + let ack_service = self.ack_service(); + + let options = options(); + + let services = path!("services" / "collector" / ..) + .and( + event_service + .or(raw_service) + .unify() + .or(health_service) + .unify() + .or(ack_service) + .unify() + .or(options) + .unify(), + ) + .or_else(finish_err); + + Box::pin(async move { + let span = Span::current(); + let make_svc = make_service_fn(move |conn: &MaybeTlsIncomingStream| { + let svc = ServiceBuilder::new() + .layer(build_http_trace_layer(span.clone())) + .option_layer(keepalive_settings.max_connection_age_secs.map(|secs| { + MaxConnectionAgeLayer::new( + Duration::from_secs(secs), + keepalive_settings.max_connection_age_jitter_factor, + conn.peer_addr(), + ) + })) + .service(warp::service(services.clone())); + futures_util::future::ok::<_, Infallible>(svc) + }); + + Server::builder(hyper::server::accept::from_stream(listener.accept_stream())) + .serve(make_svc) + .with_graceful_shutdown(self.shutdown.map(|_| ())) + .await + .map_err(|err| { + error!("An error occurred: {:?}.", err); + })?; + + Ok(()) + }) + } + fn event_service(&self, out: SourceSender) -> BoxedFilter<(Response,)> { let splunk_channel_query_param = warp::query::>() .map(|qs: HashMap| qs.get("channel").map(|v| v.to_owned())); @@ -349,7 +363,6 @@ impl SplunkSource { let store_hec_token = self.store_hec_token; let log_namespace = self.log_namespace; let events_received = self.events_received.clone(); - warp::post() .and( path!("event") @@ -411,15 +424,19 @@ impl SplunkSource { let mut error = None; let mut events = Vec::new(); - let iter: EventIterator<'_, StrRead<'_>> = EventIteratorGenerator { - deserializer: Deserializer::from_str(&body).into_iter::(), - channel, + let meta = RequestMeta { + token: token.clone(), remote, remote_addr, + }; + + let iter: EventIterator<'_, StrRead<'_>, E> = EventIteratorGenerator { + deserializer: Deserializer::from_str(&body).into_iter::(), + channel, batch, - token: token.filter(|_| store_hec_token).map(Into::into), log_namespace, events_received, + extractor: E::new(meta, log_namespace, store_hec_token), } .into(); @@ -461,7 +478,7 @@ impl SplunkSource { warp::post() .and(path!("raw" / "1.0").or(path!("raw"))) .and(self.authorization()) - .and(SplunkSource::required_channel()) + .and(required_channel()) .and(warp::addr::remote()) .and(warp::header::optional::("X-Forwarded-For")) .and(self.gzip()) @@ -538,40 +555,14 @@ impl SplunkSource { .boxed() } - fn lenient_json_content_type_check() -> impl Filter + Clone - where - T: Send + DeserializeOwned + 'static, - { - warp::header::optional::(CONTENT_TYPE.as_str()) - .and(warp::body::bytes()) - .and_then( - |ctype: Option, body: bytes::Bytes| async move { - let ok = ctype - .as_ref() - .and_then(|v| v.to_str().ok()) - .map(|h| h.to_ascii_lowercase().contains("application/json")) - .unwrap_or(true); - - if !ok { - return Err(warp::reject::custom(ApiError::UnsupportedContentType)); - } - - let value = serde_json::from_slice::(&body) - .map_err(|_| warp::reject::custom(ApiError::BadRequest))?; - - Ok(value) - }, - ) - } - fn ack_service(&self) -> BoxedFilter<(Response,)> { let idx_ack = self.idx_ack.clone(); warp::post() .and(warp::path!("ack")) .and(self.authorization()) - .and(SplunkSource::required_channel()) - .and(Self::lenient_json_content_type_check::()) + .and(required_channel()) + .and(lenient_json_content_type_check::()) .and_then(move |_, channel: String, req: HecAckStatusRequest| { let idx_ack = idx_ack.clone(); async move { @@ -588,23 +579,6 @@ impl SplunkSource { .boxed() } - fn options() -> BoxedFilter<(Response,)> { - let post = warp::options() - .and( - path!("event") - .or(path!("event" / "1.0")) - .or(path!("raw" / "1.0")) - .or(path!("raw")), - ) - .map(|_| warp::reply::with_header(warp::reply(), "Allow", "POST").into_response()); - - let get = warp::options() - .and(path!("health").or(path!("health" / "1.0"))) - .map(|_| warp::reply::with_header(warp::reply(), "Allow", "GET").into_response()); - - post.or(get).unify().boxed() - } - /// Authorize request fn authorization(&self) -> BoxedFilter<(Option,)> { let valid_credentials = self.valid_credentials.clone(); @@ -645,25 +619,69 @@ impl SplunkSource { }) .boxed() } +} - fn required_channel() -> BoxedFilter<(String,)> { - let splunk_channel_query_param = warp::query::>() - .map(|qs: HashMap| qs.get("channel").map(|v| v.to_owned())); - let splunk_channel_header = warp::header::optional::(X_SPLUNK_REQUEST_CHANNEL); +fn options() -> BoxedFilter<(Response,)> { + let post = warp::options() + .and( + path!("event") + .or(path!("event" / "1.0")) + .or(path!("raw" / "1.0")) + .or(path!("raw")), + ) + .map(|_| warp::reply::with_header(warp::reply(), "Allow", "POST").into_response()); - splunk_channel_header - .and(splunk_channel_query_param) - .and_then(|header: Option, query_param| async move { - header - .or(query_param) - .ok_or_else(|| Rejection::from(ApiError::MissingChannel)) - }) - .boxed() - } + let get = warp::options() + .and(path!("health").or(path!("health" / "1.0"))) + .map(|_| warp::reply::with_header(warp::reply(), "Allow", "GET").into_response()); + + post.or(get).unify().boxed() } + +fn required_channel() -> BoxedFilter<(String,)> { + let splunk_channel_query_param = warp::query::>() + .map(|qs: HashMap| qs.get("channel").map(|v| v.to_owned())); + let splunk_channel_header = warp::header::optional::(X_SPLUNK_REQUEST_CHANNEL); + + splunk_channel_header + .and(splunk_channel_query_param) + .and_then(|header: Option, query_param| async move { + header + .or(query_param) + .ok_or_else(|| Rejection::from(ApiError::MissingChannel)) + }) + .boxed() +} + +fn lenient_json_content_type_check() -> impl Filter + Clone +where + T: Send + DeserializeOwned + 'static, +{ + warp::header::optional::(CONTENT_TYPE.as_str()) + .and(warp::body::bytes()) + .and_then( + |ctype: Option, body: bytes::Bytes| async move { + let ok = ctype + .as_ref() + .and_then(|v| v.to_str().ok()) + .map(|h| h.to_ascii_lowercase().contains("application/json")) + .unwrap_or(true); + + if !ok { + return Err(warp::reject::custom(ApiError::UnsupportedContentType)); + } + + let value = serde_json::from_slice::(&body) + .map_err(|_| warp::reject::custom(ApiError::BadRequest))?; + + Ok(value) + }, + ) +} + /// Constructs one or more events from json-s coming from reader. /// If errors, it's done with input. -struct EventIterator<'de, R: JsonRead<'de>> { +struct EventIterator<'de, R: JsonRead<'de>, E: Extractor> { /// Remaining request with JSON events deserializer: serde_json::StreamDeserializer<'de, R, JsonValue>, /// Count of sent events @@ -673,66 +691,50 @@ struct EventIterator<'de, R: JsonRead<'de>> { /// Default time time: Time, /// Remaining extracted default values - extractors: [DefaultExtractor; 4], + extractor: E, /// Event finalization batch: Option, - /// Splunk HEC Token for passthrough - token: Option>, /// Lognamespace to put the events in log_namespace: LogNamespace, /// handle to EventsReceived registry events_received: Registered, } +#[derive(Debug, Clone)] +pub struct RequestMeta { + pub token: Option, + pub remote: Option, + pub remote_addr: Option, +} + /// Intermediate struct to generate an `EventIterator` -struct EventIteratorGenerator<'de, R: JsonRead<'de>> { +struct EventIteratorGenerator<'de, R: JsonRead<'de>, E: Extractor> { deserializer: serde_json::StreamDeserializer<'de, R, JsonValue>, channel: Option, batch: Option, - token: Option>, log_namespace: LogNamespace, events_received: Registered, - remote: Option, - remote_addr: Option, + extractor: E, } -impl<'de, R: JsonRead<'de>> From> for EventIterator<'de, R> { - fn from(f: EventIteratorGenerator<'de, R>) -> Self { +impl<'de, R: JsonRead<'de>, E: Extractor> From> + for EventIterator<'de, R, E> +{ + fn from(f: EventIteratorGenerator<'de, R, E>) -> Self { Self { deserializer: f.deserializer, events: 0, channel: f.channel.map(Value::from), time: Time::Now(Utc::now()), - extractors: [ - // Extract the host field with the given priority: - // 1. The host field is present in the event payload - // 2. The x-forwarded-for header is present in the incoming request - // 3. Use the `remote`: SocketAddr value provided by warp - DefaultExtractor::new_with( - "host", - log_schema().host_key().cloned().into(), - f.remote_addr - .or_else(|| f.remote.map(|addr| addr.to_string())) - .map(Value::from), - f.log_namespace, - ), - DefaultExtractor::new("index", OptionalValuePath::new(INDEX), f.log_namespace), - DefaultExtractor::new("source", OptionalValuePath::new(SOURCE), f.log_namespace), - DefaultExtractor::new( - "sourcetype", - OptionalValuePath::new(SOURCETYPE), - f.log_namespace, - ), - ], + extractor: f.extractor, batch: f.batch, - token: f.token, log_namespace: f.log_namespace, events_received: f.events_received, } } } -impl<'de, R: JsonRead<'de>> EventIterator<'de, R> { +impl<'de, R: JsonRead<'de>, E: Extractor> EventIterator<'de, R, E> { fn build_event(&mut self, mut json: JsonValue) -> Result { // Construct Event from parsed json event let mut log = match self.log_namespace { @@ -827,14 +829,7 @@ impl<'de, R: JsonRead<'de>> EventIterator<'de, R> { ); // Extract default extracted fields - for de in self.extractors.iter_mut() { - de.extract(&mut log, &mut json); - } - - // Add passthrough token if present - if let Some(token) = &self.token { - log.metadata_mut().set_splunk_hec_token(Arc::clone(token)); - } + self.extractor.extract(&mut log, &mut json); if let Some(batch) = self.batch.clone() { log = log.with_batch_notifier(&batch); @@ -921,7 +916,7 @@ impl<'de, R: JsonRead<'de>> EventIterator<'de, R> { } } -impl<'de, R: JsonRead<'de>> Iterator for EventIterator<'de, R> { +impl<'de, R: JsonRead<'de>, E: Extractor> Iterator for EventIterator<'de, R, E> { type Item = Result; fn next(&mut self) -> Option { @@ -980,21 +975,22 @@ fn parse_timestamp(t: i64) -> Option> { Some(ts) } -/// Maintains last known extracted value of field and uses it in the absence of field. -struct DefaultExtractor { +/// MetaExtractor is a helper struct that extracts a field from the request and adds it to the log event metadata. +/// It maintains last known extracted value of field and uses it in the absence of field. +struct FieldExtractor { field: &'static str, to_field: OptionalValuePath, value: Option, log_namespace: LogNamespace, } -impl DefaultExtractor { +impl FieldExtractor { const fn new( field: &'static str, to_field: OptionalValuePath, log_namespace: LogNamespace, ) -> Self { - DefaultExtractor { + Self { field, to_field, value: None, @@ -1008,7 +1004,7 @@ impl DefaultExtractor { value: impl Into>, log_namespace: LogNamespace, ) -> Self { - DefaultExtractor { + Self { field, to_field, value: value.into(), @@ -1037,6 +1033,71 @@ impl DefaultExtractor { } } +/// Extractor describes a hook that can be attached to the splunk source to extract custom properties +/// from the request and received values and add them to the log event. +/// +/// The default implementation of the trait is DefaultExtractor which extracts the host field, index, +/// source and sourcetype fields and stores them in the log metadata. +/// +/// This DefaultExtractor can be wrapped in a custom implementation to extract additional properties. +pub trait Extractor { + /// create a new instance of the extractor + fn new(meta: RequestMeta, log_namespace: LogNamespace, store_hec_token: bool) -> Self; + + /// extract will be called for each value in the request and the associated log. + fn extract(&mut self, log: &mut LogEvent, value: &mut JsonValue); +} + +pub struct DefaultExtractor { + extractors: [FieldExtractor; 4], + store_hec_token: bool, + token: Option, +} + +impl Extractor for DefaultExtractor { + fn new(meta: RequestMeta, log_namespace: LogNamespace, store_hec_token: bool) -> Self { + DefaultExtractor { + extractors: [ + // Extract the host field with the given priority: + // 1. The host field is present in the event payload + // 2. The x-forwarded-for header is present in the incoming request + // 3. Use the `remote`: SocketAddr value provided by warp + FieldExtractor::new_with( + "host", + log_schema().host_key().cloned().into(), + meta.remote_addr + .or_else(|| meta.remote.map(|addr| addr.to_string())) + .map(Value::from), + log_namespace, + ), + FieldExtractor::new("index", OptionalValuePath::new(INDEX), log_namespace), + FieldExtractor::new("source", OptionalValuePath::new(SOURCE), log_namespace), + FieldExtractor::new( + "sourcetype", + OptionalValuePath::new(SOURCETYPE), + log_namespace, + ), + ], + store_hec_token, + token: meta.token, + } + } + + fn extract(&mut self, log: &mut LogEvent, value: &mut JsonValue) { + for de in self.extractors.iter_mut() { + de.extract(log, value); + } + + // Add passthrough token if present + if let Some(token) = &self.token + && self.store_hec_token + { + log.metadata_mut() + .set_splunk_hec_token(token.clone().into()); + } + } +} + /// For tracking origin of the timestamp #[derive(Clone, Debug)] enum Time {