改善想法

LightNetPlus 借鉴论文 MotionGRU + GAN 的改造方案(面向你现有 PyTorch 代码)

目标:在尽量少改动你当前 LightNetPlus 结构的前提下,借鉴论文里的

  1. MotionGRU(瞬态变化 + 趋势动量) 来增强多步滚动预测的稳定性与“运动一致性”
  2. GAN 对抗训练(cGAN / 序列判别) 来缓解“平均化模糊”,提升高密度雷电区域的细节、形态真实性

0. 你当前模型结构回顾(为什么好改)

你现有结构是典型的「多源编码 + 融合隐状态 + 自回归解码」:

  • 3 个 Encoder(WRF / LIG / AWS)分别提取时序特征,输出最后时刻隐状态 (h1,c1), (h2,c2), (h3,c3)
  • 通过 1×1 Conv 做通道对齐,然后 cat 得到融合的 (h,c),通道=64
  • Decoder 采用自回归循环预测 num_PRED 步,每一步用 current_input 生成 out,再喂回去

这非常契合 MotionGRU 的“每一步更新 motion state”,也很适合 GAN 让判别器直接看整段预测序列。


1. MotionGRU 借鉴思路:从“论文版”到“适配你代码的最小改动版”

论文 MotionGRU 的核心是维护两类状态:

  • F_t:运动表征(motion filter / transient + trend 的合成)
  • D_t:趋势动量(movement trend / momentum)

并通过类似 GRU 门控得到瞬态项 F'_t,再用动量更新 D_t,最后 F_t = F'_t + D_t

1.1 在你的网络里放在哪里(推荐位置)

强烈推荐放在 Decoder 的逐步预测循环里,因为你每一步都更新 ConvLSTM 隐状态 (h,c),并且滚动预测时误差会累积,MotionGRU 的趋势项最能缓解后期不稳定。

插入位置建议:

  • 你现在每步:ConvDown -> ConvLSTM -> Deconv -> out
  • 改为:ConvDown -> ConvLSTM -> MotionGRU(基于 lstm_out 或 h) -> Deconv -> out

1.2 选择“MotionGRU-lite”:先不做 warp,只做 motion state + 调制

论文有些版本会把 F_t 用于 warp(重采样对齐),但你现在的 Decoder 还没有显式 flow/warp 的结构。

建议先做一个“MotionGRU-lite”

  • 不做 warp
  • 只维护 F, D 并产生一个“运动增强特征” X
  • X 替代原本送给 deconv 的 lstm_out

这样改动量小、收益明显,且更容易训稳。


2. MotionGRU-lite 的模块定义(建议实现)

2.1 变量定义(与你的张量形状对齐)

在 decoder 中,lstm_out 的形状大概是:

  • lstm_out: (B, hidden_channels, H', W')
    其中 hidden_channels=64(你 decoder 的 ConvLSTM hidden_channels)

我们定义 MotionGRU-lite 的状态:

  • F: (B, C_f, H', W') (建议 C_f = hidden_channels,最省事)
  • D: (B, C_f, H', W')

2.2 门控设计(用 1~2 个卷积产生 u 和 z)

论文里瞬态更新是:

  • F'_t = u_t ⊙ z_t + (1-u_t) ⊙ F_{t-1}

在实现上,我们可以用当前 lstm_out 和上一步 F_prev 拼接,然后卷积输出 uz

  • u = sigmoid(Conv_u([lstm_out, F_prev]))
  • z = tanh(Conv_z([lstm_out, F_prev]))

2.3 趋势动量更新(EMA/动量累积)

论文趋势更新:

  • D_t = D_{t-1} + α (F_{t-1} - D_{t-1})

实现:

  • D = D + alpha * (F_prev - D)

2.4 合成 motion filter 并输出增强特征

  • F = F_transient + D
  • 用一个卷积把 F 映射回 hidden_channels,与 lstm_out 融合:

最简单的融合方式(建议从最简单开始):

  • X = lstm_out + Conv_f(F) (残差注入)

更强一点(门控融合):

  • g = sigmoid(Conv_g([lstm_out, F]))
  • X = g ⊙ lstm_out + (1-g) ⊙ Conv_f(F)

3. 在你的 Decoder 里怎么接(改动点清单)

3.1 Decoder.init 新增一个 motion 单元

新增:

  • self.motion = MotionGRULite(hidden_channels=hidden_channels, alpha=0.25, kernel=3)(alpha 你可以调)
  • (可选)如果你想对 motion state 降维,可以设 motion_channels < hidden_channels,但建议先同维省事。

3.2 Decoder.forward 里维护 F, D 两个状态

在进入 for i in range(num_pred) 之前初始化:

  • F = torch.zeros_like(h)(或 zeros with shape like lstm_out
  • D = torch.zeros_like(h)

然后在每个 step:

  • 你已有:lstm_out, h, c = self.convlstm(...)
  • 加一行:lstm_out, F, D = self.motion(lstm_out, F, D)

之后保持你原来的 deconv/out 流程不变。


4. GAN 对抗训练的改造方案(与你的输出形式对齐)

4.1 先明确你输出的形状(你现在做了 reshape)

你 decoder 最后返回:

  • decoder_outputs: (B, num_pred, H*W, 1)

但 GAN 判别器通常更适合吃 二维图像/序列图像

  • (B, T, 1, H, W)(B, 1, T, H, W)

建议训练阶段额外保留一个 decoder_outputs_map

  • 在 decoder 里 outputs.append(out.unsqueeze(1)) 那个地方,你已经有 (B,1,1,H,W)
  • 你可以同时返回:
    • pred_map = torch.cat(outputs, dim=1) # (B,T,1,H,W)
    • pred_flat = pred_map.view(B,T,-1,1) # 保持你原有 loss 用

这样:

  • 原来的监督 loss 不受影响(用 pred_flat
  • GAN 判别器用 pred_map 更自然

4.2 选择判别器形态:先 2D 拼帧 PatchGAN(最稳最省事)

把未来 T 帧沿通道拼起来:

  • pred_map: (B,T,1,H,W) -> pred_cat: (B, T, H, W)(把 1 合并成 channel)
  • 如果加条件 cond(最后观测帧):
    • cond: (B,1,H,W)
    • 判别器输入:inp = cat([cond, pred_cat], dim=1) -> (B, 1+T, H, W)

判别器输出一个 patch 网格(不是单标量),更利于学习局部纹理/形态:

  • D(inp) -> (B, 1, H_d, W_d)(每个 patch 一个真伪判断)

4.3 训练目标(推荐 hinge loss,更稳)

判别器:

  • L_D = mean(relu(1 - D(real))) + mean(relu(1 + D(fake)))

生成器对抗项:

  • L_adv = - mean(D(fake))

生成器总损失:

  • L_G = L_pred + λ_adv * L_adv

其中:

  • L_pred = 你当前的主损失(BCE/MSE/L1 等)
  • λ_adv 建议从 1e-31e-4 起步,先稳再调大

4.4 “real / fake” 样本怎么构造

  • real_future: Ground truth 未来 (B,T,1,H,W)
  • fake_future: 模型预测未来 (B,T,1,H,W)

判别器输入(含条件):

  • D_in_real = cat([cond_last, real_future_cat], dim=1)
  • D_in_fake = cat([cond_last, fake_future_cat], dim=1)

cond_last 最简单就是 “历史序列最后一帧 LIG”(你 decoder_inputs 或 encoder2_inputs 的最后一帧)
也可以用更强的条件:把 WRF/AWS 编码特征投影到图上,但那会大幅改结构,建议先不做。


5. 训练流程建议(避免 GAN 一上来崩)

5.1 预训练生成器(必做)

先只用 L_pred 训练若干 epoch,让模型:

  • 能预测出大致正确的形态与位置
  • 不会输出全 0 / 全噪声

5.2 再引入 GAN 交替训练

每个 iteration:

  1. 更新 D(1 次)
  • fake_future = G(...)(注意 fake_future.detach()
  • L_D,更新判别器
  1. 更新 G(1 次)
  • 再 forward 一次(或缓存计算图)
  • L_pred + λ_adv * L_adv,更新生成器

如果你发现 D 太强导致 G 学不动:

  • 降低 D 学习率
  • 或者 D:G 更新比改成 1:2(少训 D 多训 G)
  • 或者减小 λ_adv

6. 最小可实施改动清单(你照着做基本就能跑)

6.1 模型侧(代码改动)

  • 新增 MotionGRULite 模块文件(或写在同文件)
  • Decoder.__init__ 添加 self.motion
  • Decoder.forward 初始化 F,D 并每步更新 lstm_out
  • Decoder 同时返回 pred_map(B,T,1,H,W)与 pred_flat(保持原逻辑)

6.2 训练侧(代码改动)

  • 新建 DiscriminatorPatch2D(输入通道 1+T
  • 实现 hinge loss 的 L_D / L_adv
  • 训练分两阶段:
    • Stage1:只训 G(L_pred)
    • Stage2:交替训 D 和 G

7. 一些经验性超参起点(便于你快速稳定)

MotionGRU-lite

  • alpha:0.1 ~ 0.3(趋势更新速度)
  • kernel:3(足够)
  • F,D 通道:= hidden_channels(先同维)

GAN

  • λ_adv:从 1e-41e-3 起步
  • 判别器学习率:可略大于 G 或相同(不确定就先相同)
  • 判别器结构:3~5 层 stride-2 conv,最后输出 patch map

8. 你这套任务里,这两块各自“解决什么问题”

  • MotionGRU-lite:减少滚动预测后期的漂移、抖动、形态塌陷;让“整体运动趋势”更稳
  • GAN:减少预测模糊、增强高密度区域边缘/纹理、让形态更像真实样本分布

两者叠加通常是:MotionGRU 让“动得对/稳”,GAN 让“看起来真/锐”。


9. 下一步我需要你给的两点信息(用于把方案改成具体补丁代码)

为了我能直接给你“可粘贴的代码补丁”,我需要确认:

  1. 你的 L_pred 现在是什么?
    • BCE(像素二分类)还是 MSE/L1(回归)?
  2. 训练时 decoder_inputs 具体是什么(最后一帧 LIG?还是别的)?
    • 这决定 cGAN 的 cond_last 从哪里取最合理

你不想回答也没关系:我也可以按“最常见设置”(cond=最后一帧 LIG,L_pred=BCE)直接给一版默认实现。