基于 Hugging Face Datasets 和 Transformers 的图像相似性搜索
基于 HuggingFace Datasets 和 Transformers 的图像相似性搜索
通过本文,你将学习使用 🤗 Transformers 构建图像相似性搜索系统。找出查询图像和潜在候选图像之间的相似性是信息检索系统的一个重要用例,例如反向图像搜索 (即找出查询图像的原图)。此类系统试图解答的问题是,给定一个 查询 图像和一组 候选 图像,找出候选图像中哪些图像与查询图像最相似。
我们将使用 🤗 datasets
库,因为它无缝支持并行处理,这在构建系统时会派上用场。
尽管这篇文章使用了基于 ViT 的模型 (nateraw/vit-base-beans
) 和特定的 (Beans) 数据集,但它可以扩展到其他支持视觉模态的模型,也可以扩展到其他图像数据集。你可以尝试的一些著名模型有:
此外,文章中介绍的方法也有可能扩展到其他模态。
要研究完整的图像相似度系统,你可以参考 这个 Colab Notebook。
我们如何定义相似性?
要构建这个系统,我们首先需要定义我们想要如何计算两个图像之间的相似度。一种广泛流行的做法是先计算给定图像的稠密表征 (即嵌入 (embedding)),然后使用 余弦相似性度量 (cosine similarity metric) 来确定两幅图像的相似程度。
在本文中,我们将使用 “嵌入” 来表示向量空间中的图像。它为我们提供了一种将图像从高维像素空间 (例如 224 × 224 × 3) 有意义地压缩到一个低得多的维度 (例如 768) 的好方法。这样做的主要优点是减少了后续步骤中的计算时间。
计算嵌入
为了计算图像的嵌入,我们需要使用一个视觉模型,该模型知道如何在向量空间中表示输入图像。这种类型的模型通常也称为图像编码器 (image encoder)。
我们利用 AutoModel
类来加载模型。它为我们提供了一个接口,可以从 HuggingFace Hub 加载任何兼容的模型 checkpoint。除了模型,我们还会加载与模型关联的处理器 (processor) 以进行数据预处理。
from transformers import AutoFeatureExtractor, AutoModel model_ckpt = "nateraw/vit-base-beans" extractor = AutoFeatureExtractor.from_pretrained (model_ckpt) model = AutoModel.from_pretrained (model_ckpt)
本例中使用的 checkpoint 是一个在 beans
数据集 上微调过的 ViT 模型。
这里可能你会问一些问题:
Q1: 为什么我们不使用 AutoModelForImageClassification
?
这是因为我们想要获得图像的稠密表征,而 AutoModelForImageClassification
只能输出离散类别。
Q2: 为什么使用这个特定的 checkpoint?
如前所述,我们使用特定的数据集来构建系统。因此,与其使用通用模型 (例如 在 ImageNet-1k 数据集上训练的模型),不如使用使用已针对所用数据集微调过的模型。这样,模型能更好地理解输入图像。
注意 你还可以使用通过自监督预训练获得的 checkpoint, 不必得由有监督学习训练而得。事实上,如果预训练得当,自监督模型可以 获得 令人印象深刻的检索性能。
现在我们有了一个用于计算嵌入的模型,我们需要一些候选图像来被查询。
加载候选图像数据集
后面,我们会构建将候选图像映射到哈希值的哈希表。在查询时,我们会使用到这些哈希表,详细讨论的讨论稍后进行。现在,我们先使用 beans
数据集 中的训练集来获取一组候选图像。
from datasets import load_dataset dataset = load_dataset ("beans")
以下展示了训练集中的一个样本:
该数据集的三个 features
如下:
dataset ["train"].features >>> {'image_file_path': Value (dtype='string', id=None), 'image': Image (decode=True, id=None), 'labels': ClassLabel (names=['angular_leaf_spot', 'bean_rust', 'healthy'], id=None)}
为了使图像相似性系统可演示,系统的总体运行时间需要比较短,因此我们这里只使用候选图像数据集中的 100 张图像。
num_samples = 100 seed = 42 candidate_subset = dataset ["train"].shuffle (seed=seed).select (range (num_samples))
寻找相似图片的过程
下图展示了获取相似图像的基本过程。
稍微拆解一下上图,我们分为 4 步走:
- 从候选图像 (
candidate_subset
) 中提取嵌入,将它们存储在一个矩阵中。 - 获取查询图像并提取其嵌入。
- 遍历嵌入矩阵 (步骤 1 中得到的) 并计算查询嵌入和当前候选嵌入之间的相似度得分。我们通常维护一个类似字典的映射,来维护候选图像的 ID 与相似性分数之间的对应关系。
- 根据相似度得分进行排序并返回相应的图像 ID。最后,使用这些 ID 来获取候选图像。
我们可以编写一个简单的工具函数用于计算嵌入并使用 map()
方法将其作用于候选图像数据集的每张图像,以有效地计算嵌入。
import torch def extract_embeddings (model: torch.nn.Module): """Utility to compute embeddings.""" device = model.device def pp (batch): images = batch ["image"] # `transformation_chain` is a compostion of preprocessing # transformations we apply to the input images to prepare them # for the model. For more details, check out the accompanying Colab Notebook. image_batch_transformed = torch.stack ( [transformation_chain (image) for image in images] ) new_batch = {"pixel_values": image_batch_transformed.to (device)} with torch.no_grad (): embeddings = model (**new_batch).last_hidden_state [:, 0].cpu () return {"embeddings": embeddings} return pp
我们可以像这样映射 extract_embeddings()
:
device = "cuda" if torch.cuda.is_available () else "cpu" extract_fn = extract_embeddings (model.to (device)) candidate_subset_emb = candidate_subset.map (extract_fn, batched=True, batch_size=batch_size)
接下来,为方便起见,我们创建一个候选图像 ID 的列表。
candidate_ids = [] for id in tqdm (range (len (candidate_subset_emb))): label = candidate_subset_emb [id]["labels"] # Create a unique indentifier. entry = str (id) + "_" + str (label) candidate_ids.append (entry)
我们用包含所有候选图像的嵌入矩阵来计算与查询图像的相似度分数。我们之前已经计算了候选图像嵌入,在这里我们只是将它们集中到一个矩阵中。
all_candidate_embeddings = np.array (candidate_subset_emb ["embeddings"]) all_candidate_embeddings = torch.from_numpy (all_candidate_embeddings)
我们将使用 余弦相似度 来计算两个嵌入向量之间的相似度分数。然后,我们用它来获取给定查询图像的相似候选图像。
def compute_scores (emb_one, emb_two): """Computes cosine similarity between two vectors.""" scores = torch.nn.functional.cosine_similarity (emb_one, emb_two) return scores.numpy ().tolist () def fetch_similar (image, top_k=5): """Fetches the`top_k`similar images with`image`as the query.""" # Prepare the input query image for embedding computation. image_transformed = transformation_chain (image).unsqueeze (0) new_batch = {"pixel_values": image_transformed.to (device)} # Comute the embedding. with torch.no_grad (): query_embeddings = model (**new_batch).last_hidden_state [:, 0].cpu () # Compute similarity scores with all the candidate images at one go. # We also create a mapping between the candidate image identifiers # and their similarity scores with the query image. sim_scores = compute_scores (all_candidate_embeddings, query_embeddings) similarity_mapping = dict (zip (candidate_ids, sim_scores)) # Sort the mapping dictionary and return `top_k` candidates. similarity_mapping_sorted = dict ( sorted (similarity_mapping.items (), key=lambda x: x [1], reverse=True) ) id_entries = list (similarity_mapping_sorted.keys ())[:top_k] ids = list (map (lambda x: int (x.split ("_")[0]), id_entries)) labels = list (map (lambda x: int (x.split ("_")[-1]), id_entries)) return ids, labels
执行查询
经过以上准备,我们可以进行相似性搜索了。我们从 beans
数据集的测试集中选取一张查询图像来搜索:
test_idx = np.random.choice (len (dataset ["test"])) test_sample = dataset ["test"][test_idx]["image"] test_label = dataset ["test"][test_idx]["labels"] sim_ids, sim_labels = fetch_similar (test_sample) print (f"Query label: {test_label}") print (f"Top 5 candidate labels: {sim_labels}")
结果为:
Query label: 0 Top 5 candidate labels: [0, 0, 0, 0, 0]
看起来我们的系统得到了一组正确的相似图像。将结果可视化,如下:
进一步扩展与结论
现在,我们有了一个可用的图像相似度系统。但实际系统需要处理比这多得多的候选图像。考虑到这一点,我们目前的程序有不少缺点:
- 如果我们按原样存储嵌入,内存需求会迅速增加,尤其是在处理数百万张候选图像时。在我们的例子中嵌入是 768 维,这即使对大规模系统而言可能也是相对比较高的维度。
- 高维的嵌入对检索部分涉及的后续计算有直接影响。
如果我们能以某种方式降低嵌入的维度而不影响它们的意义,我们仍然可以在速度和检索质量之间保持良好的折衷。本文 附带的 Colab Notebook 实现并演示了如何通过随机投影 (random projection) 和位置敏感哈希 (locality-sensitive hashing,LSH) 这两种方法来取得折衷。
🤗 Datasets 提供与 FAISS 的直接集成,进一步简化了构建相似性系统的过程。假设你已经提取了候选图像的嵌入 (beans
数据集) 并把他们存储在称为 embedding
的 feature
中。你现在可以轻松地使用 dataset
的 add_faiss_index()
方法来构建稠密索引:
dataset_with_embeddings.add_faiss_index (column="embeddings")
建立索引后,可以使用 dataset_with_embeddings
模块的 get_nearest_examples()
方法为给定查询嵌入检索最近邻:
scores, retrieved_examples = dataset_with_embeddings.get_nearest_examples ( "embeddings", qi_embedding, k=top_k )
该方法返回检索分数及其对应的图像。要了解更多信息,你可以查看 官方文档 和 这个 notebook。
在本文中,我们快速入门并构建了一个图像相似度系统。如果你觉得这篇文章很有趣,我们强烈建议你基于我们讨论的概念继续构建你的系统,这样你就可以更加熟悉内部工作原理。
还想了解更多吗?以下是一些可能对你有用的其他资源:
英文原文: https://hf.co/blog/image-similarity
译者: Matrix Yao (姚伟峰),英特尔深度学习工程师,工作方向为 transformer-family 模型在各模态数据上的应用及大规模模型的训练推理。
审校、排版: zhongdongy (阿东)

