100字范文,内容丰富有趣,生活中的好帮手!
100字范文 > 机器学习实验:使用sklearn的决策树算法对葡萄酒数据集进行分类

机器学习实验:使用sklearn的决策树算法对葡萄酒数据集进行分类

时间:2022-07-10 09:54:30

相关推荐

机器学习实验:使用sklearn的决策树算法对葡萄酒数据集进行分类

机器学习实验:使用sklearn的决策树算法对葡萄酒数据集进行分类

问题如下:

使用sklearn的决策树算法对葡萄酒数据集进行分类,要求:

①划分训练集和测试集(测试集占20%)

②对测试集的预测类别标签和真实标签进行对比

③输出分类的准确率

④调整参数比较不同算法(ID3,C4.5,CART)的分类效果。

代码实现:

导入依赖包

#导入相关库import sklearnfrom sklearn.model_selection import train_test_splitfrom sklearn import tree #导入tree模块from sklearn.datasets import load_winefrom math import log2import pandas as pdimport graphvizimport treePlotter

导入数据集

#导入数据集wine = load_wine()X = wine.data #XY = wine.target #Yfeatures_name = wine.feature_namesprint(features_name)pd.concat([pd.DataFrame(X),pd.DataFrame(Y)],axis=1)#打印数据

划分数据集,数据集划分为测试集占20%;

#划分数据集,数据集划分为测试集占20%;x_train, x_test, y_train, y_test = train_test_split(X, Y,test_size=0.2)# print(x_train.shape) #(142, 13)# print(x_test.shape)#(36, 13)

导入模型,进行训练

#采用C4.5算法进行计算#获取模型model = tree.DecisionTreeClassifier(criterion="entropy",splitter="best",max_depth=None,min_samples_split=2,min_samples_leaf=1,min_weight_fraction_leaf=0.0,max_features=None,random_state=None,max_leaf_nodes=None,class_weight=None);model.fit(x_train,y_train)score = model.score(x_test,y_test)y_predict = model.predict(x_test)print('准确率为:',score)#准确率为: 0.9444444444444444

对测试集的预测类别标签和真实标签进行对比

pd.concat([pd.DataFrame(x_test),pd.DataFrame(y_test),pd.DataFrame(y_predict)],axis=1)#打印数据,对测试集的预测类别标签和真实标签进行对比

最后两列为真实标签和预测类别标签

调整参数比较不同算法(ID3,C4.5,CART)的分类效果

#采用CART算法进行计算#获取模型model = tree.DecisionTreeClassifier(criterion="gini",splitter="best",max_depth=None,min_samples_split=2,min_samples_leaf=1,min_weight_fraction_leaf=0.0,max_features=None,random_state=None,max_leaf_nodes=None,class_weight=None);model.fit(x_train,y_train)score = model.score(x_test,y_test)y_predict = model.predict(x_test)print('准确率为:',score)#准确率为: 1.0

画出最后预测的树

feature_name = ['酒精','苹果酸','灰','灰的碱性','镁','总酚','类黄酮','非黄烷类酚类','花青素','颜色强度','色调','od280/od315稀释葡萄酒','脯氨酸']dot_data = tree.export_graphviz(model,out_file=None,feature_names=feature_name,class_names=['二锅头','苦荞','江小白'],filled=True,rounded=True)graph = graphviz.Source(dot_data)graph#graph.render('tree')

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。