timm 是一个用于深度学习的开源库,全称是 “PyTorch Image Models”。该库由 Ross Wightman 创建并维护,旨在提供高效且易于使用的图像模型,包括大量预训练的模型和实用工具。timm 库基于 PyTorch 框架,主要特点包括:
下面是一个使用 timm 库加载预训练模型并进行推理的简单示例:
import timm import torch from PIL import Image from torchvision import transforms # 加载预训练的 ResNet50 模型 model = timm.create_model('resnet50', pretrained=True) model.eval() # 图像预处理 preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # 加载并预处理图像 img = Image.open('path_to_your_image.jpg') img_tensor = preprocess(img).unsqueeze(0) # 推理 with torch.no_grad(): output = model(img_tensor) # 输出结果 print(output)
timm
(PyTorch Image Models)库包含了众多预训练的图像分类模型,这些模型在各种流行的数据集上进行了训练。以下是一些主要的预训练模型类别和具体模型名称:
ResNet 系列
EfficientNet 系列
Vision Transformers (ViT)
DeiT(Data-efficient Image Transformers)
MobileNet 系列
RegNet 系列
DenseNet 系列
Inception 系列
SENet 系列
EfficientDet(用于目标检测)
NFNet 系列
ConvNeXt 系列
Swin Transformer
CaiT(Class-Attention in Image Transformers)
可以通过以下代码查看 timm
支持的所有预训练模型:
import timm # 列出所有可用的模型 model_names = timm.list_models(pretrained=True) print(model_names)
这些预训练模型已经在ImageNet等大型数据集上进行了训练,因此在迁移学习任务中通常表现良好。选择适合你任务的模型架构,可以加快训练过程,并提高模型的性能。