@@ -1334,28 +1334,31 @@ static void sched_compute_splits(ggml_backend_sched_t sched) {
13341334 //ggml_backend_synchronize(split_backend); // necessary to measure compute time
13351335 } else {
13361336 // similar to ggml_backend_compare_graph_backend
1337- for (int j = 0 ; j < split -> graph .n_nodes ; j ++ ) {
1338- struct ggml_tensor * t = split -> graph .nodes [j ];
1337+ for (int j0 = 0 ; j0 < split -> graph .n_nodes ; j0 ++ ) {
1338+ struct ggml_tensor * t = split -> graph .nodes [j0 ];
13391339
1340- int k = j ;
1340+ int j1 = j0 ;
13411341
1342- // check if the user needs data from this node
1343- while (!sched -> callback_eval (k , t , true, sched -> callback_eval_user_data ) && k < split -> graph .n_nodes - 1 ) {
1344- t = split -> graph .nodes [++ k ];
1342+ // determine the range [j0, j1] of nodes that can be computed together
1343+ while (j1 < split -> graph .n_nodes - 1 ) {
1344+ // check if the user needs data from this node
1345+ if (sched -> callback_eval (t , true, sched -> callback_eval_user_data )) {
1346+ break ;
1347+ }
1348+
1349+ t = split -> graph .nodes [++ j1 ];
13451350 }
13461351
1347- struct ggml_cgraph gv = ggml_graph_view (& split -> graph , j , k + 1 );
1352+ struct ggml_cgraph gv = ggml_graph_view (& split -> graph , j0 , j1 + 1 );
13481353
13491354 ggml_backend_graph_compute (split_backend , & gv );
13501355
1351- // TODO: k is node index in the split, not in the original graph
1352- // TODO: avoid the ask == true call here
1353- if (sched -> callback_eval (k , t , true, sched -> callback_eval_user_data ) &&
1354- !sched -> callback_eval (k , t , false, sched -> callback_eval_user_data )) {
1356+ if (sched -> callback_eval (t , true, sched -> callback_eval_user_data ) && // ask
1357+ !sched -> callback_eval (t , false, sched -> callback_eval_user_data )) { // eval
13551358 break ;
13561359 }
13571360
1358- j = k ;
1361+ j0 = j1 ;
13591362 }
13601363 }
13611364 uint64_t compute_end_us = ggml_time_us ();
0 commit comments