Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update torch.cuda.amp to torch.amp #13244

Open
wants to merge 19 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions classify/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import torch.hub as hub
import torch.optim.lr_scheduler as lr_scheduler
import torchvision
from torch.cuda import amp
from tqdm import tqdm

FILE = Path(__file__).resolve()
Expand All @@ -48,6 +47,7 @@
check_git_info,
check_git_status,
check_requirements,
check_version,
colorstr,
download,
increment_path,
Expand Down Expand Up @@ -198,7 +198,13 @@ def lf(x):
t0 = time.time()
criterion = smartCrossEntropyLoss(label_smoothing=opt.label_smoothing) # loss function
best_fitness = 0.0
scaler = amp.GradScaler(enabled=cuda)

scaler = None
if check_version(torch.__version__, "2.4.0"):
scaler = torch.amp.GradScaler("cuda", enabled=cuda)
else:
scaler = torch.cuda.amp.GradScaler(enabled=cuda)

val = test_dir.stem # 'val' or 'test'
LOGGER.info(
f'Image sizes {imgsz} train, {imgsz} test\n'
Expand All @@ -218,8 +224,14 @@ def lf(x):
for i, (images, labels) in pbar: # progress bar
images, labels = images.to(device, non_blocking=True), labels.to(device)

amp_autocast = None
if check_version(torch.__version__, "2.4.0"):
amp_autocast = torch.amp.autocast("cuda", enabled=device.type != "cpu")
else:
amp_autocast = torch.cuda.amp.autocast(enabled=device.type != "cpu")

# Forward
with amp.autocast(enabled=cuda): # stability issues when enabled
with amp_autocast: # stability issues when enabled
loss = criterion(model(images), labels)

# Backward
Expand Down
10 changes: 9 additions & 1 deletion classify/val.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
Profile,
check_img_size,
check_requirements,
check_version,
colorstr,
increment_path,
print_args,
Expand Down Expand Up @@ -108,7 +109,14 @@ def run(
action = "validating" if dataloader.dataset.root.stem == "val" else "testing"
desc = f"{pbar.desc[:-36]}{action:>36}" if pbar else f"{action}"
bar = tqdm(dataloader, desc, n, not training, bar_format=TQDM_BAR_FORMAT, position=0)
with torch.cuda.amp.autocast(enabled=device.type != "cpu"):

amp_autocast = None
if check_version(torch.__version__, "2.4.0"):
amp_autocast = torch.amp.autocast("cuda", enabled=device.type != "cpu")
else:
amp_autocast = torch.cuda.amp.autocast(enabled=device.type != "cpu")

with amp_autocast:
for images, labels in bar:
with dt[0]:
images, labels = images.to(device, non_blocking=True), labels.to(device)
Expand Down
16 changes: 13 additions & 3 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import torch
import torch.nn as nn
from PIL import Image
from torch.cuda import amp

# Import 'ultralytics' package or install if missing
try:
Expand Down Expand Up @@ -862,7 +861,12 @@ def forward(self, ims, size=640, augment=False, profile=False):
p = next(self.model.parameters()) if self.pt else torch.empty(1, device=self.model.device) # param
autocast = self.amp and (p.device.type != "cpu") # Automatic Mixed Precision (AMP) inference
if isinstance(ims, torch.Tensor): # torch
with amp.autocast(autocast):
amp_autocast = None
if check_version(torch.__version__, "2.4.0"):
amp_autocast = torch.amp.autocast("cuda", enabled=autocast)
else:
amp_autocast = torch.cuda.amp.autocast(enabled=autocast)
with amp_autocast:
return self.model(ims.to(p.device).type_as(p), augment=augment) # inference

# Pre-process
Expand All @@ -889,7 +893,13 @@ def forward(self, ims, size=640, augment=False, profile=False):
x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW
x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32

with amp.autocast(autocast):
amp_autocast = None
if check_version(torch.__version__, "2.4.0"):
amp_autocast = torch.amp.autocast("cuda", enabled=autocast)
else:
amp_autocast = torch.cuda.amp.autocast(enabled=autocast)

with amp_autocast:
# Inference
with dt[1]:
y = self.model(x, augment=augment) # forward
Expand Down
17 changes: 15 additions & 2 deletions segment/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
check_img_size,
check_requirements,
check_suffix,
check_version,
check_yaml,
colorstr,
get_latest_run,
Expand Down Expand Up @@ -320,7 +321,13 @@ def lf(x):
maps = np.zeros(nc) # mAP per class
results = (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) # P, R, [email protected], [email protected], val_loss(box, obj, cls)
scheduler.last_epoch = start_epoch - 1 # do not move
scaler = torch.cuda.amp.GradScaler(enabled=amp)

scaler = None
if check_version(torch.__version__, "2.4.0"):
scaler = torch.amp.GradScaler("cuda", enabled=amp)
else:
scaler = torch.cuda.amp.GradScaler(enabled=amp)

stopper, stop = EarlyStopping(patience=opt.patience), False
compute_loss = ComputeLoss(model, overlap=overlap) # init loss class
# callbacks.run('on_train_start')
Expand Down Expand Up @@ -379,8 +386,14 @@ def lf(x):
ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)

amp_autocast = None
if check_version(torch.__version__, "2.4.0"):
amp_autocast = torch.amp.autocast("cuda", enabled=amp)
else:
amp_autocast = torch.cuda.amp.autocast(enabled=amp)

# Forward
with torch.cuda.amp.autocast(amp):
with amp_autocast:
pred = model(imgs) # forward
loss, loss_items = compute_loss(pred, targets.to(device), masks=masks.to(device).float())
if RANK != -1:
Expand Down
17 changes: 15 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
check_img_size,
check_requirements,
check_suffix,
check_version,
check_yaml,
colorstr,
get_latest_run,
Expand Down Expand Up @@ -352,7 +353,13 @@ def lf(x):
maps = np.zeros(nc) # mAP per class
results = (0, 0, 0, 0, 0, 0, 0) # P, R, [email protected], [email protected], val_loss(box, obj, cls)
scheduler.last_epoch = start_epoch - 1 # do not move
scaler = torch.cuda.amp.GradScaler(enabled=amp)

scaler = None
if check_version(torch.__version__, "2.4.0"):
scaler = torch.amp.GradScaler("cuda", enabled=amp)
else:
scaler = torch.cuda.amp.GradScaler(enabled=amp)

stopper, stop = EarlyStopping(patience=opt.patience), False
compute_loss = ComputeLoss(model) # init loss class
callbacks.run("on_train_start")
Expand Down Expand Up @@ -408,8 +415,14 @@ def lf(x):
ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)

amp_autocast = None
if check_version(torch.__version__, "2.4.0"):
amp_autocast = torch.amp.autocast("cuda", enabled=amp)
else:
amp_autocast = torch.cuda.amp.autocast(amp)

# Forward
with torch.cuda.amp.autocast(amp):
with amp_autocast:
pred = model(imgs) # forward
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
if RANK != -1:
Expand Down
5 changes: 4 additions & 1 deletion utils/autobatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@
import numpy as np
import torch

from utils.general import LOGGER, colorstr
from utils.general import LOGGER, check_version, colorstr
from utils.torch_utils import profile


def check_train_batch_size(model, imgsz=640, amp=True):
"""Checks and computes optimal training batch size for YOLOv5 model, given image size and AMP setting."""
if check_version(torch.__version__, "2.4.0"):
with torch.amp.autocast("cuda", enabled=amp):
return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size
with torch.cuda.amp.autocast(amp):
return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size

Expand Down