diff --git a/brian2/groups/neurongroup.py b/brian2/groups/neurongroup.py index 5058e1bcc..877e3a27c 100644 --- a/brian2/groups/neurongroup.py +++ b/brian2/groups/neurongroup.py @@ -16,13 +16,16 @@ from brian2.core.spikesource import SpikeSource from brian2.core.variables import (Variables, LinkedVariable, DynamicArrayVariable, Subexpression) +from brian2.core.namespace import get_local_namespace from brian2.equations.equations import (Equations, DIFFERENTIAL_EQUATION, SUBEXPRESSION, PARAMETER, check_subexpressions, - extract_constant_subexpressions) + extract_constant_subexpressions, + SingleEquation, Expression) from brian2.equations.refractory import add_refractoriness from brian2.parsing.expressions import (parse_expression_dimensions, is_boolean_expression) +from brian2.parsing.sympytools import str_to_sympy, sympy_to_str from brian2.stateupdaters.base import StateUpdateMethod from brian2.units.allunits import second from brian2.units.fundamentalunits import (Quantity, Unit, DIMENSIONLESS, @@ -31,10 +34,16 @@ fail_for_dimension_mismatch) from brian2.utils.logger import get_logger from brian2.utils.stringtools import get_identifiers - +from brian2.codegen.runtime.numpy_rt.numpy_rt import NumpyCodeObject from .group import Group, CodeRunner, get_dtype from .subgroup import Subgroup +try: + from scipy.optimize import root + scipy_available = True +except ImportError: + scipy_available = False + __all__ = ['NeuronGroup'] logger = get_logger(__name__) @@ -920,3 +929,228 @@ def add_event_to_text(event): add_event_to_text(event) return '\n'.join(text) + + def resting_state(self, x0 = {}): + ''' + Calculate resting state of the system. + + Parameters + ---------- + x0 : dict + Initial guess for the state variables. If any of the system's state variables are not + added, default value of 0 is mapped as the initial guess to the missing state variables. + Note: Time elapsed to locate the resting state would be lesser for better initial guesses. + + Returns + ------- + rest_state : dict + Dictioary with pair of state variables and resting state values. Returned values + are represented in SI units. + ''' + # check scipy availability + if not scipy_available: + raise NotImplementedError("Scipy is not available for using `scipy.optimize.root()`") + # check state variables defined in initial guess are valid + if(x0.keys() - self.equations.diff_eq_names): + raise KeyError("Unknown State Variable: {}".format(next(iter(x0.keys() - + self.equations.diff_eq_names)))) + + # Add 0 as the intial value for non-mentioned state variables in x0 + x0.update({name : 0 for name in self.equations.diff_eq_names - x0.keys()}) + + # sort dictionary items + state_dict = dict(sorted(x0.items())) + + # helper functions to create NeuronGroup object of corresponding equation + # For example: _rhs_equation() returns NeuronGroup object with equations representing + # Right-Hand-Side of self.equations and _jacobian_equation() returns NeuronGroup object + # with equations of jacobian matrix + rhs_states, rhs_group = _rhs_equation(self.equations, get_local_namespace(1)) + jac_variables, jac_group = _jacobian_equation(self.equations, self.variables, get_local_namespace(1)) + + # solver function with _wrapper() as the callable function to be optimized + result = root(_wrapper, list(state_dict.values()), args = (rhs_states, rhs_group, jac_variables, + jac_group, state_dict.keys()), jac = True) + + # check the result message for the status of convergence + if result.success == False: + raise Exception("Root calculation failed to converge. Poor initial guess may be the cause of the failure") + + # evaluate the solution states to get state variables of jacobian + jac_state = _evaluate_states(jac_group, dict(zip(state_dict.keys(), result.x)), list(jac_variables.reshape(-1))) + + # with the state values, prepare jacobian matrix + jac_matrix = np.zeros(jac_variables.shape) + + for row in range(jac_variables.shape[0]): + for col in range(jac_variables.shape[1]): + jac_matrix[row, col] = float(jac_state[jac_variables[row, col]]) + + # check whether the solution is stable by using sign of eigenvalues + jac_eig = np.linalg.eigvals(jac_matrix) + if not np.all(np.real(jac_eig) < 0): + raise Exception('Equilibrium is not stable. Failed to converge to stable equilibrium') + + # return the soultion in dictionary form + return dict(zip(state_dict.keys(), result.x)) + +def _rhs_equation(eqs, namespace = None, level = 0): + + """ + Extract the RHS of a system of differential equations. External constants + can be provided via the namespace or will be taken from the local namespace. + Make a new set of equations, where differential equations are replaced by parameters, + and a new subexpression defines their RHS. + + E.g. for 'dv/dt = -v / tau : volt' use: + '''v : volt + RHS_v = -v / tau : volt''' + + This function could be used to find a resting state of the + system, i.e. a fixed point where the RHS of all equations are approximately 0. + + Parameters + ---------- + eqs : `Equations` + The equations + + Returns + ------- + rhs_states : list + A list with the names of all variables defined as RHS of the equations + rhs_group : `NeuronGroup` + The NeuronGroup object + """ + + if namespace is None: + namespace = get_local_namespace(level+1) + + rhs_equations = [] + for eq in eqs.values(): + if eq.type == DIFFERENTIAL_EQUATION: + rhs_equations.append(SingleEquation(PARAMETER, eq.varname, + dimensions=eq.dim, + var_type=eq.var_type)) + rhs_equations.append(SingleEquation(SUBEXPRESSION, 'RHS_'+eq.varname, + dimensions=eq.dim/second.dim, + var_type=eq.var_type, + expr=eq.expr)) + else: + rhs_equations.append(eq) + + # NeuronGroup with the obtained rhs_equations + rhs_group = NeuronGroup(1, model = Equations(rhs_equations), + codeobj_class = NumpyCodeObject, + namespace = namespace) + # states corresponding to RHS of the system of differential equations + rhs_states = ['RHS_' + name for name in eqs.diff_eq_names] + + return (rhs_states, rhs_group) + +def _jacobian_equation(eqs, group_variables, namespace = None, level = 0): + + """ + Create jacobain expressions of a system of differential equations. External constants + can be provided via the namespace or will be taken from the local namespace. + Make a new set of equations, where differential equations are replaced by parameters, + and a new subexpression defines their jacobain expression. + + This function could be used to find a resting state of the + system and check its stability + + Parameters + ---------- + eqs : `Equations` + Equations of the parent NeuronGroup + group_variables : `Variables` + Variables of the parent NeuronGroup + + Returns + ------- + jac_matrix_variables : `2D-NumPy array` + 2D- matrix of jacobian variables. + For example: jac_matrix_variables of model with two variables: u and v would be, + np.array([[J_u_u J_u_v], + [J_v_u J_v_v]]) + jac_group : `NeuronGroup` + The NeuronGroup object + """ + + if namespace is None: + namespace = get_local_namespace(level+1) + # prepare jac_eqs + diff_eqs = eqs.get_substituted_expressions(group_variables) + diff_eq_names = [name for name, _ in diff_eqs] + system = sympy.Matrix([str_to_sympy(diff_eq[1].code) + for diff_eq in diff_eqs]) + J = system.jacobian([str_to_sympy(d) for d in diff_eq_names]) + jac_eqs = [] + for diff_eq_name, diff_eq in diff_eqs: + jac_eqs.append(SingleEquation(PARAMETER, diff_eq_name, + dimensions=eqs[diff_eq_name].dim, + var_type=eqs[diff_eq_name].var_type)) + for var_idx, diff_eq_var in enumerate(diff_eq_names): + for diff_idx, diff_eq_diff in enumerate(diff_eq_names): + dimensions = eqs[diff_eq_var].dim/second.dim/eqs[diff_eq_diff].dim + expr = f'{sympy_to_str(J[var_idx, diff_idx])}' + if expr == '0': + expr = f'0*{dimensions!r}' + jac_eqs.append(SingleEquation(SUBEXPRESSION, f'J_{diff_eq_var}_{diff_eq_diff}', + dimensions=dimensions, + expr=Expression(expr))) + # NeuronGroup with the obtained jac_eqs + jac_group = NeuronGroup(1, model = Equations(jac_eqs), + codeobj_class = NumpyCodeObject, + namespace = namespace) + # prepare 2D matrix of jacobian variables + jac_matrix_variables = np.array( + [[f'J_{var}_{diff_var}' for diff_var in diff_eq_names] + for var in diff_eq_names]) + + return (jac_matrix_variables, jac_group) + +def _evaluate_states(group, values, states): + + """ + Evaluate the set of states when given values are set. + The function gets NeuronGroup object and set the given values to it; + and returns the values of states given + + Parameters + ---------- + group : `NeuronGroup` + The NeuronGroup + values : dict-like + Values of states to be set to group + states: list + State variables for which values have to be get + + Returns + ------- + state_values : dict-like + Dictionary of state variables and their values + """ + + group.set_states(values, units = False) + state_values = group.get_states(states) + return state_values + +def _wrapper(args, rhs_states, rhs_group, jac_variables, jac_group, diff_eq_names): + """ + Vector function for which root needs to be calculated. Callable function of `scipy.optimize.root()` + """ + # match the argument values with correct variables + sorted_variable_dict = {name : arg for name, arg in zip(sorted(diff_eq_names), args)} + # get the values of `rhs_states` when given values(sorted_variable_dict) are set to rhs_group + rhs = _evaluate_states(rhs_group, sorted_variable_dict, rhs_states) + # get the values of `jac_varaibles` when given values(sorted_variable_dict) are set to jac_group + jac = _evaluate_states(jac_group, sorted_variable_dict, list(jac_variables.reshape(-1))) + + # with the values prepare jacobian matrix + jac_matrix = np.zeros(jac_variables.shape) + for row in range(jac_variables.shape[0]): + for col in range(jac_variables.shape[1]): + jac_matrix[row, col] = float(jac[jac_variables[row, col]]) + + return [float(rhs['RHS_{}'.format(name)]) for name in sorted(diff_eq_names)], jac_matrix + diff --git a/brian2/tests/test_neurongroup.py b/brian2/tests/test_neurongroup.py index 85bacf38b..a58e00950 100644 --- a/brian2/tests/test_neurongroup.py +++ b/brian2/tests/test_neurongroup.py @@ -23,10 +23,16 @@ from brian2.units.allunits import second, volt from brian2.units.fundamentalunits import (DimensionMismatchError, have_same_dimensions) -from brian2.units.stdunits import ms, mV, Hz +from brian2.units.stdunits import ms, mV, Hz, cm, msiemens, nA from brian2.units.unitsafefunctions import linspace +from brian2.units.allunits import second, volt, umetre, siemens, ufarad from brian2.utils.logger import catch_logs +try: + import scipy + scipy_available = True +except ImportError: + scipy_available = False @pytest.mark.codegen_independent def test_creation(): @@ -1716,6 +1722,85 @@ def test_semantics_mod(): assert_allclose(G.x[:], float_values % 3) assert_allclose(G.y[:], float_values % 3) +def test_simple_resting_value(): + """ + Test the resting state values of the system + """ + # simple model with single dependent variable, here it is not necessary + # to run the model as the resting value is certain + El = - 100 + tau = 1 * ms + eqs = ''' + dv/dt = (El - v)/tau : 1 + ''' + grp = NeuronGroup(1, eqs, method = 'exact') + resting_state = grp.resting_state() + assert_allclose(resting_state['v'], El) + + # one more example + area = 100 * umetre ** 2 + g_L = 1e-2 * siemens * cm ** -2 * area + E_L = 1000 + Cm = 1 * ufarad * cm ** -2 * area + grp = NeuronGroup(10, '''dv/dt = I_leak / Cm : volt + I_leak = g_L*(E_L - v) : amp''') + resting_state = grp.resting_state({'v': float(10000)}) + assert_allclose(resting_state['v'], E_L) + +def test_failed_resting_state(): + # check the failed to converge system is correctly notified to the user + area = 20000 * umetre ** 2 + Cm = 1 * ufarad * cm ** -2 * area + gl = 5e-5 * siemens * cm ** -2 * area + El = -65 * mV + EK = -90 * mV + ENa = 50 * mV + g_na = 100 * msiemens * cm ** -2 * area + g_kd = 30 * msiemens * cm ** -2 * area + VT = -63 * mV + I = 0.01*nA + eqs = Equations(''' + dv/dt = (gl*(El-v) - g_na*(m*m*m)*h*(v-ENa) - g_kd*(n*n*n*n)*(v-EK) + I)/Cm : volt + dm/dt = 0.32*(mV**-1)*(13.*mV-v+VT)/ + (exp((13.*mV-v+VT)/(4.*mV))-1.)/ms*(1-m)-0.28*(mV**-1)*(v-VT-40.*mV)/ + (exp((v-VT-40.*mV)/(5.*mV))-1.)/ms*m : 1 + dn/dt = 0.032*(mV**-1)*(15.*mV-v+VT)/ + (exp((15.*mV-v+VT)/(5.*mV))-1.)/ms*(1.-n)-.5*exp((10.*mV-v+VT)/(40.*mV))/ms*n : 1 + dh/dt = 0.128*exp((17.*mV-v+VT)/(18.*mV))/ms*(1.-h)-4./(1+exp((40.*mV-v+VT)/(5.*mV)))/ms*h : 1 + ''') + group = NeuronGroup(1, eqs, method='exponential_euler') + group.v = -70*mV + # very poor choice of initial values causing the convergence to fail + with pytest.raises(Exception): + group.resting_state({'v': 0, 'm': 100000000, 'n': 1000000, 'h': 100000000}) + +def test_unstable_resting_state(): + + # check the unstability of the converged solution + area = 20000 * umetre ** 2 + Cm = 1 * ufarad * cm ** -2 * area + gl = 5e-5 * siemens * cm ** -2 * area + El = -65 * mV + EK = -90 * mV + ENa = 50 * mV + g_na = 100 * msiemens * cm ** -2 * area + g_kd = 30 * msiemens * cm ** -2 * area + VT = -63 * mV + I = 0.01*nA + eqs = Equations(''' + dv/dt = (gl*(El-v) - g_na*(m*m*m)*h*(v-ENa) - g_kd*(n*n*n*n)*(v-EK) + I)/Cm : volt + dm/dt = 0.32*(mV**-1)*(13.*mV-v+VT)/ + (exp((13.*mV-v+VT)/(4.*mV))-1.)/ms*(1-m)-0.28*(mV**-1)*(v-VT-40.*mV)/ + (exp((v-VT-40.*mV)/(5.*mV))-1.)/ms*m : 1 + dn/dt = 0.032*(mV**-1)*(15.*mV-v+VT)/ + (exp((15.*mV-v+VT)/(5.*mV))-1.)/ms*(1.-n)-.5*exp((10.*mV-v+VT)/(40.*mV))/ms*n : 1 + dh/dt = 0.128*exp((17.*mV-v+VT)/(18.*mV))/ms*(1.-h)-4./(1+exp((40.*mV-v+VT)/(5.*mV)))/ms*h : 1 + ''') + group = NeuronGroup(1, eqs, method='exponential_euler') + group.v = -70*mV + # converging to unstable solution + with pytest.raises(Exception): + group.resting_state() if __name__ == '__main__': test_set_states() @@ -1792,3 +1877,8 @@ def test_semantics_mod(): test_semantics_floor_division() test_semantics_floating_point_division() test_semantics_mod() + if scipy_available: + test_simple_resting_value() + test_failed_resting_state() + test_unstable_resting_state() +