msdd’s blog

deep learning勉強中。プログラム関連のこと書きます。

PyTorchでニューラルネットワークのパラメータ数を取得する方法

パラメータ取得方法

PyTorchでニューラルネットワークのパラメータを取得する方法として、自分で関数を 書いて求める方法、ライブラリを使って求める方法がある。 その方法を説明していく。

1. 自作関数を書く

自作の関数を使って、PyTorchのネットワークのパラメータ数を求めることが出来る。 作る関数は、下のような関数である。

count_parameters(model)がネットワーク(model)の全パラメータ数を取得するもので、 count_trainable_parameters(model)が全ての学習可能なパラメータを取得するものである。 引数は、いずれもパラメータ数を知りたいネットワーク。

中身はすごく単純で、model.parameters()でネットワークの層を取得して、 そのパラメータ数を数え上げている。count_trainable_parametersの方では、さらにrequires_gradがTrue、つまりパラメータが学習可能なもののみを数え上げる。

def count_parameters(model):
    return sum(p.numel() for p in model.parameters())

def count_trainable_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

使い方は、下のようになる。 まず、パラメータ数をしりたいネットワークを読み込む。 下の例では、resnet18を読み込んでいる。そして、上の関数を適用して パラメータ数を取得する(count_parameters(model)の部分)。 2つの関数を2つの条件で試している。 1つ目の条件は重みを固定しない時、もう一つは、ネットワークの重みを固定した時である。

import torch
import torchvision

#モデル読み込み
model=torchvision.models.resnet18()

#重み固定なし
#パラメータ取得して表示
num_parameters=count_parameters(model)
print(num_parameters)
num_parameters=count_trainable_parameters(model)
print(num_parameters)

#重み固定してみる
for param in model.parameters():
    param.requires_grad=False
#パラメータ取得して表示
num_parameters=count_parameters(model)
print(num_parameters)
num_parameters=count_trainable_parameters(model)
print(num_parameters)

結果は、下のようになった。 上の2つは、ネットワークの重みを固定していないときで、どちらのパラメータ取得関数も同じ値となり、ネットワークの全パラメータ数が表示されている。 一方、下の2つは、ネットワークの重みを固定した時の結果で、重みを固定しているので、 パラメータは学習できない状態になるので、学習可能なパラメータ数を得る、count_trainable_parameters()は、結果として、0を返した。

11689512
11689512
11689512
0

注意点としては、モデルのパラメータをlistで保持しているときちんと数えれない。 listではなくnn.ModuleList()を使うことで、数えることができる。

2. torchsummaryというライブラリを使う方法

PyTorchのパラメータ数を取得するライブラリに、 torchsummaryというライブラリがある。

GitHub - sksq96/pytorch-summary: Model summary in PyTorch similar to `model.summary()` in Keras

このライブラリは便利なもので、 レイヤーごとのに出力サイズとパラメータ数を出してくれて、さらに全体でのパラメータ数も表示してくれる。

使い方

pipでインストールできる。

pip install torchsummary

使い方は、簡単で、 下のmodelのところに、自分のネットワークモデルを入れる。(channel,H,W)を 自分のネットワークの入力の次元に変える。

from torchsummary import summary
summary(model, input_size=(channels, H, W))

resnet18で試してみる。使うPyTorchのモデルをsummary()の第1引数に 入れて、第2引数に、入力するtensorのサイズ、(3,224,224)を入れた。

import torchvision.models as models
from torchsummary import summary

model=models.resnet18()
summary(model,(3,224,224))

