Keras高级用法:回调、自定义层与多输入输出

🎙️ 语音朗读 当前: 晓晓 (温柔女声)

Keras高级用法:回调、自定义层与多输入输出

Keras作为TensorFlow的高级API,不仅简单易用,还提供了丰富的扩展能力。

回调机制(Callbacks)

回调是Keras训练过程中的钩子函数,可以在训练的不同阶段执行自定义逻辑:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from tensorflow import keras
from tensorflow.keras import layers, callbacks
import numpy as np

# 常用回调
my_callbacks = [
# 早停:验证损失不再改善时停止训练
callbacks.EarlyStopping(
monitor='val_loss',
patience=5,
restore_best_weights=True
),

# 学习率调度:验证损失停滞时降低学习率
callbacks.ReduceLROnPlateau(
monitor='val_loss',
factor=0.5,
patience=3,
min_lr=1e-7
),

# 模型检查点:保存最佳模型
callbacks.ModelCheckpoint(
'best_model.h5',
monitor='val_accuracy',
save_best_only=True,
mode='max'
),

# TensorBoard日志
callbacks.TensorBoard(
log_dir='./logs',
histogram_freq=1,
write_graph=True
),

# CSV日志
callbacks.CSVLogger('training_log.csv'),
]

model.fit(x_train, y_train, epochs=100,
validation_split=0.2, callbacks=my_callbacks)

自定义回调

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
class CustomCallback(keras.callbacks.Callback):
def __init__(self, threshold=0.95):
super().__init__()
self.threshold = threshold

def on_epoch_end(self, epoch, logs=None):
val_acc = logs.get('val_accuracy')
if val_acc and val_acc > self.threshold:
print(f'\n达到 {self.threshold:.0%} 验证精度,停止训练!')
self.model.stop_training = True

def on_train_batch_end(self, batch, logs=None):
if batch % 100 == 0:
print(f'\nBatch {batch}: loss={logs["loss"]:.4f}')

# 使用自定义回调
model.fit(x_train, y_train, epochs=100,
validation_split=0.2,
callbacks=[CustomCallback(threshold=0.98)])

自定义层

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
class DenseWithL2(layers.Layer):
"""带L2正则化的自定义全连接层"""

def __init__(self, units, activation=None, l2_lambda=0.01, **kwargs):
super().__init__(**kwargs)
self.units = units
self.activation = keras.activations.get(activation)
self.l2_lambda = l2_lambda

def build(self, input_shape):
self.w = self.add_weight(
shape=(input_shape[-1], self.units),
initializer='glorot_uniform',
regularizer=keras.regularizers.l2(self.l2_lambda),
trainable=True,
name='kernel'
)
self.b = self.add_weight(
shape=(self.units,),
initializer='zeros',
trainable=True,
name='bias'
)

def call(self, inputs):
output = tf.matmul(inputs, self.w) + self.b
if self.activation:
output = self.activation(output)
return output

def get_config(self):
config = super().get_config()
config.update({
'units': self.units,
'activation': keras.activations.serialize(self.activation),
'l2_lambda': self.l2_lambda
})
return config

自定义损失函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# 函数式自定义损失
def focal_loss(gamma=2.0, alpha=0.25):
def loss(y_true, y_pred):
y_pred = tf.clip_by_value(y_pred, 1e-7, 1 - 1e-7)
cross_entropy = -y_true * tf.math.log(y_pred)
weight = alpha * y_true * tf.math.pow(1 - y_pred, gamma)
return tf.reduce_mean(weight * cross_entropy)
return loss

# 类式自定义损失
class ContrastiveLoss(keras.losses.Loss):
def __init__(self, margin=1.0, **kwargs):
super().__init__(**kwargs)
self.margin = margin

def call(self, y_true, y_pred):
square_pred = tf.math.square(y_pred)
margin_square = tf.math.square(tf.maximum(self.margin - y_pred, 0))
return tf.reduce_mean(y_true * square_pred + (1 - y_true) * margin_square)

多输入模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# 文本+数值特征的多输入模型
text_input = keras.Input(shape=(100,), name='text')
text_features = layers.Embedding(10000, 64)(text_input)
text_features = layers.LSTM(64)(text_features)

numeric_input = keras.Input(shape=(10,), name='numeric')
numeric_features = layers.Dense(32, activation='relu')(numeric_input)

# 合并
concat = layers.concatenate([text_features, numeric_features])
x = layers.Dense(64, activation='relu')(concat)
x = layers.Dropout(0.3)(x)
output = layers.Dense(1, activation='sigmoid')(x)

model = keras.Model(
inputs=[text_input, numeric_input],
outputs=output
)

model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# 训练时传入字典
model.fit(
{'text': text_data, 'numeric': numeric_data},
labels,
epochs=10,
batch_size=32
)

多输出模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# 同时预测类别和属性的多输出模型
input_layer = keras.Input(shape=(224, 224, 3))

x = layers.Conv2D(32, 3, activation='relu')(input_layer)
x = layers.MaxPooling2D()(x)
x = layers.Conv2D(64, 3, activation='relu')(x)
x = layers.MaxPooling2D()(x)
x = layers.GlobalAveragePooling2D()(x)

# 分类输出
class_output = layers.Dense(10, activation='softmax', name='class')(x)

# 属性输出
attr_output = layers.Dense(5, activation='sigmoid', name='attributes')(x)

model = keras.Model(inputs=input_layer, outputs=[class_output, attr_output])

model.compile(
optimizer='adam',
loss={
'class': 'categorical_crossentropy',
'attributes': 'binary_crossentropy'
},
loss_weights={'class': 1.0, 'attributes': 0.5},
metrics=['accuracy']
)

自定义训练步骤

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class CustomModel(keras.Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.dense1 = layers.Dense(128, activation='relu')
self.dense2 = layers.Dense(10, activation='softmax')
self.loss_tracker = keras.metrics.Mean(name='loss')

def call(self, inputs):
x = self.dense1(inputs)
return self.dense2(x)

def train_step(self, data):
x, y = data

with tf.GradientTape() as tape:
y_pred = self(x, training=True)
loss = self.compiled_loss(y, y_pred)

gradients = tape.gradient(loss, self.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
self.loss_tracker.update_state(loss)

return {'loss': self.loss_tracker.result()}

@property
def metrics(self):
return [self.loss_tracker]

model = CustomModel()
model.compile(optimizer='adam', loss='categorical_crossentropy')
model.fit(x_train, y_train, epochs=10)

总结

Keras的高级特性——回调、自定义层、自定义损失、多输入输出模型和自定义训练步骤——为复杂场景提供了强大的扩展能力。回调机制实现训练过程的精细控制,自定义层和损失函数满足特殊需求,多输入输出模型处理复杂数据结构。掌握这些高级用法,才能在实际项目中灵活应对各种挑战。

© 2019-2026 ovo$^{mc^2}$ All Rights Reserved. | 站点总访问 28969 次 | 访客 19045
Theme by hiero