Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ const graphqlMiddleware = graphqlHTTP({
});
```

You can provide a configuration object with custom costs for scalars and objects as `scalarCost` and `objectCost` respectively, and a custom cost factor for lists as `listFactor`.
You can provide a configuration object with custom global costs for scalars and objects as `scalarCost` and `objectCost` respectively, and a custom cost factor for lists as `listFactor`.

```js
const ComplexityLimitRule = createComplexityLimitRule(1000, {
Expand All @@ -34,6 +34,29 @@ const ComplexityLimitRule = createComplexityLimitRule(1000, {
});
```

You can also set custom costs and cost factors on fields definitions with `getCost` and `getCostFactor` callbacks.

```js
const expensiveField = {
type: ExpensiveItem,
getCost: () => 50,
};

const expensiveList = {
type: new GraphQLList(MyItem),
getCostFactor: () => 100,
};
```

You can also define these via field directives in the SDL.

```graphql
type CustomCostItem {
expensiveField: ExpensiveItem @cost(value: 50)
expensiveList: [MyItem] @costFactor(value: 100)
}
```

The configuration object also supports an `onCost` callback for logging query costs and a `formatErrorMessage` callback for customizing error messages. `onCost` will be called for every query with its cost. `formatErrorMessage` will be called with the cost whenever a query exceeds the complexity limit, and should return a string containing the error message.

```js
Expand Down
118 changes: 80 additions & 38 deletions src/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ export class ComplexityVisitor {
this.introspectionListFactor = introspectionListFactor;

this.currentFragment = null;
this.listDepth = 0;
this.introspectionListDepth = 0;
this.costFactor = 1;

this.rootCalculator = new CostCalculator();
this.fragmentCalculators = Object.create(null);
Expand All @@ -80,25 +79,38 @@ export class ComplexityVisitor {
}

enterField() {
this.enterType(this.context.getType());
this.costFactor *= this.getFieldCostFactor();
this.getCalculator().addImmediate(this.costFactor * this.getFieldCost());
}

enterType(type) {
leaveField() {
this.costFactor /= this.getFieldCostFactor();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

every time, i forget this operator exists

}

getFieldCostFactor() {
const fieldDef = this.context.getFieldDef();
if (fieldDef.getCostFactor) {
return fieldDef.getCostFactor();
}

const directiveCostFactor = this.getDirectiveValue('costFactor');
if (directiveCostFactor != null) {
return directiveCostFactor;
}

return this.getTypeCostFactor(this.context.getType());
}

getTypeCostFactor(type) {
if (type instanceof GraphQLNonNull) {
this.enterType(type.ofType);
return this.getTypeCostFactor(type.ofType);
} else if (type instanceof GraphQLList) {
if (this.isIntrospectionList(type)) {
++this.introspectionListDepth;
} else {
++this.listDepth;
}

this.enterType(type.ofType);
} else {
const fieldCost = type instanceof GraphQLObjectType ?
this.objectCost : this.scalarCost;
this.getCalculator().addImmediate(this.getDepthFactor() * fieldCost);
const typeListFactor = this.isIntrospectionList(type) ?
this.introspectionListFactor : this.listFactor;
return typeListFactor * this.getTypeCostFactor(type.ofType);
}

return 1;
}

isIntrospectionList({ ofType }) {
Expand All @@ -110,38 +122,68 @@ export class ComplexityVisitor {
return IntrospectionTypes[type.name] === type;
}

getCalculator() {
return this.currentFragment === null ?
this.rootCalculator : this.fragmentCalculators[this.currentFragment];
}
getFieldCost() {
const fieldDef = this.context.getFieldDef();
if (fieldDef.getCost) {
return fieldDef.getCost();
}

const directiveCost = this.getDirectiveValue('cost');
if (directiveCost != null) {
return directiveCost;
}

getDepthFactor() {
return (
this.listFactor ** this.listDepth *
this.introspectionListFactor ** this.introspectionListDepth
);
return this.getTypeCost(this.context.getType());
}

leaveField() {
this.leaveType(this.context.getType());
getTypeCost(type) {
if (type instanceof GraphQLNonNull || type instanceof GraphQLList) {
return this.getTypeCost(type.ofType);
}

return type instanceof GraphQLObjectType ?
this.objectCost : this.scalarCost;
}

leaveType(type) {
if (type instanceof GraphQLNonNull) {
this.leaveType(type.ofType);
} else if (type instanceof GraphQLList) {
if (this.isIntrospectionList(type)) {
--this.introspectionListDepth;
} else {
--this.listDepth;
}
getDirectiveValue(directiveName) {
const fieldDef = this.context.getFieldDef();

const { astNode } = fieldDef;
if (!astNode || !astNode.directives) {
return null;
}

this.leaveType(type.ofType);
const directive = astNode.directives.find(({ name }) => (
name.value === directiveName
));
if (!directive) {
return null;
}

const valueArgument = directive.arguments.find(argument => (
argument.name.value === 'value'
));

if (!valueArgument) {
const fieldName = fieldDef.name;
const parentTypeName = this.context.getParentType().name;

throw new Error(
`No \`value\` argument defined in \`@${directiveName}\` directive ` +
`on \`${fieldName}\` field on \`${parentTypeName}\`.`,
);
}

return parseFloat(valueArgument.value.value);
}

getCalculator() {
return this.currentFragment === null ?
this.rootCalculator : this.fragmentCalculators[this.currentFragment];
}

enterFragmentSpread(node) {
this.getCalculator().addFragment(this.getDepthFactor(), node.name.value);
this.getCalculator().addFragment(this.costFactor, node.name.value);
}

enterFragmentDefinition(node) {
Expand Down
61 changes: 60 additions & 1 deletion test/ComplexityVisitor.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@ import {
import { ComplexityVisitor } from '../src';

import schema from './fixtures/schema';
import sdlSchema from './fixtures/sdlSchema';

describe('ComplexityVisitor', () => {
const typeInfo = new TypeInfo(schema);
const sdlTypeInfo = new TypeInfo(sdlSchema);

describe('simple queries', () => {
it('should calculate the correct cost', () => {
Expand Down Expand Up @@ -109,7 +111,7 @@ describe('ComplexityVisitor', () => {
});
});

describe('custom costs', () => {
describe('custom visitor costs', () => {
it('should calculate the correct cost', () => {
const ast = parse(`
query {
Expand All @@ -135,6 +137,63 @@ describe('ComplexityVisitor', () => {
});
});

describe('custom field costs', () => {
it('should calculate the correct cost', () => {
const ast = parse(`
query {
expensiveItem {
name
}
expensiveList {
name
}
}
`);

const context = new ValidationContext(schema, ast, typeInfo);
const visitor = new ComplexityVisitor(context, {});

visit(ast, visitWithTypeInfo(typeInfo, visitor));
expect(visitor.getCost()).toBe(271);
});

it('should calculate the correct cost on an SDL schema', () => {
const ast = parse(`
query {
expensiveItem {
name
}
expensiveList {
name
}
}
`);

const context = new ValidationContext(sdlSchema, ast, sdlTypeInfo);
const visitor = new ComplexityVisitor(context, {});

visit(ast, visitWithTypeInfo(sdlTypeInfo, visitor));
expect(visitor.getCost()).toBe(271);
});

it('should error on missing value in cost directive', () => {
const ast = parse(`
query {
missingCostValue
}
`);

const context = new ValidationContext(sdlSchema, ast, sdlTypeInfo);
const visitor = new ComplexityVisitor(context, {});

expect(() => {
visit(ast, visitWithTypeInfo(sdlTypeInfo, visitor));
}).toThrow(
/`@cost` directive on `missingCostValue` field on `Query`/,
);
});
});

describe('introspection query', () => {
it('should calculate a reduced cost for the introspection query', () => {
const ast = parse(introspectionQuery);
Expand Down
20 changes: 20 additions & 0 deletions test/createComplexityLimitRule.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { GraphQLError, parse, validate } from 'graphql';
import { createComplexityLimitRule } from '../src';

import schema from './fixtures/schema';
import sdlSchema from './fixtures/sdlSchema';

describe('createComplexityLimitRule', () => {
it('should not report errors on a valid query', () => {
Expand Down Expand Up @@ -55,6 +56,25 @@ describe('createComplexityLimitRule', () => {
expect(onCostSpy).toHaveBeenCalledWith(1);
});

it('should call onCost with complexity score on an SDL schema', () => {
const ast = parse(`
query {
expensiveItem {
name
}
}
`);

const onCostSpy = jest.fn();

const errors = validate(sdlSchema, ast, [
createComplexityLimitRule(60, { onCost: onCostSpy }),
]);

expect(errors).toHaveLength(0);
expect(onCostSpy).toHaveBeenCalledWith(51);
});

it('should call onCost with cost when there are errors', () => {
const ast = parse(`
query {
Expand Down
15 changes: 14 additions & 1 deletion test/fixtures/schema.js
Original file line number Diff line number Diff line change
@@ -1,13 +1,26 @@
import {
GraphQLList, GraphQLObjectType, GraphQLNonNull, GraphQLSchema, GraphQLString,
GraphQLList,
GraphQLObjectType,
GraphQLNonNull,
GraphQLSchema,
GraphQLString,
} from 'graphql';

const Item = new GraphQLObjectType({
name: 'Item',
fields: () => ({
name: { type: GraphQLString },
item: { type: Item },
expensiveItem: {
type: Item,
getCost: () => 50,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this safe to extend like this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, it works, doesn’t it?

},
list: { type: new GraphQLList(Item) },
expensiveList: {
type: new GraphQLList(Item),
getCost: () => 10,
getCostFactor: () => 20,
},
nonNullItem: {
type: new GraphQLNonNull(Item),
resolve: () => ({}),
Expand Down
14 changes: 14 additions & 0 deletions test/fixtures/sdlSchema.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import { buildSchema } from 'graphql';

export default buildSchema(`
type Query {
name: String

item: Query
expensiveItem: Query @cost(value: 50)
list: [Query]
expensiveList: [Query] @cost(value: 10) @costFactor(value: 20)

missingCostValue: String @cost
}
`);