Skip to content

Commit 9cf0dfa

Browse files
committed
Implement SMT match and lambda terms, reimplement SmtTerm#toString so as to output SMTLib S-exprs
1 parent 4fd1fcd commit 9cf0dfa

File tree

5 files changed

+135
-36
lines changed

5 files changed

+135
-36
lines changed

src/main/java/org/semgus/java/event/SmtSpecEvent.java

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55
import org.semgus.java.object.TypedVar;
66

77
import java.util.List;
8+
import java.util.stream.Collectors;
89

910
/**
1011
* A SemGuS parser event of the "smt" type.
1112
*/
1213
public sealed interface SmtSpecEvent extends SpecEvent {
1314

1415
/**
15-
* A "declare-function" event declaring the signature of an auxiliary function.
16+
* A "declare-function" event declaring the signature of a function.
1617
*
1718
* @param name The name of the function.
1819
* @param returnType The return type of the function.
@@ -27,7 +28,7 @@ record DeclareFunctionEvent(
2728
}
2829

2930
/**
30-
* A "define-function" event giving a definition for a previously-declared auxiliary function.
31+
* A "define-function" event giving a definition for a previously-declared function.
3132
*
3233
* @param name The name of the function.
3334
* @param returnType The return type of the function.
@@ -40,11 +41,20 @@ record DefineFunctionEvent(
4041
List<TypedVar> arguments,
4142
SmtTerm body
4243
) implements SmtSpecEvent {
43-
// NO-OP
44+
45+
/**
46+
* Constructs a {@link org.semgus.java.object.SmtTerm.Lambda} lambda abstraction from the function definition.
47+
*
48+
* @return The new lambda abstraction SMT term.
49+
*/
50+
public SmtTerm toLambda() {
51+
return new SmtTerm.Lambda(arguments.stream().map(TypedVar::name).collect(Collectors.toList()), body);
52+
}
53+
4454
}
4555

4656
/**
47-
* A "declare-datatype" event declaring the signature of an auxiliary datatype.
57+
* A "declare-datatype" event declaring the signature of a datatype.
4858
*
4959
* @param name The name of the datatype.
5060
*/
@@ -53,7 +63,7 @@ record DeclareDatatypeEvent(String name) implements SmtSpecEvent {
5363
}
5464

5565
/**
56-
* A "define-datatype" event giving a definition for a previously-declared auxiliary datatype.
66+
* A "define-datatype" event giving a definition for a previously-declared datatype.
5767
*
5868
* @param name The name of the datatype.
5969
* @param constructors The constructors of the datatype.

src/main/java/org/semgus/java/object/SmtContext.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
import java.util.stream.Collectors;
66

77
/**
8-
* A context for SMT containing auxiliary datatype and function definitions.
8+
* A context for SMT containing datatype and function definitions.
99
*
10-
* @param datatypes The table of auxiliary datatype definitions.
11-
* @param functions The table of auxiliary function definitions.
10+
* @param datatypes The table of datatype definitions.
11+
* @param functions The table of function definitions.
1212
*/
1313
public record SmtContext(Map<String, Datatype> datatypes, Map<String, SmtContext.Function> functions) {
1414

@@ -29,7 +29,7 @@ public String toString() {
2929
}
3030

3131
/**
32-
* A definition of an auxiliary datatype.
32+
* A definition of an (inductive) datatype.
3333
*
3434
* @param name The name of the datatype.
3535
* @param constructors The set of constructors for the datatype.
@@ -47,7 +47,7 @@ public String toString() {
4747
}
4848

4949
/**
50-
* A constructor for an auxiliary datatype.
50+
* A constructor for a datatype.
5151
*
5252
* @param name The name of the constructor.
5353
* @param argumentTypes The types of the constructor's arguments.
@@ -72,7 +72,7 @@ public String toString() {
7272
}
7373

7474
/**
75-
* A definition of an auxiliary function.
75+
* A definition of a function.
7676
*
7777
* @param name The name of the function.
7878
* @param arguments The arguments to the function.
@@ -84,7 +84,7 @@ public record Function(String name, List<TypedVar> arguments, SmtTerm body) {
8484
public String toString() {
8585
return String.format(
8686
"(lambda (%s) %s)",
87-
arguments.stream().map(TypedVar::toStringSExpr).collect(Collectors.joining(" ")),
87+
arguments.stream().map(TypedVar::toString).collect(Collectors.joining(" ")),
8888
body);
8989
}
9090

src/main/java/org/semgus/java/object/SmtTerm.java

Lines changed: 100 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ static SmtTerm deserialize(Object termDtoRaw) throws DeserializationException {
3333
case "application" -> deserializeApplication(termDto); // a function application
3434
case "exists" -> deserializeQuantifier(termDto, Quantifier.Type.EXISTS); // an existential quantifier
3535
case "forall" -> deserializeQuantifier(termDto, Quantifier.Type.FOR_ALL); // a universal quantifier
36+
case "lambda" -> deserializeLambda(termDto); // a lambda abstraction
37+
case "match" -> deserializeMatch(termDto); // a pattern matching expression
3638
case "variable" -> deserializeVariable(termDto); // a variable
3739
case "bitvector" -> deserializeBitVector(termDto); // a bit vector
3840
default -> throw new DeserializationException(
@@ -132,6 +134,44 @@ private static SmtTerm deserializeQuantifier(JSONObject termDto, Quantifier.Type
132134
return new Quantifier(qType, Arrays.asList(bindings), child);
133135
}
134136

137+
/**
138+
* Deserializes a lambda abstraction.
139+
*
140+
* @param termDto The JSON representation of the lambda abstraction.
141+
* @return The deserialized lambda abstraction.
142+
* @throws DeserializationException If {@code termDto} is not a valid representation of lambda abstraction.
143+
*/
144+
private static SmtTerm deserializeLambda(JSONObject termDto) throws DeserializationException {
145+
return new Lambda(JsonUtils.getStrings(termDto, "arguments"), deserializeAt(termDto, "body"));
146+
}
147+
148+
/**
149+
* Deserializes a pattern-matching expression.
150+
*
151+
* @param termDto The JSON representation of the pattern-matching expression.
152+
* @return The deserialized pattern-matching expression.
153+
* @throws DeserializationException If {@code termDto} is not a valid representation of a pattern match.
154+
*/
155+
private static SmtTerm deserializeMatch(JSONObject termDto) throws DeserializationException {
156+
SmtTerm matchTerm = deserializeAt(termDto, "term");
157+
List<JSONObject> casesDto = JsonUtils.getObjects(termDto, "binders");
158+
159+
Match.Case[] cases = new Match.Case[casesDto.size()];
160+
for (int i = 0; i < cases.length; i++) {
161+
JSONObject caseDto = casesDto.get(i);
162+
try {
163+
cases[i] = new Match.Case(
164+
JsonUtils.getString(caseDto, "operator"),
165+
JsonUtils.getStrings(caseDto, "arguments"),
166+
deserializeAt(caseDto, "child"));
167+
} catch (DeserializationException e) {
168+
throw e.prepend("binders." + i);
169+
}
170+
}
171+
172+
return new Match(matchTerm, Arrays.asList(cases));
173+
}
174+
135175
/**
136176
* Deserializes a variable.
137177
*
@@ -229,10 +269,7 @@ record Application(Identifier name, Identifier returnType, List<TypedTerm> argum
229269

230270
@Override
231271
public String toString() {
232-
if (arguments.size() == 0) {
233-
return "(" + name + ")";
234-
}
235-
return String.format("(%s %s)",
272+
return arguments.isEmpty() ? name.toString() : String.format("(%s %s)",
236273
name, arguments.stream().map(TypedTerm::toString).collect(Collectors.joining(" ")));
237274
}
238275

@@ -269,43 +306,94 @@ public enum Type {
269306
/**
270307
* The existential quantifier.
271308
*/
272-
EXISTS("∃"),
309+
EXISTS("∃", "exists"),
273310

274311
/**
275312
* The universal quantifier.
276313
*/
277-
FOR_ALL("∀");
314+
FOR_ALL("∀", "forall");
278315

279316
/**
280317
* The symbol representing the quantifier.
281318
*/
282-
public final String symbol;
319+
public final String symbol, name;
283320

284321
/**
285322
* Constructs a quantifier type.
286323
*
287324
* @param symbol The symbol representing the quantifier.
288325
*/
289-
Type(String symbol) {
326+
Type(String symbol, String name) {
290327
this.symbol = symbol;
328+
this.name = name;
291329
}
292330

293331
@Override
294332
public String toString() {
295-
return symbol;
333+
return name;
296334
}
297335
}
298336

299337
@Override
300338
public String toString() {
301339
return String.format("(%s (%s) %s)",
302340
type,
303-
bindings.stream().map(TypedVar::toStringSExpr).collect(Collectors.joining(" ")),
341+
bindings.stream().map(TypedVar::toString).collect(Collectors.joining(" ")),
304342
child);
305343
}
306344

307345
}
308346

