|
|
51CTO旗下网站
|
|
移动端

2.3.1 类的封装

《量化交易之路:用Python做股票量化分析》第2章量化语言——Python,第2部分(第2~6章)主要讲解了量化交易需要的基础知识及相关工具,如Python语言、NumPy、pandas、数据可视化及量化数学等知识,适合完全没有任何编程经验的读者从头开始阅读。本节为大家介绍类的封装。

作者:阿布来源:机械工业出版社|2017-10-19 16:15

2.3  面向对象

面向对象编程(Object Oriented Programming,OOP)是一种计算机编程架构。OOP的一条基本原则是计算机程序是由单个能够起到子程序作用的单元或对象组合而成。OOP达到了软件工程的3个主要目标:重用性、灵活性和扩展性。

2.3.1  类的封装

面向对象的设计思想是抽象出类,类的抽象程度比函数要高,类既包含属性(数据),又包含操作属性的方法。

Python中的类定义方式与其他语言差别不大,最大的特点是每一个类方法的第一个参数都是以self开始,后面是其他参数。

Python中并没有访问控制,没有类似Java、C++中private(私有)、protect(受保护的)和public(公有)的方法显式声明,而是通过遵循一定的属性和方法命名规则来达到这个效果。

任何以单下画线开头的名字都代表protect,如self._price_array、def_init_stock_ dict(self);

任何以双下画线开头的名字都代表private,如self.__init_change(self);

以双下画线开头结尾的名字都代表系统保留定义,如__init__( self)、__str__( self)、 __iter__( self)、__len__( self)。

例如,StockTradeDays中__init__(self)方法抽象了2.2节的代码完成了由一个字符串对象price_array创建出stockdict的整个过程。可以发现,类中定义的__init_change(self)、_init_stock_dict(self)以及filter_stock(self)都是2.2节使用过的方法,StockTradeDays将它们进一步封装在类中。

  1. from collections import namedtuple  
  2. from collections import OrderedDict  
  3. class StockTradeDays(object):  
  4.     def __init__(self, price_array, start_date, date_array=None):  
  5.         # 私有价格序列  
  6.         self.__price_array = price_array  
  7.         # 私有日期序列  
  8.         selfself.__date_array = self._init_days(start_date, date_array)  
  9.         # 私有涨跌幅序列  
  10.         selfself.__change_array = self.__init_change()  
  11.         # 进行OrderedDict的组装  
  12.         selfself.stock_dict = self._init_stock_dict()  
  13.     def __init_change(self):  
  14.         """  
  15.         从price_array生成change_array  
  16.         :return:  
  17.         """  
  18.         price_float_array = [float(price_str) for price_str in  
  19.                              self.__price_array]  
  20.         # 通过将时间平移形成两个错开的收盘价序列,通过zip()函数打包成为一个新的序列  
  21.         # 每个元素为相邻的两个收盘价格  
  22.         pp_array = [(price1, price2) for price1, price2 in  
  23.                     zip(price_float_array[:-1], price_float_array[1:])]  
  24.         change_array = map(  
  25.             lambda pp: reduce(lambda a, b: round((b - a) / a, 3), pp),  
  26.             pp_array)  
  27.         # list insert()函数插入数据,将第一天的涨跌幅设置为0  
  28.         change_array.insert(0, 0)  
  29.         return change_array  
  30.     def _init_days(self, start_date, date_array):  
  31.         """  
  32.         protect方法,  
  33.         :param start_date: 初始日期  
  34.         :param date_array: 给定日期序列  
  35.         :return:  
  36.         """  
  37.         if date_array is None:  
  38.             # 由start_date和self.__price_array来确定日期序列  
  39.             date_array = [str(start_date + ind) for ind, _ in  
  40.                           enumerate(self.__price_array)]  
  41.         else:  
  42.             # 稍后的内容会使用外部直接设置的方式  
  43.             # 如果外面设置了date_array,就直接转换str类型组成新date_array  
  44.             date_array = [str(date) for date in date_array]  
  45.         return date_array  
  46.     def _init_stock_dict(self):  
  47.         """  
  48.         使用namedtuple,OrderedDict将结果合并  
  49.         :return:  
  50.         """  
  51.         stock_namedtuple = namedtuple('stock',  
  52.                                       ('date', 'price', 'change'))  
  53.         # 使用以被赋值的__date_array等进行OrderedDict的组装  
  54.         stock_dict = OrderedDict(  
  55.             (date, stock_namedtuple(date, price, change))  
  56.             for date, price, change in  
  57.             zip(self.__date_array, self.__price_array,  
  58.                 self.__change_array))  
  59.         return stock_dict  
  60.     def filter_stock(self, want_up=Truewant_calc_sum=False):  
  61.         """  
  62.         筛选结果子集  
  63.         :param want_up: 是否筛选上涨  
  64.         :param want_calc_sum: 是否计算涨跌幅和  
  65.         :return:  
  66.         """  
  67.         # Python中的三目表达式的写法  
  68.         filter_func = (lambda day: day.change > 0) if want_up else (  
  69.             lambda day: day.change < 0)  
  70.         # 使用filter_func作为筛选函数  
  71.         want_days = filter(filter_func, self.stock_dict.values())  
  72.         if not want_calc_sum:  
  73.             return want_days  
  74.         # 需要计算涨跌幅和  
  75.         change_sum = 0.0  
  76.         for day in want_days:  
  77.             change_sum += day.change  
  78.         return change_sum  
  79.     """  
  80.         下面的__str__、__iter__、__getitem__和__len__稍后会详细讲解  
  81.     """  
  82.     def __str__(self):  
  83.         return str(self.stock_dict)  
  84.     __repr__ = __str__ 
  85.     def __iter__(self):  
  86.         """  
  87.         通过代理stock_dict的迭代,yield元素  
  88.         :return:  
  89.         """  
  90.         for key in self.stock_dict:  
  91.             yield self.stock_dict[key]  
  92.     def __getitem__(self, ind):  
  93.         date_key = self.__date_array[ind]  
  94.         return self.stock_dict[date_key]  
  95.     def __len__(self):  
  96.         return len(self.stock_dict) 

