← На главную

Как не сломать PyTorch training loop: порядок операций

22.06.2026 23:44 · hackernews

Собрать training loop в PyTorch — вроде бы просто, но переставить пару строк — и всё ломается. Причём исключений не будет: модель либо не сойдётся, либо выдаст мусор, либо сожрёт всю память. Разберём типичные ошибки и правильный порядок.

Самый частый косяк — вызвать model.to(device) после создания оптимизатора. Если при этом ещё и тип данных меняется (например, .half()), nn.Module.to() создаёт новые nn.Parameter, а оптимизатор всё ещё держит ссылки на старые и обновляет их. Весёлая картина. optimiser.zero_grad() должен стоять до loss.backward(), иначе градиенты накопятся от нескольких батчей — получите сумму вместо текущего. clip_grad_norm_() — только после backward() и до step(). Поставите до — пусто, после — уже поздно. scheduler.step() вызывается раз за эпоху, снаружи батч-цикла. Если поставить внутрь — learning rate будет падать в разы чаще, чем нужно.

Забудете model.train() после валидации — Dropout не включится, BatchNorm замёрзнет. На валидации обязательно torch.no_grad(), иначе autograd построит граф на каждом батче — память утечёт до OOM. И не пишите loss в лог напрямую, берите loss.item() — иначе граф останется висеть, пока логируется тензор.

Теперь как правильно. Данные: TensorDataset упаковывает входы и метки, DataLoader дробит на батчи. num_workers > 0 ускоряет загрузку, pin_memory=True ускоряет передачу на GPU, persistent_workers=True не пересоздаёт процессы между эпохами. batch_size лучше брать кратным 16 или 8 — это удобно для tensor core.

Модель: наследуете nn.Module, обязательно вызываете super().__init__(), иначе параметры не зарегистрируются. model.to(device) — до создания оптимизатора. model.train() переключает Dropout и BatchNorm в обучающий режим. Внутри цикла: optimiser.zero_grad(), потом forward (logits = model(X_batch)), loss (criterion(logits, y_batch)), loss.backward(), clip_grad_norm_(), optimiser.step(). После батч-цикла — scheduler.step(). На валидации — model.eval() с torch.no_grad().

Если нужно увеличить эффективный batch size, используйте gradient accumulation: суммируйте градиенты на нескольких микро-батчах, потом делайте step. Всё по той же схеме, но zero_grad() только после шага, а в промежутках — накопление.

Порядок операций критичен. Запомните его — и тренировка пойдёт как надо.

Читать оригинал →