347+
/**
348+
* Represents a lambda abstraction in an SMT formula.
349+
*
350+
* @param arguments The names of the lambda arguments. Beware of conflicts with variables in the outer context!
351+
* @param body The body of the lambda term, which may contain the arguments as bound variables.
352+
*/
353+
record Lambda(List<String> arguments, SmtTerm body) implements SmtTerm {
354+
355+
@Override
356+
public String toString() {
357+
return String.format("(lambda (%s) %s)", String.join(" ", arguments), body);
358+
}
359+
360+
}
361+
362+
/**
363+
* Represents a pattern-matching expression in an SMT formula. Used to match against constructors for inductive
364+
* types as defined by {@link org.semgus.java.event.SmtSpecEvent.DefineDatatypeEvent}.
365+
*
366+
* @param matchTerm The term being matched on.
367+
* @param cases The match cases.
368+
*/
369+
record Match(SmtTerm matchTerm, List<Case> cases) implements SmtTerm {
370+
371+
@Override
372+
public String toString() {
373+
return String.format("(match %s (%s))",
374+
matchTerm,
375+
cases.stream().map(Case::toString).collect(Collectors.joining(" ")));
376+
}
377+
378+
/**
379+
* A match case in a {@link Match} pattern-matching expression.
380+
*
381+
* @param opName The name of the operator to match against.
382+
* @param arguments The names to which the operator's arguments should be bound in the result term.
383+
* @param result The match result.
384+
*/
385+
public record Case(String opName, List<String> arguments, SmtTerm result) {
386+
387+
@Override
388+
public String toString() {
389+
return arguments.isEmpty() ? String.format("(%s %s)", opName, result)
390+
: String.format("((%s %s) %s)", opName, String.join(" ", arguments), result);
391+
}
392+
393+
}
394+
395+
}
396+
309397
/**
310398
* Represents a variable in an SMT formula.
311399
*
@@ -359,11 +447,11 @@ record CBitVector(int size, BitSet value) implements SmtTerm {
359447

360448
@Override
361449
public String toString() {
362-
StringBuilder sb = new StringBuilder("<");
450+
StringBuilder sb = new StringBuilder("#b");
363451
for (int i = size - 1; i >= 0; i--) {
364452
sb.append(value.get(i) ? '1' : '0');
365453
}
366-
return sb.append(">").toString();
454+
return sb.toString();
367455
}
368456

369457
}

src/main/java/org/semgus/java/object/TypedVar.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,17 @@ public static List<TypedVar> fromNamesAndTypes(List<String> names, List<Identifi
2828

2929
@Override
3030
public String toString() {
31-
return String.format("%s: %s", name, type);
31+
return String.format("(%s %s)", name, type);
3232
}
3333

3434
/**
35-
* Stringifies this typed variable as an s-expression of the form "(var type)".
35+
* Use {@link #toString()} instead.
3636
*
3737
* @return The stringified typed variable.
3838
*/
39+
@Deprecated(since = "1.1.0")
3940
public String toStringSExpr() {
40-
return String.format("(%s %s)", name, type);
41+
return toString();
4142
}
4243

4344
}

src/main/java/org/semgus/java/problem/ProblemGenerator.java

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -96,14 +96,14 @@ public static SemgusProblem fromEvents(Iterable<SpecEvent> events) {
9696
private final Map<String, AttributeValue> metadata = new HashMap<>();
9797

9898
/**
99-
* Collected auxiliary datatype definitions.
99+
* Collected datatype definitions.
100100
*/
101-
private final Map<String, SmtContext.Datatype> auxDatatypeDefs = new HashMap<>();
101+
private final Map<String, SmtContext.Datatype> datatypeDefs = new HashMap<>();
102102

103103
/**
104-
* Collected auxiliary function definitions.
104+
* Collected function definitions.
105105
*/
106-
private final Map<String, SmtContext.Function> auxFunctionDefs = new HashMap<>();
106+
private final Map<String, SmtContext.Function> functionDefs = new HashMap<>();
107107

108108
/**
109109
* Collected term types.
@@ -156,21 +156,21 @@ private void consumeSetInfo(MetaSpecEvent.SetInfoEvent event) {
156156
}
157157

158158
/**
159-
* Collects an auxiliary function definition from a "declare-datatype" event.
159+
* Collects a function definition from a "declare-datatype" event.
160160
*
161161
* @param event The event.
162162
*/
163163
private void consumeDefineFunction(SmtSpecEvent.DefineFunctionEvent event) {
164-
auxFunctionDefs.put(event.name(), new SmtContext.Function(event.name(), event.arguments(), event.body()));
164+
functionDefs.put(event.name(), new SmtContext.Function(event.name(), event.arguments(), event.body()));
165165
}
166166

167167
/**
168-
* Collects an auxiliary datatype definition from a "declare-datatype" event.
168+
* Collects a datatype definition from a "declare-datatype" event.
169169
*
170170
* @param event The event.
171171
*/
172172
private void consumeDefineDatatype(SmtSpecEvent.DefineDatatypeEvent event) {
173-
auxDatatypeDefs.put(event.name(), new SmtContext.Datatype(event.name(), event.constructors().stream()
173+
datatypeDefs.put(event.name(), new SmtContext.Datatype(event.name(), event.constructors().stream()
174174
.map(c -> new SmtContext.Datatype.Constructor(c.name(), c.argumentTypes()))
175175
.collect(Collectors.toUnmodifiableMap(SmtContext.Datatype.Constructor::name, c -> c))));
176176
}
@@ -289,7 +289,7 @@ public SemgusProblem end() {
289289
nonTerminals,
290290
new ArrayList<>(constraints),
291291
new HashMap<>(metadata),
292-
new SmtContext(auxDatatypeDefs, auxFunctionDefs));
292+
new SmtContext(datatypeDefs, functionDefs));
293293
}
294294

295295
/**

0 commit comments

Comments
 (0)