How to list all symbols in mxnet?

by buechel   Last Updated October 09, 2019 17:26 PM

In mxnet 1.4 using the Python API, suppose I do

import mxnet as mx

tmp = mx.sym.var('a')
print(tmp)  # <Symbol a>

tmp = tmp + tmp
print(tmp)  # <Symbol _plus0>

tmp = mx.sym.var('b')
tmp = tmp + tmp
print(tmp)  # <Symbol _plus1>

I assume, <Symbol _plus0> is still present in the graph somewhere. How can I list all symbols which currently live in my graph?

I would like to do something like mx.sym.list_all_symbols().

I have checked this tutorial, the docs, as well as the source code but couldn't find anything.

Tags : python mxnet


Answers 1


Use mxnet's viz module to plot the network. You can also save the symbol graph to a json file and peruse the json file to view all the symbols -

import mxnet as mx
a = mx.sym.Variable('a')
b = mx.sym.Variable('b')
c = a + b
d = a * b
# matrix multiplication
e = mx.sym.dot(a, b)
# reshape
f = mx.sym.reshape(d+e, shape=(1,4))
# plot
f.save('fgraph-symbol.json')
mx.viz.plot_network(symbol=f)

The graph that gets plotted

anirudh
anirudh
October 09, 2019 17:24 PM

Related Questions


Updated March 20, 2017 22:26 PM

Updated December 20, 2017 02:26 AM

Updated May 17, 2018 13:26 PM

Updated October 22, 2018 02:26 AM

Updated August 14, 2019 12:26 PM