水ノ茉の宣伝
準備中...
ゲームを作る予定なの水ノ茉
作業環境
- Windows 10
- Visual Studio Code
- Python 3.11
- PyTorch Lightning 2.2.1
始まり
指定された要素(精度や損失)を元に上位や下位Nつのチェックポイントを保存してくれる ModelCheckpoint さん。
使用箇所を抜粋するとこんな感じ。
def on_validation_epoch_end(self) -> None:
super().on_validation_epoch_end()
loss = torch.stack(self.valid_outputs["loss"]).mean()
accuracy = torch.cat(self.valid_outputs["accuracy"]).mean()
self.log("val_loss", loss)
self.log("val_accuracy", accuracy)
# free up the memory
self.valid_outputs["loss"].clear()
self.valid_outputs["accuracy"].clear()
ModelCheckpoint(
monitor="val_accuracy",
filename="checkpoint-{epoch}-{val_accuracy:.8f}-{val_loss:.8f}",
save_top_k=3,
mode="max",
save_last=True,
)

| 引数名 | 説明 |
|---|---|
| monitor | 監視対象を指定 |
| filename | 保存するチェックポイント名のフォーマットを指定 |
| save_top_k | 最大で幾つ保存するかを指定 |
| mode | 精度なら上位(max)、損失なら下位(min)を指定 |
| save_last | 最後/エポック毎のチェックポイントをlast.ckptとして保存するか |
たったこれだけ。めっちゃシンプルで便利です。
※他にも指定可能な引数はありますが必要最小限が上記。
ちょっと寄り道します。
PyTorch Lightningのログはステップごとに記録されます。ステップごとの記録の欠点はデータセットが増減した際に横軸がズレるため、ぱっと見で比較がしにくいことです。そのため筆者はステップごととは別にエポックごとにも記録を付けています。

愚直にLoggerに書き込むと画像のように縦長に展開されてしまいます。気にならない方はいいですが、ちょっと縦長過ぎる気がします。
そんな時は グループ・カテゴリ名/名称 (例: "train/loss", "valid/loss") と指定することでグループ・カテゴリ別けができます。
def on_train_epoch_end(self) -> None:
super().on_train_epoch_end()
loss = torch.stack(self.train_outputs["loss"]).mean()
accuracy = torch.cat(self.train_outputs["accuracy"]).mean()
# train groups
self.log("train/loss", loss)
self.log("train/accuracy", accuracy)
# free up the memory
self.train_outputs["loss"].clear()
self.train_outputs["accuracy"].clear()
def on_validation_epoch_end(self) -> None:
super().on_validation_epoch_end()
loss = torch.stack(self.valid_outputs["loss"]).mean()
accuracy = torch.cat(self.valid_outputs["accuracy"]).mean()
# valid groups
self.log("valid/loss", loss)
self.log("valid/accuracy", accuracy)
# free up the memory
self.valid_outputs["loss"].clear()
self.valid_outputs["accuracy"].clear()

trainとvalidで纏めて表示されるため見やすいですね。
問題
本題に戻りますか。
スラッシュ(/)を用いてグループ・カテゴリ別けしたものを監視対象に指定するとどうなるでしょうか。
ModelCheckpoint(
monitor="valid/accuracy",
filename="checkpoint-{epoch}-{valid/accuracy:.8f}-{valid/loss:.8f}",
save_top_k=3,
mode="max",
save_last=True,
)

はい、ぶっ壊れました。
スラッシュがディレクトリの区切り記号として認識されてしまい、ファイル名のフォーマットが挙動不審になってしまうのです。
これの対処法です。
解決
とっても簡単です。auto_insert_metric_name=Falseを指定するだけです。
ModelCheckpoint(
monitor="valid/accuracy",
filename="checkpoint-{epoch}-{valid/accuracy:.8f}-{valid/loss:.8f}",
save_top_k=3,
mode="max",
save_last=True,
auto_insert_metric_name=False,
)

| 引数名 | 説明 |
|---|---|
| monitor | 監視対象を指定 |
| filename | 保存するチェックポイント名のフォーマットを指定 |
| save_top_k | 最大で幾つ保存するかを指定 |
| mode | 精度なら上位(max)、損失なら下位(min)を指定 |
| save_last | 最後/エポック毎のチェックポイントをlast.ckptとして保存するか |
| auto_insert_metric_name | ファイル名にメトリック名(たぶんlossとかaccuracyの総称)を含めるか |
auto_insert_metric_nameのデフォルト値はTrueです。
そのためfilename="checkpoint-{epoch}-{val_accuracy:.8f}-{val_loss:.8f}"と指定するだけでファイル名にepochやval_accuracy、val_lossなどのメトリック名が自動的に挿入されていました。
要はディレクトリの区切り記号とauto_insert_metric_nameの相性が悪いため、発生していた問題なのでした。
auto_insert_metric_name=Falseを指定するとepochやaccuracy、lossなどのメトリック名を自動で挿入してくれないためfilenameに明示的に指定しましょう。
ModelCheckpoint(
monitor="valid/accuracy",
filename="checkpoint-epoch={epoch}-accuracy={valid/accuracy:.8f}-loss={valid/loss:.8f}",
save_top_k=3,
mode="max",
save_last=True,
auto_insert_metric_name=False,
)

大 解 決
おわり!!!
かれこれ2年ほどこの問題を放置していました。
業務で扱っている訳ではないのでぶっちゃけ未来の自分が汲み取れる範囲なら、ある程度適当であったり、問題を放置したりしても困ることなかったんですよね。