From 84fddc8416317ca14547bc08dce74b9be7b20c1f Mon Sep 17 00:00:00 2001 From: Sagar Shahari Date: Tue, 8 Apr 2025 02:30:38 +0530 Subject: [PATCH 1/2] Update neurongroup.py --- brian2/groups/neurongroup.py | 52 +++++++++++++++++++++++++++++++++++- 1 file changed, 51 insertions(+), 1 deletion(-) diff --git a/brian2/groups/neurongroup.py b/brian2/groups/neurongroup.py index b8f3e7f6e..2f2244256 100644 --- a/brian2/groups/neurongroup.py +++ b/brian2/groups/neurongroup.py @@ -518,7 +518,7 @@ def __init__( method_options=None, threshold=None, reset=None, - refractory=False, + refractory=None, # Updated to None instead of False events=None, namespace=None, dtype=None, @@ -576,6 +576,26 @@ def __init__( } ) + # Handle events + if events is None: + events = {} + self.events = {'spike': threshold} if threshold else {} + self.events.update(events) + + # Handle refractory + if refractory is not None: + if isinstance(refractory, (str, Quantity)): + refractory = {'spike': refractory} + elif isinstance(refractory, dict): + for event in refractory: + if event not in self.events: + raise ValueError(f"Unknown event '{event}' in refractory dictionary.") + else: + raise TypeError("refractory must be a string, Quantity, or dictionary") + else: + refractory = {} + self._refractory = refractory + # add refractoriness #: The original equations as specified by the user (i.e. without #: the multiplied `int(not_refractory)` term for equations marked as @@ -836,6 +856,36 @@ def _create_variables(self, user_dtype, events): else: raise AssertionError(f"Unknown type of equation: {eq.eq_type}") + # refractory variable setup + if 'spike' in self.events: + self.variables.add_array('_lastspike', size=self.N, dtype=float, constant=False, value=-1e100) + + for event, refr in self._refractory.items(): + event_name = event.replace(' ', '_') + self.variables.add_array(f'_lastevent_{event_name}', size=self.N, dtype=float, + constant=False, value=-1e100) + + if isinstance(refr, Quantity): + self.variables.add_array(f'_refractory_until_{event_name}', size=self.N, dtype=float, + constant=False, value=-1e100) + self.variables.add_subexpression(f'not_refractory_{event_name}', + f't >= _refractory_until_{event_name}') + elif isinstance(refr, str): + self.variables.add_subexpression(f'not_refractory_{event_name}', f'not ({refr})') + + + for eq in self.equations.values(): + if eq.type == DIFFERENTIAL_EQUATION and "unless refractory" in eq.flags: + not_refractory_var = self.variables[f'not_refractory_{event_name}'] + var = self.variables[eq.varname] + var.set_conditional_write(not_refractory_var) + + # Events without refractory + for event in self.events: + event_name = event.replace(' ', '_') + if event not in self._refractory: + self.variables.add_subexpression(f'not_refractory_{event_name}', 'True') + # Add the conditional-write attribute for variables with the # "unless refractory" flag if self._refractory is not False: From fecec97834f8ce07ed9ba6128a725e9eb3dfb874 Mon Sep 17 00:00:00 2001 From: Sagar Shahari Date: Fri, 11 Apr 2025 20:17:22 +0530 Subject: [PATCH 2/2] Added special case spike event variables for backward compatibility --- brian2/groups/neurongroup.py | 79 ++++++++++++++++++++---------------- 1 file changed, 44 insertions(+), 35 deletions(-) diff --git a/brian2/groups/neurongroup.py b/brian2/groups/neurongroup.py index 2f2244256..78944f6e3 100644 --- a/brian2/groups/neurongroup.py +++ b/brian2/groups/neurongroup.py @@ -814,9 +814,12 @@ def _create_variables(self, user_dtype, events): # Standard variables always present for event in events: - self.variables.add_array( - f"_{event}space", size=self._N + 1, dtype=np.int32, constant=False - ) + if event == 'spike': + self.variables.add_array("_spikespace", size=self._N + 1, dtype=np.int32, constant=False) + else: + eventspace_name = f"_{event.replace(' ', '_')}space" + self.variables.add_array(eventspace_name, size=self._N + 1, dtype=np.int32, constant=False) + # Add the special variable "i" which can be used to refer to the neuron index self.variables.add_arange("i", size=self._N, constant=True, read_only=True) # Add the clock variables @@ -856,42 +859,48 @@ def _create_variables(self, user_dtype, events): else: raise AssertionError(f"Unknown type of equation: {eq.eq_type}") - # refractory variable setup + # refractory variable setup for spike event if 'spike' in self.events: - self.variables.add_array('_lastspike', size=self.N, dtype=float, constant=False, value=-1e100) - - for event, refr in self._refractory.items(): - event_name = event.replace(' ', '_') - self.variables.add_array(f'_lastevent_{event_name}', size=self.N, dtype=float, - constant=False, value=-1e100) - - if isinstance(refr, Quantity): - self.variables.add_array(f'_refractory_until_{event_name}', size=self.N, dtype=float, - constant=False, value=-1e100) - self.variables.add_subexpression(f'not_refractory_{event_name}', - f't >= _refractory_until_{event_name}') - elif isinstance(refr, str): - self.variables.add_subexpression(f'not_refractory_{event_name}', f'not ({refr})') - - - for eq in self.equations.values(): - if eq.type == DIFFERENTIAL_EQUATION and "unless refractory" in eq.flags: - not_refractory_var = self.variables[f'not_refractory_{event_name}'] - var = self.variables[eq.varname] - var.set_conditional_write(not_refractory_var) + self.variables.add_array('lastspike', size=self._N, dtype=float, constant=False, value=-1e100) + if self._refractory.get('spike', False): + self.variables.add_subexpression('not_refractory', 'True') + # For other events + for event in self.events: + if event != 'spike' and event in self._refractory: + event_name = event.replace(' ', '_') + self.variables.add_array(f'_lastevent_{event_name}', size=self._N, dtype=float, + constant=False, value=-1e100) + refr = self._refractory[event] + if isinstance(refr, Quantity): + self.variables.add_array(f'_refractory_until_{event_name}', size=self._N, dtype=float, + constant=False, value=-1e100) + self.variables.add_subexpression(f'not_refractory_{event_name}', + f't >= _refractory_until_{event_name}') + elif isinstance(refr, str): + self.variables.add_subexpression(f'not_refractory_{event_name}', f'not ({refr})') + # Events without refractory for event in self.events: - event_name = event.replace(' ', '_') - if event not in self._refractory: - self.variables.add_subexpression(f'not_refractory_{event_name}', 'True') - - # Add the conditional-write attribute for variables with the - # "unless refractory" flag - if self._refractory is not False: - for eq in self.equations.values(): - if eq.type == DIFFERENTIAL_EQUATION and "unless refractory" in eq.flags: - not_refractory_var = self.variables["not_refractory"] + event_name = event.replace(' ', '_') + if event not in self._refractory: + self.variables.add_subexpression(f'not_refractory_{event_name}', 'True') + + if 'spike' in self.events and 'spike' in self._refractory: + refr = self._refractory['spike'] + if isinstance(refr, Quantity): + self.variables.add_array('_refractory_until', size=self._N, dtype=float, + constant=False, value=-1e100) + self.variables.add_subexpression('not_refractory', 't >= _refractory_until') + elif isinstance(refr, str): + self.variables.add_subexpression('not_refractory', f'not ({refr})') + + for eq in self.equations.values(): + if eq.type == DIFFERENTIAL_EQUATION and "unless refractory" in eq.flags: + for event in self.events: + event_name = event.replace(' ', '_') + not_refractory_var = self.variables.get(f'not_refractory_{event_name}', None) + if not_refractory_var: var = self.variables[eq.varname] var.set_conditional_write(not_refractory_var)