使用LSTM预测股票数据

数据集:股票数据集.

数据集来源::https://www.kaggle.com/dsadads/databases

1 加载数据集

import numpy as np
import pandas as pd
import datetime
stock = pd.read_csv('dataset/SH600519.csv')
stock_data = pd.read_csv('dataset/SH600519.csv')
stock_data.set_index(['date'], inplace=True)
stock_data
Unnamed: 0openclosehighlowvolumecode
date
2010-04-267488.70287.38189.07287.362107036.13600519
2010-04-277587.35584.84187.35584.68158234.48600519
2010-04-287684.23584.31885.12883.59726287.43600519
2010-04-297784.59285.67186.31584.59234501.20600519
2010-04-307883.87182.34083.87181.52385566.70600519
........................
2020-04-2024951221.0001227.3001231.5001216.80024239.00600519
2020-04-2124961221.0201200.0001223.9901193.00029224.00600519
2020-04-2224971206.0001244.5001249.5001202.22044035.00600519
2020-04-2324981250.0001252.2601265.6801247.77026899.00600519
2020-04-2424991248.0001250.5601259.8901235.18019122.00600519

2426 rows × 7 columns

2 绘制收盘价图

import matplotlib.pyplot as plt
from matplotlib import ticker # 调整坐标轴
from matplotlib.pylab import date2num # 日期转换
stock = stock[100:200]
stock['close'].plot(grid = True)
<AxesSubplot:>

在这里插入图片描述

3 计算涨跌幅

stock_data.shape[0]
2426
stock_data.iloc[101:102,].values
array([[1.75000e+02, 1.06990e+02, 1.08749e+02, 1.08858e+02, 1.06475e+02,
        1.85480e+04, 6.00519e+05]])
quote_change = []
for i in range(stock_data.shape[0]):
    if (i == 0):
        quote_change.append(0)
    else:
        today = stock_data.iloc[(i,1)]
        yestaday = stock_data.iloc[(i-1,1)]
        quote = (today - yestaday)/yestaday
        quote_change.append(np.array(quote,dtype=np.float))
stock_data['quote_change'] = quote_change
stock_data
Unnamed: 0openclosehighlowvolumecodequote_change
date
2010-04-267488.70287.38189.07287.362107036.136005190
2010-04-277587.35584.84187.35584.68158234.48600519-0.015185677887758948
2010-04-287684.23584.31885.12883.59726287.43600519-0.03571632991815013
2010-04-297784.59285.67186.31584.59234501.206005190.00423814328960645
2010-04-307883.87182.34083.87181.52385566.70600519-0.008523264611310805
...........................
2020-04-2024951221.0001227.3001231.5001216.80024239.006005190.00909090909090909
2020-04-2124961221.0201200.0001223.9901193.00029224.006005191.6380016380001484e-05
2020-04-2224971206.0001244.5001249.5001202.22044035.00600519-0.012301190807685363
2020-04-2324981250.0001252.2601265.6801247.77026899.006005190.03648424543946932
2020-04-2424991248.0001250.5601259.8901235.18019122.00600519-0.0016

2426 rows × 8 columns

20天最大涨幅的计算
len(stock_data)
2426
封装成函数
def up(min_data ,i , m):
    if(min_data > stock_data.iloc[i - m,1]):
        min_data = stock_data.iloc[i - m,1]
    return min_data
def down(min_data,i,k):
    if(min_data > stock_data.iloc[(i + k,1)]):
        min_data = stock_data.iloc[(i + k,1)]
    return min_data
sequence = 20

new_feature = []

