网站开发设计中的收获,做娱乐网站,wordpress php7 iis,宁波seo排名费用目录 文章目录 目录网络剪枝——network-slimming 项目复现clone 存储库Baselinevgg训练结果 resnet训练结果 densenet训练结果 Sparsityvgg训练结果 resnet训练结果 densenet训练结果 Prunevgg命令结果 resnet命令结果 densenet命令结果 Fine-tunevgg训练结果 resnet训练结果 …目录 文章目录 目录网络剪枝——network-slimming 项目复现clone 存储库Baselinevgg训练结果 resnet训练结果 densenet训练结果 Sparsityvgg训练结果 resnet训练结果 densenet训练结果 Prunevgg命令结果 resnet命令结果 densenet命令结果 Fine-tunevgg训练结果 resnet训练结果 densenet训练结果 模型大小计算脚本 param_counter.py结果汇总CIFAR10 网络剪枝——network-slimming 项目复现
【GiHnub】Eric-mingjie/network-slimming: Network Slimming (Pytorch) (ICCV 2017) (github.com)【作者复现项目】通过百度网盘分享的文件network-slimming-regin.zip 链接https://pan.baidu.com/s/1vTJSLS5ZDjE8R8XaApW96A?pwdt1z2 提取码t1z2 仅以 CIFAR-10 为例CIFAR-100 同理.提供中文README_zh-CN.md.包含 CIFAR-10/100 数据集data.cifar10、data.cifar100.解决了 main.py 运行报错问题.加入了计算训练后模型的 Parameters 大小脚本param_counter.py.
clone 存储库 注若 clone 作者复现项目则忽略这一步直接进入下一步若想自行从头复现则 clone 以下存储库. 链接https://pan.baidu.com/s/1nppPLKoiPbJPW60HOa2TxQ?pwdud89 提取码ud89 Baseline
vgg
训练
【命令】
python main.py --dataset cifar10 --arch vgg --depth 19这个报错通常出现在使用 Python 的multiprocessing库来创建进程时尤其是在 Windows 操作系统上. 在 Windows 上Python 的multiprocessing模块启动新进程的方式与 Linux 或 macOS 不同它使用 “spawn” 来启动新进程这意味着每个子进程都会从头开始执行脚本. 因此如果在脚本顶层级别启动进程而不是在受保护的if __name__ __main__:块中每个子进程都会尝试再次启动子进程从而导致无限递归和上述错误. 为了解决这个问题应 确保多进程代码即main.py位于if __name__ __main__:保护块内.
# 导入部分
...def main():...if __name__ __main__:main()再次运行命令又报错 这个报错通常发生在尝试直接索引一个0维的张量tensor时. 在 PyTorch 中0 维张量是一个单一值的张量但是不能像普通的数组那样通过索引来访问。要从 0 维张量中获取其 Python 数值需要使用.item()方法. 为了解决这个问题应该 使用.item()方法来替换所有.data[0]的用法
# 在 train 函数中
if batch_idx % args.log_interval 0:print(Train Epoch: {} [{}/{} ({:.1f}%)]\tLoss: {:.6f}.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))# 在 test 函数中
for data, target in test_loader:if args.cuda:data, target data.cuda(), target.cuda()data, target Variable(data), Variable(target)output model(data)test_loss F.cross_entropy(output, target, reductionsum).item() # sum up batch losspred output.data.max(1, keepdimTrue)[1]correct pred.eq(target.data.view_as(pred)).cpu().sum()test_loss / len(test_loader.dataset)再次运行命令就正常运行了 结果
Terminal 在 ./logs 生成文件checkpoint.pth.tar、model_best.pth.tar resnet
训练
【命令】
python main.py --dataset cifar10 --arch resnet --depth 164结果 densenet
训练
【命令】
python main.py --dataset cifar10 --arch densenet --depth 40结果 Sparsity
vgg
训练
【命令】
python main.py -sr --s 0.0001 --dataset cifar10 --arch vgg --depth 19结果 resnet
训练
【命令】
python main.py -sr --s 0.00001 --dataset cifar10 --arch resnet --depth 164结果 densenet
训练
【命令】
python main.py -sr --s 0.00001 --dataset cifar10 --arch densenet --depth 40结果 Prune
vgg
命令
python vggprune.py --dataset cifar10 --depth 19 --percent 0.7 --model ./results/CIFAR10_results/CIFAR10-Vgg/Sparsity/model_best.pth.tar --save ./prunes与main.py同理为了解决这个问题应 确保多进程代码位于if __name__ __main__:保护块内
# 导入部分
...def main():...if __name__ __main__:main()之后就可以正常运行了. 结果
Terminal 在./prunes生成文件prune.txt、pruned.pth.tar 在prune.txt中我们可以看到 Number of parameters、Test accuracy resnet
命令
python resprune.py --dataset cifar10 --depth 164 --percent 0.4 --model ./results/CIFAR10_results/CIFAR10-Resnet-164/Sparsity/model_best.pth.tar --save ./prunes结果 densenet
命令
python denseprune.py --dataset cifar10 --depth 40 --percent 0.4 --model ./results/CIFAR10_results/CIFAR10-Densenet-40/Sparsity/model_best.pth.tar --save ./prunes结果 Fine-tune
vgg
训练
【命令】
python main.py --refine ./results/CIFAR10_results/CIFAR10-Vgg/Prune/pruned.pth.tar --dataset cifar10 --arch vgg --depth 19 --epochs 160结果 resnet
训练
【命令】
python main.py --refine ./results/CIFAR10_results/CIFAR10-Resnet-164/Prune/pruned.pth.tar --dataset cifar10 --arch resnet --depth 164 --epochs 160结果 densenet
训练
【命令】
python main.py --refine ./results/CIFAR10_results/CIFAR10-Densenet-40/Prune/pruned.pth.tar --dataset cifar10 --arch densenet --depth 40 --epochs 160结果 模型大小计算脚本 param_counter.py
【路径】./script/param_counter.py
import torchdef load_model(model_path):model torch.load(model_path, map_locationtorch.device(cpu))return modeldef count_parameters(model_state_dict):total_params sum(p.numel() for p in model_state_dict.values())return total_paramsdef get_model_parameters(model_path):# 加载模型状态字典model load_model(model_path)# 模型状态字典存储在 state_dict 键下model_state_dict model[state_dict] if state_dict in model else model# 计算参数总数total_params count_parameters(model_state_dict)return total_params在main.py中
from script.param_counter import get_model_parametersdef main():...# 计算 Parametersmodel_path logs/model_best.pth.tartotal_params get_model_parameters(model_path)print(fTotal parameters in the model: {total_params})结果汇总 注与原项目结果略有差别. CIFAR10
CIFAR10-VggBaselineSparsity(1e-4)Prune(70%)Fine-tune-160(70%)Top1 Accuracy(%)93.7293.6033.9893.75Parameters20.05M20.05M2.22M2.23M
CIFAR10-Resnet-164BaselineSparsity(1e-5)Prune(40%)Fine-tune-160(40%)Top1 Accuracy(%)94.9995.0094.5995.27Parameters1.74M1.74M1.46M1.49M
CIFAR10-Densenet-40BaselineSparsity(1e-5)Prune(40%)Fine-tune-160(40%)Top1 Accuracy(%)94.1594.3794.1494.48Parameters1.09M1.09M0.70M0.72M