Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions models/GRU.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import torch
import torch.nn as nn


# Gated Path Planning Network module
class Planner(nn.Module):
"""
Implementation of the Gated Path Planning Network.
"""
def __init__(self, num_orient, num_actions, args):
super(Planner, self).__init__()

self.num_orient = num_orient
self.num_actions = num_actions

self.l_h = args.l_h
self.k = args.k
self.f = args.f

self.hid = nn.Conv2d(
in_channels=(num_orient + 1), # maze map + goal location
out_channels=self.l_h,
kernel_size=(3, 3),
stride=1,
padding=1,
bias=True)

self.h0 = nn.Conv2d(
in_channels=self.l_h,
out_channels=self.l_h,
kernel_size=(3, 3),
stride=1,
padding=1,
bias=True)

self.conv = nn.Conv2d(
in_channels=self.l_h,
out_channels=1,
kernel_size=(self.f, self.f),
stride=1,
padding=int((self.f - 1.0) / 2),
bias=True)

self.gru = nn.GRUCell(1, self.l_h)

self.policy = nn.Conv2d(
in_channels=self.l_h,
out_channels=num_actions * num_orient,
kernel_size=(1, 1),
stride=1,
padding=0,
bias=False)

self.sm = nn.Softmax2d()

def forward(self, map_design, goal_map):
maze_size = map_design.size()[-1]
X = torch.cat([map_design, goal_map], 1)

hid = self.hid(X)
h0 = self.h0(hid).transpose(1, 3).contiguous().view(-1, self.l_h)

last_h = h0
for _ in range(0, self.k - 1):
h_map = last_h.view(-1, maze_size, maze_size, self.l_h)
h_map = h_map.transpose(3, 1)
inp = self.conv(h_map).transpose(1, 3).contiguous().view(-1, 1)

last_h= self.gru(inp, last_h)

hk = last_h.view(-1, maze_size, maze_size, self.l_h).transpose(3, 1)
logits = self.policy(hk)

# Normalize over actions
logits = logits.view(-1, self.num_actions, maze_size, maze_size)
probs = self.sm(logits)

# Reshape to output dimensions
logits = logits.view(-1, self.num_orient, self.num_actions, maze_size,
maze_size)
probs = probs.view(-1, self.num_orient, self.num_actions, maze_size,
maze_size)
logits = torch.transpose(logits, 1, 2).contiguous()
probs = torch.transpose(probs, 1, 2).contiguous()

return logits, probs, h0, hk
96 changes: 96 additions & 0 deletions models/PLSTM.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import torch
import torch.nn as nn
import sys
#sys.path.append('PATH to models folder')
from peephole_lstm import PLSTMCell

class Planner(nn.Module):
"""
Implementation of the Gated Path Planning Network.
"""
def __init__(self, num_orient, num_actions, args):
super(Planner, self).__init__()

self.num_orient = num_orient
self.num_actions = num_actions

self.l_h = args.l_h
self.k = args.k
self.f = args.f

self.hid = nn.Conv2d(
in_channels=(num_orient + 1), # maze map + goal location
out_channels=self.l_h,
kernel_size=(3, 3),
stride=1,
padding=1,
bias=True)

self.h0 = nn.Conv2d(
in_channels=self.l_h,
out_channels=self.l_h,
kernel_size=(3, 3),
stride=1,
padding=1,
bias=True)

self.c0 = nn.Conv2d(
in_channels=self.l_h,
out_channels=self.l_h,
kernel_size=(3, 3),
stride=1,
padding=1,
bias=True)

self.conv = nn.Conv2d(
in_channels=self.l_h,
out_channels=1,
kernel_size=(self.f, self.f),
stride=1,
padding=int((self.f - 1.0) / 2),
bias=True)

self.plstm = PLSTMCell(1, self.l_h, self.l_h)

self.policy = nn.Conv2d(
in_channels=self.l_h,
out_channels=num_actions * num_orient,
kernel_size=(1, 1),
stride=1,
padding=0,
bias=False)

