RANDOM NOTES
ush
ush
ush

master of universe

torch.save(model)をするときはモジュールに気をつける


pytorchでモデルを保存する方法は僕の知っているところ2通りあって、

torch.save(model, filename)

torhc.save(model.state_dict(), filename)

である。

後者はモデルというより、モデルのパラメータを保存してるらしい。

モデルを保存するサンプルコードを探すと、公式ドキュメントでもだいたいのサイトでも後者のmodel.state_dict()を使うサンプルコードになっていると思う。

今回は、なぜ前者でなく後者なのか知る機会があったので、メモする

なぜtorch.save(model, filename)だとダメなのか

理由: モデルをロードしたとき、モデルがインポートしたモジュールが必要になる。

例えば自作のモジュールを使ってモデルを学習したとする。

そのモデルがtorch.save(model, filename)で保存されたとすると、そのモデルをロードするときにその自作のモジュールも必要となる(使う使わないに限らず)。

今回、学習したモデルをロードして推論に使おうとしたが、モデル内のスクリプトでインポートされていたモジュールがインストールできなくて、モデルをロードできないということがあった。

正確には、モデルのモジュールがロードする時に、saveした時と同じモジュールがimport可能になっていないとダメって感じっぽい。

https://github.com/pytorch/pytorch/issues/3678

調べた感じだと、他にもGPUに送ったモデルをtorch.save(model, filename)した場合、そのモデルをロードするときにGPUに置かれるから、GPUが使えないとモデルをロードできなくなるらしい。

ということがあるので、

torhc.save(model.state_dict(), filename)

を使います。