LLM的长度外推浅谈

一、nbce
nbce:使用朴素贝叶斯扩展llm的context处理长度  
苏神最早提出的扩展llm的context方法,基于bayes启发得到的公式:
在问答下实测确实不错,在较长context下的阅读理解还算好用。
局限性是,无序性,即无法识别context的输入顺序,这在续写故事等场景可能表现欠佳,做一些依赖每个context生成答案,比如提取文档摘要,效果较差。
outputs = model(input_ids=input_ids,                        attention_mask=attention_mask,                        return_dict=true,                        use_cache=true,                        past_key_values=past_key_values                       )past_key_values = outputs.past_key_values        # ===== 核心代码开始 =====beta = 0.25probas = torch.nn.functional.softmax(outputs.logits[:, -1], dim=-1)logits = probas.log()k = (probas * logits).sum(dim=-1)[1:].argmax() + 1logits_max = logits[k]logits_uncond = logits[0]logits = (1 + beta) * logits_max - beta * logits_uncond# ===== 核心代码结束 =====        # 构建分布,采样tau = 0.01  # tau = 1是标准的随机采样,tau->0则是贪心搜索probas = torch.nn.functional.softmax(logits[none] / tau , dim=-1)next_tokens = torch.multinomial(probas, num_samples=1).squeeze(1)      
此处代码,图片,文本均选自科学空间。
二、线性内插
llama基于rotary embedding在2048长度上预训练,该方法通过将position压缩到0~2048之间,从而达到长度外推的目的。
longchat将模型微调为上下文长度外扩为16384,压缩比为 8。例如,position_ids = 10000 的 token 变为position_ids = 10000 / 8 = 1250,相邻 token 10001 变为 10001 / 8 = 1250.125
该方法的缺陷是需要进行一定量的微调,让模型来适应这种改变。
import torchimport transformersimport transformers.models.llama.modeling_llamafrom einops import rearrangefrom functools import partialclass condenserotaryembedding(torch.nn.module):    def __init__(self, dim, ratio, max_position_embeddings=2048, base=10000, device=none):        super().__init__()        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))        self.register_buffer(inv_freq, inv_freq)                # build here to make `torch.jit.trace` work.        self.ratio = ratio        max_position_embeddings *= ratio        print(fcondensing positional embeddings from {max_position_embeddings} to {max_position_embeddings // ratio})        self.max_seq_len_cached = max_position_embeddings        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) / ratio        freqs = torch.einsum(i,j->ij, t, self.inv_freq)        # different from paper, but it uses a different permutation in order to obtain the same calculation        emb = torch.cat((freqs, freqs), dim=-1)        dtype = torch.get_default_dtype()        self.register_buffer(cos_cached, emb.cos()[none, none, :, :].to(dtype), persistent=false)        self.register_buffer(sin_cached, emb.sin()[none, none, :, :].to(dtype), persistent=false)    def forward(self, x, seq_len=none):        # x: [bs, num_attention_heads, seq_len, head_size]        # this `if` block is unlikely to be run after we build sin/cos in `__init__`. keep the logic here just in case.        if seq_len > self.max_seq_len_cached:            self.max_seq_len_cached = seq_len            t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) / self.ratio            freqs = torch.einsum(i,j->ij, t, self.inv_freq)            # different from paper, but it uses a different permutation in order to obtain the same calculation            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)            self.register_buffer(cos_cached, emb.cos()[none, none, :, :].to(x.dtype), persistent=false)            self.register_buffer(sin_cached, emb.sin()[none, none, :, :].to(x.dtype), persistent=false)        return (            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),        )def replace_llama_with_condense(ratio):    transformers.models.llama.modeling_llama.llamarotaryembedding = partial(condenserotaryembedding, ratio=ratio)  
三、ntk-aware scaled rope
rope是一种β进制编码//spaces.ac.cn/archives/9675
有意思的解释一下,rope 的行为就像一个时钟。12小时时钟基本上是一个维度为 3、底数为 60 的 rope。因此,每秒钟,分针转动 1/60 分钟,每分钟,时针转动 1/60。
现在,如果将时间减慢 4 倍,那就是二使用的线性rope 缩放。不幸的是,现在区分每一秒,因为现在秒针几乎每秒都不会移动。
因此,如果有人给你两个不同的时间,仅相差一秒,你将无法从远处区分它们。ntk-aware rope 扩展不会减慢时间。一秒仍然是一秒,但它会使分钟减慢 1.5 倍,将小时减慢 2 倍。
这样,您可以将 90 分钟容纳在一个小时中,将 24 小时容纳在半天中。
所以现在你基本上有了一个可以测量 129.6k 秒而不是 43.2k 秒的时钟。由于在查看时间时不需要精确测量时针,因此与秒相比,更大程度地缩放小时至关重要。
不想失去秒针的精度,但可以承受分针甚至时针的精度损失。
import transformersold_init = transformers.models.llama.modeling_llama.llamarotaryembedding.__init__def ntk_scaled_init(self, dim, max_position_embeddings=2048, base=10000, device=none):    #the method is just these three lines    max_position_embeddings = 16384    a = 8 #alpha value    base = base * a ** (dim / (dim-2)) #base change formula    old_init(self, dim, max_position_embeddings, base, device)transformers.models.llama.modeling_llama.llamarotaryembedding.__init__ = ntk_scaled_init  
四、dynamically scaled rope
对于上面的方法二、三,都涉及到一个超参数α,用于调节缩放比例,该方法是通过序列长度动态选择正确的比例参数,效果可以看上图。
对于线性插值,前 2k 上下文的精确位置值,然后在模型逐个生成标记时重新计算每个新序列长度的位置向量。本质上,将比例设置为原始模型上下文长度/当前序列长度。
对于动态 ntk,α 的缩放设置为 (α * 当前序列长度 / 原始模型上下文长度) - (α - 1)。随着序列长度的增加动态缩放超参数。
import mathimport torchclass llamadynamicscaledrotaryembedding(torch.nn.module):    def __init__(self, dim, max_position_embeddings=2048, base=10000, ntk=false, device=none):        super().__init__()        self.ntk = ntk        self.base = base        self.dim = dim        self.max_position_embeddings = max_position_embeddings        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))        self.register_buffer(inv_freq, inv_freq)        # build here to make `torch.jit.trace` work.        self.max_seq_len_cached = max_position_embeddings        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)        freqs = torch.einsum(i,j->ij, t, self.inv_freq)        # different from paper, but it uses a different permutation in order to obtain the same calculation        emb = torch.cat((freqs, freqs), dim=-1)        dtype = torch.get_default_dtype()        self.register_buffer(cos_cached, emb.cos()[none, none, :, :].to(dtype), persistent=false)        self.register_buffer(sin_cached, emb.sin()[none, none, :, :].to(dtype), persistent=false)    def forward(self, x, seq_len=none):        # x: [bs, num_attention_heads, seq_len, head_size]        # this `if` block is unlikely to be run after we build sin/cos in `__init__`. keep the logic here just in case.        if seq_len > self.max_seq_len_cached:            self.max_seq_len_cached = seq_len            if self.ntk:                base = self.base * ((self.ntk * seq_len / self.max_position_embeddings) - (self.ntk - 1)) ** (self.dim / (self.dim-2))                inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim))                self.register_buffer(inv_freq, inv_freq)            t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)            if not self.ntk:                t *= self.max_position_embeddings / seq_len            freqs = torch.einsum(i,j->ij, t, self.inv_freq)            # different from paper, but it uses a different permutation in order to obtain the same calculation            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)            self.register_buffer(cos_cached, emb.cos()[none, none, :, :].to(x.dtype), persistent=false)            self.register_buffer(sin_cached, emb.sin()[none, none, :, :].to(x.dtype), persistent=false)        return (            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),        )  
五、consistent of dynamically scaled rope
方法四存在一个问题是,因为α是动态的,因为解码是有cache的,所以,在生成第100个token时,算的α和第200个token时,算的α时不一致的。
query和key的rotation base不一致,正确的应该时这样
import mathfrom typing import list, optional, tuple, unionimport torchimport torch.nn.functional as fimport torch.utils.checkpointfrom torch import nnfrom transformers.models.llama.modeling_llama import repeat_kv, apply_rotary_pos_embfrom transformers.models.llama.modeling_llama import llamaattentiondef forward(        self,        hidden_states: torch.tensor,        attention_mask: optional[torch.tensor] = none,        position_ids: optional[torch.longtensor] = none,        past_key_value: optional[tuple[torch.tensor]] = none,        output_attentions: bool = false,        use_cache: bool = false,) -> tuple[torch.tensor, optional[torch.tensor], optional[tuple[torch.tensor]]]:    bsz, q_len, _ = hidden_states.size()    if self.pretraining_tp > 1:        key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp        query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0)        key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)        value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)        query_states = [f.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)]        query_states = torch.cat(query_states, dim=-1)        key_states = [f.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)]        key_states = torch.cat(key_states, dim=-1)        value_states = [f.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)]        value_states = torch.cat(value_states, dim=-1)    else:        query_states = self.q_proj(hidden_states)        key_states = self.k_proj(hidden_states)        value_states = self.v_proj(hidden_states)    query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)    key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)    value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)    kv_seq_len = key_states.shape[-2]    if past_key_value is not none:        kv_seq_len += past_key_value[0].shape[-2]    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)    if past_key_value is not none:        # reuse k w/o rope        key_states = torch.cat([past_key_value[0], key_states], dim=2)    # apply rope after retrieving all keys and queries    query_states, rotated_key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)    if past_key_value is not none:        # reuse v, self_attention        value_states = torch.cat([past_key_value[1], value_states], dim=2)    past_key_value = (key_states, value_states) if use_cache else none # cache the key w/o rope    # repeat k/v heads if n_kv_heads  1:        attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)        o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1)        attn_output = sum([f.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)])    else:        attn_output = self.o_proj(attn_output)    if not output_attentions:        attn_weights = none    return attn_output, attn_weights, past_key_valuedef replace_llama_attn_with_consistent_ntk_rope():    llamaattention.forward = forward


晶振的最大波特率及其误差介绍
基于C#开发一个简单的窗体应用程序
手机板对板连接器的测试及解决方案
通过CAN模块和PIC30系列芯片实现船舶电站控制系统的设计
rfid系统供应商有哪些_国内十大rfid系统供应商排名
LLM的长度外推浅谈
Safran与是德科技合作推出基于GNSS技术的5G LBS方案
Maxim EZCascade技术简化视频显示器设计
腾讯云Serverless应用最新进展,构建全云端开发体验
ICMAX大型自动化DDR测试机台即将上线运行 助力存储芯片行业国产化替代进程提速
iButton传感器和温度/湿度数据记录器综述
三菱PLC GX Developer的应用特点
户外安防摄像头气密性测试是如何做到的
光栅线位移传感器的结构原理及维护知识
美国研究人员,用机器学习,实现让机器人在不平路面上自由行走
3名航天员太空生活“剧透”
飞腾盛装亮相MWC上海展,全面展现5G行业“芯”实践
苹果13pro价格表官网报价
C与脚本的混合编程是怎样编程的
智能音箱是服务人类的还是窃听私隐的