@@ -193,6 +193,7 @@ int mnist_eval(
193193
194194 // soft max
195195 ggml_tensor * probs = ggml_soft_max (ctx0, fc2);
196+ ggml_set_name (probs, " probs" );
196197
197198 // build / export / run the computation graph
198199 ggml_build_forward_expand (&gf, probs);
@@ -201,25 +202,27 @@ int mnist_eval(
201202 // ggml_graph_print (&gf);
202203 ggml_graph_dump_dot (&gf, NULL , " mnist.dot" );
203204
205+ ggml_graph_export (&gf, " mnist.ggml" );
206+
207+ #if 0
204208 const float * probs_data = ggml_get_data_f32(probs);
205209
206210 const int prediction = std::max_element(probs_data, probs_data + 10) - probs_data;
211+ #else
212+ struct ggml_context * ctx_data = NULL ;
213+ struct ggml_context * ctx_eval = NULL ;
207214
208- ggml_free (ctx0);
215+ struct ggml_cgraph gfi = ggml_graph_import (" mnist.ggml" , &ctx_data, &ctx_eval);
216+ gfi.n_threads = n_threads;
209217
210- ggml_graph_export (&gf, " mnist.ggml " );
218+ ggml_graph_compute (ctx0, &gfi );
211219
212- // TMP
213- // import the computation graph
214- {
215- struct ggml_context * ctx_data = NULL ;
216- struct ggml_context * ctx_eval = NULL ;
220+ const float * probs_data = ggml_get_data_f32 (ggml_get_tensor_by_name (&gfi, " probs" ));
217221
218- struct ggml_cgraph gfi = ggml_graph_import ( " mnist.ggml " , &ctx_data, &ctx_eval) ;
219- gfi. n_threads = n_threads;
222+ const int prediction = std::max_element (probs_data, probs_data + 10 ) - probs_data ;
223+ # endif
220224
221- ggml_graph_compute (ctx0, &gfi);
222- }
225+ ggml_free (ctx0);
223226
224227 return prediction;
225228}
0 commit comments