miao blog
2448 words
12 minutes
从梯度裁剪到模型剪枝:SEVEN方法的探索

摘要#

论文简介#

《SEVEN- Pruning Transformer Model by Reserving Sentinels》(SEVEN:通过保留Sentinels修剪Transformer模型)是一篇关注于Transformer模型压缩的研究论文。Transformer模型虽然在各种任务上表现出色,但其庞大的参数规模限制了在移动设备上的应用。常见的修剪方法通常倾向于保留梯度噪声较大的权重,这会使修剪后的模型对稀疏度和数据集敏感,表现不佳。为了改进这一点,本文提出了一种名为SEVEN的新方法,特别偏好那些始终保持高敏感性的权重(即梯度噪声小的权重),并倾向于保留这些权重。

SEVEN方法基于一种名为符号下降(Symbolic Descent, SD)的梯度处理技术。通过SD的累积过程描述Transformer模型(TM)上的噪声批量梯度序列,动态地评估权重的重要性分数。

我们先从论文本身介绍去了解它所使用的技术是如何实现的,然后再进一步去根据代码一步步去复现它的操作。

SEVEN方法的两种形式:#

  • SEVENpre:是在训练前应用的预修剪方法。
  • SEVENdyn:是在训练过程中逐渐进行的动态修剪方法。

SEVEN在各种自然语言处理、问答任务和图像分类领域的Transformer模型上进行了广泛的实验,证明了其在多个修剪场景和不同稀疏度水平上的有效性。此外,SEVEN在多种微调策略下也表现出了稳健的性能。

方法论与基本原理#

SEVEN方法主要关注于如何通过管理梯度噪声来优化Transformer模型的剪枝过程。模型训练中的梯度噪声是由小批量数据的随机性引入的,这可以帮助模型探索更广的参数空间,但也可能引起梯度的不稳定。SEVEN通过区分两种类型的权重来应对这一挑战:

权重分类#

临时哨兵权重(Temporary Sentinel Weights, TSW)#

  • 定义:这些权重的梯度值很高,但是包含大量的噪声。

哨兵权重(Sentinel Weights, SW)#

  • 定义:这些权重在多个训练周期中展示出稳定的梯度,且梯度噪声较低。

剪枝方法#

SEVEN提出了两种剪枝策略,分别针对训练前和训练中的不同需求,接下来我们根据论文来介绍这两个最重要的两个算法

SEVENpre(预剪枝方法)#

  • 目标:在训练开始前根据权重的潜在重要性进行剪枝。
  • 实现:使用权重的初始梯度信息来预测其对最终模型性能的影响。具体来说,选择那些具有最小梯度噪声和合适梯度大小的权重。
  • 步骤
    • 计算每个权重的梯度和梯度的统计特性(如均值和方差)。
    • 剪枝那些具有高梯度方差(噪声大)且梯度大小不是特别突出的权重。

Local image

Algorithm 1: SEVENpre 详细介绍#

输入需求:#

  • 预训练模型 (θ0\theta_0):已经训练过的初始模型参数。
  • 训练数据 (DD):用于评估和进一步训练模型的数据集。
  • 稀疏度 (ss):期望的模型稀疏程度,即最终模型中应被置为零的权重比例。
  • 总迭代次数 (TT):算法的总运行轮数。
  • 剪枝步骤 (KK):在这些初期步骤中进行权重剪枝。

算法步骤:#

  1. 迭代循环:从 1 到 (TT),对每一轮迭代执行以下步骤。
  2. 剪枝判断:如果当前迭代次数 (tt) 小于或等于剪枝步骤 (KK),则执行剪枝相关的操作。
  3. 更新累积梯度统计:更新 (gˉt\bar{g}_t) 和 (gˉt2\bar{g}^2_t)。
  4. 计算得分 (S(θt)S(\theta_t)):根据公式 15 计算每个权重的重要性得分。
  5. 计算当前步骤的剪枝率 (PP):这是基于预设的稀疏度 (ss) 和当前迭代步骤的函数。
  6. 计算得分的百分位数 (τ\tau):从所有权重的得分 (S(θt)S(\theta_t)) 中计算得出阈值 (τ\tau)。
  7. 生成掩码 (MM):对于所有权重,如果其得分低于 (τ\tau),则在掩码 (MM) 中对应位置标记为 0,否则标记为 1。
  8. 应用掩码:使用掩码 (MM) 来更新模型,即 (θtM\theta_t \odot M),这一步实际上是将部分权重置零,完成剪枝。
  9. 结束剪枝判断:如果当前迭代次数超过剪枝步骤 (KK),不再进行剪枝。
  10. 更新模型:无论是否进行剪枝,都需要根据损失函数通过梯度下降更新模型权重,即 (θt+1=θtηθtL(θt)\theta_{t+1} = \theta_t - \eta \nabla_{\theta_t} L(\theta_t))。
  11. 重复迭代:直到完成所有迭代。
  12. 返回模型:返回经过剪枝和训练的模型权重 (θM\theta \odot M)。