for i in range(stock_data.shape[0]):
    min_data = stock_data.iloc[i,1]
    # 当i<10时,向上寻找i中最小值 向下寻找十天的最小值
    if (i < 10):
        for m in range(i):
            min_data = up(min_data ,i ,m)       
        for k in range(10):
            min_data = down(min_data,i,k)
    if (i > (stock.shape[0]-10)):                  
        for j in range(10):
            min_data = up(min_data ,i ,j)   
        for n in range(stock.shape[0]-i):   
            min_data = down(min_data,i,n) 
    else:
        for j in range(10):
            min_data = up(min_data,i,j)
        for k in range(10):
            min_data = down(min_data,i,k)
        
    new_feature.append(np.array((stock_data.iloc[(i,1)]-min_data)/min_data,dtype=np.float))
直接求
sequence = 20

new_feature = []

for i in range(stock_data.shape[0]):
    min_data = stock_data.iloc[i,1]
    # 当i<10时,向上寻找i中最小值 向下寻找十天的最小值
    if (i < 10):
        for m in range(i):
            if(min_data > stock_data.iloc[i-m,1]):
                min_data = stock_data.iloc[i-m,1]       
        for k in range(10):
            if(min_data > stock_data.iloc[(i + k,1)]):
                min_data = stock_data.iloc[(i + k,1)]
    if (i > (stock.shape[0]-10)):                  
        for j in range(10):
            if(min_data > stock_data.iloc[(i - j,1)]):
                min_data = stock_data.iloc[(i - j,1)]  
        for n in range(stock.shape[0]-i):   
             if(min_data > stock_data.iloc[(i + n,1)]):
                min_data = stock_data.iloc[(i + n,1)]  
    else:
        for j in range(10):
            if(min_data > stock_data.iloc[(i - j,1)]):
                min_data = stock_data.iloc[(i - j,1)]  
        for k in range(10):
            if(min_data > stock_data.iloc[(i + k,1)]):
                min_data = stock_data.iloc[(i + k,1)]
        
    new_feature.append(np.array((stock_data.iloc[(i,1)]-min_data)/min_data,dtype=np.float))

new_feature

len(new_feature)
2426
stock_data['max_increase'] = new_feature
stock_data
Unnamed: 0openclosehighlowvolumecodequote_changemax_increase
date
2010-04-267488.70287.38189.07287.362107036.1360051900.10317637987214874
2010-04-277587.35584.84187.35584.68158234.48600519-0.0151856778877589480.08642389871402628
2010-04-287684.23584.31885.12883.59726287.43600519-0.035716329918150130.0476208243165932
2010-04-297784.59285.67186.31584.59234501.206005190.004238143289606450.052060791483222554
2010-04-307883.87182.34083.87181.52385566.70600519-0.0085232646113108050.043093798970225965
..............................
2020-04-2024951221.0001227.3001231.5001216.80024239.006005190.009090909090909090.059895833333333336
2020-04-2124961221.0201200.0001223.9901193.00029224.006005191.6380016380001484e-050.05991319444444443
2020-04-2224971206.0001244.5001249.5001202.22044035.00600519-0.0123011908076853630.04155871074722759
2020-04-2324981250.0001252.2601265.6801247.77026899.006005190.036484245439469320.07955919438974668
2020-04-2424991248.0001250.5601259.8901235.18019122.00600519-0.00160.07124463519313305

2426 rows × 9 columns

选择use_cols作为特征

#X.append(np.array(stock_data.iloc[i:(i+sequence),].values, dtype=np.float))
# label 取当前日期后的30天收盘价涨幅
#y.append(np.array(stock_data.iloc[(i + sequence,5)],dtype=np.float))

4 归一化

sklearn
from sklearn.preprocessing import MinMaxScaler 
# stock_data = stock_data[2000:2420]
columns = ['open','close','high','low','volume','quote_change','max_increase']
stock_data = stock_data[columns]
stock_data
openclosehighlowvolumequote_changemax_increase
date
2010-04-2688.70287.38189.07287.362107036.1300.10317637987214874
2010-04-2787.35584.84187.35584.68158234.48-0.0151856778877589480.08642389871402628
2010-04-2884.23584.31885.12883.59726287.43-0.035716329918150130.0476208243165932
2010-04-2984.59285.67186.31584.59234501.200.004238143289606450.052060791483222554
2010-04-3083.87182.34083.87181.52385566.70-0.0085232646113108050.043093798970225965
........................
2020-04-201221.0001227.3001231.5001216.80024239.000.009090909090909090.059895833333333336
2020-04-211221.0201200.0001223.9901193.00029224.001.6380016380001484e-050.05991319444444443
2020-04-221206.0001244.5001249.5001202.22044035.00-0.0123011908076853630.04155871074722759
2020-04-231250.0001252.2601265.6801247.77026899.000.036484245439469320.07955919438974668
2020-04-241248.0001250.5601259.8901235.18019122.00-0.00160.07124463519313305

2426 rows × 7 columns

scaler = MinMaxScaler()
stock_scaler = scaler.fit_transform(stock_data)
stock_scaler = pd.DataFrame(stock_scaler)
stock_scaler.columns = columns
stock_scaler
openclosehighlowvolumequote_changemax_increase
00.0070930.0058060.0065090.0062300.3535560.4369610.299871
10.0059410.0036380.0050590.0039340.1803170.3782080.251182
20.0032740.0031920.0031790.0030060.0669090.2987760.138405
30.0035790.0043470.0041810.0038580.0960670.4533580.151309
40.0029630.0015040.0021180.0012300.2773420.4039850.125247
........................
24210.9752050.9786970.9711390.9734770.0596370.4721330.174081
24220.9752220.9553970.9647980.9530950.0773330.4370240.174131
24230.9623800.9933770.9863380.9609910.1299100.3893680.120786
24241.0000001.0000001.0000001.0000000.0690800.5781170.231230
24250.9982900.9985490.9951110.9892180.0414730.4307700.207065

2426 rows × 7 columns

# 为了归一化后复现原来数据
close_min = stock_data['quote_change'].min()
close_max = stock_data['quote_change'].max()
# 归一化处理(0,1)
stock=stock_data.apply(lambda x:(x-min(x))/(max(x)-min(x)))
stock
openclosehighlowvolumequote_changemax_increase
date
2010-04-260.0070930.0058060.0065090.0062300.3535560.4369610.299871
2010-04-270.0059410.0036380.0050590.0039340.1803170.3782080.251182
2010-04-280.0032740.0031920.0031790.0030060.0669090.2987760.138405
2010-04-290.0035790.0043470.0041810.0038580.0960670.4533580.151309
2010-04-300.0029630.0015040.0021180.0012300.2773420.4039850.125247
........................
2020-04-200.9752050.9786970.9711390.9734770.0596370.4721330.174081
2020-04-210.9752220.9553970.9647980.9530950.0773330.4370240.174131
2020-04-220.9623800.9933770.9863380.9609910.1299100.3893680.120786
2020-04-231.0000001.0000001.0000001.0000000.0690800.5781170.23123
2020-04-240.9982900.9985490.9951110.9892180.0414730.430770.207065

2426 rows × 7 columns

5 前20天的数据预测之后的数据

stock = stock_scaler
pd.DataFrame(stock.iloc[201:400,].values)
0123456
00.0296740.0294360.0291610.0294940.0554480.4242310.000000
10.0293930.0286500.0286550.0285070.0875250.4259030.000000
20.0288810.0274350.0284860.0278110.1075460.4167710.000000
30.0276310.0279460.0274480.0276740.0919840.3874240.000000
40.0280940.0287430.0278560.0284130.0927010.4555290.013949
........................
1940.0518610.0519450.0510490.0520220.0287540.3808210.000000
1950.0514500.0524100.0522710.0518080.0387000.4237960.000000
1960.0518300.0509920.0515930.0511020.0266450.4491800.009179
1970.0525220.0545920.0534680.0528750.0823920.4591550.025905
1980.0543410.0541150.0531120.0536780.0639580.4950080.069899

