Geeks_Z の Blog Geeks_Z の Blog
首页
  • 学习笔记

    • 《HTML》
    • 《CSS》
    • 《JavaWeb》
    • 《Vue》
  • 后端文章

    • Linux
    • Maven
    • 汇编语言
    • 软件工程
    • 计算机网络概述
    • Conda
    • Pip
    • Shell
    • SSH
    • Mac快捷键
    • Zotero
  • 学习笔记

    • 《数据结构与算法》
    • 《算法设计与分析》
    • 《Spring》
    • 《SpringMVC》
    • 《SpringBoot》
    • 《SpringCloud》
    • 《Nginx》
  • 深度学习文章
  • 学习笔记

    • 《PyTorch》
    • 《ReinforementLearning》
    • 《MetaLearning》
  • 学习笔记

    • 《高等数学》
    • 《线性代数》
    • 《概率论与数理统计》
  • 增量学习
  • 哈希学习
GitHub (opens new window)

Geeks_Z

AI小学生
首页
  • 学习笔记

    • 《HTML》
    • 《CSS》
    • 《JavaWeb》
    • 《Vue》
  • 后端文章

    • Linux
    • Maven
    • 汇编语言
    • 软件工程
    • 计算机网络概述
    • Conda
    • Pip
    • Shell
    • SSH
    • Mac快捷键
    • Zotero
  • 学习笔记

    • 《数据结构与算法》
    • 《算法设计与分析》
    • 《Spring》
    • 《SpringMVC》
    • 《SpringBoot》
    • 《SpringCloud》
    • 《Nginx》
  • 深度学习文章
  • 学习笔记

    • 《PyTorch》
    • 《ReinforementLearning》
    • 《MetaLearning》
  • 学习笔记

    • 《高等数学》
    • 《线性代数》
    • 《概率论与数理统计》
  • 增量学习
  • 哈希学习