SEVENdyn(动态剪枝方法)#

  • 目标:在模型训练过程中动态调整剪枝策略,以适应权重在训练过程中的变化。
  • 实现:基于权重在训练过程中的行为动态调整剪枝决策。这涉及到监测权重的梯度变化,并根据其稳定性做出剪枝决策。
  • 步骤
    • 在每个训练周期计算权重的梯度和相关统计量。
    • 动态调整权重的保留或移除,特别关注那些表现出稳定梯度表现的权重。

Local image

Algorithm 2: SEVENdyn 详细介绍#

输入需求:#

  • 预训练模型 (θ0\theta_0):已经训练过的模型初始参数。
  • 训练数据 (DD):用于评估和进一步训练模型的数据集。
  • 稀疏度 (ss):目标稀疏度,表示模型中应被置为零的权重比例。
  • 剪枝步骤 (KK):从开始剪枝的那一刻起进行剪枝的总迭代次数。
  • 总迭代次数 (TT):算法的总运行轮数。
  • 剪枝开始迭代 (titi):开始执行剪枝操作的迭代次数。

算法步骤:#

  1. 迭代循环:从 1 到 (TT),对每一轮迭代执行以下步骤。
  2. 判断是否执行剪枝:如果当前迭代次数 (tt) 在剪枝开始时间 (titi) 之后并且在剪枝结束时间 (ti+Kti + K) 之前,则执行剪枝相关操作。
  3. 更新累积梯度统计:更新 (gˉt\bar{g}_t) 和 (gˉt2\bar{g}^2_t)。
  4. 计算得分 (S(θt)S(\theta_t)):根据公式 15 计算每个权重的得分。
  5. 计算当前步骤的剪枝率 (PP):使用公式 (ss×(1ttiK)3s - s \times \left(1 - \frac{t - ti}{K}\right)^3) 来确定这一步的剪枝率。
  6. 计算得分的百分位数 (τ\tau):根据剪枝率 (PP),计算所有权重得分 (S(θt)S(\theta_t)) 的 (PP) 百分位数作为阈值 (τ\tau)。
  7. 生成掩码 (MM):创建一个掩码 (MM),其中权重得分小于 (τ\tau) 的对应为 0,其余为 1。
  8. 应用掩码并更新模型:使用掩码 (MM) 来更新模型,即 (θt+1M\theta_{t+1} \odot M),实际上是将部分权重置零,同时保留其他权重以进行下一轮训练。
  9. 继续迭代直至结束:如果当前迭代次数超出剪枝步骤范围,不再进行剪枝。
  10. 更新模型:无论是否进行剪枝,都根据损失函数通过梯度下降更新模型权重,即 (θt+1=θtηθtL(θt)\theta_{t+1} = \theta_t - \eta \nabla_{\theta_t} L(\theta_t))。
  11. 迭代完成:重复上述步骤直到完成所有 (TT) 次迭代。
  12. 返回模型:返回经过可能的剪枝和训练后的模型权重 (θM\theta \odot M)。

实验设计与结果#

模型选择#

为了测试SEVEN方法的有效性,实验采用了包括BERT和CLIP等流行的预训练Transformer模型。

数据集#

实验涉及多个任务和数据集,包括:

  • 图像分类任务:CIFAR10/100, ImageNet
  • 自然语言处理任务:GLUE基准测试和SQuAD问答任务

剪枝方法比较#

SEVEN与其他几种剪枝方法进行了比较,包括:

  • 随机剪枝
  • 梯度基剪枝(如SNIP, GraSP)
  • 学习型剪枝(如Lottery Ticket Hypothesis)

性能评估指标#

实验主要关注以下指标:

  • 模型准确性
  • 模型大小
  • 计算效率

实验结果#

模型性能#

SEVEN在多个数据集和任务中表现优异,即使在高剪枝率下也能保持甚至提高模型的准确性。

与其他方法的比较#

在大多数情况下,SEVEN超越了基于梯度的剪枝方法和随机剪枝,尤其是在处理复杂的Transformer模型时更为有效。

不同任务的表现#

无论是在图像分类还是在自然语言处理任务中,SEVEN都显示出强大的性能和鲁棒性。

结论与灵感来源#

SEVEN通过一种创新的剪枝策略,有效地管理了Transformer模型中的梯度噪声,并优化了模型剪枝过程。其灵感主要来源于对现有剪枝方法的局限性的观察,特别是在处理Transformer模型时,传统梯度基方法往往无法充分考虑梯度噪声的影响。

实现步骤#

  1. 问题识别:通过分析现有剪枝方法,识别出在Transformer模型中梯度噪声管理的问题。
  2. 理论提出:基于对梯度噪声和模型权重重要性的深入理解,提出了将哨兵权重(SW)和临时哨兵权重(TSW)的概念用于剪枝决策。
  3. 方法开发
    • SEVENpre:应对训练前的剪枝需求,使用初始梯度信息预测权重的潜在重要性。
    • SEVENdyn:应对训练中的剪枝需求,基于权重的行为动态调整剪枝决策。
  4. 广泛测试:在多个数据集和任务上进行广泛测试,验证方法的有效性和通用性。
  5. 结果分析与优化:根据实验结果进一步调整和优化剪枝策略。
从梯度裁剪到模型剪枝:SEVEN方法的探索
https://ruiboom.cn/posts/seven/
Author
🐱
Published at
2024-05-05