diff --git a/.github/workflows/push_pr.yml b/.github/workflows/push_pr.yml index aa559d7b..00705e57 100644 --- a/.github/workflows/push_pr.yml +++ b/.github/workflows/push_pr.yml @@ -39,4 +39,4 @@ jobs: shell: bash -l {0} run: | pip install . - pytest pypulseq/tests + pytest -m "not matlab_seq_comp" pypulseq/tests diff --git a/pypulseq/make_block_pulse.py b/pypulseq/make_block_pulse.py index 34f399d6..406889d6 100644 --- a/pypulseq/make_block_pulse.py +++ b/pypulseq/make_block_pulse.py @@ -1,5 +1,6 @@ from types import SimpleNamespace from typing import Tuple, Union +from warnings import warn import numpy as np @@ -11,29 +12,36 @@ def make_block_pulse( flip_angle: float, - bandwidth: float = 0, delay: float = 0, - duration: float = 4e-3, + duration: float = None, + bandwidth: float = None, + time_bw_product: float = None, freq_offset: float = 0, phase_offset: float = 0, return_delay: bool = False, system: Opts = Opts(), - time_bw_product: float = 0, use: str = str(), ) -> Union[SimpleNamespace, Tuple[SimpleNamespace, SimpleNamespace]]: """ - Create a block pulse with optional slice selectiveness. + Create a block (RECT or hard) pulse. + + Define duration, or bandwidth, or bandwidth and time_bw_product. + If none are provided a default 4 ms pulse will be generated. Parameters ---------- flip_angle : float Flip angle in radians. - bandwidth : float, default=0 - Bandwidth in Hertz (hz). delay : float, default=0 - Delay in seconds (s) of accompanying slice select trapezoidal event. - duration : float, default=4e-3 + Delay in seconds (s). + duration : float, default=None Duration in seconds (s). + bandwidth : float, default=None + Bandwidth in Hertz (Hz). + If supplied without time_bw_product duration = 1 / (4 * bandwidth) + time_bw_product : float, default=None + Time-bandwidth product. + If supplied with bandwidth, duration = time_bw_product / bandwidth freq_offset : float, default=0 Frequency offset in Hertz (Hz). phase_offset : float, default=0 @@ -42,40 +50,57 @@ def make_block_pulse( Boolean flag to indicate if the delay event has to be returned. system : Opts, default=Opts() System limits. - time_bw_product : float, default=0 - Time-bandwidth product. use : str, default=str() - Use of radio-frequency block pulse event. Must be one of 'excitation', 'refocusing' or 'inversion'. + Use of radio-frequency block pulse event. Returns ------- rf : SimpleNamespace Radio-frequency block pulse event. delay : SimpleNamespace, optional - Slice select trapezoidal gradient event accompanying the radio-frequency block pulse event. + Delay event. Raises ------ ValueError - If invalid `use` parameter is passed. Must be one of 'excitation', 'refocusing' or 'inversion'. - If neither `bandwidth` nor `duration` are passed. - If `return_gz=True`, and `slice_thickness` is not passed. + If invalid `use` parameter is passed. + One of bandwidth or duration must be defined, but not both. + One of bandwidth or duration must be defined and be > 0. """ valid_use_pulses = get_supported_rf_uses() if use != "" and use not in valid_use_pulses: raise ValueError( - f"Invalid use parameter. Must be one of 'excitation', 'refocusing' or 'inversion'. Passed: {use}" + "Invalid use parameter. " + f"Must be one of {valid_use_pulses}. Passed: {use}" ) - if duration == 0: - if time_bw_product > 0: + if duration is None and bandwidth is None: + warn('Using default 4 ms duration for block pulse.') + duration = 4E-3 + elif duration is not None and bandwidth is not None\ + and duration > 0: + # Multiple arguments + raise ValueError( + "One of bandwidth or duration must be defined, but not both.") + elif duration is not None\ + and duration > 0: + # Explicitly handle this most expected case. + # There is probably a better way of writing this if block + pass + elif duration is None\ + and bandwidth is not None\ + and bandwidth > 0: + if time_bw_product is not None\ + and time_bw_product > 0: duration = time_bw_product / bandwidth - elif bandwidth > 0: - duration = 1 / (4 * bandwidth) else: - raise ValueError("Either bandwidth or duration must be defined") + duration = 1 / (4 * bandwidth) + else: + # Invalid arguments + raise ValueError( + "One of bandwidth or duration must be defined and be > 0. " + f"duration = {duration} s, bandwidth = {bandwidth} Hz.") - BW = 1 / (4 * duration) N = round(duration / system.rf_raster_time) t = np.array([0, N]) * system.rf_raster_time signal = flip_angle / (2 * np.pi) / duration * np.ones_like(t) diff --git a/pypulseq/make_sigpy_pulse.py b/pypulseq/make_sigpy_pulse.py index 34e6f505..99e975e6 100644 --- a/pypulseq/make_sigpy_pulse.py +++ b/pypulseq/make_sigpy_pulse.py @@ -26,6 +26,7 @@ def sigpy_n_seq( time_bw_product: float = 4, pulse_cfg: SigpyPulseOpts = SigpyPulseOpts(), use: str = str(), + plot: bool = True, ) -> Union[SimpleNamespace, Tuple[SimpleNamespace, SimpleNamespace, SimpleNamespace]]: """ Creates a radio-frequency sinc pulse event using the sigpy rf pulse library and optionally accompanying slice select, slice select rephasing @@ -62,6 +63,8 @@ def sigpy_n_seq( Time-bandwidth product. use : str, optional, default=str() Use of radio-frequency sinc pulse. Must be one of 'excitation', 'refocusing' or 'inversion'. + plot: bool, optional, default=True + Show sigpy plot outputs Returns ------- @@ -92,7 +95,7 @@ def sigpy_n_seq( duration=duration, system=system, pulse_cfg=pulse_cfg, - disp=True, + disp=plot, ) if pulse_cfg.pulse_type == "sms": [signal, t, pulse] = make_sms( @@ -101,7 +104,7 @@ def sigpy_n_seq( duration=duration, system=system, pulse_cfg=pulse_cfg, - disp=True, + disp=plot, ) rfp = SimpleNamespace() diff --git a/pypulseq/tests/pytest.ini b/pypulseq/tests/pytest.ini new file mode 100644 index 00000000..0f760eaa --- /dev/null +++ b/pypulseq/tests/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +markers = + matlab_seq_comp: marks tests as comparison with matlab generated sequence (deselect with '-m "not matlab_seq_comp"') \ No newline at end of file diff --git a/pypulseq/tests/test_MPRAGE.py b/pypulseq/tests/test_MPRAGE.py index 59eecb0d..22c11ec9 100644 --- a/pypulseq/tests/test_MPRAGE.py +++ b/pypulseq/tests/test_MPRAGE.py @@ -3,7 +3,10 @@ from pypulseq.seq_examples.scripts import write_MPRAGE from pypulseq.tests import base +import pytest + +@pytest.mark.matlab_seq_comp class TestMPRAGE(unittest.TestCase): def test_write_epi(self): matlab_seq_filename = "mprage_matlab.seq" diff --git a/pypulseq/tests/test_epi.py b/pypulseq/tests/test_epi.py index 3c3d3592..53ddd1c6 100644 --- a/pypulseq/tests/test_epi.py +++ b/pypulseq/tests/test_epi.py @@ -3,7 +3,10 @@ from pypulseq.seq_examples.scripts import write_epi from pypulseq.tests import base +import pytest + +@pytest.mark.matlab_seq_comp class TestEPI(unittest.TestCase): def test_write_epi(self): matlab_seq_filename = "epi_matlab.seq" diff --git a/pypulseq/tests/test_epi_label.py b/pypulseq/tests/test_epi_label.py index bfbcde62..91cad996 100644 --- a/pypulseq/tests/test_epi_label.py +++ b/pypulseq/tests/test_epi_label.py @@ -3,7 +3,10 @@ from pypulseq.seq_examples.scripts import write_epi_label from pypulseq.tests import base +import pytest + +@pytest.mark.matlab_seq_comp class TestEPILabel(unittest.TestCase): def test_write_epi(self): matlab_seq_filename = "epi_label_matlab.seq" diff --git a/pypulseq/tests/test_epi_se.py b/pypulseq/tests/test_epi_se.py index 6b6713b2..4c8ddf71 100644 --- a/pypulseq/tests/test_epi_se.py +++ b/pypulseq/tests/test_epi_se.py @@ -3,7 +3,10 @@ from pypulseq.seq_examples.scripts import write_epi_se from pypulseq.tests import base +import pytest + +@pytest.mark.matlab_seq_comp class TestEPISpinEcho(unittest.TestCase): def test_write_epi(self): matlab_seq_filename = "epi_se_matlab.seq" diff --git a/pypulseq/tests/test_epi_se_rs.py b/pypulseq/tests/test_epi_se_rs.py index ae4e57b3..37050265 100644 --- a/pypulseq/tests/test_epi_se_rs.py +++ b/pypulseq/tests/test_epi_se_rs.py @@ -3,7 +3,10 @@ from pypulseq.seq_examples.scripts import write_epi_se_rs from pypulseq.tests import base +import pytest + +@pytest.mark.matlab_seq_comp class TestEPISpinEchoRS(unittest.TestCase): def test_write_epi(self): matlab_seq_filename = "epi_se_rs_matlab.seq" diff --git a/pypulseq/tests/test_gre.py b/pypulseq/tests/test_gre.py index dbb99685..5bd6aba2 100644 --- a/pypulseq/tests/test_gre.py +++ b/pypulseq/tests/test_gre.py @@ -3,7 +3,10 @@ from pypulseq.seq_examples.scripts import write_gre from pypulseq.tests import base +import pytest + +@pytest.mark.matlab_seq_comp class TestGRE(unittest.TestCase): def test_write_epi(self): matlab_seq_filename = "gre_matlab.seq" diff --git a/pypulseq/tests/test_gre_label.py b/pypulseq/tests/test_gre_label.py index 5899f8e9..640d0087 100644 --- a/pypulseq/tests/test_gre_label.py +++ b/pypulseq/tests/test_gre_label.py @@ -3,7 +3,10 @@ from pypulseq.seq_examples.scripts import write_gre_label from pypulseq.tests import base +import pytest + +@pytest.mark.matlab_seq_comp class TestGRELabel(unittest.TestCase): def test_write_epi(self): matlab_seq_filename = "gre_label_matlab.seq" diff --git a/pypulseq/tests/test_gre_radial.py b/pypulseq/tests/test_gre_radial.py index 99b6c558..878032c6 100644 --- a/pypulseq/tests/test_gre_radial.py +++ b/pypulseq/tests/test_gre_radial.py @@ -3,7 +3,10 @@ from pypulseq.seq_examples.scripts import write_radial_gre from pypulseq.tests import base +import pytest + +@pytest.mark.matlab_seq_comp class TestEPISpinEchoRS(unittest.TestCase): def test_write_epi(self): matlab_seq_filename = "gre_radial_matlab.seq" diff --git a/pypulseq/tests/test_haste.py b/pypulseq/tests/test_haste.py index a7f06212..dbd1e4e6 100644 --- a/pypulseq/tests/test_haste.py +++ b/pypulseq/tests/test_haste.py @@ -3,7 +3,10 @@ from pypulseq.seq_examples.scripts import write_haste from pypulseq.tests import base +import pytest + +@pytest.mark.matlab_seq_comp class TestHASTE(unittest.TestCase): def test_write_epi(self): matlab_seq_filename = "haste_matlab.seq" diff --git a/pypulseq/tests/test_make_block_pulse.py b/pypulseq/tests/test_make_block_pulse.py new file mode 100644 index 00000000..232328bd --- /dev/null +++ b/pypulseq/tests/test_make_block_pulse.py @@ -0,0 +1,89 @@ +"""Tests for the make_block_pulse.py module + +Will Clarke, University of Oxford, 2023 +""" + +from types import SimpleNamespace + +import pytest +import numpy as np + +from pypulseq import make_block_pulse + + +def test_invalid_use_error(): + + with pytest.raises( + ValueError, + match=r"Invalid use parameter."): + make_block_pulse(flip_angle=np.pi, duration=1E-3, use='foo') + + +def test_bandwidth_and_duration_error(): + + with pytest.raises( + ValueError, + match=r"One of bandwidth or duration must be defined, but not both."): + make_block_pulse(flip_angle=np.pi, duration=1E-3, bandwidth=1000) + + +def test_invalid_bandwidth_and_duration_error(): + + with pytest.raises( + ValueError, + match=r"One of bandwidth or duration must be defined and be > 0."): + make_block_pulse(flip_angle=np.pi, duration=-1E-3) + + with pytest.raises( + ValueError, + match=r"One of bandwidth or duration must be defined and be > 0."): + make_block_pulse(flip_angle=np.pi, bandwidth=-1E3) + + +def test_default_duration_warning(): + + with pytest.warns( + UserWarning, + match=r'Using default 4 ms duration for block pulse.'): + make_block_pulse(flip_angle=np.pi) + + +def test_generation_methods(): + """Test minimum input cases + Cover: + - Just flip_angle + - duration + - bandwidth + - bandwidth + time_bw_product + """ + + # Capture expected warning for default case + with pytest.warns(UserWarning): + case1 = make_block_pulse(flip_angle=np.pi) + + assert isinstance(case1, SimpleNamespace) + assert case1.shape_dur == 4E-3 + + case2 = make_block_pulse(flip_angle=np.pi, duration=1E-3) + assert isinstance(case2, SimpleNamespace) + assert case2.shape_dur == 1E-3 + + case3 = make_block_pulse(flip_angle=np.pi, bandwidth=1E3) + assert isinstance(case3, SimpleNamespace) + assert case3.shape_dur == 1 / (4 * 1E3) + + case4 = make_block_pulse(flip_angle=np.pi, bandwidth=1E3, time_bw_product=5) + assert isinstance(case4, SimpleNamespace) + assert case4.shape_dur == 5 / 1E3 + + +def test_amp_calculation(): + # A 1 ms 180 degree pulse requires 500 Hz gamma B1 + pulse = make_block_pulse(duration=1E-3, flip_angle=np.pi) + assert np.isclose(pulse.signal.max(), 500) + + pulse = make_block_pulse(duration=1E-3, flip_angle=np.pi/2) + assert np.isclose(pulse.signal.max(), 250) + + pulse = make_block_pulse(duration=2E-3, flip_angle=np.pi/2) + assert np.isclose(pulse.signal.max(), 125) diff --git a/pypulseq/tests/test_sigpy.py b/pypulseq/tests/test_sigpy.py index 2e37cf31..92fc7607 100644 --- a/pypulseq/tests/test_sigpy.py +++ b/pypulseq/tests/test_sigpy.py @@ -46,6 +46,7 @@ def test_slr(self): time_bw_product=4, return_gz=True, pulse_cfg=pulse_cfg, + plot=False, ) [a, b] = rf.sim.abrm( @@ -97,6 +98,7 @@ def test_sms(self): time_bw_product=4, return_gz=True, pulse_cfg=pulse_cfg, + plot=False ) [a, b] = rf.sim.abrm( diff --git a/pypulseq/tests/test_tse.py b/pypulseq/tests/test_tse.py index 908fba11..0a3e4127 100644 --- a/pypulseq/tests/test_tse.py +++ b/pypulseq/tests/test_tse.py @@ -3,7 +3,10 @@ from pypulseq.seq_examples.scripts import write_tse from pypulseq.tests import base +import pytest + +@pytest.mark.matlab_seq_comp class TestTSE(unittest.TestCase): def test_write_epi(self): matlab_seq_filename = "tse_matlab.seq" diff --git a/pypulseq/tests/test_ute.py b/pypulseq/tests/test_ute.py index 3742a4a3..bf3b7a33 100644 --- a/pypulseq/tests/test_ute.py +++ b/pypulseq/tests/test_ute.py @@ -3,7 +3,10 @@ from pypulseq.seq_examples.scripts import write_ute from pypulseq.tests import base +import pytest + +@pytest.mark.matlab_seq_comp class TestUTE(unittest.TestCase): def test_write_epi(self): matlab_seq_filename = "ute_matlab.seq"