什麼是 網格搜尋(Grid Search)?
網格搜尋是一種超參數調校方法,它窮舉超參數空間中所有可能的組合,並評估每個組合的模型效能。
核心概念
網格搜尋 (Grid Search) 是一種超參數優化技術,用於在預定義的超參數網格中,系統性地評估所有可能的超參數組合,以找到最佳的模型配置。它是一種窮舉搜索方法,適用於超參數空間相對較小的情況。其核心思想是將每個超參數的可能取值離散化,形成一個網格,然後遍歷網格中的每個節點(即每個超參數組合),訓練並評估模型,最終選擇在驗證集上表現最佳的超參數組合。
網格搜尋的優點是簡單易懂,容易實現。但缺點也很明顯,當超參數數量較多或每個超參數的取值範圍較大時,網格搜尋的計算成本會呈指數級增長,導致搜索時間過長,甚至無法完成。
運作原理
網格搜尋的運作原理可以分解為以下幾個步驟:
- 定義超參數空間: 首先,需要確定要優化的超參數,以及每個超參數的可能取值。這些取值可以是離散的,例如
learning_rate = [0.01, 0.1, 1.0],也可以是連續的,但需要將其離散化,例如C = [0.1, 1, 10, 100](用於 SVM 的正則化參數)。 - 構建超參數網格: 將所有超參數的可能取值組合起來,形成一個網格。例如,如果有兩個超參數,第一個超參數有3個取值,第二個超參數有4個取值,那麼網格中就有 3 * 4 = 12 個節點,每個節點代表一個超參數組合。
- 訓練和評估模型: 對於網格中的每個節點(即每個超參數組合),使用訓練數據訓練模型,並在驗證集上評估模型的效能。常用的評估指標包括準確度、精確度、召回率、F1分數、AUC等。為了提高評估的可靠性,通常會使用交叉驗證 (Cross-Validation)。
- 選擇最佳超參數: 根據驗證集上的效能,選擇最佳的超參數組合。通常選擇在交叉驗證中平均效能最高的超參數組合。
- 在測試集上評估: 使用最佳超參數組合訓練的模型,在測試集上進行最終評估,以評估模型的泛化能力。
交叉驗證 (Cross-Validation):
交叉驗證是一種評估模型泛化能力的技術,它可以更可靠地評估模型在未見數據上的表現。常用的交叉驗證方法包括 k 折交叉驗證 (k-fold Cross-Validation)。在 k 折交叉驗證中,將訓練數據分成 k 個子集(稱為 folds),然後依次使用其中一個子集作為驗證集,其餘 k-1 個子集作為訓練集,訓練模型並評估效能。重複 k 次,每次使用不同的子集作為驗證集,最後將 k 次評估結果的平均值作為模型的最終效能評估。
實際應用
網格搜尋廣泛應用於各種機器學習任務中,例如:
- 支持向量機 (SVM): 在 SVM 中,需要調整核函數類型、正則化參數 C、核函數參數 gamma 等超參數。
- 決策樹 (Decision Tree): 在決策樹中,需要調整樹的最大深度、最小葉節點樣本數等超參數。
- 隨機森林 (Random Forest): 在隨機森林中,需要調整樹的數量、樹的最大深度、最小葉節點樣本數等超參數。
- K 近鄰 (KNN): 在 KNN 中,需要調整鄰居數量 k、距離度量方法等超參數。
- 神經網路 (Neural Network): 在神經網路中,需要調整學習率、批次大小、隱藏層數量、每層的神經元數量、激活函數等超參數。
在實際應用中,可以利用各種工具和框架來簡化網格搜尋的過程,例如:
- Scikit-learn: Scikit-learn 提供了
GridSearchCV類,可以方便地進行網格搜尋。 - Keras Tuner: Keras Tuner 是一個用於 Keras 模型的超參數調校庫,它也支持網格搜尋。
- 其他 AutoML 工具: 許多 AutoML 工具也提供了網格搜尋的功能。
Scikit-learn 的 GridSearchCV 示例:
python from sklearn.model_selection import GridSearchCV from sklearn.svm import SVC from sklearn.model_selection import train_test_split
準備數據
X, y = ... # 你的數據 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
定義超參數網格
param_grid = { 'C': [0.1, 1, 10, 100], 'gamma': [0.001, 0.01, 0.1, 1], 'kernel': ['rbf'] }
創建 GridSearchCV 對象
grid = GridSearchCV(SVC(), param_grid, refit=True, verbose=3, cv=5)
訓練模型並進行網格搜尋
grid.fit(X_train, y_train)
獲取最佳超參數組合
print(grid.best_params_)
獲取最佳模型
best_model = grid.best_estimator_
在測試集上評估模型
accuracy = best_model.score(X_test, y_test) print(f'Accuracy: {accuracy}')
常見誤區
- 超參數空間定義不合理: 如果超參數空間定義不合理,例如取值範圍過窄或過寬,可能會導致無法找到最佳的超參數組合。應該根據經驗和領域知識,合理地定義超參數空間。
- 計算資源不足: 網格搜尋的計算成本較高,如果計算資源不足,可能會導致搜索時間過長,甚至無法完成。可以考慮使用更高效的超參數優化方法,例如隨機搜尋或貝葉斯優化。
- 沒有使用交叉驗證: 如果沒有使用交叉驗證,可能會導致模型過度擬合驗證集,從而在測試集上的泛化能力下降。應該使用交叉驗證來更可靠地評估模型的泛化能力。
- 忽略超參數之間的相互作用: 超參數之間可能存在相互作用,如果忽略這些相互作用,可能會導致找到次優的超參數組合。可以使用更複雜的超參數優化方法,例如貝葉斯優化,來考慮超參數之間的相互作用。
- 過度擬合驗證集: 即使使用了交叉驗證,仍然有可能過度擬合驗證集。為了避免這種情況,可以使用更嚴格的交叉驗證方法,例如嵌套交叉驗證 (Nested Cross-Validation),或者使用正則化技術來降低模型的複雜度。
相關術語
常見問題
延伸學習
想看 網格搜尋 的完整影片教學?前往 美第奇 AI 學院