Python ResNet 图像特征提取

作者:追风剑情 发布于:2026-5-26 17:55 分类:AI

安装依赖:pip install chromadb torch torchvision pillow

import chromadb
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import os

# ------------------- 1. 构建特征提取器 -------------------
def get_feature_extractor(weight_path="resnet50-0676ba61.pth"):
    # 先创建结构,不加载预训练
    model = models.resnet50(weights=None)
    # 加载本地权重
    state_dict = torch.load(weight_path, map_location='cpu')
    model.load_state_dict(state_dict)
    # 移除分类头
    model = torch.nn.Sequential(*list(model.children())[:-1])
    model.eval()
    return model

# ------------------- 2. 图片预处理 -------------------
def get_transform():
    """
    定义图片预处理流程:
    1. 缩放到 256x256
    2. 中心裁剪至 224x224(符合 ResNet 输入尺寸)
    3. 转为张量
    4. 标准化(使用 ImageNet 数据集的均值和标准差)
    """
    return transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

# ------------------- 3. 单张图片特征提取函数 -------------------
def extract_feature(model, transform, image_path):
    """
    从单张图片中提取特征向量
    返回一个形状为 (2048,) 的 numpy 数组
    """
    # 读取图片并转为 RGB
    img = Image.open(image_path).convert('RGB')
    # 预处理
    img_tensor = transform(img)
    # 增加 batch 维度: (1, 3, 224, 224)
    img_tensor = img_tensor.unsqueeze(0)
    # 提取特征(禁用梯度计算,节省内存)
    with torch.no_grad():
        feature = model(img_tensor)
    # 去除多余维度 -> (2048,)
    feature = feature.squeeze()
    # L2 归一化
    feature = feature / torch.norm(feature, p=2)
    return feature.numpy()

# ------------------- 4. 初始化 ChromaDB 客户端 -------------------
# 使用内存模式(演示用),生产环境可改为持久化模式
client = chromadb.Client()

# 创建或获取一个集合(Collection)
# 一个 Collection 相当于向量数据库中的一张"表",存储同类向量
collection = client.create_collection(
    name="image_features",
    metadata={"description": "使用 ResNet50 提取的图片特征向量"}
)

# ------------------- 5. 模拟准备一些图片 -------------------
# 假设当前目录下有以下图片文件
image_paths = {
    "img_001": "cat.png",
    "img_002": "dog.png",
    "img_003": "car.png"
}

# 初始化特征提取器和预处理管道
extractor = get_feature_extractor()
transform = get_transform()

# ------------------- 6. 提取特征并批量存入 ChromaDB -------------------
ids = []          # 存储每条记录的 ID
embeddings = []   # 存储对应的特征向量
metadatas = []    # 存储相关的元数据(如文件名、路径等)

for img_id, path in image_paths.items():
    if not os.path.exists(path):
        print(f"警告: 图片文件 {path} 不存在,跳过")
        continue

    # 提取特征向量
    feature_vec = extract_feature(extractor, transform, path)

    # 记录元数据,方便后续查询时获取原始图片信息
    ids.append(img_id)
    embeddings.append(feature_vec.tolist())
    metadatas.append({"filename": path, "id": img_id})

    print(f"已处理 {img_id}: {path}, 特征向量维度: {len(feature_vec)}")

# 批量添加到 ChromaDB
collection.add(
    ids=ids,
    embeddings=embeddings,
    metadatas=metadatas
)

print(f"\n成功存入 {len(ids)} 张图片的特征向量到 ChromaDB!")

# ------------------- 7. 相似度查询示例 -------------------
# 假设用第一张图片(cat.jpg)作为查询图片,检索最相似的图片
query_img_path = "cat.png"
query_feature = extract_feature(extractor, transform, query_img_path)

# 执行相似性搜索
results = collection.query(
    query_embeddings=[query_feature.tolist()],
    n_results=3,  # 返回最相似的 2 条结果
    include=["metadatas", "distances"]
)

print("\n=== 相似度检索结果 ===")
for idx, (metadata, distance) in enumerate(zip(results['metadatas'][0], results['distances'][0])):
    print(f"排名 {idx + 1}: {metadata['filename']}, 余弦距离: {distance:.4f}")

11111.png

标签: AI

Powered by emlog  蜀ICP备18021003号-1   sitemap

川公网安备 51019002001593号