CLIP在Github上的使用教程

CLIP的github链接:https://github.com/openai/CLIP

CLIP

Blog,Paper,Model Card,Colab
CLIP(对比语言-图像预训练)是一个在各种(图像、文本)对上进行训练的神经网络。可以用自然语言指示它在给定图像的情况下预测最相关的文本片段,而无需直接对任务进行优化,这与 GPT-2 和 3 的零镜头功能类似。我们发现,CLIP 无需使用任何 128 万个原始标注示例,就能在 ImageNet "零拍摄 "上达到原始 ResNet50 的性能,克服了计算机视觉领域的几大挑战。

Usage用法

首先,安装 PyTorch 1.7.1(或更高版本)和 torchvision,以及少量其他依赖项,然后将此 repo 作为 Python 软件包安装。在 CUDA GPU 机器上,完成以下步骤即可:

conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=11.0
pip install ftfy regex tqdm
pip install git+https://github.com/openai/CLIP.git

将上面的 cudatoolkit=11.0 替换为机器上相应的 CUDA 版本,如果在没有 GPU 的机器上安装,则替换为 cpuonly

import torch
import clip
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)

with torch.no_grad():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    
    logits_per_image, logits_per_text = model(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

print("Label probs:", probs)  # prints: [[0.9927937  0.00421068 0.00299572]]

API

CLIP 模块提供以下方法:

clip.available_models()

返回可用 CLIP 模型的名称。例如下面就是我执行的结果。
在这里插入图片描述

clip.load(name, device=..., jit=False)

返回模型和模型所需的 TorchVision 变换(由 clip.available_models() 返回的模型名称指定)。它将根据需要下载模型。name参数也可以是本地检查点的路径。
可以选择指定运行模型的设备,默认情况下,如果有第一个 CUDA 设备,则使用该设备,否则使用 CPU。当 jitFalse 时,将加载模型的非 JIT 版本。

clip.tokenize(text: Union[str, List[str]], context_length=77)

返回包含给定文本输入的标记化序列的 LongTensor。这可用作模型的输入。

clip.load() 返回的模型支持以下方法:

model.encode_image(image: Tensor)

给定一批图像,返回 CLIP 模型视觉部分编码的图像特征。

model.encode_text(text: Tensor)

给定一批文本标记,返回 CLIP 模型语言部分编码的文本特征。

model(image: Tensor, text: Tensor)

给定一批图像和一批文本标记,返回两个张量,其中包含与每张图像和每个文本输入相对应的 logit 分数。这些值是相应图像和文本特征之间的余弦相似度乘以 100。

More Examples更多实例

Zero-Shot预测

下面的代码使用 CLIP 执行零点预测,如论文附录 B 所示。该示例从 CIFAR-100 数据集中获取一张图片,并预测数据集中 100 个文本标签中最有可能出现的标签。

import os
import clip
import torch
from torchvision.datasets import CIFAR100

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)

# Download the dataset
cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)

# Prepare the inputs
image, class_id = cifar100[3637]
image_input = preprocess(image).unsqueeze(0).to(device)
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)

# Calculate features
with torch.no_grad():
    image_features = model.encode_image(image_input)
    text_features = model.encode_text(text_inputs)

# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)

# Print the result
print("\nTop predictions:\n")
for value, index in zip(values, indices):
    print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%")

输出结果如下(具体数字可能因计算设备而略有不同):

Top predictions:

           snake: 65.31%
          turtle: 12.29%
    sweet_pepper: 3.83%
          lizard: 1.88%
       crocodile: 1.75%

请注意,本示例使用的 encode_image()encode_text() 方法可返回给定输入的编码特征。

Linear-probe evaluation线性探针评估

下面的示例使用 scikit-learn 对图像特征进行逻辑回归。

import os
import clip
import torch

import numpy as np
from sklearn.linear_model import LogisticRegression
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100
from tqdm import tqdm

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)

# Load the dataset
root = os.path.expanduser("~/.cache")
train = CIFAR100(root, download=True, train=True, transform=preprocess)
test = CIFAR100(root, download=True, train=False, transform=preprocess)


def get_features(dataset):
    all_features = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(DataLoader(dataset, batch_size=100)):
            features = model.encode_image(images.to(device))

            all_features.append(features)
            all_labels.append(labels)

    return torch.cat(all_features).cpu().numpy(), torch.cat(all_labels).cpu().numpy()

# Calculate the image features
train_features, train_labels = get_features(train)
test_features, test_labels = get_features(test)

# Perform logistic regression
classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1)
classifier.fit(train_features, train_labels)

# Evaluate using the logistic regression classifier
predictions = classifier.predict(test_features)
accuracy = np.mean((test_labels == predictions).astype(float)) * 100.
print(f"Accuracy = {accuracy:.3f}")

请注意,C 值应通过使用验证分割进行超参数扫描来确定。

See Also

OpenCLIP:包括更大的、独立训练的 CLIP 模型,最高可达 ViT-G/14
Hugging Face implementation of CLIP:更易于与高频生态系统集成

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/219595.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

千方百计阻止内网崩溃:上海迅软DSE答疑深度解析攻击后果

如今企业数据安全已成为一大议题,内网安全作为企业数据安全的第一步,一旦遭到攻击有多严重? 1.数据泄露:非法攻击企业内网者可能获取到重要的公司数据,包括客户信息、财务数据、知识产权等,导致隐私泄露和经…

linux系统下农场种菜小游戏!

linux系统下农场种菜小游戏! 今天给大家分享一个linux系统下一个简单的小游戏 源码如下,在linux系统下创建一个.sh的脚本文件,复制粘贴进去即可! #!/bin/bash# 初始化变量 vegetables("生菜" "西兰花" &qu…

Netty03-核心组件NioEventLoopGroup解读

NioEventLoopGroup 可以看到NioEventLoopGroup继承了MultithreadEventExecutorGroup并且实现了EventLoopGroup接口,而这两个类被ExecutorService修饰,所以NioEventLoopGroup实际上是一个线程池,池中的对象其实就是单个的NioEventLoop。 源码…

Comprehension from Chaos: Towards Informed Consent for Private Computation

目录 笔记后续的研究方向摘要引言 Comprehension from Chaos: Towards Informed Consent for Private Computation CCS 2023 笔记 本文探讨了用户对私有计算的理解和期望,其中包括多方计算和私有查询执行等技术。该研究进行了 22 次半结构化访谈,以调查…

DataGrip连接虚拟机上Docker部署的Mysql出错解决

1.1 首先判断CentOS的防火墙,如果开启就关闭 //查看防火墙状态 systemctl status firewalld //关闭防火墙systemctl stop firewalld.service//关闭防火墙开机自启systemctl disable firewalld.service而后可以打开DataGrip连接了,如果连接不上执行如下…

如何衡量和提高测试覆盖率?

衡量和提高测试覆盖率,对于尽早发现软件缺陷、提高软件质量和用户满意度,都具有重要意义。如果测试覆盖率低,意味着用例未覆盖到产品的所有代码路径和场景,这可能导致未及时发现潜在缺陷,代码中可能存在逻辑错误、边界…

两电脑共享鼠标键盘方案

一开始使用的是shareMouse 但是需要注册还有很多不稳定问题 后来想买个双拷线,又太贵,感觉不值的。 再后来,发现微软有自己的系统上的 共享方案 ,叫做 Mouse without Borders ,而且是免费的,只能在window电脑上使用…

一张网页截图,AI帮你写前端代码,前端窃喜,终于不用干体力活了

简介 众所周知,作为一个前端开发来说,尤其是比较偏营销和页面频繁改版的项目,大部分的时间都在”套模板“,根本没有精力学习前端技术,那么这个项目可谓是让前端的小伙伴们看到了一丝丝的曙光。将屏幕截图转换为代码&a…

12.5单端口RAM,JS计数器,流水线乘法器,不重叠序列检测器(状态机+移位寄存器),信号发生器,交通灯

单端口RAM timescale 1ns/1nsmodule RAM_1port(input clk,input rst,input enb,input [6:0]addr,input [3:0]w_data,output wire [3:0]r_data );reg [6:0]mem[127:0];integer i;always (posedge clk or negedge rst) beginif(!rst) beginfor (i0; i<127 ; ii1) beginmem[i]…

【开源】基于JAVA的桃花峪滑雪场租赁系统

