Machine Learning with Python-Python | Decision Tree Regression using sklearn
Decision Tree is a decision-making tool that uses a flowchart-like tree structure or is a model of decisions and all of their possible results, including outcomes, input costs and utility.
Decision-tree algorithm falls under the category of supervised learning algorithms. It works for both continuous as well as categorical output variables.
The branches/edges represent the result of the node and the nodes have either:
- Conditions [Decision Nodes]
- Result [End Nodes]
The branches/edges represent the truth/falsity of the statement and takes makes a decision based on that in the example below which shows a decision tree that evaluates the smallest of three numbers:

Decision Tree Regression:
Decision tree regression observes features of an object and trains a model in the structure of a tree to predict data in the future to produce meaningful continuous output. Continuous output means that the output/result is not discrete, i.e., it is not represented just by a discrete, known set of numbers or values.
Discrete output example: A weather prediction model that predicts whether or not there’ll be rain in a particular day.
Continuous output example: A profit prediction model that states the probable profit that can be generated from the sale of a product.
Here, continuous values are predicted with the help of a decision tree regression model.
Let’s see the Step-by-Step implementation –
- Step 1: Import the required libraries.
# import numpy package for arrays and stuffimportnumpy as np# import matplotlib.pyplot for plotting our resultimportmatplotlib.pyplot as plt# import pandas for importing csv filesimportpandas as pd - Step 2: Initialize and print the Dataset.
# import dataset# dataset = pd.read_csv('Data.csv')# alternatively open up .csv file to read datadataset=np.array([['Asset Flip',100,1000],['Text Based',500,3000],['Visual Novel',1500,5000],['2D Pixel Art',3500,8000],['2D Vector Art',5000,6500],['Strategy',6000,7000],['First Person Shooter',8000,15000],['Simulator',9500,20000],['Racing',12000,21000],['RPG',14000,25000],['Sandbox',15500,27000],['Open-World',16500,30000],['MMOFPS',25000,52000],['MMORPG',30000,80000]])# print the datasetprint(dataset)
- Step 3: Select all the rows and column 1 from dataset to “X”.
# select all rows by : and column 1# by 1:2 representing featuresX=dataset[:,1:2].astype(int)# print Xprint(X)
- Step 4: Select all of the rows and column 2 from dataset to “y”.
# select all rows by : and column 2# by 2 to Y representing labelsy=dataset[:,2].astype(int)# print yprint(y)
- Step 5: Fit decision tree regressor to the dataset
# import the regressorfromsklearn.treeimportDecisionTreeRegressor# create a regressor objectregressor=DecisionTreeRegressor(random_state=0)# fit the regressor with X and Y dataregressor.fit(X, y)
- Step 6: Predicting a new value
# predicting a new value# test the output by changing values, like 3750y_pred=regressor.predict(3750)# print the predicted priceprint("Predicted price: % d\n"%y_pred)
- Step 7: Visualising the result
# arange for creating a range of values# from min value of X to max value of X# with a difference of 0.01 between two# consecutive valuesX_grid=np.arange(min(X),max(X),0.01)# reshape for reshaping the data into# a len(X_grid)*1 array, i.e. to make# a column out of the X_grid valuesX_grid=X_grid.reshape((len(X_grid),1))# scatter plot for original dataplt.scatter(X, y, color='red')# plot predicted dataplt.plot(X_grid, regressor.predict(X_grid), color='blue')# specify titleplt.title('Profit to Production Cost (Decision Tree Regression)')# specify X axis labelplt.xlabel('Production Cost')# specify Y axis labelplt.ylabel('Profit')# show the plotplt.show()
- Step 8: The tree is finally exported and shown in the TREE STRUCTURE below, visualized using http://www.webgraphviz.com/ by copying the data from the ‘tree.dot’ file.
# import export_graphvizfromsklearn.treeimportexport_graphviz# export the decision tree to a tree.dot file# for visualizing the plot easily anywhereexport_graphviz(regressor, out_file='tree.dot',feature_names=['Production Cost'])chevron_right
filter_none
Output (Decision Tree):
