Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,6 @@

import static graphql.Directives.DeprecatedDirective;
import static graphql.com.google.common.base.Predicates.not;
import static graphql.introspection.Introspection.DirectiveLocation.ARGUMENT_DEFINITION;
import static graphql.introspection.Introspection.DirectiveLocation.ENUM_VALUE;
import static graphql.introspection.Introspection.DirectiveLocation.FIELD_DEFINITION;
import static graphql.introspection.Introspection.DirectiveLocation.INPUT_FIELD_DEFINITION;
import static graphql.schema.visibility.DefaultGraphqlFieldVisibility.DEFAULT_FIELD_VISIBILITY;
import static graphql.util.EscapeUtil.escapeJsonString;
import static java.util.Optional.ofNullable;
Expand All @@ -15,42 +11,8 @@
import graphql.Assert;
import graphql.PublicApi;
import graphql.execution.ValuesResolver;
import graphql.language.AstPrinter;
import graphql.language.Description;
import graphql.language.Document;
import graphql.language.EnumTypeDefinition;
import graphql.language.EnumValueDefinition;
import graphql.language.FieldDefinition;
import graphql.language.InputObjectTypeDefinition;
import graphql.language.InputValueDefinition;
import graphql.language.InterfaceTypeDefinition;
import graphql.language.ObjectTypeDefinition;
import graphql.language.ScalarTypeDefinition;
import graphql.language.TypeDefinition;
import graphql.language.UnionTypeDefinition;
import graphql.schema.DefaultGraphqlTypeComparatorRegistry;
import graphql.schema.GraphQLArgument;
import graphql.schema.GraphQLDirective;
import graphql.schema.GraphQLEnumType;
import graphql.schema.GraphQLEnumValueDefinition;
import graphql.schema.GraphQLFieldDefinition;
import graphql.schema.GraphQLInputObjectField;
import graphql.schema.GraphQLInputObjectType;
import graphql.schema.GraphQLInputType;
import graphql.schema.GraphQLInterfaceType;
import graphql.schema.GraphQLNamedOutputType;
import graphql.schema.GraphQLNamedType;
import graphql.schema.GraphQLObjectType;
import graphql.schema.GraphQLOutputType;
import graphql.schema.GraphQLScalarType;
import graphql.schema.GraphQLSchema;
import graphql.schema.GraphQLSchemaElement;
import graphql.schema.GraphQLType;
import graphql.schema.GraphQLTypeUtil;
import graphql.schema.GraphQLUnionType;
import graphql.schema.GraphqlTypeComparatorEnvironment;
import graphql.schema.GraphqlTypeComparatorRegistry;
import graphql.schema.InputValueWithState;
import graphql.language.*;
import graphql.schema.*;
import graphql.schema.idl.DirectiveInfo;
import graphql.schema.idl.ScalarInfo;
import graphql.schema.idl.SchemaParser;
Expand All @@ -76,11 +38,8 @@ public class FederationSdlPrinter {
// we use this so that we get the simple "@deprecated" as text and not a full exploded
// text with arguments (but only when we auto add this)
//
private static final GraphQLDirective DeprecatedDirective4Printing =
GraphQLDirective.newDirective()
.name("deprecated")
.validLocations(FIELD_DEFINITION, ENUM_VALUE, ARGUMENT_DEFINITION, INPUT_FIELD_DEFINITION)
.build();
private static final GraphQLAppliedDirective DeprecatedDirective4Printing =
GraphQLAppliedDirective.newDirective().name("deprecated").build();

/**
* This predicate excludes all directives which are specified bt the GraphQL Specification.
Expand All @@ -104,7 +63,7 @@ public static class Options {

private final boolean descriptionsAsHashComments;

private final Predicate<GraphQLDirective> includeDirective;
private final Predicate<GraphQLAppliedDirective> includeDirective;

private final Predicate<GraphQLDirective> includeDirectiveDefinition;

Expand All @@ -121,7 +80,7 @@ private Options(
boolean includeDirectiveDefinitions,
boolean useAstDefinitions,
boolean descriptionsAsHashComments,
Predicate<GraphQLDirective> includeDirective,
Predicate<GraphQLAppliedDirective> includeDirective,
Predicate<GraphQLDirective> includeDirectiveDefinition,
Predicate<GraphQLNamedType> includeTypeDefinition,
Predicate<GraphQLSchemaElement> includeSchemaElement,
Expand Down Expand Up @@ -155,7 +114,7 @@ public boolean isIncludeDirectiveDefinitions() {
return includeDirectiveDefinitions;
}

public Predicate<GraphQLDirective> getIncludeDirective() {
public Predicate<GraphQLAppliedDirective> getIncludeDirective() {
return includeDirective;
}

Expand Down Expand Up @@ -321,7 +280,7 @@ public Options includeDirectives(boolean flag) {
* @param includeDirective the predicate to decide of a directive is printed
* @return new instance of options
*/
public Options includeDirectives(Predicate<GraphQLDirective> includeDirective) {
public Options includeDirectives(Predicate<GraphQLAppliedDirective> includeDirective) {
return new Options(
this.includeIntrospectionTypes,
this.includeScalars,
Expand Down Expand Up @@ -580,7 +539,8 @@ private TypePrinter<GraphQLScalarType> scalarPrinter() {
printComments(out, type, "");
out.format(
"scalar %s%s\n\n",
type.getName(), directivesString(GraphQLScalarType.class, type.getDirectives()));
type.getName(),
directivesString(GraphQLScalarType.class, type.getAppliedDirectives()));
}
}
};
Expand All @@ -606,14 +566,15 @@ private TypePrinter<GraphQLEnumType> enumPrinter() {
printComments(out, type, "");
out.format(
"enum %s%s",
type.getName(), directivesString(GraphQLEnumType.class, type.getDirectives()));
type.getName(), directivesString(GraphQLEnumType.class, type.getAppliedDirectives()));
List<GraphQLEnumValueDefinition> values =
type.getValues().stream().sorted(comparator).collect(toList());
if (values.size() > 0) {
out.format(" {\n");
for (GraphQLEnumValueDefinition enumValueDefinition : values) {
printComments(out, enumValueDefinition, " ");
List<GraphQLDirective> enumValueDirectives = enumValueDefinition.getDirectives();
List<GraphQLAppliedDirective> enumValueDirectives =
enumValueDefinition.getAppliedDirectives();
if (enumValueDefinition.isDeprecated()) {
enumValueDirectives = addDeprecatedDirectiveIfNeeded(enumValueDirectives);
}
Expand Down Expand Up @@ -644,7 +605,7 @@ private void printFieldDefinitions(
.forEach(
fd -> {
printComments(out, fd, " ");
List<GraphQLDirective> fieldDirectives = fd.getDirectives();
List<GraphQLAppliedDirective> fieldDirectives = fd.getAppliedDirectives();
if (fd.isDeprecated()) {
fieldDirectives = addDeprecatedDirectiveIfNeeded(fieldDirectives);
}
Expand Down Expand Up @@ -672,7 +633,8 @@ private TypePrinter<GraphQLInterfaceType> interfacePrinter() {
if (type.getInterfaces().isEmpty()) {
out.format(
"interface %s%s",
type.getName(), directivesString(GraphQLInterfaceType.class, type.getDirectives()));
type.getName(),
directivesString(GraphQLInterfaceType.class, type.getAppliedDirectives()));
} else {

GraphqlTypeComparatorEnvironment environment =
Expand All @@ -691,7 +653,7 @@ private TypePrinter<GraphQLInterfaceType> interfacePrinter() {
"interface %s implements %s%s",
type.getName(),
interfaceNames.collect(joining(" & ")),
directivesString(GraphQLInterfaceType.class, type.getDirectives()));
directivesString(GraphQLInterfaceType.class, type.getAppliedDirectives()));
}

GraphqlTypeComparatorEnvironment environment =
Expand Down Expand Up @@ -728,7 +690,7 @@ private TypePrinter<GraphQLUnionType> unionPrinter() {
printComments(out, type, "");
out.format(
"union %s%s = ",
type.getName(), directivesString(GraphQLUnionType.class, type.getDirectives()));
type.getName(), directivesString(GraphQLUnionType.class, type.getAppliedDirectives()));
List<GraphQLNamedOutputType> types =
type.getTypes().stream().sorted(comparator).collect(toList());
for (int i = 0; i < types.size(); i++) {
Expand All @@ -755,7 +717,8 @@ private TypePrinter<GraphQLObjectType> objectPrinter() {
if (type.getInterfaces().isEmpty()) {
out.format(
"type %s%s",
type.getName(), directivesString(GraphQLObjectType.class, type.getDirectives()));
type.getName(),
directivesString(GraphQLObjectType.class, type.getAppliedDirectives()));
} else {

GraphqlTypeComparatorEnvironment environment =
Expand All @@ -774,7 +737,7 @@ private TypePrinter<GraphQLObjectType> objectPrinter() {
"type %s implements %s%s",
type.getName(),
interfaceNames.collect(joining(" & ")),
directivesString(GraphQLObjectType.class, type.getDirectives()));
directivesString(GraphQLObjectType.class, type.getAppliedDirectives()));
}

GraphqlTypeComparatorEnvironment environment =
Expand Down Expand Up @@ -810,7 +773,8 @@ private TypePrinter<GraphQLInputObjectType> inputObjectPrinter() {

out.format(
"input %s%s",
type.getName(), directivesString(GraphQLInputObjectType.class, type.getDirectives()));
type.getName(),
directivesString(GraphQLInputObjectType.class, type.getAppliedDirectives()));
List<GraphQLInputObjectField> inputObjectFields = visibility.getFieldDefinitions(type);
if (inputObjectFields.size() > 0) {
out.format(" {\n");
Expand All @@ -826,7 +790,8 @@ private TypePrinter<GraphQLInputObjectType> inputObjectPrinter() {
String astValue = printAst(defaultValue, fd.getType());
out.format(" = %s", astValue);
}
out.format(directivesString(GraphQLInputObjectField.class, fd.getDirectives()));
out.format(
directivesString(GraphQLInputObjectField.class, fd.getAppliedDirectives()));
out.format("\n");
});
out.format("}");
Expand Down Expand Up @@ -871,7 +836,7 @@ private static String printAst(InputValueWithState value, GraphQLInputType type)

private TypePrinter<GraphQLSchema> schemaPrinter() {
return (out, schema, visibility) -> {
List<GraphQLDirective> schemaDirectives = schema.getSchemaDirectives();
List<GraphQLAppliedDirective> schemaDirectives = schema.getSchemaAppliedDirectives();
GraphQLObjectType queryType = schema.getQueryType();
GraphQLObjectType mutationType = schema.getMutationType();
GraphQLObjectType subscriptionType = schema.getSubscriptionType();
Expand Down Expand Up @@ -917,7 +882,6 @@ private TypePrinter<GraphQLSchema> schemaPrinter() {

private List<GraphQLDirective> getSchemaDirectives(GraphQLSchema schema) {
return schema.getDirectives().stream()
.filter(options.getIncludeDirective())
.filter(options.getIncludeSchemaElement())
.filter(options.getIncludeDirectiveDefinition())
.collect(toList());
Expand Down Expand Up @@ -972,7 +936,7 @@ String argsString(Class<? extends GraphQLSchemaElement> parent, List<GraphQLArgu
sb.append(printAst(defaultValue, argument.getType()));
}

argument.getDirectives().stream()
argument.getAppliedDirectives().stream()
.filter(options.getIncludeSchemaElement())
.map(this::directiveString)
.filter(it -> !it.isEmpty())
Expand All @@ -990,7 +954,7 @@ String argsString(Class<? extends GraphQLSchemaElement> parent, List<GraphQLArgu
}

String directivesString(
Class<? extends GraphQLSchemaElement> parent, List<GraphQLDirective> directives) {
Class<? extends GraphQLSchemaElement> parent, List<GraphQLAppliedDirective> directives) {
directives =
directives.stream()
// @deprecated is special - we always print it if something is deprecated
Expand Down Expand Up @@ -1019,7 +983,7 @@ String directivesString(

directives = directives.stream().sorted(comparator).collect(toList());
for (int i = 0; i < directives.size(); i++) {
GraphQLDirective directive = directives.get(i);
GraphQLAppliedDirective directive = directives.get(i);
sb.append(directiveString(directive));
if (i < directives.size() - 1) {
sb.append(" ");
Expand All @@ -1028,7 +992,7 @@ String directivesString(
return sb.toString();
}

private String directiveString(GraphQLDirective directive) {
private String directiveString(GraphQLAppliedDirective directive) {
if (!options.getIncludeSchemaElement().test(directive)) {
return "";
}
Expand All @@ -1050,21 +1014,19 @@ private String directiveString(GraphQLDirective directive) {
Comparator<? super GraphQLSchemaElement> comparator =
options.comparatorRegistry.getComparator(environment);

List<GraphQLArgument> args = directive.getArguments();
List<GraphQLAppliedDirectiveArgument> args = directive.getArguments();
args =
args.stream()
.filter(arg -> arg.getArgumentValue().isSet() || arg.getArgumentDefaultValue().isSet())
.filter(arg -> arg.getArgumentValue().isSet())
.sorted(comparator)
.collect(toList());
if (!args.isEmpty()) {
sb.append("(");
for (int i = 0; i < args.size(); i++) {
GraphQLArgument arg = args.get(i);
GraphQLAppliedDirectiveArgument arg = args.get(i);
String argValue = null;
if (arg.hasSetValue()) {
argValue = printAst(arg.getArgumentValue(), arg.getType());
} else if (arg.hasSetDefaultValue()) {
argValue = printAst(arg.getArgumentDefaultValue(), arg.getType());
}
if (!isNullOrEmpty(argValue)) {
sb.append(arg.getName());
Expand All @@ -1080,15 +1042,16 @@ private String directiveString(GraphQLDirective directive) {
return sb.toString();
}

private boolean isDeprecatedDirective(GraphQLDirective directive) {
private boolean isDeprecatedDirective(GraphQLAppliedDirective directive) {
return directive.getName().equals(DeprecatedDirective.getName());
}

private boolean hasDeprecatedDirective(List<GraphQLDirective> directives) {
private boolean hasDeprecatedDirective(List<GraphQLAppliedDirective> directives) {
return directives.stream().filter(this::isDeprecatedDirective).count() == 1;
}

private List<GraphQLDirective> addDeprecatedDirectiveIfNeeded(List<GraphQLDirective> directives) {
private List<GraphQLAppliedDirective> addDeprecatedDirectiveIfNeeded(
List<GraphQLAppliedDirective> directives) {
if (!hasDeprecatedDirective(directives)) {
directives = new ArrayList<>(directives);
directives.add(DeprecatedDirective4Printing);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,8 @@
import graphql.ExecutionResult;
import graphql.Scalars;
import graphql.com.google.common.collect.ImmutableMap;
import graphql.schema.GraphQLFieldDefinition;
import graphql.schema.GraphQLNamedType;
import graphql.schema.GraphQLObjectType;
import graphql.schema.GraphQLScalarType;
import graphql.schema.GraphQLSchema;
import graphql.schema.GraphQLUnionType;
import graphql.language.StringValue;
import graphql.schema.*;
import graphql.schema.idl.RuntimeWiring;
import graphql.schema.idl.SchemaGenerator;
import graphql.schema.idl.SchemaParser;
Expand Down Expand Up @@ -51,6 +47,10 @@ class FederationTest {
TestUtils.readResource("schemas/printerFilterExpected.graphql");
private final String fed2SDL = TestUtils.readResource("schemas/fed2.graphql");
private final String fed2FederatedSDL = TestUtils.readResource("schemas/fed2Federated.graphql");
private final String directiveFederatedSDL =
TestUtils.readResource("schemas/directiveFederated.graphql");
private final String directiveFederatedServiceSDL =
TestUtils.readResource("schemas/directiveFederatedService.graphql");
private final String fed2ServiceSDL = TestUtils.readResource("schemas/fed2Service.graphql");
private final String unionsSDL = TestUtils.readResource("schemas/unions.graphql");
private final String unionsFederatedSDL =
Expand Down Expand Up @@ -237,6 +237,51 @@ void testFed2() {
SchemaUtils.assertSDL(federatedSchema, fed2FederatedSDL, fed2ServiceSDL);
}

@Test
void testFederationDirectives() {
final GraphQLSchema federatedSchema =
Federation.transform(
GraphQLSchema.newSchema()
.query(
GraphQLObjectType.newObject()
.name("Query")
.field(
GraphQLFieldDefinition.newFieldDefinition()
.name("dummy")
.type(Scalars.GraphQLString)
.build())
.build())
.additionalType(
GraphQLObjectType.newObject()
.name("Test")
.withAppliedDirective(
GraphQLAppliedDirective.newDirective()
.name("key")
.argument(
GraphQLAppliedDirectiveArgument.newArgument()
.name("fields")
.type(_FieldSet.type)
.valueLiteral(
StringValue.newStringValue("dummy").build())
.build())
.build())
.field(
GraphQLFieldDefinition.newFieldDefinition()
.name("dummy")
.type(Scalars.GraphQLString)
.build())
.build())
.build(),
true)
.resolveEntityType(env -> env.getSchema().getObjectType("Point"))
.fetchEntities(
entityFetcher ->
ImmutableMap.builder().put("id", "1000").put("x", 0).put("y", 0).build())
.build();

SchemaUtils.assertSDL(federatedSchema, directiveFederatedSDL, directiveFederatedServiceSDL);
}

@Test
void testUnions() {
final RuntimeWiring runtimeWiring = RuntimeWiring.newRuntimeWiring().build();
Expand Down
Loading