@@ -61,16 +61,17 @@ jobs:
6161 python-version : ${{ matrix.python-version }}
6262 -
uses :
pre-commit/[email protected] 6363
64- test_ubuntu :
65- name : " Test py${{ matrix.python-version }} : fast-compile ${{ matrix.fast-compile }} : float32 ${{ matrix.float32 }} : ${{ matrix.part }}"
64+ test :
65+ name : " ${{ matrix.os }} test py${{ matrix.python-version }} : fast-compile ${{ matrix.fast-compile }} : float32 ${{ matrix.float32 }} : ${{ matrix.part }}"
6666 needs :
6767 - changes
6868 - style
69- runs-on : ubuntu-latest
69+ runs-on : ${{ matrix.os }}
7070 if : ${{ needs.changes.outputs.changes == 'true' && needs.style.result == 'success' }}
7171 strategy :
7272 fail-fast : false
7373 matrix :
74+ os : ["ubuntu-latest"]
7475 python-version : ["3.10", "3.12"]
7576 fast-compile : [0, 1]
7677 float32 : [0, 1]
@@ -103,30 +104,44 @@ jobs:
103104 fast-compile : 1
104105 include :
105106 - install-numba : 1
107+ os : " ubuntu-latest"
106108 python-version : " 3.10"
107109 fast-compile : 0
108110 float32 : 0
109111 part : " tests/link/numba"
110112 - install-numba : 1
113+ os : " ubuntu-latest"
111114 python-version : " 3.12"
112115 fast-compile : 0
113116 float32 : 0
114117 part : " tests/link/numba"
115118 - install-jax : 1
119+ os : " ubuntu-latest"
116120 python-version : " 3.10"
117121 fast-compile : 0
118122 float32 : 0
119123 part : " tests/link/jax"
120124 - install-jax : 1
125+ os : " ubuntu-latest"
121126 python-version : " 3.12"
122127 fast-compile : 0
123128 float32 : 0
124129 part : " tests/link/jax"
125130 - install-torch : 1
131+ os : " ubuntu-latest"
126132 python-version : " 3.10"
127133 fast-compile : 0
128134 float32 : 0
129135 part : " tests/link/pytorch"
136+ - os : macos-latest
137+ python-version : " 3.12"
138+ fast-compile : 0
139+ float32 : 0
140+ install-numba : 0
141+ install-jax : 0
142+ install-torch : 0
143+ part : " tests/tensor/test_blas.py tests/tensor/test_elemwise.py tests/tensor/test_math_scipy.py"
144+
130145 steps :
131146 - uses : actions/checkout@v4
132147 with :
@@ -146,15 +161,19 @@ jobs:
146161 MATRIX_CONTEXT : ${{ toJson(matrix) }}
147162 run : |
148163 echo $MATRIX_CONTEXT
149- export MATRIX_ID=`echo $MATRIX_CONTEXT | md5sum | cut -c 1-32`
164+ export MATRIX_ID=`echo $MATRIX_CONTEXT | sha256sum | cut -c 1-32`
150165 echo $MATRIX_ID
151166 echo "id=$MATRIX_ID" >> $GITHUB_OUTPUT
152167
153168 - name : Install dependencies
154169 shell : micromamba-shell {0}
155170 run : |
156171
157- micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock
172+ if [[ $OS == "macos-latest" ]]; then
173+ micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" numpy scipy pip graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock libblas=*=*accelerate;
174+ else
175+ micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock;
176+ fi
158177 if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57"; fi
159178 if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro && pip install tensorflow-probability; fi
160179 if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
@@ -163,12 +182,17 @@ jobs:
163182 pip install -e ./
164183 micromamba list && pip freeze
165184 python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'
166- python -c 'import pytensor; assert pytensor.config.blas__ldflags != "", "Blas flags are empty"'
185+ if [[ $OS == "macos-latest" ]]; then
186+ python -c 'import pytensor; assert pytensor.config.blas__ldflags.startswith("-framework Accelerate"), "Blas flags are not set to MacOS Accelerate"';
187+ else
188+ python -c 'import pytensor; assert pytensor.config.blas__ldflags != "", "Blas flags are empty"';
189+ fi
167190 env :
168191 PYTHON_VERSION : ${{ matrix.python-version }}
169192 INSTALL_NUMBA : ${{ matrix.install-numba }}
170193 INSTALL_JAX : ${{ matrix.install-jax }}
171194 INSTALL_TORCH : ${{ matrix.install-torch}}
195+ OS : ${{ matrix.os}}
172196
173197 - name : Run tests
174198 shell : micromamba-shell {0}
@@ -249,10 +273,10 @@ jobs:
249273 if : ${{ always() }}
250274 runs-on : ubuntu-latest
251275 name : " All tests"
252- needs : [changes, style, test_ubuntu ]
276+ needs : [changes, style, test ]
253277 steps :
254278 - name : Check build matrix status
255- if : ${{ needs.changes.outputs.changes == 'true' && (needs.style.result != 'success' || needs.test_ubuntu .result != 'success') }}
279+ if : ${{ needs.changes.outputs.changes == 'true' && (needs.style.result != 'success' || needs.test .result != 'success') }}
256280 run : exit 1
257281
258282 upload-coverage :
0 commit comments