1.对象支持信息打印

下面的代码首先从StockTradeDays类初始化一个实例对象trade_days,然后打印出对象信息。

  1. price_array = '30.14,29.58,26.36,32.56,32.82'.split(',')  
  2. date_base = 20170118 
  3. # 从StockTradeDays类初始化一个实例对象trade_days,内部会调用__init__  
  4. trade_days = StockTradeDays(price_array, date_base)  
  5. # 打印对象信息  
  6. trade_days 

输出如下:

  1. OrderedDict(  
  2.    [('20170118', stock(date='20170118'price='30.14'change=0)),   
  3.     ('20170119', stock(date='20170119'price='29.58'change=-0.019)),  
  4.     ('20170120', stock(date='20170120'price='26.36'change=-0.109)),  
  5.     ('20170121', stock(date='20170121'price='32.56'change=0.235)),  
  6.     ('20170122', stock(date='20170122'price='32.82'change=0.008))]) 

这里能打印出信息,是因为StockTradeDays中定义了下面代码的作用。

  1. def __str__(self):  
  2.     return str(self.stock_dict)  
  3. __repr__ = __str__ 

自定义__repr__()和__str__()的目的是简化调试和实例输出复杂度,使对象更具可读性。

2.对象支持长度获取

Python中获取字符串长度或者数组的长度都是使用len(str)和len(list),而不是使用如str.len()和list.len()的方式,因为所有需要长度的对象都需要实现__len__(),len()内建函数就是调用这个对象的__len__()。

以下代码通过len(trade_days)获取元素个数。

  1. print 'trade_days对象长度为: {}'.format(len(trade_days)) 

输出如下:

  1. trade_days对象长度为: 5 

这里能打印出对象长度,是因为StockTradeDays中定义下面代码的作用。

  1. def __len__(self):  
  2.     # 通过代理self.stock_dict的len()方法简单实现  
  3.     return len(self.stock_dict) 

3.对象支持迭代

Python的list、tuple、dict等对象都可以通过for in循环来遍历序列中的每一个元素,这个特性称做可迭代。

要判断对象是否支持迭代操作,可以使用collections.Iterable;

自定义的类也可以通过实现__iter__()方法来支持迭代操作。

以下代码判断trade_days对象是否支持迭代(即是否可通过for循环遍历trade_days对象),如果可迭代则迭代trade_days,依次打印序列元素:

  1. from collections import Iterable  
  2. # 如果trade_days是可迭代对象,依次打印出  
  3. if isinstance(trade_days, Iterable) :  
  4.     for day in trade_days:  
  5.         print day 

输出如下:

  1. stock(date='20170118'price='30.14'change=0)  
  2. stock(date='20170119'price='29.58'change=-0.019)  
  3. stock(date='20170120'price='26.36'change=-0.109)  
  4. stock(date='20170121'price='32.56'change=0.235)  
  5. stock(date='20170122'price='32.82'change=0.008) 

这里能迭代trade_days,是因为StockTradeDays中定义下面代码的作用。

  1. def __iter__(self):  
  2.     """  
  3.     通过代理stock_dict的迭代,yield元素  
  4.     :return:  
  5.     """  
  6.     for key in self.stock_dict:  
  7.         yield self.stock_dict[key] 

4.对象方法调用

下面使用类中定义的函数filter_stock(),注意下面调用这个函数的时候,第一个参数self就是调用者tradedays本身,所以Python的类方法都要求第一个参数为self。

  1. trade_days.filter_stock() 

输出如下:

  1. [stock(date='20170121'price='32.56'change=0.235),  
  2. stock(date='20170122'price='32.82'change=0.008)] 

5.对象支持索引获取

下面开始使用真实的股票数据构造trade_days,使用abu量化系统中的ABuSymbolPd获取特斯拉(TSLA)电动车两年的交易数据(abu量化系统代码地址,请通过微信公众号abu_quant获取,本书所有示例的IPython Notebook代码也在对应目录中)。

  1. from abupy import ABuSymbolPd  
  2. # 两年的TSLA收盘数据 to list()  
  3. price_array = ABuSymbolPd.make_kl_df('TSLA', n_folds=2).close.tolist()  
  4. # 两年的TSLA收盘日期 to list(),这里的写法不考虑效率,只做演示使用  
  5. date_array = ABuSymbolPd.make_kl_df('TSLA', n_folds=2).date.tolist()  
  6. price_array[:5], date_array[:5] 

输出如下:

  1. ([222.49000000000001,  
  2.   223.53999999999999,  
  3.   223.56999999999999,  
  4.   224.81999999999999,  
  5.   225.00999999999999],  
  6.  [20140723, 20140724, 20140725, 20140728, 20140729])  

下面通过真实交易数据,构造StockTradeDays,并使用索引获取对象数据。

  1. # 这里传入date_array,在StockTradeDays中_init_days()会直接使用传入的时间序列  
  2. trade_days = StockTradeDays(price_array, date_base, date_array)  
  3. print 'trade_days对象长度为: {}'.format(len(trade_days))  
  4. # 使用索引-1获取最后一天的交易数据  
  5. print '最后一天交易数据为:{}'.format(trade_days[-1]) 

输出如下:

  1. trade_days对象长度为: 504  
  2. 最后一天交易数据为:stock(date='20160726'price=225.93000000000001, change= -0.018) 

上面代码中最后一天的交易数据可以通过trade_days [-1]获取,是因为StockTradeDays支持索引获取对象数据,StockTradeDays中实现了__getitem__(),代码如下:

  1. def __getitem__(self, ind):  
  2.     date_key = self.__date_array[ind]  
  3.     return self.stock_dict[date_key] 

喜欢的朋友可以添加我们的微信账号:

51CTO读书频道二维码


51CTO读书频道活动讨论群:365934973

【责任编辑:book TEL:(010)68476606】

回书目   上一节   下一节
点赞 0
分享:
大家都在看
猜你喜欢

读 书 +更多

精通JavaScript动态网页编程(实例版)

本书通过大量实例代码,以ECMA-262版本3为基础,结合JavaScript 1.5和JavaScript 5.5,由浅入深、循序渐进地介绍了JavaScript知识要点与编...

订阅51CTO邮刊

点击这里查看样刊

订阅51CTO邮刊