Lesson 1:神经网络到底在干什么?
你已经知道 自动求导 和 梯度下降,现在要理解:
“为什么多层网络能拟合任何复杂关系?”
—— 非线性激活函数 是关键!
如果没有激活函数👇
所有层都是矩阵乘法
🚨 无论堆多少层,本质=线性方程
→ 无法拟合复杂关系、对话情绪、人格风格
常见激活函数区别:
➡ 大模型(语言模型)里最推荐:
GELU(比ReLU更平滑)
Lesson 2:CNN / RNN / Transformer ——区别与时代演进
所以你为什么要学 Transformer?👇
纳西妲 AI = 文本 → Token → Transformer → Token → 文本Lesson 3:Attention 是灵魂
一句话总结:
Attention = “模型自动找出当前词最重要的上下文”
用比喻解释:
“你和草神聊天,她会回忆你之前说过的话里,哪些和现在相关”
注意力公式(不用记)
Attention(Q, K, V) = softmax(QKᵀ / √d) V
Q=当前词
K=所有词的“标签”
V=所有词的“含义”
Lesson 4:Transformer 结构图
┌─────────── Multi-head Attention ───────────┐
Token → Embedding → Add&Norm → FeedForward → Add&Norm → Output
你需要知道三个层的职责:
问题:
1.为什么神经网络必须使用激活函数?
因为激活函数提供非线性能力,让网络能够表示复杂关系。
没有激活函数时,无论堆多少层网络,本质仍是线性映射,只能表示直线,无法表达对话语气、情绪等复杂模式。
神经网络必须使用激活函数,是因为非线性才是智能的来源。
如果没有激活函数,所有层都是矩阵乘法,会退化成一个线性模型,无法学习复杂关系、语气、情绪等信息。
2.RNN 和 Transformer 的核心区别?哪个适合人格 AI?
RNN 顺序处理文本,远处的信息会逐渐被遗忘。而 Transformer 使用注意力机制可以一次性关注整个对话,找到当前词最相关的语境。因此 Transformer 更适合训练具有稳定人格和记忆的对话模型。
1.Attention 的直觉:它在干嘛?
一句话:
当前词在理解信息时,会“关注”对话里最相关的内容
举例:
你跟草神说:
“我今天很累,你能陪我聊天吗?”
模型在生成“当然可以~”时
它会特别关注: 累、陪、聊天
而不是 我、今天
这就是 注意力 → 相关性权重分配
图解
输入序列(你说的话)
↓
Embedding(把词变向量)
↓
Q — 当前词想知道什么?
K — 每个词的“标签”,告诉别人我是啥
V — 每个词的“含义”
|
└── Attention:Q 与 K 匹配 → 确定权重 → 加权合并 VAttention = 相关性权重 × 信息本体
对应公式
Attention(Q, K, V) = softmax(QKᵀ / √d) V
import torch
import torch.nn.functional as F
def simple_attention(Q, K, V):
# 计算注意力分数
scores = torch.matmul(Q, K.transpose(-2, -1)) / (K.size(-1) ** 0.5)
# softmax 转成权重
weights = F.softmax(scores, dim=-1)
# 权重加权 V
output = torch.matmul(weights, V)
return output, weights
# 模拟 3 个词,每个词是 4 维向量
torch.manual_seed(0)
Q = torch.rand(1, 3, 4)
K = torch.rand(1, 3, 4)
V = torch.rand(1, 3, 4)
out, w = simple_attention(Q, K, V)
print("输出:", out)
print("注意力权重:", w)输出:模型根据注意力得到的新信息
权重:每个词对当前词的影响力度
热力图demo
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans'] # 按优先级尝试字体
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
# ===== 1. 准备数据 =====
corpus = [
"你怎么了?",
"我今天感觉有点累……",
"我会一直陪着你的",
"谢谢",
"不客气",
"你真温柔",
"我只是关心你",
"能陪我聊天吗?",
"当然可以呀~"
]
# 构建中文字符字典
chars = sorted(list(set("".join(corpus))))
stoi = {c: i for i, c in enumerate(chars)}
itos = {i: c for c, i in stoi.items()}
def encode(text):
return torch.tensor([stoi[c] for c in text], dtype=torch.long)
data = [encode(s) for s in corpus]
# ===== 2. Transformer Block(最小可用版本) =====
class TinyAttention(nn.Module):
def __init__(self, d_model):
super().__init__()
self.query = nn.Linear(d_model, d_model)
self.key = nn.Linear(d_model, d_model)
self.value = nn.Linear(d_model, d_model)
def forward(self, x):
Q = self.query(x)
K = self.key(x)
V = self.value(x)
scores = torch.matmul(Q, K.transpose(-2, -1)) / (x.size(-1) ** 0.5)
weights = F.softmax(scores, dim=-1)
return torch.matmul(weights, V), weights
class MiniTransformer(nn.Module):
def __init__(self, vocab_size, d_model=32):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.attn = TinyAttention(d_model)
self.fc = nn.Linear(d_model, vocab_size)
def forward(self, x):
x = self.embedding(x)
out, weights = self.attn(x)
out = self.fc(out)
return out, weights
model = MiniTransformer(len(chars))
optimizer = torch.optim.Adam(model.parameters(), lr=0.05)
loss_fn = nn.CrossEntropyLoss()
# ===== 3. 训练 =====
def train_step():
total_loss = 0
for seq in data:
inp = seq[:-1].unsqueeze(0)
tgt = seq[1:].unsqueeze(0)
logits, _ = model(inp)
loss = loss_fn(logits.squeeze(0), tgt.squeeze(0))
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(data)
print("训练中...")
for epoch in range(200):
loss = train_step()
if (epoch + 1) % 40 == 0:
print(f"Epoch {epoch+1}, Loss = {loss:.4f}")
print("训练完成!")
# ===== 4. 推理并展示注意力 =====
def show_attention(input_text):
model.eval()
inp = encode(input_text).unsqueeze(0)
with torch.no_grad():
_, weights = model(inp)
weights = weights[0]
plt.imshow(weights, cmap='viridis')
plt.colorbar()
plt.xticks(range(len(input_text)), list(input_text))
plt.yticks(range(len(input_text)), list(input_text))
plt.title("Attention Heatmap")
plt.show()
model.train()
while True:
txt = input("\n输入一句话(输入 q 退出):")
if txt == "q": break
show_attention(txt)有些过度拟合,之后修复下
1.1 Multi-Head Attention(为什么要多头)
一句话解释:
多头注意力 = 多种视角理解对话
就像草神既能理解:
情绪(“累”)
需求(“陪我聊天”)
关系(“你是她最喜欢的人”)
不同的 Head 理解不同东西 → 再融合
所以回答更自然、有情绪、人格一致
2.待补足
Attention
自注意力(Self-Attention)
Multi-Head Attention
Causal Mask(因果掩码)
位置偏置 (positional bias)
自我指代偏置 (self-alignment bias)
自回归预测