🎤 PyTorch-караоке, или как перестать забывать optimizer.zero_grad()

Многие новички в DL ломаются не на сложной архитектуре сетей, а на банальном бойлерплейте PyTorch. Порядок вызовов в тренировочном цикле — это то, что приходится гуглить раз за разом, пока это не отпечатается на подкорке.

Забыл перевести модель в train()? Получи неверные веса из-за Dropout/BatchNorm.
Забыл zero_grad()? Поздравляю, градиенты аккумулируются, обучение идет в помойку.
Поставил step() перед backward()? Ну, вы поняли.

Наткнулся на видео, которое вызывает одновременно два чувства: лютый кринж и респект.
Парень в трусах и с микрофоном посреди бардака просто взял и переложил 5 шагов Backpropagation на мотивчик, который не выкинуть из головы.
Как будто это работает лучше, чем 10 часов нудных лекций от экспертов, которые читают по слайдам.

Для тех, кто в танке, напоминаю единственно верный порядок действий, который нужно вытатуировать у себя на подкорке (или выучить песню):

1️⃣ model.train() — переводим модель в боевой режим.
2️⃣ y_pred = model(x) — Forward pass.
3️⃣ loss = loss_fn(y_pred, y) — Считаем, насколько мы ошиблись.
4️⃣ optimizer.zero_grad() — Обнуляем градиенты перед новым шагом, иначе они накопятся и вы улетите в космос.
5️⃣ loss.backward() — Считаем градиенты (обратное распространение).
6️⃣ optimizer.step() — Делаем шаг оптимизатором.

Ещё бы версию для TensorFlow, но там, боюсь, придется писать оперу в трех актах, чтобы просто инициализировать переменные 🌚

#фана_ради