0此镜像用于Musetalk的泛化模型训练,提供了完成的训练数据集HDTF,共包含403个视频素材
本镜像构建和运行所需的基础环境。
# conda activate preprocess
# cd /root/MuseTalk
/root/MuseTalk/configs/training/preprocess.yaml
该配置文件需要修改两处:
训练视频所在目录,videos文件夹中提供了几个视频素材可做快速验证。/root/MuseTalk/dataset/HDTF目录下有一个videos.zip,是完整的训练视频,可解压后使用。
测试集文件列表,注意【不包括文件名的后缀 .mp4】,测试集自行在videos文件夹中进行挑选
配置文件/root/MuseTalk/configs/training/.yaml

gpu_ids: 使用的GPU的id,如果电脑只有一张显卡,默认是 0num_processes: 只有一张显卡,则写1;应该与gpu_ids中的数量一致zero_stage: DeepSpeed 里的大模型显存优化技术。如果报错,需要改成 0这个文件我已经修改好,无需修改
python -m scripts.preprocess --config ./configs/training/preprocess.yaml
conda activate train
configs/training/stage1.yaml
data.train_bs: Adjust batch size based on your GPU memory (default: 32)data.n_sample_frames: Number of sampled frames per video (default: 1)
solver.max_train_steps: 最大训练步骤,如果希望快速验证,可调整此参数

total_limit: 保存 checkpoints 最大数量save_model_epoch_interval: 模型保存间隔checkpointing_steps: checkpoint 保存间隔val_freq: 验证频率sh train.sh stage1

conda activate train
configs/training/stage2.yaml
data.train_bs: Smaller batch size due to high GPU memory cost (default: 2)data.n_sample_frames: Higher value for temporal consistency (default: 16)
solver.gradient_accumulation_steps: Increase to simulate larger batch sizes (default: 8)solver.max_train_steps: 最大训练步骤,如果希望快速验证,可调整此参数
total_limit: 保存 checkpoints 最大数量save_model_epoch_interval: 模型保存间隔checkpointing_steps: checkpoint 保存间隔val_freq: 验证频率==需要快速验证建议调小==
sh train.sh stage2

训练的模型位置在 exp_out/stage2/test 目录中
~/MuseTalk# ls -l exp_out/stage2/test/
total 16601972
drwxr-xr-x 4 root root 4096 Apr 13 10:25 ./
drwxr-xr-x 3 root root 4096 Apr 13 10:18 ../
drwxr-xr-x 2 root root 4096 Apr 13 10:25 samples/
drwxr-xr-x 3 root root 4096 Apr 13 10:19 tensorboard/
-rw-r--r-- 1 root root 3400073943 Apr 13 10:21 unet-100.pth
-rw-r--r-- 1 root root 3400073943 Apr 13 10:23 unet-150.pth
-rw-r--r-- 1 root root 3400073943 Apr 13 10:24 unet-200.pth
-rw-r--r-- 1 root root 3400073943 Apr 13 10:25 unet-250.pth
-rw-r--r-- 1 root root 3400072675 Apr 13 10:20 unet-50.pth

支持自启动