Transformer 模型是 AI 体系的根底。已经有了数不清的关于 “Transformer 如何作业” 的核心结构图表。

了解 Transformers 是如何“思考”的

可是这些图表没有供给任何直观的核算该模型的结构表明。当研究者关于 Transformer 如何作业抱有兴趣时,直观的获取他运行的机制变得非常有用。

Thinking Like Transformers 这篇论文中提出了 transformer 类的核算结构,这个结构直接核算和模仿 Transformer 核算。运用 RASP 编程言语,使每个程序编译成一个特别的 Transformer。

在这篇博客中,我用 Python 复现了 RASP 的变体 (RASPy)。该言语大致与原始版别适当,可是多了一些我认为很风趣的变化。经过这些言语,作者 Gail Weiss 的作业,供给了一套具有应战性的风趣且正确的方法能够协助了解其作业原理。

!pip install git+https://github.com/srush/RASPy

在说起言语自身前,让咱们先看一个比方,看看用 Transformers 编码是什么样的。这是一些核算翻转的代码,即反向输入序列。代码自身用两个 Transformer 层运用 attention 和数学核算到达这个结果。

def flip():
    length = (key(1) == query(1)).value(1)
    flip = (key(length - indices - 1) == query(indices)).value(tokens)
    return flip
flip()

了解 Transformers 是如何“思考”的

文章目录

  • 榜首部分:Transformers 作为代码
  • 第二部分:用 Transformers 编写程序

Transformers 作为代码

咱们的方针是定义一套核算形式来最小化 Transformers 的表达。咱们将经过类比,描绘每个言语构造及其在 Transformers 中的对应。(正式言语规范请在本文底部查看论文全文链接)。

这个言语的核心单元是将一个序列转化成相同长度的另一个序列的序列操作。我后面将其称之为 transforms。

输入

在一个 Transformer 中,基本层是一个模型的前馈输入。这个输入一般包括原始的 token 和方位信息。

了解 Transformers 是如何“思考”的

在代码中,tokens 的特征表明最简略的 transform,它返回经过模型的 tokens,默许输入序列是 “hello”:

tokens

了解 Transformers 是如何“思考”的

假如咱们想要改变 transform 里的输入,咱们运用输入办法进行传值。

tokens.input([5, 2, 4, 5, 2, 2])

了解 Transformers 是如何“思考”的

作为 Transformers,咱们不能直接承受这些序列的方位。可是为了模拟方位嵌入,咱们能够获取方位的索引:

indices

了解 Transformers 是如何“思考”的

sop = indices
sop.input("goodbye")

了解 Transformers 是如何“思考”的

前馈网络

经过输入层后,咱们到达了前馈网络层。在 Transformer 中,这一步能够关于序列的每一个元素独立的运用数学运算。

了解 Transformers 是如何“思考”的

在代码中,咱们经过在 transforms 上核算表明这一步。在每一个序列的元素中都会进行独立的数学运算。

tokens == "l"

了解 Transformers 是如何“思考”的

结果是一个新的 transform,一旦重构新的输入就会按照重构方法核算:

model = tokens * 2 - 1
model.input([1, 2, 3, 5, 2])

了解 Transformers 是如何“思考”的

该运算能够组合多个 Transforms,举个比方,以上述的 token 和 indices 为例,这儿能够类别 Transformer 能够盯梢多个片段信息:

model = tokens - 5 + indices
model.input([1, 2, 3, 5, 2])

了解 Transformers 是如何“思考”的

(tokens == "l") | (indices == 1)

了解 Transformers 是如何“思考”的

咱们供给了一些辅助函数让写 transforms 变得更简略,举例来说,where 供给了一个相似 if 功用的结构。

where((tokens == "h") | (tokens == "l"), tokens, "q")

了解 Transformers 是如何“思考”的

map 使咱们能够定义自己的操作,例如一个字符串以 int 转化。(用户应慎重运用能够运用的简略神经网络核算的操作)

atoi = tokens.map(lambda x: ord(x) - ord('0'))
atoi.input("31234")

了解 Transformers 是如何“思考”的

函数 (functions) 能够简单的描绘这些 transforms 的级联。举例来说,下面是运用了 where 和 atoi 和加 2 的操作

def atoi(seq=tokens):
    return seq.map(lambda x: ord(x) - ord('0')) 