199 rows × 7 columns

stock
openclosehighlowvolumequote_changemax_increase
00.0070930.0058060.0065090.0062300.3535560.4369610.299871
10.0059410.0036380.0050590.0039340.1803170.3782080.251182
20.0032740.0031920.0031790.0030060.0669090.2987760.138405
30.0035790.0043470.0041810.0038580.0960670.4533580.151309
40.0029630.0015040.0021180.0012300.2773420.4039850.125247
........................
24210.9752050.9786970.9711390.9734770.0596370.4721330.174081
24220.9752220.9553970.9647980.9530950.0773330.4370240.174131
24230.9623800.9933770.9863380.9609910.1299100.3893680.120786
24241.0000001.0000001.0000001.0000000.0690800.5781170.231230
24250.9982900.9985490.9951110.9892180.0414730.4307700.207065

2426 rows × 7 columns

# 序列长度为30,即用前一个月的数据预测之后一天的数据
sequence = 20

X = []
y = []
label = []

for i in range(stock.shape[0]-sequence):
    # 选择use_cols作为特征
    X.append(np.array(stock.iloc[i:(i+sequence),].values, dtype=np.float))
    # 选择20天收盘价涨幅
    y.append(np.array(stock.iloc[(i+sequence),5],dtype=np.float))
len(X) , len(X[1]) 
(2406, 20)
len(y)
2406
划分数据集
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X,y, test_size=0.2,random_state=42)
len(X_train)
1924
len(X_test)
482
import torch
import torch.utils.data as Data 
torch.manual_seed(1)
<torch._C.Generator at 0x2393bdb43f0>
# list -> numpy
X_train = np.array(X_train)
y_train = np.array(y_train)
X_test = np.array(X_test)
y_test = np.array(y_test)

# numpy -> torch
X_train = torch.from_numpy(X_train)
y_train = torch.from_numpy(y_train)
X_test = torch.from_numpy(X_test)
y_test = torch.from_numpy(y_test)

print('X_train size: ', X_train.size())
print('y_train size: ', y_train.size())
print('X_test size: ', X_test.size())
print('y_test size: ', y_test.size())
X_train size:  torch.Size([1924, 20, 7])
y_train size:  torch.Size([1924])
X_test size:  torch.Size([482, 20, 7])
y_test size:  torch.Size([482])

6 定义网络模型

#  批处理 batch的大小为32
train_data = Data.TensorDataset(X_train, y_train)
test_data = Data.TensorDataset(X_test, y_test)

train_loader = Data.DataLoader(
    dataset=train_data,
    batch_size=32,
    shuffle=True,
    num_workers=2
)

test_loader = Data.DataLoader(
    dataset=test_data,
    batch_size=32,
    shuffle=True,
    num_workers=2
)
input_size = 7
seq_len = 20
hidden_size = 32
output_size = 1
import torch.nn as nn
import torch.nn.functional as F 
from torch.autograd import Variable
class MyNet(nn.Module):  
    def __init__(self, input_size=input_size, hidden_size=hidden_size, output_size=output_size):
        super(MyNet, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, batch_first=True)
        self.fc = nn.Linear(self.hidden_size*seq_len, self.output_size)            
    def forward(self, input):
        out,_ = self.lstm(input)
        b, s, h = out.size()
        out = self.fc(out.reshape(b, s*h))
        return out 

net = MyNet()
print(net)
MyNet(
  (lstm): LSTM(7, 32, batch_first=True)
  (fc): Linear(in_features=640, out_features=1, bias=True)
)

7 选择损失函数和优化器

import torch.optim as optim 
from tqdm import tqdm

loss_function = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=0.01)

