本文最后更新于14 天前,如有版本迭代或环境切变,可告知邮箱到xianghy_m@sina.com指正修改。
决策树是一个预测模型,它代表的是对象属性与对象值之间的一种映射关系。树中每个节点表示某个对象,而每个分叉路径则代表某个可能的属性值,而每个叶节点则对应从根节点到该叶节点所经历的路径所表示的对象的值。
目标:
路径:C:\Users\文件名.csv,文件名“test”,有三组数据,E列为决策结果,存放数值为0或1,F列为决策条件1,G列为决策条件2, 表格中的E列决策结果的值,是根据对应行的F和G共同决定的。
尝试通过决策树模型建立条件和结果的内在逻辑,借助决策树算法自动学习从而得到决策判据。第一步:给出python代码,要求标明代码注释;第二步:输出决策规则和决策推算表,也就是不同条件组合输出的决策结果;
背景:决策树方法在分类、预测、规则提取等领域有着广泛应用。20世纪70年代后期和80年代初期,机器学习研究者J.Ross Quinlan提出了ID3算法以后,决策树在机器学习、数据挖掘领域得到极大的发展。Quinlan后来又提出了C4.5,成为新的监督学习算法。1984年,几位统计学家 提出了CART分类算法。ID3和CART算法几乎同时被提出,但都是采用类似的方法从训练样本中学习决策树。——《Python数据分析与挖掘实战》
工具:
库名 | 主要功能 |
sklearn | 提供决策树分类器和回归器,支持模型训练和评估 |
matplotlib | 用于绘制图表,可视化决策树等模型结果(需结合其他工具) |
pandas | 用于数据整理和分析,方便处理输入给决策树的数据 |
准备工作:
pip install pandas
pip install chardet
pip install scikit-learn
pip install matplotlib
Mermaid格式输出
graph TD
A[条件1 <= 79.50] -->|Yes| B[条件2 <= 106.50]
B -->|Yes| C[条件2 <= 104.50]
C -->|Yes| D[class: 0]
C -->|No| E[class: 0]
B -->|No| F[条件1 <= 62.50]
F -->|Yes| G[class: 0]
F -->|No| H[class: 1]
A -->|No| I[条件1 > 79.50]
I -->|Yes| J[条件2 <= 98.50]
J -->|Yes| K[条件1 <= 116.50]
K -->|Yes| L[class: 0]
K -->|No| M[class: 1]
J -->|No| N[条件2 <= 102.50]
N -->|Yes| O[class: 1]
N -->|No| P[class: 1]
决策树
代码:
import pandas as pd
import chardet
from sklearn.tree import DecisionTreeClassifier, export_text
from sklearn import tree
import matplotlib.pyplot as plt
import itertools
# 1. 数据加载
file_path = r"文件路径"
# 尝试自动检测文件编码
with open(file_path, 'rb') as f:
result = chardet.detect(f.read())
detected_encoding = result['encoding']
print(f"检测到的编码格式:{detected_encoding}")
# 尝试使用检测到的编码读取文件
try:
df = pd.read_csv(file_path, encoding=detected_encoding)
except UnicodeDecodeError:
print(f"使用检测到的编码 {detected_encoding} 读取失败,尝试使用常见编码格式。")
try:
df = pd.read_csv(file_path, encoding='gbk') # 尝试使用 GBK 编码
except UnicodeDecodeError:
df = pd.read_csv(file_path, encoding='ISO-8859-1') # 尝试使用 ISO-8859-1 编码
print("列名:", df.columns)
# 2. 数据准备
df.columns = df.columns.str.strip() # 去除列名中的多余空格
try:
X = df[['决策条件1Index', '决策条件2Index']]
y = df['决策结果']
except KeyError as e:
print(f"列名错误:{e}")
print("请检查列名是否正确,并确保列名与文件中的列名一致。")
exit()
# 3. 模型训练
model = DecisionTreeClassifier(random_state=42, max_depth=3)
model.fit(X, y)
# 4. 输出决策规则
tree_rules = export_text(model, feature_names=['决策条件1Index', '决策条件2Index'])
print("决策树规则:")
print(tree_rules)
# 5. 可视化决策树(可选)
fig, ax = plt.subplots(figsize=(12, 8))
tree.plot_tree(model, feature_names=['决策条件1Index', '决策条件2Index'], class_names=['0', '1'], filled=True, ax=ax)
plt.title("决策树可视化")
plt.show()
# 6. 输出决策推算表
conditions = list(itertools.product([0, 1], repeat=2))
decision_table = pd.DataFrame(conditions, columns=['决策条件1Index', '决策条件2Index'])
decision_table['预测决策结果'] = model.predict(decision_table[['决策条件1Index', '决策条件2Index']])
print("\n决策推算表:")
print(decision_table)