Scikit Learn 股票投资:p10

前言

本期的教学视频比较短,主要的教学内容是“标记”我们的数据。Machie Learning(机器学习)主要步骤都离不开数据获取、整理数据、数据标准化。这些步骤通常都是最浪费时间的。

视频

视频出处

视频系列:Scikit-learn Machine Learning with Python and SKlearn

本视频出处:Scikit Learn Machine Learning for investing Tutorial with Python p. 10

哔哩哔哩:Scikit Learn Machine Learning for investing Tutorial with Python p. 10

数据下载

下载地址:数据

百 度 云: 地址 密码: yyq8

内容

首先我们添加新的列Status用于判断股票是否优于大盘。

df = pd.DataFrame(
  columns = [
    'Date',
    'Unix',
    'Ticker',
    'DE Ratio',
    'Price',
    'stock_p_change',
    'SP500',
    'sp500_p_change',
    'Difference',
    'Status'
  ]
)

并在df = df.append()添加status的对应字典。 作者将df = df.append()中的 DIfference数据单独抽取了出来,并加入了IF判断,由此获得相应股票是否outperform或者underperform

  stock_p_change = ((stock_price - starting_stock_value) / starting_stock_value) * 100
  sp500_p_change = ((sp500_value - starting_sp500_value) / starting_sp500_value) * 100
  difference = stock_p_change - sp500_p_change

  if difference > 0:
      status = "outperform"
  else:
      status = "underperform"

  #将数据叠加并保存在字典里面
  df = df.append({'Date':date_stamp,
                  'Unix':unix_time,
                  'Ticker':ticker,
                  'DE Ratio':value,
                  'Price':stock_price,
                  'stock_p_change':stock_p_change,
                  'SP500':sp500_value,
                  'sp500_p_change':sp500_p_change,
                  'Difference': difference,
                  'Status': status},
                  ignore_index = True
  )
except Exception as e:
  pass
  #print(str(e))

最后在画图部分添加:

#将每只股票与指数的差用画图显示出来,会发现有些数据缺失了,然后修复
for each_ticker in ticker_list:
  try:
    plot_df = df[(df['Ticker'] == each_ticker)]
    plot_df = plot_df.set_index(['Date'])

    if plot_df['Status'][-1] =="underperform":
        color = 'r'
    else:
        color = 'g'

    plot_df['Difference'].plot(label=each_ticker, color=color)
    plt.legend()
  except:
    pass

输出

源代码

import pandas as pd
import os
import time
from datetime import datetime
import re
from time import mktime
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import style
style.use("dark_background")


#获取数据的具体路径
path = "../intraQuarter"