for epoch in tqdm(range(100)):
    total_loss = 0
    for _,(data, label) in enumerate(train_loader):
        data = Variable(data).float()
        pred = net(data)
        label = Variable(label).float()
        label = label.unsqueeze(1)
        loss = loss_function(pred, label)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        total_loss += loss.item()
        
    if (epoch + 1) % 10 == 0:
        print('Epoch: ', epoch+1, ' loss: ', total_loss)
 10%|████████                                                                         | 10/100 [00:15<02:20,  1.56s/it]

Epoch:  10  loss:  0.39631053362973034


 20%|████████████████▏                                                                | 20/100 [00:30<01:59,  1.49s/it]

Epoch:  20  loss:  0.38432611781172454


 30%|████████████████████████▎                                                        | 30/100 [00:46<01:48,  1.56s/it]

Epoch:  30  loss:  0.3594463015906513


 40%|████████████████████████████████▍                                                | 40/100 [01:02<01:37,  1.62s/it]

Epoch:  40  loss:  0.2304799237754196


 50%|████████████████████████████████████████▌                                        | 50/100 [01:17<01:17,  1.55s/it]

Epoch:  50  loss:  0.19538419507443905


 60%|████████████████████████████████████████████████▌                                | 60/100 [01:33<01:06,  1.66s/it]

Epoch:  60  loss:  0.15910959872417152


 70%|████████████████████████████████████████████████████████▋                        | 70/100 [01:50<00:51,  1.73s/it]

Epoch:  70  loss:  0.10108215303625911


 80%|████████████████████████████████████████████████████████████████▊                | 80/100 [02:07<00:32,  1.63s/it]

Epoch:  80  loss:  0.09773806459270418


 90%|████████████████████████████████████████████████████████████████████████▉        | 90/100 [02:23<00:16,  1.67s/it]

Epoch:  90  loss:  0.06839053201838396


100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [02:41<00:00,  1.62s/it]

Epoch:  100  loss:  0.06617061665747315

8 测试

pred_list = []
label_list = []

for _, (data, label) in enumerate(test_loader):
    data = Variable(data).float()
    pred = net(data)
    pred_list.extend(pred.data.squeeze(1).tolist())
    label_list.extend(label.tolist())
pred_list[:5]
[0.4532894492149353,
 0.4373990297317505,
 0.4552455246448517,
 0.3734513521194458,
 0.4417729675769806]
len(pred_list)
482
label_list[:5]
[0.4526440093601085,
 0.4235539693300972,
 0.4670736837074137,
 0.3081553795002458,
 0.5096712136068774]

9 可视化

import matplotlib.pyplot as plt  
plt.rcParams['font.sans-serif'] = [u'SimHei']
plt.rcParams['axes.unicode_minus'] = False
import matplotlib.pyplot as plt 

plt.figure(figsize=(20,6))

plt.plot([i*(close_max-close_min)+close_min for i in pred_list[:50]] , label='pred')
plt.plot([i*(close_max-close_min)+close_min for i in label_list[:50]], label='real')
plt.title('Stock Forecast(前50条数据)')
plt.legend()
plt.show()

在这里插入图片描述

import matplotlib.pyplot as plt 
plt.figure(figsize=(20,6))
plt.plot([i*(close_max-close_min)+close_min for i in pred_list[50:100]] , label='pred')
plt.plot([i*(close_max-close_min)+close_min for i in label_list[50:100]], label='real')
plt.title('Stock Forecast')
plt.legend()
plt.show()

在这里插入图片描述

import matplotlib.pyplot as plt 

plt.figure(figsize=(20,6))
plt.plot(pred_list[400:480] , label='pred')
plt.plot(label_list[400:480], label='real')
plt.title('Stock Forecast(第400条到第480条数据)')
plt.legend()
plt.savefig('dataset/some.jpg')
plt.show()

在这里插入图片描述

import matplotlib.pyplot as plt 
plt.figure(figsize=(20,6))
plt.plot(pred_list , label='pred')
plt.plot(label_list, label='real')
plt.title('Stock Forecast(测试集所有数据)')
plt.legend()
plt.savefig('dataset/pred_real.jpg')
plt.show()