低调大师中文资讯倾力打造互联网数据资讯、行业资源、电子商务、移动互联网、网络营销平台。
持续更新报道IT业界、互联网、市场资讯、驱动更新,是最及时权威的产业资讯及硬件资讯报道平台。
转载内容版权归作者及来源网站所有,本站原创内容转载请注明来源。
- 上一篇
多云和混合云场景下的 API 管理:挑战与选择
作者张超,API7 Cloud 产品负责人,Apache APISIX PMC 成员。 原文链接 一、多云和混合云 如今微服务已经成为最流行的一种软件架构,人们通过自己对业务的理解,和科学方法(比如领域驱动设计的理论)的加持将组织对外提供的产品拆分为一个个微服务,同时按照微服务的架构调整组织架构(逆康威定律)来开展业务的迭代。 以往组织习惯于自建数据中心来部署自家的产品,这些数据中心可能构建在买断或者是租赁的机房中,组织需要对机房的设备进行复杂的管理,包括交换机、服务器、磁盘、机柜等这些硬件设施,以及应对因为机房掉电、机柜温度过高、服务器崩溃等带来的故障。应对这些要问题往往会花费大量的人力财力,同时达不到很好的效果,使得自己产品的 SLA(服务等级约定)往往无法达到对客户的承诺,从而造成赔偿。 随着云的兴起,人们越来越习惯于将自己的业务部署到公有云之上,公有云帮助用户屏蔽了硬件的细节,工程师再也不用关心基础设施的搭建和维护问题,而是可以把自己的精力尽可能得放到业务的部署、维护和开发之上。然而,除了享受云带来的便利以外,云的兴起和使用也给用户带来了其他的一些问题: 被“锁定”,当深度使用...
- 下一篇
JuiceFS 在火山引擎边缘计算的应用实践
火山引擎边缘云是以云计算基础技术和边缘异构算力结合网络为基础,构建在边缘大规模基础设施之上的云计算服务,形成以边缘位置的计算、网络、存储、安全、智能为核心能力的新一代分布式云计算解决方案。 01- 边缘场景存储挑战 边缘存储主要面向适配边缘计算的典型业务场景,如边缘渲染。火山引擎边缘渲染依托底层海量算力资源,可助力用户实现百万渲染帧队列轻松编排、渲染任务就近调度、多任务多节点并行渲染,极大提升渲染 简单介绍一下在边缘渲染中遇到的存储问题: 需要对象存储与文件系统的元数据统一,实现数据通过对象存储接口上传以后,可以通过 POSIX 接口直接进行操作; 满足高吞吐量的场景需求,尤其是在读的时候; 完全实现 S3 接口和 POSIX 接口。 为了解决在边缘渲染中遇到的存储问题,团队花了将近半年的时间开展了存储选型测试。最初,团队选择了公司内部的存储组件,从可持续性和性能上来说,都能比较好的满足我们的需求。 但是落地到边缘场景,有两个具体的问题: 首先,公司内部组件是为了中心机房设计的,对于物理机资源和数量是有要求的,边缘某些机房很难满足; 其次,整个公司的存储组件都打包在一起,包括:对象存储...
相关文章
文章评论
共有0条评论来说两句吧...
文章二维码
点击排行
推荐阅读
最新文章
- CentOS7编译安装Gcc9.2.0,解决mysql等软件编译问题
- SpringBoot2整合Redis,开启缓存,提高访问速度
- SpringBoot2整合MyBatis,连接MySql数据库做增删改查操作
- CentOS8,CentOS7,CentOS6编译安装Redis5.0.7
- MySQL8.0.19开启GTID主从同步CentOS8
- Mario游戏-低调大师作品
- Linux系统CentOS6、CentOS7手动修改IP地址
- Docker使用Oracle官方镜像安装(12C,18C,19C)
- Docker安装Oracle12C,快速搭建Oracle学习环境
- CentOS7安装Docker,走上虚拟化容器引擎之路