#定义一个function,默认值为Total Debt/Equity (mrq),以后可以更改为其他
def Key_Stats(gather="Total Debt/Equity (mrq)"):
  statspath = path+'/_KeyStats'

  #os模块提供的walk方法很强大,能够把给定的目录下的所有目录和文件遍历出来。
  #方法:os.walk(path),遍历path,返回一个对象,他的每个部分都是一个三元组,('目录x',[目录x下的目录list],目录x下面的文件)
  stock_list = [x[0] for x in os.walk(statspath)]
  df = pd.DataFrame(
    columns = [
      'Date',
      'Unix',
      'Ticker',
      'DE Ratio',
      'Price',
      'stock_p_change',
      'SP500',
      'sp500_p_change',
      'Difference',
      'Status'
    ]
  )

  #读取SP500指数数据
  sp500_df = pd.DataFrame.from_csv("SPY.csv")

  ticker_list = []

  #stock_list[1:] -- 主要作用是跳过根目录intraQuarter
  for each_dir in stock_list[1:25]:
    #os.listdir(each_dir):列出each_dir下的目录和文件
    each_file = os.listdir(each_dir)

    # ticker = each_dir.split("\\")[1] # Windows only
    # ticker = each_dir.split("/")[1] # this didn't work so do this:
    ticker = os.path.basename(os.path.normpath(each_dir))
    # print(ticker) # 用作验证是否成功获取股票代码
    ticker_list.append(ticker)

    starting_stock_value = False
    starting_sp500_value = False

    if len(each_file) > 0:
      for file in each_file:
        #将文件名转换为时间序列
        date_stamp = datetime.strptime(file, '%Y%m%d%H%M%S.html')
        #转换为unix_time
        unix_time = time.mktime(date_stamp.timetuple())
        full_file_path = each_dir+'/'+file
        #读取html文件
        source = open(full_file_path,'r').read()
        try:
          try:
            #读取html后并以gather + ':</td><td class="yfnc_tabledata1">作为split,提取第一位元素,然后将后面</td>作为结束  
            value = float(source.split(gather+':</td><td class="yfnc_tabledata1">')[1].split('</td>')[0])
          except Exception as e:

            try:
              #因为yahoo中的数据经过调整,获取页面数值的位置已经更改,会出现读取错误,新的数值会多了一行,所以要添加\n跳过                 
              value = float(source.split(gather+':</td>\n<td class="yfnc_tabledata1">')[1].split('</td>')[0])
              #print(str(e),ticker, file)
            except Exception as e:
              pass
              #print(str(e),ticker, file)
              #time.sleep(15)
              #value = float(source.split(gather+':</td>\n<td class="yfnc_tabledata1">')[1].split('</td>')[0])

          try:
            sp500_date = datetime.fromtimestamp(unix_time).strftime('%Y-%m-%d')
            row = sp500_df[(sp500_df.index == sp500_date)]
            sp500_value = float(row["Adj Close"])
          except:
            sp500_date = datetime.fromtimestamp(unix_time-259200).strftime('%Y-%m-%d')
            row = sp500_df[(sp500_df.index == sp500_date)]
            sp500_value = float(row["Adj Close"])

          try:
             #将文件内容以':</td><td class="yfnc_tabledata1">'开始进行分割,[1]表示分割后获取后面的第一个内容D/E的值
            stock_price = float(source.split('</small><big><b>')[1].split('</b></big>')[0])
          except Exception as e:
            #    <span id="yfs_l10_afl">43.27</span>
            try:
              #因为yahoo中的数据经过调整,获取页面数值的位置已经更改,所以要利用正则表达式通过re.search查找出数值
              #r'(\d{1,8}\.\d{1,8})'中的\d表示数字,{1,8}表示查找1-8数字,中间的\.表示跳过数值中的"."
              stock_price = (source.split('</small><big><b>')[1].split('</b></big>')[0])
              stock_price = re.search(r'(\d{1,8}\.\d{1,8})',stock_price)
              stock_price = float(stock_price.group(1))
              #print(stock_price)
            except Exception as e:
              try:
                #原因同上
                stock_price = (source.split('<span class="time_rtq_ticker">')[1].split('</span>')[0])
                stock_price = re.search(r'(\d{1,8}\.\d{1,8})',stock_price)
                stock_price = float(stock_price.group(1))
              except Exception as e:
                print(str(e),'a;lsdkfh',file,ticker)

              #print('Latest:',stock_price)
              #print('stock price',str(e),ticker,file)
              #time.sleep(15)

          #print("stock_price:",stock_price,"ticker:", ticker)

          if not starting_stock_value:
            starting_stock_value = stock_price
          if not starting_sp500_value:
            starting_sp500_value = sp500_value

          stock_p_change = ((stock_price - starting_stock_value) / starting_stock_value) * 100
          sp500_p_change = ((sp500_value - starting_sp500_value) / starting_sp500_value) * 100
          difference = stock_p_change - sp500_p_change

          if difference > 0:
              status = "outperform"
          else:
              status = "underperform"

          #将数据叠加并保存在字典里面
          df = df.append({'Date':date_stamp,
                          'Unix':unix_time,
                          'Ticker':ticker,
                          'DE Ratio':value,
                          'Price':stock_price,
                          'stock_p_change':stock_p_change,
                          'SP500':sp500_value,
                          'sp500_p_change':sp500_p_change,
                          'Difference': difference,
                          'Status': status},
                          ignore_index = True
          )
        except Exception as e:
          pass
          #print(str(e))

  #将每只股票与指数的差用画图显示出来,会发现有些数据缺失了,然后修复
  for each_ticker in ticker_list:
    try:
      plot_df = df[(df['Ticker'] == each_ticker)]
      plot_df = plot_df.set_index(['Date'])

      if plot_df['Status'][-1] =="underperform":
          color = 'r'
      else:
          color = 'g'

      plot_df['Difference'].plot(label=each_ticker, color=color)
      plt.legend()
    except:
      pass

  plt.show()

  #将gather="Total Debt/Equity (mrq)" 中的多余符号去掉,作为文件名
  save = gather.replace(' ','').replace(')','').replace('(','').replace('/','')+('.csv')
  print(save)
  df.to_csv(save)

Key_Stats()

最后

虽然分c君_BingWong只是作为一名搬运工,连码农都称不上。 但制作代码中的注释、翻译和搬运都花了很多时间,请各位大侠高抬贵手,在转载时请注明出处。

阅读量: | 柯西君_BingWong | 2017-08-31