Python ResNet 图像分类

作者:追风剑情 发布于:2026-5-20 18:40 分类:AI

示例演示了如何使用预训练好的 ResNet 神经网络来识别图片中的物品类别。ResNet 使用的是 ImageNet 数据集。

# pip install torch torchvision pillow
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import requests
from io import BytesIO

# ------------------- 1. 加载模型结构 -------------------
model = models.resnet50(weights=None)   # 不自动下载

# 会自动从网上下载模型权重文件,国内通常会下载失败或者卡很久。
# model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)

# ------------------- 2. 加载本地权重 -------------------
# https://download.pytorch.org/models/resnet50-0676ba61.pth
weight_path = r"resnet50-0676ba61.pth"
state_dict = torch.load(weight_path, map_location=torch.device('cpu'))
model.load_state_dict(state_dict)
model.eval()

# ------------------- 3. 从本地文件加载 ImageNet 标签 -------------------
# https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt
with open("imagenet_classes.txt", "r") as f:
    labels = [line.strip() for line in f.readlines()]

# ------------------- 4. 图像预处理 -------------------
preprocess = transforms.Compose([
    # 缩放
    transforms.Resize(256),
    # 裁剪
    transforms.CenterCrop(224),
    # 转张量
    transforms.ToTensor(),
    # 标准化
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# ------------------- 5. 加载本地图片 -------------------
img = Image.open("American_Eskimo_Dog.jpg").convert('RGB')


# ------------------- 6. 推理 -------------------
# 将原始图片 img 通过预处理流程(缩放、裁剪、转张量、标准化)转换成模型需要的输入格式
input_tensor = preprocess(img)

# 在 batch 维度上增加一个维度,因为 PyTorch 模型要求输入必须是四维: [batch_size, channels, height, width]
# unsqueeze(0) 在索引 0 处插入一维,形状变为 [1, 3, 224, 224],即 batch_size = 1
input_batch = input_tensor.unsqueeze(0)

# torch.no_grad() 上下文管理器:在此区域内不计算梯度,节省内存和计算时间
# 因为在推理(预测)时不需要反向传播,所以关闭自动求导功能
with torch.no_grad():
    # 将预处理好的 batch 数据送入 ResNet50 模型进行前向传播
    # 模型输出原始 logits(未经过 softmax 的分数),形状为 [1, 1000]
    output = model(input_batch)

# ------------------- 7. 输出结果 -------------------
# 对模型输出的原始 logits(形状 [1000])应用 softmax 函数,将其转换为概率分布
# softmax 会将每个类别的分数映射到 [0, 1] 区间,且所有类别的概率之和为 1
# dim=0 表示在第 0 维(即类别这一维)上计算 softmax
probabilities = torch.nn.functional.softmax(output[0], dim=0)

# 在概率分布中找出最大值及其对应的索引位置
# top_prob 保存最大的概率值(标量),top_idx 保存该概率对应的类别索引(0~999)
# dim=0 表示沿着类别维度搜索最大值
top_prob, top_idx = torch.max(probabilities, 0)
print(f"预测类别: {labels[top_idx]} (置信度: {top_prob.item():.4f})")

# 获取置信度最高的前5项
top_probs, top_indices = torch.topk(probabilities, k=5, dim=0)
print("预测结果(Top-5):")
for i in range(5):
    class_name = labels[top_indices[i]]
    prob = top_probs[i].item()
    print(f"  第{i+1}名: {class_name} (置信度: {prob:.4f})")

运行测试
11111.png

标签: AI

Powered by emlog  蜀ICP备18021003号-1   sitemap

川公网安备 51019002001593号