
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "auto_examples/compose/plot_transformed_target.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        Click :ref:`here <sphx_glr_download_auto_examples_compose_plot_transformed_target.py>`
        to download the full example code

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_auto_examples_compose_plot_transformed_target.py:


======================================================
Effect of transforming the targets in regression model
======================================================

In this example, we give an overview of
:class:`~sklearn.compose.TransformedTargetRegressor`. We use two examples
to illustrate the benefit of transforming the targets before learning a linear
regression model. The first example uses synthetic data while the second
example is based on the Ames housing data set.

.. GENERATED FROM PYTHON SOURCE LINES 14-29

.. code-block:: default


    # Author: Guillaume Lemaitre <guillaume.lemaitre@inria.fr>
    # License: BSD 3 clause

    import numpy as np
    import matplotlib
    import matplotlib.pyplot as plt

    from sklearn.datasets import make_regression
    from sklearn.model_selection import train_test_split
    from sklearn.linear_model import RidgeCV
    from sklearn.compose import TransformedTargetRegressor
    from sklearn.metrics import median_absolute_error, r2_score
    from sklearn.utils.fixes import parse_version








.. GENERATED FROM PYTHON SOURCE LINES 30-32

Synthetic example
#############################################################################

.. GENERATED FROM PYTHON SOURCE LINES 32-39

.. code-block:: default


    # `normed` is being deprecated in favor of `density` in histograms
    if parse_version(matplotlib.__version__) >= parse_version("2.1"):
        density_param = {"density": True}
    else:
        density_param = {"normed": True}








.. GENERATED FROM PYTHON SOURCE LINES 40-51

A synthetic random regression dataset is generated. The targets ``y`` are
modified by:

  1. translating all targets such that all entries are
     non-negative (by adding the absolute value of the lowest ``y``) and
  2. applying an exponential function to obtain non-linear
     targets which cannot be fitted using a simple linear model.

Therefore, a logarithmic (`np.log1p`) and an exponential function
(`np.expm1`) will be used to transform the targets before training a linear
regression model and using it for prediction.

.. GENERATED FROM PYTHON SOURCE LINES 51-56

.. code-block:: default


    X, y = make_regression(n_samples=10000, noise=100, random_state=0)
    y = np.expm1((y + abs(y.min())) / 200)
    y_trans = np.log1p(y)








.. GENERATED FROM PYTHON SOURCE LINES 57-59

Below we plot the probability density functions of the target
before and after applying the logarithmic functions.

.. GENERATED FROM PYTHON SOURCE LINES 59-78

.. code-block:: default


    f, (ax0, ax1) = plt.subplots(1, 2)

    ax0.hist(y, bins=100, **density_param)
    ax0.set_xlim([0, 2000])
    ax0.set_ylabel("Probability")
    ax0.set_xlabel("Target")
    ax0.set_title("Target distribution")

    ax1.hist(y_trans, bins=100, **density_param)
    ax1.set_ylabel("Probability")
    ax1.set_xlabel("Target")
    ax1.set_title("Transformed target distribution")

    f.suptitle("Synthetic data", y=0.06, x=0.53)
    f.tight_layout(rect=[0.05, 0.05, 0.95, 0.95])

    X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)




.. image-sg:: /auto_examples/compose/images/sphx_glr_plot_transformed_target_001.png
   :alt: Synthetic data, Target distribution, Transformed target distribution
   :srcset: /auto_examples/compose/images/sphx_glr_plot_transformed_target_001.png
   :class: sphx-glr-single-img





.. GENERATED FROM PYTHON SOURCE LINES 79-84

At first, a linear model will be applied on the original targets. Due to the
non-linearity, the model trained will not be precise during
prediction. Subsequently, a logarithmic function is used to linearize the
targets, allowing better prediction even with a similar linear model as
reported by the median absolute error (MAE).

.. GENERATED FROM PYTHON SOURCE LINES 84-128

.. code-block:: default


    f, (ax0, ax1) = plt.subplots(1, 2, sharey=True)
    # Use linear model
    regr = RidgeCV()
    regr.fit(X_train, y_train)
    y_pred = regr.predict(X_test)
    # Plot results
    ax0.scatter(y_test, y_pred)
    ax0.plot([0, 2000], [0, 2000], "--k")
    ax0.set_ylabel("Target predicted")
    ax0.set_xlabel("True Target")
    ax0.set_title("Ridge regression \n without target transformation")
    ax0.text(
        100,
        1750,
        r"$R^2$=%.2f, MAE=%.2f"
        % (r2_score(y_test, y_pred), median_absolute_error(y_test, y_pred)),
    )
    ax0.set_xlim([0, 2000])
    ax0.set_ylim([0, 2000])
    # Transform targets and use same linear model
    regr_trans = TransformedTargetRegressor(
        regressor=RidgeCV(), func=np.log1p, inverse_func=np.expm1
    )
    regr_trans.fit(X_train, y_train)
    y_pred = regr_trans.predict(X_test)

    ax1.scatter(y_test, y_pred)
    ax1.plot([0, 2000], [0, 2000], "--k")
    ax1.set_ylabel("Target predicted")
    ax1.set_xlabel("True Target")
    ax1.set_title("Ridge regression \n with target transformation")
    ax1.text(
        100,
        1750,
        r"$R^2$=%.2f, MAE=%.2f"
        % (r2_score(y_test, y_pred), median_absolute_error(y_test, y_pred)),
    )
    ax1.set_xlim([0, 2000])
    ax1.set_ylim([0, 2000])

    f.suptitle("Synthetic data", y=0.035)
    f.tight_layout(rect=[0.05, 0.05, 0.95, 0.95])




