Úvod do práce s PyTorch pro HPC
Tento návod slouží jako startovní bod pro studenty zapojené do ESF projektu pro podporu AI. Cílem není vysvětlovat, jak fungují jednotlivé modely neuronových sítí, ale jak správně technicky nastavit Python skript, aby byl robustní, konfigurovatelný a připravený pro běh na výpočetním clusteru (pomocí systému Slurm) i na vašem lokálním počítači.
Co je v tomto návodu?
- Příkazová řádka (
argparse): Jak měnit parametry Python scriptu bez přepisování kódu. - Hardware (
CPUvsGPU): Automatická detekce a využití grafické karty. - Vlastní data: Jak načíst data, která nejsou součástí standardních knihoven.
- Checkpointy: Ukládání a načítání stavu trénování (záloha proti pádům).
1. Předávání argumentů (Argparse)
Když spouštíte úlohy na clusteru (přes Slurm), nechcete pro každou změnu learning_rate nebo počtu epoch editovat zdrojový kód. Python má vestavěnou knihovnu argparse.
import argparse
def get_args():
parser = argparse.ArgumentParser(description="PyTorch Training Script")
# Základní parametry
parser.add_argument("--epochs", type=int, default=10, help="Počet epoch")
parser.add_argument("--batch-size", type=int, default=32, help="Velikost dávky (batch)")
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
# Cesty k souborům
parser.add_argument("--data-path", type=str, default="./data", help="Cesta k datům")
parser.add_argument("--save-path", type=str, default="./checkpoints", help="Kam ukládat model")
# Přepínač pro pokračování v tréninku
parser.add_argument("--resume", action="store_true", help="Načíst poslední checkpoint, pokud existuje")
return parser.parse_args()
# Použití v kódu:
# args = get_args()
# print(f"Trénuji {args.epochs} epoch...")
2. Automatické přepínání CPU a GPU
PyTorch vyžaduje, aby data i model byly na stejném zařízení. Na clusteru může být dostupná GPU, na vašem notebooku třeba jen CPU. Tento kód to vyřeší univerzálně:
import torch
# Zjistíme, zda je dostupná CUDA (Nvidia GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Používám zařízení: {device}")
# Později v kódu musíte poslat model i data na toto zařízení:
# model = MyModel().to(device)
# inputs = inputs.to(device)
3. Načítání nestandardních dat
Většina tutoriálů používá MNIST. V praxi ale máte vlastní CSV, obrázky nebo texty. V PyTorch se k tomu dědí třída Dataset. Musíte implementovat tři metody: __init__, __len__ a __getitem__.
from torch.utils.data import Dataset, DataLoader
import numpy as np # Příklad, pokud máte data v numpy nebo CSV
class MujVlastniDataset(Dataset):
def __init__(self, data_path):
# Zde načtěte seznam souborů nebo celou tabulku do paměti
# Příklad: Generujeme náhodná data, vy byste zde použili např. pandas.read_csv(data_path)
self.data = np.random.randn(100, 10).astype(np.float32) # 100 vzorků, 10 features
self.labels = np.random.randint(0, 2, size=(100)) # Binární klasifikace
def __len__(self):
# Musí vrátit celkový počet vzorků
return len(self.data)
def __getitem__(self, idx):
# Zde vrátíte jeden konkrétní vzorek a jeho label
sample = self.data[idx]
label = self.labels[idx]
# Převedeme na PyTorch tenzory
return torch.from_numpy(sample), torch.tensor(label)
4. Checkpointy (Ukládání a načítání)
Výpočty na clusteru mohou trvat dny a mohou spadnout neby mohou být ukončeny z důvodu časového limitu clusteru. Je tedy kritické ukládat průběžný stav. Neukládáme jen váhy modelu (state_dict), ale i stav optimalizátoru a číslo epochy, abychom mohli navázat.
Uložení:
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}
torch.save(checkpoint, "cesta/k/souboru.pth")
Načtení:
checkpoint = torch.load("cesta/k/souboru.pth")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
Kompletní šablona (Copy & Paste)
Tento skript spojuje vše výše uvedené. Můžete ho zkopírovat do souboru train.py a rovnou spustit.
import os
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
# -------------------------------------------------------------------
# 1. Definice Datasetu (Zde upravte pro svá data)
# -------------------------------------------------------------------
class MujDataset(Dataset):
def __init__(self, num_samples=1000, input_dim=20):
# Simulace dat - nahraďte načítáním z disku
self.x = np.random.randn(num_samples, input_dim).astype(np.float32)
self.y = np.random.randint(0, 2, size=(num_samples)).astype(np.float32)
def __len__(self):
return len(self.x)
def __getitem__(self, idx):
return torch.from_numpy(self.x[idx]), torch.tensor(self.y[idx]).unsqueeze(0)
# -------------------------------------------------------------------
# 2. Jednoduchý Model (Zde nahraďte svou architekturou)
# -------------------------------------------------------------------
class SimpleModel(nn.Module):
def __init__(self, input_dim):
super(SimpleModel, self).__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, 64),
nn.ReLU(),
nn.Linear(64, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.net(x)
# -------------------------------------------------------------------
# 3. Pomocné funkce pro Checkpointy
# -------------------------------------------------------------------
def save_checkpoint(state, filename="checkpoint.pth"):
print(f"=> Ukládám checkpoint do {filename}")
torch.save(state, filename)
def load_checkpoint(checkpoint_path, model, optimizer):
print(f"=> Načítám checkpoint z {checkpoint_path}")
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
return checkpoint['epoch']
# -------------------------------------------------------------------
# 4. Hlavní trénovací smyčka
# -------------------------------------------------------------------
def main():
# --- Nastavení argumentů ---
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type=int, default=5, help="Počet epoch")
parser.add_argument("--batch-size", type=int, default=32, help="Velikost batch")
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
parser.add_argument("--save-path", type=str, default="checkpoints", help="Složka pro ukládání")
parser.add_argument("--resume", type=str, default=None, help="Cesta k checkpointu pro navázání (např. checkpoints/last.pth)")
args = parser.parse_args()
# Vytvoření složky pro checkpointy, pokud neexistuje
os.makedirs(args.save_path, exist_ok=True)
# --- Nastavení zařízení (CPU/GPU) ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Běžím na zařízení: {device}")
# --- Příprava Dat ---
dataset = MujDataset()
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
# --- Inicializace modelu ---
model = SimpleModel(input_dim=20).to(device)
optimizer = optim.Adam(model.parameters(), lr=args.lr)
criterion = nn.BCELoss()
start_epoch = 0
# --- Načtení checkpointu (pokud je zadán) ---
if args.resume and os.path.isfile(args.resume):
start_epoch = load_checkpoint(args.resume, model, optimizer)
start_epoch += 1 # Pokračujeme následující epochou
print(f"Pokračuji od epochy {start_epoch}")
# --- Trénovací smyčka ---
for epoch in range(start_epoch, args.epochs):
model.train() # Přepnutí do trénovacího módu
running_loss = 0.0
for i, (inputs, labels) in enumerate(dataloader):
# Přesun dat na GPU/CPU
inputs, labels = inputs.to(device), labels.to(device)
# Forward pass
outputs = model(inputs)
loss = criterion(outputs, labels)
# Backward pass a optimalizace
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"Epocha [{epoch+1}/{args.epochs}], Loss: {running_loss/len(dataloader):.4f}")
# Uložení checkpointu po každé epoše
checkpoint_path = os.path.join(args.save_path, "last_checkpoint.pth")
save_checkpoint({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': running_loss,
}, filename=checkpoint_path)
print("Trénování dokončeno.")
if __name__ == "__main__":
main()
Jak skript spustit?
Na lokálním počítači:
python train.py --epochs 10 --batch-size 64
Příklad pro SLURM pro Laniakea cluster (run_job.sh):
#!/bin/bash
#SBATCH --partition=Virgo_A
# Spuštění s parametry
python train.py --epochs 100 --batch-size 128 --save-path ./checkpoints