self.sm = nn.Softmax2d()

def forward(self, map_design, goal_map):
maze_size = map_design.size()[-1]
X = torch.cat([map_design, goal_map], 1)

hid = self.hid(X)
h0 = self.h0(hid).transpose(1, 3).contiguous().view(-1, self.l_h)
c0 = self.c0(hid).transpose(1, 3).contiguous().view(-1, self.l_h)

last_h, last_c = h0, c0
for _ in range(0, self.k - 1):
h_map = last_h.view(-1, maze_size, maze_size, self.l_h)
h_map = h_map.transpose(3, 1)
inp = self.conv(h_map).transpose(1, 3).contiguous().view(-1, 1)

last_h, last_c = self.plstm(inp, last_h, last_c)

hk = last_h.view(-1, maze_size, maze_size, self.l_h).transpose(3, 1)
logits = self.policy(hk)

# Normalize over actions
logits = logits.view(-1, self.num_actions, maze_size, maze_size)
probs = self.sm(logits)

# Reshape to output dimensions
logits = logits.view(-1, self.num_orient, self.num_actions, maze_size,
maze_size)
probs = probs.view(-1, self.num_orient, self.num_actions, maze_size,
maze_size)
logits = torch.transpose(logits, 1, 2).contiguous()
probs = torch.transpose(probs, 1, 2).contiguous()

return logits, probs, h0, hk
70 changes: 70 additions & 0 deletions models/peephole_lstm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import torch
import torch.nn as nn
import math
from torch.nn.parameter import Parameter
import torch.nn.functional as F

class PLSTMCell(nn.Module):
def __init__(self, input_size, hidden_size, cell_size, bias=True):
super(PLSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.cell_size = cell_size
self.bias = bias
self.weight_ih = Parameter(torch.Tensor(hidden_size, input_size))
self.weight_hh = Parameter(torch.Tensor(hidden_size, hidden_size))
self.weight_ch = Parameter(torch.Tensor(hidden_size, cell_size))
if bias:
self.bias_ih = Parameter(torch.Tensor(hidden_size))
self.bias_hh = Parameter(torch.Tensor(hidden_size))
self.bias_ch = Parameter(torch.Tensor(hidden_size))
else:
self.register_parameter('bias_ih', None)
self.register_parameter('bias_hh', None)
self.register_parameter('bias_ch', None)
self.reset_parameters()

def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)

def LSTMPCell(self, input, hidden, cell, w_ih, w_hh, w_ch, b_ih=None, b_hh=None, b_ch=None):
if input.is_cuda:
igates = F.linear(input, w_ih)
hgates = F.linear(hidden[0], w_hh)
state = fusedBackend.LSTMFused.apply
return state(igates, hgates, hidden[1]) if b_ih is None else state(igates, hgates, hidden[1], b_ih, b_hh)

hx = hidden
cx = cell
input_gates = F.linear(input, w_ih, b_ih)
hidden_gates = F.linear(hx, w_hh, b_ch)
cell_gates = F.linear(cx, w_ch, b_ch)

ingate = input_gates + hidden_gates + cell_gates
ingate = torch.sigmoid(ingate)

forgetgate = input_gates + hidden_gates + cell_gates
forgetgate = torch.sigmoid(forgetgate)

cellgate = input_gates + hidden_gates
cellgate = torch.tanh(cellgate)
cy = (forgetgate * cx) + (ingate * cellgate)

outgate = input_gates + hidden_gates + F.linear(cy, w_ch, b_ch)
outgate = torch.sigmoid(outgate)

hy = outgate * torch.tanh(cy)

return hy, cy

def forward(self, input, hx, cx):
return self.LSTMPCell(
input, hx, cx,
self.weight_ih, self.weight_hh, self.weight_ch,
self.bias_ih, self.bias_hh, self.bias_ch,
)