top of page
shivamshinde92722

Clear Your Doubts About the Usage of Scikit-Learn’s Fit, Transform, and Predict Methods

It's important to know the distinction between fit, and fit_transform methods in Scikit-Learn and when it's appropriate to use them. Additionally, it's helpful to grasp the purpose and function of predict method.

If you’re a data science enthusiast, you may have struggled with distinguishing between the scikit-learn fit and fit-transform methods at some point. This article aims to clear up any confusion you may have. But before diving into the topic, let’s first go over some basic terminology.

Transformers and Estimators

Transformer refers to an object having fit and fit_transform methods. These two methods are famous for their use in cleaning, reducing, expanding, or generating features in the data. Some examples of transformers include scikit-learn OneHotEncoder, MinMaxScaler, SimpleImputer, etc.

Estimator refers to the trained machine or deep learning model. Estimators have basic methods like fit and prediction.

Use of fit, fit_transform, and predict method

Depending on the following two factors one can decide the suitable method to use:

  1. Type of object in question (Transformer or Estimator)

  2. Type of data in question (Training data or Test/New data)

In the case of a transformer dealing with the training data, a suitable method of usage should be fit_transform. In this case, rather than using the fit_transform method, we can also apply the fit method followed by the transform method (fit_transform method ~ fit method followed by transform method).

In the case of the transformer dealing with testing or new data, a suitable method of usage should be the transform method. Note that the transformer used in this case must be fitted on the training data before using transform operation on testing or new data.

While using the estimator in the code, the fit method should be used with the training data, and predict method should be used with the testing or new data.

In short…

The prediction method is usually not used on the training data since the estimator is trained on it. But sometimes, if you want to see the actual values vs predicted values visually by plotting actual and predicted values together, then you will need to use the predict method on training data.

Why not use the fit method on testing or new data?


Let’s understand this using one example. In the case of scikit-learn SimpleImputer transformer, the fit method learns the imputation value from the training data and uses it to transform the training data. Now ideally, we should use the same imputation value to fill in the missing values in the testing or new data.

But if we used the fit method again on the testing data, then the imputer will learn the imputation value (maybe mean/median/mode of the feature of data) from the testing data. Then, there won’t be any point in having train data if we are using data that should be foreign to the transformer for imputation.


Working of fit, transform, fit_transform, and predict method with the scikit-learn pipelines


There are two types of scikit-learn pipelines:

  1. The pipeline which ends with an estimator (e.g., classifier or regressor)

  2. The pipeline, which ends with a transformer

The pipeline type determines which methods you can use with it, and what those methods do.

If a pipeline ends with an estimator, you can use fit and predict methods:

  1. pipeline_object.fit(): All steps before the final one run the fit_transform method and the final step runs the fit method.

  2. pipeline_object.predict(): All steps before the final one run the transform method and the final step run runs predict method.

If a pipeline ends with a transformer, you can use transform, and fit_transform methods:

  1. pipeline_object.transform(): All steps run the transform method.

  2. pipeline_object.fit_transform(): All steps run fit_transform method

 

I hope you liked this article. Have a great day!

Comments


bottom of page