import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.neural_network import MLPRegressor
def get_class(T_class, info, num):
info = pd.DataFrame(info)[['StuName', 'StuClass', '总评']]
info = info.rename(columns={'总评': num})
data_group = info.groupby(info["StuClass"])
for key, value in data_group:
if (key == T_class):
info = value
info.reindex()
return info
data1_import = pd.read_csv('output/高等数学(一)-1_output.csv', encoding='GBK')
data2_import = pd.read_csv('output/高等数学(一)-2_output.csv', encoding='GBK')
data3_import = pd.read_csv('output/线性代数_output.csv', encoding='GBK')
info1 = get_class('18大数据1', data1_import, '高1')
info1 = pd.concat([info1, get_class('18大数据2', data1_import, '高1')], axis=0)
info2 = get_class('18大数据1', data2_import, '高2')
info2 = pd.concat([info2, get_class('18大数据2', data2_import, '高2')], axis=0)
info3 = get_class('18大数据1', data3_import, '线代')
info3 = pd.concat([info3, get_class('18大数据2', data3_import, '线代')], axis=0)
out = info1.join(info2, lsuffix='高1', rsuffix='高2')
out = out.join(info3)
print(out)
x1 = input('请输入高数一:')
x2 = input('请输入高数二:')
xin1 = ((int(x1) - info1['高1'].mean()) / info1['高1'].std())
xin2 = ((int(x2) - info2['高2'].mean()) / info2['高2'].std())
predict_input = pd.DataFrame({'高1': [xin1], '高2': [xin2]})
info1['高1'] = (info1['高1'] - info1['高1'].mean()) / info1['高1'].std()
info2['高2'] = (info2['高2'] - info2['高2'].mean()) / info2['高2'].std()
x_train = info1.join(info2, lsuffix='高1', rsuffix='高2')
x_train = x_train.loc[:, ('高1', '高2')]
print(x_train.shape)
y_train = info3['线代']
model = MLPRegressor(solver='lbfgs', activation='relu', hidden_layer_sizes=(12, 12), random_state=100)
model.fit(x_train, y_train)
score = model.score(x_train, y_train)
test = model.predict(x_train)
result = model.predict(predict_input)
plt.figure()
plt.plot(np.arange(len(y_train)), y_train, "bo-", label="真实值")
plt.plot(np.arange(len(test)), test, "ro-", label="预测值")
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.title(f"sklearn神经网络---拟合度:{score}\n高数一:{x1}--高数二:{x2}--线代预测值:{result}")
plt.legend(loc="best")
plt.show()