op = (atoi(where(tokens == "-", "0", tokens)) + 2)
op.input("02-13")

了解 Transformers 是如何“思考”的

注意力挑选器

到开端运用注意力机制事情就变得开端风趣起来了。这将答应序列间的不同元素进行信息交流。

了解 Transformers 是如何“思考”的

咱们开端定义 key 和 query 的概念,Keys 和 Queries 能够直接从上面的 transforms 创立。举个比方,假如咱们想要定义一个 key 咱们称作 key

key(tokens)

关于 query 也一样

query(tokens)

了解 Transformers 是如何“思考”的

标量能够作为 keyquery 运用,他们会广播到根底序列的长度。

query(1)

了解 Transformers 是如何“思考”的

咱们创立了挑选器来运用 key 和 query 之间的操作。这对应于一个二进制矩阵,指示每个 query 要关注哪个 key。与 Transformers 不同,这个注意力矩阵未参加权重

eq = (key(tokens) == query(tokens))
eq

了解 Transformers 是如何“思考”的

一些比方:

  • 挑选器的匹配方位偏移 1:
offset = (key(indices) == query(indices - 1))
offset

了解 Transformers 是如何“思考”的

  • key 早于 query 的挑选器:
before = key(indices) < query(indices)
before

了解 Transformers 是如何“思考”的

  • key 晚于 query 的挑选器:
after = key(indices) > query(indices)
after

了解 Transformers 是如何“思考”的

挑选器能够经过布尔操作合并。比方,这个挑选器将 before 和 eq 做合并,咱们经过在矩阵中包括一对键和值来显示这一点。

before & eq

了解 Transformers 是如何“思考”的

运用注意力机制

给一个注意力挑选器,咱们能够供给一个序列值做聚合操作。咱们经过累加那些挑选器选过的真值做聚合。

(请注意:在原始论文中,他们运用一个平均聚合操作而且展现了一个巧妙的结构,其中平均聚合能够代表总和核算。RASPy 默许情况下运用累加来使其简略化并防止碎片化。实际上,这意味着 raspy 或许低估了所需求的层数。基于平均值的模型或许需求这个层数的两倍)

注意聚合操作使咱们能够核算直方图之类的功用。

(key(tokens) == query(tokens)).value(1)

了解 Transformers 是如何“思考”的

视觉上咱们遵从图表结构,Query 在左面,Key 在上边,Value 在下面,输出在右边

了解 Transformers 是如何“思考”的

一些注意力机制操作甚至不需求用到输入 token 。举例来说,去核算序列长度,咱们创立一个 ” select all ” 的注意力挑选器而且给他赋值。

length = (key(1) == query(1)).value(1)
length = length.name("length")
length

了解 Transformers 是如何“思考”的

这儿有更多杂乱的比方,下面将一步一步展现。(这有点像做采访一样)

咱们想要核算一个序列的相邻值的和,首先咱们向前切断:

WINDOW=3
s1 = (key(indices) >= query(indices - WINDOW + 1))  
s1

了解 Transformers 是如何“思考”的

然后咱们向后切断:

s2 = (key(indices) <= query(indices))
s2

了解 Transformers 是如何“思考”的

两者相交:

sel = s1 & s2
sel

了解 Transformers 是如何“思考”的

终究聚合:

sum2 = sel.value(tokens)
sum2.input([1,3,2,2,2])

了解 Transformers 是如何“思考”的

这儿有个能够核算累计求和的比方,咱们这儿引入一个给 transform 命名的能力来协助你调试。

def cumsum(seq=tokens):
    x = (before | (key(indices) == query(indices))).value(seq)
    return x.name("cumsum")
cumsum().input([3, 1, -2, 3, 1])

了解 Transformers 是如何“思考”的

这个言语支撑编译更加杂乱的 transforms。他一起经过盯梢每一个运算操作核算层。

了解 Transformers 是如何“思考”的

这儿有个 2 层 transform 的比方,榜首个对应于核算长度,第二个对应于累积总和。

x = cumsum(length - indices)
x.input([3, 2, 3, 5])

了解 Transformers 是如何“思考”的

用 transformers 进行编程

运用这个函数库,咱们能够编写完结一个杂乱任务,Gail Weiss 给过我一个极端应战的问题来打破这个过程:咱们能够加载一个添加任意长度数字的 Transformer 吗?

