Long Short-Term Memory Networks With Python
最近、仕事でRNNを扱うアプリケーションが多くなっています。そのようなわけで、今回からしばらくRNN(Recurrent Neural Network)についてまとめていこうと思います。参考資料は、
です*1。
この本は、RNNの様々なアーキテクチャを Keras で実装して解説しています。取り上げられているアーキテクチャは
- Vanilla LSTM
- Stacked LSTM
- CNN LSTM
- Encoder Decoder LSTM
- Bidirectional LSTM
- Generative LSTM
などです。RNNのタスクというと機械翻訳、音声認識、Image Captioningなど大規模なデータと長い訓練時間が必要なタスクが一般的ですが、この本ではCPUでも訓練できるほど基本的なタスクが取り上られています(MNISTより簡単なレベル)
Kerasでやってはこの本の真似になってしまうので、このブログではPyTorchでやっていきます。
上の本のKerasのコードをPyTorchに翻訳しているのですが、KerasとPyTorchではRNNの実装方法がだいぶ違うことを実感しています。最近はPyTorchに慣れているせいか、Kerasの実装が難しく感じます。Kerasはコード量は少ないのですが、ドキュメント当たらないと読みときにくいタイプの難しさです。
個人的にKerasとPyTorchの両方とも使いこなしたいので、これからKerasとPyTorchのRNNの実装を比較しながらまとめていきたいと思います!
*1:このブログの記事はほとんど読んだのですが、機械学習、Deep Learning、Kerasの入門としてとてもよいサイトです。説明が異常に丁寧です。英語だけどオススメ。