在这里插入图片描述

import matplotlib.pyplot as plt 
plt.figure(figsize=(20,6))
plt.plot([i*(close_max-close_min)+close_min for i in pred_list[:482]] , label='pred')
plt.plot([i*(close_max-close_min)+close_min for i in label_list[:482]], label='real')
plt.title('Stock Forecast(测试集所有数据(还原数据))')
plt.legend()
plt.savefig('dataset/all.jpg')
plt.show()

在这里插入图片描述

10 保存图片

import matplotlib.pyplot as plt 
plt.figure(figsize=(20,6))
plt.plot([i*(close_max-close_min)+close_min for i in pred_list[:482]] , label='pred')
plt.title('Stock Forecast pred')
plt.legend()
plt.savefig('dataset/pred.jpg')
plt.show()

在这里插入图片描述

import matplotlib.pyplot as plt 
plt.figure(figsize=(20,6))
plt.plot([i*(close_max-close_min)+close_min for i in label_list[:482]], label='real')
plt.title('Stock Forecast (real)')
plt.legend()
plt.savefig('dataset/real.jpg')
plt.show()

在这里插入图片描述

11 计算相似度

len(label_list)
482
欧式距离
sum_all = 0
for i in range(len(label_list)):
    sum_all = sum_all + (label_list[i] - pred_list[i])**2
sum_all
1.1392427957537818
DTW
pred_some = pred_list[50:100]
label_some = label_list[50:100]
pred_some = pred_list[:482]
label_some = label_list[0:482]
欧式距离矩阵
distances = np.zeros((len(pred_some), len(pred_some)))
for i in range(len(pred_some)):
    for j in range(len(label_some)):
        distances[i,j] = (label_some[j]-pred_some[i])**2 
len(distances)
482

计算两个序列的距离矩阵。横着表示x序列,竖着是y序列。
比如说第0行第0个元素1表示x序列的第0个值和y序列的第0个值的距离(Python的索引从0开始)

颜色越深表示距离越远

欧式距离矩阵可视化

def distance_cost_plot(distances):
    plt.figure(figsize=(20,6))
    plt.imshow(distances, interpolation='nearest', cmap='Reds') 
    plt.gca().invert_yaxis()#倒转y轴,让它与x轴的都从左下角开始
    plt.xlabel("X")
    plt.ylabel("Y")
#    plt.grid()
    plt.colorbar()
distance_cost_plot(distances)

在这里插入图片描述

x = pred_some
y = label_some
# 计算一个累积距离矩阵
accumulated_cost = np.zeros((len(pred_some), len(label_some)))
accumulated_cost[0,0] = distances[0,0]
pd.DataFrame(accumulated_cost)
0123456789...472473474475476477478479480481
04.165926e-070.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
10.000000e+000.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
20.000000e+000.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
30.000000e+000.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
40.000000e+000.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
..................................................................
4770.000000e+000.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
4780.000000e+000.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
4790.000000e+000.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
4800.000000e+000.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
4810.000000e+000.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0

482 rows × 482 columns

distance_cost_plot(accumulated_cost)

在这里插入图片描述

显然累积距离矩阵的第0行第0列=距离矩阵的第0行第0列=1,我们必须经过起点吧……如果我们一直往右走,那么累积距离距离矩阵

# 累积距离距离矩阵
for i in range(1, len(label_some)):
    accumulated_cost[0,i] = distances[0,i] + accumulated_cost[0, i-1] 
pd.DataFrame(accumulated_cost)
0123456789...472473474475476477478479480481
04.165926e-070.0008850.0010750.0221390.0253170.0285030.0288940.0296630.0301730.030348...2.7913182.7946772.7949192.7988352.8033862.8128282.8275252.8281092.8505782.850708
10.000000e+000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000...0.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000
20.000000e+000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000...0.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000
30.000000e+000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000...0.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000
40.000000e+000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000...0.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000
..................................................................
4770.000000e+000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000...0.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000
4780.000000e+000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000...0.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000
4790.000000e+000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000...0.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000
4800.000000e+000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000...0.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000
4810.000000e+000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000...0.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000

