Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
296 changes: 213 additions & 83 deletions pina/label_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,56 @@ def __init__(self, x, labels):
[0.9518, 0.1025],
[0.8066, 0.9615]])
'''
if x.ndim == 1:
x = x.reshape(-1, 1)
# if x.ndim == 1:
# x = x.reshape(-1, 1)

if isinstance(labels, str):
labels = [labels]

if len(labels) != x.shape[-1]:
raise ValueError('the tensor has not the same number of columns of '
'the passed labels.')
# print(labels)
if (isinstance(labels, (tuple, list))
and not isinstance(labels[0], (tuple, list))):
labels = [labels]

print(labels)
print(x.dim)
if len(labels) > x.ndim:
raise ValueError(
'The number of labels is greater than the number of columns '
'of the tensor.')

# print(len(labels), x.ndim, range(1-x.ndim, len(labels)-x.ndim, 1))
k_ = [-k for k in range(1, len(labels)+1, 1)]
if isinstance(labels, (tuple, list)):
self.dim_labels = list(k_)
labels = dict(zip(k_, labels))
elif isinstance(labels, dict):
self.dim_labels = list(labels.keys())
labels = dict(zip(k_, labels.values()))
# print(labels)


else:
raise TypeError(
'`labels` should be a str, a list of str, a list of list of str or a dict')

assert isinstance(labels, dict)
print(labels)

# print(x.shape)
for d in labels:
# print(x.shape[d], len(labels[d]), d)
if x.shape[d] != len(labels[d]):
err = (
f'The tensor has not the same number of columns of '
f'the passed labels. {x.shape[d]} != {len(labels[d])} '
f'(d = {d}).'
)
raise ValueError(err)

# if len(labels) != x.shape[-1]:
# raise ValueError('the tensor has not the same number of columns of '
# 'the passed labels.')
self._labels = labels

@property
Expand All @@ -93,27 +134,27 @@ def labels(self, labels):

self._labels = labels # assign the label

@staticmethod
def vstack(label_tensors):
"""
Stack tensors vertically. For more details, see
:meth:`torch.vstack`.
# @staticmethod
# def vstack(label_tensors):
# """
# Stack tensors vertically. For more details, see
# :meth:`torch.vstack`.

:param list(LabelTensor) label_tensors: the tensors to stack. They need
to have equal labels.
:return: the stacked tensor
:rtype: LabelTensor
"""
if len(label_tensors) == 0:
return []
# :param list(LabelTensor) label_tensors: the tensors to stack. They need
# to have equal labels.
# :return: the stacked tensor
# :rtype: LabelTensor
# """
# if len(label_tensors) == 0:
# return []

all_labels = [label for lt in label_tensors for label in lt.labels]
if set(all_labels) != set(label_tensors[0].labels):
raise RuntimeError('The tensors to stack have different labels')
# all_labels = [label for lt in label_tensors for label in lt.labels]
# if set(all_labels) != set(label_tensors[0].labels):
# raise RuntimeError('The tensors to stack have different labels')

labels = label_tensors[0].labels
tensors = [lt.extract(labels) for lt in label_tensors]
return LabelTensor(torch.vstack(tensors), labels)
# labels = label_tensors[0].labels
# tensors = [lt.extract(labels) for lt in label_tensors]
# return LabelTensor(torch.vstack(tensors), labels)

def clone(self, *args, **kwargs):
"""
Expand Down Expand Up @@ -167,6 +208,59 @@ def cpu(self, *args, **kwargs):
new.data = tmp.data
return tmp

def extract_(self, label_to_extract):
"""
"""
if isinstance(label_to_extract, str):
label_to_extract = [label_to_extract]

if isinstance(label_to_extract, (tuple, list)):
# TODO:
# comment factorize improve
# Lasciate ogni speranza, o voi che entrate
print(self.labels)
dim_mask = []
new_labels = []
new_shape = []
for j in range(-self.ndim, 0, 1):
jcomp_valid_indeces = [True] * self.shape[j]
print(self.dim_labels)
if j in self.dim_labels:
jcomp_labels = self.labels[j]

for i, label in enumerate(label_to_extract):
if label in jcomp_labels:
index = jcomp_labels.index(label)
jcomp_valid_indeces[index] = False

if all(jcomp_valid_indeces):
new_labels.append(jcomp_labels)
else:
new_labels.append([
jcomp_labels[i] for i, valid in enumerate(jcomp_valid_indeces) if not valid])

new_shape.append(len(new_labels[-1]))

else: # j not in self.dim_labels
new_shape.append(self.shape[j])
print(j, new_labels)
dim_mask.append(torch.tensor(jcomp_valid_indeces))

def create_mask(dim_mask):
grids = torch.meshgrid(dim_mask)
f = grids[0]
for g in grids[1:]:
f = f & g
return f
mask = create_mask(dim_mask)
print(mask.shape)
print(new_labels)
print(self.tensor[~mask].reshape(new_shape[::-1]).shape)

new_t = LabelTensor(self.tensor[~mask].reshape(new_shape[::-1]).T, labels=new_labels[::-1])

return new_t

def extract(self, label_to_extract):
"""
Extract the subset of the original tensor by returning all the columns
Expand Down Expand Up @@ -202,84 +296,120 @@ def extract(self, label_to_extract):
return extracted_tensor

def detach(self):
"""
Return a new Tensor, detached from the current graph.
"""
detached = super().detach()
if hasattr(self, '_labels'):
detached._labels = self._labels
return detached


def requires_grad_(self, mode = True):
"""
Set tensor's ``requires_grad`` attribute in-place.
"""
lt = super().requires_grad_(mode)
lt.labels = self.labels
return lt

def append(self, lt, mode='std'):
"""
Return a copy of the merged tensors.
# def append(self, lt, mode='std'):
# """
# Return a copy of the merged tensors.

# :param LabelTensor lt: The tensor to merge.
# :param str mode: {'std', 'first', 'cross'}
# :return: The merged tensors.
# :rtype: LabelTensor
# """
# if set(self.labels).intersection(lt.labels):
# raise RuntimeError('The tensors to merge have common labels')

# new_labels = self.labels + lt.labels
# if mode == 'std':
# new_tensor = torch.cat((self, lt), dim=1)
# elif mode == 'first':
# raise NotImplementedError
# elif mode == 'cross':
# tensor1 = self
# tensor2 = lt
# n1 = tensor1.shape[0]
# n2 = tensor2.shape[0]

# tensor1 = LabelTensor(tensor1.repeat(n2, 1), labels=tensor1.labels)
# tensor2 = LabelTensor(tensor2.repeat_interleave(n1, dim=0),
# labels=tensor2.labels)
# new_tensor = torch.cat((tensor1, tensor2), dim=1)

# new_tensor = new_tensor.as_subclass(LabelTensor)
# new_tensor.labels = new_labels
# return new_tensor
def append(self, lt, dim=None, component=None):


if dim is None and component is None:
pass

:param LabelTensor lt: The tensor to merge.
:param str mode: {'std', 'first', 'cross'}
:return: The merged tensors.
:rtype: LabelTensor
"""
if set(self.labels).intersection(lt.labels):
raise RuntimeError('The tensors to merge have common labels')

new_labels = self.labels + lt.labels
if mode == 'std':
new_tensor = torch.cat((self, lt), dim=1)
elif mode == 'first':
raise NotImplementedError
elif mode == 'cross':
tensor1 = self
tensor2 = lt
n1 = tensor1.shape[0]
n2 = tensor2.shape[0]

tensor1 = LabelTensor(tensor1.repeat(n2, 1), labels=tensor1.labels)
tensor2 = LabelTensor(tensor2.repeat_interleave(n1, dim=0),
labels=tensor2.labels)
new_tensor = torch.cat((tensor1, tensor2), dim=1)

new_tensor = new_tensor.as_subclass(LabelTensor)
new_tensor.labels = new_labels
return new_tensor
if dim is None and component is not None:

def __getitem__(self, index):
"""
Return a copy of the selected tensor.
"""
if self.ndim != lt.ndim:
raise RuntimeError('The tensors to merge have different dimensions')

if isinstance(index, str) or (isinstance(index, (tuple, list))and all(isinstance(a, str) for a in index)):
return self.extract(index)
common_labels = [i for i in self.labels.values()
for j in lt.labels.values() if i == j]

selected_lt = super(Tensor, self).__getitem__(index)

try:
len_index = len(index)
except TypeError:
len_index = 1

if isinstance(index, int) or len_index == 1:
if selected_lt.ndim == 1:
selected_lt = selected_lt.reshape(1, -1)
if hasattr(self, 'labels'):
selected_lt.labels = self.labels
elif len_index == 2:
if selected_lt.ndim == 1:
selected_lt = selected_lt.reshape(-1, 1)
if hasattr(self, 'labels'):
if isinstance(index[1], list):
selected_lt.labels = [self.labels[i] for i in index[1]]
else:
selected_lt.labels = self.labels[index[1]]
else:
selected_lt.labels = self.labels
# if len(common_labels) > 1:
# raise RuntimeError(f'The tensors to merge have too many common labels: {common_labels}')

if len(common_labels) == 0:
raise RuntimeError(f'The tensors to merge have no common labels')

common_labels = common_labels[0]
for k, v in self.labels.items():
if v == common_labels:
dim1 = [True] * self.ndim
dim1[k] = False

return selected_lt
for k, v in lt.labels.items():
if v == common_labels:
dim2 = [True] * lt.ndim
dim2[k] = False

if dim1 == dim2:
print(dim1, common_labels)
dim_to_append = [i for i, j in enumerate(dim1) if j == True]
print(dim_to_append)
if len(dim_to_append) > 1:
raise RuntimeError(f'The tensors to merge have too dimensions and only {component} is given')
result = LabelTensor(
torch.cat((self.tensor, lt.tensor), dim=dim_to_append[0]),
labels={k: common_labels}
)
print(result)
print('ggggggggg')

return result
else:
raise NotImplementedError

def _append(self, lt, mode):
print(self.labels, lt.labels)

def __getitem__(self, index):
"""
Disable the slicing of the labels.
"""
text = (
'LabelTensor does not support slicing. '
'Use `extract` instead, or `tensor` to get the underlying tensor.'
)
raise RuntimeError(text)

@property
def tensor(self):
"""
Return the underlying tensor.
"""
return self.as_subclass(Tensor)

def __len__(self) -> int:
Expand All @@ -290,5 +420,5 @@ def __str__(self):
s = f'labels({str(self.labels)})\n'
else:
s = 'no labels\n'
s += super().__str__()
s += self.tensor.__str__()
return s
Empty file added pina/optimizer/__init__.py
Empty file.
4 changes: 4 additions & 0 deletions pina/optimizer/optimizer_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
""" Abstract class for Optimizer """

class Optimizer:
pass
17 changes: 17 additions & 0 deletions pina/optimizer/torch_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from .optimizer_interface import Optimizer
from ..utils import check_consistency

class TorchOptimizer(Optimizer):

def __init__(self, optimizer_class, **kwargs):
check_consistency(optimizers, torch.optim.Optimizer, subclass=True)

self.optimizer_class = optimizer_class
self.kwargs = kwargs

def hook(self, parameters):
self.optimizer_instance = self.optimizer_class(
parameters, **self.kwargs
)


1 change: 1 addition & 0 deletions pina/problem/abstract_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def have_sampled_points(self):
Check if all points for
``Location`` are sampled.
"""
print(self._have_sampled_points)
return all(self._have_sampled_points.values())

@property
Expand Down
Loading