Skip to content

Commit f9c655a

Browse files
committed
Make Schedule::Copy non-const, fork RND seed in Copy
1 parent 76c7677 commit f9c655a

File tree

6 files changed

+7
-6
lines changed

6 files changed

+7
-6
lines changed

include/tvm/support/random_engine.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ class LinearCongruentialEngine {
115115
* \return The forked seed.
116116
*/
117117
TRandState ForkSeed() {
118-
// In order for reproducibility, we computer the new seed using RNG's random state and a
118+
// In order for reproducibility, we compute the new seed using RNG's random state and a
119119
// different set of parameters. Note that both 32767 and 1999999973 are prime numbers.
120120
return ((*this)() * 32767) % 1999999973;
121121
}

include/tvm/tir/schedule/schedule.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ class ScheduleNode : public runtime::Object {
123123
* 3) All the random variables are valid in the copy, pointing to the corresponding sref
124124
* reconstructed
125125
*/
126-
virtual Schedule Copy() const = 0;
126+
virtual Schedule Copy() = 0;
127127
/*!
128128
* \brief Seed the randomness
129129
* \param seed The new random seed, -1 if use device random, otherwise non-negative

src/tir/schedule/concrete_schedule.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,11 +182,12 @@ void ConcreteScheduleNode::Copy(ScheduleState* new_state, TSymbolTable* new_symb
182182
new_state->get()->DebugVerify();
183183
}
184184

185-
Schedule ConcreteScheduleNode::Copy() const {
185+
Schedule ConcreteScheduleNode::Copy() {
186186
ObjectPtr<ConcreteScheduleNode> n = make_object<ConcreteScheduleNode>();
187187
n->error_render_level_ = this->error_render_level_;
188188
ConcreteScheduleNode::Copy(&n->state_, &n->symbol_table_);
189189
n->analyzer_ = std::make_unique<arith::Analyzer>(); // new analyzer needed because it is stateful
190+
n->rand_state_ = ForkSeed();
190191
return Schedule(std::move(n));
191192
}
192193

src/tir/schedule/concrete_schedule.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class ConcreteScheduleNode : public ScheduleNode {
6161
public:
6262
ScheduleState state() const final { return state_; }
6363
Optional<Trace> trace() const override { return NullOpt; }
64-
Schedule Copy() const override;
64+
Schedule Copy() override;
6565
void Seed(support::LinearCongruentialEngine::TRandState seed = -1) final;
6666
support::LinearCongruentialEngine::TRandState ForkSeed() final;
6767

src/tir/schedule/traced_schedule.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ Schedule Schedule::Traced(IRModule mod, support::LinearCongruentialEngine::TRand
3333
return Schedule(std::move(n));
3434
}
3535

36-
Schedule TracedScheduleNode::Copy() const {
36+
Schedule TracedScheduleNode::Copy() {
3737
ObjectPtr<TracedScheduleNode> n = make_object<TracedScheduleNode>();
3838
n->error_render_level_ = this->error_render_level_;
3939
ConcreteScheduleNode::Copy(&n->state_, &n->symbol_table_);

src/tir/schedule/traced_schedule.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class TracedScheduleNode : public ConcreteScheduleNode {
4343

4444
public:
4545
Optional<Trace> trace() const final { return trace_; }
46-
Schedule Copy() const final;
46+
Schedule Copy() final;
4747

4848
public:
4949
/******** Schedule: Sampling ********/

0 commit comments

Comments
 (0)