Skip to content

Commit b5cceba

Browse files
committed
Correctly synthesize semantic member _modify accessor. Support
differentiation of _modify accessor for wrapped values. Fixes #55084
1 parent 7fbdbfd commit b5cceba

File tree

5 files changed

+116
-19
lines changed

5 files changed

+116
-19
lines changed

lib/SILOptimizer/Differentiation/Common.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,10 @@ bool isSemanticMemberAccessor(SILFunction *original) {
6161
auto *accessor = dyn_cast<AccessorDecl>(decl);
6262
if (!accessor)
6363
return false;
64-
// Currently, only getters and setters are supported.
65-
// TODO(https://github.com/apple/swift/issues/55084): Support `modify` accessors.
64+
// Currently, only getters, setters and _modify accessors are supported.
6665
if (accessor->getAccessorKind() != AccessorKind::Get &&
67-
accessor->getAccessorKind() != AccessorKind::Set)
66+
accessor->getAccessorKind() != AccessorKind::Set &&
67+
accessor->getAccessorKind() != AccessorKind::Modify)
6868
return false;
6969
// Accessor must come from a `var` declaration.
7070
auto *varDecl = dyn_cast<VarDecl>(accessor->getStorage());

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 84 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,7 @@ class PullbackCloner::Implementation final
905905
bool runForSemanticMemberAccessor();
906906
bool runForSemanticMemberGetter();
907907
bool runForSemanticMemberSetter();
908+
bool runForSemanticMemberModify();
908909

