續(xù)yolov5實戰(zhàn)之二維碼檢測
目錄前沿為什么要做輕量化什么是剪枝稀疏化訓(xùn)練剪枝微調(diào)結(jié)語模型下載前沿在上一篇yolov5的博客中,我們用yolov5訓(xùn)練了一個二維碼檢測器,可以用來檢測圖像中是否有二維碼,后續(xù)可以接一個二維碼解碼器,就可以解碼出二維碼的信息了(后續(xù)可以聊聊)。這篇博客再講講另一個方面:模型輕量化,具體的是輕量化中的模型剪枝。
【資料圖】
我們訓(xùn)練的模型不僅僅會用在GPU這種算力高的硬件上,也有可能用在嵌入式CPU或者NPU上,這類硬件算力往往較低,盡管在這些設(shè)備上運(yùn)行模型時,我們可以將模型量化為int8,可以大大降低計算量,但有時候只靠這一方式也是不夠的。比較直觀能想到的提升模型運(yùn)行速度的方式是裁剪模型,比如減少通道數(shù)或模型的深度,這種方式是以犧牲模型精度為代價的。這就促使我們尋找更好的模型輕量化方法,剪枝就是一種使用比較廣泛的模型輕量化方法。
什么是剪枝模型剪枝(Model Pruning)是一種通過減少神經(jīng)網(wǎng)絡(luò)模型中的冗余參數(shù)和連接來優(yōu)化模型的方法。它旨在減小模型的大小、內(nèi)存占用和計算復(fù)雜度,同時盡可能地保持模型的性能。
模型剪枝的基本思想是通過識別和刪除對模型性能影響較小的參數(shù)或連接,以達(dá)到模型精簡和優(yōu)化的目的。方法包括剪枝后的參數(shù)微調(diào)、重新訓(xùn)練和微調(diào)整體網(wǎng)絡(luò)結(jié)構(gòu)等。直觀的理解就是像下圖這樣。??模型剪枝可以在不顯著損失模型性能的情況下,大幅度減少模型的參數(shù)量和計算量,從而提高模型的部署效率和推理速度。它特別適用于嵌入式設(shè)備、移動設(shè)備和邊緣計算等資源受限的場景,以及需要部署在較小存儲空間或帶寬受限環(huán)境中的應(yīng)用。本文選擇的模型剪枝方法:Learning Efficient Convolutional Networks through Network Slimming源代碼:https://github.com/foolwood/pytorch-slimming這個方法基于的想法是通過稀疏化訓(xùn)練,通過BN層的參數(shù),自動得到權(quán)重較小通道,去掉這些通道,從而達(dá)到模型裁剪的目的。
稀疏化訓(xùn)練如上文述,為了達(dá)到剪枝的目的,我們要使用稀疏化訓(xùn)練,以使得讓模型權(quán)重更緊湊,能夠去掉一些權(quán)重較小的通道,達(dá)到模型裁剪的目的。為了進(jìn)行稀疏化訓(xùn)練,引入一個稀疏化稀疏參數(shù),這個參數(shù)越大,模型越稀疏,能夠裁剪的比例越大,需要在實際中調(diào)整,參數(shù)過大,模型性能可能會下降較多,參數(shù)過小,能夠裁剪的比例又會過小。??為了進(jìn)行稀疏化訓(xùn)練,首先匯總模型的所有BN層:
if opt.sl > 0: print("Sparse Learning Model!") print("===> Sparse learning rate is ", hyp["sl"]) prunable_modules = [] prunable_module_type = (nn.BatchNorm2d, ) for i, m in enumerate(model.modules()): if isinstance(m, prunable_module_type): prunable_modules.append(m)
在訓(xùn)練loss中增加稀疏化loss:
def compute_pruning_loss(p, prunable_modules, model, loss): """ Compute the pruning loss :param p: predicted output :param prunable_modules: list of prunable modules :param model: model :param loss: original yolo loss :return: loss """ float_tensor = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor sl_loss = float_tensor([0]) hyp = model.hyp # hyperparameters red = "mean" # Loss reduction (sum or mean) if prunable_modules is not None: for m in prunable_modules: sl_loss += m.weight.norm(1) sl_loss /= len(prunable_modules) sl_loss *= hyp["sl"] bs = p[0].shape[0] # batch size loss += sl_loss * bs return loss
# Forward with amp.autocast(enabled=cuda): pred = model(imgs) # forward loss, loss_items = compute_loss(pred, targets.to(device), model) # loss scaled by batch_size # Sparse Learning if opt.sl > 0: loss = compute_pruning_loss(pred, prunable_modules, model, loss) if rank != -1: loss *= opt.world_size # gradient averaged between devices in DDP mode
設(shè)置合適的稀疏化稀疏進(jìn)行訓(xùn)練,這一過程和普通的yolov5模型訓(xùn)練一樣。
剪枝pruning.py
#!/usr/bin/env python# -*- coding: utf-8 -*-"""Copyright (c) 2019 luozw, Inc. All Rights ReservedAuthors: luozhiwang(luozw1994@outlook.com)Date: 2020/9/7"""import osimport argparseimport numpy as npimport torchimport torch.nn as nnimport torch_pruning as tpimport copyimport matplotlib.pyplot as pltfrom models.yolo import Modelimport mathdef load_model(cfg="models/mobile-yolo5l_voc.yaml", weights="./outputs/mvoc/weights/best_mvoc.pt"): restor_num = 0 ommit_num = 0 model = Model(cfg).to(device) ckpt = torch.load(weights, map_location=device) # load checkpoint names = ckpt["model"].names dic = {} for k, v in ckpt["model"].float().state_dict().items(): if k in model.state_dict() and model.state_dict()[k].shape == v.shape: dic[k] = v restor_num += 1 else: ommit_num += 1 print("Build model from", cfg) print("Resotre weight from", weights) print("Restore %d vars, ommit %d vars" % (restor_num, ommit_num)) ckpt["model"] = dic model.load_state_dict(ckpt["model"], strict=False) del ckpt model.float() model.model[-1].export = True return model, namesdef bn_analyze(prunable_modules, save_path=None): bn_val = [] max_val = [] for layer_to_prune in prunable_modules: # select a layer weight = layer_to_prune.weight.data.detach().cpu().numpy() max_val.append(max(weight)) bn_val.extend(weight) bn_val = np.abs(bn_val) max_val = np.abs(max_val) bn_val = sorted(bn_val) max_val = sorted(max_val) plt.hist(bn_val, bins=101, align="mid", log=True, range=(0, 1.0)) if save_path is not None: if os.path.isfile(save_path): os.remove(save_path) plt.savefig(save_path) return bn_val, max_valdef channel_prune(ori_model, example_inputs, output_transform, pruned_prob=0.3, thres=None, rules=1): model = copy.deepcopy(ori_model) model.cpu().eval() prunable_module_type = (nn.BatchNorm2d) ignore_idx = [] #[230, 260, 290] prunable_modules = [] for i, m in enumerate(model.modules()): if i in ignore_idx: continue if isinstance(m, nn.Upsample): continue if isinstance(m, prunable_module_type): prunable_modules.append(m) ori_size = tp.utils.count_params(model) DG = tp.DependencyGraph().build_dependency(model, example_inputs=example_inputs, output_transform=output_transform) bn_val, max_val = bn_analyze(prunable_modules, "render_img/before_pruning.jpg") if thres is None: thres_pos = int(pruned_prob * len(bn_val)) thres_pos = min(thres_pos, len(bn_val)-1) thres_pos = max(thres_pos, 0) thres = bn_val[thres_pos] print("Min val is %f, Max val is %f, Thres is %f" % (bn_val[0], bn_val[-1], thres)) for layer_to_prune in prunable_modules: # select a layer weight = layer_to_prune.weight.data.detach().cpu().numpy() if isinstance(layer_to_prune, nn.Conv2d): if layer_to_prune.groups > 1: prune_fn = tp.prune_group_conv else: prune_fn = tp.prune_conv L1_norm = np.sum(np.abs(weight), axis=(1, 2, 3)) elif isinstance(layer_to_prune, nn.BatchNorm2d): prune_fn = tp.prune_batchnorm L1_norm = np.abs(weight) pos = np.array([i for i in range(len(L1_norm))]) pruned_idx_mask = L1_norm < thres prun_index = pos[pruned_idx_mask].tolist() if rules != 1: prune_channel_nums = len(L1_norm) - max(rules, int((len(L1_norm) - pruned_idx_mask.sum())/rules + 0.5)*rules) _, index = torch.topk(torch.tensor(L1_norm), prune_channel_nums, largest=False) prun_index = index.numpy().tolist() if len(prun_index) == len(L1_norm): del prun_index[np.argmax(L1_norm)] plan = DG.get_pruning_plan(layer_to_prune, prune_fn, prun_index) plan.exec() bn_analyze(prunable_modules, "render_img/after_pruning.jpg") with torch.no_grad(): out = model(example_inputs) if output_transform: out = output_transform(out) print(" Params: %s => %s" % (ori_size, tp.utils.count_params(model))) if isinstance(out, (list, tuple)): for o in out: print(" Output: ", o.shape) else: print(" Output: ", out.shape) print("------------------------------------------------------\n") return modelif __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--cfg", default="models/yolov5s_voc.yaml", type=str, help="*.cfg path") parser.add_argument("--weights", default="runs/exp7_sl-2e-3-yolov5s/weights/last.pt", type=str, help="*.data path") parser.add_argument("--save-dir", default="runs/exp7_sl-2e-3-yolov5s/weights", type=str, help="*.data path") parser.add_argument("-r", "--rate", default=1, type=int, help="通道數(shù)為rate的倍數(shù)") parser.add_argument("-p", "--prob", default=0.5, type=float, help="pruning prob") parser.add_argument("-t", "--thres", default=0, type=float, help="pruning thres") opt = parser.parse_args() cfg = opt.cfg weights = opt.weights save_dir = opt.save_dir device = torch.device("cpu") model, names = load_model(cfg, weights) example_inputs = torch.zeros((1, 3, 64, 64), dtype=torch.float32).to() output_transform = None # for prob in [0, 0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]: if opt.thres != 0: thres = opt.thres prob = "p.auto" else: thres = None prob = opt.prob pruned_model = channel_prune(model, example_inputs=example_inputs, output_transform=output_transform, pruned_prob=prob, thres=thres,rules=opt.rate) pruned_model.model[-1].export = False pruned_model.names = names save_path = os.path.join(save_dir, "pruned_"+str(prob).split(".")[-1] + ".pt") print(pruned_model) torch.save({"model": pruned_model.module if hasattr(pruned_model, "module") else pruned_model}, save_path)
可以按比例剪枝, 如剪枝比例0.5:
python prune.py --cfg models/yolov5s_voc.yaml --weights runs/exp7_sl-2e-3-yolov5s/weights/last.pt --save-dir runs/exp7_sl-2e-3-yolov5s/weights/ --prob 0.5
還可以按權(quán)重大小剪枝,比如小于0.01權(quán)重的通道剪:
python prune.py --cfg models/yolov5s_voc.yaml --weights runs/exp7_sl-2e-3-yolov5s/weights/last.pt --save-dir runs/exp7_sl-2e-3-yolov5s/weights/ --thres 0.01
往往通道是8的倍數(shù)時,神經(jīng)網(wǎng)絡(luò)推理較快:
python prune.py --cfg models/yolov5s_voc.yaml --weights runs/exp7_sl-2e-3-yolov5s/weights/last.pt --save-dir runs/exp7_sl-2e-3-yolov5s/weights/ --thres 0.01 --rate 8
執(zhí)行剪枝后,模型將會變小。
微調(diào)剪枝后,模型性能會下降,此時我們需要再微調(diào)剪枝后的模型,其訓(xùn)練過程與剪枝前訓(xùn)練方式一致。一般情況下,可以接近剪枝前的性能。
結(jié)語通過剪枝可以在精度損失較小的情況下,加快模型的推理速度,在我們需要做實時分析的任務(wù)中非常有用。
模型下載輕量級二維碼檢測模型:模型下載
標(biāo)簽: