深度学习系列64:数字人wav2lip详解

1. 整体流程

第一步,加载视频/图片和音频/tts。用melspectrogram将wav文件拆分成mel_chunks。
第二步,调用face_detect模型,给出人脸检测结果(可以改造成从文件中读取),包装成4个数组batch:img_batch(人脸),mel_batch(语音),frame_batch(原图),coords_batch(坐标)
第三步,加载模型,进行计算。这个模型目前看下来就是简单的resnet,没有transfomer。另外mask也不是用分割模型,而是直接将图片下半部分全部作为mask😄,然后将mask图片拼接到原图片的色彩通道上作为输入。
第四步:预测出来的人脸拼接到原图上,输出位视频。

2. 优缺点

优点:极其简单,一个人脸检测模型+一个基于CNN的lipsync模型,速度很快。
缺点:嘴唇经常是歪的,而且有变形;牙齿不断在闪烁。经过图像增强后,我们取出截图如下:
在这里插入图片描述

3. 其他版本

3.1 Easy_Wav2Lip

这个版本相当好用。首先执行python install.py来下载模型文件。然后配置一下config.ini,执行python run.py即可。
生成配置文件的代码可以在目录下的Easy_Wav2Lip_v8.3.ipynb中来修改;也可以通过执行python GUI.py打开图形界面来修改:
在这里插入图片描述
执行代码的入口仍然是inference.py。这里说明一下分支内容:

  1. 基础人脸检测模型为RetinaFace,模型文件为checkpoints/mobilenet.pth。
  2. 如果使用Imporved模式,会调用load_sr()方法加载sr_model(gfpgan做super resolution,参数文件);如果使用Enhanced,会进行upscale。具体的表现是:如果仅使用imporved模式,嘴部会比较模糊;使用enhanced模式会得到清晰度统一的视频。
  3. 如果mouth_tracking为true,则会调用复杂一些的create_tracked_mask;否则仅启用create_mask
  4. 模型可选用"Wav2Lip", "Wav2Lip_GAN"两种。

在github的项目文件里面有一个ipynb文件可供学习。

3.2 Wav2Lip-fast

使用如下代码执行:
python inference.py --checkpoint_path <ckpt> --face <video.mp4> --audio <an-audio-source> --multiplier <multiplier-to-fasten-process>
这里的multiplier,指的是每隔多少帧进行一次face detection。
简化版代码如下:

import cv2,audio,face_detection,subprocess,torch,platform,sys
from models import Wav2Lip
from tqdm import tqdm
import numpy as np
facefile = '../openheygen/video-retalking/examples/face/2.mp4'
audiofile = '../openheygen/video-retalking/examples/audio/1.wav'
checkpoint_path = 'checkpoints/wav2lip_gan.pth'
base_name = facefile.split('/')[-1]
device = 'mps'
fps = 25
mel_step_size = 16
multiplier = 1
img_size = 96
face_det_batch_size = 16
batch_size = 128
wav = audio.load_wav(audiofile, 16000)
mel = audio.melspectrogram(wav)
mel_chunks = []
mel_idx_multiplier = 80./fps 
detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,flip_input=False, device=device)

def face_detect(images, multiplier=1):
    predictions = []
    batch_size = face_det_batch_size
    for i in range(0, len(images), batch_size * multiplier):
        predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size]), multiplier))
    
    results = []
    pady1, pady2, padx1, padx2 = [0, 10, 0, 0]
    for rect, image in zip(predictions, images):
        y1 = max(0, rect[1] - pady1)
        y2 = min(image.shape[0], rect[3] + pady2)
        x1 = max(0, rect[0] - padx1)
        x2 = min(image.shape[1], rect[2] + padx2)
        results.append([x1, y1, x2, y2])
    boxes = np.array(results)
    results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
    return results 

def datagen(frames, mels, multiplier):
    img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
    face_det_results = face_detect(frames, multiplier) 
    for i, m in enumerate(mels):
        idx = i%len(frames)
        frame_to_save = frames[idx].copy()
        face, coords = face_det_results[idx].copy()
        face = cv2.resize(face, (img_size, img_size))
        img_batch.append(face)
        mel_batch.append(m)
        frame_batch.append(frame_to_save)
        coords_batch.append(coords)
        if len(img_batch) >= batch_size:
            img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
            img_masked = img_batch.copy()
            img_masked[:, img_size//2:] = 0
            img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
            mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
            yield img_batch, mel_batch, frame_batch, coords_batch
            img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []

    if len(img_batch) > 0:
        img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
        img_masked = img_batch.copy()
        img_masked[:, img_size//2:] = 0
        img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
        mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
        yield img_batch, mel_batch, frame_batch, coords_batch

def load_model(path):
    model = Wav2Lip()
    print("Load checkpoint from: {}".format(path)) #torch.load(checkpoint_path)
    checkpoint = torch.load(path,map_location=torch.device(device))
    s = checkpoint["state_dict"]
    new_s = {}
    for k, v in s.items():
        new_s[k.replace('module.', '')] = v
    model.load_state_dict(new_s)
    model = model.to(device)
    return model.eval()

print('step1: read files...')    
if facefile.split('.')[-1] in ['png','jpg','jpeg']:
    full_frames = [cv2.imread(facefile)]
else:
    full_frames = []
    video_stream = cv2.VideoCapture(facefile)
    fps = video_stream.get(cv2.CAP_PROP_FPS)
    while 1:
        still_reading, frame = video_stream.read()
        if not still_reading:
            video_stream.release()
            break
        full_frames.append(frame)
i = 0
while 1:
    start_idx = int(i * mel_idx_multiplier)
    if start_idx + mel_step_size > len(mel[0]):
        mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
        break
    mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size])
    i += 1
full_frames = full_frames[:len(mel_chunks)]
gen = datagen(full_frames.copy(), mel_chunks, multiplier)


print('step2: load model and predict lip...')
results =[]
model = load_model(checkpoint_path)
frame_h, frame_w = full_frames[0].shape[:-1]
for i, (img_batch, mel_batch, frames, coords) in enumerate(tqdm(gen,total=int(np.ceil(float(len(mel_chunks))/batch_size)))):
    img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
    mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
    with torch.no_grad():
        pred = model(mel_batch, img_batch)
    pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
    for p, f, c in zip(pred, frames, coords):
        y1, y2, x1, x2 = c
        p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1))
        f[y1:y2, x1:x2] = p
        results.append(f)
        
print('step3: write file with audio...')
import matplotlib.pyplot as plt
out = cv2.VideoWriter('temp.mp4', cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_w, frame_h))
for pp in results:
    out.write(pp)
out.release()
command = 'ffmpeg -loglevel error -y -i {} -i {} -strict -2 -q:v 1 {}'.format(audiofile, 'temp.mp4', 'result.mp4')
subprocess.call(command, shell=platform.system() != 'Windows')
from IPython.display import HTML
display(HTML("""
  <video height=400 controls>
        <source src=result.mp4 type="video/mp4">
  </video>"""))

如果需要的话,可以进行一次画质增强:

sys.path.insert(0, 'third_part/GFPGAN')
from third_part.GFPGAN.gfpgan import GFPGANer
restorer = GFPGANer(model_path='checkpoints/GFPGANv1.3.pth', 
                    upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
final_results = []
for r in tqdm(results):
    final_results.append(restorer.enhance(r, has_aligned=False, only_center_face=True, paste_back=True)[2])  
    
import matplotlib.pyplot as plt
plt.imshow(cv2.cvtColor(final_results[0], cv2.COLOR_BGR2RGB))
out = cv2.VideoWriter('temp.mp4', cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_w, frame_h))
for pp in final_results:
    out.write(pp)
out.release()
command = 'ffmpeg -loglevel error -y -i {} -i {} -strict -2 -q:v 1 {}'.format(audiofile, 'temp.mp4', 'result.mp4')
subprocess.call(command, shell=platform.system() != 'Windows')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/585539.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

74、堆-数组中的第K个最大元素

思路&#xff1a; 直接排序是可以的&#xff0c;但是时间复杂度不符合。可以使用优先队列&#xff0c;代码如下&#xff1a; class Solution {public int findKthLargest(int[] nums, int k) {if (numsnull||nums.length0||k<0||k>nums.length){return Integer.MAX_VAL…

神之浩劫2测试资格100%获取教程 测试资格获取方法教程

《神之浩劫》是一款基于Unreal 3&#xff08;虚幻3&#xff09;游戏引擎开发的3D团队竞技游戏&#xff0c;由美国Hi-Rez工作室开发、腾讯全球代理。2013年10月31日&#xff0c;游戏开启国服首测&#xff0c;并于2014年3月25日在美国公测。2018年1月20日&#xff0c;国服并入全球…

shell脚本-监控系统内存和磁盘容量

监控内存和磁盘容量除了可以使用zabbix监控工具来监控&#xff0c;还可以通过编写Shell脚本来监控。 #! /bin/bash #此脚本用于监控内存和磁盘容量&#xff0c;内存小于500MB且磁盘容量小于1000MB时报警#提取根分区剩余空间 disk_size$(df / | awk /\//{print $4})#提取内存剩…

Redis(七) zset有序集合类型

文章目录 前言命令ZADDZCARDZCOUNTZRANGEZREVRANGEZRANGEBYSCOREZPOPMAXZPOPMIN两个阻塞版本的POP命令BZPOPMAX BZPOPMINZRANKZREVRANKZSCOREZREMZREMRANGEBYRANKZREMRANGEBYSCOREZINCRBY集合间操作ZINTERSTOREZUNIONSTORE 命令小结 内部编码使用场景 前言 对于有序集合这个名…

Java核心技术.卷I-上-笔记

目录 面向对象程序设计 使用命令行工具简单的编译源码 数据类型 StringBuilder 数组 对象与类 理解方法调用 继承 代理 异常 断言 日志 面向对象程序设计 面向对象的程序是由对象组成的&#xff0c;每个对象包含对用户公开的特定功能部分和隐藏的实现部分从根本上…

高校宿舍管理

在高等教育的迅猛发展浪潮中&#xff0c;大学校园正经历着前所未有的变革。随着招生规模的不断扩大&#xff0c;学生宿舍管理工作变得日益繁重和复杂。传统的管理方法&#xff0c;如使用Word和Excel进行数据记录和整理&#xff0c;已经无法满足现代高效、精准的管理需求。此外&…

关于几个水表术语的理解

GB/T778.1-2018《饮用冷水水表和热水水表 第 1 部分&#xff1a;量值要求和技术要求》、JJG162-2019《饮 用冷水水表检定规程》和 JJF1777-2019《饮用冷 水水表型式评价大纲》不仅规范了水表行业的专业名词解释&#xff0c;而且给出了影响水表性能的主要因素的定义。本文从影响…

Mellanox网卡打流命令ib_write_bw执行遇到Couldn‘t listen to port 18515原因与解决办法?

要点 要点&#xff1a; ib默认使用18515命令 相关命令&#xff1a; netstat -tuln | grep 18515 ib_write_bw --help |grep port# server ib_write_bw --ib-devmlx5_1 --port 88990 # client ib_write_bw --ib-devmlx5_0 1.1.1.1 --port88990现象&#xff1a; 根因&#xf…

Spring Cloud Feign

序言 本文给大家介绍一下 Spring Cloud Feign 的基础概念以及使用方式。 一、远程调用 在传统的单体系统中&#xff0c;我们通常是客户端去请求服务端的接口。但是在分布式的系统中&#xff0c;常常需要一个服务去调用另外一个服务的接口。在服务端如何去调用另外一个服务端…

docker compose mysql主从复制及orchestrator高可用使用

1.orchestrator 功能演示&#xff1a; 1.1 多级级联&#xff1a; 1.2 主从切换&#xff1a; 切换成功后&#xff0c;原来的主库是红色的&#xff0c;需要在主库的配置页面点击“start replication ”&#xff0c;重新连接上新的主库。 1.3 主从故障&#xff0c;从库自动切换新…