.. image-sg:: /auto_examples/compose/images/sphx_glr_plot_transformed_target_002.png
   :alt: Synthetic data, Ridge regression   without target transformation, Ridge regression   with target transformation
   :srcset: /auto_examples/compose/images/sphx_glr_plot_transformed_target_002.png
   :class: sphx-glr-single-img





.. GENERATED FROM PYTHON SOURCE LINES 129-135

Real-world data set
##############################################################################

 In a similar manner, the Ames housing data set is used to show the impact
 of transforming the targets before learning a model. In this example, the
 target to be predicted is the selling price of each house.

.. GENERATED FROM PYTHON SOURCE LINES 135-148

.. code-block:: default


    from sklearn.datasets import fetch_openml
    from sklearn.preprocessing import QuantileTransformer, quantile_transform

    ames = fetch_openml(name="house_prices", as_frame=True)
    # Keep only numeric columns
    X = ames.data.select_dtypes(np.number)
    # Remove columns with NaN or Inf values
    X = X.drop(columns=["LotFrontage", "GarageYrBlt", "MasVnrArea"])
    y = ames.target
    y_trans = quantile_transform(
        y.to_frame(), n_quantiles=900, output_distribution="normal", copy=True
    ).squeeze()


.. rst-class:: sphx-glr-script-out