482 rows × 482 columns

distance_cost_plot(accumulated_cost)

在这里插入图片描述

如果我们一直往上走,那么

for i in range(1, len(pred_some)):
    accumulated_cost[i,0] = distances[i, 0] + accumulated_cost[i-1, 0]  
pd.DataFrame(accumulated_cost)
0123456789...472473474475476477478479480481
04.165926e-070.0008850.0010750.0221390.0253170.0285030.0288940.0296630.0301730.030348...2.7913182.7946772.7949192.7988352.8033862.8128282.8275252.8281092.8505782.850708
12.328260e-040.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000...0.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000
22.395939e-040.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000...0.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000
36.511071e-030.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000...0.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000
46.629250e-030.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000...0.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000
..................................................................
4772.676938e+000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000...0.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000
4782.696313e+000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000...0.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000
4792.696541e+000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000...0.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000
4802.709715e+000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000...0.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000
4812.709716e+000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000...0.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000

482 rows × 482 columns

distance_cost_plot(accumulated_cost)

在这里插入图片描述

把累积距离矩阵计算完整

for i in range(1, len(pred_some)):
    for j in range(1, len(label_some)):
        accumulated_cost[i, j] = min(accumulated_cost[i-1, j-1], accumulated_cost[i-1, j], accumulated_cost[i, j-1]) + distances[i, j]
pd.DataFrame(accumulated_cost)
0123456789...472473474475476477478479480481
04.165926e-070.0008850.0010750.0221390.0253170.0285030.0288940.0296630.0301730.030348...2.7913182.7946772.7949192.7988352.8033862.8128282.8275252.8281092.8505782.850708
12.328260e-040.0001920.0010730.0177770.0230000.0246440.0259150.0260550.0275350.027542...2.7847742.7902292.7902292.7924082.7993562.8059622.8170592.8171282.8446132.844633
22.395939e-040.0011960.0003320.0219680.0207390.0241490.0244660.0253480.0257730.026003...2.7571072.7602442.7605502.7647152.7690052.7788312.7940072.7946892.8165752.816754
36.511071e-030.0027500.0090970.0045960.0231510.0212860.0312050.0271800.0358360.030210...2.7591532.7760962.7643782.7608482.7825452.7693062.7710192.7741192.8268972.821256
46.629250e-030.0030820.0033900.0212440.0092060.0112240.0122020.0124650.0136270.013630...2.7105712.7153992.7154152.7180222.7242592.7315962.7436342.7437942.7698482.769848
..................................................................
4772.676938e+002.6707012.6513172.6332922.6652502.6257482.5895582.5609922.5600402.527341...0.7778890.7982320.7771220.7733600.8058540.7704490.7704500.7800980.8542350.798668
4782.696313e+002.6828252.6749172.6333202.6717962.6327042.6150292.5735592.5864182.543372...0.7907500.8170150.7925720.7793300.8163340.7722700.7707950.7838320.8640460.815159
4792.696541e+002.6830212.6757892.6500642.6385212.6343612.6162892.5737032.5750262.543378...0.7908650.7961820.7925720.7815240.7862520.7789010.7819250.7708660.7983000.798319
4802.709715e+002.7037182.6858582.7172822.6418562.6634562.6251962.5938302.5820862.559598...0.8106220.7940200.8093870.8127490.7837020.8235480.8342970.7899900.7721450.787908
4812.709716e+002.7045002.6861002.7064142.6452372.6448462.6256592.5945052.5826782.559730...0.7951560.7975870.7942100.7979080.7884940.7928040.8070780.7904920.7951450.772238

