技術ブログ

  1. HOME
  2. ブログ
  3. IT技術
  4. PyTorch Lightningとは?Igniteとは?PyTorchとの違いは?深層学習(ディープラーニング)ライブラリを分かりやすく解説!

PyTorch Lightningとは?Igniteとは?PyTorchとの違いは?深層学習(ディープラーニング)ライブラリを分かりやすく解説!

こんにちは!PA Labメディアです。

もともと深層学習ライブラリはあちこちで開発が行われておりましたが、最近ではPyTorchとTensorFlow(Kerasを含む)が主要なフレームワークとして幅広く使われています。TensorFlowではKerasやEstimatorと呼ばれる高レベルAPIが豊富に含まれています。簡単に言うと複雑な処理を記述しなくても呼び出すだけで使用できるAPIのことです。

PyTorchでは公式の高レベルのラッパーとしてIgniteが用意されていますが、Pytorch Lightningというサードパーティのライブラリが実際の現場では使用される事が多いです。そこで今日はPyTorchの高レベルなインタフェースを提供しているPyTorch Lightingというライブラリの紹介をしていきます。

今回の記事では以下のような方を対象にしております。

  • 「Pytorch LightningとIgniteの違いが分からない」
  • 「ディープラーニングを始めてみたいが、どのライブラリを使えばよいかわからない」
  • 「Pytorchで深層学習モデルを組んでいるが、複雑になってきた」

AIの専門家であるPA Labが分かりやすく簡単にPyTorch Lightning、Ignite、Pytorchとの違いに関して紹介していきます。

目次

深層学習ライブラリPyTorchとは

深層学習のライブラリの歴史は意外と長く、昔から様々なライブラリが使用されていました。最近のライブラリとしてはTensorFlowやPyTorchが有名ですが、元々良く使用されていたライブラリとしてLuaという言語で書かれたTorchというライブラリがあります。元Googleの研究者で今はAppleの研究者であるSamy Bengioという最先端の深層学習の研究を行っている方に作成されたこのライブラリは広く使用されており、FacebookのAI研究グループでもよく使用されていました。

このTorchというライブラリと深層学習ライブラリChainerを元に、FacebookのAI研究グループが開発したライブラリがPyTorchというオープンソースの深層学習ライブラリになります。

日本語では「パイトーチ」と読み、現在ではTensorFlowと並んで最も使用されるライブラリとなっています。

PyTorchの詳細や実際の使い方に関しては、別の記事にて解説していきます。

PyTorchの高レベルAPIライブラリ:Igniteとは

IgniteはPyTochが公式で開発をしている高レベルAPIになります。
Igniteを使用すれば、PyTorchでコードを直接書くよりもコードの量を減らした上に簡単に実装することが可能になります。特にメトリックや実験管理に優れたライブラリになっています。

現在のGithubのスター数は3700となっています。(2021/09/24調べ)

PyTorchの高レベルAPIライブラリ:Pytorch Lightningとは

PyTorch Lightningはサードパーティライブラリですが、最もよく使用されるPyTorchの高レベルAPIとなっています。

Pytorch Lightningは様々な機能を含んでいて、深層学習の研究者にとっても深層学習を活用したシステムを開発するエンジニアの方にも有用なライブラリになっています。研究用のコードはもちろん、ロギング、デプロイ周りでも簡単にコードを書く事が出来るようになっています。現在ではマルチGPUの学習、16bitでの学習、TPUの活用などの実装も含まれています。

PyTorch Lightningのメリットとしては以下のようなものがあります。

  1. どのハードウェアでも学習: コードをいじることなくCPU, GPU, TPUで学習可能
  2. 16bitでの学習:メモリを半分にすることで学習速度が向上
  3. 簡潔なインターフェース: 無駄なコードを書かずに必要な箇所だけに集中可能
  4. 高度な拡張性: 複雑な関数を記述することが可能
  5. 視覚化ツールとの統合サポート: Tnesorboard, MLFlow, Comet.mlなどのツールとの連携サポート

現在のGithubのスター数は15400(2021/09/24調べ)となっており、公式で開発されているIgniteと比べても非常に多い数字になっています。

PyTorch Lightningのインストール

PyTorch Lightningはpipやcondaにより簡単にインストールすることが可能です。

pipを使用する場合は以下のコマンドで簡単にインストールする事が出来ます。

