Skip to content
This repository was archived by the owner on Dec 19, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package graphql.kickstart.autoconfigure.web.servlet;

import lombok.RequiredArgsConstructor;

@RequiredArgsConstructor
class DefaultWsCsrfToken implements WsCsrfToken {

private final String token;
private final String parameterName;

@Override
public String getToken() {
return token;
}

@Override
public String getParameterName() {
return parameterName;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,12 @@ class GraphQLSubscriptionWebsocketProperties {

private String path = "/subscriptions";
private List<String> allowedOrigins = emptyList();
private CsrfProperties csrf = new CsrfProperties();

@Data
static
class CsrfProperties {

private boolean enabled = false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Conditional;
import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
import org.springframework.web.servlet.DispatcherServlet;
import org.springframework.web.socket.server.standard.ServerEndpointExporter;
import org.springframework.web.socket.server.standard.ServerEndpointRegistration;
Expand Down Expand Up @@ -62,11 +63,7 @@ public GraphQLWebsocketServlet graphQLWebsocketServlet(
}
keepAliveListener().ifPresent(listeners::add);
return new GraphQLWebsocketServlet(
graphQLInvoker,
invocationInputFactory,
graphQLObjectMapper,
listeners,
websocketProperties.getAllowedOrigins());
graphQLInvoker, invocationInputFactory, graphQLObjectMapper, listeners);
}

private Optional<SubscriptionConnectionListener> keepAliveListener() {
Expand All @@ -78,10 +75,28 @@ private Optional<SubscriptionConnectionListener> keepAliveListener() {
return Optional.empty();
}

@Bean
public WsCsrfFilter wsCsrfFilter(
@Autowired(required = false) WsCsrfTokenRepository csrfTokenRepository) {
return new WsCsrfFilter(websocketProperties.getCsrf(), csrfTokenRepository);
}

@Bean
@ConditionalOnMissingBean
@ConditionalOnClass(HttpSessionCsrfTokenRepository.class)
public WsCsrfTokenRepository wsCsrfTokenRepository() {
return new WsSessionCsrfTokenRepository();
}

@Bean
@ConditionalOnClass(ServerContainer.class)
public ServerEndpointRegistration serverEndpointRegistration(GraphQLWebsocketServlet servlet) {
return new GraphQLWsServerEndpointRegistration(websocketProperties.getPath(), servlet);
public ServerEndpointRegistration serverEndpointRegistration(
GraphQLWebsocketServlet servlet, WsCsrfFilter csrfFilter) {
return new GraphQLWsServerEndpointRegistration(
websocketProperties.getPath(),
servlet,
csrfFilter,
websocketProperties.getAllowedOrigins());
}

@Bean
Expand Down
Original file line number Diff line number Diff line change
@@ -1,32 +1,63 @@
package graphql.kickstart.autoconfigure.web.servlet;

import graphql.kickstart.servlet.GraphQLWebsocketServlet;
import java.util.ArrayList;
import java.util.List;
import jakarta.websocket.HandshakeResponse;
import jakarta.websocket.server.HandshakeRequest;
import jakarta.websocket.server.ServerEndpointConfig;
import org.springframework.context.Lifecycle;
import org.springframework.web.socket.server.standard.ServerEndpointRegistration;

/** @author Andrew Potter */
/**
* @author Andrew Potter
*/
public class GraphQLWsServerEndpointRegistration extends ServerEndpointRegistration
implements Lifecycle {

private static final String ALL = "*";
private final GraphQLWebsocketServlet servlet;
private final WsCsrfFilter csrfFilter;
private final List<String> allowedOrigins;

public GraphQLWsServerEndpointRegistration(String path, GraphQLWebsocketServlet servlet) {
public GraphQLWsServerEndpointRegistration(
String path,
GraphQLWebsocketServlet servlet,
WsCsrfFilter csrfFilter,
List<String> allowedOrigins) {
super(path, servlet);
this.servlet = servlet;
if (allowedOrigins == null || allowedOrigins.isEmpty()) {
this.allowedOrigins = List.of(ALL);
} else {
this.allowedOrigins = new ArrayList<>(allowedOrigins);
}
this.csrfFilter = csrfFilter;
}

@Override
public boolean checkOrigin(String originHeaderValue) {
return servlet.checkOrigin(originHeaderValue);
if (originHeaderValue == null || originHeaderValue.isBlank()) {
return allowedOrigins.contains(ALL);
}
if (allowedOrigins.contains(ALL)) {
return true;
}
String originToCheck = trimTrailingSlash(originHeaderValue);
return allowedOrigins.stream()
.map(this::trimTrailingSlash)
.anyMatch(originToCheck::equalsIgnoreCase);
}

private String trimTrailingSlash(String origin) {
return (origin.endsWith("/") ? origin.substring(0, origin.length() - 1) : origin);
}

@Override
public void modifyHandshake(
ServerEndpointConfig sec, HandshakeRequest request, HandshakeResponse response) {
super.modifyHandshake(sec, request, response);
csrfFilter.doFilter(request);
servlet.modifyHandshake(sec, request, response);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package graphql.kickstart.autoconfigure.web.servlet;

import static org.springframework.util.CollectionUtils.firstElement;

import graphql.kickstart.autoconfigure.web.servlet.GraphQLSubscriptionWebsocketProperties.CsrfProperties;
import jakarta.websocket.server.HandshakeRequest;
import java.util.Objects;
import lombok.RequiredArgsConstructor;

@RequiredArgsConstructor
class WsCsrfFilter {

private final CsrfProperties csrfProperties;
private final WsCsrfTokenRepository tokenRepository;

void doFilter(HandshakeRequest request) {
if (csrfProperties.isEnabled() && tokenRepository != null) {
WsCsrfToken csrfToken = tokenRepository.loadToken(request);
boolean missingToken = csrfToken == null;
if (missingToken) {
csrfToken = tokenRepository.generateToken(request);
tokenRepository.saveToken(csrfToken, request);
}

String actualToken =
firstElement(request.getParameterMap().get(csrfToken.getParameterName()));
if (!Objects.equals(csrfToken.getToken(), actualToken)) {
throw new IllegalStateException(
"Invalid CSRF Token '"
+ actualToken
+ "' was found on the request parameter '"
+ csrfToken.getParameterName()
+ "'.");
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package graphql.kickstart.autoconfigure.web.servlet;

import java.io.Serializable;

public interface WsCsrfToken extends Serializable {

String getToken();

String getParameterName();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package graphql.kickstart.autoconfigure.web.servlet;

import jakarta.websocket.server.HandshakeRequest;

public interface WsCsrfTokenRepository {

WsCsrfToken loadToken(HandshakeRequest request);

WsCsrfToken generateToken(HandshakeRequest request);

void saveToken(WsCsrfToken csrfToken, HandshakeRequest request);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package graphql.kickstart.autoconfigure.web.servlet;

import jakarta.servlet.http.HttpSession;
import jakarta.websocket.server.HandshakeRequest;
import java.util.UUID;
import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;

class WsSessionCsrfTokenRepository implements WsCsrfTokenRepository {

private static final String DEFAULT_CSRF_PARAMETER_NAME = "_csrf";

private static final String DEFAULT_CSRF_TOKEN_ATTR_NAME =
HttpSessionCsrfTokenRepository.class.getName().concat(".CSRF_TOKEN");

private String sessionAttributeName = DEFAULT_CSRF_TOKEN_ATTR_NAME;

@Override
public void saveToken(WsCsrfToken token, HandshakeRequest request) {
HttpSession session = (HttpSession) request.getHttpSession();
if (session != null) {
if (token == null) {
session.removeAttribute(this.sessionAttributeName);
} else {
session.setAttribute(this.sessionAttributeName, token);
}
}
}

@Override
public WsCsrfToken loadToken(HandshakeRequest request) {
HttpSession session = (HttpSession) request.getHttpSession();
if (session == null) {
return null;
}
return (WsCsrfToken) session.getAttribute(this.sessionAttributeName);
}

@Override
public WsCsrfToken generateToken(HandshakeRequest request) {
return new DefaultWsCsrfToken(UUID.randomUUID().toString(), DEFAULT_CSRF_PARAMETER_NAME);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package graphql.kickstart.autoconfigure.web.servlet;

import static org.assertj.core.api.AssertionsForClassTypes.assertThat;

import graphql.kickstart.servlet.GraphQLWebsocketServlet;
import java.util.List;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;

@ExtendWith(MockitoExtension.class)
class GraphQLWsServerEndpointRegistrationTest {

private static final String PATH = "/subscriptions";

@Mock private GraphQLWebsocketServlet servlet;
@Mock private WsCsrfFilter csrfFilter;

@ParameterizedTest
@CsvSource(
value = {"https://trusted.com", "NULL", "' '"},
nullValues = {"NULL"})
void givenDefaultAllowedOrigins_whenCheckOrigin_thenReturnTrue(String origin) {
var registration = createRegistration();
var allowed = registration.checkOrigin("null".equals(origin) ? null : origin);
assertThat(allowed).isTrue();
}

private GraphQLWsServerEndpointRegistration createRegistration(String... allowedOrigins) {
return new GraphQLWsServerEndpointRegistration(
PATH, servlet, csrfFilter, List.of(allowedOrigins));
}

@ParameterizedTest(name = "{index} => allowedOrigin=''{0}'', originToCheck=''{1}''")
@CsvSource(
delimiterString = "|",
textBlock =
"""
* | https://trusted.com
https://trusted.com | https://trusted.com
https://trusted.com/ | https://trusted.com
https://trusted.com/ | https://trusted.com/
https://trusted.com | https://trusted.com/
""")
void givenAllowedOrigins_whenCheckOrigin_thenReturnTrue(
String allowedOrigin, String originToCheck) {
var registration = createRegistration(allowedOrigin);
var allowed = registration.checkOrigin(originToCheck);
assertThat(allowed).isTrue();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package graphql.kickstart.autoconfigure.web.servlet;

import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import graphql.kickstart.autoconfigure.web.servlet.GraphQLSubscriptionWebsocketProperties.CsrfProperties;
import jakarta.websocket.server.HandshakeRequest;
import java.util.List;
import java.util.Map;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;

@ExtendWith(MockitoExtension.class)
class WsCsrfFilterTest {

private CsrfProperties csrfProperties = new CsrfProperties();
@Mock private WsCsrfTokenRepository tokenRepository;
@Mock private HandshakeRequest handshakeRequest;

@Test
void givenCsrfDisabled_whenDoFilter_thenDoesNotLoadToken() {
csrfProperties.setEnabled(false);
WsCsrfFilter filter = new WsCsrfFilter(csrfProperties, tokenRepository);
filter.doFilter(handshakeRequest);

verify(tokenRepository, never()).loadToken(any());
}

@Test
void givenCsrfEnabledAndRepositoryNull_whenDoFilter_thenDoesNotGetTokenFromRequest() {
csrfProperties.setEnabled(true);
WsCsrfFilter filter = new WsCsrfFilter(csrfProperties, null);
filter.doFilter(handshakeRequest);

verify(handshakeRequest, never()).getParameterMap();
}

@Test
void givenNoTokenInSession_whenDoFilter_thenGenerateAndSaveToken() {
csrfProperties.setEnabled(true);
when(tokenRepository.loadToken(handshakeRequest)).thenReturn(null);
WsCsrfToken csrfToken = mock(WsCsrfToken.class);
when(tokenRepository.generateToken(handshakeRequest)).thenReturn(csrfToken);

WsCsrfFilter filter = new WsCsrfFilter(csrfProperties, tokenRepository);
filter.doFilter(handshakeRequest);

verify(tokenRepository).saveToken(csrfToken, handshakeRequest);
}

@Test
void givenDifferentActualToken_whenDoFilter_thenThrowsException() {
csrfProperties.setEnabled(true);
WsCsrfToken csrfToken = new DefaultWsCsrfToken("some-token", "_csrf");
when(tokenRepository.loadToken(handshakeRequest)).thenReturn(csrfToken);
when(handshakeRequest.getParameterMap())
.thenReturn(Map.of("_csrf", List.of("different-token")));

WsCsrfFilter filter = new WsCsrfFilter(csrfProperties, tokenRepository);
assertThatThrownBy(() -> filter.doFilter(handshakeRequest))
.isInstanceOf(IllegalStateException.class)
.hasMessage(
"Invalid CSRF Token 'different-token' was found on the request parameter '_csrf'.");
}

@Test
void givenSameToken_whenDoFilter_thenDoesNotThrow() {
csrfProperties.setEnabled(true);
WsCsrfToken csrfToken = new DefaultWsCsrfToken("some-token", "_csrf");
when(tokenRepository.loadToken(handshakeRequest)).thenReturn(csrfToken);
when(handshakeRequest.getParameterMap())
.thenReturn(Map.of("_csrf", List.of("some-token")));

WsCsrfFilter filter = new WsCsrfFilter(csrfProperties, tokenRepository);
assertDoesNotThrow(() -> filter.doFilter(handshakeRequest));

verify(tokenRepository).loadToken(handshakeRequest);
}
}
Loading