GitHub (opens new window)
  • Python

  • MLTutorials

  • 卷积神经网络

  • 循环神经网络

  • Transformer

  • VisionTransformer

  • 扩散模型

  • 计算机视觉

  • PTM

  • MoE

  • LoRAMoE

  • LongTailed

  • 多模态

  • 知识蒸馏

  • PEFT

  • 对比学习

  • 小样本学习

  • 迁移学习

  • 零样本学习

  • 集成学习

  • Mamba

  • PyTorch

    • PyTorch概述

    • Tensors

    • 数据处理

    • 模型

    • 训练

    • 并行计算

    • 可视化

    • 实战

    • timm

      • timm概述
      • timm使用教程
      • timm代码解读
      • timm-vit代码解读
      • create_model解读
      • Pytorch Lightning

      • 数据增强

      • 面经与bug解决

      • 常用代码片段

      • Reference
    • CL

    • CIL

    • 小样本类增量学习FSCIL

    • UCIL

    • 多模态增量学习MMCL

    • LTCIL

    • DIL

    • 论文阅读与写作

    • 分布外检测

    • GPU

    • 深度学习调参指南

    • AINotes
    • PyTorch
    • timm
    Geeks_Z
    2024-10-16
    目录

    create_model解读

    使用

    import timm
    model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
    
    1
    2

    pretrained

    True or False

    True:timm直接根据对应的URL下载模型权重参数并加载到模型,注意,只有本地没有对应模型参数时才会下载,也就是说,通常在第一次运行时下载对应模型参数,之后会直接从本地加载模型权重参数。

    registry

    create_model主体只有50行左右的代码,那么,如何实现从模型到特征提取器的转换?已知timm.list_models()函数中的每一个模型名字(str)实际上都是一个函数。

    ###输入
    import timm
    import random 
    from timm.models import registry
     
    m = timm.list_models()[-1]
    print(m)
    registry.is_model(m)
     
    ###输出
    xception71
    True
    
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12

    实际上,在 timm 内部,有一个字典称为 _model_entrypoints 包含了所有的模型名称和他们各自的函数。比如说,可以通过 model_entrypoint 函数从 _model_entrypoints 内部得到 xception71 模型的构造函数。

    ###输入
    constuctor_fn = registry.model_entrypoint(m)
    print(constuctor_fn)
    ###输出
    <function timm.models.xception_aligned.xception71(pretrained=False, **kwargs)>
    or
    <function xception71 at 0x7fc0cba0eca0>
    
    1
    2
    3
    4
    5
    6
    7

    在 timm.models.xception_aligned 模块中有一个函数称为 xception71 。类似的,timm 中的每一个模型都有着一个这样的构造函数。事实上,内部的 _model_entrypoints 字典大概长这个样子:

    _model_entrypoints
    > > 
    {
    'cspresnet50':<function timm.models.cspnet.cspresnet50(pretrained=False, **kwargs)>,'cspresnet50d': <function timm.models.cspnet.cspresnet50d(pretrained=False, **kwargs)>,
    'cspresnet50w': <function timm.models.cspnet.cspresnet50w(pretrained=False, **kwargs)>,
    'cspresnext50': <function timm.models.cspnet.cspresnext50(pretrained=False, **kwargs)>,
    'cspresnext50_iabn': <function timm.models.cspnet.cspresnext50_iabn(pretrained=False, **kwargs)>,
    'cspdarknet53': <function timm.models.cspnet.cspdarknet53(pretrained=False, **kwargs)>,
    'cspdarknet53_iabn': <function timm.models.cspnet.cspdarknet53_iabn(pretrained=False, **kwargs)>,
    'darknet53': <function timm.models.cspnet.darknet53(pretrained=False, **kwargs)>,
    'densenet121': <function timm.models.densenet.densenet121(pretrained=False, **kwargs)>,
    'densenetblur121d': <function timm.models.densenet.densenetblur121d(pretrained=False, **kwargs)>,
    'densenet121d': <function timm.models.densenet.densenet121d(pretrained=False, **kwargs)>,
    'densenet169': <function timm.models.densenet.densenet169(pretrained=False, **kwargs)>,
    'densenet201': <function timm.models.densenet.densenet201(pretrained=False, **kwargs)>,
    'densenet161': <function timm.models.densenet.densenet161(pretrained=False, **kwargs)>,
    'densenet264': <function timm.models.densenet.densenet264(pretrained=False, **kwargs)>,
     
    }
    
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19

    所以说,在 timm 对应的模块中,每个模型都有一个构造器。比如说 ResNets 系列模型被定义在 timm.models.resnet 模块中。因此,实际上有两种方式来创建一个 resnet34 模型:

    import timm
    from timm.models.resnet import resnet34
     
    # 使用 create_model
    m = timm.create_model('resnet34')
     
    # 直接调用构造函数
    m = resnet34()
    
    1
    2
    3
    4
    5
    6
    7
    8

    但使用上,无须调用构造函数。所用模型都可以通过create_model函数来将创建。

    Register model

    resnet34构造函数的源码如下:

    @register_model
    def resnet34(pretrained=False, **kwargs):
        """Constructs a ResNet-34 model.
        """
        model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], **kwargs)
        return _create_resnet('resnet34', pretrained, **model_args)
    
    1
    2
    3
    4
    5
    6

    会发现 timm 中的每个模型都有一个 register_model 装饰器。最开始, _model_entrypoints 是一个空字典。通过 register_model 装饰器来不断地像其中添加模型名称和它对应的构造函数。该装饰器的定义如下:

    def register_model(fn):
        # lookup containing module
        mod = sys.modules[fn.__module__]
        module_name_split = fn.__module__.split('.')
        module_name = module_name_split[-1] if len(module_name_split) else ''
     
        # add model to __all__ in module
        model_name = fn.__name__
        if hasattr(mod, '__all__'):
            mod.__all__.append(model_name)
        else:
            mod.__all__ = [model_name]
     
        # add entries to registry dict/sets
        _model_entrypoints[model_name] = fn
        _model_to_module[model_name] = module_name
        _module_to_models[module_name].add(model_name)
        has_pretrained = False  # check if model has a pretrained url to allow filtering on this
        if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs:
            # this will catch all models that have entrypoint matching cfg key, but miss any aliasing
            # entrypoints or non-matching combos
            has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url']
        if has_pretrained:
            _model_has_pretrained.add(model_name)
        return fn
    
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25

    可以看到, register_model 函数完成了一些比较基础的步骤,但这里需要指出的是这一句:

    _model_entrypoints[model_name] = fn
    
    1

    它将给定的 fn 添加到 _model_entrypoints 其键名为 fn.name。所以说 resnet34 函数上的装饰器 @register_model 在 _model_entrypoints 中创建一个新的条目,像这样:

    {&#8217;resnet34&#8217;: <function timm.models.resnet.resnet34(pretrained=False, **kwargs)>}
    
    1

    同样可以看到在 resnet34 构造函数的源码中,在设置完一些 model_args 之后,它会随后调用 _create_resnet 函数。再来看一下该函数的源码:

    def _create_resnet(variant, pretrained=False, **kwargs):
        return build_model_with_cfg(
            ResNet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained, **kwargs)
    
    1
    2
    3

    所以在 _create_resnet 函数之中,会再调用 build_model_with_cfg 函数并将一个构造器类 ResNet 、变量名 resnet34、一个 default_cfg 和一些 **kwargs 传入其中。

    default config

    timm 中所有的模型都有一个默认的配置,包括指向它的预训练权重参数的URL、类别数、输入图像尺寸、池化尺寸等。resnet34 的默认配置如下:

    {'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth',
    'num_classes': 1000,
    'input_size': (3, 224, 224),
    'pool_size': (7, 7),
    'crop_pct': 0.875,
    'interpolation': 'bilinear',
    'mean': (0.485, 0.456, 0.406),
    'std': (0.229, 0.224, 0.225),
    'first_conv': 'conv1',
    'classifier': 'fc'}
    
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10

    build model with config

    这个 build_model_with_cfg 函数负责:

    • 真正地实例化一个模型类来创建一个模型
    • 若 pruned=True,对模型进行剪枝
    • 若 pretrained=True,加载预训练模型参数
    • 若 features_only=True,将模型转换为特征提取器
    def build_model_with_cfg(
            model_cls: Callable,
            variant: str,
            pretrained: bool,
            default_cfg: dict,
            model_cfg: dict = None,
            feature_cfg: dict = None,
            pretrained_strict: bool = True,
            pretrained_filter_fn: Callable = None,
            pretrained_custom_load: bool = False,
            **kwargs):
        pruned = kwargs.pop('pruned', False)
        features = False
        feature_cfg = feature_cfg or {}
     
        if kwargs.pop('features_only', False):
            features = True
            feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
            if 'out_indices' in kwargs:
                feature_cfg['out_indices'] = kwargs.pop('out_indices')
     
        model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs)
        model.default_cfg = deepcopy(default_cfg)
     
        if pruned:
            model = adapt_model_from_file(model, variant)
     
        # for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
        num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
        if pretrained:
            if pretrained_custom_load:
                load_custom_pretrained(model)
            else:
                load_pretrained(
                    model,
                    num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3),
                    filter_fn=pretrained_filter_fn, strict=pretrained_strict)
     
        if features:
            feature_cls = FeatureListNet
            if 'feature_cls' in feature_cfg:
                feature_cls = feature_cfg.pop('feature_cls')
                if isinstance(feature_cls, str):
                    feature_cls = feature_cls.lower()
                    if 'hook' in feature_cls:
                        feature_cls = FeatureHookNet
                    else:
                        assert False, f'Unknown feature class {feature_cls}'
            model = feature_cls(model, **feature_cfg)
            model.default_cfg = default_cfg_for_features(default_cfg)  # add back default_cfg
     
        return model
    
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52

    可以看到,模型在这一步被创建出来:model = model_cls(**kwargs)。

    总结

    • 每个模型有不同的构造函数,可以传入不同的参数, _model_entrypoints 字典包括了所有的模型名称及其对应的构造函数
    • build_with_model_cfg 函数接收模型构造器类和其中的一些具体参数,真正地实例化一个模型
    • load_pretrained 会加载预训练参数
    • FeatureListNet 类可以将模型转换为特征提取器

    【Timm】create_model全面详实概念理解及实践篇 (opens new window)

    上次更新: 2025/06/25, 11:25:50
    timm-vit代码解读
    概述

    ← timm-vit代码解读 概述→

    最近更新
    01
    帮助信息查看
    06-08
    02
    常用命令
    06-08
    03
    学习资源
    06-07
    更多文章>
    Theme by Vdoing | Copyright © 2022-2025 Geeks_Z | MIT License
    京公网安备 11010802040735号 | 京ICP备2022029989号-1
    • 跟随系统
    • 浅色模式
    • 深色模式
    • 阅读模式