从零实现TensorFlow 2.x CRF层:代码反推NER核心原理
在自然语言处理领域,命名实体识别(NER)任务常采用条件随机场(CRF)作为解码层。但大多数教程停留在数学公式推导,让开发者陷入"看得懂推不通,推得通写不出"的困境。本文将以代码实现为主导,通过手写TensorFlow 2.x的CRF层,逆向解析其核心原理。我们将从定义发射分数和转移矩阵开始,逐步实现前向计算、损失函数和维特比解码,最终整合成可复用的CRF层模块。
1. CRF核心概念与实现准备
CRF作为判别式概率模型,其核心在于考虑相邻标签间的转移特性。例如在BIO标注体系中,"B-PER"后面只能接"I-PER"或"O",而不能接"B-ORG"。这种强约束特性使CRF成为序列标注任务的首选。
实现CRF层需要三个关键组件:
- 发射分数(Emission Scores):由上层模型输出的每个标签的未归一化分数
- 转移矩阵(Transition Matrix):存储标签间转移概率的参数矩阵
- 维特比算法(Viterbi):计算最优标签序列的动态规划算法
先导入必要库并定义超参数:
import tensorflow as tf from tensorflow.keras.layers import Layer class CRF(Layer): def __init__(self, num_tags, **kwargs): super(CRF, self).__init__(**kwargs) self.num_tags = num_tags # 标签数量 self.transitions = tf.Variable( tf.random.uniform(shape=(num_tags, num_tags)), name="transitions", trainable=True)这里num_tags表示标签数量,transitions是随机初始化的可训练转移矩阵。例如BIO标注体系有3个标签,则num_tags=3。
提示:转移矩阵的维度是[标签数量, 标签数量],每个元素
transitions[i][j]表示从标签i转移到标签j的分数
2. 实现CRF的前向计算
前向计算需要完成两项工作:计算所有可能路径的分数(用于训练时的损失计算)和计算真实路径的分数。
定义前向计算函数:
def call(self, inputs, targets=None, sequence_lengths=None, training=None): emissions = inputs # 上层模型输出的发射分数 [batch_size, seq_len, num_tags] if training and targets is not None: # 训练阶段计算损失 log_likelihood = self._compute_log_likelihood(emissions, targets, sequence_lengths) self.add_loss(-log_likelihood) # 预测阶段返回维特比解码结果 return self._viterbi_decode(emissions, sequence_lengths)其中_compute_log_likelihood函数计算负对数似然损失,_viterbi_decode函数实现维特比解码算法。
2.1 计算真实路径分数
真实路径分数的计算需要考虑发射分数和转移分数:
def _compute_log_likelihood(self, emissions, tags, sequence_lengths): batch_size = tf.shape(emissions)[0] seq_len = tf.shape(emissions)[1] num_tags = tf.shape(emissions)[2] # 创建掩码处理变长序列 mask = tf.sequence_mask(sequence_lengths, maxlen=seq_len, dtype=tf.float32) # 计算发射分数 emit_scores = tf.gather_nd(emissions, tf.stack([tf.range(batch_size)[:, None], tf.range(seq_len)[None, :], tags], axis=-1)) emit_scores = tf.reduce_sum(emit_scores * mask, axis=1) # 计算转移分数 tags_transposed = tf.transpose(tags, perm=[1, 0]) prev_tags = tf.concat([tf.fill([1, batch_size], -1), tags_transposed[:-1]], axis=0) prev_tags = tf.transpose(prev_tags, perm=[1, 0]) transition_scores = tf.gather_nd(self.transitions, tf.stack([prev_tags, tags], axis=-1)) transition_scores = tf.reduce_sum(transition_scores * mask, axis=1) # 计算序列开始和结束的分数 start_tags = tf.gather(tags, [0], axis=1) start_scores = tf.gather_nd(self.transitions, tf.stack([tf.zeros_like(start_tags), start_tags], axis=-1)) end_tags = tf.gather(tags, sequence_lengths-1, axis=1, batch_dims=1) end_scores = tf.gather_nd(self.transitions, tf.stack([end_tags, tf.fill(tf.shape(end_tags), self.num_tags-1)], axis=-1)) # 计算对数似然 log_numerator = tf.reduce_sum(emit_scores + transition_scores + start_scores + end_scores) log_denominator = self._compute_log_partition_function(emissions, sequence_lengths) return log_numerator - log_denominator2.2 计算所有路径分数(配分函数)
配分函数Z的计算采用动态规划方法,避免直接计算所有可能路径的高复杂度:
def _compute_log_partition_function(self, emissions, sequence_lengths): batch_size, seq_len, num_tags = tf.unstack(tf.shape(emissions)) # 初始化前向变量alpha log_alpha = tf.TensorArray(tf.float32, size=seq_len) init_alpha = tf.fill([batch_size, num_tags], -1e4) init_alpha = tf.tensor_scatter_nd_update(init_alpha, [[i, 0] for i in range(batch_size)], tf.zeros(batch_size)) log_alpha = log_alpha.write(0, init_alpha) # 递归计算前向变量 mask = tf.sequence_mask(sequence_lengths, maxlen=seq_len, dtype=tf.float32) emissions_t = tf.transpose(emissions, perm=[1, 0, 2]) def loop_fn(i, log_alpha): prev_log_alpha = log_alpha.read(i-1) curr_emissions = emissions_t[i] # 广播相加:prev_log_alpha [batch, num_tags] + transitions [num_tags, num_tags] # 得到 [batch, num_tags, num_tags] log_alpha_i = prev_log_alpha[:, None] + self.transitions[None, :, :] log_alpha_i += curr_emissions[:, None, :] # logsumexp沿最后一个维度计算 log_alpha_i = tf.reduce_logsumexp(log_alpha_i, axis=-1) # 应用掩码 log_alpha_i = log_alpha_i * mask[:, i, None] + log_alpha.read(i-1) * (1 - mask[:, i, None]) return log_alpha.write(i, log_alpha_i) # 执行循环 for i in tf.range(1, seq_len): log_alpha = loop_fn(i, log_alpha) # 最终计算配分函数 log_alpha_final = log_alpha.read(seq_len-1) log_z = tf.reduce_logsumexp(log_alpha_final + self.transitions[:, -1], axis=-1) return tf.reduce_sum(log_z)3. 维特比解码算法实现
预测阶段需要找到分数最高的标签序列,这可以通过维特比算法实现:
def _viterbi_decode(self, emissions, sequence_lengths): batch_size, seq_len, num_tags = tf.unstack(tf.shape(emissions)) # 初始化维特比变量 viterbi = tf.TensorArray(tf.float32, size=seq_len) init_viterbi = tf.fill([batch_size, num_tags], -1e4) init_viterbi = tf.tensor_scatter_nd_update(init_viterbi, [[i, 0] for i in range(batch_size)], tf.zeros(batch_size)) viterbi = viterbi.write(0, init_viterbi) # 初始化反向指针 backpointers = tf.TensorArray(tf.int32, size=seq_len) init_ptrs = tf.fill([batch_size, num_tags], 0) backpointers = backpointers.write(0, init_ptrs) # 递归计算 emissions_t = tf.transpose(emissions, perm=[1, 0, 2]) mask = tf.sequence_mask(sequence_lengths, maxlen=seq_len, dtype=tf.float32) def loop_fn(i, viterbi, backpointers): prev_viterbi = viterbi.read(i-1) curr_emissions = emissions_t[i] # 广播相加:prev_viterbi [batch, num_tags] + transitions [num_tags, num_tags] # 得到 [batch, num_tags, num_tags] curr_viterbi = prev_viterbi[:, None] + self.transitions[None, :, :] curr_viterbi += curr_emissions[:, None, :] # 记录最大值和反向指针 max_logp = tf.reduce_max(curr_viterbi, axis=-1) argmax_tags = tf.argmax(curr_viterbi, axis=-1, output_type=tf.int32) # 应用掩码 masked_max_logp = max_logp * mask[:, i, None] + prev_viterbi * (1 - mask[:, i, None]) masked_argmax_tags = argmax_tags * tf.cast(mask[:, i, None], tf.int32) + backpointers.read(i-1) * (1 - tf.cast(mask[:, i, None], tf.int32)) viterbi = viterbi.write(i, masked_max_logp) backpointers = backpointers.write(i, masked_argmax_tags) return viterbi, backpointers # 执行循环 for i in tf.range(1, seq_len): viterbi, backpointers = loop_fn(i, viterbi, backpointers) # 回溯找到最优路径 def get_best_path(i, best_tags, backpointers_t): best_tags = tf.concat([tf.expand_dims(backpointers_t[i, tf.range(batch_size), best_tags[:, 0]], 1), best_tags], axis=1) return i-1, best_tags, backpointers_t backpointers_t = tf.transpose(backpointers.stack(), perm=[1, 0, 2]) best_tags = tf.expand_dims(tf.argmax(viterbi.read(seq_len-1), axis=-1, output_type=tf.int32), 1) _, best_tags, _ = tf.while_loop( lambda i, *_: i >= 1, get_best_path, (seq_len-1, best_tags, backpointers_t) ) return best_tags4. 整合CRF层与模型训练
将实现的CRF层整合到模型中,以BIO标注任务为例:
class NERModel(tf.keras.Model): def __init__(self, num_tags): super().__init__() self.embedding = tf.keras.layers.Embedding(10000, 128) self.bilstm = tf.keras.layers.Bidirectional( tf.keras.layers.LSTM(64, return_sequences=True) ) self.dense = tf.keras.layers.Dense(num_tags) self.crf = CRF(num_tags) def call(self, inputs, targets=None, training=None): x = self.embedding(inputs) x = self.bilstm(x) logits = self.dense(x) if training: return self.crf(logits, targets) else: return self.crf(logits)训练时直接使用CRF层计算出的负对数似然作为损失函数:
model = NERModel(num_tags=3) optimizer = tf.keras.optimizers.Adam(0.001) @tf.function def train_step(x, y, lengths): with tf.GradientTape() as tape: logits = model(x, y, training=True) loss = sum(model.losses) grads = tape.gradient(loss, model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables)) return loss预测时使用维特比解码得到最优标签序列:
def predict(model, x, lengths): logits = model(x, training=False) return logits5. CRF实现中的关键细节
5.1 处理变长序列
NER任务中文本长度不一,需要正确处理变长序列。我们通过sequence_lengths参数和掩码机制实现:
# 创建变长序列的掩码 sequence_lengths = tf.constant([10, 7, 12], dtype=tf.int32) # 三个样本的实际长度 max_len = tf.reduce_max(sequence_lengths) mask = tf.sequence_mask(sequence_lengths, maxlen=max_len, dtype=tf.float32) # 应用掩码 transition_scores = transition_scores * mask5.2 数值稳定性
CRF计算涉及大量指数运算,容易导致数值不稳定。我们采用log空间计算提升稳定性:
# 原始空间计算(不稳定) exp_scores = tf.exp(some_scores) sum_exp = tf.reduce_sum(exp_scores) prob = exp_scores / sum_exp # log空间计算(稳定) log_scores = some_scores log_sum_exp = tf.reduce_logsumexp(log_scores) log_prob = log_scores - log_sum_exp5.3 转移矩阵约束
某些标签转移在业务中不可能发生(如B→I的非法转移),可通过约束转移矩阵实现:
# 定义不可能转移的掩码 constraint_mask = tf.constant([[0, 1, 1], # B不能转移到B [1, 0, 1], # I不能转移到I [1, 1, 0]], dtype=tf.float32) # 训练前应用约束 self.transitions.assign(self.transitions * constraint_mask)6. 完整代码示例
以下是完整可运行的TensorFlow 2.x CRF层实现:
import tensorflow as tf from tensorflow.keras.layers import Layer class CRF(Layer): def __init__(self, num_tags, **kwargs): super(CRF, self).__init__(**kwargs) self.num_tags = num_tags self.transitions = tf.Variable( tf.random.uniform(shape=(num_tags, num_tags)), name="transitions", trainable=True) def call(self, inputs, targets=None, sequence_lengths=None, training=None): emissions = inputs if training and targets is not None: log_likelihood = self._compute_log_likelihood(emissions, targets, sequence_lengths) self.add_loss(-log_likelihood) return self._viterbi_decode(emissions, sequence_lengths) def _compute_log_likelihood(self, emissions, tags, sequence_lengths): batch_size = tf.shape(emissions)[0] seq_len = tf.shape(emissions)[1] num_tags = tf.shape(emissions)[2] mask = tf.sequence_mask(sequence_lengths, maxlen=seq_len, dtype=tf.float32) # 计算发射分数 emit_scores = tf.gather_nd(emissions, tf.stack([tf.range(batch_size)[:, None], tf.range(seq_len)[None, :], tags], axis=-1)) emit_scores = tf.reduce_sum(emit_scores * mask, axis=1) # 计算转移分数 tags_transposed = tf.transpose(tags, perm=[1, 0]) prev_tags = tf.concat([tf.fill([1, batch_size], -1), tags_transposed[:-1]], axis=0) prev_tags = tf.transpose(prev_tags, perm=[1, 0]) transition_scores = tf.gather_nd(self.transitions, tf.stack([prev_tags, tags], axis=-1)) transition_scores = tf.reduce_sum(transition_scores * mask, axis=1) # 计算序列开始和结束的分数 start_tags = tf.gather(tags, [0], axis=1) start_scores = tf.gather_nd(self.transitions, tf.stack([tf.zeros_like(start_tags), start_tags], axis=-1)) end_tags = tf.gather(tags, sequence_lengths-1, axis=1, batch_dims=1) end_scores = tf.gather_nd(self.transitions, tf.stack([end_tags, tf.fill(tf.shape(end_tags), self.num_tags-1)], axis=-1)) # 计算对数似然 log_numerator = tf.reduce_sum(emit_scores + transition_scores + start_scores + end_scores) log_denominator = self._compute_log_partition_function(emissions, sequence_lengths) return log_numerator - log_denominator def _compute_log_partition_function(self, emissions, sequence_lengths): batch_size, seq_len, num_tags = tf.unstack(tf.shape(emissions)) log_alpha = tf.TensorArray(tf.float32, size=seq_len) init_alpha = tf.fill([batch_size, num_tags], -1e4) init_alpha = tf.tensor_scatter_nd_update(init_alpha, [[i, 0] for i in range(batch_size)], tf.zeros(batch_size)) log_alpha = log_alpha.write(0, init_alpha) mask = tf.sequence_mask(sequence_lengths, maxlen=seq_len, dtype=tf.float32) emissions_t = tf.transpose(emissions, perm=[1, 0, 2]) def loop_fn(i, log_alpha): prev_log_alpha = log_alpha.read(i-1) curr_emissions = emissions_t[i] log_alpha_i = prev_log_alpha[:, None] + self.transitions[None, :, :] log_alpha_i += curr_emissions[:, None, :] log_alpha_i = tf.reduce_logsumexp(log_alpha_i, axis=-1) log_alpha_i = log_alpha_i * mask[:, i, None] + log_alpha.read(i-1) * (1 - mask[:, i, None]) return log_alpha.write(i, log_alpha_i) for i in tf.range(1, seq_len): log_alpha = loop_fn(i, log_alpha) log_alpha_final = log_alpha.read(seq_len-1) log_z = tf.reduce_logsumexp(log_alpha_final + self.transitions[:, -1], axis=-1) return tf.reduce_sum(log_z) def _viterbi_decode(self, emissions, sequence_lengths): batch_size, seq_len, num_tags = tf.unstack(tf.shape(emissions)) viterbi = tf.TensorArray(tf.float32, size=seq_len) init_viterbi = tf.fill([batch_size, num_tags], -1e4) init_viterbi = tf.tensor_scatter_nd_update(init_viterbi, [[i, 0] for i in range(batch_size)], tf.zeros(batch_size)) viterbi = viterbi.write(0, init_viterbi) backpointers = tf.TensorArray(tf.int32, size=seq_len) init_ptrs = tf.fill([batch_size, num_tags], 0) backpointers = backpointers.write(0, init_ptrs) emissions_t = tf.transpose(emissions, perm=[1, 0, 2]) mask = tf.sequence_mask(sequence_lengths, maxlen=seq_len, dtype=tf.float32) def loop_fn(i, viterbi, backpointers): prev_viterbi = viterbi.read(i-1) curr_emissions = emissions_t[i] curr_viterbi = prev_viterbi[:, None] + self.transitions[None, :, :] curr_viterbi += curr_emissions[:, None, :] max_logp = tf.reduce_max(curr_viterbi, axis=-1) argmax_tags = tf.argmax(curr_viterbi, axis=-1, output_type=tf.int32) masked_max_logp = max_logp * mask[:, i, None] + prev_viterbi * (1 - mask[:, i, None]) masked_argmax_tags = argmax_tags * tf.cast(mask[:, i, None], tf.int32) + backpointers.read(i-1) * (1 - tf.cast(mask[:, i, None], tf.int32)) viterbi = viterbi.write(i, masked_max_logp) backpointers = backpointers.write(i, masked_argmax_tags) return viterbi, backpointers for i in tf.range(1, seq_len): viterbi, backpointers = loop_fn(i, viterbi, backpointers) def get_best_path(i, best_tags, backpointers_t): best_tags = tf.concat([tf.expand_dims(backpointers_t[i, tf.range(batch_size), best_tags[:, 0]], 1), best_tags], axis=1) return i-1, best_tags, backpointers_t backpointers_t = tf.transpose(backpointers.stack(), perm=[1, 0, 2]) best_tags = tf.expand_dims(tf.argmax(viterbi.read(seq_len-1), axis=-1, output_type=tf.int32), 1) _, best_tags, _ = tf.while_loop( lambda i, *_: i >= 1, get_best_path, (seq_len-1, best_tags, backpointers_t) ) return best_tags