top of page
shivamshinde92722

A Simple Approach to Creating Custom Transformers Using Scikit-Learn Classes

In this article, I will be explaining how to create a transformer according to our processing needs using Scikit-Learn classes.

Preprocessing the data is one of the most important steps in the data science lifecycle. Being a very popular machine learning library, Scikit-Learn has a lot of predefined transformers that help us transform our data into the required format.


However, there might be cases when one might want to perform some processing operation for which Scikit-Learn doesn’t have a suitable transformer. Fortunately, we can create our own transformer aka custom transformer to cater to our preprocessing needs using some of the Scikit-Learn classes very easily. Also, the reason why anyone would want to go through the process of creating a custom transformer (since we can also do those custom operations without a transformer) is that the custom transformers created using Scikit-Learn classes work very well with the Scikit-Learn functionality such as pipelines which makes our life very easy. If you want to learn more about the Scikit-Learn Pipelines then you can check out my other blog on the subject.


Let’s see how to create the custom transformers using Scikit-Learn classes.

 

First, let us understand the theory. To create a custom transformer, one only needs to implement three methods: fit(), transform() and, fit_transform().


You can get the fit_transform() method for free if the custom transformer class has a base class named TransformerMixin. Also, by making BaseEstimator class a base class of the custom transformer, we can get two methods namely get_params() and set_params() which will be very useful to get and set parameters after the hyperparameter tuning operation.


Now, let’s understand the role of fit() and transform() methods in our custom transformer. The fit() method is used to calculate the parameters that are needed for the data processing operation. The fit() method calculates the required parameters and then returns self (class instance). The transform method is where the parameters from fit() method are used to make a transformation of data. This method returns the transformed data. The fit() and transform() methods take three arguments:


  1. self

  2. X (independent features values)

  3. y (dependent feature values)

To fix this idea in our mind, let us take an example.


Let’s say we want to create a transformer that removes the outliers from the data. To perform this operation, we need to know the values of Q1 (quartile 1), Q3 (quartile 3), and the value of an interquartile range (i.e., Q3 — Q1). These values are calculated in the fit() method. After implementing the fit() method, we will find the outliers in transform() method using the parameters calculated in fit() method.


The other approach would be to calculate the parameter as well as to perform the transformation in transform() method only. In this case, fit() method will only return self (class instance).


 

Let’s try creating one custom transformer ourselves. Let us take the example of the outliers explained above. We will make a slight change in our code. Instead of removing the outliers completely, we will replace them with the median of the data.



Let’s understand the above code line by line.


  1. Our custom transformer named Outlier_Remover has two base classes namely BaseEstimator and TransformerMixin.

  2. We are taking the list of names of numerical features as an input to the custom transformer class using an init method. This is because we will be finding the outliers in numerical features only.

  3. We are not doing anything in fit() method

  4. All the work regarding the outliers is performed in transform method. First, we found out the values of quartile1, quartile2, and quartile3. Next, using the found values, we found the interquartile range.

  5. Now that we have values of quartile1, quartile3, and interquartile range, we will replace the outliers' values which are greater than the value of (quartile3 + 1.5 x interquartile range) and which are smaller than the value of (quartile1–1.5 x interquartile range) with the median of the data.

  6. Lastly, we have returned the transformed data i.e., X.


Let’s take another example: we will create a custom transformer to remove useless features (aka columns) from our data.



Again, let us understand this code line by line.


  1. Our custom transformer named Remove_Useless_Features has two base classes namely BaseEstimator and TransformerMixin. This will grant us fit_transform(), get_params() and set_params() methods.

  2. There is no need to implement anything in the init as well as the fit() method for this case.

  3. The transform() method makes use of the pandas drop() method to remove useless features from the dataset and return the dataset after transformation.

 

I hope the above examples made the custom transformers’ concept quite clear to you. If you have any doubts, please feel free to comment.


You can find step-by-step code for the examples used in the article here.


Have a great day!

Comments


bottom of page