$ pip install pytorch-lightning

condaを使用する場合は以下のコマンドで簡単にインストールする事が出来ます。

$ conda install pytorch-lightning -c conda-forge

PyTorchとPyTorch Lightningのオートエンコーダの実装比較

PyTorchを使用した場合

# models
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
encoder.cuda(0)
decoder.cuda(0)

# download on rank 0 only
if global_rank == 0:
    mnist_train = MNIST(os.getcwd(), train=True, download=True)

# download on rank 0 only
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize(0.5, 0.5)])
mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)

# train (55,000 images), val split (5,000 images)
mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])

# The dataloaders handle shuffling, batching, etc...
mnist_train = DataLoader(mnist_train, batch_size=64)
mnist_train
mnist_val = DataLoader(mnist_val, batch_size=64)
mnist_val

# optimizer
params = [encoder.parameters(), decoder.parameters()]
optimizer = torch.optim.Adam(params, lr=1e-3)

# TRAIN LOOP
model.train()
num_epochs = 1
for epoch in range(num_epochs):
    for train_batch in mnist_train:
        x, y = train_batch
        x = x.cuda(0)
        x = x.view(x.size(0), -1)
        z = encoder(x)
        x_hat = decoder(z)
        loss = F.mse_loss(x_hat, x)
        print('train loss: ', loss.item())
        loss.backward()
        optimizer.step()

このコードの中では

  1. 深層学習モデルの実装
  2. cuda/gpuの設定
  3. データローダの自前処理
  4. optimizerのパラメータ設定
  5. 学習・評価の処理の実装

という処理が含まれている形になります。PyTorchだけでもかなりシンプルに書く事が出来ていますが、PyTorch Lightningではどのようになるでしょうか。

PyTorch Lightningを使用した場合(2ステップ)

PyTorchだと基本的に2ステップの操作で実装をしていく事が可能になります。

LightningModuleの子クラスでの定義

LightningModuleを継承した子クラスを作成して呼び出す事で、モデルの定義を行う事が可能です。

先程のオートエンコーダのモデルの実装を行います。

class LitAutoEncoder(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
        self.decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        embedding = self.encoder(x)
        return embedding

    def training_step(self, batch, batch_idx):
        # training_step defined the train loop.
        # It is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        # Logging to TensorBoard by default
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

Lightning Trainerでの学習

次にLightning Trainerを使用して学習を進めていきます。こちらにモデルを渡す事で自動的に学習・評価を行う事が可能になります。

dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train_loader = DataLoader(dataset)

autoencoder = LitAutoEncoder()

# most basic trainer, uses good defaults (auto-tensorboard, checkpoints, logs, and more)
# trainer = pl.Trainer(gpus=8) (if you have GPUs)
trainer = pl.Trainer()
trainer.fit(autoencoder, train_loader)

上記のようにコードがとてもスッキリしていることが分かるでしょうか。
このようにシンプルで読みやすい実装を行う事で、間違いがなく拡張性・再現性の高いコードを書き、深層学習のモデル側の設計に集中して開発を行う事が可能になります。これがPyTorch Lightningの利点であり、高レベルAPIの利点となります。また難しい拡張を行っていく上でもPyTorchを書く事に変わりはないため、簡単にスタートする事が出来る上に拡張性は高いため、基本的にはPyTorch Lightningなどの高レベルAPIでスタートする事をオススメします。

まとめ

今回はPyTorchの簡単な紹介、PyTorch Lightning、Igniteなどの高レベルAPI、PyTorchとPyTorch Lightningのライブラリの実装の比較などを紹介してきました。PyTorchは最も使われる深層学習のライブラリの一つであり、今後も必須のライブラリとなっています。初学者も最近PyTorchを初めた方もまずはPyTorch Lightningを使用した実装に切り替える事で、シンプルで読みやすく拡張性の高いコードを目指して深層学習モデルの設計に集中することで、より価値のある実装を行っていきましょう。


PA Labでは「AIを用いた自動化×サービス開発」の専門家として活動をしています。高度なデータ分析からシステム開発まで一貫したサービス提供を行っており、特に機械学習やディープラーニングを中心としたビジネス促進を得意としております。

無料で分析設計/データ活用に関するご相談も実施中なので、ご相談があればお問い合わせまで。

  1. この記事へのコメントはありません。

  1. この記事へのトラックバックはありません。

関連記事