909910
/// If original result is non-varied, it will always have a zero derivative.
910911
/// Skip full pullback generation and simply emit zero derivatives for wrt
@@ -2452,7 +2453,8 @@ bool PullbackCloner::Implementation::run() {
24522453

24532454
// If the original function is an accessor with special-case pullback
24542455
// generation logic, do special-case generation.
2455-
if (isSemanticMemberAccessor(&original)) {
2456+
bool isSemanticMemberAcc = isSemanticMemberAccessor(&original);
2457+
if (isSemanticMemberAcc) {
24562458
if (runForSemanticMemberAccessor())
24572459
return true;
24582460
}
@@ -2730,7 +2732,8 @@ bool PullbackCloner::Implementation::run() {
27302732
#endif
27312733

27322734
LLVM_DEBUG(getADDebugStream()
2733-
<< "Generated pullback for " << original.getName() << ":\n"
2735+
<< "Generated " << (isSemanticMemberAcc ? "semantic member accessor" : "normal")
2736+
<< " pullback for " << original.getName() << ":\n"
27342737
<< pullback);
27352738
return errorOccurred;
27362739
}
@@ -3205,7 +3208,8 @@ bool PullbackCloner::Implementation::runForSemanticMemberAccessor() {
32053208
return runForSemanticMemberGetter();
32063209
case AccessorKind::Set:
32073210
return runForSemanticMemberSetter();
3208-
// TODO(https://github.com/apple/swift/issues/55084): Support `modify` accessors.
3211+
case AccessorKind::Modify:
3212+
return runForSemanticMemberModify();
32093213
default:
32103214
llvm_unreachable("Unsupported accessor kind; inconsistent with "
32113215
"`isSemanticMemberAccessor`?");
@@ -3389,6 +3393,83 @@ bool PullbackCloner::Implementation::runForSemanticMemberSetter() {
33893393
return false;
33903394
}
33913395

3396+
bool PullbackCloner::Implementation::runForSemanticMemberModify() {
3397+
auto &original = getOriginal();
3398+
auto &pullback = getPullback();
3399+
auto pbLoc = getPullback().getLocation();
3400+
3401+
auto *accessor = cast<AccessorDecl>(original.getDeclContext()->getAsDecl());
3402+
assert(accessor->getAccessorKind() == AccessorKind::Modify);
3403+
3404+
auto *origEntry = original.getEntryBlock();
3405+
// We assume that the accessor has a simple 3-BB structure with yield in the entry BB
3406+
// plus resume and unwind BBs
3407+
auto *yi = cast<YieldInst>(origEntry->getTerminator());
3408+
auto *origResumeBB = yi->getResumeBB();
3409+
3410+
auto *pbEntry = pullback.getEntryBlock();
3411+
builder.setCurrentDebugScope(
3412+
remapScope(origEntry->getScopeOfFirstNonMetaInstruction()));
3413+
builder.setInsertionPoint(pbEntry);
3414+
3415+
// Get _modify accessor argument values.
3416+
// Accessor type : $(inout Self) -> @yields @inout Argument
3417+
// Pullback type : $(inout Self', linear map tuple) -> @yields @inout Argument'
3418+
// Normally pullbacks for semantic member accessors are single BB and
3419+
// therefore have empty linear map tuple, however, coroutines have a branching
3420+
// control flow due to possible coroutine abort, so we need to accommodate for
3421+
// this. We keep branch tracing enums in order not to special case in many
3422+
// other places. As there is no way to return to coroutine via abort exit, we
3423+
// essentially "linearize" a coroutine.
3424+
auto loweredFnTy = original.getLoweredFunctionType();
3425+
auto pullbackLoweredFnTy = pullback.getLoweredFunctionType();
3426+
3427+
assert(loweredFnTy->getNumParameters() == 1 &&
3428+
loweredFnTy->getNumYields() == 1);
3429+
assert(pullbackLoweredFnTy->getNumParameters() == 2);
3430+
assert(pullbackLoweredFnTy->getNumYields() == 1);
3431+
3432+
SILValue origSelf = original.getArgumentsWithoutIndirectResults().front();
3433+
3434+
SmallVector<SILValue, 8> origFormalResults;
3435+
collectAllFormalResultsInTypeOrder(original, origFormalResults);
3436+
3437+
assert(getConfig().resultIndices->getNumIndices() == 2 &&
3438+
"Modify accessor should have two semantic results");
3439+
3440+
auto origYield = origFormalResults[*std::next(getConfig().resultIndices->begin())];
3441+
3442+
// Look up the corresponding field in the tangent space.
3443+
auto *origField = cast<VarDecl>(accessor->getStorage());
3444+
auto baseType = remapType(origSelf->getType()).getASTType();
3445+
auto *tanField = getTangentStoredProperty(getContext(), origField, baseType,
3446+
pbLoc, getInvoker());
3447+
if (!tanField) {
3448+
errorOccurred = true;
3449+
return true;
3450+
}
3451+
3452+
auto adjSelf = getAdjointBuffer(origResumeBB, origSelf);
3453+
auto *adjSelfElt = builder.createStructElementAddr(pbLoc, adjSelf, tanField);
3454+
// Modify accessors have inout yields and therefore should yield addresses.
3455+
assert(getTangentValueCategory(origYield) == SILValueCategory::Address &&
3456+
"Modify accessors should yield indirect");
3457+
3458+
// Yield the adjoint buffer and do everything else in the resume
3459+
// destination. Unwind destination is unreachable as the coroutine can never
3460+
// be aborted.
3461+
auto *unwindBB = getPullback().createBasicBlock();
3462+
auto *resumeBB = getPullbackBlock(origEntry);
3463+
builder.createYield(yi->getLoc(), {adjSelfElt}, resumeBB, unwindBB);
3464+
builder.setInsertionPoint(unwindBB);
3465+
builder.createUnreachable(SILLocation::invalid());
3466+
3467+
builder.setInsertionPoint(resumeBB);
3468+
addToAdjointBuffer(origEntry, origSelf, adjSelf, pbLoc);
3469+
3470+
return false;
3471+
}
3472+
33923473
//--------------------------------------------------------------------------//
33933474
// Adjoint buffer mapping
33943475
//--------------------------------------------------------------------------//

lib/SILOptimizer/Differentiation/VJPCloner.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,16 @@ class VJPCloner::Implementation final
460460
TypeSubstCloner::visitEndApplyInst(eai);
461461
return;
462462
}
463+
// If the original function is a semantic member accessor, do standard
464+
// cloning. Semantic member accessors have special pullback generation
465+
// logic, so all `end_apply` instructions can be directly cloned to the VJP.
466+
if (isSemanticMemberAccessor(original)) {
467+
LLVM_DEBUG(getADDebugStream()
468+
<< "Cloning `end_apply` in semantic member accessor:\n"
469+
<< *eai << '\n');
470+
TypeSubstCloner::visitEndApplyInst(eai);
471+
return;
472+
}
463473

464474
Builder.setCurrentDebugScope(getOpScope(eai->getDebugScope()));
465475
auto loc = eai->getLoc();
@@ -607,6 +617,16 @@ class VJPCloner::Implementation final
607617
TypeSubstCloner::visitBeginApplyInst(bai);
608618
return;
609619
}
620+
// If the original function is a semantic member accessor, do standard
621+
// cloning. Semantic member accessors have special pullback generation
622+
// logic, so all `begin_apply` instructions can be directly cloned to the VJP.
623+
if (isSemanticMemberAccessor(original)) {
624+
LLVM_DEBUG(getADDebugStream()
625+
<< "Cloning `begin_apply` in semantic member accessor:\n"
626+
<< *bai << '\n');
627+
TypeSubstCloner::visitBeginApplyInst(bai);
628+
return;
629+
}
610630

611631
Builder.setCurrentDebugScope(getOpScope(bai->getDebugScope()));
612632
auto loc = bai->getLoc();

test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -680,9 +680,7 @@ extension DifferentiableWrapper: Differentiable where Value: Differentiable {}
680680
// accesses.
681681

682682
struct Struct: Differentiable {
683-
// expected-error @+4 {{expression is not differentiable}}
684-
// expected-error @+3 {{expression is not differentiable}}
685-
// expected-note @+2 {{cannot differentiate access to property 'Struct._x' because 'Struct.TangentVector' does not have a stored property named '_x'}}
683+
// expected-error @+2 {{expression is not differentiable}}
686684
// expected-note @+1 {{cannot differentiate access to property 'Struct._x' because 'Struct.TangentVector' does not have a stored property named '_x'}}
687685
@DifferentiableWrapper @DifferentiableWrapper var x: Float = 10
688686

test/AutoDiff/validation-test/property_wrappers.swift

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ struct Wrapper<Value> {
1919
var wrappedValue: Value { // computed property
2020
get { value }
2121
set { value = newValue }
22+
_modify { yield &value }
2223
}
2324

2425
init(wrappedValue: Value) {
@@ -46,16 +47,13 @@ PropertyWrapperTests.test("SimpleStruct") {
4647
expectEqual((.init(x: 60, y: 0, z: 20), 300),
4748
gradient(at: Struct(), 2, of: setter))
4849

49-
// TODO: Support `modify` accessors (https://github.com/apple/swift/issues/55084).
50-
/*
5150
func modify(_ s: Struct, _ x: Tracked<Float>) -> Tracked<Float> {
5251
var s = s
5352
s.x *= x * s.z
5453
return s.x
5554
}
5655
expectEqual((.init(x: 60, y: 0, z: 20), 300),
5756
gradient(at: Struct(), 2, of: modify))
58-
*/
5957
}
6058

6159
struct GenericStruct<T> {
@@ -86,16 +84,13 @@ PropertyWrapperTests.test("GenericStruct") {
8684
expectEqual((.init(x: 60, y: 0, z: 20), 300),
8785
gradient(at: GenericStruct<Tracked<Float>>(y: 20), 2, of: setter))
8886

89-
// TODO: Support `modify` accessors (https://github.com/apple/swift/issues/55084).
90-
/*
9187
func modify<T>(_ s: GenericStruct<T>, _ x: Tracked<Float>) -> Tracked<Float> {
9288
var s = s
9389
s.x *= x * s.z
9490
return s.x
9591
}
9692
expectEqual((.init(x: 60, y: 0, z: 20), 300),
9793
gradient(at: GenericStruct<Tracked<Float>>(y: 1), 2, of: modify))
98-
*/
9994
}
10095

10196
// TF-1149: Test class with loadable type but address-only `TangentVector` type.
@@ -131,16 +126,18 @@ PropertyWrapperTests.test("SimpleClass") {
131126
gradient(at: Class(), 2, of: setter))
132127
*/
133128

134-
// TODO: Support `modify` accessors (https://github.com/apple/swift/issues/55084).
135-
/*
129+
// FIXME(TF-1175): Same issue as above
136130
func modify(_ c: Class, _ x: Tracked<Float>) -> Tracked<Float> {
137131
var c = c
138132
c.x *= x * c.z
139133
return c.x
140134
}
135+
/*
141136
expectEqual((.init(x: 60, y: 0, z: 20), 300),
142137
gradient(at: Class(), 2, of: modify))
143138
*/
139+
expectEqual((.init(x: 1, y: 0, z: 0), 0),
140+
gradient(at: Class(), 2, of: modify))
144141
}
145142

146143
// From: https://github.com/apple/swift-evolution/blob/master/proposals/0258-property-wrappers.md#proposed-solution
@@ -157,12 +154,13 @@ enum Lazy<Value> {
157154

158155
var wrappedValue: Value {
159156
// TODO(TF-1250): Replace with actual mutating getter implementation.
160-
// Requires differentiation to support functions with multiple results.
161-
get {
157+
// Requires support for mutating semantic member accessor
158+
/* mutating */ get {
162159
switch self {
163160
case .uninitialized(let initializer):
164161
let value = initializer()
165162
// NOTE: Actual implementation assigns to `self` here.
163+
// self = .initialized(value)
166164
return value
167165
case .initialized(let value):
168166
return value

0 commit comments

Comments
 (0)