Hyperparameter tuning using GridSearchCV and KerasClassifier Last Updated : 20 Mar, 2024 Summarize Comments Improve Suggest changes Share Like Article Like Report Hyperparameter tuning is done to increase the efficiency of a model by tuning the parameters of the neural network. Some scikit-learn APIs like GridSearchCV and RandomizedSearchCV are used to perform hyper parameter tuning. In this article, you'll learn how to use GridSearchCV to tune Keras Neural Networks hyper parameters. Approach: We will wrap Keras models for use in scikit-learn using KerasClassifier which is a wrapper.We will use cross validation using KerasClassifier and GridSearchCVTune hyperparameters like number of epochs, number of neurons and batch size.Implementation of the scikit-learn classifier API for Keras: tf.keras.wrappers.scikit_learn.KerasClassifier( build_fn=None, **sk_params ) Code: python3 # import the libraries import tensorflow as tf import pandas as pd from sklearn.compose import ColumnTransformer from sklearn.preprocessing import OneHotEncoder from keras.wrappers.scikit_learn import KerasClassifier from sklearn.model_selection import GridSearchCV from sklearn.preprocessing import LabelEncoder from sklearn.preprocessing import StandardScaler Import the dataset using which we'll predict if a customer stays or leave. Code: Python3 # The last column is a binary value dataset = pd.read_csv('Churn_Modelling.csv') X = dataset.iloc[:, 3:-1].values y = dataset.iloc[:, -1].values Code: Preprocess the data Python3 le = LabelEncoder() X[:, 2] = le.fit_transform(X[:, 2]) #perform one hot encoding ct = ColumnTransformer(transformers=[('encoder', OneHotEncoder(), [1])], remainder='passthrough') X = np.array(ct.fit_transform(X)) # perform standardization of the data. sc = StandardScaler() X = sc.fit_transform(X) To use the KerasClassifier wrapper, we will need to build our model in a function which needs to be passed to the build_fn argument in the KerasClassifier constructor. Code: python3 def build_clf(unit): # creating the layers of the NN ann = tf.keras.models.Sequential() ann.add(tf.keras.layers.Dense(units=unit, activation='relu')) ann.add(tf.keras.layers.Dense(units=unit, activation='relu')) ann.add(tf.keras.layers.Dense(units=1, activation='sigmoid')) ann.compile(optimizer = 'adam', loss = 'binary_crossentropy', metrics = ['accuracy']) return ann Code: create the object of KerasClassifier class python3 model=KerasClassifier(build_fn=build_clf) Now we will create the dictionary of the parameters we want to tune and pass as an argument in GridSearchCV. Code: python3 params={'batch_size':[100, 20, 50, 25, 32], 'nb_epoch':[200, 100, 300, 400], 'unit':[5,6, 10, 11, 12, 15], } gs=GridSearchCV(estimator=model, param_grid=params, cv=10) # now fit the dataset to the GridSearchCV object. gs = gs.fit(X, y) The best_score_ member gives the best score observed during the optimization procedure and the best_params_ describes the combination of parameters that achieved the best results. Code: python3 best_params=gs.best_params_ accuracy=gs.best_score_ Output:Accuracy: 0.80325Best Params: {'batch_size': 20, 'nb_epoch': 200, 'unit': 15} Comment More infoAdvertise with us Next Article Python Tutorial - Learn Python Programming Language M maryamnadeem20 Follow Improve Article Tags : Machine Learning AI-ML-DS Tensorflow python Deep-Learning +1 More Practice Tags : Machine Learningpython Similar Reads Python Tutorial - Learn Python Programming Language Python is one of the most popular programming languages. Itâs simple to use, packed with features and supported by a wide range of libraries and frameworks. Its clean syntax makes it beginner-friendly. It'sA high-level language, used in web development, data science, automation, AI and more.Known fo 10 min read Machine Learning Tutorial Machine learning is a branch of Artificial Intelligence that focuses on developing models and algorithms that let computers learn from data without being explicitly programmed for every task. In simple words, ML teaches the systems to think and understand like humans by learning from the data.Machin 5 min read Linear Regression in Machine learning Linear regression is a type of supervised machine-learning algorithm that learns from the labelled datasets and maps the data points with most optimized linear functions which can be used for prediction on new datasets. It assumes that there is a linear relationship between the input and output, mea 15+ min read Support Vector Machine (SVM) Algorithm Support Vector Machine (SVM) is a supervised machine learning algorithm used for classification and regression tasks. It tries to find the best boundary known as hyperplane that separates different classes in the data. It is useful when you want to do binary classification like spam vs. not spam or 9 min read 100+ Machine Learning Projects with Source Code [2025] This article provides over 100 Machine Learning projects and ideas to provide hands-on experience for both beginners and professionals. Whether you're a student enhancing your resume or a professional advancing your career these projects offer practical insights into the world of Machine Learning an 5 min read Logistic Regression in Machine Learning Logistic Regression is a supervised machine learning algorithm used for classification problems. Unlike linear regression which predicts continuous values it predicts the probability that an input belongs to a specific class. It is used for binary classification where the output can be one of two po 11 min read K means Clustering â Introduction K-Means Clustering is an Unsupervised Machine Learning algorithm which groups unlabeled dataset into different clusters. It is used to organize data into groups based on their similarity. Understanding K-means ClusteringFor example online store uses K-Means to group customers based on purchase frequ 4 min read K-Nearest Neighbor(KNN) Algorithm K-Nearest Neighbors (KNN) is a supervised machine learning algorithm generally used for classification but can also be used for regression tasks. It works by finding the "k" closest data points (neighbors) to a given input and makesa predictions based on the majority class (for classification) or th 8 min read Backpropagation in Neural Network Back Propagation is also known as "Backward Propagation of Errors" is a method used to train neural network . Its goal is to reduce the difference between the modelâs predicted output and the actual output by adjusting the weights and biases in the network.It works iteratively to adjust weights and 9 min read Introduction to Convolution Neural Network Convolutional Neural Network (CNN) is an advanced version of artificial neural networks (ANNs), primarily designed to extract features from grid-like matrix datasets. This is particularly useful for visual datasets such as images or videos, where data patterns play a crucial role. CNNs are widely us 8 min read Like