這是一個在機器學習中用來「評估模型穩定度」的絕佳方法。您可以把它想像成舉辦 N 場模擬考,來取得一個更公平、更可靠的平均分數。
為什麼我們需要 N-fold?
問題情境: 假設您有 100 筆資料,想知道您的模型 (Model) 訓練後表現好不好。
- 最糟作法: 用全部 100 筆資料訓練,再用這 100 筆資料考試。
- 結果: 分數會非常高(例如 99%),但這是「作弊」,因為模型只是把答案背起來了 (Overfitting)。它遇到新資料時表現會很差。
- 稍好作法 (單次分割): 將資料分成「70 筆訓練集」和「30 筆測試集」。
- 結果: 您得到了一個分數(例如 85%)。但這個分數可信嗎?
- 隱藏問題: 如果您運氣不好,剛好那 30 筆測試資料都特別簡單(或特別難),這個 85% 的分數就不能代表模型的真正實力。
N-fold Cross Validation (N 折交叉驗證) 的解法
N-fold 的核心精神是:「與其只考一次試,不如我們把資料分成 N 份,輪流當 N 次考卷,再取平均!」
最常見的是 5-fold (N=5),我們以此為例:
- 切割資料:
- 首先,將全部 100 筆資料洗牌 (Shuffle)。
- 然後,將資料平均切成 5 份(每份 20 筆)。我們稱之為 Fold 1, Fold 2, Fold 3, Fold 4, Fold 5。
- 輪流考試 (共 5 輪):
- 第 1 輪:
- 訓練集: 拿 Fold 2, 3, 4, 5 (共 80 筆) 來訓練模型。
- 測試集: 拿 Fold 1 (共 20 筆) 來考試。
- 得到分數 1 (例如:85%)
- 第 2 輪:
- 訓練集: 拿 Fold 1, 3, 4, 5 (共 80 筆) 來訓練模型。
- 測試集: 拿 Fold 2 (共 20 筆) 來考試。
- 得到分數 2 (例如:90%)
- 第 3 輪:
- 訓練集: 拿 Fold 1, 2, 4, 5 (共 80 筆) 來訓練模型。
- 測試集: 拿 Fold 3 (共 20 筆) 來考試。
- 得到分數 3 (例如:80%)
- 第 4 輪:
- 訓練集: 拿 Fold 1, 2, 3, 5 (共 80 筆) 來訓練模型。
- 測試集: 拿 Fold 4 (共 20 筆) 來考試。
- 得到分數 4 (例如:88%)
- 第 5 輪:
- 訓練集: 拿 Fold 1, 2, 3, 4 (共 80 筆) 來訓練模型。
- 測試集: 拿 Fold 5 (共 20 筆) 來考試。
- 得到分數 5 (例如:82%)
- 第 1 輪:
- 計算總成績:
- 我們得到了 5 個分數:[85, 90, 80, 88, 82]。
- 最終平均分數 = (85 + 90 + 80 + 88 + 82) / 5 = 85%。
結論
這個 85% 的平均分數,遠比「單次分割」得到的 85% 要可靠 (Robust) 且穩定 (Stable)。
- N-fold 的好處:
- 更可靠的評估: 避免了因「運氣不好」切到特定資料而導致的評估偏差。
- 充分利用資料: 每一筆資料都有機會被當作「測試集」,也都有機會被當作「訓練集」。
- 觀察穩定度: 如果 5 次的分數是 [85, 86, 84, 85, 85],代表模型很穩定。如果分數是 [99, 60, 80, 95, 70],代表模型非常不穩定 (高變異 High Variance)。
- N-fold 的代價:
- 您需要訓練 N 次模型,計算成本是單次分割的 N 倍。(但這個代價通常是值得的)
總結一句話:N-fold Cross Validation 就是透過「輪流」當測試集並計算「平均成績」,來對模型進行一次更公平、更嚴謹的「模擬考」。
