diff --git a/dd-java-agent/instrumentation/spring-webmvc-3.1/src/main/java/datadog/trace/instrumentation/springweb/HttpMessageConverterInstrumentation.java b/dd-java-agent/instrumentation/spring-webmvc-3.1/src/main/java/datadog/trace/instrumentation/springweb/HttpMessageConverterInstrumentation.java index e2684d5473e..59bbb726476 100644 --- a/dd-java-agent/instrumentation/spring-webmvc-3.1/src/main/java/datadog/trace/instrumentation/springweb/HttpMessageConverterInstrumentation.java +++ b/dd-java-agent/instrumentation/spring-webmvc-3.1/src/main/java/datadog/trace/instrumentation/springweb/HttpMessageConverterInstrumentation.java @@ -71,6 +71,27 @@ public void methodAdvice(MethodTransformer transformer) { .and(takesArgument(1, Class.class)) .and(takesArgument(2, named("org.springframework.http.HttpInputMessage"))), HttpMessageConverterInstrumentation.class.getName() + "$HttpMessageConverterReadAdvice"); + + transformer.applyAdvice( + isMethod() + .and(isPublic()) + .and(named("write")) + .and(takesArguments(3)) + .and(takesArgument(0, Object.class)) + .and(takesArgument(1, named("org.springframework.http.MediaType"))) + .and(takesArgument(2, named("org.springframework.http.HttpOutputMessage"))), + HttpMessageConverterInstrumentation.class.getName() + "$HttpMessageConverterWriteAdvice"); + + transformer.applyAdvice( + isMethod() + .and(isPublic()) + .and(named("write")) + .and(takesArguments(4)) + .and(takesArgument(0, Object.class)) + .and(takesArgument(1, Type.class)) + .and(takesArgument(2, named("org.springframework.http.MediaType"))) + .and(takesArgument(3, named("org.springframework.http.HttpOutputMessage"))), + HttpMessageConverterInstrumentation.class.getName() + "$HttpMessageConverterWriteAdvice"); } @RequiresRequestContext(RequestContextSlot.APPSEC) @@ -106,4 +127,37 @@ public static void after( } } } + + @RequiresRequestContext(RequestContextSlot.APPSEC) + public static class HttpMessageConverterWriteAdvice { + @Advice.OnMethodEnter(suppress = Throwable.class) + public static void before( + @Advice.Argument(0) final Object obj, @ActiveRequestContext RequestContext reqCtx) { + if (obj == null) { + return; + } + + CallbackProvider cbp = AgentTracer.get().getCallbackProvider(RequestContextSlot.APPSEC); + BiFunction> callback = + cbp.getCallback(EVENTS.responseBody()); + if (callback == null) { + return; + } + + Flow flow = callback.apply(reqCtx, obj); + Flow.Action action = flow.getAction(); + if (action instanceof Flow.Action.RequestBlockingAction) { + Flow.Action.RequestBlockingAction rba = (Flow.Action.RequestBlockingAction) action; + BlockResponseFunction brf = reqCtx.getBlockResponseFunction(); + if (brf != null) { + brf.tryCommitBlockingResponse( + reqCtx.getTraceSegment(), + rba.getStatusCode(), + rba.getBlockingContentType(), + rba.getExtraHeaders()); + } + throw new BlockingException("Blocked response (for HttpMessageConverter/write)"); + } + } + } } diff --git a/dd-java-agent/instrumentation/spring-webmvc-3.1/src/test/groovy/test/boot/SpringBootBasedTest.groovy b/dd-java-agent/instrumentation/spring-webmvc-3.1/src/test/groovy/test/boot/SpringBootBasedTest.groovy index a80247071ee..aa2dde1850e 100644 --- a/dd-java-agent/instrumentation/spring-webmvc-3.1/src/test/groovy/test/boot/SpringBootBasedTest.groovy +++ b/dd-java-agent/instrumentation/spring-webmvc-3.1/src/test/groovy/test/boot/SpringBootBasedTest.groovy @@ -77,6 +77,11 @@ class SpringBootBasedTest extends HttpServerTest return "boot-context" } + @Override + boolean testResponseBodyJson() { + return true + } + @Override String expectedServiceName() { servletContext @@ -163,8 +168,7 @@ class SpringBootBasedTest extends HttpServerTest @Override Map expectedExtraServerTags(ServerEndpoint endpoint) { - ["servlet.path": endpoint.path, "servlet.context": "/$servletContext"] + - extraServerTags + ["servlet.path": endpoint.path, "servlet.context": "/$servletContext"] + extraServerTags } @Override diff --git a/dd-smoke-tests/appsec/springboot/src/main/java/datadog/smoketest/appsec/springboot/controller/WebController.java b/dd-smoke-tests/appsec/springboot/src/main/java/datadog/smoketest/appsec/springboot/controller/WebController.java index 0338d5afe97..791bd30c29c 100644 --- a/dd-smoke-tests/appsec/springboot/src/main/java/datadog/smoketest/appsec/springboot/controller/WebController.java +++ b/dd-smoke-tests/appsec/springboot/src/main/java/datadog/smoketest/appsec/springboot/controller/WebController.java @@ -9,6 +9,7 @@ import java.nio.file.Paths; import java.sql.Connection; import java.sql.DriverManager; +import java.util.Map; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpSession; import org.apache.commons.httpclient.HttpClient; @@ -239,6 +240,14 @@ public ResponseEntity exceedResponseHeaders() { return new ResponseEntity<>("Custom headers added", headers, HttpStatus.OK); } + @PostMapping("/api_security/response") + public ResponseEntity> apiSecurityResponse( + @RequestBody Map body) { + // This endpoint is used to test API security response handling + // It simply returns the body received in the request + return ResponseEntity.ok(body); + } + private void withProcess(final Operation op) { Process process = null; try { diff --git a/dd-smoke-tests/appsec/springboot/src/test/groovy/datadog/smoketest/appsec/AppSecHttpMessageConverterSmokeTest.groovy b/dd-smoke-tests/appsec/springboot/src/test/groovy/datadog/smoketest/appsec/AppSecHttpMessageConverterSmokeTest.groovy new file mode 100644 index 00000000000..3e20741c8af --- /dev/null +++ b/dd-smoke-tests/appsec/springboot/src/test/groovy/datadog/smoketest/appsec/AppSecHttpMessageConverterSmokeTest.groovy @@ -0,0 +1,70 @@ +package datadog.smoketest.appsec + +import groovy.json.JsonOutput +import groovy.json.JsonSlurper +import okhttp3.MediaType +import okhttp3.Request +import okhttp3.RequestBody + +import java.util.zip.GZIPInputStream + +class AppSecHttpMessageConverterSmokeTest extends AbstractAppSecServerSmokeTest { + + @Override + def logLevel() { + 'DEBUG' + } + + @Override + ProcessBuilder createProcessBuilder() { + String springBootShadowJar = System.getProperty("datadog.smoketest.appsec.springboot.shadowJar.path") + + List command = new ArrayList<>() + command.add(javaPath()) + command.addAll(defaultJavaProperties) + command.addAll(defaultAppSecProperties) + command.addAll((String[]) [ + "-Ddd.writer.type=MultiWriter:TraceStructureWriter:${output.getAbsolutePath()},DDAgentWriter", + "-jar", + springBootShadowJar, + "--server.port=${httpPort}" + ]) + ProcessBuilder processBuilder = new ProcessBuilder(command) + processBuilder.directory(new File(buildDirectory)) + } + + @Override + File createTemporaryFile() { + return new File("${buildDirectory}/tmp/trace-structure-http-converter.out") + } + + void 'test response schema extraction'() { + given: + def url = "http://localhost:${httpPort}/api_security/response" + def body = [ + "main" : [["key": "id001", "value": 1345.67], ["value": 1567.89, "key": "id002"]], + "nullable": null, + ] + def request = new Request.Builder() + .url(url) + .post(RequestBody.create(MediaType.get('application/json'), JsonOutput.toJson(body))) + .build() + + when: + final response = client.newCall(request).execute() + waitForTraceCount(1) + + then: + response.code() == 200 + def span = rootSpans.first() + span.meta.containsKey('_dd.appsec.s.res.headers') + span.meta.containsKey('_dd.appsec.s.res.body') + final schema = new JsonSlurper().parse(unzip(span.meta.get('_dd.appsec.s.res.body'))) + assert schema == [["main": [[[["key": [8], "value": [16]]]], ["len": 2]], "nullable": [1]]] + } + + private static byte[] unzip(final String text) { + final inflaterStream = new GZIPInputStream(new ByteArrayInputStream(text.decodeBase64())) + return inflaterStream.getBytes() + } +}