PyTorchのネットワーク構造を可視化できるものを探してみた
はじめに
PyTorchでネットワーク構造を見たいけど、何使えばいいかわからなかったので、探した。
import torch import torch.nn as nn import torch.nn.functional as F class TestModel(nn.Module): def __init__(self): super().__init__() self.fc1=nn.Linear(10,5) self.fc2=nn.Linear(5,2) def forward(self,x): x=F.relu(self.fc1(x)) x=F.softmax(self.fc2(x)) return x model=TestModel() print(model)
テストで使うネットワーク。FC層2層のテストモデル。
TestModel( (fc1): Linear(in_features=10, out_features=5, bias=True) (fc2): Linear(in_features=5, out_features=2, bias=True) )
可視化手法
1. onnxに変換してnetronで見る
まず、pytorchのモデルをonnx形式に変換する。
dummy_input=torch.randn(1,10)#ダミーの入力を用意する input_names = [ "input"] output_names = [ "output" ] torch.onnx.export(model, dummy_input, "./test_model.onnx", verbose=True,input_names=input_names,output_names=output_names)
実行すると、下のような出力とtest_model.onnxというonnxファイルが出力される。
graph(%input : Float(1, 10), %fc1.weight : Float(5, 10), %fc1.bias : Float(5), %fc2.weight : Float(2, 5), %fc2.bias : Float(2)): %5 : Float(1, 5) = onnx::Gemm[alpha=1, beta=1, transB=1](%input, %fc1.weight, %fc1.bias), scope: TestModel/Linear[fc1] # /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:1370:0 %6 : Float(1, 5) = onnx::Relu(%5), scope: TestModel # /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:914:0 %7 : Float(1, 2) = onnx::Gemm[alpha=1, beta=1, transB=1](%6, %fc2.weight, %fc2.bias), scope: TestModel/Linear[fc2] # /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:1370:0 %output : Float(1, 2) = onnx::LogSoftmax[axis=1](%7), scope: TestModel # /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:1317:0 return (%output)
onnx形式に変換後、Netronというもので可視化できる。こののページにonnxファイルをアップロードするとネットワーク構造が見れる。
層の名前は変わってしまうが、きちんと表示されている。
このnetronはonnxに変換せずにもpytorchの保存したファイルでも表示はできるみたい。しかし、サイトには実験的サポートと書いてある。
torch.save(model,"./test_model.pth")
保存されたファイルをアップロードしてみると、ネットワーク構造は表示されたが、relu、log_softmaxが表示されなかった。
2. tensorboard
tensorboardのグラフ表示する機能を使う方法。
google colabのtensorflowのバージョンを2系へ変更して、 tensorboardの拡張を読み込む。
%tensorflow_version 2.x %load_ext tensorboard
add_graphを用いて、モデルの構造の出力する。
from torch.utils.tensorboard import SummaryWriter dummy_input=torch.randn(1,10) writer = SummaryWriter() writer.add_graph(model, dummy_input) writer.close()
tensorboardを起動して、グラフを見る。
%tensorboard --logdir ./runs
ダブルクリックすると、中身も見れる。
3. pytorchvizで見る
pytorchvizを使い、モデルを可視化する。
ライブラリをインストールする。
!apt-get install graphviz !pip install torchviz
次に、モデルを生成し、入力xを通して出力yを得る。 make_dotの引数に出力とモデルのパラメータを入れると、グラフが出てくる。
from torchviz import make_dot model=TestModel() x=torch.randn(1,10)#ダミー入力 y=model(x) make_dot(y,params=dict(model.named_parameters()))