例如: 给一个字符串 “19492+23919”, 咱们能够加载正确的输出吗?

假如你想自己尝试,咱们供给了一个 版别 你能够自己试试。

应战一:挑选一个给定的索引

加载一个在索引 i 处全元素都有值的序列

def index(i, seq=tokens):
    x = (key(indices) == query(i)).value(seq)
    return x.name("index")
index(1)

了解 Transformers 是如何“思考”的

应战二:转化

经过 i 方位将一切 token 移动到右侧。

def shift(i=1, default="_", seq=tokens):
    x = (key(indices) == query(indices-i)).value(seq, default)
    return x.name("shift")
shift(2)

了解 Transformers 是如何“思考”的

应战三:最小化

核算序列的最小值。(这一步开端变得困难,咱们版别用了 2 层注意力机制)

def minimum(seq=tokens):
    sel1 = before & (key(seq) == query(seq))
    sel2 = key(seq) < query(seq)
    less = (sel1 | sel2).value(1)
    x = (key(less) == query(0)).value(seq)
    return x.name("min")
minimum()([5,3,2,5,2])

了解 Transformers 是如何“思考”的

应战四:榜首索引

核算有 token q 的榜首索引 (2 层)

def first(q, seq=tokens):
    return minimum(where(seq == q, indices, 99))
first("l")

了解 Transformers 是如何“思考”的

应战五:右对齐

右对齐一个填充序列。例:”ralign().inputs('xyz___') ='—xyz'” (2 层)

def ralign(default="-", sop=tokens):
    c = (key(sop) == query("_")).value(1)
    x = (key(indices + c) == query(indices)).value(sop, default)
    return x.name("ralign")
ralign()("xyz__")

了解 Transformers 是如何“思考”的

应战六:别离

把一个序列在 token “v” 处别离成两部分然后右对齐 (2 层):

def split(v, i, sop=tokens):
    mid = (key(sop) == query(v)).value(indices)
    if i == 0:
        x = ralign("0", where(indices < mid, sop, "_"))
        return x
    else:
        x = where(indices > mid, sop, "0")
        return x
split("+", 1)("xyz+zyr")

了解 Transformers 是如何“思考”的

split("+", 0)("xyz+zyr")

了解 Transformers 是如何“思考”的

应战七:滑动

将特别 token “<” 替换为最接近的 “<” value (2 层):

def slide(match, seq=tokens):
    x = cumsum(match) 
    y = ((key(x) == query(x + 1)) & (key(match) == query(True))).value(seq)
    seq =  where(match, seq, y)
    return seq.name("slide")
slide(tokens != "<").input("xxxh<<<l")

了解 Transformers 是如何“思考”的

应战八:增加

你要执行两个数字的添加。这是过程。

add().input("683+345")
  1. 分成两部分。转制成整形。参加

“683+345” => [0, 0, 0, 9, 12, 8]

  1. 核算带着条款。三种或许性:1 个带着,0 不带着,< 也许有带着。

[0, 0, 0, 9, 12, 8] => “00<100”

  1. 滑动进位系数

“00<100” => 001100″

  1. 完结加法

这些都是 1 行代码。完整的体系是 6 个注意力机制。(虽然 Gail 说,假如你足够仔细则能够在 5 个中完结!)。

def add(sop=tokens):
    # 0) Parse and add
    x = atoi(split("+", 0, sop)) + atoi(split("+", 1, sop))
    # 1) Check for carries 
    carry = shift(-1, "0", where(x > 9, "1", where(x == 9, "<", "0")))
    # 2) In parallel, slide carries to their column                                         
    carries = atoi(slide(carry != "<", carry))
    # 3) Add in carries.                                                                                  
    return (x + carries) % 10
add()("683+345")

了解 Transformers 是如何“思考”的

683 + 345
1028

完美搞定!

参考资料 & 文内链接:

  • 假如你对这个主题感兴趣想了解更多,请查看论文:Thinking Like Transformers
  • 以及了解更多 RASP 言语
  • 假如你对「形式言语和神经网络」(FLaNN) 感兴趣或者有认识感兴趣的人,欢迎约请他们参加咱们的 线上社区!
  • 本篇博文,包括库、Notebook 和博文的内容
  • 本博客文章由 Sasha Rush 和 Gail Weiss 共同编写

英文原文:Thinking Like Transformers

译者:innovation64 (李洋)