-
Notifications
You must be signed in to change notification settings - Fork 449
Closed
Description
你好,我在尝试使用fastNLP构建自己的模型的时候遇到了问题。
pytorch教程[https://pytorch.apachecn.org/]中有一个字符级的文本生成模型
class RNN(nn.Module):
def __init__(self,input_size,hidden_size,output_size):
super().__init__()
self.input_size=input_size
self.hidden_size=hidden_size
self.output_size=output_size
self.i2h=nn.Linear(input_size+hidden_size,hidden_size)
self.i2o=nn.Linear(input_size+hidden_size,output_size)
self.softmax=nn.LogSoftmax()
def forward(self,input ,hidden):
combined=torch.cat((input,hidden),1)
hidden=self.i2h(combined)
output=self.i2o(combined)
output=self.softmax(output)
return output,hidden
def init_hidden(self):
return V(torch.zeros(1,self.hidden_size))
我想用fastNLP重写一次,但是在写的时候遇到一个问题,fastNLP要求返回数据为一个字典类型{'pred':outputs},那对于这个模型应该如何改写呢?
谢谢~
Metadata
Metadata
Assignees
Labels
No labels