人工知能に関する断創録

このブログでは人工知能のさまざまな分野について調査したことをまとめています(更新停止: 2019年12月31日)

Long Short-Term Memory Networks With Python

最近、仕事でRNNを扱うアプリケーションが多くなっています。そのようなわけで、今回からしばらくRNN(Recurrent Neural Network)についてまとめていこうと思います。参考資料は、

です*1

この本は、RNNの様々なアーキテクチャを Keras で実装して解説しています。取り上げられているアーキテクチャは

  • Vanilla LSTM
  • Stacked LSTM
  • CNN LSTM
  • Encoder Decoder LSTM
  • Bidirectional LSTM
  • Generative LSTM

などです。RNNのタスクというと機械翻訳、音声認識、Image Captioningなど大規模なデータと長い訓練時間が必要なタスクが一般的ですが、この本ではCPUでも訓練できるほど基本的なタスクが取り上られています(MNISTより簡単なレベル)

Kerasでやってはこの本の真似になってしまうので、このブログではPyTorchでやっていきます

上の本のKerasのコードをPyTorchに翻訳しているのですが、KerasとPyTorchではRNNの実装方法がだいぶ違うことを実感しています。最近はPyTorchに慣れているせいか、Kerasの実装が難しく感じます。Kerasはコード量は少ないのですが、ドキュメント当たらないと読みときにくいタイプの難しさです。

個人的にKerasとPyTorchの両方とも使いこなしたいので、これからKerasとPyTorchのRNNの実装を比較しながらまとめていきたいと思います!

*1:このブログの記事はほとんど読んだのですが、機械学習、Deep Learning、Kerasの入門としてとてもよいサイトです。説明が異常に丁寧です。英語だけどオススメ。