PyTorchでニューラルネットワークのパラメータ数を取得する方法
パラメータ取得方法
PyTorchでニューラルネットワークのパラメータを取得する方法として、自分で関数を 書いて求める方法、ライブラリを使って求める方法がある。 その方法を説明していく。
1. 自作関数を書く
自作の関数を使って、PyTorchのネットワークのパラメータ数を求めることが出来る。 作る関数は、下のような関数である。
count_parameters(model)
がネットワーク(model)の全パラメータ数を取得するもので、
count_trainable_parameters(model)
が全ての学習可能なパラメータを取得するものである。
引数は、いずれもパラメータ数を知りたいネットワーク。
中身はすごく単純で、model.parameters()
でネットワークの層を取得して、
そのパラメータ数を数え上げている。count_trainable_parameters
の方では、さらにrequires_grad
がTrue、つまりパラメータが学習可能なもののみを数え上げる。
def count_parameters(model): return sum(p.numel() for p in model.parameters()) def count_trainable_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad)
使い方は、下のようになる。
まず、パラメータ数をしりたいネットワークを読み込む。
下の例では、resnet18
を読み込んでいる。そして、上の関数を適用して
パラメータ数を取得する(count_parameters(model)
の部分)。
2つの関数を2つの条件で試している。
1つ目の条件は重みを固定しない時、もう一つは、ネットワークの重みを固定した時である。
import torch import torchvision #モデル読み込み model=torchvision.models.resnet18() #重み固定なし #パラメータ取得して表示 num_parameters=count_parameters(model) print(num_parameters) num_parameters=count_trainable_parameters(model) print(num_parameters) #重み固定してみる for param in model.parameters(): param.requires_grad=False #パラメータ取得して表示 num_parameters=count_parameters(model) print(num_parameters) num_parameters=count_trainable_parameters(model) print(num_parameters)
結果は、下のようになった。
上の2つは、ネットワークの重みを固定していないときで、どちらのパラメータ取得関数も同じ値となり、ネットワークの全パラメータ数が表示されている。
一方、下の2つは、ネットワークの重みを固定した時の結果で、重みを固定しているので、
パラメータは学習できない状態になるので、学習可能なパラメータ数を得る、count_trainable_parameters()
は、結果として、0を返した。
11689512 11689512 11689512 0
注意点としては、モデルのパラメータをlistで保持しているときちんと数えれない。 listではなくnn.ModuleList()を使うことで、数えることができる。
2. torchsummaryというライブラリを使う方法
PyTorchのパラメータ数を取得するライブラリに、 torchsummaryというライブラリがある。
GitHub - sksq96/pytorch-summary: Model summary in PyTorch similar to `model.summary()` in Keras
このライブラリは便利なもので、 レイヤーごとのに出力サイズとパラメータ数を出してくれて、さらに全体でのパラメータ数も表示してくれる。
使い方
pipでインストールできる。
pip install torchsummary
使い方は、簡単で、 下のmodelのところに、自分のネットワークモデルを入れる。(channel,H,W)を 自分のネットワークの入力の次元に変える。
from torchsummary import summary summary(model, input_size=(channels, H, W))
resnet18
で試してみる。使うPyTorchのモデルをsummary()
の第1引数に
入れて、第2引数に、入力するtensorのサイズ、(3,224,224)を入れた。
import torchvision.models as models from torchsummary import summary model=models.resnet18() summary(model,(3,224,224))
結果は、下のようなものが出力される。
それぞれのLayer
にレイヤーの名前、Output Shape
に第2引数で入れた時の出力のサイズ、
そして、Param #
にパラメータの数が表示される。そして、最後の部分に、
全体のパラメータ数と学習可能なパラメータ数が表示される。
---------------------------------------------------------------- Layer (type) Output Shape Param # ================================================================ Conv2d-1 [-1, 64, 112, 112] 9,408 BatchNorm2d-2 [-1, 64, 112, 112] 128 ReLU-3 [-1, 64, 112, 112] 0 MaxPool2d-4 [-1, 64, 56, 56] 0 Conv2d-5 [-1, 64, 56, 56] 36,864 BatchNorm2d-6 [-1, 64, 56, 56] 128 ReLU-7 [-1, 64, 56, 56] 0 Conv2d-8 [-1, 64, 56, 56] 36,864 BatchNorm2d-9 [-1, 64, 56, 56] 128 ReLU-10 [-1, 64, 56, 56] 0 BasicBlock-11 [-1, 64, 56, 56] 0 Conv2d-12 [-1, 64, 56, 56] 36,864 BatchNorm2d-13 [-1, 64, 56, 56] 128 ReLU-14 [-1, 64, 56, 56] 0 Conv2d-15 [-1, 64, 56, 56] 36,864 BatchNorm2d-16 [-1, 64, 56, 56] 128 ReLU-17 [-1, 64, 56, 56] 0 BasicBlock-18 [-1, 64, 56, 56] 0 Conv2d-19 [-1, 128, 28, 28] 73,728 BatchNorm2d-20 [-1, 128, 28, 28] 256 ReLU-21 [-1, 128, 28, 28] 0 Conv2d-22 [-1, 128, 28, 28] 147,456 BatchNorm2d-23 [-1, 128, 28, 28] 256 Conv2d-24 [-1, 128, 28, 28] 8,192 BatchNorm2d-25 [-1, 128, 28, 28] 256 ReLU-26 [-1, 128, 28, 28] 0 BasicBlock-27 [-1, 128, 28, 28] 0 Conv2d-28 [-1, 128, 28, 28] 147,456 BatchNorm2d-29 [-1, 128, 28, 28] 256 ReLU-30 [-1, 128, 28, 28] 0 Conv2d-31 [-1, 128, 28, 28] 147,456 BatchNorm2d-32 [-1, 128, 28, 28] 256 ReLU-33 [-1, 128, 28, 28] 0 BasicBlock-34 [-1, 128, 28, 28] 0 Conv2d-35 [-1, 256, 14, 14] 294,912 BatchNorm2d-36 [-1, 256, 14, 14] 512 ReLU-37 [-1, 256, 14, 14] 0 Conv2d-38 [-1, 256, 14, 14] 589,824 BatchNorm2d-39 [-1, 256, 14, 14] 512 Conv2d-40 [-1, 256, 14, 14] 32,768 BatchNorm2d-41 [-1, 256, 14, 14] 512 ReLU-42 [-1, 256, 14, 14] 0 BasicBlock-43 [-1, 256, 14, 14] 0 Conv2d-44 [-1, 256, 14, 14] 589,824 BatchNorm2d-45 [-1, 256, 14, 14] 512 ReLU-46 [-1, 256, 14, 14] 0 Conv2d-47 [-1, 256, 14, 14] 589,824 BatchNorm2d-48 [-1, 256, 14, 14] 512 ReLU-49 [-1, 256, 14, 14] 0 BasicBlock-50 [-1, 256, 14, 14] 0 Conv2d-51 [-1, 512, 7, 7] 1,179,648 BatchNorm2d-52 [-1, 512, 7, 7] 1,024 ReLU-53 [-1, 512, 7, 7] 0 Conv2d-54 [-1, 512, 7, 7] 2,359,296 BatchNorm2d-55 [-1, 512, 7, 7] 1,024 Conv2d-56 [-1, 512, 7, 7] 131,072 BatchNorm2d-57 [-1, 512, 7, 7] 1,024 ReLU-58 [-1, 512, 7, 7] 0 BasicBlock-59 [-1, 512, 7, 7] 0 Conv2d-60 [-1, 512, 7, 7] 2,359,296 BatchNorm2d-61 [-1, 512, 7, 7] 1,024 ReLU-62 [-1, 512, 7, 7] 0 Conv2d-63 [-1, 512, 7, 7] 2,359,296 BatchNorm2d-64 [-1, 512, 7, 7] 1,024 ReLU-65 [-1, 512, 7, 7] 0 BasicBlock-66 [-1, 512, 7, 7] 0 AdaptiveAvgPool2d-67 [-1, 512, 1, 1] 0 Linear-68 [-1, 1000] 513,000 ================================================================ Total params: 11,689,512 Trainable params: 11,689,512 Non-trainable params: 0 ---------------------------------------------------------------- Input size (MB): 0.57 Forward/backward pass size (MB): 62.79 Params size (MB): 44.59 Estimated Total Size (MB): 107.96 ----------------------------------------------------------------
前の自作関数のものと比べると、パラメータ数が同じになっているので、正しいことがわかる。
各層ごとのパラメータも見れて、とても便利である。
バグ
しかし、githubのissueを見てみると、パラメータ共有をしていると、バグってしまう と書いてあった。パラメータ共有をしているネットワークを使う時には、このライブラリ 以外でも確認した方がよさそうである。
まとめ
PyTorchのネットワークのパラメータ数を取得する方法として、自分で関数を作って求める方法と、torchsummaryというライブラリを使って取得する方法を紹介した。 torchsummaryは、importして、summary関数のみで、各層のパラメータ数など様々な情報を表示できるので、個人的にはtorchsummaryを使う方法の方がいいと思う。しかし、パラメータ共有などを使っている時は、バグっているので注意が必要。