Skip to content

Conversation

@yuanfz98
Copy link
Contributor

@yuanfz98 yuanfz98 commented Jul 6, 2022

Hello,

This PR supports aten::rnn_tanh, aten::rnn_relu. The idea is from the previous implementation of GRU and LSTM in relay.

Links to issue #11827

def test_RNN_torch(num_layers: int,
                                    bidirectional: bool,
                                    use_bias: bool,
                                    hidden_size: int,
                                    input_size: int,
                                    seq_len: int,
                                    batch_first: bool,
                                    batch_size: int):
    r''' 
    Args:
        num_layers (int): num_layers to be passed to torch.nn.RNN
        bidirectional (bool): whether to build bidirectional RNN or not
        use_bias (bool): whether to use bias or not
        hidden_size (int): hidden_size of RNN cells
        input_size (int): Input features
        seq_len (int): Timesteps in input data
        batch_first (bool): Whether batch dimension is first or second dimension in input tensor
        batch_size (int): Batch size of input. If 0, unbatched input will be fed to network
    '''

    if batch_first:
        input_shape = (batch_size, seq_len, input_size)
    else:
        input_shape = (seq_len, batch_size, input_size)
    pytorch_net = torch.nn.Sequential(
        torch.nn.RNN(input_size,
                     hidden_size,
                     batch_first=batch_first,
                     num_layers=num_layers,
                     bidirectional=bidirectional,
                     bias=use_bias)
    )

    scripted_model = torch.jit.trace(pytorch_net.eval(),
                                     torch.randn(input_shape))

    mod, params = relay.frontend.from_pytorch(scripted_model,
                                              [('input', input_shape)])
    mod = relay.transform.InferType()(mod)
    print(mod.astext())

if __name__ == "__main__":

    test_RNN_torch(1,
                   False,
                   True,
                   5,
                   5,
                   15,
                   True,
                   32)

Out:

#[version = "0.0.5"]
type List[A] {
  Cons(A, List[A]),
  Nil,
}

type Option[A] {
  Some(A),
  None,
}

type Tree[A] {
  Rose(A, List[Tree[A]]),
}

type tensor_float16_t {
  tensor_nil_float16,
  tensor0_float16(float16),
  tensor1_float16(Tensor[(?), float16]),
  tensor2_float16(Tensor[(?, ?), float16]),
  tensor3_float16(Tensor[(?, ?, ?), float16]),
  tensor4_float16(Tensor[(?, ?, ?, ?), float16]),
  tensor5_float16(Tensor[(?, ?, ?, ?, ?), float16]),
  tensor6_float16(Tensor[(?, ?, ?, ?, ?, ?), float16]),
}

type tensor_float32_t {
  tensor_nil_float32,
  tensor0_float32(float32),
  tensor1_float32(Tensor[(?), float32]),
...

@masahi masahi merged commit 40d242a into apache:main Jul 7, 2022
masahi pushed a commit to masahi/tvm that referenced this pull request Jul 15, 2022
* emptycommit 2nd try

* dev

* comments

* format

* format

Co-authored-by: yuanfz <[email protected]>
mikeseven pushed a commit to mikeseven/tvm that referenced this pull request Sep 27, 2023
* emptycommit 2nd try

* dev

* comments

* format

* format

Co-authored-by: yuanfz <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants