diff --git a/ddp.slurm b/ddp.slurm deleted file mode 100644 index 100d3a7..0000000 --- a/ddp.slurm +++ /dev/null @@ -1,23 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=ddp-training -#SBATCH --partition=a100 -#SBATCH --nodes=2 -#SBATCH --gres=gpu:2 # 2 gpus per node -#SBATCH --ntasks=4 # 4 processes per job -#SBATCH --array=0-26%3 # 27 jobs, max 3 in parallel (27 unique models, given hyperparemeter configurations) -#SBATCH --output=/nfs/nhome/live/jbhagat/nanogpt_ddp_runs/job_%j.out -#SBATCH --error=/nfs/nhome/live/jbhagat/nanogpt_ddp_runs/job_%j.err - -# Set first node as the master -MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) - -# Activate env -source /nfs/nhome/live/jbhagat/mambaforge/etc/profile.d/mamba.sh -mamba activate nanogpt - -# Run ddp -srun python /nfs/nhome/live/jbhagat/nanogpt/ddp.py \ - --config-idx="$SLURM_ARRAY_TASK_ID" \ - --world-size="$SLURM_NTASKS" \ - --rank="$SLURM_PROCID" \ - --master-addr="$MASTER_ADDR" diff --git a/ddp.py b/ddp_and_fsdp/ddp.py similarity index 81% rename from ddp.py rename to ddp_and_fsdp/ddp.py index accf998..e443edc 100644 --- a/ddp.py +++ b/ddp_and_fsdp/ddp.py @@ -2,6 +2,7 @@ import argparse # noqa: I001 import os +import sys import time from itertools import product from pathlib import Path @@ -17,6 +18,9 @@ from torch.utils.data import DataLoader, TensorDataset, random_split from torch.utils.data.distributed import DistributedSampler +# Import nanogpt from relative directory. +nanogpt_dir = Path.cwd().parent +sys.path.append(str(nanogpt_dir)) from nanogpt import NanoGPT, build_dataset # Hyperparameters for model setup. @@ -28,17 +32,11 @@ {"ctx_len": 2048, "emb_dim": 1024, "n_heads": 20, "head_sz": 80, "n_blocks": 12}, ] -def setup( - rank: int, # rank of current process - world_size: int, # number of processes - master_addr: str, # master machine address (IP or hostname) - master_port: str, # master machine port -): +def setup(backend: str): """Sets up the DDP environment.""" - os.environ["MASTER_ADDR"] = master_addr - os.environ["MASTER_PORT"] = master_port - # Create distributed process group. - init_process_group(backend="nccl", rank=rank, world_size=world_size) + # Create distributed process group and set cuda device according to torchrun LOCAL_RANK env var. + init_process_group(backend=backend) + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) def cleanup(): """Cleans up and kills DDP environment.""" @@ -51,9 +49,10 @@ def train( val_loader: DataLoader, # batched dataset for validation optimizer: optim, # optimizer loss_fn: nn.modules.loss, # loss function - rank: int, # rank of current process + global_rank: int, # rank of current process across all nodes + local_rank: int, # rank of current process within node max_epochs: int = 5, # max n training epochs - max_batches: int = 1e9, # max n batches to train + max_batches: int = 1000, # max n batches to train val_chk_interval: int = 200, # check val loss every `val_chk_interval` batches & print losses val_iter: int = 5, # number of batches on val_loader to run and avg when computing val loss patience_thresh: int = 1e9, # consecutive batches without val loss decrease for early stopping @@ -69,8 +68,8 @@ def estimate_losses( """Estimate losses on val_loader, and return val loss and train loss avg.""" model.eval() for val_i, (x_val, y_val) in enumerate(val_loader): - logits = model(x_val.to(rank)) - val_loss = loss_fn(logits.view(-1, n_tokens), y_val.to(rank).view(-1)) + logits = model(x_val.to(local_rank)) + val_loss = loss_fn(logits.view(-1, n_tokens), y_val.to(local_rank).view(-1)) val_losses.append(val_loss.item()) if val_i >= (val_iter - 1): break @@ -101,7 +100,7 @@ def apply_gradient_centralization(optimizer): train_losses, val_losses, train_losses_avg, val_losses_avg = [], [], [], [] init_loss, best_val_loss = float("inf"), float("inf") patience_ct = 0 - if rank == 0: + if global_rank == 0: wandb.log({"expected_total_batches": batch_lim}) # /s> @@ -111,9 +110,9 @@ def apply_gradient_centralization(optimizer): for batch_i, (x_train, y_train) in enumerate(train_loader): # [batch_sz, ctx_len, n_tokens], but... + logits = model(x_train.to(local_rank)) # -> [batch_sz, ctx_len, n_tokens], but... # must reshape to compare against batch_sz vector of targets for cross-entropy loss - loss = loss_fn(logits.view(-1, n_tokens), y_train.to(rank).view(-1)) + loss = loss_fn(logits.view(-1, n_tokens), y_train.to(local_rank).view(-1)) loss.backward() apply_gradient_centralization(optimizer) optimizer.step() @@ -125,7 +124,7 @@ def apply_gradient_centralization(optimizer): estimate_losses( model, val_loader, val_losses, val_losses_avg, train_losses, train_losses_avg ) - if rank == 0: + if global_rank == 0: wandb.log({"train_loss": train_losses_avg[-1], "val_loss": val_losses_avg[-1]}) # Return if patience check reached (early stopping). patience_ct = ( @@ -133,21 +132,21 @@ def apply_gradient_centralization(optimizer): ) best_val_loss = min(best_val_loss, val_losses_avg[-1]) if patience_ct >= patience_thresh: - if rank == 0: + if global_rank == 0: wandb.log( {"train_loss": train_losses_avg[-1], "val_loss": val_losses_avg[-1]} ) return loss, train_losses_avg, val_losses_avg # Return if max_batches reached. if (batch_i + 1) * (epoch + 1) >= max_batches: - if rank == 0: + if global_rank == 0: wandb.log({"train_loss": train_losses_avg[-1], "val_loss": val_losses_avg[-1]}) return loss, train_losses_avg, val_losses_avg # Save checkpoint check. if ( Path(save_chkpt_dir).exists() and (init_loss - loss.item()) > save_chkpt_thresh - and rank == 0 + and global_rank == 0 ): torch.save( model.module.state_dict(), @@ -156,7 +155,7 @@ def apply_gradient_centralization(optimizer): init_loss = loss.item() # /ss> # /s> # Return after max_epochs reached. - if rank == 0: + if global_rank == 0: wandb.log( { "train_loss": train_losses_avg[-1], @@ -179,7 +178,7 @@ def apply_gradient_centralization(optimizer): "estimated_time_remaining": est_remaining_t } ) - if Path(save_chkpt_dir).exists() and rank == 0: + if Path(save_chkpt_dir).exists() and local_rank == 0: torch.save( model.module.state_dict(), Path(save_chkpt_dir) / f"model_chkpt_loss{loss.item():.3f}.pth" @@ -187,10 +186,9 @@ def apply_gradient_centralization(optimizer): return loss, train_losses_avg, val_losses_avg def main( - rank: int, # rank of current process - world_size: int, # number of processes - master_addr: str, # master machine address (IP or hostname) - master_port: str, # master machine port + backend: str, # DDP backend to use + global_rank: int, # rank of current process across all nodes + local_rank: int, # rank of current process within node text_file: str, # path to text file to train on train_config: tuple[float, optim.Optimizer, list[dict]], # lr, optimizer, model config ): @@ -199,7 +197,7 @@ def main( Sets up DDP env, creates dataset from text file, creates and trains model, cleans up DDP env. """ # Set up DDP environment. - setup(rank, world_size, master_addr, master_port) + setup(backend) # Set up dataset. with open(text_file) as f: text = f.read() @@ -215,7 +213,7 @@ def main( ) # Set up model. model = NanoGPT(n_tokens=len(tokens), **train_config[2]) - model = DDP(model.to(rank), device_ids=[rank]) + model = DDP(model.to(local_rank), device_ids=[local_rank]) # Initialize wandb config and run. param_bytes = 4 # 32-bit floats bytes_in_gb = 1024**3 @@ -223,7 +221,7 @@ def main( n_tot_params_b = round(n_tot_params / 1e9, 3) tot_sz_gb = n_tot_params * param_bytes / bytes_in_gb run_name = f"{train_config[1].__name__}-{train_config[0]}_{n_tot_params_b}B" - if rank == 0: + if global_rank == 0: wandb_config = { "n_params_bil": n_tot_params_b, "sz_gb": tot_sz_gb, @@ -240,7 +238,16 @@ def main( optimizer = train_config[1](model.parameters(), lr=train_config[0]) loss_fn = nn.CrossEntropyLoss() save_chkpt_dir = Path.home() / "nanogpt_ddp_runs" / "chkpts" / run_name - train(model, train_loader, val_loader, optimizer, loss_fn, rank, save_chkpt_dir=save_chkpt_dir) + train( + model, + train_loader, + val_loader, + optimizer, + loss_fn, + global_rank, + local_rank, + save_chkpt_dir=save_chkpt_dir + ) # Clean up DDP environment. cleanup() @@ -250,42 +257,29 @@ def main( # Parse args. parser = argparse.ArgumentParser(description="Run DDP distributed training of NanoGPTs.") parser.add_argument( - "--train-config-idx", + "--ddp_backend", + type=str, + default="nccl", + help="DDP backend to use (typically 'nccl' on Unix-like system, 'gloo' on Windows)." + ) + parser.add_argument( + "--train_config_idx", type=int, required=True, help="Index of train config to run. (See `train_configs` var)" ) parser.add_argument( - "--world-size", type=int, required=True, help="Number of processes to use for DDP." - ) - parser.add_argument("--rank", type=int, required=True, help="Rank of current process.") - parser.add_argument( - "--master-addr", type=str, required=True, help="Master address (or hostname) for DDP." - ) - parser.add_argument("--master-port", type=str, default="4444", help="Master port for DDP.") - parser.add_argument( - "--text-file", + "--text_file", type=str, - default=(Path.cwd() / "data/tiny_austen.txt"), + default=(Path.cwd().parent / "data/tiny_austen.txt"), help="Path to text file to train on." ) args = parser.parse_args() + # Get ranks from torchrun env vars. + global_rank = int(os.environ["RANK"]) # rank of current process across all nodes + local_rank = int(os.environ["LOCAL_RANK"]) # rank of current process within node # Set training config. train_configs = list(product(LR_SET, OPTIM_SET, ARCH_SET)) train_config = train_configs[args.train_config_idx] # Run DDP training. - main(args.rank, args.world_size, args.master_addr, args.master_port, args.text_file, train_config) - - # Use `mp.spawn` and 'gloo' (as backend device comm library) for local testing. - # mp.spawn( # passes `rank` to `main` as first arg automatically - # main, - # args=( - # args.world_size, - # args.master_addr, - # args.master_port, - # args.text_file, - # train_config, - # ), - # nprocs=args.world_size, - # join=True, - # ) + main(args.ddp_backend, global_rank, local_rank, args.text_file, train_config) diff --git a/ddp_and_fsdp/ddp.slurm b/ddp_and_fsdp/ddp.slurm new file mode 100644 index 0000000..e1c55d5 --- /dev/null +++ b/ddp_and_fsdp/ddp.slurm @@ -0,0 +1,37 @@ +#!/bin/bash +#SBATCH --job-name=ddp-training +#SBATCH --partition=a100 +#SBATCH --nodes=1 +#SBATCH --mem=128G +#SBATCH --ntasks=2 # processes per job +#SBATCH --gres=gpu:2 # gpus total across nodes +#SBATCH --array=0-26%3 # jobs, % max in parallel (27 unique models, given hyperparemeter configurations) +#SBATCH --output=/nfs/nhome/live/jbhagat/nanogpt_ddp_runs/job_%j.out +#SBATCH --error=/nfs/nhome/live/jbhagat/nanogpt_ddp_runs/job_%j.err + +# Set first node as the master +HEAD_NODE_HOSTNAME=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) +HEAD_NODE_IP=$(nslookup $HEAD_NODE_HOSTNAME | grep 'Address:' | awk 'NR==2 {print $2}') + +# Dynamically calculate number of processes per node, based on number of nodes assigned for this job +PROCS_PER_NODE=$(($SLURM_NTASKS / $SLURM_JOB_NUM_NODES)) + +# Echo vars to .out file +echo "HEAD_NODE_HOSTNAME: $HEAD_NODE_HOSTNAME, HEAD_NODE_IP: $HEAD_NODE_IP, PROCS_PER_NODE: $PROCS_PER_NODE" + +# Activate env +source /nfs/nhome/live/jbhagat/mambaforge/etc/profile.d/conda.sh +conda activate nanogpt + +# Run ddp +srun torchrun \ + --standalone \ + --nnodes=${SLURM_JOB_NUM_NODES} \ + --nproc_per_node=${PROCS_PER_NODE} \ + /nfs/nhome/live/jbhagat/nanoGPT/ddp_and_fsdp/ddp.py \ + --train_config_idx="$SLURM_ARRAY_TASK_ID" + +# rdzv args for multinode +#--rdzv_id=4444 \ +#--rdzv_backend="c10d" \ +#--rdzv_endpoint="$HEAD_NODE_IP:44444" diff --git a/readme.md b/readme.md index eeb1851..acb4bf2 100644 --- a/readme.md +++ b/readme.md @@ -12,6 +12,9 @@ Multi-head self-attention is implemented "from scratch", at the level of pytorch While the overall architecture is similar, this nanoGPT makes departures from Karpathy's nanoGPT in: naming conventions, data loading and training configuration, projecting embedding dimensions to attention heads, the format of operations in self-attention units and transformer blocks, output model generation (by adding parameters such as `temp` and `top_k`), and more. +Additionally, examples of distributed training of models across multiple GPUs using PyTorch +Distributed Data Parallel (DDP) and Fully Sharded Data Parallel (FSDP) via Slurm can be found in the `ddp_and_fsdp` directory. + ## Examples ### nanoGPT-Shakespeare @@ -41,6 +44,10 @@ Output generated from models trained after approximately 320000 (top), 640000 (m - `tests/` contains tests that can be run via pytest for verifying components of nanoGPT work as expected. +- `ddp_and_fsdp/` contains python modules and slurm scripts for: + - 1: speeding up training of a single model across multiple GPUs via model copying and distributed batching using DDP. + - 2: training a single large model across multiple GPUs via sharding using FSDP. + - `.github/workflows/` contains a github actions workflow for building the python environment, running tests, and uploading the results to codecov. ## Usage @@ -75,7 +82,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Import nanogpt nanogpt_dir = Path.cwd() -sys.path.append(nanogpt_dir) +sys.path.append(str(nanogpt_dir)) import nanogpt # Load in text file to train on and build dataloaders