Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions ezyrb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
__all__ = [
'Database', 'Snapshot', 'Reduction', 'POD', 'Approximation', 'RBF', 'Linear', 'GPR',
'ANN', 'KNeighborsRegressor', 'RadiusNeighborsRegressor', 'AE',
'ReducedOrderModel', 'PODAE', 'RegularGrid'
'ReducedOrderModel', 'PODAE', 'RegularGrid',
'MultiReducedOrderModel'
]

from .meta import *
from .database import Database
from .snapshot import Snapshot
from .parameter import Parameter
from .reducedordermodel import ReducedOrderModel
from .reducedordermodel import ReducedOrderModel, MultiReducedOrderModel
from .reduction import *
from .approximation import *
from .regular_grid import RegularGrid
5 changes: 4 additions & 1 deletion ezyrb/approximation/ann.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,10 @@ def _build_model(self, points, values):
layers.insert(0, points.shape[1])
layers.append(values.shape[1])

self.model = self._list_to_sequential(layers, self.function)
if self.model is None:
self.model = self._list_to_sequential(layers, self.function)
else:
self.model = self.model

def fit(self, points, values):
"""
Expand Down
28 changes: 24 additions & 4 deletions ezyrb/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class Database():
None meaning no scaling.
:param array_like space: the input spatial data
"""
def __init__(self, parameters=None, snapshots=None):
def __init__(self, parameters=None, snapshots=None, space=None):
self._pairs = []

if parameters is None and snapshots is None:
Expand All @@ -30,13 +30,21 @@ def __init__(self, parameters=None, snapshots=None):

if len(parameters) != len(snapshots):
raise ValueError('parameters and snapshots must have the same length')

for param, snap in zip(parameters, snapshots):
param = Parameter(param)
snap = Snapshot(snap)
if isinstance(space, dict):
snap_space = space.get(tuple(param.values), None)
# print('snap_space', snap_space)
else:
snap_space = space
snap = Snapshot(snap, space=snap_space)

self.add(param, snap)

# TODO: eventually improve the `space` assignment in the snapshots,
# snapshots can have different space coordinates

@property
def parameters_matrix(self):
"""
Expand Down Expand Up @@ -113,7 +121,7 @@ def split(self, chunks, seed=None):
>>> train, test = db.split([80, 20]) # n snapshots

"""

if seed is not None:
np.random.seed(seed)

Expand Down Expand Up @@ -152,3 +160,15 @@ def split(self, chunks, seed=None):
new_database[i].add(p, s)

return new_database

def get_snapshot_space(self, index):
"""
Get the space coordinates of a snapshot by its index.

:param int index: The index of the snapshot.
:return: The space coordinates of the snapshot.
:rtype: numpy.ndarray
"""
if index < 0 or index >= len(self._pairs):
raise IndexError("Snapshot index out of range.")
return self._pairs[index][1].space
6 changes: 6 additions & 0 deletions ezyrb/plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,15 @@
'DatabaseScaler',
'ShiftSnapshots',
'AutomaticShiftSnapshots',
'Aggregation',
'DatabaseSplitter',
'DatabaseDictionarySplitter'
]

from .scaler import DatabaseScaler
from .plugin import Plugin
from .shift import ShiftSnapshots
from .automatic_shift import AutomaticShiftSnapshots
from .aggregation import Aggregation
from .database_splitter import DatabaseSplitter
from .database_splitter import DatabaseDictionarySplitter
Loading