From ec67a2d34349c078749dc2cd240b27a07241a899 Mon Sep 17 00:00:00 2001 From: zengbin93 Date: Sun, 3 Dec 2023 22:10:08 +0800 Subject: [PATCH] =?UTF-8?q?V0.9.38=20=E6=9B=B4=E6=96=B0=E4=B8=80=E6=89=B9?= =?UTF-8?q?=E4=BB=A3=E7=A0=81=20(#179)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 0.9.38 start coding * 0.9.38 fix create_grid_params * 0.9.38 新增 resample_to_daily 方法 * 0.9.38 夏普和卡玛设定上下限 * 0.9.38 新增 show_monthly_return * 0.9.38 fix test * 0.9.38 update docs --- .github/workflows/pythonpackage.yml | 2 +- czsc/__init__.py | 5 +- czsc/objects.py | 27 ++++- czsc/sensors/event.py | 42 ++++++- czsc/sensors/utils.py | 15 +++ czsc/strategies.py | 65 +++++++++-- czsc/traders/base.py | 139 ++++++++++++++++++++--- czsc/traders/performance.py | 40 +++++++ czsc/traders/sig_parse.py | 69 +++++++++++- czsc/utils/__init__.py | 28 +++-- czsc/utils/plotly_plot.py | 166 +++++++++++++++++++++++++++- czsc/utils/st_components.py | 15 +++ czsc/utils/stats.py | 15 ++- czsc/utils/trade.py | 50 +++++++++ test/test_trade_utils.py | 19 ++++ test/test_utils.py | 4 +- 16 files changed, 651 insertions(+), 50 deletions(-) diff --git a/.github/workflows/pythonpackage.yml b/.github/workflows/pythonpackage.yml index 1dc650925..87b1345e9 100644 --- a/.github/workflows/pythonpackage.yml +++ b/.github/workflows/pythonpackage.yml @@ -5,7 +5,7 @@ name: Python package on: push: - branches: [ master, V0.9.37 ] + branches: [ master, V0.9.38 ] pull_request: branches: [ master ] diff --git a/czsc/__init__.py b/czsc/__init__.py index a7022dc4c..7ac958cd1 100644 --- a/czsc/__init__.py +++ b/czsc/__init__.py @@ -61,6 +61,7 @@ update_tbars, update_nbars, risk_free_returns, + resample_to_daily, CrossSectionalPerformance, cross_sectional_ranker, @@ -112,10 +113,10 @@ feture_cross_layering, ) -__version__ = "0.9.37" +__version__ = "0.9.38" __author__ = "zengbin93" __email__ = "zeng_bin8888@163.com" -__date__ = "20231118" +__date__ = "20231126" def welcome(): diff --git a/czsc/objects.py b/czsc/objects.py index ad6f1b5ec..00ecad284 100644 --- a/czsc/objects.py +++ b/czsc/objects.py @@ -1058,6 +1058,31 @@ def evaluate(self, trade_dir: str = "多空") -> dict: def update(self, s: dict): """更新持仓状态 + 函数执行逻辑: + + - 首先,检查最新信号的时间是否在上次信号之前,如果是则打印警告信息并返回。 + - 初始化一些变量,包括操作类型(op)和操作描述(op_desc)。 + - 遍历所有的事件,检查是否与最新信号匹配。如果匹配,则记录操作类型和操作描述,并跳出循环。 + - 提取最新信号的相关信息,包括交易对符号、时间、价格和成交量。 + - 更新持仓状态的结束时间为最新信号的时间。 + - 如果操作类型是开仓(LO或SO),更新最后一个事件的信息。 + - 定义一个内部函数__create_operate,用于创建操作记录。 + - 根据操作类型更新仓位和操作记录。 + + - 如果操作类型是LO(开多),检查是否满足开仓条件,如果满足则开多仓,否则只平空仓。 + - 如果操作类型是SO(开空),检查是否满足开仓条件,如果满足则开空仓,否则只平多仓。 + - 如果当前持仓为多仓,进行多头出场的判断: + - 如果操作类型是LE(平多),平多仓。 + - 如果当前价格相对于最后一个事件的价格的收益率小于止损阈值,平多仓。 + - 如果当前成交量相对于最后一个事件的成交量的增加量大于超时阈值,平多仓。 + + - 如果当前持仓为空仓,进行空头出场的判断: + - 如果操作类型是SE(平空),平空仓。 + - 如果当前价格相对于最后一个事件的价格的收益率小于止损阈值,平空仓。 + - 如果当前成交量相对于最后一个事件的成交量的增加量大于超时阈值,平空仓。 + + - 将当前持仓状态和价格记录到持仓列表中。 + :param s: 最新信号字典 :return: """ @@ -1091,7 +1116,7 @@ def update(self, s: dict): def __create_operate(_op, _op_desc): self.pos_changed = True return { - "symbol": self.symbol, + "symbol": symbol, "dt": dt, "bid": bid, "price": price, diff --git a/czsc/sensors/event.py b/czsc/sensors/event.py index 3fbd4f9ee..8dc84c747 100644 --- a/czsc/sensors/event.py +++ b/czsc/sensors/event.py @@ -91,6 +91,16 @@ def __init__(self, events: List[Union[Dict[str, Any], Event]], symbols: List[str self.csc = df def _get_signals_config(self): + """获取所有事件的信号配置,并将其合并为一个不包含重复项的列表。 + 该列表包含了所有事件所需的信号计算和解析规则,以便于后续的事件匹配过程。 + + 1. 创建一个空列表 config,用于存储所有的信号配置。 + 2. 遍历 self.events 中的所有事件(Event 对象)。对于每个事件,调用其 get_signals_config 方法, + 传入 signals_module 参数,并将返回值(即该事件的信号配置)添加到 config 列表中。 + 3. 通过 list comprehension 生成一个新的列表 config。对 config 列表中的每个字典 d + 使用 tuple(d.items()) 转换为元组,然后将这些元组转换回 dict 并加入新列表中。 + 4. 返回处理后的 config 列表。 + """ config = [] for event in self.events: _c = event.get_signals_config(signals_module=self.signals_module) @@ -99,7 +109,28 @@ def _get_signals_config(self): return config def _single_symbol(self, symbol): - """单个symbol的事件匹配""" + """单个symbol的事件匹配 + + 对单个标的(symbol)进行事件匹配。它首先获取 K 线数据,然后生成 CZSC 信号,接着遍历每个事件并计算匹配情况, + 最后整理数据框并返回。如果在过程中遇到问题,则记录错误并返回一个空 DataFrame。 + + 函数执行逻辑: + + 1. 调用 self.read_bars 方法读取指定 symbol、频率(self.base_freq)、开始时间(self.bar_sdt) + 和结束时间(self.edt)的 K 线数据,并将返回值赋给 bars。 + 2. 使用 generate_czsc_signals 函数生成 CZSC 信号。这里传入了 bars、复制后的 + signals_config(以防止修改原始配置)、开始时间(self.sdt)以及 df=False(表示返回一个字典列表而非 DataFrame 对象)。 + 3. 将上一步生成的信号转换为 DataFrame 并保存到 sigs 变量中。 + 4. 创建一个新的 events 复制品(以防止修改原始事件列表),并创建一个空列表 new_cols,用于存储新添加的列名。 + 5. 遍历新的 events 列表,对于每个 event: + a. 获取 event 的名称 e_name。 + b. 使用 apply 函数应用 is_match 方法来判断每行数据是否与该事件相匹配。 + 结果是一个布尔值和一个 float 值(表示匹配得分),它们分别被保存为 e_name 和 f'{e_name}_F' 列。 + c. 将这两个新列名添加到 new_cols 列表中。 + 6. 在 sigs 数据框中添加一列 n1b,表示涨跌幅。 + 7. 最后,重新组织 sigs 数据框的列顺序,使其包含以下列:symbol、dt、open、close、high、low、vol、amount、n1b 以及所有新添加的列。 + + """ try: bars = self.read_bars(symbol, freq=self.base_freq, sdt=self.bar_sdt, edt=self.edt, **self.kwargs) sigs = generate_czsc_signals(bars, deepcopy(self.signals_config), sdt=self.sdt, df=False) @@ -139,6 +170,15 @@ def get_event_csc(self, event_name: str): csc = cross section count,表示截面匹配次数 + 函数执行逻辑: + + 1. 创建一个 self.data 的副本 df。 + 2. 在 df 中筛选出 event_name 列等于 1 的行。 + 3. 使用 groupby 方法按 symbol 和 dt 对筛选后的数据进行分组,并计算 event_name 列的总和。 + 结果将形成一个新的 DataFrame,其中索引为 (symbol, dt) 组合,只有一个列 event_name,表示每个组合的匹配次数。 + 4. 再次使用 groupby 方法按 dt 对上一步的结果进行分组,并计算 event_name 列的总和。这次得到的新 DataFrame + 只有一个列 event_name,表示在每个时间点所有标的的事件匹配总数。 + :param event_name: 事件名称 :return: DataFrame """ diff --git a/czsc/sensors/utils.py b/czsc/sensors/utils.py index 76384b59d..b41728b64 100644 --- a/czsc/sensors/utils.py +++ b/czsc/sensors/utils.py @@ -138,6 +138,21 @@ def holds_concepts_effect(holds: pd.DataFrame, concepts: dict, top_n=20, min_n=3 原理概述:在选股时,如果股票的概念板块与组合中的其他股票的概念板块有重合,那么这个股票的表现会更好。 + 函数计算逻辑: + + 1. 如果kwargs中存在'copy'键且对应值为True,则将holds进行复制。 + 2. 为holds添加'概念板块'列,该列的值是holds中'symbol'列对应的股票的概念板块列表,如果没有对应的概念板块则填充为空。 + 3. 添加'概念数量'列,该列的值是每个股票的概念板块数量。 + 4. 从holds中筛选出概念数量大于0的行,赋值给holds。 + 5. 创建空列表new_holds和空字典dt_key_concepts。 + 6. 对holds按照'dt'进行分组,遍历每个分组,计算板块效应。 + a. 计算密集出现的概念,选取出现次数最多的前top_n个概念,赋值给key_concepts列表。 + b. 将日期dt和对应的key_concepts存入dt_key_concepts字典。 + c. 计算在密集概念中出现次数超过min_n的股票,将符合条件的股票添加到new_holds列表中。 + 7. 使用pd.concat将new_holds中的DataFrame进行合并,忽略索引,赋值给dfh。 + 8. 创建DataFrame dfk,其中包含日期(dt)和对应的强势概念(key_concepts)。 + 9. 返回dfh和dfk。 + :param holds: 组合股票池数据,样例: =================== ========= ========== diff --git a/czsc/strategies.py b/czsc/strategies.py index 8e26bf387..c81e0fdcc 100644 --- a/czsc/strategies.py +++ b/czsc/strategies.py @@ -79,19 +79,30 @@ def positions(self) -> List[Position]: def init_bar_generator(self, bars: List[RawBar], **kwargs): """使用策略定义初始化一个 BarGenerator 对象 + 函数执行逻辑: + + - 该方法的目的是使用策略定义初始化一个BarGenerator对象。BarGenerator是用于生成K线数据的类。 + - 参数bars表示基础周期的K线数据,**kwargs用于接收额外的关键字参数。 + - 首先,方法获取了基础K线的频率,并检查了是否已经有一个初始化好的BarGenerator对象传入。 + - 然后,根据基础频率是否在排序后的频率列表中,确定要使用的频率列表。 + - 如果没有传入BarGenerator对象,则根据传入的基础K线数据和其他参数创建一个新的BarGenerator对象, + 并使用部分K线数据初始化它。余下的K线数据将用于trader的初始化区间。 + - 如果传入了BarGenerator对象,则会做一些断言检查,确保传入的基础K线数据与已有的BarGenerator对象的基础周期一致, + 并且BarGenerator的end_dt是datetime类型。然后,筛选出在BarGenerator的end_dt之后的K线数据。 + - 最后,返回BarGenerator对象和余下的K线数据。 + :param bars: 基础周期K线 :param kwargs: + bg 已经初始化好的BarGenerator对象,如果传入了bg,则忽略sdt和n参数 sdt 初始化开始日期 n 初始化最小K线数量 + :return: """ base_freq = str(bars[0].freq.value) bg: BarGenerator = kwargs.get("bg", None) - if base_freq in self.sorted_freqs: - freqs = self.sorted_freqs[1:] - else: - freqs = self.sorted_freqs + freqs = self.sorted_freqs[1:] if base_freq in self.sorted_freqs else self.sorted_freqs if bg is None: uni_times = sorted(list({x.dt.strftime("%H:%M") for x in bars})) @@ -125,11 +136,20 @@ def init_trader(self, bars: List[RawBar], **kwargs) -> CzscTrader: **注意:** 这里会将所有持仓策略在 sdt 之后的交易信号计算出来并缓存在持仓策略实例内部,所以初始化的过程本身也是回测的过程。 + 函数执行逻辑: + + - 首先,它通过调用init_bar_generator方法获取已经初始化好的BarGenerator对象和余下的K线数据。 + - 然后,它创建一个CzscTrader对象,将BarGenerator对象、持仓策略的深拷贝、交易信号配置的深拷贝等参数传递给CzscTrader的构造函数。 + - 接着,使用余下的K线数据对CzscTrader对象进行初始化,通过调用trader.on_bar(bar)方法处理每一根K线数据。 + - 最后,返回初始化完成的CzscTrader对象。 + :param bars: 基础周期K线 :param kwargs: + bg 已经初始化好的BarGenerator对象,如果传入了bg,则忽略sdt和n参数 sdt 初始化开始日期 n 初始化最小K线数量 + :return: 完成策略初始化后的 CzscTrader 对象 """ bg, bars2 = self.init_bar_generator(bars, **kwargs) @@ -152,7 +172,7 @@ def dummy(self, sigs: List[dict], **kwargs) -> CzscTrader: sleep_time = kwargs.get("sleep_time", 0) sleep_step = kwargs.get("sleep_step", 1000) - trader = CzscTrader(positions=deepcopy(self.positions)) # type: ignore + trader = CzscTrader(positions=deepcopy(self.positions)) # type: ignore for i, sig in tqdm(enumerate(sigs), desc=f"回测 {self.symbol} {self.sorted_freqs}"): trader.on_sig(sig) @@ -164,12 +184,28 @@ def dummy(self, sigs: List[dict], **kwargs) -> CzscTrader: def replay(self, bars: List[RawBar], res_path, **kwargs): """交易策略交易过程回放 + 函数执行逻辑: + + - 该方法用于交易策略交易过程的回放。它接受基础周期的K线数据、结果目录以及额外的关键字参数作为输入。 + - 首先,它检查refresh参数,如果为True,则使用shutil.rmtree删除已存在的结果目录。 + - 然后,它检查结果目录是否已存在,并且是否允许覆盖。如果目录已存在且不允许覆盖,则记录一条警告信息并返回。 + - 通过调用os.makedirs创建结果目录,确保目录的存在。 + - 接着,调用init_bar_generator方法初始化BarGenerator对象,并进行相关的初始化操作。 + - 创建一个CzscTrader对象,并将初始化好的BarGenerator对象、持仓策略的深拷贝、交易信号配置的深拷贝等参数传递给CzscTrader的构造函数。 + - 为每个持仓策略创建相应的目录。 + - 遍历K线数据,调用trader.on_bar(bar)方法处理每一根K线数据。 + - 在每根K线数据处理完成后,检查每个持仓策略是否有操作,并且操作的时间是否与当前K线的时间一致。 + 如果有操作,则生成相应的HTML文件名,并调用trader.take_snapshot(file_html)方法生成交易快照。 + - 最后,遍历每个持仓策略,记录其评估信息,包括多空合并表现、多头表现、空头表现等。 + :param bars: 基础周期K线 :param res_path: 结果目录 :param kwargs: - bg 已经初始化好的BarGenerator对象,如果传入了bg,则忽略sdt和n参数 - sdt 初始化开始日期 - n 初始化最小K线数量 + + bg 已经初始化好的BarGenerator对象,如果传入了bg,则忽略sdt和n参数 + sdt 初始化开始日期 + n 初始化最小K线数量 + refresh 是否刷新结果目录 :return: """ if kwargs.get("refresh", False): @@ -294,7 +330,7 @@ def save_positions(self, path): :return: None """ os.makedirs(path, exist_ok=True) - for pos in self.positions: # type: ignore + for pos in self.positions: # type: ignore pos_ = pos.dump() pos_.pop("symbol") hash_code = hashlib.md5(str(pos_).encode()).hexdigest() @@ -322,6 +358,17 @@ def load_positions(self, files: List, check=True) -> List[Position]: class CzscJsonStrategy(CzscStrategyBase): """仅传入Json配置的Positions就完成策略创建 + 执行逻辑: + + 1. 定义CzscJsonStrategy类,并继承自CzscStrategyBase。这个类可以通过仅传入Json配置的Positions来完成策略创建。 + 2. 类中定义了一个名为positions的属性,使用@property装饰器将其标记为只读属性。 + 3. 在positions属性的getter方法中,执行以下操作: + - 从self.kwargs字典中获取键为"files_position"的值,并将其赋值给变量files。 + 这里的self.kwargs可能是通过在实例化该类时传入的参数或其他方式设置的一个字典,其中包含了策略配置文件的路径列表。 + - 使用self.kwargs.get方法获取键为"check_position"的值,并设置默认值为True,将其赋值给变量check。这个值用于确定是否对JSON持仓策略进行MD5校验。 + - 调用self.load_positions(files, check)方法,并返回其结果。这个方法可能是从父类CzscStrategyBase中继承的方法, + 用于从配置文件中加载持仓策略。将文件列表和校验标志作为参数传递给该方法,并返回加载的持仓策略列表。 + 必须参数: files_position: 以 json 文件配置的策略,每个json文件对应一个持仓策略配置 check_position: 是否对 json 持仓策略进行 MD5 校验,默认为 True diff --git a/czsc/traders/base.py b/czsc/traders/base.py index a29f8c78d..b102d5b22 100644 --- a/czsc/traders/base.py +++ b/czsc/traders/base.py @@ -68,6 +68,13 @@ def __repr__(self): def get_signals_by_conf(self): """通过信号参数配置获取信号 + 函数执行逻辑: + + 1. 函数首先创建一个空的有序字典s。 + 2. 如果self.signals_config不存在,函数直接返回空字典s,否则,函数遍历其中的每一个配置。 + 3. 对于每一个参数,函数提取出信号名称和freq,并根据这两个参数获取相应的信号,获取到的信号被添加到字典s中。 + 4. 函数最后返回字典s,其中包含了所有获取到的信号。 + 信号参数配置,格式如下: signals_config = [ @@ -98,6 +105,14 @@ def get_signals_by_conf(self): def take_snapshot(self, file_html=None, width: str = "1400px", height: str = "580px"): """获取快照 + 函数执行逻辑: + + 1. 函数首先创建一个Tab对象,用于存储所有的图表和表格。 + 2. 函数遍历所有的freq,对于每一个freq,函数获取相应的CZSC对象,并将其转换为一个图表,然后添加到Tab对象中。 + 3. 函数提取出所有的信号,并按照freq分组。对于每一个freq,函数创建一个表格,包含该freq下的所有信号,然后添加到Tab对象中。 + 4. 如果还有其他的信号,函数创建一个表格,包含所有的其他信号,然后添加到Tab对象中。 + 5. 最后,如果提供了file_html参数,函数将Tab对象渲染为一个HTML文件并保存;否则,函数返回Tab对象。 + :param file_html: 交易快照保存的 html 文件名 :param width: 图表宽度 :param height: 图表高度 @@ -135,13 +150,33 @@ def take_snapshot(self, file_html=None, width: str = "1400px", height: str = "58 return tab def open_in_browser(self, width="1400px", height="580px"): - """直接在浏览器中打开分析结果""" + """直接在浏览器中打开分析结果 + + 函数执行逻辑: + + 1. 首先创建一个HTML文件的路径file_html,这个文件将被保存在用户的主目录下,文件名为"temp_czsc_advanced_trader.html"。 + 2. 然后,函数调用self.take_snapshot方法,将分析结果保存为一个HTML文件。 + 3. 最后,函数使用webbrowser.open方法打开这个HTML文件 + """ file_html = os.path.join(home_path, "temp_czsc_advanced_trader.html") self.take_snapshot(file_html, width, height) webbrowser.open(file_html) def update_signals(self, bar: RawBar): - """输入基础周期已完成K线,更新信号,更新仓位""" + """输入基础周期已完成K线,更新信号,更新仓位 + + 函数执行逻辑: + + 1. 函数首先调用self.bg.update(bar),输入一个已完成的基础周期K线bar,更新各周期K线。 + 2. 然后,函数遍历所有的K线freq和对应的K线数据,对每一个K线数据,函数调用self.kas[freq].update(b[-1]),更新对应的 CZSC 对象。 + 3. 函数提取出K线的标的代码bar.symbol,并将其赋值给self.symbol。 + 4. 函数提取出基础freq的最后一根K线last_bar,并从中提取出结束时间dt,K线IDid,以及收盘价close,并将它们分别赋值给self.end_dt,self.bid,和self.latest_price。 + 5. 函数创建一个空的有序字典s,并调用self.get_signals_by_conf()获取所有的信号,然后将这些信号更新到字典s中。 + 6. 最后,函数将last_bar的所有属性更新到字典s中。 + + :param bar: 基础周期已完成K线 + :return: None + """ self.bg.update(bar) for freq, b in self.bg.bars.items(): self.kas[freq].update(b[-1]) @@ -158,6 +193,16 @@ def generate_czsc_signals(bars: List[RawBar], signals_config: List[dict], sdt: Union[AnyStr, datetime] = "20170101", init_n: int = 500, df=False, **kwargs): """使用 CzscSignals 生成信号 + 函数执行逻辑: + + 1. 函数首先从信号配置signals_config中获取所有的freqs。 + 2. 然后,函数将信号计算开始时间sdt转换为datetime类型,并将开始时间之前的K线数据分配给bars_left,开始时间之后的K线数据分配给bars_right。 + 3. 如果bars_right为空,即没有开始时间之后的K线数据,函数会发出一个警告,并返回一个空的DataFrame或空列表。 + 4. 函数创建一个BarGenerator对象bg,并使用bars_left中的K线数据来初始化它。 + 5. 函数创建一个CzscSignals对象cs,并将bg和信号配置signals_config作为参数传入。 + 6. 函数遍历bars_right中的每一根K线,对于每一根K线,函数调用cs.update_signals(bar)来更新信号,并将更新后的信号添加到_sigs列表中。 + 7. 最后,如果df参数为True,函数将_sigs转换为DataFrame并返回;否则,直接返回_sigs。 + :param bars: 基础周期 K 线序列 :param signals_config: 信号函数配置,格式如下: signals_config = [ @@ -207,9 +252,20 @@ def generate_czsc_signals(bars: List[RawBar], signals_config: List[dict], def check_signals_acc(bars: List[RawBar], signals_config: List[dict], delta_days: int = 5, **kwargs) -> None: - """人工验证形态信号识别的准确性的辅助工具: + """输入基础周期K线和想要验证的信号,输出信号识别结果的快照 - 输入基础周期K线和想要验证的信号,输出信号识别结果的快照 + 函数执行逻辑: + + 1. 函数首先获取基础周期K线的base_freq,并检查输入的K线数据bars是否按时间升序排列。如果bars的长度小于600,函数直接返回。 + 2. 然后,函数调用generate_czsc_signals方法,生成Czsc信号,并将结果保存在df中。 + 3. 函数提取出df中所有的信号列s_cols,并打印每一列的值的数量。然后,函数将所有的信号添加到signals列表中。 + 4. 函数将bars分为两部分,bars_left和bars_right,并获取信号配置signals_config中的所有freqs。 + 5. 函数创建一个BarGenerator对象bg,并使用bars_left中的K线数据来初始化它。 + 6. 函数创建一个CzscSignals对象ct,并将bg和信号配置signals_config作为参数传入。 + 7. 函数创建一个字典last_dt,用于存储每一个信号最后一次出现的时间。 + 8. 函数遍历bars_right中的每一根K线,对于每一根K线,函数调用ct.update_signals(bar)来更新信号。 + 9. 对于每一个信号,如果当前K线的时间与该信号最后一次出现的时间的差值大于delta_days,并且该信号与当前的信号匹配, + 函数将创建一个HTML文件,保存信号识别结果的快照,并更新该信号最后一次出现的时间。 :param bars: 原始K线 :param signals_config: 需要验证的信号列表 @@ -259,10 +315,18 @@ def check_signals_acc(bars: List[RawBar], signals_config: List[dict], delta_days def get_unique_signals(bars: List[RawBar], signals_config: List[dict], **kwargs): """获取信号函数中定义的所有信号列表 + 函数执行逻辑: + + 1. 函数首先检查输入的K线数据bars是否按时间升序排列。如果bars的长度小于600,函数直接返回一个空列表。 + 2. 然后,函数调用generate_czsc_signals方法,生成CZSC信号,并将结果保存在df中。 + 3. 函数遍历df中的所有列,对于每一列,如果列名包含三个部分,函数提取出该列中的所有唯一值,然后将列名和每一个唯一值组合成一个新的信号, + 并添加到_res列表中。注意,如果唯一值中包含"其他",则不会被添加到_res中。 + 4. 最后,函数返回_res,其中包含了所有的唯一信号。 + :param bars: 基础K线数据 :param signals_config: 信号函数配置 - :param kwargs: - :return: + :param kwargs: 传递给generate_czsc_signals方法的参数 + :return: 信号列表 """ assert bars[2].dt > bars[1].dt > bars[0].dt and bars[2].id > bars[1].id, "bars 中的K线元素必须按时间升序" if len(bars) < 600: @@ -282,6 +346,18 @@ def __init__(self, bg: Optional[BarGenerator] = None, positions: Optional[List[P ensemble_method: Union[AnyStr, Callable] = "mean", **kwargs): """ + 初始化逻辑: + + 1. 首先接收几个参数: + bg是一个可选的BarGenerator对象, + positions是一个可选的Position对象列表, + ensemble_method是一个集成方法,可以是字符串或者一个回调函数。 + 2. 函数将positions赋值给self.positions。如果positions不为空,函数会检查positions中的所有名称是否都是唯一的,如果不是,函数会抛出一个断言错误。 + 3. 函数将ensemble_method赋值给self.__ensemble_method。这个参数用于指定如何从多个仓位中集成一个仓位。 + 它可以是"mean"(平均),"vote"(投票),"max"(最大),或者一个回调函数。 + 4. 函数将"CzscTrader"赋值给self.name。 + 5. 最后,函数调用父类的初始化函数,传入bg和其他参数。 + :param bg: bar generator 对象 :param get_signals: 信号计算函数,输入是 CzscSignals 对象,输出是信号字典 :param ensemble_method: 多个仓位集成一个仓位的方法,可选值 mean, vote, max;也可以传入一个回调函数 @@ -308,6 +384,12 @@ def __repr__(self): def update(self, bar: RawBar) -> None: """输入基础周期已完成K线,更新信号,更新仓位 + 函数执行逻辑: + + 1. 函数首先接收一个参数bar,这是一个已完成的基础周期K线。 + 2. 函数调用self.update_signals(bar),输入这个已完成的基础周期K线,更新信号。 + 3. 如果self.positions不为空,即存在仓位,函数遍历所有的仓位,对于每一个仓位,函数调用position.update(self.s),更新该仓位的状态。 + :param bar: 基础周期已完成K线 :return: None """ @@ -317,9 +399,14 @@ def update(self, bar: RawBar) -> None: position.update(self.s) def on_sig(self, sig: dict) -> None: - """通过信号字典直接交易 + """通过信号字典直接交易,用于快速回测场景 + + 函数执行逻辑: - 主要用于快速回测场景 + 1. 函数首先接收一个参数sig,这是一个信号字典,赋值给self.s。 + 2. 函数从sig中提取出标的代码symbol,结束时间dt,K线ID id,以及收盘价close, + 并将它们分别赋值给self.symbol,self.end_dt,self.bid,和self.latest_price。 + 4. 如果self.positions不为空,即存在持仓策略,函数遍历所有position,函数调用position.update(self.s),更新该仓位的状态 :param sig: 信号字典 :return: None @@ -343,6 +430,10 @@ def on_bar(self, bar: RawBar) -> None: def pos_changed(self) -> bool: """判断仓位是否发生变化 + 1. 函数首先检查self.positions是否为空。如果为空,即没有仓位,函数直接返回False。 + 2. 如果self.positions不为空,函数遍历所有的仓位,对于每一个仓位,函数检查其pos_changed属性。 + 如果任何一个仓位的pos_changed属性为True,即该仓位发生了变化,函数返回True。 + :return: True/False """ if not self.positions: @@ -352,6 +443,18 @@ def pos_changed(self) -> bool: def get_ensemble_pos(self, method: Union[AnyStr, Callable] = None) -> float: """获取多个仓位的集成仓位 + 函数执行逻辑: + + 1. 函数首先检查self.positions是否为空。如果为空,即没有仓位,函数直接返回0。 + 2. 如果self.positions不为空,函数获取集成方法method。如果没有传入method参数,函数使用self.__ensemble_method作为集成方法。 + 3. 如果method是一个字符串,函数将其转换为小写,然后获取所有仓位的仓位序列pos_seq。 + 1. 如果method是"mean",函数计算pos_seq的平均值作为集成仓位。 + 2. 如果method是"vote",函数计算pos_seq的和的符号作为集成仓位。 + 3. 如果method是"max",函数获取pos_seq的最大值作为集成仓位。 + 4. 如果method不是以上任何一个值,函数抛出一个值错误。 + + 4. 如果method不是一个字符串,即它是一个回调函数,函数将所有仓位的名称和仓位组成的字典作为参数传入method,并将返回值作为集成仓位。 + :param method: 多个仓位集成一个仓位的方法,可选值 mean, vote, max;也可以传入一个回调函数 假设有三个仓位对象,当前仓位分别是 1, 1, -1 @@ -388,6 +491,13 @@ def get_ensemble_pos(self, method: Union[AnyStr, Callable] = None) -> float: def get_position(self, name: str) -> Optional[Position]: """获取指定名称的仓位策略对象 + 函数执行逻辑: + + 1. 函数首先接收一个参数name,这是要查找的仓位名称。 + 2. 函数检查self.positions是否为空。如果为空,即没有仓位,函数直接返回None。 + 3. 如果self.positions不为空,函数遍历所有的仓位,对于每一个仓位,函数检查其名称是否与输入的名称相同。如果相同,函数返回该仓位。 + 4. 如果遍历所有的仓位都没有找到与输入名称相同的仓位,函数返回None。 + :param name: 仓位名称 :return: Position """ @@ -453,6 +563,12 @@ def take_snapshot(self, file_html=None, width: str = "1400px", height: str = "58 def get_ensemble_weight(self, method: Optional[Union[AnyStr, Callable]] = None): """获取 CzscTrader 中所有 positions 按照 method 方法集成之后的权重 + 函数执行逻辑: + + 1. 函数首先接收一个参数method,这是集成方法,可以是字符串或者一个回调函数。 + 2. 函数检查是否提供了method参数。如果没有提供,函数使用self.__ensemble_method作为集成方法;如果提供了,函数使用提供的method作为集成方法。 + 3. 函数调用get_ensemble_weight函数,输入self和method,获取所有仓位按照指定方法集成之后的权重。 + :param method: str or callable 集成方法,可选值包括:'mean', 'max', 'min', 'vote' 也可以传入自定义的函数,函数的输入为 dict,key 为 position.name,value 为 position.pos, 样例输入: @@ -473,7 +589,6 @@ def weight_backtest(self, **kwargs): - method: str or callable,集成方法,参考 get_ensemble_weight 方法 - digits: int,权重小数点后保留的位数,例如 2 表示保留两位小数 - fee_rate: float,手续费率,例如 0.0002 表示万二 - - res_path: str,回测结果保存路径 :return: 回测结果 """ @@ -482,8 +597,6 @@ def weight_backtest(self, **kwargs): method = kwargs.get("method", self.__ensemble_method) digits = kwargs.get("digits", 2) fee_rate = kwargs.get("fee_rate", 0.0002) - res_path = kwargs.get("res_path", "./weight_backtest") dfw = self.get_ensemble_weight(method) - wb = WeightBacktest(dfw, digits=digits, fee_rate=fee_rate, res_path=res_path) - _res = wb.backtest() - return _res + wb = WeightBacktest(dfw, digits=digits, fee_rate=fee_rate) + return wb diff --git a/czsc/traders/performance.py b/czsc/traders/performance.py index 98f707fbf..35d7dc4f2 100644 --- a/czsc/traders/performance.py +++ b/czsc/traders/performance.py @@ -235,6 +235,32 @@ def agg_to_excel(self, file_xlsx): def combine_holds_and_pairs(holds, pairs, results_path): """结合股票池和择时策略开平交易进行分析 + 函数计算逻辑: + + 1. 将holds和pairs数据进行处理和准备。 + - 将holds复制到dfh变量。 + - 将dfh的'成分日期'列转换为日期类型。 + - 将dfh的'证券代码'列赋值给'标的代码'列。 + - 将pairs复制到dfp变量。 + - 将dfp的'开仓时间'列转换为日期类型,并将日期部分提取出来赋值给'开仓日期'列。 + + 2. 合并数据并筛选交易对。 + - 将dfp与dfh的[['开仓日期', '标的代码', '持仓权重']]列进行左连接,得到dfp_。 + - 从dfp_中选择持仓权重大于0的交易对,赋值给df_pairs。 + - 从dfp中选择开仓时间在df_pairs的开仓时间范围内的数据,赋值给dfp_sub。 + + 3. 进行评价和分析。 + - 使用dfp_sub创建PairsPerformance对象tp_old。 + - 使用df_pairs创建PairsPerformance对象tp_new。 + + 4. 创建结果目录并保存评价结果和交易数据。 + - 使用os.makedirs创建结果目录。 + - 将tp_old的统计结果保存为Excel文件,文件名为"原始交易评价.xlsx"。 + - 将tp_new的统计结果保存为Excel文件,文件名为"组合过滤评价.xlsx"。 + - 将df_pairs的数据保存为Feather文件,文件名为"组合过滤交易.feather"。 + + 5. 返回tp_old和tp_new对象。 + :param holds: 组合股票池数据,样例: 成分日期 证券代码 n1b 持仓权重 0 2020-01-02 000001.SZ 183.758194 0.001232 @@ -290,6 +316,20 @@ def combine_holds_and_pairs(holds, pairs, results_path): def combine_dates_and_pairs(dates: list, pairs: pd.DataFrame, results_path): """结合大盘日期择时和择时策略开平交易进行分析 + 函数计算逻辑: + + 1. 将dates转换为日期类型,并赋值给变量dates。 + 2. 将pairs复制到dfp变量。 + 3. 将dfp的'开仓时间'列转换为日期类型,并将日期部分提取出来赋值给'开仓日期'列。 + 4. 从dfp中选择开仓日期在dates中的数据,赋值给df_pairs。 + 5. 从dfp中选择开仓时间在df_pairs的开仓时间范围内的数据,赋值给dfp_sub。 + 6. 使用dfp_sub创建PairsPerformance对象tp_old。 + 7. 使用df_pairs创建PairsPerformance对象tp_new。 + 8. 打印原始交易的基本信息和平仓年度统计。 + 9. 打印组合过滤后的交易的基本信息和平仓年度统计。 + 10. 创建结果目录并保存评价结果和交易数据。 + 11. 返回tp_old和tp_new对象。 + :param dates: 大盘日期择时日期数据,数据样例 ['2020-01-02', ..., '2022-01-06'] :param pairs: 择时策略开平交易数据,数据格式如下 标的代码 交易方向 最大仓位 开仓时间 累计开仓 平仓时间 \ diff --git a/czsc/traders/sig_parse.py b/czsc/traders/sig_parse.py index 20393a4c3..cc7578b09 100644 --- a/czsc/traders/sig_parse.py +++ b/czsc/traders/sig_parse.py @@ -19,6 +19,17 @@ class SignalsParser: def __init__(self, signals_module: str = 'czsc.signals'): """ + 函数执行逻辑: + + 1. 将传入的 signals_module 参数赋给实例变量 self.signals_module,代表信号函数所在的模块,默认模块是czsc库的signals模块。 + 2. 使用 import_by_name 函数导入了指定名称的模块 signals_module。 + 3. 对于导入的模块中的每个属性名进行遍历: + - 魔法函数和私有函数不进行处理。 + - 获取函数的注解信息,并通过正则表达式获取注解中的参数模板和信号列表。 + - 如果解析到了参数模板,则将其存储在 sig_pats_map 中,key是函数名称。 + - 如果解析到了信号列表,则将其存储在 sig_name_map 中,并且为每个信号创建了 Signal 对象并存储在列表中,key是函数名称。 + 4. 最后将得到的 sig_name_map 和 sig_pats_map 存储在实例变量中,以便其他方法使用。 + :param signals_module: 指定信号函数所在模块 """ self.signals_module = signals_module @@ -51,8 +62,15 @@ def __init__(self, signals_module: str = 'czsc.signals'): def parse_params(self, name, signal): """获取信号函数参数 - :param name: 信号函数名称 - :param signal: 需要解析的信号 + 函数执行逻辑: + + 1. 首先根据传入的 name 和 signal 参数,通过 Signal(signal).key 获取一个键值。 + 2. 然后从实例变量 sig_pats_map 中获取与指定名称对应的参数模板,并将其存储在 pats 中。 + 3. 如果没有找到参数模板,则返回 None。 + 4. 最后将信号函数的完整名称存储在参数字典中,并返回参数字典。 + + :param name: 信号函数名称, 如:cxt_bi_end_V230222 + :param signal: 需要解析的信号, 如:15分钟_D1K_量柱V221218_低量柱_6K_任意_0 :return: """ key = Signal(signal).key @@ -74,6 +92,12 @@ def parse_params(self, name, signal): def get_function_name(self, signal: str): """获取信号对应的信号函数名称 + 函数执行逻辑: + + 1. 创建一个 _signal 对象,通过传入的信号字符串进行初始化。 + 2. 通过遍历 sig_name_map 中的项目,找出那些与 _signal.k3 相匹配的键,并将它们存储在 _k3_match 列表中。 + 3. 如果只有一个匹配项,则返回该项;否则记录错误日志并返回 None。 + :param signal: 信号,数据样例:15分钟_D1K_量柱V221218_低量柱_6K_任意_0 :return: 信号函数名称 """ @@ -90,9 +114,18 @@ def get_function_name(self, signal: str): def config_to_keys(self, config: List[Dict]): """将信号函数配置转换为信号key列表 + 函数执行逻辑: + + 1. 首先创建了一个空列表 keys 用于存储信号key。 + 2. 对于传入的 config 列表中的每个配置字典 conf 进行以下操作: + - 获取信号函数的名称。 + - 如果该信号函数的名称在 self.sig_pats_map 中存在对应的模板,使用参数填充模板,并将结果添加到 keys 列表中。 + :param config: 信号函数配置 + config = [{'freq': '日线', 'max_overlap': '3', 'name': 'czsc.signals.cxt_bi_end_V230222'}, {'freq1': '日线', 'freq2': '60分钟', 'name': 'czsc.signals.cxt_zhong_shu_gong_zhen_V221221'}] + :return: 信号key列表 """ keys = [] @@ -103,7 +136,22 @@ def config_to_keys(self, config: List[Dict]): return keys def parse(self, signal_seq: List[str]): - """解析信号序列""" + """解析信号序列 + + 函数执行逻辑: + + 1. 接受一个signal_seq 参数。 + 2. 定义一个空列表res ,用于存储解析结果。 + 3. 遍历信号序列signal_seq 中的每一个信号: + + - 调用get_function_name 方法,以信号为参数,获取该信号对应的函数名。 + - 进行函数名存在性判断,name 在sig_pats_map 中存在, + 调用parse_params 方法,以函数名和信号为参数,解析参数并返回结果。 + + :param signal_seq: 信号序列, 样例: + ['15分钟_D1K_量柱V221218_低量柱_6K_任意_0', '日线_D1K_量柱V221218_低量柱_6K_任意_0'] + :return: 信号函数配置 + """ res = [] for signal in signal_seq: name = self.get_function_name(signal) @@ -119,6 +167,12 @@ def parse(self, signal_seq: List[str]): def get_signals_config(signals_seq: List[str], signals_module: str = 'czsc.signals') -> List[Dict]: """获取信号列表对应的信号函数配置 + 函数执行逻辑: + + 1. 首先创建了一个 SignalsParser 类的实例对象 sp,传入了参数 signals_module进行初始化, + 初始化工作主要是解析signals_module下的信号函数,生成了sig_pats_map信号参数模板字典和sig_name_map信号列表字典。 + 2. 然后使用 sp 实例调用 parse 方法,该方法解析 signals_seq 中的信号,并返回信号函数的配置信息。 + :param signals_seq: 信号列表 :param signals_module: 信号函数所在模块 :return: 信号函数配置 @@ -131,6 +185,15 @@ def get_signals_config(signals_seq: List[str], signals_module: str = 'czsc.signa def get_signals_freqs(signals_seq: List) -> List[str]: """获取信号列表对应的K线周期列表 + 函数执行逻辑: + + 1. 然后对于 signals_seq 中的每个信号进行以下操作: + + - 使用正则表达式从信号中提取信号周期,并将其存储在 _freqs 变量中。 + - 如果提取到了信号周期,则将其加入到 freqs 列表中。 + + 2. 最后验证数据是否符合sorted_freqs列表规范,并且以sorted_freqs列表的排序进行返回。 + :param signals_seq: 信号列表 / 信号函数配置列表 :return: K线周期列表 """ diff --git a/czsc/utils/__init__.py b/czsc/utils/__init__.py index d96c0a710..2d4419cb7 100644 --- a/czsc/utils/__init__.py +++ b/czsc/utils/__init__.py @@ -16,7 +16,7 @@ from .sig import check_pressure_support, check_gap_info, is_bis_down, is_bis_up, get_sub_elements, is_symmetry_zs from .sig import same_dir_counts, fast_slow_cross, count_last_same, create_single_signal from .plotly_plot import KlineChart -from .trade import cal_trade_price, update_nbars, update_bbars, update_tbars, risk_free_returns +from .trade import cal_trade_price, update_nbars, update_bbars, update_tbars, risk_free_returns, resample_to_daily from .cross import CrossSectionalPerformance, cross_sectional_ranker from .stats import daily_performance, net_value_stats, subtract_fee from .signal_analyzer import SignalAnalyzer, SignalPerformance @@ -67,6 +67,17 @@ def get_py_namespace(file_py: str, keys: list = []) -> dict: def import_by_name(name): """通过字符串导入模块、类、函数 + 函数执行逻辑: + + 1. 检查 name 中是否包含点号('.')。如果没有,则直接使用内置的 import 函数来导入整个模块,并返回该模块对象。 + 2. 如果 name 包含点号,先处理一个相对路径。将 name 拆分为两部分:module_name 和 function_name。 + 使用 Python 内置的 rsplit 方法从右边开始分割,只取一次,这样可以确保我们将最后的一个点号前的部分作为 module_name,点号后面的部分作为 function_name。 + 3. 使用import函数导入指定的 module_name。 + 这里传入三个参数:globals() 和 locals() 分别代表当前全局和局部命名空间; + [function_name] 是一个列表,用于指定要导入的子模块或属性名。 + 这样做是为了避免一次性导入整个模块的所有内容,提高效率。 + 4. 使用 vars 函数获取模块的字典表示形式(即模块内所有的变量和函数),取出 function_name 对应的值,然后返回这个值。 + :param name: 模块名,如:'czsc.objects.Factor' :return: 模块对象 """ @@ -89,11 +100,11 @@ def freqs_sorted(freqs): return _freqs_new -def create_grid_params(prefix: str, detail=False, **kwargs) -> dict: +def create_grid_params(prefix: str = "", multiply=3, **kwargs) -> dict: """创建 grid search 参数组合 :param prefix: 参数组前缀 - :param detail: 是否使用参数值构建参数组的名称 + :param multiply: 参数组合的位数,如果为 0,则使用 # 分隔参数 :param kwargs: 任意参数的候选序列,参数值推荐使用 iterable :return: 参数组合字典 @@ -111,8 +122,8 @@ def create_grid_params(prefix: str, detail=False, **kwargs) -> dict: >>>x = create_grid_params("test", x=2, y=('a', 'b'), detail=False) >>>print(x) Out[1]: - {'test@001': {'x': 2, 'y': 'a'}, - 'test@002': {'x': 2, 'y': 'b'}} + {'test001': {'x': 2, 'y': 'a'}, + 'test002': {'x': 2, 'y': 'b'}} """ from sklearn.model_selection import ParameterGrid @@ -126,13 +137,12 @@ def create_grid_params(prefix: str, detail=False, **kwargs) -> dict: params = {} for i, row in enumerate(ParameterGrid(params_grid), 1): - if detail: + if multiply == 0: key = "#".join([f"{k}={v}" for k, v in row.items()]) - # params[f"{prefix}@{key}"] = row else: - key = str(i).zfill(3) + key = str(i).zfill(multiply) - row['version'] = f"{prefix}@{key}" + row['version'] = f"{prefix}{key}" params[f"{prefix}@{key}"] = row return params diff --git a/czsc/utils/plotly_plot.py b/czsc/utils/plotly_plot.py index ce7834b12..c46bf93cc 100644 --- a/czsc/utils/plotly_plot.py +++ b/czsc/utils/plotly_plot.py @@ -23,7 +23,22 @@ class KlineChart: """ def __init__(self, n_rows=3, **kwargs): - # 子图数量 + """K线绘图工具类 + + 初始化执行逻辑: + + - 接收一个可选参数 n_rows,默认值为 3。这个参数表示图表中的子图数量。 + - 接收一个可变参数列表 **kwargs,可以传递其他配置参数。 + - 如果没有提供 row_heights 参数,则根据 n_rows 设置默认的行高度。 + - 定义了一些颜色变量:color_red 和 color_green。 + - 使用 make_subplots 函数创建一个具有 n_rows 行和 1 列的子图布局,并设置一些共享属性和间距。 + - 使用 fig.update_yaxes 和 fig.update_xaxes 更新 Y 轴和 X 轴的属性,如显示网格、自动调整范围等。 + - 使用 fig.update_layout 更新整个图形的布局,包括标题、边距、图例位置和样式、背景模板等。 + - 将 fig 对象保存在 self.fig 属性中。 + + :param n_rows: 子图数量 + :param kwargs: + """ self.n_rows = n_rows row_heights = kwargs.get("row_heights", None) if not row_heights: @@ -65,7 +80,24 @@ def __init__(self, n_rows=3, **kwargs): self.fig = fig def add_kline(self, kline: pd.DataFrame, name: str = "K线", **kwargs): - """绘制K线""" + """绘制K线 + + 函数执行逻辑: + + 1. 检查 kline 数据框是否包含 'text' 列。如果没有,则添加一个空字符串列。 + 2. 使用 go.Candlestick 创建一个K线图,并传入以下参数: + - x: 日期时间数据 + - open, high, low, close: 开盘价、最高价、最低价和收盘价 + - text: 显示在每个 K 线上的文本标签 + - name: 图例名称 + - showlegend: 是否显示图例 + - increasing_line_color 和 decreasing_line_color: 上涨时的颜色和下跌时的颜色 + - increasing_fillcolor 和 decreasing_fillcolor: 上涨时填充颜色和下跌时填充颜色 + - **kwargs: 可以传递其他自定义参数给 Candlestick 函数。 + + 3. 将创建的烛台图对象添加到 self.fig 中的第一个子图(row=1, col=1)。 + 4. 使用 fig.update_traces 更新所有 traces 的 xaxis 属性为 "x1"。 + """ if 'text' not in kline.columns: kline['text'] = "" @@ -77,13 +109,43 @@ def add_kline(self, kline: pd.DataFrame, name: str = "K线", **kwargs): self.fig.update_traces(xaxis="x1") def add_vol(self, kline: pd.DataFrame, row=2, **kwargs): - """绘制成交量图""" + """绘制成交量图 + + 函数执行逻辑: + + 1. 首先,复制输入的 kline 数据框到 df。 + 2. 使用 np.where 函数根据收盘价(df['close'])和开盘价(df['open'])之间的关系为 df 创建一个新列 'vol_color'。 + 如果收盘价大于开盘价,则使用红色(self.color_red),否则使用绿色(self.color_green)。 + 3. 调用 add_bar_indicator 方法绘制成交量图。传递以下参数: + - x: 日期时间数据 + - y: 成交量数据 + - color: 根据 'vol_color' 列的颜色 + - name: 图例名称 + - row: 指定要添加指标的子图行数,默认值为 2 + - show_legend: 是否显示图例,默认值为 False + """ df = kline.copy() df['vol_color'] = np.where(df['close'] > df['open'], self.color_red, self.color_green) self.add_bar_indicator(df['dt'], df['vol'], color=df['vol_color'], name="成交量", row=row, show_legend=False) def add_sma(self, kline: pd.DataFrame, row=1, ma_seq=(5, 10, 20), visible=False, **kwargs): - """绘制均线图""" + """绘制均线图 + + 函数执行逻辑: + + 1. 复制输入的 kline 数据框到 df。 + 2. 获取自定义参数 line_width,默认值为 0.6。 + 3. 遍历 ma_seq 中的所有均线周期: + - 对每个周期使用 pandas rolling 方法计算收盘价的移动平均线。 + - 调用 add_scatter_indicator 方法将移动平均线数据绘制为折线图。传递以下参数: + - x: 日期时间数据 + - y: 移动平均线数据 + - name: 图例名称,格式为 "MA{ma}",其中 {ma} 是当前的均线周期。 + - row: 指定要添加指标的子图行数,默认值为 1 + - line_width: 线宽,默认值为 0.6 + - visible: 是否可见,默认值为 False + - show_legend: 是否显示图例,默认值为 True + """ df = kline.copy() line_width = kwargs.get('line_width', 0.6) for ma in ma_seq: @@ -91,7 +153,30 @@ def add_sma(self, kline: pd.DataFrame, row=1, ma_seq=(5, 10, 20), visible=False, row=row, line_width=line_width, visible=visible, show_legend=True) def add_macd(self, kline: pd.DataFrame, row=3, **kwargs): - """绘制MACD图""" + """绘制MACD图 + + 函数执行逻辑: + + 1. 首先,复制输入的 kline 数据框到 df。 + 2. 获取自定义参数 fastperiod、slowperiod 和 signalperiod。这些参数分别对应于计算 MACD 时使用的快周期、慢周期和信号周期,默认值分别为 12、26 和 9。 + 3. 使用 talib 库的 MACD 函数计算 MACD 值(diff, dea, macd)。 + 4. 创建一个名为 macd_colors 的 numpy 数组,根据 macd 值大于零的情况设置颜色:大于零使用红色(self.color_red),否则使用绿色(self.color_green)。 + 5. 调用 add_scatter_indicator 方法将 diff 和 dea 绘制为折线图。传递以下参数: + - x: 日期时间数据 + - y: diff 或 dea 数据 + - name: 图例名称,分别为 "DIFF" 和 "DEA" + - row: 指定要添加指标的子图行数,默认值为 3 + - line_color: 线的颜色,分别为 'white' 和 'yellow' + - show_legend: 是否显示图例,默认值为 False + - line_width: 线宽,默认值为 0.6 + 6. 调用 add_bar_indicator 方法将 macd 绘制为柱状图。传递以下参数: + - x: 日期时间数据 + - y: macd 数据 + - name: 图例名称,为 "MACD" + - row: 指定要添加指标的子图行数,默认值为 3 + - color: 根据 macd_colors 设置颜色 + - show_legend: 是否显示图例,默认值为 False + """ df = kline.copy() fastperiod = kwargs.get('fastperiod', 12) slowperiod = kwargs.get('slowperiod', 26) @@ -106,7 +191,26 @@ def add_macd(self, kline: pd.DataFrame, row=3, **kwargs): self.add_bar_indicator(df['dt'], macd, name="MACD", row=row, color=macd_colors, show_legend=False) def add_indicator(self, dt, scatters: list = None, scatter_names: list = None, bar=None, bar_name='', row=4, **kwargs): - """绘制曲线叠加bar型指标""" + """绘制曲线叠加bar型指标 + + 1. 获取自定义参数 line_width,默认值为 0.6。 + 2. 如果 scatters(列表)不为空,则遍历 scatters 中的所有散点数据: + - 对于每个散点数据,调用 add_scatter_indicator 方法将其绘制为折线图。传递以下参数: + - x: 日期时间数据 + - y: 散点数据 + - name: 图例名称,来自 scatter_names 列表 + - row: 指定要添加指标的子图行数,默认值为 4 + - show_legend: 是否显示图例,默认值为 False + - line_width: 线宽,默认值为 0.6 + 3. 如果 bar 不为空,则使用 np.where 函数根据 bar 值大于零的情况设置颜色:大于零使用红色(self.color_red),否则使用绿色(self.color_green)。 + 4. 调用 add_bar_indicator 方法将 bar 绘制为柱状图。传递以下参数: + - x: 日期时间数据 + - y: bar 数据 + - name: 图例名称,为传入的 bar_name 参数 + - row: 指定要添加指标的子图行数,默认值为 4 + - color: 根据上一步计算的颜色设置 + - show_legend: 是否显示图例,默认值为 False + """ line_width = kwargs.get('line_width', 0.6) for i, scatter in enumerate(scatters): self.add_scatter_indicator(dt, scatter, name=scatter_names[i], row=row, show_legend=False, line_width=line_width) @@ -118,6 +222,25 @@ def add_indicator(self, dt, scatters: list = None, scatter_names: list = None, b def add_marker_indicator(self, x, y, name: str, row: int, text=None, **kwargs): """绘制标记类指标 + 函数执行逻辑: + + 1. 获取自定义参数 line_color、line_width、hover_template、show_legend 和 visible。 + 这些参数分别对应于折线颜色、宽度、鼠标悬停时显示的模板、是否显示图例和是否可见。 + 2. 使用给定的 x、y 数据创建一个 go.Scatter 对象(散点图),并传入以下参数: + - x: 指标的x轴数据 + - y: 指标的y轴数据 + - name: 指标名称 + - text: 文本说明 + - line_width: 线宽 + - line_color: 线颜色 + - hovertemplate: 鼠标悬停时显示的模板 + - showlegend: 是否显示图例 + - visible: 是否可见 + - opacity: 透明度 + - mode: 绘制模式,为 'markers' 表示只绘制标记 + - marker: 标记的样式,包括大小、颜色和符号 + 3. 调用 self.fig.add_trace 方法将创建的 go.Scatter 对象添加到指定子图中,并更新所有 traces 的 X 轴属性为 "x1"。 + :param x: 指标的x轴 :param y: 指标的y轴 :param name: 指标名称 @@ -145,6 +268,21 @@ def add_scatter_indicator(self, x, y, name: str, row: int, text=None, **kwargs): 绘图API文档:https://plotly.com/python-api-reference/generated/plotly.graph_objects.Scatter.html + 函数执行逻辑: + + 1. 获取自定义参数 mode、hover_template、show_legend、opacity 和 visible。这些参数分别对应于绘图模式、鼠标悬停时显示的模板、是否显示图例、透明度和是否可见。 + 2. 使用给定的 x、y 数据创建一个 go.Scatter 对象(散点图),并传入以下参数: + - x: 指标的x轴数据 + - y: 指标的y轴数据 + - name: 指标名称 + - text: 文本说明 + - mode: 绘制模式,默认为 'text+lines',表示同时绘制文本和线条 + - hovertemplate: 鼠标悬停时显示的模板 + - showlegend: 是否显示图例 + - visible: 是否可见 + - opacity: 透明度 + 3. 调用 self.fig.add_trace 方法将创建的 go.Scatter 对象添加到指定子图中,并更新所有 traces 的 X 轴属性为 "x1"。 + :param x: 指标的x轴 :param y: 指标的y轴 :param name: 指标名称 @@ -169,6 +307,22 @@ def add_bar_indicator(self, x, y, name: str, row: int, color=None, **kwargs): 绘图API文档:https://plotly.com/python-api-reference/generated/plotly.graph_objects.Bar.html + 函数执行逻辑: + + 1. 获取自定义参数 hover_template、show_legend、visible 和 base。这些参数分别对应于鼠标悬停时显示的模板、是否显示图例、是否可见和基线(默认为 True)。 + 2. 如果 color 参数为空,则使用 self.color_red 作为颜色。 + 3. 使用给定的 x、y 数据创建一个 go.Bar 对象(条形图),并传入以下参数: + - x: 指标的x轴数据 + - y: 指标的y轴数据 + - marker_line_color: 条形边框的颜色 + - marker_color: 条形填充的颜色 + - name: 指标名称 + - showlegend: 是否显示图例 + - hovertemplate: 鼠标悬停时显示的模板 + - visible: 是否可见 + - base: 基线,默认为 True + 4. 调用 self.fig.add_trace 方法将创建的 go.Bar 对象添加到指定子图中,并更新所有 traces 的 X 轴属性为 "x1"。 + :param x: 指标的x轴 :param y: 指标的y轴 :param name: 指标名称 diff --git a/czsc/utils/st_components.py b/czsc/utils/st_components.py index 379e7f881..4370b2005 100644 --- a/czsc/utils/st_components.py +++ b/czsc/utils/st_components.py @@ -54,6 +54,21 @@ def _stats(df_, type_='持有日'): st.plotly_chart(fig, use_container_width=True) +def show_monthly_return(df, ret_col='total', title="月度累计收益", **kwargs): + """展示指定列的月度累计收益""" + assert df.index.dtype == 'datetime64[ns]', "index 必须是 datetime 类型" + st.subheader(title, divider="rainbow") + monthly = df[[ret_col]].resample('M').sum() + monthly['year'] = monthly.index.year + monthly['month'] = monthly.index.month + monthly = monthly.pivot_table(index='year', columns='month', values=ret_col) + month_cols = [f"{x}月" for x in range(1, 13)] + monthly.columns = month_cols + monthly['年收益'] = monthly.sum(axis=1) + monthly = monthly.style.background_gradient(cmap='RdYlGn_r', axis=None, subset=month_cols).format('{:.2%}', na_rep='-') + st.dataframe(monthly, use_container_width=True) + + def show_correlation(df, cols=None, method='pearson', **kwargs): """用 streamlit 展示相关性 diff --git a/czsc/utils/stats.py b/czsc/utils/stats.py index 67414adf5..bdba34760 100644 --- a/czsc/utils/stats.py +++ b/czsc/utils/stats.py @@ -96,7 +96,7 @@ def daily_performance(daily_returns): dd = np.maximum.accumulate(cum_returns) - cum_returns max_drawdown = np.max(dd) kama = annual_returns / max_drawdown if max_drawdown != 0 else 10 - win_pct = len(daily_returns[daily_returns > 0]) / len(daily_returns) + win_pct = len(daily_returns[daily_returns >= 0]) / len(daily_returns) annual_volatility = np.std(daily_returns) * np.sqrt(252) none_zero_cover = len(daily_returns[daily_returns != 0]) / len(daily_returns) @@ -106,11 +106,20 @@ def daily_performance(daily_returns): for i in range(len(high_index) - 1): max_interval = max(max_interval, high_index[i + 1] - high_index[i]) + def __min_max(x, min_val, max_val, digits=4): + if x < min_val: + x1 = min_val + elif x > max_val: + x1 = max_val + else: + x1 = x + return round(x1, digits) + sta = { "年化": round(annual_returns, 4), - "夏普": round(sharpe_ratio, 2), + "夏普": __min_max(sharpe_ratio, -5, 5, 2), "最大回撤": round(max_drawdown, 4), - "卡玛": round(kama, 2), + "卡玛": __min_max(kama, -10, 10, 2), "日胜率": round(win_pct, 4), "年化波动率": round(annual_volatility, 4), "非零覆盖": round(none_zero_cover, 4), diff --git a/czsc/utils/trade.py b/czsc/utils/trade.py index d3743ad1c..e73b4f006 100644 --- a/czsc/utils/trade.py +++ b/czsc/utils/trade.py @@ -143,3 +143,53 @@ def update_tbars(da: pd.DataFrame, event_col: str) -> None: n_seq = [int(x.strip('nb')) for x in da.columns if x[0] == 'n' and x[-1] == 'b'] for n in n_seq: da[f't{n}b'] = da[f'n{n}b'] * da[event_col] + + +def resample_to_daily(df: pd.DataFrame, sdt=None, edt=None, only_trade_date=True): + """将非日线数据转换为日线数据,以便进行日线级别的分析 + + 函数执行逻辑: + + 1. 首先,函数接收一个数据框`df`,以及可选的开始日期`sdt`,结束日期`edt`,和一个布尔值`only_trade_date`。 + 2. 函数将`df`中的`dt`列转换为日期时间格式。如果没有提供`sdt`或`edt`,则使用`df`中的最小和最大日期作为开始和结束日期。 + 3. 创建一个日期序列。如果`only_trade_date`为真,则只包含交易日期;否则,包含`sdt`和`edt`之间的所有日期。 + 4. 使用`merge_asof`函数,找到每个日期在原始`df`中对应的最近一个日期。 + 5. 创建一个映射,将每个日期映射到原始`df`中的对应行。 + 6. 对于日期序列中的每个日期,复制映射中对应的数据行,并将日期设置为当前日期。 + 7. 最后,将所有复制的数据行合并成一个新的数据框,并返回。 + + :param df: 日线以上周期的数据,必须包含 dt 列 + :param sdt: 开始日期 + :param edt: 结束日期 + :param only_trade_date: 是否只保留交易日数据 + :return: pd.DataFrame + """ + from czsc.utils.calendar import get_trading_dates + + df['dt'] = pd.to_datetime(df['dt']) + sdt = df['dt'].min() if not sdt else pd.to_datetime(sdt) + edt = df['dt'].max() if not edt else pd.to_datetime(edt) + + # 创建日期序列 + if only_trade_date: + trade_dates = get_trading_dates(sdt=sdt, edt=edt) + else: + trade_dates = pd.date_range(sdt, edt, freq='D').tolist() + trade_dates = pd.DataFrame({'date': trade_dates}) + trade_dates = trade_dates.sort_values('date', ascending=True).reset_index(drop=True) + + # 通过 merge_asof 找到每个日期对应原始 df 中最近一个日期 + vdt = pd.DataFrame({'dt': df['dt'].unique()}) + trade_dates = pd.merge_asof(trade_dates, vdt, left_on='date', right_on='dt') + trade_dates = trade_dates.dropna(subset=['dt']).reset_index(drop=True) + + dt_map = {dt: dfg for dt, dfg in df.groupby('dt')} + results = [] + for row in trade_dates.to_dict('records'): + # 注意:这里必须进行 copy,否则默认浅拷贝导致数据异常 + df_ = dt_map[row['dt']].copy() + df_['dt'] = row['date'] + results.append(df_) + + dfr = pd.concat(results, ignore_index=True) + return dfr diff --git a/test/test_trade_utils.py b/test/test_trade_utils.py index bdad6016a..76040d4fd 100644 --- a/test/test_trade_utils.py +++ b/test/test_trade_utils.py @@ -6,6 +6,7 @@ describe: 测试交易价格计算 """ import czsc +import pandas as pd import numpy as np from test.test_analyze import read_1min @@ -18,3 +19,21 @@ def test_trade_price(): close = df['close'].iloc[1:21] vol = df['vol'].iloc[1:21] assert df['VWAP20'].iloc[0] == round(np.average(close, weights=vol), 3) + + +def test_make_it_daily(): + dts = pd.date_range(start='2022-01-01', end='2022-02-28', freq='W') + df = pd.DataFrame({'dt': dts, 'value': np.random.random(len(dts))}) + + # Call the function with the test DataFrame + result = czsc.resample_to_daily(df) + + # Check the result + assert isinstance(result, pd.DataFrame), "Result should be a DataFrame" + assert 'dt' in result.columns, "Result should have a 'dt' column" + assert result['dt'].dtype == 'datetime64[ns]', "'dt' column should be datetime64[ns] type" + assert not result['dt'].isnull().any(), "'dt' column should not have any null values" + + # Check if the result DataFrame has daily data + result = czsc.resample_to_daily(df, only_trade_date=False) + assert (result['dt'].diff().dt.days <= 1).iloc[1:].all(), "Result should have daily data" diff --git a/test/test_utils.py b/test/test_utils.py index 090dbfa70..d01da28be 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -111,9 +111,9 @@ def test_daily_performance(): # Test case 4: normal daily returns daily_returns = np.array([0.01, 0.02, -0.01, 0.03, 0.02, -0.02, 0.01, -0.01, 0.02, 0.01]) result = daily_performance(daily_returns) - assert result == {'年化': 2.016, '夏普': 8.27, '最大回撤': 0.02, '卡玛': 100.8, '日胜率': 0.7, + assert result == {'年化': 2.016, '夏普': 5, '最大回撤': 0.02, '卡玛': 10, '日胜率': 0.7, '年化波动率': 0.2439, '非零覆盖': 1.0, '盈亏平衡点': 0.7, '最大新高时间': 4} result = daily_performance([0.01, 0.02, -0.01, 0.03, 0.02, -0.02, 0.01, -0.01, 0.02, 0.01]) - assert result == {'年化': 2.016, '夏普': 8.27, '最大回撤': 0.02, '卡玛': 100.8, '日胜率': 0.7, + assert result == {'年化': 2.016, '夏普': 5, '最大回撤': 0.02, '卡玛': 10, '日胜率': 0.7, '年化波动率': 0.2439, '非零覆盖': 1.0, '盈亏平衡点': 0.7, '最大新高时间': 4}