PyTorch Lightning ModelCheckpointのfilename引数にスラッシュを含めるとぶっ壊れる問題

Python深層学習

水ノ茉の宣伝

準備中...
ゲームを作る予定なの
水ノ茉こおり

作業環境

  • Windows 10
  • Visual Studio Code
  • Python 3.11
  • PyTorch Lightning 2.2.1

始まり

指定された要素(精度や損失)を元に上位や下位Nつのチェックポイントを保存してくれる ModelCheckpoint さん。

使用箇所を抜粋するとこんな感じ。

def on_validation_epoch_end(self) -> None:
    super().on_validation_epoch_end()

    loss = torch.stack(self.valid_outputs["loss"]).mean()
    accuracy = torch.cat(self.valid_outputs["accuracy"]).mean()

    self.log("val_loss", loss)
    self.log("val_accuracy", accuracy)

    # free up the memory
    self.valid_outputs["loss"].clear()
    self.valid_outputs["accuracy"].clear()
ModelCheckpoint(
    monitor="val_accuracy",
    filename="checkpoint-{epoch}-{val_accuracy:.8f}-{val_loss:.8f}",
    save_top_k=3,
    mode="max",
    save_last=True,
)
引数名説明
monitor監視対象を指定
filename保存するチェックポイント名のフォーマットを指定
save_top_k最大で幾つ保存するかを指定
mode精度なら上位(max)、損失なら下位(min)を指定
save_last最後/エポック毎のチェックポイントをlast.ckptとして保存するか

たったこれだけ。めっちゃシンプルで便利です。
※他にも指定可能な引数はありますが必要最小限が上記。

ちょっと寄り道します。

PyTorch Lightningのログはステップごとに記録されます。ステップごとの記録の欠点はデータセットが増減した際に横軸がズレるため、ぱっと見で比較がしにくいことです。そのため筆者はステップごととは別にエポックごとにも記録を付けています。

愚直にLoggerに書き込むと画像のように縦長に展開されてしまいます。気にならない方はいいですが、ちょっと縦長過ぎる気がします。

そんな時は グループ・カテゴリ名/名称 (例: "train/loss", "valid/loss") と指定することでグループ・カテゴリ別けができます。

def on_train_epoch_end(self) -> None:
    super().on_train_epoch_end()

    loss = torch.stack(self.train_outputs["loss"]).mean()
    accuracy = torch.cat(self.train_outputs["accuracy"]).mean()

    # train groups
    self.log("train/loss", loss)
    self.log("train/accuracy", accuracy)

    # free up the memory
    self.train_outputs["loss"].clear()
    self.train_outputs["accuracy"].clear()

def on_validation_epoch_end(self) -> None:
    super().on_validation_epoch_end()

    loss = torch.stack(self.valid_outputs["loss"]).mean()
    accuracy = torch.cat(self.valid_outputs["accuracy"]).mean()

    # valid groups
    self.log("valid/loss", loss)
    self.log("valid/accuracy", accuracy)

    # free up the memory
    self.valid_outputs["loss"].clear()
    self.valid_outputs["accuracy"].clear()

trainとvalidで纏めて表示されるため見やすいですね。

問題

本題に戻りますか。

スラッシュ(/)を用いてグループ・カテゴリ別けしたものを監視対象に指定するとどうなるでしょうか。

ModelCheckpoint(
    monitor="valid/accuracy",
    filename="checkpoint-{epoch}-{valid/accuracy:.8f}-{valid/loss:.8f}",
    save_top_k=3,
    mode="max",
    save_last=True,
)

はい、ぶっ壊れました。

スラッシュがディレクトリの区切り記号として認識されてしまい、ファイル名のフォーマットが挙動不審になってしまうのです。

これの対処法です。

解決

とっても簡単です。
auto_insert_metric_name=Falseを指定するだけです。

ModelCheckpoint(
    monitor="valid/accuracy",
    filename="checkpoint-{epoch}-{valid/accuracy:.8f}-{valid/loss:.8f}",
    save_top_k=3,
    mode="max",
    save_last=True,
    auto_insert_metric_name=False,
)
引数名説明
monitor監視対象を指定
filename保存するチェックポイント名のフォーマットを指定
save_top_k最大で幾つ保存するかを指定
mode精度なら上位(max)、損失なら下位(min)を指定
save_last最後/エポック毎のチェックポイントをlast.ckptとして保存するか
auto_insert_metric_nameファイル名にメトリック名(たぶんlossとかaccuracyの総称)を含めるか

auto_insert_metric_nameのデフォルト値はTrueです。

そのためfilename="checkpoint-{epoch}-{val_accuracy:.8f}-{val_loss:.8f}"と指定するだけでファイル名にepochやval_accuracy、val_lossなどのメトリック名が自動的に挿入されていました。

要はディレクトリの区切り記号とauto_insert_metric_nameの相性が悪いため、発生していた問題なのでした。

auto_insert_metric_name=Falseを指定するとepochaccuracylossなどのメトリック名を自動で挿入してくれないためfilenameに明示的に指定しましょう。

ModelCheckpoint(
    monitor="valid/accuracy",
    filename="checkpoint-epoch={epoch}-accuracy={valid/accuracy:.8f}-loss={valid/loss:.8f}",
    save_top_k=3,
    mode="max",
    save_last=True,
    auto_insert_metric_name=False,
)

大 解 決

おわり!!!

かれこれ2年ほどこの問題を放置していました。

業務で扱っている訳ではないのでぶっちゃけ未来の自分が汲み取れる範囲なら、ある程度適当であったり、問題を放置したりしても困ることなかったんですよね。

参考