项目编号&#xff1a; S 036 &#xff0c;文末获取源码。 \color{red}{项目编号&#xff1a;S036&#xff0c;文末获取源码。} 项目编号&#xff1a;S036&#xff0c;文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 游客服务2.2 雪场管理 三、数据库设…

中国AI大模型,应该如何商业化?

虽然大模型商业化的路径较为清晰&#xff0c;目前国内厂商也都在积极探索&#xff0c;但大模型的商业化之路&#xff0c;不能仅限于商业模式的探索尝试&#xff0c;更在于解决大模型发展的底层问题。 作者|斗斗 编辑|皮爷 出品|产业家 如今&#xff0c;大模型的商业化问题再…

制作红木家具3d模型

在线工具推荐&#xff1a; 3D数字孪生场景编辑器 - GLTF/GLB材质纹理编辑器 - 3D模型在线转换 - Three.js AI自动纹理开发包 - YOLO 虚幻合成数据生成器 - 三维模型预览图生成器 - 3D模型语义搜索引擎 在家居行业中&#xff0c;设计师可以通过在3D建模中添加实际的家具、…

竞赛选题 题目:基于深度学习的图像风格迁移 - [ 卷积神经网络 机器视觉 ]

文章目录 0 简介1 VGG网络2 风格迁移3 内容损失4 风格损失5 主代码实现6 迁移模型实现7 效果展示8 最后 0 简介 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 基于深度学习卷积神经网络的花卉识别 该项目较为新颖&#xff0c;适合作为竞赛课题方向&#xff0c…

编程模拟支付宝能量产生过程--数据控制流

#模拟支付宝蚂蚁森林的能量产生过程 behavior_points { # 定义行为对应的积分"步行": 2,"生活缴费": 10,"线下支付": 5,"网络购票": 5,"共享单车": 10 }total_points 0 # 初始化总积分while True: # 开…

Linux性能系统学习之监控工具

目录 前言linux性能度量标准监控工具topuptimeps/pstreefreempstatvmstat 前言 在实际产品开发过程中遇到一些莫名其妙的问题&#xff0c;比如swap交换分区随着时间增多影响到系统调用&#xff0c;或CPU占用以及内存的监测等&#xff0c;所以有必要系统了解Linux的性能问题。 …

待办事项app推荐哪一款?每日待办事项提醒用什么APP

每天的生活中&#xff0c;我们总是充满着各种待办事项&#xff0c;如果不及时处理&#xff0c;就会导致各种问题的出现。在众多的待办事项app中&#xff0c;如何选择一款最适合自己的app呢&#xff1f;所谓待办事项&#xff0c;通常是指尚未着手的事项。在日常生活中&#xff0…

Mysql进阶-事务锁

前置知识-事务 事务简介 事务 是一组操作的集合&#xff0c;它是一个不可分割的工作单位&#xff0c;事务会把所有的操作作为一个整体一起向系统提交或撤销操作请求&#xff0c;即这些操作要么同时成功&#xff0c;要么同时失败。 就比如: 张三给李四转账1000块钱&#xff0…

3D模型制作木质纹理贴图

在线工具推荐&#xff1a; 3D数字孪生场景编辑器 - GLTF/GLB材质纹理编辑器 - 3D模型在线转换 - Three.js AI自动纹理开发包 - YOLO 虚幻合成数据生成器 - 三维模型预览图生成器 - 3D模型语义搜索引擎 本文将讲解如何使用GLTF 编辑器 -NSDT 在线材质编辑工具为3D模型设置…

jQuery选择器、操作DOM、事件处理机制、动画、ADJX操作知识点梳理

jQuery 核心理念就是写的更少,做的更多实现的代码更加简洁有效的提高开发效率 jQuery跟JavaScript的用法是不一样的 跟jQuery相继诞生的JavaScript库还有很多,不包括node.js 关于代码$("li").get(0),获取DOM对象 jQuery对象声明,是通过($)符号来实现的 如…

【【FPGA 之 MicroBlaze XADC 实验】】

FPGA 之 MicroBlaze XADC 实验 Vivado IP 核提供了 XADC 软核&#xff0c;XADC 包含两个模数转换器&#xff08;ADC&#xff09;&#xff0c;一个模拟多路复用器&#xff0c;片上温度和片上电压传感器等。我们可以利用这个模块监测芯片温度和供电电压&#xff0c;也可以用来测…