Skip to content

Commit 1eb27c3

Browse files
committed
Updated python code to match original repository again
1 parent f0d447a commit 1eb27c3

File tree

2 files changed

+11
-20
lines changed

2 files changed

+11
-20
lines changed

code/model.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ class Seq2Seq(nn.Module):
1111
Build Seqence-to-Sequence.
1212
1313
Parameters:
14-
1514
* `encoder`- encoder of seq2seq model. e.g. roberta
1615
* `decoder`- decoder of seq2seq model. e.g. transformer
1716
* `config`- configuration of encoder model.
@@ -73,9 +72,8 @@ def forward(self, source_ids=None,source_mask=None,target_ids=None,target_mask=N
7372
return outputs
7473
else:
7574
#Predict
76-
preds=[]
77-
#zero=torch.cuda.LongTensor(1).fill_(0)
78-
zero=torch.LongTensor(1).fill_(0)
75+
preds=[]
76+
zero=torch.cuda.LongTensor(1).fill_(0)
7977
for i in range(source_ids.shape[0]):
8078
context=encoder_output[:,i:i+1]
8179
context_mask=source_mask[i:i+1,:]
@@ -108,9 +106,7 @@ def forward(self, source_ids=None,source_mask=None,target_ids=None,target_mask=N
108106
class Beam(object):
109107
def __init__(self, size,sos,eos):
110108
self.size = size
111-
#TODO: Make tt with switch on version
112-
#self.tt = torch.cuda
113-
self.tt = torch
109+
self.tt = torch.cuda
114110
# The score for each translation on the beam.
115111
self.scores = self.tt.FloatTensor(size).zero_()
116112
# The backpointers at each time-step.
@@ -138,12 +134,9 @@ def advance(self, wordLk):
138134
"""
139135
Given prob over words for every last beam `wordLk` and attention
140136
`attnOut`: Compute and update the beam search.
141-
142137
Parameters:
143-
144138
* `wordLk`- probs of advancing from the last step (K x words)
145139
* `attnOut`- attention at the last step
146-
147140
Returns: True if beam search is complete.
148141
"""
149142
numWords = wordLk.size(1)
@@ -218,5 +211,4 @@ def buildTargetTokens(self, preds):
218211
break
219212
tokens.append(tok)
220213
sentence.append(tokens)
221-
return sentence
222-
214+
return sentence

code/run.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
#
99
# http://www.apache.org/licenses/LICENSE-2.0
1010
#
11-
# Unless required by applicable law or agreed to in writing, softwareFuchs.Mader.Luchs
12-
11+
# Unless required by applicable law or agreed to in writing, software
1312
# distributed under the License is distributed on an "AS IS" BASIS,
1413
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1514
# See the License for the specific language governing permissions and
@@ -304,17 +303,19 @@ def main():
304303
'weight_decay': args.weight_decay},
305304
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
306305
]
306+
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
307307
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
308-
scheduler = get_linear_schedule_with_warmup(optimizer,
309-
num_warmup_steps=int(len(train_dataloader)*args.num_train_epochs*0.1),
310-
num_training_steps=len(train_dataloader)*args.num_train_epochs)
308+
scheduler = get_linear_schedule_with_warmup(optimizer,
309+
num_warmup_steps=int(t_total*0.1),
310+
num_training_steps=t_total)
311311

312312
#Start training
313313
logger.info("***** Running training *****")
314314
logger.info(" Num examples = %d", len(train_examples))
315315
logger.info(" Batch size = %d", args.train_batch_size)
316316
logger.info(" Num epoch = %d", args.num_train_epochs)
317317

318+
318319
model.train()
319320
dev_dataset={}
320321
nb_tr_examples, nb_tr_steps,tr_loss,global_step,best_bleu,best_loss = 0, 0,0,0,0,1e6
@@ -515,6 +516,4 @@ def main():
515516

516517

517518
if __name__ == "__main__":
518-
main()
519-
520-
519+
main()

0 commit comments

Comments
 (0)