482 rows × 482 columns

distance_cost_plot(accumulated_cost)

在这里插入图片描述

现在,最佳路径已经清晰地显示在了累积距离矩阵之中,就是图中颜色最淡的方块。

现在,我们只需要通过回溯的方法找回最佳路径就可以了:

path = [[len(label_some)-1, len(pred_some)-1]]
i = len(pred_some)-1
j = len(label_some)-1
while i>0 and j>0:
    if i==0:
        j = j - 1
    elif j==0:
        i = i - 1
    else:
        if accumulated_cost[i-1, j] == min(accumulated_cost[i-1, j-1], accumulated_cost[i-1, j], accumulated_cost[i, j-1]):
            i = i - 1#来自于左边
        elif accumulated_cost[i, j-1] == min(accumulated_cost[i-1, j-1], accumulated_cost[i-1, j], accumulated_cost[i, j-1]):
            j = j-1#来自于下边
        else:
            i = i - 1#来自于左下边
            j= j- 1
    path.append([j, i])
path.append([0,0])
path_x = [point[0] for point in path]
path_y = [point[1] for point in path]
distance_cost_plot(accumulated_cost)
plt.plot(path_x, path_y)
[<matplotlib.lines.Line2D at 0x2393ee7a220>]

在这里插入图片描述

图片相似度
from skimage.metrics import structural_similarity as sk_cpt_ssim
import matplotlib.pyplot as plt
import numpy as np
import cv2

def mse(imageA, imageB):
    # 计算两张图片的MSE指标
    err = np.sum((imageA.astype("float") - imageB.astype("float")) ** 2)
    err /= float(imageA.shape[0] * imageA.shape[1])
    
    # 返回结果,该值越小越好
    return err

def compare_images(imageA, imageB, title):
    # 分别计算输入图片的MSE和SSIM指标值的大小
    m = mse(imageA, imageB)
    s = sk_cpt_ssim(imageA, imageB)

    # 创建figure
    fig = plt.figure(title)
    plt.suptitle("MSE: %.2f, SSIM: %.2f" % (m, s))

    # 显示第一张图片
    ax = fig.add_subplot(1, 2, 1)
    plt.imshow(imageA, cmap = plt.cm.gray)
    plt.axis("off")

    # 显示第二张图片
    ax = fig.add_subplot(1, 2, 2)
    plt.imshow(imageB, cmap = plt.cm.gray)
    plt.axis("off")
    plt.tight_layout()
    plt.show()



# 读取图片
pred_image = cv2.imread("dataset/pred.jpg")
real_image = cv2.imread("dataset/real.jpg")
all_image = cv2.imread('dataset/all.jpg')
some_image = cv2.imread('dataset/some.jpg')

# 将彩色图转换为灰度图
pred = cv2.cvtColor(pred_image, cv2.COLOR_BGR2GRAY)
real = cv2.cvtColor(real_image, cv2.COLOR_BGR2GRAY)
all_image = cv2.cvtColor(all_image,cv2.COLOR_BGR2GRAY)
some_image = cv2.cvtColor(some_image,cv2.COLOR_BGR2GRAY)

# 初始化figure对象
fig = plt.figure("Images")
# images = ("pred", pred), ("real", real),('all',all_image),('some',some_image)
images = ("pred", pred), ("real", real)

# 遍历每张图片
for (i, (name, image)) in enumerate(images):
    # 显示图片
    ax = fig.add_subplot(1, 4, i + 1)
    ax.set_title(name)
    plt.imshow(image, cmap = plt.cm.gray)
    plt.axis("off")
plt.tight_layout()
plt.show()

# 比较图片
# compare_images(real, real, "real vs real")
compare_images(real, pred, "real vs pred")
# compare_images(all_image, pred, "real vs pred")
# compare_images(some_image,all_image,'some vs all')

在这里插入图片描述
在这里插入图片描述