11/*
2- * Copyright 2002-2015 the original author or authors.
2+ * Copyright 2002-2018 the original author or authors.
33 *
44 * Licensed under the Apache License, Version 2.0 (the "License");
55 * you may not use this file except in compliance with the License.
1616
1717package org .springframework .web .cors .reactive ;
1818
19+ import java .util .concurrent .atomic .AtomicReference ;
20+
1921import org .junit .Test ;
22+ import reactor .core .publisher .Mono ;
2023
2124import org .springframework .http .HttpHeaders ;
25+ import org .springframework .http .server .reactive .ServerHttpRequest ;
2226import org .springframework .mock .http .server .reactive .test .MockServerHttpRequest ;
27+ import org .springframework .mock .web .test .server .MockServerWebExchange ;
28+ import org .springframework .web .filter .reactive .ForwardedHeaderFilter ;
2329
24- import static org .junit .Assert .assertFalse ;
25- import static org .junit .Assert .assertTrue ;
26- import static org .springframework .mock .http .server .reactive .test .MockServerHttpRequest .get ;
27- import static org .springframework .mock .http .server .reactive .test .MockServerHttpRequest .options ;
30+ import static org .junit .Assert .*;
31+ import static org .springframework .mock .http .server .reactive .test .MockServerHttpRequest .*;
2832
2933/**
3034 * Test case for reactive {@link CorsUtils}.
@@ -35,19 +39,19 @@ public class CorsUtilsTests {
3539
3640 @ Test
3741 public void isCorsRequest () {
38- MockServerHttpRequest request = get ("/" ).header (HttpHeaders .ORIGIN , "http://domain.com" ).build ();
42+ ServerHttpRequest request = get ("/" ).header (HttpHeaders .ORIGIN , "http://domain.com" ).build ();
3943 assertTrue (CorsUtils .isCorsRequest (request ));
4044 }
4145
4246 @ Test
4347 public void isNotCorsRequest () {
44- MockServerHttpRequest request = get ("/" ).build ();
48+ ServerHttpRequest request = get ("/" ).build ();
4549 assertFalse (CorsUtils .isCorsRequest (request ));
4650 }
4751
4852 @ Test
4953 public void isPreFlightRequest () {
50- MockServerHttpRequest request = options ("/" )
54+ ServerHttpRequest request = options ("/" )
5155 .header (HttpHeaders .ORIGIN , "http://domain.com" )
5256 .header (HttpHeaders .ACCESS_CONTROL_REQUEST_METHOD , "GET" )
5357 .build ();
@@ -56,7 +60,7 @@ public void isPreFlightRequest() {
5660
5761 @ Test
5862 public void isNotPreFlightRequest () {
59- MockServerHttpRequest request = get ("/" ).build ();
63+ ServerHttpRequest request = get ("/" ).build ();
6064 assertFalse (CorsUtils .isPreFlightRequest (request ));
6165
6266 request = options ("/" ).header (HttpHeaders .ORIGIN , "http://domain.com" ).build ();
@@ -68,31 +72,35 @@ public void isNotPreFlightRequest() {
6872
6973 @ Test // SPR-16262
7074 public void isSameOriginWithXForwardedHeaders () {
71- assertTrue (checkSameOriginWithXForwardedHeaders ("mydomain1.com" , -1 , "https" , null , -1 , "https://mydomain1.com" ));
72- assertTrue (checkSameOriginWithXForwardedHeaders ("mydomain1.com" , 123 , "https" , null , -1 , "https://mydomain1.com" ));
73- assertTrue (checkSameOriginWithXForwardedHeaders ("mydomain1.com" , -1 , "https" , "mydomain2.com" , -1 , "https://mydomain2.com" ));
74- assertTrue (checkSameOriginWithXForwardedHeaders ("mydomain1.com" , 123 , "https" , "mydomain2.com" , -1 , "https://mydomain2.com" ));
75- assertTrue (checkSameOriginWithXForwardedHeaders ("mydomain1.com" , -1 , "https" , "mydomain2.com" , 456 , "https://mydomain2.com:456" ));
76- assertTrue (checkSameOriginWithXForwardedHeaders ("mydomain1.com" , 123 , "https" , "mydomain2.com" , 456 , "https://mydomain2.com:456" ));
75+ String server = "mydomain1.com" ;
76+ testWithXForwardedHeaders (server , -1 , "https" , null , -1 , "https://mydomain1.com" );
77+ testWithXForwardedHeaders (server , 123 , "https" , null , -1 , "https://mydomain1.com" );
78+ testWithXForwardedHeaders (server , -1 , "https" , "mydomain2.com" , -1 , "https://mydomain2.com" );
79+ testWithXForwardedHeaders (server , 123 , "https" , "mydomain2.com" , -1 , "https://mydomain2.com" );
80+ testWithXForwardedHeaders (server , -1 , "https" , "mydomain2.com" , 456 , "https://mydomain2.com:456" );
81+ testWithXForwardedHeaders (server , 123 , "https" , "mydomain2.com" , 456 , "https://mydomain2.com:456" );
7782 }
7883
7984 @ Test // SPR-16262
8085 public void isSameOriginWithForwardedHeader () {
81- assertTrue (checkSameOriginWithForwardedHeader ("mydomain1.com" , -1 , "proto=https" , "https://mydomain1.com" ));
82- assertTrue (checkSameOriginWithForwardedHeader ("mydomain1.com" , 123 , "proto=https" , "https://mydomain1.com" ));
83- assertTrue (checkSameOriginWithForwardedHeader ("mydomain1.com" , -1 , "proto=https; host=mydomain2.com" , "https://mydomain2.com" ));
84- assertTrue (checkSameOriginWithForwardedHeader ("mydomain1.com" , 123 , "proto=https; host=mydomain2.com" , "https://mydomain2.com" ));
85- assertTrue (checkSameOriginWithForwardedHeader ("mydomain1.com" , -1 , "proto=https; host=mydomain2.com:456" , "https://mydomain2.com:456" ));
86- assertTrue (checkSameOriginWithForwardedHeader ("mydomain1.com" , 123 , "proto=https; host=mydomain2.com:456" , "https://mydomain2.com:456" ));
86+ String server = "mydomain1.com" ;
87+ testWithForwardedHeader (server , -1 , "proto=https" , "https://mydomain1.com" );
88+ testWithForwardedHeader (server , 123 , "proto=https" , "https://mydomain1.com" );
89+ testWithForwardedHeader (server , -1 , "proto=https; host=mydomain2.com" , "https://mydomain2.com" );
90+ testWithForwardedHeader (server , 123 , "proto=https; host=mydomain2.com" , "https://mydomain2.com" );
91+ testWithForwardedHeader (server , -1 , "proto=https; host=mydomain2.com:456" , "https://mydomain2.com:456" );
92+ testWithForwardedHeader (server , 123 , "proto=https; host=mydomain2.com:456" , "https://mydomain2.com:456" );
8793 }
8894
89- private boolean checkSameOriginWithXForwardedHeaders (String serverName , int port , String forwardedProto , String forwardedHost , int forwardedPort , String originHeader ) {
95+ private void testWithXForwardedHeaders (String serverName , int port ,
96+ String forwardedProto , String forwardedHost , int forwardedPort , String originHeader ) {
97+
9098 String url = "http://" + serverName ;
9199 if (port != -1 ) {
92100 url = url + ":" + port ;
93101 }
94- MockServerHttpRequest . BaseBuilder <?> builder = get ( url )
95- .header (HttpHeaders .ORIGIN , originHeader );
102+
103+ MockServerHttpRequest . BaseBuilder <?> builder = get ( url ) .header (HttpHeaders .ORIGIN , originHeader );
96104 if (forwardedProto != null ) {
97105 builder .header ("X-Forwarded-Proto" , forwardedProto );
98106 }
@@ -102,18 +110,36 @@ private boolean checkSameOriginWithXForwardedHeaders(String serverName, int port
102110 if (forwardedPort != -1 ) {
103111 builder .header ("X-Forwarded-Port" , String .valueOf (forwardedPort ));
104112 }
105- return CorsUtils .isSameOrigin (builder .build ());
113+
114+ ServerHttpRequest request = adaptFromForwardedHeaders (builder );
115+ assertTrue (CorsUtils .isSameOrigin (request ));
106116 }
107117
108- private boolean checkSameOriginWithForwardedHeader (String serverName , int port , String forwardedHeader , String originHeader ) {
118+ private void testWithForwardedHeader (String serverName , int port ,
119+ String forwardedHeader , String originHeader ) {
120+
109121 String url = "http://" + serverName ;
110122 if (port != -1 ) {
111123 url = url + ":" + port ;
112124 }
125+
113126 MockServerHttpRequest .BaseBuilder <?> builder = get (url )
114127 .header ("Forwarded" , forwardedHeader )
115128 .header (HttpHeaders .ORIGIN , originHeader );
116- return CorsUtils .isSameOrigin (builder .build ());
129+
130+ ServerHttpRequest request = adaptFromForwardedHeaders (builder );
131+ assertTrue (CorsUtils .isSameOrigin (request ));
132+ }
133+
134+ // SPR-16668
135+ private ServerHttpRequest adaptFromForwardedHeaders (MockServerHttpRequest .BaseBuilder <?> builder ) {
136+ AtomicReference <ServerHttpRequest > requestRef = new AtomicReference <>();
137+ MockServerWebExchange exchange = MockServerWebExchange .from (builder );
138+ new ForwardedHeaderFilter ().filter (exchange , exchange2 -> {
139+ requestRef .set (exchange2 .getRequest ());
140+ return Mono .empty ();
141+ }).block ();
142+ return requestRef .get ();
117143 }
118144
119145}
0 commit comments