使用 PyTorch 和 Milvus 进行图像搜索
本指南介绍了一个集成 PyTorch 和 Milvus 使用嵌入进行图像搜索的示例。PyTorch 是一个强大的开源深度学习框架,广泛用于构建和部署机器学习模型。在这个示例中,我们将利用其 Torchvision 库和预训练的 ResNet50 模型来生成表示图像内容的特征向量(嵌入)。这些嵌入将存储在 Milvus 中,这是一个高性能向量数据库,以实现高效的相似性搜索。使用的数据集是来自 Kaggle 的印象派分类器数据集。通过将 PyTorch 的深度学习能力与 Milvus 的可扩展搜索功能相结合,本示例演示了如何构建一个强大且高效的图像检索系统。
让我们开始吧!
安装要求
在这个示例中,我们将使用 pymilvus
连接和使用 Milvus,使用 torch
运行嵌入模型,使用 torchvision
进行实际模型和预处理,使用 gdown
下载示例数据集,使用 tqdm
显示加载进度条。
pip install pymilvus torch gdown torchvision tqdm
获取数据
我们将使用 gdown
从 Google Drive 获取 zip 文件,然后使用内置的 zipfile
库解压缩它。
import gdown
import zipfile
url = 'https://drive.google.com/uc?id=1OYDHLEy992qu5C4C8HV5uDIkOWRTAR1_'
output = './paintings.zip'
gdown.download(url, output)
with zipfile.ZipFile("./paintings.zip","r") as zip_ref:
zip_ref.extractall("./paintings")
数据集的大小为 2.35 GB,下载时间取决于您的网络条件。
全局参数
这些是我们将使用的一些主要全局参数,用于更容易的跟踪和更新。
# Milvus Setup Arguments
COLLECTION_NAME = 'image_search' # Collection name
DIMENSION = 2048 # Embedding vector size in this example
MILVUS_HOST = "localhost"
MILVUS_PORT = "19530"
# Inference Arguments
BATCH_SIZE = 128
TOP_K = 3
设置 Milvus
此时,我们开始设置 Milvus。步骤如下:
-
使用提供的 URI 连接到 Milvus 实例。
from pymilvus import connections
# Connect to the instance
connections.connect(host=MILVUS_HOST, port=MILVUS_PORT) -
如果 Collection 已存在,则删除它。
from pymilvus import utility
# Remove any previous collections with the same name
if utility.has_collection(COLLECTION_NAME):
utility.drop_collection(COLLECTION_NAME) -
创建包含 ID、图像文件路径和其嵌入的 Collection。
from pymilvus import FieldSchema, CollectionSchema, DataType, Collection
# Create collection which includes the id, filepath of the image, and image embedding
fields = [
FieldSchema(name='id', dtype=DataType.INT64, is_primary=True, auto_id=True),
FieldSchema(name='filepath', dtype=DataType.VARCHAR, max_length=200), # VARCHARS need a maximum length, so for this example they are set to 200 characters
FieldSchema(name='image_embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION)
]
schema = CollectionSchema(fields=fields)
collection = Collection(name=COLLECTION_NAME, schema=schema) -
在新创建的 Collection 上创建索引并将其加载到内存中。
# Create an AutoIndex index for collection
index_params = {
'metric_type':'L2',
'index_type':"IVF_FLAT",
'params':{'nlist': 16384}
}
collection.create_index(field_name="image_embedding", index_params=index_params)
collection.load()
一旦这些步骤完成,Collection 就可以插入数据并进行搜索了。任何添加的数据都将自动建立索引并立即可用于搜索。如果数据很新,搜索可能会较慢,因为仍在建立索引过程中的数据将使用暴力搜索。
插入数据
在这个示例中,我们将使用 torch
及其模型中心提供的 ResNet50 模型。为了获得嵌入,我们去掉了最终的分类层,这使得模型为我们提供 2048 维的嵌入。在 torch
上找到的所有视觉模型都使用我们在这里包含的相同预处理。
在接下来的几个步骤中,我们将:
-
加载数据。
import glob
# Get the filepaths of the images
paths = glob.glob('./paintings/paintings/**/*.jpg', recursive=True)
len(paths) -
将数据预处理成批次。
import torch
# Load the embedding model with the last layer removed
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
model = torch.nn.Sequential(*(list(model.children())[:-1]))
model.eval() -
嵌入数据。
from torchvision import transforms
# Preprocessing for images
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]) -
插入数据。
from PIL import Image
from tqdm import tqdm
# Embed function that embeds the batch and inserts it
def embed(data):
with torch.no_grad():
output = model(torch.stack(data[0])).squeeze()
collection.insert([data[1], output.tolist()])
data_batch = [[],[]]
# Read the images into batches for embedding and insertion
for path in tqdm(paths):
im = Image.open(path).convert('RGB')
data_batch[0].append(preprocess(im))
data_batch[1].append(path)
if len(data_batch[0]) % BATCH_SIZE == 0:
embed(data_batch)
data_batch = [[],[]]
# Embed and insert the remainder
if len(data_batch[0]) != 0:
embed(data_batch)
# Call a flush to index any unsealed segments.
collection.flush()- 这个步骤相对耗时,因为嵌入需要时间。喝杯咖啡放松一下吧。
- PyTorch 可能与 Python 3.9 及更早版本不太兼容。请考虑使用 Python 3.10 及更高版本。
执行搜索
将所有数据插入 Milvus 后,我们可以开始执行搜索。在这个示例中,我们将搜索两个示例图像。因为我们进行的是批量搜索,搜索时间在批次的图像之间共享。
import glob
# Get the filepaths of the search images
search_paths = glob.glob('./paintings/test_paintings/**/*.jpg', recursive=True)
len(search_paths)
import time
from matplotlib import pyplot as plt
# Embed the search images
def embed(data):
with torch.no_grad():
ret = model(torch.stack(data))
# If more than one image, use squeeze
if len(ret) > 1:
return ret.squeeze().tolist()
# Squeeze would remove batch for single image, so using flatten
else:
return torch.flatten(ret, start_dim=1).tolist()
data_batch = [[],[]]
for path in search_paths:
im = Image.open(path).convert('RGB')
data_batch[0].append(preprocess(im))
data_batch[1].append(path)
embeds = embed(data_batch[0])
start = time.time()
res = collection.search(embeds, anns_field='image_embedding', param={'nprobe': 128}, limit=TOP_K, output_fields=['filepath'])
finish = time.time()
# Show the image results
f, axarr = plt.subplots(len(data_batch[1]), TOP_K + 1, figsize=(20, 10), squeeze=False)
for hits_i, hits in enumerate(res):
axarr[hits_i][0].imshow(Image.open(data_batch[1][hits_i]))
axarr[hits_i][0].set_axis_off()
axarr[hits_i][0].set_title('Search Time: ' + str(finish - start))
for hit_i, hit in enumerate(hits):
axarr[hits_i][hit_i + 1].imshow(Image.open(hit.entity.get('filepath')))
axarr[hits_i][hit_i + 1].set_axis_off()
axarr[hits_i][hit_i + 1].set_title('Distance: ' + str(hit.distance))
# Save the search result in a separate image file alongside your script.
plt.savefig('search_result.png')
搜索结果图像应该与以下类似: