Skip to content

Commit b4e3b5c

Browse files
committed
ggml : add ggml_get_tensor_by_name()
1 parent 15f12f8 commit b4e3b5c

File tree

3 files changed

+28
-11
lines changed

3 files changed

+28
-11
lines changed

examples/mnist/main.cpp

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ int mnist_eval(
193193

194194
// soft max
195195
ggml_tensor * probs = ggml_soft_max(ctx0, fc2);
196+
ggml_set_name(probs, "probs");
196197

197198
// build / export / run the computation graph
198199
ggml_build_forward_expand(&gf, probs);
@@ -201,25 +202,27 @@ int mnist_eval(
201202
//ggml_graph_print (&gf);
202203
ggml_graph_dump_dot(&gf, NULL, "mnist.dot");
203204

205+
ggml_graph_export(&gf, "mnist.ggml");
206+
207+
#if 0
204208
const float * probs_data = ggml_get_data_f32(probs);
205209

206210
const int prediction = std::max_element(probs_data, probs_data + 10) - probs_data;
211+
#else
212+
struct ggml_context * ctx_data = NULL;
213+
struct ggml_context * ctx_eval = NULL;
207214

208-
ggml_free(ctx0);
215+
struct ggml_cgraph gfi = ggml_graph_import("mnist.ggml", &ctx_data, &ctx_eval);
216+
gfi.n_threads = n_threads;
209217

210-
ggml_graph_export(&gf, "mnist.ggml");
218+
ggml_graph_compute(ctx0, &gfi);
211219

212-
// TMP
213-
// import the computation graph
214-
{
215-
struct ggml_context * ctx_data = NULL;
216-
struct ggml_context * ctx_eval = NULL;
220+
const float * probs_data = ggml_get_data_f32(ggml_get_tensor_by_name(&gfi, "probs"));
217221

218-
struct ggml_cgraph gfi = ggml_graph_import("mnist.ggml", &ctx_data, &ctx_eval);
219-
gfi.n_threads = n_threads;
222+
const int prediction = std::max_element(probs_data, probs_data + 10) - probs_data;
223+
#endif
220224

221-
ggml_graph_compute(ctx0, &gfi);
222-
}
225+
ggml_free(ctx0);
223226

224227
return prediction;
225228
}

include/ggml/ggml.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -978,6 +978,8 @@ extern "C" {
978978
GGML_API void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph);
979979
GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph);
980980

981+
GGML_API struct ggml_tensor * ggml_get_tensor_by_name(struct ggml_cgraph * cgraph, const char * name);
982+
981983
// print info and performance information for the graph
982984
GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph);
983985

src/ggml.c

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14526,6 +14526,18 @@ void ggml_graph_reset(struct ggml_cgraph * cgraph) {
1452614526
}
1452714527
}
1452814528

14529+
struct ggml_tensor * ggml_get_tensor_by_name(struct ggml_cgraph * cgraph, const char * name) {
14530+
for (int i = 0; i < cgraph->n_nodes; i++) {
14531+
struct ggml_tensor * node = cgraph->nodes[i];
14532+
14533+
if (strcmp(node->name, name) == 0) {
14534+
return node;
14535+
}
14536+
}
14537+
14538+
return NULL;
14539+
}
14540+
1452914541
void ggml_graph_print(const struct ggml_cgraph * cgraph) {
1453014542
int64_t perf_total_per_op_us[GGML_OP_COUNT] = {0};
1453114543

0 commit comments

Comments
 (0)