Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
f814eb9
Draft Context Free Grammar Constrained Decoding
Saibo-creator Nov 17, 2023
7940a54
add support for batch decoding
Saibo-creator Nov 17, 2023
3c97dfb
integrate GCD to `generate` function api
Saibo-creator Nov 17, 2023
f171b0c
fix: add stop condition that if stacks is empty, the eos_token should…
Saibo-creator Nov 17, 2023
5c1d519
remove trailing white space from json grammar
Saibo-creator Nov 18, 2023
ff9c245
downgrade "empty stack" from warning to debug.
Saibo-creator Nov 18, 2023
c8e62aa
Fix: Raise Error when the input tokens are not consistent with the gr…
Saibo-creator Nov 18, 2023
e5677ea
fix: fix parsing error with \\ by adding this case as escape char
Saibo-creator Nov 19, 2023
2c53fcf
use the same json grammar as in llama-cpp
Saibo-creator Nov 19, 2023
70e52ef
feat: add c grammar
Saibo-creator Nov 19, 2023
f9c7081
hotfix: rolling json.gbnf back to the previous version. Parsing is no…
Saibo-creator Nov 20, 2023
ccdacbd
fix newline error, now the same state as in commit 3e78f00 in llama-c…
Saibo-creator Nov 20, 2023
5daf160
add tests for grammar-constrained decoding
Saibo-creator Nov 23, 2023
48e0567
1.run code cleaning
Saibo-creator Nov 23, 2023
313708c
support top-k and top-p sampling
Saibo-creator Nov 23, 2023
904baba
Merge branch 'huggingface:main' into feature/cfgcd
Saibo-creator Nov 28, 2023
ecfd347
Merge branch 'huggingface:main' into feature/cfgcd
Saibo-creator Nov 30, 2023
8af9c88
introduce transformers-CFG README.md
Saibo-creator Dec 1, 2023
fb890bd
add example in README.md
Saibo-creator Dec 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
537 changes: 44 additions & 493 deletions README.md

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions examples/grammars/arithmetic.gbnf
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
root ::= (expr "=" ws term "\n")+
expr ::= term ([-+*/] term)*
term ::= ident | num | "(" ws expr ")" ws
ident ::= [a-z] [a-z0-9_]* ws
num ::= [0-9]+ ws
ws ::= [ \t\n]*
42 changes: 42 additions & 0 deletions examples/grammars/c.gbnf
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
root ::= (declaration)*

declaration ::= dataType identifier "(" parameter? ")" "{" statement* "}"

dataType ::= "int" ws | "float" ws | "char" ws
identifier ::= [a-zA-Z_] [a-zA-Z_0-9]*

parameter ::= dataType identifier

statement ::=
( dataType identifier ws "=" ws expression ";" ) |
( identifier ws "=" ws expression ";" ) |
( identifier ws "(" argList? ")" ";" ) |
( "return" ws expression ";" ) |
( "while" "(" condition ")" "{" statement* "}" ) |
( "for" "(" forInit ";" ws condition ";" ws forUpdate ")" "{" statement* "}" ) |
( "if" "(" condition ")" "{" statement* "}" ("else" "{" statement* "}")? ) |
( singleLineComment ) |
( multiLineComment )

forInit ::= dataType identifier ws "=" ws expression | identifier ws "=" ws expression
forUpdate ::= identifier ws "=" ws expression

condition ::= expression relationOperator expression
relationOperator ::= ("<=" | "<" | "==" | "!=" | ">=" | ">")

expression ::= term (("+" | "-") term)*
term ::= factor(("*" | "/") factor)*

factor ::= identifier | number | unaryTerm | funcCall | parenExpression
unaryTerm ::= "-" factor
funcCall ::= identifier "(" argList? ")"
parenExpression ::= "(" ws expression ws ")"

argList ::= expression ("," ws expression)*

number ::= [0-9]+

singleLineComment ::= "//" [^\n]* "\n"
multiLineComment ::= "/*" ( [^*] | ("*" [^/]) )* "*/"

ws ::= ([ \t\n]+)
13 changes: 13 additions & 0 deletions examples/grammars/chess.gbnf
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Specifies chess moves as a list in algebraic notation, using PGN conventions

# Force first move to "1. ", then any 1-2 digit number after, relying on model to follow the pattern
root ::= "1. " move " " move "\n" ([1-9] [0-9]? ". " move " " move "\n")+
move ::= (pawn | nonpawn | castle) [+#]?

# piece type, optional file/rank, optional capture, dest file & rank
nonpawn ::= [NBKQR] [a-h]? [1-8]? "x"? [a-h] [1-8]

# optional file & capture, dest file & rank, optional promotion
pawn ::= ([a-h] "x")? [a-h] [1-8] ("=" [NBKQR])?

castle ::= "O-O" "-O"?
14 changes: 14 additions & 0 deletions examples/grammars/json.gbnf
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
root ::= object

object ::= "{" ws ( string ":" ws value ("," ws string ":" ws value)* )? "}"

value ::= object | array | string | number | ("true" | "false" | "null") ws

array ::= "[" ws ( value ("," ws value)* )? "]" ws

string ::= "\"" ( [a-zA-Z0-9] )* "\"" ws

number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws


ws ::= ([ \t\n] ws)?
14 changes: 14 additions & 0 deletions examples/grammars/json_w_trailing_space.gbnf
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
root ::= object

object ::= "{" ws ( string ":" ws value ("," ws string ":" ws value)* )? "}" ws

value ::= object | array | string | number | ("true" | "false" | "null") ws

array ::= "[" ws ( value ("," ws value)* )? "]" ws

string ::= "\"" ( [a-zA-Z0-9] )* "\"" ws

number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws


ws ::= ([ \t\n] ws)?
7 changes: 7 additions & 0 deletions examples/grammars/simple_arithmetic.gbnf
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
root ::= (expr "=" ws term "\n")+
expr ::= term ([-+*/] term)*
term ::= num | "(" ws expr ")" ws
num ::= [0-9]+ ws
ws ::= [ \t\n]*
# this is a comment

Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2022 University of Cambridge, Tencent AI Lab, DeepMind and The University of Hong Kong Authors and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


# import torch

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.grammar_utils import IncrementalGrammarConstraint
from transformers.generation.logits_process import GrammarConstrainedLogitsProcessor


if __name__ == "__main__":

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained("gpt2")

# Load grammar
with open("examples/grammars/json.gbnf", "r") as file:
grammar_str = file.read()
grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer)
grammar_processor = GrammarConstrainedLogitsProcessor(grammar)


# Generate
prefix1 = "This is a valid json string for http request:"
prefix2 = "This is a valid json string for shopping cart:"
input_ids = tokenizer([prefix1, prefix2], add_special_tokens=False, return_tensors="pt", padding=True)["input_ids"]

output = model.generate(
input_ids,
do_sample=False,
max_length=50,
num_beams=2,
logits_processor=[grammar_processor],
repetition_penalty=5.0,
num_return_sequences=1,
)
# decode output
generations = tokenizer.batch_decode(output, skip_special_tokens=True)
print(generations)

"""
'This is a valid json string for http request:{ "request": { "method": "GET", "headers": [], "content": "Content","type": "application" }}
'This is a valid json string for shopping cart:This is a valid json string for shopping cart:{ "name": "MyCart", "price": 0, "value": 1 }
"""
Loading