Úvod do práce s PyTorch pro HPC

Tento návod slouží jako startovní bod pro studenty zapojené do ESF projektu pro podporu AI. Cílem není vysvětlovat, jak fungují jednotlivé modely neuronových sítí, ale jak správně technicky nastavit Python skript, aby byl robustní, konfigurovatelný a připravený pro běh na výpočetním clusteru (pomocí systému Slurm) i na vašem lokálním počítači.

Co je v tomto návodu?

  1. Příkazová řádka (argparse): Jak měnit parametry Python scriptu bez přepisování kódu.
  2. Hardware (CPU vs GPU): Automatická detekce a využití grafické karty.
  3. Vlastní data: Jak načíst data, která nejsou součástí standardních knihoven.
  4. Checkpointy: Ukládání a načítání stavu trénování (záloha proti pádům).

1. Předávání argumentů (Argparse)

Když spouštíte úlohy na clusteru (přes Slurm), nechcete pro každou změnu learning_rate nebo počtu epoch editovat zdrojový kód. Python má vestavěnou knihovnu argparse.

import argparse

def get_args():
    parser = argparse.ArgumentParser(description="PyTorch Training Script")

    # Základní parametry
    parser.add_argument("--epochs", type=int, default=10, help="Počet epoch")
    parser.add_argument("--batch-size", type=int, default=32, help="Velikost dávky (batch)")
    parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")

    # Cesty k souborům
    parser.add_argument("--data-path", type=str, default="./data", help="Cesta k datům")
    parser.add_argument("--save-path", type=str, default="./checkpoints", help="Kam ukládat model")

    # Přepínač pro pokračování v tréninku
    parser.add_argument("--resume", action="store_true", help="Načíst poslední checkpoint, pokud existuje")

    return parser.parse_args()

# Použití v kódu:
# args = get_args()
# print(f"Trénuji {args.epochs} epoch...")

2. Automatické přepínání CPU a GPU

PyTorch vyžaduje, aby data i model byly na stejném zařízení. Na clusteru může být dostupná GPU, na vašem notebooku třeba jen CPU. Tento kód to vyřeší univerzálně:

import torch

# Zjistíme, zda je dostupná CUDA (Nvidia GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Používám zařízení: {device}")

# Později v kódu musíte poslat model i data na toto zařízení:
# model = MyModel().to(device)
# inputs = inputs.to(device)

3. Načítání nestandardních dat

Většina tutoriálů používá MNIST. V praxi ale máte vlastní CSV, obrázky nebo texty. V PyTorch se k tomu dědí třída Dataset. Musíte implementovat tři metody: __init__, __len__ a __getitem__.

from torch.utils.data import Dataset, DataLoader
import numpy as np # Příklad, pokud máte data v numpy nebo CSV

class MujVlastniDataset(Dataset):
    def __init__(self, data_path):
        # Zde načtěte seznam souborů nebo celou tabulku do paměti
        # Příklad: Generujeme náhodná data, vy byste zde použili např. pandas.read_csv(data_path)
        self.data = np.random.randn(100, 10).astype(np.float32) # 100 vzorků, 10 features
        self.labels = np.random.randint(0, 2, size=(100))       # Binární klasifikace

    def __len__(self):
        # Musí vrátit celkový počet vzorků
        return len(self.data)

    def __getitem__(self, idx):
        # Zde vrátíte jeden konkrétní vzorek a jeho label
        sample = self.data[idx]
        label = self.labels[idx]

        # Převedeme na PyTorch tenzory
        return torch.from_numpy(sample), torch.tensor(label)

4. Checkpointy (Ukládání a načítání)

Výpočty na clusteru mohou trvat dny a mohou spadnout neby mohou být ukončeny z důvodu časového limitu clusteru. Je tedy kritické ukládat průběžný stav. Neukládáme jen váhy modelu (state_dict), ale i stav optimalizátoru a číslo epochy, abychom mohli navázat.

Uložení:

checkpoint = {
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
}
torch.save(checkpoint, "cesta/k/souboru.pth")

Načtení:

checkpoint = torch.load("cesta/k/souboru.pth")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']

Kompletní šablona (Copy & Paste)

Tento skript spojuje vše výše uvedené. Můžete ho zkopírovat do souboru train.py a rovnou spustit.

import os
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np

# -------------------------------------------------------------------
# 1. Definice Datasetu (Zde upravte pro svá data)
# -------------------------------------------------------------------
class MujDataset(Dataset):
    def __init__(self, num_samples=1000, input_dim=20):
        # Simulace dat - nahraďte načítáním z disku
        self.x = np.random.randn(num_samples, input_dim).astype(np.float32)
        self.y = np.random.randint(0, 2, size=(num_samples)).astype(np.float32)

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return torch.from_numpy(self.x[idx]), torch.tensor(self.y[idx]).unsqueeze(0)

# -------------------------------------------------------------------
# 2. Jednoduchý Model (Zde nahraďte svou architekturou)
# -------------------------------------------------------------------
class SimpleModel(nn.Module):
    def __init__(self, input_dim):
        super(SimpleModel, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x)

# -------------------------------------------------------------------
# 3. Pomocné funkce pro Checkpointy
# -------------------------------------------------------------------
def save_checkpoint(state, filename="checkpoint.pth"):
    print(f"=> Ukládám checkpoint do {filename}")
    torch.save(state, filename)

def load_checkpoint(checkpoint_path, model, optimizer):
    print(f"=> Načítám checkpoint z {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return checkpoint['epoch']

# -------------------------------------------------------------------
# 4. Hlavní trénovací smyčka
# -------------------------------------------------------------------
def main():
    # --- Nastavení argumentů ---
    parser = argparse.ArgumentParser()
    parser.add_argument("--epochs", type=int, default=5, help="Počet epoch")
    parser.add_argument("--batch-size", type=int, default=32, help="Velikost batch")
    parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
    parser.add_argument("--save-path", type=str, default="checkpoints", help="Složka pro ukládání")
    parser.add_argument("--resume", type=str, default=None, help="Cesta k checkpointu pro navázání (např. checkpoints/last.pth)")
    args = parser.parse_args()

    # Vytvoření složky pro checkpointy, pokud neexistuje
    os.makedirs(args.save_path, exist_ok=True)

    # --- Nastavení zařízení (CPU/GPU) ---
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Běžím na zařízení: {device}")

    # --- Příprava Dat ---
    dataset = MujDataset()
    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)

    # --- Inicializace modelu ---
    model = SimpleModel(input_dim=20).to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    criterion = nn.BCELoss()

    start_epoch = 0

    # --- Načtení checkpointu (pokud je zadán) ---
    if args.resume and os.path.isfile(args.resume):
        start_epoch = load_checkpoint(args.resume, model, optimizer)
        start_epoch += 1 # Pokračujeme následující epochou
        print(f"Pokračuji od epochy {start_epoch}")

    # --- Trénovací smyčka ---
    for epoch in range(start_epoch, args.epochs):
        model.train() # Přepnutí do trénovacího módu
        running_loss = 0.0

        for i, (inputs, labels) in enumerate(dataloader):
            # Přesun dat na GPU/CPU
            inputs, labels = inputs.to(device), labels.to(device)

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Backward pass a optimalizace
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epocha [{epoch+1}/{args.epochs}], Loss: {running_loss/len(dataloader):.4f}")

        # Uložení checkpointu po každé epoše
        checkpoint_path = os.path.join(args.save_path, "last_checkpoint.pth")
        save_checkpoint({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': running_loss,
        }, filename=checkpoint_path)

    print("Trénování dokončeno.")

if __name__ == "__main__":
    main()

Jak skript spustit?

Na lokálním počítači:

python train.py --epochs 10 --batch-size 64

Příklad pro SLURM pro Laniakea cluster (run_job.sh):

#!/bin/bash
#SBATCH --partition=Virgo_A 

# Spuštění s parametry
python train.py --epochs 100 --batch-size 128 --save-path ./checkpoints