@@ -31,6 +31,17 @@ float frand_normal(struct random_normal_distribution * rnd) {
3131 return ((r < rnd->min ) ? (rnd->min ) : (r > rnd->max ) ? (rnd->max ) : r);
3232}
3333
34+ void ggml_graph_compute_helper (std::vector<uint8_t > & buf, ggml_cgraph * graph, int n_threads) {
35+ struct ggml_cplan plan = ggml_graph_plan (graph, n_threads);
36+
37+ if (plan.work_size > 0 ) {
38+ buf.resize (plan.work_size );
39+ plan.work_data = buf.data ();
40+ }
41+
42+ ggml_graph_compute (graph, &plan);
43+ }
44+
3445struct ggml_tensor * randomize_tensor (
3546 struct ggml_tensor * tensor,
3647 int ndims,
@@ -1569,6 +1580,8 @@ int main(int argc, char ** argv) {
15691580 int n_tokens = model.hparams .n_ctx ;
15701581 int n_vocab = model.hparams .n_vocab ;
15711582
1583+ std::vector<uint8_t > work_buffer;
1584+
15721585 for (int ex=0 ; ex<n_examples; ++ex) {
15731586 struct ggml_init_params params = {
15741587 /* .mem_size =*/ compute_size,
@@ -1586,7 +1599,6 @@ int main(int argc, char ** argv) {
15861599 int n_past = 0 ;
15871600
15881601 ggml_cgraph gf = {};
1589- gf.n_threads = 1 ;
15901602
15911603 get_example_targets_batch (ctx0, 64 *ex+0 , tokens_input, targets);
15921604
@@ -1595,7 +1607,7 @@ int main(int argc, char ** argv) {
15951607 struct ggml_tensor * e = square_error_loss (ctx0, targets, logits);
15961608
15971609 ggml_build_forward_expand (&gf, e);
1598- ggml_graph_compute (ctx0 , &gf);
1610+ ggml_graph_compute_helper (work_buffer , &gf, /* n_threads */ 1 );
15991611
16001612 float error_before_opt = ggml_get_f32_1d (e, 0 );
16011613
@@ -1611,7 +1623,7 @@ int main(int argc, char ** argv) {
16111623 ggml_opt (ctx0, opt_params_lbfgs, e);
16121624 //
16131625 ggml_build_forward_expand (&gf, e);
1614- ggml_graph_compute (ctx0 , &gf);
1626+ ggml_graph_compute_helper (work_buffer , &gf, /* n_threads */ 1 );
16151627
16161628 float error_after_opt = ggml_get_f32_1d (e, 0 );
16171629
@@ -1659,13 +1671,12 @@ int main(int argc, char ** argv) {
16591671 struct ggml_context * ctx0 = ggml_init (params);
16601672
16611673 ggml_cgraph gf = {};
1662- gf.n_threads = 1 ;
16631674
16641675 int n_past = 0 ;
16651676 struct ggml_tensor * logits = forward (&model, &kv_self, ctx0, &gf, tokens_input, sample_ctx, n_past);
16661677
16671678 ggml_build_forward_expand (&gf, logits);
1668- ggml_graph_compute (ctx0 , &gf);
1679+ ggml_graph_compute_helper (work_buffer , &gf, /* n_threads */ 1 );
16691680
16701681 struct ggml_tensor * best_samples = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, sample_ctx);
16711682 struct ggml_tensor * probs = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_vocab, sample_ctx);
@@ -1687,10 +1698,11 @@ int main(int argc, char ** argv) {
16871698 }
16881699
16891700 print_matrix (model.tok_embeddings );
1690-
16911701 printf (" done\n " );
1702+
16921703 // ggml_free(kv_self.ctx);
16931704 // ggml_free(model_lora.ctx);
16941705 ggml_free (model.ctx );
1706+
16951707 return 0 ;
16961708}
0 commit comments