結果は、下のようなものが出力される。 それぞれのLayerにレイヤーの名前、Output Shapeに第2引数で入れた時の出力のサイズ、 そして、Param #にパラメータの数が表示される。そして、最後の部分に、 全体のパラメータ数と学習可能なパラメータ数が表示される。

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]          36,864
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
       BasicBlock-11           [-1, 64, 56, 56]               0
           Conv2d-12           [-1, 64, 56, 56]          36,864
      BatchNorm2d-13           [-1, 64, 56, 56]             128
             ReLU-14           [-1, 64, 56, 56]               0
           Conv2d-15           [-1, 64, 56, 56]          36,864
      BatchNorm2d-16           [-1, 64, 56, 56]             128
             ReLU-17           [-1, 64, 56, 56]               0
       BasicBlock-18           [-1, 64, 56, 56]               0
           Conv2d-19          [-1, 128, 28, 28]          73,728
      BatchNorm2d-20          [-1, 128, 28, 28]             256
             ReLU-21          [-1, 128, 28, 28]               0
           Conv2d-22          [-1, 128, 28, 28]         147,456
      BatchNorm2d-23          [-1, 128, 28, 28]             256
           Conv2d-24          [-1, 128, 28, 28]           8,192
      BatchNorm2d-25          [-1, 128, 28, 28]             256
             ReLU-26          [-1, 128, 28, 28]               0
       BasicBlock-27          [-1, 128, 28, 28]               0
           Conv2d-28          [-1, 128, 28, 28]         147,456
      BatchNorm2d-29          [-1, 128, 28, 28]             256
             ReLU-30          [-1, 128, 28, 28]               0
           Conv2d-31          [-1, 128, 28, 28]         147,456
      BatchNorm2d-32          [-1, 128, 28, 28]             256
             ReLU-33          [-1, 128, 28, 28]               0
       BasicBlock-34          [-1, 128, 28, 28]               0
           Conv2d-35          [-1, 256, 14, 14]         294,912
      BatchNorm2d-36          [-1, 256, 14, 14]             512
             ReLU-37          [-1, 256, 14, 14]               0
           Conv2d-38          [-1, 256, 14, 14]         589,824
      BatchNorm2d-39          [-1, 256, 14, 14]             512
           Conv2d-40          [-1, 256, 14, 14]          32,768
      BatchNorm2d-41          [-1, 256, 14, 14]             512
             ReLU-42          [-1, 256, 14, 14]               0
       BasicBlock-43          [-1, 256, 14, 14]               0
           Conv2d-44          [-1, 256, 14, 14]         589,824
      BatchNorm2d-45          [-1, 256, 14, 14]             512
             ReLU-46          [-1, 256, 14, 14]               0
           Conv2d-47          [-1, 256, 14, 14]         589,824
      BatchNorm2d-48          [-1, 256, 14, 14]             512
             ReLU-49          [-1, 256, 14, 14]               0
       BasicBlock-50          [-1, 256, 14, 14]               0
           Conv2d-51            [-1, 512, 7, 7]       1,179,648
      BatchNorm2d-52            [-1, 512, 7, 7]           1,024
             ReLU-53            [-1, 512, 7, 7]               0
           Conv2d-54            [-1, 512, 7, 7]       2,359,296
      BatchNorm2d-55            [-1, 512, 7, 7]           1,024
           Conv2d-56            [-1, 512, 7, 7]         131,072
      BatchNorm2d-57            [-1, 512, 7, 7]           1,024
             ReLU-58            [-1, 512, 7, 7]               0
       BasicBlock-59            [-1, 512, 7, 7]               0
           Conv2d-60            [-1, 512, 7, 7]       2,359,296
      BatchNorm2d-61            [-1, 512, 7, 7]           1,024
             ReLU-62            [-1, 512, 7, 7]               0
           Conv2d-63            [-1, 512, 7, 7]       2,359,296
      BatchNorm2d-64            [-1, 512, 7, 7]           1,024
             ReLU-65            [-1, 512, 7, 7]               0
       BasicBlock-66            [-1, 512, 7, 7]               0
AdaptiveAvgPool2d-67            [-1, 512, 1, 1]               0
           Linear-68                 [-1, 1000]         513,000
================================================================
Total params: 11,689,512
Trainable params: 11,689,512
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 62.79
Params size (MB): 44.59
Estimated Total Size (MB): 107.96
----------------------------------------------------------------

前の自作関数のものと比べると、パラメータ数が同じになっているので、正しいことがわかる。

各層ごとのパラメータも見れて、とても便利である。

バグ

しかし、githubのissueを見てみると、パラメータ共有をしていると、バグってしまう と書いてあった。パラメータ共有をしているネットワークを使う時には、このライブラリ 以外でも確認した方がよさそうである。

まとめ

PyTorchのネットワークのパラメータ数を取得する方法として、自分で関数を作って求める方法と、torchsummaryというライブラリを使って取得する方法を紹介した。 torchsummaryは、importして、summary関数のみで、各層のパラメータ数など様々な情報を表示できるので、個人的にはtorchsummaryを使う方法の方がいいと思う。しかし、パラメータ共有などを使っている時は、バグっているので注意が必要。

参考サイト