Python CLIP 实现文搜图

作者:追风剑情 发布于:2026-5-28 18:13 分类:AI

"""
Chinese-CLIP 文搜图(中文查询)完整示例
- 使用本地 OFA-Sys/chinese-clip-vit-base-patch16 模型
- ChromaDB 余弦距离检索
- 国内镜像加速下载
"""

# ==================== 0. 设置镜像(必须在导入 transformers 之前) ====================
import os
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"

# ==================== 1. 导入依赖 ====================
import torch
import numpy as np
from PIL import Image
from transformers import ChineseCLIPProcessor, ChineseCLIPModel
import chromadb

# ==================== 2. 加载 CLIP 模型 ====================
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "OFA-Sys/chinese-clip-vit-base-patch16"

model = ChineseCLIPModel.from_pretrained(model_name).to(device)
processor = ChineseCLIPProcessor.from_pretrained(model_name)
model.eval()
print(f"模型加载完成,设备: {device}")

# ==================== 3. 特征提取函数(使用底层 API,稳定可靠) ====================
def get_image_embedding(image_path):
    """返回归一化的 512 维图像特征向量 (list)"""
    image = Image.open(image_path).convert("RGB")
    inputs = processor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        # 视觉模型提取特征
        vision_outputs = model.vision_model(pixel_values=inputs['pixel_values'])
        # 取 [CLS] 标记(第一个 token),与文本处理保持一致
        cls_token = vision_outputs.last_hidden_state[:, 0, :]   # [1, hidden_size]
        # 通过 visual_projection 映射到 512 维
        image_features = model.visual_projection(cls_token)
    # L2 归一化
    image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
    return image_features[0].cpu().numpy().tolist()

def get_text_embedding(text):
    """返回归一化的 512 维文本特征向量 (list)"""
    inputs = processor(text=text, return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        # 1. 获取文本模型的原始输出
        text_outputs = model.text_model(**inputs)
        # 2. 取 [CLS] 标记的隐藏状态(第一个 token)
        #    last_hidden_state shape: [batch_size, seq_len, hidden_size]
        cls_token = text_outputs.last_hidden_state[:, 0, :]   # [1, hidden_size]
        # 3. 通过 text_projection 层映射到 512 维
        text_features = model.text_projection(cls_token)
    # L2 归一化
    text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)
    return text_features[0].cpu().numpy().tolist()

# ==================== 4. 初始化 ChromaDB(余弦距离) ====================
client = chromadb.PersistentClient(path="./chroma_clip_cn_data")
collection = client.get_or_create_collection(
    name="clip_images_cn",
    metadata={"hnsw:space": "cosine"}
)

# 清空旧数据(如果有)
if collection.count() > 0:
    collection.delete(ids=collection.get()["ids"])
    print("已清空旧集合")

# ==================== 5. 准备图片数据 ====================
image_files = ["cat.png", "dog.png", "car.png"]
image_descriptions = [
    "一只毛茸茸的可爱猫咪,坐在柔软的垫子上",
    "一只快乐的狗在阳光明媚的公园里玩接球游戏,尾巴摇来摇去",
    "一辆现代红色汽车在风景优美的沿海公路上行驶"
]

ids = []
embeddings = []
metadatas = []

print("\n正在提取图片特征...")
for i, img_file in enumerate(image_files):
    if not os.path.exists(img_file):
        print(f"❌ 文件不存在: {img_file}")
        continue
    emb = get_image_embedding(img_file)
    ids.append(f"img_{i+1}")
    embeddings.append(emb)
    metadatas.append({"filename": img_file, "description": image_descriptions[i]})
    print(f"✅ {img_file} -> 维度 {len(emb)}, 范数 {np.linalg.norm(emb):.6f}")

# 存入 ChromaDB
collection.add(ids=ids, embeddings=embeddings, metadatas=metadatas)
print(f"\n✅ 成功存入 {collection.count()} 张图片\n")

# ==================== 6. 文搜图查询函数 ====================
def search_text_to_image(query_text, n_results=3):
    print(f"查询文本: \"{query_text}\"")
    query_vec = get_text_embedding(query_text)
    results = collection.query(
        query_embeddings=[query_vec],
        n_results=n_results,
        include=["metadatas", "distances"]
    )
    print("搜索结果(距离值越小越相似):")
    for i, (meta, dist) in enumerate(zip(results["metadatas"][0], results["distances"][0])):
        # 核心修改:直接使用 distance 进行排序,这就是最正确的用法
        print(f"   {i+1}. {meta['filename']} (余弦距离: {dist:.4f}) - {meta['description']}")
    print()

# ==================== 7. 测试查询 ====================
if __name__ == "__main__":
    search_text_to_image("一只长胡须的毛茸茸动物")
    search_text_to_image("摇尾巴的四条腿朋友")
    search_text_to_image("公路上行驶的轿车")

运行测试
11111.png

标签: AI

Powered by emlog  蜀ICP备18021003号-1   sitemap

川公网安备 51019002001593号