.. code-block:: pytb

    Traceback (most recent call last):
      File "/build/scikit-learn-HBxYkq/scikit-learn-1.0.2/examples/compose/plot_transformed_target.py", line 139, in <module>
        ames = fetch_openml(name="house_prices", as_frame=True)
      File "/build/scikit-learn-HBxYkq/scikit-learn-1.0.2/.pybuild/cpython3_3.9/build/sklearn/datasets/_openml.py", line 862, in fetch_openml
        data_info = _get_data_info_by_name(name, version, data_home)
      File "/build/scikit-learn-HBxYkq/scikit-learn-1.0.2/.pybuild/cpython3_3.9/build/sklearn/datasets/_openml.py", line 412, in _get_data_info_by_name
        json_data = _get_json_content_from_openml_api(
      File "/build/scikit-learn-HBxYkq/scikit-learn-1.0.2/.pybuild/cpython3_3.9/build/sklearn/datasets/_openml.py", line 175, in _get_json_content_from_openml_api
        return _load_json()
      File "/build/scikit-learn-HBxYkq/scikit-learn-1.0.2/.pybuild/cpython3_3.9/build/sklearn/datasets/_openml.py", line 69, in wrapper
        return f(*args, **kw)
      File "/build/scikit-learn-HBxYkq/scikit-learn-1.0.2/.pybuild/cpython3_3.9/build/sklearn/datasets/_openml.py", line 171, in _load_json
        with closing(_open_openml_url(url, data_home)) as response:
      File "/build/scikit-learn-HBxYkq/scikit-learn-1.0.2/.pybuild/cpython3_3.9/build/sklearn/datasets/_openml.py", line 118, in _open_openml_url
        with closing(urlopen(req)) as fsrc:
      File "/usr/lib/python3.9/urllib/request.py", line 214, in urlopen
        return opener.open(url, data, timeout)
      File "/usr/lib/python3.9/urllib/request.py", line 517, in open
        response = self._open(req, data)
      File "/usr/lib/python3.9/urllib/request.py", line 534, in _open
        result = self._call_chain(self.handle_open, protocol, protocol +
      File "/usr/lib/python3.9/urllib/request.py", line 494, in _call_chain
        result = func(*args)
      File "/usr/lib/python3.9/urllib/request.py", line 1389, in https_open
        return self.do_open(http.client.HTTPSConnection, req,
      File "/usr/lib/python3.9/urllib/request.py", line 1349, in do_open
        raise URLError(err)
    urllib.error.URLError: <urlopen error [Errno -2] Name or service not known>




.. GENERATED FROM PYTHON SOURCE LINES 149-152

A :class:`~sklearn.preprocessing.QuantileTransformer` is used to normalize
the target distribution before applying a
:class:`~sklearn.linear_model.RidgeCV` model.

.. GENERATED FROM PYTHON SOURCE LINES 152-171

.. code-block:: default


    f, (ax0, ax1) = plt.subplots(1, 2)

    ax0.hist(y, bins=100, **density_param)
    ax0.set_ylabel("Probability")
    ax0.set_xlabel("Target")
    ax0.text(s="Target distribution", x=1.2e5, y=9.8e-6, fontsize=12)
    ax0.ticklabel_format(axis="both", style="sci", scilimits=(0, 0))

    ax1.hist(y_trans, bins=100, **density_param)
    ax1.set_ylabel("Probability")
    ax1.set_xlabel("Target")
    ax1.text(s="Transformed target distribution", x=-6.8, y=0.479, fontsize=12)

    f.suptitle("Ames housing data: selling price", y=0.04)
    f.tight_layout(rect=[0.05, 0.05, 0.95, 0.95])

    X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=1)


.. GENERATED FROM PYTHON SOURCE LINES 172-179

The effect of the transformer is weaker than on the synthetic data. However,
the transformation results in an increase in :math:`R^2` and large decrease
of the MAE. The residual plot (predicted target - true target vs predicted
target) without target transformation takes on a curved, 'reverse smile'
shape due to residual values that vary depending on the value of predicted
target. With target transformation, the shape is more linear indicating
better model fit.

.. GENERATED FROM PYTHON SOURCE LINES 179-248

.. code-block:: default


    f, (ax0, ax1) = plt.subplots(2, 2, sharey="row", figsize=(6.5, 8))

    regr = RidgeCV()
    regr.fit(X_train, y_train)
    y_pred = regr.predict(X_test)

    ax0[0].scatter(y_pred, y_test, s=8)
    ax0[0].plot([0, 7e5], [0, 7e5], "--k")
    ax0[0].set_ylabel("True target")
    ax0[0].set_xlabel("Predicted target")
    ax0[0].text(
        s="Ridge regression \n without target transformation",
        x=-5e4,
        y=8e5,
        fontsize=12,
        multialignment="center",
    )
    ax0[0].text(
        3e4,
        64e4,
        r"$R^2$=%.2f, MAE=%.2f"
        % (r2_score(y_test, y_pred), median_absolute_error(y_test, y_pred)),
    )
    ax0[0].set_xlim([0, 7e5])
    ax0[0].set_ylim([0, 7e5])
    ax0[0].ticklabel_format(axis="both", style="sci", scilimits=(0, 0))

    ax1[0].scatter(y_pred, (y_pred - y_test), s=8)
    ax1[0].set_ylabel("Residual")
    ax1[0].set_xlabel("Predicted target")
    ax1[0].ticklabel_format(axis="both", style="sci", scilimits=(0, 0))

    regr_trans = TransformedTargetRegressor(
        regressor=RidgeCV(),
        transformer=QuantileTransformer(n_quantiles=900, output_distribution="normal"),
    )
    regr_trans.fit(X_train, y_train)
    y_pred = regr_trans.predict(X_test)

    ax0[1].scatter(y_pred, y_test, s=8)
    ax0[1].plot([0, 7e5], [0, 7e5], "--k")
    ax0[1].set_ylabel("True target")
    ax0[1].set_xlabel("Predicted target")
    ax0[1].text(
        s="Ridge regression \n with target transformation",
        x=-5e4,
        y=8e5,
        fontsize=12,
        multialignment="center",
    )
    ax0[1].text(
        3e4,
        64e4,
        r"$R^2$=%.2f, MAE=%.2f"
        % (r2_score(y_test, y_pred), median_absolute_error(y_test, y_pred)),
    )
    ax0[1].set_xlim([0, 7e5])
    ax0[1].set_ylim([0, 7e5])
    ax0[1].ticklabel_format(axis="both", style="sci", scilimits=(0, 0))

    ax1[1].scatter(y_pred, (y_pred - y_test), s=8)
    ax1[1].set_ylabel("Residual")
    ax1[1].set_xlabel("Predicted target")
    ax1[1].ticklabel_format(axis="both", style="sci", scilimits=(0, 0))

    f.suptitle("Ames housing data: selling price", y=0.035)

    plt.show()


.. rst-class:: sphx-glr-timing

   **Total running time of the script:** ( 0 minutes  0.992 seconds)


.. _sphx_glr_download_auto_examples_compose_plot_transformed_target.py:


.. only :: html

 .. container:: sphx-glr-footer
    :class: sphx-glr-footer-example



  .. container:: sphx-glr-download sphx-glr-download-python

     :download:`Download Python source code: plot_transformed_target.py <plot_transformed_target.py>`



  .. container:: sphx-glr-download sphx-glr-download-jupyter

     :download:`Download Jupyter notebook: plot_transformed_target.ipynb <plot_transformed_target.ipynb>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
