@@ -20,26 +20,26 @@ class AbstractProblem(metaclass=ABCMeta):
2020
2121 def __init__ (self ):
2222
23-
2423 self ._discretized_domains = {}
2524
2625 for name , domain in self .domains .items ():
2726 if isinstance (domain , (torch .Tensor , LabelTensor )):
2827 self ._discretized_domains [name ] = domain
2928
3029 for condition_name in self .conditions :
31- self .conditions [condition_name ]._problem = self
30+ self .conditions [condition_name ].set_problem (self )
31+
3232 # # variable storing all points
33- # self.input_pts = {}
33+ self .input_pts = {}
3434
3535 # # varible to check if sampling is done. If no location
3636 # # element is presented in Condition this variable is set to true
3737 # self._have_sampled_points = {}
38- # for condition_name in self.conditions:
39- # self._have_sampled_points [condition_name] = False
38+ for condition_name in self .conditions :
39+ self ._discretized_domains [condition_name ] = False
4040
4141 # # put in self.input_pts all the points that we don't need to sample
42- # self._span_condition_points()
42+ self ._span_condition_points ()
4343
4444 def __deepcopy__ (self , memo ):
4545 """
@@ -125,7 +125,7 @@ def _span_condition_points(self):
125125 if hasattr (condition , "input_points" ):
126126 samples = condition .input_points
127127 self .input_pts [condition_name ] = samples
128- self ._have_sampled_points [condition_name ] = True
128+ self ._discretized_domains [condition_name ] = True
129129 if hasattr (self , "unknown_parameter_domain" ):
130130 # initialize the unknown parameters of the inverse problem given
131131 # the domain the user gives
@@ -141,7 +141,7 @@ def _span_condition_points(self):
141141 )
142142
143143 def discretise_domain (
144- self , n , mode = "random" , variables = "all" , locations = "all"
144+ self , n , mode = "random" , variables = "all" , domains = "all"
145145 ):
146146 """
147147 Generate a set of points to span the `Location` of all the conditions of
@@ -192,31 +192,37 @@ def discretise_domain(
192192 f"should be in { self .input_variables } ." ,
193193 )
194194
195- # check consistency location
196- locations_to_sample = [
197- condition
198- for condition in self .conditions
199- if hasattr (self .conditions [condition ], "location" )
200- ]
201- if locations == "all" :
202- # only locations that can be sampled
203- locations = locations_to_sample
195+ # # check consistency location # TODO: check if this is needed (from 0.1)
196+ # locations_to_sample = [
197+ # condition
198+ # for condition in self.conditions
199+ # if hasattr(self.conditions[condition], "location")
200+ # ]
201+ # if locations == "all":
202+ # # only locations that can be sampled
203+ # locations = locations_to_sample
204+ # else:
205+ # check_consistency(locations, str)
206+
207+ # if sorted(locations) != sorted(locations_to_sample):
208+ if domains == "all" :
209+ domains = [condition for condition in self .conditions ]
204210 else :
205- check_consistency (locations , str )
206-
207- if sorted (locations ) != sorted (locations_to_sample ):
211+ check_consistency (domains , str )
212+ print ( domains )
213+ if sorted (domains ) != sorted (self . conditions ):
208214 TypeError (
209215 f"Wrong locations for sampling. Location " ,
210216 f"should be in { locations_to_sample } ." ,
211217 )
212218
213219 # sampling
214- for location in locations :
215- condition = self .conditions [location ]
220+ for d in domains :
221+ condition = self .conditions [d ]
216222
217223 # we try to check if we have already sampled
218224 try :
219- already_sampled = [self .input_pts [location ]]
225+ already_sampled = [self .input_pts [d ]]
220226 # if we have not sampled, a key error is thrown
221227 except KeyError :
222228 already_sampled = []
@@ -225,25 +231,27 @@ def discretise_domain(
225231 # but we want to sample again we set already_sampled
226232 # to an empty list since we need to sample again, and
227233 # self._have_sampled_points to False.
228- if self ._have_sampled_points [ location ]:
234+ if self ._discretized_domains [ d ]:
229235 already_sampled = []
230- self ._have_sampled_points [location ] = False
231-
236+ self ._discretized_domains [d ] = False
237+ print (condition .domain )
238+ print (d )
232239 # build samples
233240 samples = [
234- condition . location .sample (n = n , mode = mode , variables = variables )
241+ self . domains [ d ] .sample (n = n , mode = mode , variables = variables )
235242 ] + already_sampled
236243 pts = merge_tensors (samples )
237- self .input_pts [location ] = pts
244+ self .input_pts [d ] = pts
238245
239246 # the condition is sampled if input_pts contains all labels
240- if sorted (self .input_pts [location ].labels ) == sorted (
247+ if sorted (self .input_pts [d ].labels ) == sorted (
241248 self .input_variables
242249 ):
243- self ._have_sampled_points [location ] = True
244- self .input_pts [location ] = self .input_pts [location ].extract (
245- sorted (self .input_variables )
246- )
250+ # self._have_sampled_points[location] = True
251+ # self.input_pts[location] = self.input_pts[location].extract(
252+ # sorted(self.input_variables)
253+ # )
254+ self ._have_sampled_points [d ] = True
247255
248256 def add_points (self , new_points ):
249257 """
0 commit comments