pyqt字体选择器

pyqt字体选择器 pyqt字体选择器效果代码 pyqt字体选择器 pyqt中QFontDialog 类是一个预定义的对话框&#xff0c;允许用户选择一个字体并设置其样式、大小等属性。 效果 代码 from PyQt5.QtWidgets import QApplication, QWidget, QVBoxLayout, QLabel, QPushButton, QFontD…

信息收集。

信息收集 接着使用cs进行信息收集 发现域内管理员账号。 然后查看pc信息&#xff0c; 查看进程。 发现域为god.org 尝试定位域控。 提权 使用cs的功能进行权限提权 成功获取管理员权限。 hash抓取 接着抓hash 成功抓到管理员账号、密码。 接着进行横向传递 成功获取AD和…

从曝光到安装:App传参安装的关键步骤与数据指标

随着移动互联网的普及&#xff0c;手游市场日益繁荣&#xff0c;手游推广方式也日新月异。在这个竞争激烈的市场中&#xff0c;如何有效地推广手游&#xff0c;吸引更多的用户&#xff0c;成为了开发者和广告主关注的焦点。而Xinstall作为国内专业的App全渠道统计服务商&#x…

Android Widget开发代码示例详细说明

因为AppWidgetProvider扩展自BroadcastReceiver, 所以你不能保证回调函数完成调用后&#xff0c;AppWidgetProvider还在继续运行。 a. AppWidgetProvider 的实现 /*** Copyright(C):教育电子有限公司 * Project Name: NineSync* Filename: SynWidgetProvider.java * Author(S…

ORACLE 11G RAC 访问SQLSERVER

平时都是单机&#xff0c;RAC有点不一样&#xff0c;其实也一样。 目录 1.操作环境信息 2.安装GATEWAY 3.配置实例信息 4.配置监听 5.配置网络别名 6.创建到SQLSERVER的DBLINK 7.测试DBLINK有效性 1.操作环境信息 HIS PACS 数据库版本 ORACLE 11.2.0.4 RAC MS SQLSE…

C++多态(全)

多态 概念 调用函数的多种形态&#xff0c; 多态构成条件 1&#xff09;父子类完成虚函数的重写&#xff08;三同&#xff1a;函数名&#xff0c;参数&#xff0c;返回值相同&#xff09; 2&#xff09;父类的指针或者引用调用虚函数 虚函数 被virtual修饰的类成员函数 …

Llama images - 记录我看到的那些羊驼

来自 &#xff1a; DREAM: Distributed RAG Experimentation Framework

73、栈-柱状图中最大的矩形

思路&#xff1a; 矩形面积&#xff1a;宽度*高度 高度如何确定呢&#xff1f;就是在宽度中最矮的元素。如何确定宽度&#xff0c;就是要确定左右边界。 当我们在处理直方图最大矩形面积问题时&#xff0c;遇到一个比栈顶柱子矮的新柱子时开始计算面积的原因关键在于如何确定…

Hotcoin Research|玩赚WEB3:Seraph零成本赚取技巧

在《Seraph》这款游戏里&#xff0c;要提升自己的游戏技能和体验&#xff0c;了解如何免费赚取游戏货币灵魂晶石并挑战游戏主线是非常重要的。你可以通过卖东西、参加虚空异界地图和混沌秘境来在游戏里赚更多的钱&#xff0c;并更享受游戏的乐趣。最酷的是&#xff0c;得到的灵…

低功耗数字IC后端设计实现典型案例| UPF Flow如何避免工具乱用Always On Buffer?

下图所示为咱们社区低功耗四核A7 Top Hierarchical Flow后端训练营中的一个案例&#xff0c;设计中存在若干个Power Domain&#xff0c;其中Power Domain2(简称PD2)为default Top Domain&#xff0c;Power Domain1&#xff08;简称PD1&#xff09;为一个需要power off的domain&…
最新文章