timm代码解读
1 创建dataset
timm
库通过 create_dataset
函数来得到 dataset_train
和 dataset_eval
这两个 dataset 类。
dataset_train = create_dataset(
args.dataset,
root=args.data_dir, split=args.train_split, is_training=True,
batch_size=args.batch_size, repeats=args.epoch_repeats)
dataset_eval = create_dataset(
args.dataset, root=args.data_dir, split=args.val_split, is_training=False, batch_size=args.batch_size)
2
3
4
5
6
create_dataset
函数如下面所示,实际返回的是:ImageDataset(root, parser=name, **kwargs)
def create_dataset(name, root, split='validation', search_split=True, is_training=False, batch_size=None, **kwargs):
name = name.lower()
if name.startswith('tfds'):
ds = IterableImageDataset(
root, parser=name, split=split, is_training=is_training, batch_size=batch_size, **kwargs)
else:
# FIXME support more advance split cfg for ImageFolder/Tar datasets in the future
kwargs.pop('repeats', 0) # FIXME currently only Iterable dataset support the repeat multiplier
if search_split and os.path.isdir(root):
root = _search_split(root, split)
ds = ImageDataset(root, parser=name, **kwargs)
return ds
2
3
4
5
6
7
8
9
10
11
12
13
ImageDataset 类如下所示,它的内部定义了最关键的 getitem 函数。
class ImageDataset(data.Dataset):
def __init__(
self,
root,
parser=None,
class_map='',
load_bytes=False,
transform=None,
):
if parser is None or isinstance(parser, str):
parser = create_parser(parser or '', root=root, class_map=class_map)
self.parser = parser
self.load_bytes = load_bytes
self.transform = transform
self._consecutive_errors = 0
def __getitem__(self, index):
img, target = self.parser[index]
try:
img = img.read() if self.load_bytes else Image.open(img).convert('RGB')
except Exception as e:
_logger.warning(f'Skipped sample (index {index}, file {self.parser.filename(index)}). {str(e)}')
self._consecutive_errors += 1
if self._consecutive_errors < _ERROR_RETRY:
return self.__getitem__((index + 1) % len(self.parser))
else:
raise e
self._consecutive_errors = 0
if self.transform is not None:
img = self.transform(img)
if target is None:
target = torch.tensor(-1, dtype=torch.long)
return img, target
def __len__(self):
return len(self.parser)
def filename(self, index, basename=False, absolute=False):
return self.parser.filename(index, basename, absolute)
def filenames(self, basename=False, absolute=False):
return self.parser.filenames(basename, absolute)
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
那么这个函数里面最关键的一句是:
img, target = self.parser[index]
而这里的 parser 来自于parser = create_parser(parser or '', root=root, class_map=class_map),所以有必要看看这个 create_parser 函数。
create_parser 函数的定义如下所示,最后返回的 parser 来自:parser = ParserImageFolder(root, **kwargs)
def create_parser(name, root, split='train', **kwargs):
name = name.lower()
name = name.split('/', 2)
prefix = ''
if len(name) > 1:
prefix = name[0]
name = name[-1]
# FIXME improve the selection right now just tfds prefix or fallback path, will need options to
# explicitly select other options shortly
if prefix == 'tfds':
from .parser_tfds import ParserTfds # defer tensorflow import
parser = ParserTfds(root, name, split=split, shuffle=kwargs.pop('shuffle', False), **kwargs)
else:
assert os.path.exists(root)
# default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder
# FIXME support split here, in parser?
if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar':
parser = ParserImageInTar(root, **kwargs)
else:
parser = ParserImageFolder(root, **kwargs)
return parser
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
所以有必要看看这个 ParserImageFolder 函数。
class ParserImageFolder(Parser):
def __init__(
self,
root,
class_map=''):
super().__init__()
self.root = root
class_to_idx = None
if class_map:
class_to_idx = load_class_map(class_map, root)
self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx)
if len(self.samples) == 0:
raise RuntimeError(
f'Found 0 images in subfolders of {root}. Supported image extensions are {", ".join(IMG_EXTENSIONS)}')
def __getitem__(self, index):
path, target = self.samples[index]
return open(path, 'rb'), target
def __len__(self):
return len(self.samples)
def _filename(self, index, basename=False, absolute=False):
filename = self.samples[index][0]
if basename:
filename = os.path.basename(filename)
elif not absolute:
filename = os.path.relpath(filename, self.root)
return filename
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
ParserImageFolder 函数通过 self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx) 来找到所有的 samples 的类别(0-1000) 和一个类名映射索引的 class_to_idx 表。
然后直接通过 path, target = self.samples[index] 找到某个索引的图片路径和类索引 (0-1000)。
所以说 img, target = self.parser[index] 的返回值其实就是 ParserImageFolder 类的 getitem 函数的返回值,即**:某个索引的图片路径和类索引 (0-1000)**。也就是 dataset 的功能。
2 构建dataloader
timm 库通过 create_loader 函数来创建dataloader,需要传入上一步构建的 dataset_train。
loader_train = create_loader(
dataset_train,
input_size=data_config['input_size'],
batch_size=args.batch_size,
is_training=True,
use_prefetcher=args.prefetcher,
no_aug=args.no_aug,
re_prob=args.reprob,
re_mode=args.remode,
re_count=args.recount,
re_split=args.resplit,
scale=args.scale,
ratio=args.ratio,
hflip=args.hflip,
vflip=args.vflip,
color_jitter=args.color_jitter,
auto_augment=args.aa,
num_aug_splits=num_aug_splits,
interpolation=train_interpolation,
mean=data_config['mean'],
std=data_config['std'],
num_workers=args.workers,
distributed=args.distributed,
collate_fn=collate_fn,
pin_memory=args.pin_mem,
use_multi_epochs_loader=args.use_multi_epochs_loader
)
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
create_loader 函数内部通过:
loader_class = torch.utils.data.DataLoader
得到 loader_class,再通过下面的语句建立 DataLoader (需要的参数 batch_size, shuffle, num_workers, sampler, collate_fn, drop_last 等等都以字典的形式保存在 loader_args 中):
loader_args = dict(
batch_size=batch_size,
shuffle=not isinstance(dataset, torch.utils.data.IterableDataset) and sampler is None and is_training,
num_workers=num_workers,
sampler=sampler,
collate_fn=collate_fn,
pin_memory=pin_memory,
drop_last=is_training,
persistent_workers=persistent_workers)
try:
loader = loader_class(dataset, **loader_args)
2
3
4
5
6
7
8
9
10
11
最后返回 loader。
3 创建模型
timm 库通过 create_model
函数来创建模型。
model = create_model(
args.model,
pretrained=args.pretrained,
num_classes=args.num_classes,
drop_rate=args.drop,
drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path
drop_path_rate=args.drop_path,
drop_block_rate=args.drop_block,
global_pool=args.gp,
bn_tf=args.bn_tf,
bn_momentum=args.bn_momentum,
bn_eps=args.bn_eps,
scriptable=args.torchscript,
checkpoint_path=args.initial_checkpoint)
2
3
4
5
6
7
8
9
10
11
12
13
14
函数 create_model
的具体实现是:
def create_model(
model_name,
pretrained=False,
checkpoint_path='',
scriptable=None,
exportable=None,
no_jit=None,
**kwargs):
"""Create a model
Args:
model_name (str): name of model to instantiate
pretrained (bool): load pretrained ImageNet-1k weights if true
checkpoint_path (str): path of checkpoint to load after model is initialized
scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet)
exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet)
no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only)
Keyword Args:
drop_rate (float): dropout rate for training (default: 0.0)
global_pool (str): global pool type (default: 'avg')
**: other kwargs are model specific
"""
source_name, model_name = split_model_name(model_name)
# Only EfficientNet and MobileNetV3 models have support for batchnorm params or drop_connect_rate passed as args
is_efficientnet = is_model_in_modules(model_name, ['efficientnet', 'mobilenetv3'])
if not is_efficientnet:
kwargs.pop('bn_tf', None)
kwargs.pop('bn_momentum', None)
kwargs.pop('bn_eps', None)
# handle backwards compat with drop_connect -> drop_path change
drop_connect_rate = kwargs.pop('drop_connect_rate', None)
if drop_connect_rate is not None and kwargs.get('drop_path_rate', None) is None:
print("WARNING: 'drop_connect' as an argument is deprecated, please use 'drop_path'."
" Setting drop_path to %f." % drop_connect_rate)
kwargs['drop_path_rate'] = drop_connect_rate
# Parameters that aren't supported by all models or are intended to only override model defaults if set
# should default to None in command line args/cfg. Remove them if they are present and not set so that
# non-supporting models don't break and default args remain in effect.
kwargs = {k: v for k, v in kwargs.items() if v is not None}
if source_name == 'hf_hub':
# For model names specified in the form `hf_hub:path/architecture_name#revision`,
# load model weights + default_cfg from Hugging Face hub.
hf_default_cfg, model_name = load_model_config_from_hf(model_name)
kwargs['external_default_cfg'] = hf_default_cfg # FIXME revamp default_cfg interface someday
if is_model(model_name):
create_fn = model_entrypoint(model_name)
else:
raise RuntimeError('Unknown model (%s)' % model_name)
with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
model = create_fn(pretrained=pretrained, **kwargs)
if checkpoint_path:
load_checkpoint(model, checkpoint_path)
return model
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
timm 库每次新定义一个模型,都类似于这样的形式 (这里以 vit_base_patch32_384 为例):
@register_model
def vit_base_patch32_384(pretrained=False, **kwargs):
""" ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs)
return model
2
3
4
5
6
7
8
这里的 register_model 来自 register.py (opens new window) 文件的 register_model 函数,如下。
register_model 函数的输入是 fn,也就是例子里面的 vit_base_patch32_384。register_model 函数的功能是把这个模型的函数的信息存到 _model_to_module 和 _model_entrypoints 等等的字典里面,相当于把 vit_base_patch32_384 给注册一下。
def register_model(fn):
# lookup containing module
mod = sys.modules[fn.__module__]
module_name_split = fn.__module__.split('.')
module_name = module_name_split[-1] if len(module_name_split) else ''
# add model to __all__ in module
model_name = fn.__name__
if hasattr(mod, '__all__'):
mod.__all__.append(model_name)
else:
mod.__all__ = [model_name]
# add entries to registry dict/sets
_model_entrypoints[model_name] = fn
_model_to_module[model_name] = module_name
_module_to_models[module_name].add(model_name)
has_pretrained = False # check if model has a pretrained url to allow filtering on this
if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs:
# this will catch all models that have entrypoint matching cfg key, but miss any aliasing
# entrypoints or non-matching combos
has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url']
_model_default_cfgs[model_name] = deepcopy(mod.default_cfgs[model_name])
if has_pretrained:
_model_has_pretrained.add(model_name)
return fn
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
注册完以后,通过 create_model
函数中的 create_fn = model_entrypoint(model_name) 语句得到 vit_base_patch32_384 函数。所以 create_fn() 就相当于是 vit_base_patch32_384 ()。
最后就是使用 create_fn 函数得到模型并返回 model。
with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
model = create_fn(pretrained=pretrained, **kwargs)
if checkpoint_path:
load_checkpoint(model, checkpoint_path)
return model
2
3
4
5
6
7
4 构建优化器
timm 库通过 create_optimizer_v2 函数来构建优化器。
optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args))
create_optimizer_v2 函数的具体实现如下,需要传入的参数是:模型参数,优化器类型,出示学习率,weight_decay 等等。之后通过 opt_lower 的选择来构建不同类型的优化器。
def create_optimizer_v2(
model_or_params,
opt: str = 'sgd',
lr: Optional[float] = None,
weight_decay: float = 0.,
momentum: float = 0.9,
filter_bias_and_bn: bool = True,
**kwargs):
""" Create an optimizer.
TODO currently the model is passed in and all parameters are selected for optimization.
For more general use an interface that allows selection of parameters to optimize and lr groups, one of:
* a filter fn interface that further breaks params into groups in a weight_decay compatible fashion
* expose the parameters interface and leave it up to caller
Args:
model_or_params (nn.Module): model containing parameters to optimize
opt: name of optimizer to create
lr: initial learning rate
weight_decay: weight decay to apply in optimizer
momentum: momentum for momentum based optimizers (others may use betas via kwargs)
filter_bias_and_bn: filter out bias, bn and other 1d params from weight decay
**kwargs: extra optimizer specific kwargs to pass through
Returns:
Optimizer
"""
if isinstance(model_or_params, nn.Module):
# a model was passed in, extract parameters and add weight decays to appropriate layers
if weight_decay and filter_bias_and_bn:
skip = {}
if hasattr(model_or_params, 'no_weight_decay'):
skip = model_or_params.no_weight_decay()
parameters = add_weight_decay(model_or_params, weight_decay, skip)
weight_decay = 0.
else:
parameters = model_or_params.parameters()
else:
# iterable of parameters or param groups passed in
parameters = model_or_params
opt_lower = opt.lower()
opt_split = opt_lower.split('_')
opt_lower = opt_split[-1]
if 'fused' in opt_lower:
assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
opt_args = dict(weight_decay=weight_decay, **kwargs)
if lr is not None:
opt_args.setdefault('lr', lr)
# basic SGD & related
if opt_lower == 'sgd' or opt_lower == 'nesterov':
# NOTE 'sgd' refers to SGD + nesterov momentum for legacy / backwards compat reasons
opt_args.pop('eps', None)
optimizer = optim.SGD(parameters, momentum=momentum, nesterov=True, **opt_args)
elif opt_lower == 'momentum':
opt_args.pop('eps', None)
optimizer = optim.SGD(parameters, momentum=momentum, nesterov=False, **opt_args)
elif opt_lower == 'sgdp':
optimizer = SGDP(parameters, momentum=momentum, nesterov=True, **opt_args)
# adaptive
elif opt_lower == 'adam':
optimizer = optim.Adam(parameters, **opt_args)
elif opt_lower == 'adamw':
optimizer = optim.AdamW(parameters, **opt_args)
elif opt_lower == 'adamp':
optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
elif opt_lower == 'nadam':
try:
# NOTE PyTorch >= 1.10 should have native NAdam
optimizer = optim.Nadam(parameters, **opt_args)
except AttributeError:
optimizer = Nadam(parameters, **opt_args)
elif opt_lower == 'radam':
optimizer = RAdam(parameters, **opt_args)
elif opt_lower == 'adamax':
optimizer = optim.Adamax(parameters, **opt_args)
elif opt_lower == 'adabelief':
optimizer = AdaBelief(parameters, rectify=False, **opt_args)
elif opt_lower == 'radabelief':
optimizer = AdaBelief(parameters, rectify=True, **opt_args)
elif opt_lower == 'adadelta':
optimizer = optim.Adadelta(parameters, **opt_args)
elif opt_lower == 'adagrad':
opt_args.setdefault('eps', 1e-8)
optimizer = optim.Adagrad(parameters, **opt_args)
elif opt_lower == 'adafactor':
optimizer = Adafactor(parameters, **opt_args)
elif opt_lower == 'lamb':
optimizer = Lamb(parameters, **opt_args)
elif opt_lower == 'lambc':
optimizer = Lamb(parameters, trust_clip=True, **opt_args)
elif opt_lower == 'larc':
optimizer = Lars(parameters, momentum=momentum, trust_clip=True, **opt_args)
elif opt_lower == 'lars':
optimizer = Lars(parameters, momentum=momentum, **opt_args)
elif opt_lower == 'nlarc':
optimizer = Lars(parameters, momentum=momentum, trust_clip=True, nesterov=True, **opt_args)
elif opt_lower == 'nlars':
optimizer = Lars(parameters, momentum=momentum, nesterov=True, **opt_args)
elif opt_lower == 'madgrad':
optimizer = MADGRAD(parameters, momentum=momentum, **opt_args)
elif opt_lower == 'madgradw':
optimizer = MADGRAD(parameters, momentum=momentum, decoupled_decay=True, **opt_args)
elif opt_lower == 'novograd' or opt_lower == 'nvnovograd':
optimizer = NvNovoGrad(parameters, **opt_args)
elif opt_lower == 'rmsprop':
optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=momentum, **opt_args)
elif opt_lower == 'rmsproptf':
optimizer = RMSpropTF(parameters, alpha=0.9, momentum=momentum, **opt_args)
# second order
elif opt_lower == 'adahessian':
optimizer = Adahessian(parameters, **opt_args)
# NVIDIA fused optimizers, require APEX to be installed
elif opt_lower == 'fusedsgd':
opt_args.pop('eps', None)
optimizer = FusedSGD(parameters, momentum=momentum, nesterov=True, **opt_args)
elif opt_lower == 'fusedmomentum':
opt_args.pop('eps', None)
optimizer = FusedSGD(parameters, momentum=momentum, nesterov=False, **opt_args)
elif opt_lower == 'fusedadam':
optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
elif opt_lower == 'fusedadamw':
optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
elif opt_lower == 'fusedlamb':
optimizer = FusedLAMB(parameters, **opt_args)
elif opt_lower == 'fusednovograd':
opt_args.setdefault('betas', (0.95, 0.98))
optimizer = FusedNovoGrad(parameters, **opt_args)
else:
assert False and "Invalid optimizer"
raise ValueError
if len(opt_split) > 1:
if opt_split[0] == 'lookahead':
optimizer = Lookahead(optimizer)
return optimizer
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
5 构建scheduler
timm 库通过 create_scheduler 函数来构建 scheduler。
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
内部通过 args.sched 参数控制具体创建什么类型的 scheduler。
6 构建训练 engine
timm 库通过 train_one_epoch 函数来构建训练 engine。
def train_one_epoch(
epoch, model, loader, optimizer, loss_fn, args,
lr_scheduler=None, saver=None, output_dir=None, amp_autocast=suppress,
loss_scaler=None, model_ema=None, mixup_fn=None):
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
if args.prefetcher and loader.mixup_enabled:
loader.mixup_enabled = False
elif mixup_fn is not None:
mixup_fn.mixup_enabled = False
second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
batch_time_m = AverageMeter()
data_time_m = AverageMeter()
losses_m = AverageMeter()
model.train()
end = time.time()
last_idx = len(loader) - 1
num_updates = epoch * len(loader)
for batch_idx, (input, target) in enumerate(loader):
last_batch = batch_idx == last_idx
data_time_m.update(time.time() - end)
if not args.prefetcher:
input, target = input.cuda(), target.cuda()
if mixup_fn is not None:
input, target = mixup_fn(input, target)
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
with amp_autocast():
output = model(input)
loss = loss_fn(output, target)
if not args.distributed:
losses_m.update(loss.item(), input.size(0))
optimizer.zero_grad()
if loss_scaler is not None:
loss_scaler(
loss, optimizer,
clip_grad=args.clip_grad, clip_mode=args.clip_mode,
parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
create_graph=second_order)
else:
loss.backward(create_graph=second_order)
if args.clip_grad is not None:
dispatch_clip_grad(
model_parameters(model, exclude_head='agc' in args.clip_mode),
value=args.clip_grad, mode=args.clip_mode)
optimizer.step()
if model_ema is not None:
model_ema.update(model)
torch.cuda.synchronize()
num_updates += 1
batch_time_m.update(time.time() - end)
if last_batch or batch_idx % args.log_interval == 0:
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
lr = sum(lrl) / len(lrl)
if args.distributed:
reduced_loss = reduce_tensor(loss.data, args.world_size)
losses_m.update(reduced_loss.item(), input.size(0))
if args.local_rank == 0:
_logger.info(
'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) '
'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
'({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
'LR: {lr:.3e} '
'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
epoch,
batch_idx, len(loader),
100. * batch_idx / last_idx,
loss=losses_m,
batch_time=batch_time_m,
rate=input.size(0) * args.world_size / batch_time_m.val,
rate_avg=input.size(0) * args.world_size / batch_time_m.avg,
lr=lr,
data_time=data_time_m))
if args.save_images and output_dir:
torchvision.utils.save_image(
input,
os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
padding=0,
normalize=True)
if saver is not None and args.recovery_interval and (
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
saver.save_recovery(epoch, batch_idx=batch_idx)
if lr_scheduler is not None:
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
end = time.time()
# end for
if hasattr(optimizer, 'sync_lookahead'):
optimizer.sync_lookahead()
return OrderedDict([('loss', losses_m.avg)])
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
这个函数里面值得注意的是 loss_scaler 函数,它的作用本质上是 loss.backward(create_graph=create_graph) 和 optimizer.step()。
loss_scaler 继承 NativeScaler 这个类。这个类的实例在调用时需要传入 loss, optimizer, clip_grad 等参数,在 call () 函数的内部实现了 loss.backward(create_graph=create_graph) 功能和 optimizer.step() 功能。
class NativeScaler:
state_dict_key = "amp_scaler"
def __init__(self):
self._scaler = torch.cuda.amp.GradScaler()
def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False):
self._scaler.scale(loss).backward(create_graph=create_graph)
if clip_grad is not None:
assert parameters is not None
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
dispatch_clip_grad(parameters, clip_grad, mode=clip_mode)
self._scaler.step(optimizer)
self._scaler.update()
def state_dict(self):
return self._scaler.state_dict()
def load_state_dict(self, state_dict):
self._scaler.load_state_dict(state_dict)
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20