Skip to content

将pytorch模型转换为fastNLP可用的模型 #286

@jwc19890114

Description

@jwc19890114

你好,我在尝试使用fastNLP构建自己的模型的时候遇到了问题。
pytorch教程[https://pytorch.apachecn.org/]中有一个字符级的文本生成模型

class RNN(nn.Module):
    def __init__(self,input_size,hidden_size,output_size):
        super().__init__()
        self.input_size=input_size
        self.hidden_size=hidden_size
        self.output_size=output_size
        
        self.i2h=nn.Linear(input_size+hidden_size,hidden_size)
        self.i2o=nn.Linear(input_size+hidden_size,output_size)
        self.softmax=nn.LogSoftmax()
        
    def forward(self,input ,hidden):
        combined=torch.cat((input,hidden),1)
        hidden=self.i2h(combined)
        output=self.i2o(combined)
        output=self.softmax(output)
        return output,hidden
    def init_hidden(self):
        return V(torch.zeros(1,self.hidden_size))

我想用fastNLP重写一次,但是在写的时候遇到一个问题,fastNLP要求返回数据为一个字典类型{'pred':outputs},那对于这个模型应该如何改写呢?
谢谢~

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions