cBlog

Tips for you.

Stable DiffusionがIntel MacBook Proで動いたのでメモ

スポンサーリンク
※当ブログのAmazon、iTunes、サウンドハウス等のリンクはアフィリエイトを利用しています。

PyTorchのバックエンドとしてMPSを使い、Stable DiffusionがM1 Macで動いたと聞いた。MPSはMetal Performance Shaderのことらしい。

ほい? MetalならIntel MacのRadeonでも動くのでは?としてやってみた。

 

環境

  • 2.3 GHz 8コアIntel Core i9
  • AMD Radeon Pro 5500M 8 GB
  • macOS Monterey 12.5.1
  • Homebrewで入れたminiforge

 

追記4

GitHubに上げました。

github.com

 

普通に入れる

以下を参考にした:

https://rentry.org/SDInstallGuide

ダウンロードする。

% git clone https://github.com/CompVis/stable-diffusion.git
% cd stable-diffusion

environment.yamlを編集する。

CUDAを使わない:

-  - cudatoolkit=11.3
+  # - cudatoolkit=11.3

MPS対応:

-  - pytorch
+  - pytorch-nightly
-  - pytorch=1.11.0
-  - torchvision=0.12.0
+  - pytorch
+  - torchvision

仮想環境を用意する。

% conda env create -f environment.yaml
% conda activate ldm
% mkdir -p models/ldm/stable-diffusion-v1

https://huggingface.co/CompVis/stable-diffusion-v-1-4-originalでLog InやらAccess Repositoryし、https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/blob/main/sd-v1-4.ckptを(~/Downloadsに)ダウンロードする。

モデルを移動しリネームする。

% mv ~/Downloads/sd-v1-4.ckpt models/ldm/stable-diffusion-v1/model.ckpt

実行してみる。

% python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" --plms

動かない。

 

mpsが動くか確認

以下を参考にした:

zenn.dev

mpsを認識するか確認する。

% python
>>> import torch
>>> torch.device('mps')
device(type='mps')

良さそう。

以下を実行してみる(同記事より引用)。

pytorch_m1_macbook.py

# -*- coding: utf-8 -*-
import torch
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision import transforms as tt
from torchvision.models import resnet18

import os
from argparse import ArgumentParser
import time

def main(device):
    # ResNetのハイパーパラメータ
    n_epoch = 5            # エポック数
    batch_size = 512       # ミニバッチサイズ
    momentum = 0.9         # SGDのmomentum
    lr = 0.01              # 学習率
    weight_decay = 0.00005 # weight decay

    # 訓練データとテストデータを用意
    mean = (0.491, 0.482, 0.446)
    std = (0.247, 0.243, 0.261)
    train_transform = tt.Compose([
        tt.RandomHorizontalFlip(p=0.5),
        tt.RandomCrop(size=32, padding=4, padding_mode='reflect'),
        tt.ToTensor(),
        tt.Normalize(mean=mean, std=std)
    ])
    test_transform = tt.Compose([tt.ToTensor(), tt.Normalize(mean, std)])
    root = os.path.dirname(os.path.abspath(__file__))
    train_set = CIFAR10(root=root, train=True,
                        download=True, transform=train_transform)
    train_loader = DataLoader(train_set, batch_size=batch_size,
                              shuffle=True, num_workers=8)

    # ResNetの準備
    resnet = resnet18()
    resnet.fc = torch.nn.Linear(resnet.fc.in_features, 10)

    # 訓練
    criterion = CrossEntropyLoss()
    optimizer = SGD(resnet.parameters(), lr=lr,
                    momentum=momentum, weight_decay=weight_decay)
    train_start_time = time.time()
    resnet.to(device)
    resnet.train()
    for epoch in range(1, n_epoch+1):
        train_loss = 0.0
        for inputs, labels in train_loader:
            inputs = inputs.to(device)
            optimizer.zero_grad()
            outputs = resnet(inputs)
            labels = labels.to(device)
            loss = criterion(outputs, labels)
            loss.backward()
            train_loss += loss.item()
            del loss  # メモリ節約のため
            optimizer.step()
        print('Epoch {} / {}: time = {}[s], loss = {:.2f}'.format(
            epoch, n_epoch, time.time() - train_start_time, train_loss))
    print('Train time on {}: {:.2f}[s] (Train loss = {:.2f})'.format(
        device, time.time() - train_start_time, train_loss))

    # 評価
    test_set = CIFAR10(root=root, train=False, download=True,
                       transform=test_transform)
    test_loader = DataLoader(test_set, batch_size=batch_size,
                             shuffle=False, num_workers=8)
    test_loss = 0.0
    test_start_time = time.time()
    resnet.eval()
    for inputs, labels in test_loader:
        inputs = inputs.to(device)
        outputs = resnet(inputs)
        labels = labels.to(device)
        loss = criterion(outputs, labels)
        test_loss += loss.item()
    print('Test time on {}: {:.2f}[s](Test loss = {:.2f})'.format(
        device, time.time() - test_start_time, test_loss))


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--device', type=str, default='mps',
                        choices=['cpu', 'mps'])
    args = parser.parse_args()
    device = torch.device(args.device)
    main(device)

CPUの場合(抜粋):

% python pytorch_m1_macbook.py --device cpu
Epoch 1 / 5: time = 249.9817099571228[s], loss = 170.60
Epoch 2 / 5: time = 498.5888819694519[s], loss = 137.21
Epoch 3 / 5: time = 762.4725549221039[s], loss = 122.71
Epoch 4 / 5: time = 1022.5609741210938[s], loss = 112.18
Epoch 5 / 5: time = 1274.3697321414948[s], loss = 103.73
Train time on cpu: 1274.37[s] (Train loss = 103.73)
Test time on cpu: 58.76[s](Test loss = 20.09)

GPUの場合(抜粋):

% python pytorch_m1_macbook.py --device mps
Epoch 1 / 5: time = 131.3166902065277[s], loss = 170.33
Epoch 2 / 5: time = 246.86656522750854[s], loss = 137.14
Epoch 3 / 5: time = 362.39308524131775[s], loss = 122.12
Epoch 4 / 5: time = 478.34768986701965[s], loss = 113.14
Epoch 5 / 5: time = 594.5503239631653[s], loss = 104.61
Train time on mps: 594.55[s] (Train loss = 104.61)
est time on mps: 59.96[s](Test loss = 20.42)

2倍程度に速くなった。

 

Stable Diffusionのコードを修正

以下を参考にした:

zenn.dev

4つのファイルを編集。

  • scripts/txt2img.py
  • ldm/models/diffusion/plms.py
  • configs/stable-diffusion/v1-inference.yamlconfigs/stable-diffusion/v1-inference.yaml
  • /usr/local/Caskroom/miniforge/base/envs/ldm/lib/python3.8/site-packages/torch/nn/functional.py
  • ldm/modules/attention.py

scripts/txt2img.py

         print("unexpected keys:")
         print(u)
 
-    model.cuda()
+    # model.cuda()
+    model.to("mps")
     model.eval()
     return model
 
     config = OmegaConf.load(f"{opt.config}")
     model = load_model_from_config(config, f"{opt.ckpt}")
 
-    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+    # device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+    device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
     model = model.to(device)
 
     if opt.plms:
 
     precision_scope = autocast if opt.precision=="autocast" else nullcontext
     with torch.no_grad():
-        with precision_scope("cuda"):
+        # with precision_scope("cuda"):
+        with nullcontext("mps"):
             with model.ema_scope():
                 tic = time.time()
                 all_samples = list()

ldm/models/diffusion/plms.py

 
     def register_buffer(self, name, attr):
         if type(attr) == torch.Tensor:
-            if attr.device != torch.device("cuda"):
-                attr = attr.to(torch.device("cuda"))
+            if attr.device != torch.device("mps"):
+                attr = attr.to(torch.float32).to(torch.device("mps")).contiguous()
         setattr(self, name, attr)

configs/stable-diffusion/v1-inference.yamlconfigs/stable-diffusion/v1-inference.yaml(編集必要でした)

 
     cond_stage_config:
       target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
+      params: # edit
+        device: mps # edit

/usr/local/Caskroom/miniforge/base/envs/ldm/lib/python3.8/site-packages/torch/nn/functional.py

         return handle_torch_function(
             layer_norm, (input, weight, bias), input, normalized_shape, weight=weight, bias=bias, eps=eps
         )
-    return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)
+    return torch.layer_norm(input.contiguous(), normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled) # edit

ldm/modules/attention.py

         return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)

     def _forward(self, x, context=None):
+        x = x.contiguous()  # edit
         x = self.attn1(self.norm1(x)) + x
         x = self.attn2(self.norm2(x), context=context) + x
         x = self.ff(self.norm3(x)) + x

改めて実行する。

% python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" --plms

動いた!

(下段中央の画像はNSFW filterに引っかかったみたい。)

 

問題点

画像6枚で1時間以上かかりました。

ただ、GPUが10-20%程度しか使われてませんでした(以下記事追記に類似)。

zenn.dev

フルで使われるなど、更なる高速化ができないか探してみます。

 

追記

よくわかりませんが、画像2枚(--n_samples 1 --n_rows 1)の方がGPU使ってくれて(70-80%くらい)、6分くらいで終わります。

あと、--n_rows 1としているのにグリッドが2行になるのもよくわからないです。

 

追記2

  • --n_iter:グリッドの行数(デフォルト2)
  • --n_rows:グリッド1行に何枚か(デフォルト=n_samples

--n_samples 1 --n_iter 1とすべきだった。

画像1枚3分くらいでできた。

 

追記3

torch/nn/functional.pyではなくldm/modules/attention.pyを編集するように変更しました。これでライブラリに手を加えなくて済みます。

目下の興味は--n_samples 2のときにAppleInternal/(中略)/MetalPerformanceShaders/MPSCore/Types/MPSNDArray.mm:705: failed assertion `[MPSTemporaryNDArray initWithDevice:descriptor:] Error: product of dimension sizes > 2**31'というエラーが出ること。MPS側の問題でしょうか。--n_samples 3でも動くのになぜ。