频道栏目
首页 > 程序开发 > Web开发 > Python > 正文
python3.4之决策树
2016-12-23 09:18:00         来源:baidu_15113429的博客  
收藏   我要投稿
#!/usr/bin/env python
# coding=utf-8

import numpy as np
from sklearn import tree
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import classification_report
from sklearn.cross_validation import train_test_split
import pydot
from sklearn.externals.six import StringIO

def loadDataSet():
    data = []
    label = []
    with open('D:python/fat.txt') as file:
        for line in file:
            tokens = line.strip().split(' ')
            data.append([float(tk) for tk in tokens[:-1]])
            label.append(tokens[-1])
    x = np.array(data)
    print('x:')
    print(x)
    label = np.array(label)
    y = np.zeros(label.shape)
    y[label == 'fat'] = 1
    print('y:')
    print(y)
    return x, y

def decisionTreeClf():
    x, y = loadDataSet()

    # 拆分数据集和训练集
    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2)
    print('x_train:');
    print(x_train)
    print('x_test:');
    print(x_test)
    print('y_train:');
    print(y_train)
    print('y_test:');
    print(y_test)
    # 使用信息熵作为划分标准
    clf = tree.DecisionTreeClassifier(criterion='entropy')
    print(clf)
    clf.fit(x_train, y_train)
    dot_data = StringIO() 
    with open("iris.dot", 'w') as f: 
        f=tree.export_graphviz(clf, out_file=f)
        tree.export_graphviz(clf, out_file=dot_data)
        graph = pydot.graph_from_dot_data(dot_data.getvalue())  
        graph[0].write_pdf("ex.pdf")  
#         Image(graph.create_png())
    # 打印特征在分类起到的作用性
    print(clf.feature_importances_)

    # 打印测试结果
    answer = clf.predict(x_train)
    print('x_train:')
    print(x_train)
    print('answer:')
    print(answer)
    print('y_train:')
    print(y_train)
    print('计算正确率:')
    print(np.mean(answer == y_train))

    # 准确率与召回率
    precision, recall, thresholds = precision_recall_curve(y_train, clf.predict(x_train)
)
    answer = clf.predict_proba(x)[:, 1]
    print(classification_report(y, answer, target_names=['thin', 'fat']))

decisionTreeClf()
# print('ll')

数据集fat.txt文件内容如下:

1.5 50 thin
1.5 60 fat
1.6 40 thin
1.6 60 fat
1.7 60 thin
1.7 80 fat
1.8 60 thin
1.8 90 fat
1.9 70 thin
1.9 80 fat

所需要的Python包有:

pygraphviz (1.3.1)

pyparsing (2.1.10)

scikit-learn (0.18.1)

pygraphviz (1.3.1)包是可视化包。

下载可视化工具:

graphviz-2.38.msi

百度搜索安装可视化工具。

点击复制链接 与好友分享!回本站首页
上一篇:python之socket编程(二)
下一篇:文章标题
相关文章
图文推荐
点击排行

关于我们 | 联系我们 | 广告服务 | 投资合作 | 版权申明 | 在线帮助 | 网站地图 | 作品发布 | Vip技术培训 | 举报中心

版权所有: 红黑联盟--致力于做实用的IT技术学习网站