-
Notifications
You must be signed in to change notification settings - Fork 247
Feat: Introduce abstract RateMonitor class for unified rate analysis #1657
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
5f71c6e
ce2a50a
f73bab1
ef56ce7
1b51606
129d27b
066f06b
63be5b7
2668486
6ef761e
f2aad63
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -2,21 +2,154 @@ | |||||
Module defining `PopulationRateMonitor`. | ||||||
""" | ||||||
|
||||||
from abc import ABC, abstractmethod | ||||||
|
||||||
import numpy as np | ||||||
|
||||||
from brian2.core.clocks import Clock | ||||||
from brian2.core.variables import Variables | ||||||
from brian2.groups.group import CodeRunner, Group | ||||||
from brian2.units.allunits import hertz, second | ||||||
from brian2.units.fundamentalunits import Quantity, check_units | ||||||
from brian2.utils.logger import get_logger | ||||||
|
||||||
__all__ = ["PopulationRateMonitor"] | ||||||
__all__ = ["PopulationRateMonitor", "RateMonitor"] | ||||||
|
||||||
|
||||||
logger = get_logger(__name__) | ||||||
|
||||||
|
||||||
class PopulationRateMonitor(Group, CodeRunner): | ||||||
class RateMonitor(CodeRunner, Group, Clock, ABC): | ||||||
""" | ||||||
Abstract base class for monitors that record rates. | ||||||
""" | ||||||
|
||||||
@abstractmethod | ||||||
@check_units(bin_size=second) | ||||||
def binned(self, bin_size): | ||||||
""" | ||||||
Return the rate calculated in bins of a certain size. | ||||||
|
||||||
Parameters | ||||||
------------- | ||||||
bin_size : `Quantity` | ||||||
The size of the bins in seconds. Should be a multiple of dt. | ||||||
|
||||||
Returns | ||||||
------- | ||||||
bins : `Quantity` | ||||||
The midpoints of the bins. | ||||||
binned_values : `Quantity` | ||||||
The binned values. For EventMonitor subclasses, this is a 2D array | ||||||
with shape (neurons, bins). For PopulationRateMonitor, this is a 1D array. | ||||||
""" | ||||||
raise NotImplementedError() | ||||||
|
||||||
@check_units(width=second) | ||||||
def smooth_rate(self, window="gaussian", width=None): | ||||||
""" | ||||||
Returns a smoothed out version of the firing rate(s). | ||||||
|
||||||
Parameters | ||||||
---------- | ||||||
window : str, ndarray | ||||||
The window to use for smoothing. Can be a string to chose a | ||||||
predefined window(`flat` for a rectangular, and `gaussian` | ||||||
for a Gaussian-shaped window). | ||||||
|
||||||
In this case the width of the window | ||||||
is determined by the `width` argument. Note that for the Gaussian | ||||||
window, the `width` parameter specifies the standard deviation of | ||||||
the Gaussian, the width of the actual window is `4*width + dt` | ||||||
(rounded to the nearest dt). For the flat window, the width is | ||||||
rounded to the nearest odd multiple of dt to avoid shifting the rate | ||||||
in time. | ||||||
Alternatively, an arbitrary window can be given as a numpy array | ||||||
(with an odd number of elements). In this case, the width in units | ||||||
of time depends on the ``dt`` of the simulation, and no `width` | ||||||
argument can be specified. The given window will be automatically | ||||||
normalized to a sum of 1. | ||||||
width : `Quantity`, optional | ||||||
The width of the ``window`` in seconds (for a predefined window). | ||||||
|
||||||
Returns | ||||||
------- | ||||||
rate : `Quantity` | ||||||
The smoothed firing rate(s) in Hz. For EventMonitor subclasses, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this docstring is slightly incorrect: the smoothing function does not use any binning, the resulting array will always have the size of the original number of time steps. |
||||||
this returns a 2D array with shape (neurons, time_bins). | ||||||
For PopulationRateMonitor, this returns a 1D array. | ||||||
Note that the rates are smoothed and not re-binned, i.e. the length | ||||||
of the returned array is the same as the length of the binned data | ||||||
and can be plotted against the bin centers from the ``binned`` method. | ||||||
""" | ||||||
if width is None and isinstance(window, str): | ||||||
raise TypeError("Need a width when using a predefined window.") | ||||||
if width is not None and not isinstance(window, str): | ||||||
raise TypeError("Can only specify a width for a predefined window") | ||||||
|
||||||
if isinstance(window, str): | ||||||
if window == "gaussian": | ||||||
# basically Gaussian theoretically spans infinite time, but practically it falls off quickly, | ||||||
# So we choose a window of +- 2*(Standard deviations), i.e 95% of gaussian curve | ||||||
|
||||||
width_dt = int( | ||||||
np.round(2 * width / self.clock.dt) | ||||||
) # Rounding only for the size of the window, not for the standard | ||||||
# deviation of the Gaussian | ||||||
window = np.exp( | ||||||
-np.arange(-width_dt, width_dt + 1) ** 2 | ||||||
* 1.0 # hack to ensure floating-point division :) | ||||||
/ (2 * (width / self.clock.dt) ** 2) | ||||||
) | ||||||
elif window == "flat": | ||||||
width_dt = int(np.round(width / (2 * self.clock.dt))) * 2 + 1 | ||||||
used_width = width_dt * self.clock.dt | ||||||
if abs(used_width - width) > 1e-6 * self.clock.dt: | ||||||
logger.info( | ||||||
f"width adjusted from {width} to {used_width}", | ||||||
"adjusted_width", | ||||||
once=True, | ||||||
) | ||||||
window = np.ones(width_dt) | ||||||
else: | ||||||
raise NotImplementedError(f'Unknown pre-defined window "{window}"') | ||||||
else: | ||||||
try: | ||||||
window = np.asarray(window) | ||||||
except TypeError: | ||||||
raise TypeError(f"Cannot use a window of type {type(window)}") | ||||||
if window.ndim != 1: | ||||||
raise TypeError("The provided window has to be one-dimensional.") | ||||||
if len(window) % 2 != 1: | ||||||
raise TypeError("The window has to have an odd number of values.") | ||||||
|
||||||
# Get the binned rates at the finest resolution | ||||||
_, binned_values = self.binned(bin_size=self.clock.dt) | ||||||
|
||||||
# Normalize the window | ||||||
window = window * 1.0 / sum(window) | ||||||
|
||||||
# Extract the raw numpy array from the Quantity (if it is one) | ||||||
if hasattr(binned_values, "dimensions"): | ||||||
binned_array = np.asarray(binned_values) | ||||||
else: | ||||||
binned_array = binned_values | ||||||
|
||||||
# So we need to handle for both 1D (PopulationRateMonitor) and 2D (EventMonitor) cases separately as `np.convolve()` only works with 1D arrays | ||||||
if binned_values.ndim == 1: | ||||||
# PopulationRateMonitor case - 1D convolution | ||||||
smoothed = np.convolve(binned_values, window, mode="same") | ||||||
else: | ||||||
# EventMonitor/SpikeMonitor case - convolve each neuron and then we return the smoothed 2D array ( neuron * bins ) | ||||||
num_neurons, num_bins = binned_array.shape | ||||||
smoothed = np.zeros((num_neurons, num_bins)) | ||||||
for i in range(num_neurons): | ||||||
smoothed[i, :] = np.convolve(binned_array[i, :], window, mode="same") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We previously depended on |
||||||
|
||||||
return Quantity(smoothed, dim=hertz.dim) | ||||||
|
||||||
|
||||||
class PopulationRateMonitor(RateMonitor): | ||||||
""" | ||||||
Record instantaneous firing rates, averaged across neurons from a | ||||||
`NeuronGroup` or other spike source. | ||||||
|
@@ -100,82 +233,48 @@ def reinit(self): | |||||
""" | ||||||
raise NotImplementedError() | ||||||
|
||||||
@check_units(width=second) | ||||||
def smooth_rate(self, window="gaussian", width=None): | ||||||
@check_units(bin_size=second) | ||||||
def binned(self, bin_size): | ||||||
""" | ||||||
smooth_rate(self, window='gaussian', width=None) | ||||||
|
||||||
Return a smooth version of the population rate. | ||||||
Return the population rate binned with the given bin size. | ||||||
|
||||||
Parameters | ||||||
---------- | ||||||
window : str, ndarray | ||||||
The window to use for smoothing. Can be a string to chose a | ||||||
predefined window(``'flat'`` for a rectangular, and ``'gaussian'`` | ||||||
for a Gaussian-shaped window). In this case the width of the window | ||||||
is determined by the ``width`` argument. Note that for the Gaussian | ||||||
window, the ``width`` parameter specifies the standard deviation of | ||||||
the Gaussian, the width of the actual window is ``4*width + dt`` | ||||||
(rounded to the nearest dt). For the flat window, the width is | ||||||
rounded to the nearest odd multiple of dt to avoid shifting the rate | ||||||
in time. | ||||||
Alternatively, an arbitrary window can be given as a numpy array | ||||||
(with an odd number of elements). In this case, the width in units | ||||||
of time depends on the ``dt`` of the simulation, and no ``width`` | ||||||
argument can be specified. The given window will be automatically | ||||||
normalized to a sum of 1. | ||||||
width : `Quantity`, optional | ||||||
The width of the ``window`` in seconds (for a predefined window). | ||||||
bin_size : `Quantity` | ||||||
The size of the bins in seconds. Should be a multiple of dt. | ||||||
|
||||||
Returns | ||||||
------- | ||||||
rate : `Quantity` | ||||||
The population rate in Hz, smoothed with the given window. Note that | ||||||
the rates are smoothed and not re-binned, i.e. the length of the | ||||||
returned array is the same as the length of the ``rate`` attribute | ||||||
and can be plotted against the `PopulationRateMonitor` 's ``t`` | ||||||
attribute. | ||||||
bins : `Quantity` | ||||||
The midpoints of the bins. | ||||||
binned_values : `Quantity` | ||||||
The binned population rates as a 1D array in Hz. | ||||||
""" | ||||||
if width is None and isinstance(window, str): | ||||||
raise TypeError("Need a width when using a predefined window.") | ||||||
if width is not None and not isinstance(window, str): | ||||||
raise TypeError("Can only specify a width for a predefined window") | ||||||
if ( | ||||||
bin_size / self.clock.dt | ||||||
) % 1 > 1e-6: # to make sure bin_size is an integer multiple of the internal time resolution dt | ||||||
raise ValueError("bin_size has to be a multiple of dt.") | ||||||
if bin_size == self.clock.dt: | ||||||
return self.t[:], self.rate | ||||||
|
||||||
if isinstance(window, str): | ||||||
if window == "gaussian": | ||||||
width_dt = int(np.round(2 * width / self.clock.dt)) | ||||||
# Rounding only for the size of the window, not for the standard | ||||||
# deviation of the Gaussian | ||||||
window = np.exp( | ||||||
-np.arange(-width_dt, width_dt + 1) ** 2 | ||||||
* 1.0 | ||||||
/ (2 * (width / self.clock.dt) ** 2) | ||||||
) | ||||||
elif window == "flat": | ||||||
width_dt = int(width / 2 / self.clock.dt) * 2 + 1 | ||||||
used_width = width_dt * self.clock.dt | ||||||
if abs(used_width - width) > 1e-6 * self.clock.dt: | ||||||
logger.info( | ||||||
f"width adjusted from {width} to {used_width}", | ||||||
"adjusted_width", | ||||||
once=True, | ||||||
) | ||||||
window = np.ones(width_dt) | ||||||
else: | ||||||
raise NotImplementedError(f'Unknown pre-defined window "{window}"') | ||||||
else: | ||||||
try: | ||||||
window = np.asarray(window) | ||||||
except TypeError: | ||||||
raise TypeError(f"Cannot use a window of type {type(window)}") | ||||||
if window.ndim != 1: | ||||||
raise TypeError("The provided window has to be one-dimensional.") | ||||||
if len(window) % 2 != 1: | ||||||
raise TypeError("The window has to have an odd number of values.") | ||||||
return Quantity( | ||||||
np.convolve(self.rate_, window * 1.0 / sum(window), mode="same"), | ||||||
dim=hertz.dim, | ||||||
) | ||||||
num_bins = int(self.clock.t / bin_size) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
This is safer for numerical edge cases. |
||||||
bins = ( | ||||||
np.arange(num_bins) * bin_size + bin_size / 2.0 | ||||||
) # as we want Bin centers (not edges) | ||||||
|
||||||
t_indices = (self.t / self.clock.dt).astype(int) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Same here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, let's discuss this during our meeting, the issue is slightly more complicated than this. |
||||||
bin_indices = (t_indices * self.clock.dt / bin_size).astype(int) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
binned_values = np.zeros(num_bins) # to store total firing rate values per bin | ||||||
bin_counts = np.zeros(num_bins) # to store how many samples went into each bin | ||||||
|
||||||
np.add.at(binned_values, bin_indices, self.rate) | ||||||
np.add.at(bin_counts, bin_indices, 1) | ||||||
|
||||||
# Avoid division by zero for empty bins | ||||||
non_empty_bins = bin_counts > 0 | ||||||
binned_values[non_empty_bins] /= bin_counts[non_empty_bins] | ||||||
return bins, Quantity(binned_values, dim=hertz.dim) | ||||||
Comment on lines
+267
to
+277
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel that this complicates things a bit too much – we shouldn't have to count the number of values that go into each bin (each bin should have >>> print(ar.reshape(-1, 3).mean(axis=1))
[2. 4. 2. 0.] |
||||||
|
||||||
def __repr__(self): | ||||||
classname = self.__class__.__name__ | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,13 +7,15 @@ | |
from brian2.core.names import Nameable | ||
from brian2.core.spikesource import SpikeSource | ||
from brian2.core.variables import Variables | ||
from brian2.groups.group import CodeRunner, Group | ||
from brian2.units.fundamentalunits import Quantity | ||
from brian2.groups.group import CodeRunner | ||
from brian2.monitors.ratemonitor import RateMonitor | ||
from brian2.units.allunits import hertz, second | ||
from brian2.units.fundamentalunits import Quantity, check_units | ||
|
||
__all__ = ["EventMonitor", "SpikeMonitor"] | ||
|
||
|
||
class EventMonitor(Group, CodeRunner): | ||
class EventMonitor(RateMonitor): | ||
""" | ||
Record events from a `NeuronGroup` or another event source. | ||
|
||
|
@@ -383,6 +385,71 @@ def event_trains(self): | |
""" | ||
return self.values("t") | ||
|
||
@check_units(bin_size=second) | ||
def binned(self, bin_size): | ||
""" | ||
Return the event rates binned with the given bin size. | ||
|
||
Parameters | ||
---------- | ||
bin_size : `Quantity` | ||
The size of the bins in seconds. Should be a multiple of dt. | ||
|
||
Returns | ||
------- | ||
bins : `Quantity` | ||
The midpoints of the bins. | ||
binned_values : `Quantity` | ||
The binned rates as a 2D array (neurons × bins) in Hz. | ||
""" | ||
if (bin_size / self.clock.dt) % 1 > 1e-6: | ||
raise ValueError("bin_size has to be a multiple of dt.") | ||
|
||
# Get the total duration and number of bins | ||
duration = self.clock.t | ||
num_bins = int(duration / bin_size) | ||
bins = np.arange(num_bins) * bin_size + bin_size / 2 # As we want bin centers | ||
|
||
# Now we determine the number of neurons ( as the moniter can be only monitoring a subset of the actual Group of Neurons) | ||
if hasattr(self.source, "start") and hasattr(self.source, "stop"): | ||
# this is the case of monitoring a subgroup | ||
num_neurons = self.source.stop - self.source.start | ||
neuron_offset = ( | ||
self.source.start | ||
) # needed for calulations as we want to know from which index to start the calculations for binning from | ||
else: | ||
# Case where we are monitoring the whole Group | ||
num_neurons = len(self.source) | ||
neuron_offset = 0 # no offset as we are monitoring the whole Group | ||
|
||
# Now we initialize the binned values array (neurons × bins) | ||
binned_values = np.zeros((num_neurons, num_bins)) | ||
if self.record: | ||
# Get the event times and indices | ||
event_times = self.t[:] | ||
event_indices = ( | ||
self.i[:] - neuron_offset | ||
) # Adjust for subgroups as stated above | ||
|
||
bin_indices = (event_times / bin_size).astype(int) | ||
# Now this is the main core code , here we count the events in each bin that happened for each neuron | ||
# Like after this we should have something like : | ||
# Example : | ||
# binned_values = [ | ||
# [2.0, 0.0, 1.0, 0.0, 0.0], # Neuron 0: 2 in bin 0, 1 in bin 2 | ||
# [0.0, 1.0, 0.0, 1.0, 0.0], # Neuron 1: 1 in bin 1, 1 in bin 3 | ||
# [0.0, 2.0, 0.0, 0.0, 1.0] # Neuron 2: 2 in bin 1, 1 in bin 4 | ||
# ] | ||
for event_idx, neuron_idx in enumerate(event_indices): | ||
if 0 <= neuron_idx < num_neurons: # sanity check | ||
bin_idx = bin_indices[event_idx] | ||
if bin_idx < num_bins: # To handle edge case at the end | ||
binned_values[neuron_idx, bin_idx] += 1 | ||
Comment on lines
+426
to
+447
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we do not care about half-filled bins in the end of the array (and I think we should not), then I think all of this can be replaced by a call to numpy's |
||
|
||
# Convert counts to rates (Hz) | ||
binned_values = binned_values / float(bin_size) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Something I didn't think about earlier: if we don't return spike counts but rates here (which probably makes sense given that we are inheriting from RateMonitor, and that this is the behaviour in the |
||
return bins, Quantity(binned_values, dim=hertz.dim) | ||
|
||
@property | ||
def num_events(self): | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A bit confused: why does
RateMonitor
has to inherit fromClock
? (It has a clock, but it isn't a clock)