From 33b347b8238c9e23595f5c69acdbd2e4bd711860 Mon Sep 17 00:00:00 2001 From: zengbin93 Date: Sun, 26 Nov 2023 16:19:12 +0800 Subject: [PATCH] =?UTF-8?q?V0.9.37=20=E6=9B=B4=E6=96=B0=E4=B8=80=E6=89=B9?= =?UTF-8?q?=E4=BB=A3=E7=A0=81=20(#178)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 0.9.37 first commit * 0.9.37 update * 0.9.37 新增GPT解读函数执行逻辑 * 0.9.37 weight backtest 增加多进程支持 * 0.9.37 fix check bi --- .github/workflows/pythonpackage.yml | 2 +- czsc/__init__.py | 5 +- czsc/analyze.py | 83 +++++++----- czsc/connectors/cooperation.py | 32 +++++ czsc/connectors/ts_connector.py | 87 +++++++++++-- czsc/data/ts.py | 1 + czsc/data/ts_cache.py | 3 - czsc/traders/rwc.py | 8 +- czsc/traders/weight_backtest.py | 119 ++++++++++++++++-- czsc/utils/__init__.py | 2 +- czsc/utils/bar_generator.py | 64 ++++++++-- czsc/utils/features.py | 44 ++++++- czsc/utils/sig.py | 8 +- czsc/utils/signal_analyzer.py | 7 +- czsc/utils/stats.py | 38 +++++- czsc/utils/trade.py | 55 ++++++-- examples/explore_func_tree.py | 15 +++ examples/test_offline/test_update_bi.py | 25 ++++ examples/test_offline/test_weight_backtest.py | 11 +- 19 files changed, 505 insertions(+), 104 deletions(-) create mode 100644 examples/explore_func_tree.py create mode 100644 examples/test_offline/test_update_bi.py diff --git a/.github/workflows/pythonpackage.yml b/.github/workflows/pythonpackage.yml index 2bf36b95f..1dc650925 100644 --- a/.github/workflows/pythonpackage.yml +++ b/.github/workflows/pythonpackage.yml @@ -5,7 +5,7 @@ name: Python package on: push: - branches: [ master, V0.9.36 ] + branches: [ master, V0.9.37 ] pull_request: branches: [ master ] diff --git a/czsc/__init__.py b/czsc/__init__.py index b0ca6d789..a7022dc4c 100644 --- a/czsc/__init__.py +++ b/czsc/__init__.py @@ -112,11 +112,10 @@ feture_cross_layering, ) -__version__ = "0.9.36" +__version__ = "0.9.37" __author__ = "zengbin93" __email__ = "zeng_bin8888@163.com" -__date__ = "20231112" - +__date__ = "20231118" def welcome(): diff --git a/czsc/analyze.py b/czsc/analyze.py index 23aec103c..b2d5c9718 100644 --- a/czsc/analyze.py +++ b/czsc/analyze.py @@ -75,7 +75,23 @@ def remove_include(k1: NewBar, k2: NewBar, k3: RawBar): def check_fx(k1: NewBar, k2: NewBar, k3: NewBar): - """查找分型""" + """查找分型 + + 函数计算逻辑: + + 1. 如果第二个`NewBar`对象的最高价和最低价都高于第一个和第三个`NewBar`对象的对应价格,那么它被认为是顶分型(G)。 + 在这种情况下,函数会创建一个新的`FX`对象,其标记为`Mark.G`,并将其赋值给`fx`。 + + 2. 如果第二个`NewBar`对象的最高价和最低价都低于第一个和第三个`NewBar`对象的对应价格,那么它被认为是底分型(D)。 + 在这种情况下,函数会创建一个新的`FX`对象,其标记为`Mark.D`,并将其赋值给`fx`。 + + 3. 函数最后返回`fx`,如果没有找到分型,`fx`将为`None`。 + + :param k1: 第一个`NewBar`对象 + :param k2: 第二个`NewBar`对象 + :param k3: 第三个`NewBar`对象 + :return: `FX`对象或`None` + """ fx = None if k1.high < k2.high > k3.high and k1.low < k2.low > k3.low: fx = FX(symbol=k1.symbol, dt=k2.dt, mark=Mark.G, high=k2.high, @@ -89,10 +105,24 @@ def check_fx(k1: NewBar, k2: NewBar, k3: NewBar): def check_fxs(bars: List[NewBar]) -> List[FX]: - """输入一串无包含关系K线,查找其中所有分型""" + """输入一串无包含关系K线,查找其中所有分型 + + 函数的主要步骤: + + 1. 创建一个空列表`fxs`用于存储找到的分型。 + 2. 遍历`bars`列表中的每个元素(除了第一个和最后一个),并对每三个连续的`NewBar`对象调用`check_fx`函数。 + 3. 如果`check_fx`函数返回一个`FX`对象,检查它的标记是否与`fxs`列表中最后一个`FX`对象的标记相同。如果相同,记录一个错误日志。 + 如果不同,将这个`FX`对象添加到`fxs`列表中。 + 4. 最后返回`fxs`列表,它包含了`bars`列表中所有找到的分型。 + + 这个函数的主要目的是找出`bars`列表中所有的顶分型和底分型,并确保它们是交替出现的。如果发现连续的两个分型标记相同,它会记录一个错误日志。 + + :param bars: 无包含关系K线列表 + :return: 分型列表 + """ fxs = [] - for i in range(1, len(bars)-1): - fx = check_fx(bars[i-1], bars[i], bars[i+1]) + for i in range(1, len(bars) - 1): + fx = check_fx(bars[i - 1], bars[i], bars[i + 1]) if isinstance(fx, FX): # 默认情况下,fxs本身是顶底交替的,但是对于一些特殊情况下不是这样; 临时强制要求fxs序列顶底交替 if len(fxs) >= 2 and fx.mark == fxs[-1].mark: @@ -115,32 +145,20 @@ def check_bi(bars: List[NewBar], benchmark=None): return None, bars fx_a = fxs[0] - try: - if fxs[0].mark == Mark.D: - direction = Direction.Up - fxs_b = [x for x in fxs if x.mark == Mark.G and x.dt > fx_a.dt and x.fx > fx_a.fx] - if not fxs_b: - return None, bars - - fx_b = fxs_b[0] - for fx in fxs_b[1:]: - if fx.high >= fx_b.high: - fx_b = fx - - elif fxs[0].mark == Mark.G: - direction = Direction.Down - fxs_b = [x for x in fxs if x.mark == Mark.D and x.dt > fx_a.dt and x.fx < fx_a.fx] - if not fxs_b: - return None, bars - - fx_b = fxs_b[0] - for fx in fxs_b[1:]: - if fx.low <= fx_b.low: - fx_b = fx - else: - raise ValueError - except Exception as e: - logger.exception(f"笔识别错误: {e}") + if fx_a.mark == Mark.D: + direction = Direction.Up + fxs_b = (x for x in fxs if x.mark == Mark.G and x.dt > fx_a.dt and x.fx > fx_a.fx) + fx_b = max(fxs_b, key=lambda fx: fx.high, default=None) + + elif fx_a.mark == Mark.G: + direction = Direction.Down + fxs_b = (x for x in fxs if x.mark == Mark.D and x.dt > fx_a.dt and x.fx < fx_a.fx) + fx_b = min(fxs_b, key=lambda fx: fx.low, default=None) + + else: + raise ValueError + + if fx_b is None: return None, bars bars_a = [x for x in bars if fx_a.elements[0].dt <= x.dt <= fx_b.elements[2].dt] @@ -291,10 +309,7 @@ def update(self, bar: RawBar): self.bars_raw = self.bars_raw[s_index:] # 如果有信号计算函数,则进行信号计算 - if self.get_signals: - self.signals = self.get_signals(c=self) - else: - self.signals = OrderedDict() + self.signals = self.get_signals(c=self) if self.get_signals else OrderedDict() def to_echarts(self, width: str = "1400px", height: str = '580px', bs=[]): """绘制K线分析图 diff --git a/czsc/connectors/cooperation.py b/czsc/connectors/cooperation.py index ef1b6fb96..818115430 100644 --- a/czsc/connectors/cooperation.py +++ b/czsc/connectors/cooperation.py @@ -8,6 +8,8 @@ import os import czsc import pandas as pd +from tqdm import tqdm +from datetime import datetime from czsc import RawBar, Freq # 首次使用需要打开一个python终端按如下方式设置 token @@ -105,3 +107,33 @@ def get_raw_bars(symbol, freq, sdt, edt, fq='前复权', **kwargs): return czsc.resample_bars(df, target_freq=freq) raise ValueError(f"symbol {symbol} 无法识别,获取数据失败!") + + +def stocks_daily_klines(years=None, **kwargs): + """获取全市场A股的日线数据""" + adj = kwargs.get('adj', 'hfq') + if years is None: + years = ['2017', '2018', '2019', '2020', '2021', '2022', '2023'] + + res = [] + for year in years: + ttl = 3600 * 6 if year == str(datetime.now().year) else -1 + kline = dc.pro_bar(trade_year=year, adj=adj, v=2, ttl=ttl) + res.append(kline) + + dfk = pd.concat(res, ignore_index=True) + dfk['dt'] = pd.to_datetime(dfk['dt']) + dfk = dfk.sort_values(['code', 'dt'], ascending=True).reset_index(drop=True) + if kwargs.get('exclude_bj', True): + dfk = dfk[~dfk['code'].str.endswith(".BJ")].reset_index(drop=True) + + nxb = kwargs.get('nxb', [1, 2, 5]) + if nxb: + rows = [] + for _, dfg in tqdm(dfk.groupby('code'), desc="计算NXB收益率", ncols=80, colour='green'): + czsc.update_nbars(dfg, numbers=nxb, move=1, price_col='open') + rows.append(dfg) + dfk = pd.concat(rows, ignore_index=True) + + dfk = dfk.rename(columns={'code': 'symbol'}) + return dfk diff --git a/czsc/connectors/ts_connector.py b/czsc/connectors/ts_connector.py index 1bc0ad5e3..26aae5b8e 100644 --- a/czsc/connectors/ts_connector.py +++ b/czsc/connectors/ts_connector.py @@ -6,29 +6,100 @@ describe: Tushare数据源 """ import os -from czsc import data +import czsc +import pandas as pd +from czsc import Freq, RawBar +from typing import List -dc = data.TsDataCache(data_path=os.environ.get('ts_data_path', r'D:\ts_data')) +# 首次使用需要打开一个python终端按如下方式设置 token +# czsc.set_url_token(token='your token', url='http://api.tushare.pro') +cache_path = os.getenv("TS_CACHE_PATH", os.path.expanduser("~/.ts_data_cache")) +dc = czsc.DataClient(url='http://api.tushare.pro', cache_path=cache_path) -def get_symbols(step): - if step.upper() == 'ALL': - return data.get_symbols(dc, 'index') + data.get_symbols(dc, 'stock') + data.get_symbols(dc, 'etfs') - return data.get_symbols(dc, step) + +def format_kline(kline: pd.DataFrame, freq: Freq) -> List[RawBar]: + """Tushare K线数据转换 + + :param kline: Tushare 数据接口返回的K线数据 + :param freq: K线周期 + :return: 转换好的K线数据 + """ + bars = [] + dt_key = 'trade_time' if '分钟' in freq.value else 'trade_date' + kline = kline.sort_values(dt_key, ascending=True, ignore_index=True) + records = kline.to_dict('records') + + for i, record in enumerate(records): + if freq == Freq.D: + vol = int(record['vol'] * 100) if record['vol'] > 0 else 0 + amount = int(record.get('amount', 0) * 1000) + else: + vol = int(record['vol']) if record['vol'] > 0 else 0 + amount = int(record.get('amount', 0)) + + # 将每一根K线转换成 RawBar 对象 + bar = RawBar(symbol=record['ts_code'], dt=pd.to_datetime(record[dt_key]), + id=i, freq=freq, open=record['open'], close=record['close'], + high=record['high'], low=record['low'], + vol=vol, # 成交量,单位:股 + amount=amount, # 成交额,单位:元 + ) + bars.append(bar) + return bars + + +def get_symbols(step="all"): + """获取标的代码""" + stocks = dc.stock_basic(exchange='', list_status='L', fields='ts_code,symbol,name,area,industry,list_date') + stocks_ = stocks[stocks['list_date'] < '2010-01-01'].ts_code.to_list() + stocks_map = { + "index": ['000905.SH', '000016.SH', '000300.SH', '000001.SH', '000852.SH', + '399001.SZ', '399006.SZ', '399376.SZ', '399377.SZ', '399317.SZ', '399303.SZ'], + "stock": stocks.ts_code.to_list(), + "check": ['000001.SZ'], + "train": stocks_[:200], + "valid": stocks_[200:600], + "etfs": ['512880.SH', '518880.SH', '515880.SH', '513050.SH', '512690.SH', + '512660.SH', '512400.SH', '512010.SH', '512000.SH', '510900.SH', + '510300.SH', '510500.SH', '510050.SH', '159992.SZ', '159985.SZ', + '159981.SZ', '159949.SZ', '159915.SZ'], + } + + asset_map = { + "index": "I", + "stock": "E", + "check": "E", + "train": "E", + "valid": "E", + "etfs": "FD" + } + + if step.lower() == "all": + symbols = [] + for k, v in stocks_map.items(): + symbols += [f"{ts_code}#{asset_map[k]}" for ts_code in v] + else: + asset = asset_map[step] + symbols = [f"{ts_code}#{asset}" for ts_code in stocks_map[step]] + + return symbols def get_raw_bars(symbol, freq, sdt, edt, fq='后复权', raw_bar=True): """读取本地数据""" + from czsc import data + tdc = data.TsDataCache(data_path=cache_path) ts_code, asset = symbol.split("#") freq = str(freq) adj = "qfq" if fq == "前复权" else "hfq" if "分钟" in freq: freq = freq.replace("分钟", "min") - bars = dc.pro_bar_minutes(ts_code, sdt=sdt, edt=edt, freq=freq, asset=asset, adj=adj, raw_bar=raw_bar) + bars = tdc.pro_bar_minutes(ts_code, sdt=sdt, edt=edt, freq=freq, asset=asset, adj=adj, raw_bar=raw_bar) else: _map = {"日线": "D", "周线": "W", "月线": "M"} freq = _map[freq] - bars = dc.pro_bar(ts_code, start_date=sdt, end_date=edt, freq=freq, asset=asset, adj=adj, raw_bar=raw_bar) + bars = tdc.pro_bar(ts_code, start_date=sdt, end_date=edt, freq=freq, asset=asset, adj=adj, raw_bar=raw_bar) return bars diff --git a/czsc/data/ts.py b/czsc/data/ts.py index e950691c8..c28f3cb13 100644 --- a/czsc/data/ts.py +++ b/czsc/data/ts.py @@ -79,6 +79,7 @@ def __getattr__(self, name): print("Tushare Pro 初始化失败") +@deprecated(reason="统一到 ts_connector 中", version='1.0.0') def format_kline(kline: pd.DataFrame, freq: Freq) -> List[RawBar]: """Tushare K线数据转换 diff --git a/czsc/data/ts_cache.py b/czsc/data/ts_cache.py index 1d34371a4..00e10bc8e 100644 --- a/czsc/data/ts_cache.py +++ b/czsc/data/ts_cache.py @@ -620,6 +620,3 @@ def stocks_daily_basic_new(self, sdt: str, edt: str): dfb['上市天数'] = (dfb['trade_date'] - pd.to_datetime(dfb['list_date'], errors='coerce')).apply(lambda x: x.days) dfb.to_feather(file_cache) return dfb - - - diff --git a/czsc/traders/rwc.py b/czsc/traders/rwc.py index 938574bb2..f3b7a958d 100644 --- a/czsc/traders/rwc.py +++ b/czsc/traders/rwc.py @@ -120,7 +120,7 @@ def publish(self, symbol, dt, weight, price=0, ref=None, overwrite=False): if not overwrite: last_dt = self.get_last_times(symbol) - if last_dt is not None and dt <= last_dt: + if last_dt is not None and dt <= last_dt: # type: ignore logger.warning(f"不允许重复写入,已过滤 {symbol} {dt} 的重复信号") return 0 @@ -213,8 +213,8 @@ def clear_all(self): """删除该策略所有记录""" self.r.delete(f'{self.key_prefix}:META:{self.strategy_name}') keys = self.get_keys(f'{self.key_prefix}:{self.strategy_name}*') - if keys is not None and len(keys) > 0: - self.r.delete(*keys) + if keys is not None and len(keys) > 0: # type: ignore + self.r.delete(*keys) # type: ignore @staticmethod def register_lua_publish(client): @@ -264,7 +264,7 @@ def register_lua_publish(client): def get_symbols(self): """获取策略交易的品种列表""" keys = self.get_keys(f'{self.key_prefix}:{self.strategy_name}*') - symbols = {x.split(":")[2] for x in keys} + symbols = {x.split(":")[2] for x in keys} # type: ignore return list(symbols) def get_last_weights(self, symbols=None, ignore_zero=True, lua=True): diff --git a/czsc/traders/weight_backtest.py b/czsc/traders/weight_backtest.py index 7b6e0087a..1684bd096 100644 --- a/czsc/traders/weight_backtest.py +++ b/czsc/traders/weight_backtest.py @@ -8,9 +8,12 @@ import numpy as np import pandas as pd import plotly.express as px +from tqdm import tqdm from loguru import logger from pathlib import Path from typing import Union, AnyStr, Callable +from multiprocessing import cpu_count +from concurrent.futures import ProcessPoolExecutor from czsc.traders.base import CzscTrader from czsc.utils.io import save_json from czsc.utils.stats import daily_performance, evaluate_pairs @@ -89,12 +92,33 @@ def long_short_equity(factors, returns, hold_period=2, rank=5, **kwargs): def get_ensemble_weight(trader: CzscTrader, method: Union[AnyStr, Callable] = 'mean'): """获取 CzscTrader 中所有 positions 按照 method 方法集成之后的权重 + 函数计算逻辑: + + 1. 获取 trader 持仓信息并转换为DataFrame: + + - 遍历交易者的每个持仓位置。 + - 将每个位置的持仓信息转换为DataFrame,并合并到一个整体的DataFrame中。 + - 将持仓列重命名为对应的位置名称。 + + 2. 根据给定的方法计算权重: + + - 如果方法是可调用对象,将持仓信息转换为字典,并传递给该方法进行计算。 + - 如果方法是预定义字符串("mean"、"max"、"min"、"vote"),根据相应的计算方式计算权重。 + + 3. 返回包含日期、交易标的、权重和价格的DataFrame: + + - 将计算得到的权重与其他相关列一起组成一个新的DataFrame。 + - 将交易标的信息添加到新的DataFrame中。 + - 返回包含日期、交易标的、权重和价格的DataFrame副本。 + :param trader: CzscTrader - 缠论交易者 + 缠论交易员对象 :param method: str or callable + 集成方法,可选值包括:'mean', 'max', 'min', 'vote' 也可以传入自定义的函数,函数的输入为 dict,key 为 position.name,value 为 position.pos, 样例输入: {'多头策略A': 1, '多头策略B': 1, '空头策略A': -1} + :param kwargs: :return: pd.DataFrame columns = ['dt', 'symbol', 'weight', 'price'] @@ -137,11 +161,22 @@ class WeightBacktest: 飞书文档:https://s0cqcxuy3p.feishu.cn/wiki/Pf1fw1woQi4iJikbKJmcYToznxb """ - version = "V231104" + version = "V231126" def __init__(self, dfw, digits=2, **kwargs) -> None: """持仓权重回测 + 初始化函数逻辑: + + 1. 将传入的kwargs保存在实例变量self.kwargs中。 + 2. 复制传入的dfw到实例变量self.dfw。 + 3. 检查self.dfw中是否存在空值,如果存在则抛出ValueError异常,并提示"dfw 中存在空值,请先处理"。 + 4. 设置实例变量self.digits为传入的digits值。 + 5. 从kwargs中获取'fee_rate'参数的值,默认为0.0002,并将其保存在实例变量self.fee_rate中。 + 6. 将self.dfw中的'weight'列转换为浮点型,并保留self.digits位小数。 + 7. 提取self.dfw中的唯一交易标的符号,并将其保存在实例变量self.symbols中。 + 8. 执行backtest()方法进行回测,并将结果保存在实例变量self.results中。 + :param dfw: pd.DataFrame, columns = ['dt', 'symbol', 'weight', 'price'], 持仓权重数据,其中 dt 为K线结束时间,必须是连续的交易时间序列,不允许有时间断层 @@ -164,7 +199,6 @@ def __init__(self, dfw, digits=2, **kwargs) -> None: :param kwargs: - fee_rate: float,单边交易成本,包括手续费与冲击成本, 默认为 0.0002 - - res_path: str,回测结果保存路径,默认为 "weight_backtest" """ self.kwargs = kwargs @@ -175,11 +209,22 @@ def __init__(self, dfw, digits=2, **kwargs) -> None: self.fee_rate = kwargs.get('fee_rate', 0.0002) self.dfw['weight'] = self.dfw['weight'].astype('float').round(digits) self.symbols = list(self.dfw['symbol'].unique().tolist()) - self.results = self.backtest() + self.results = self.backtest(n_jobs=kwargs.get('n_jobs', int(cpu_count() / 2))) def get_symbol_daily(self, symbol): """获取某个合约的每日收益率 + 函数计算逻辑: + + 1. 从实例变量self.dfw中筛选出交易标的为symbol的数据,并复制到新的DataFrame dfs。 + 2. 计算每条数据的收益(edge):权重乘以下一条数据的价格除以当前价格减1。 + 3. 计算每条数据的手续费(cost):当前权重与前一条数据权重之差的绝对值乘以实例变量self.fee_rate。 + 4. 计算每条数据扣除手续费后的收益(edge_post_fee):收益减去手续费。 + 5. 根据日期进行分组,并对每组进行求和操作,得到每日的总收益、总扣除手续费后的收益和总手续费。 + 6. 重置索引,并将交易标的符号添加到DataFrame中。 + 7. 重命名列名,将'edge_post_fee'列改为'return',将'dt'列改为'date'。 + 8. 选择需要的列,并返回包含日期、交易标的、收益、扣除手续费后的收益和手续费的DataFrame。 + :param symbol: str,合约代码 :return: pd.DataFrame,品种每日收益率, @@ -214,7 +259,32 @@ def get_symbol_daily(self, symbol): return daily def get_symbol_pairs(self, symbol): - """获取某个合约的开平交易记录""" + """获取某个合约的开平交易记录 + + 函数计算逻辑: + + 1. 从实例变量self.dfw中筛选出交易标的为symbol的数据,并复制到新的DataFrame dfs。 + 2. 将权重乘以10的self.digits次方,并转换为整数类型,作为volume列的值。 + 3. 生成bar_id列,从1开始递增,与行数对应。 + 4. 创建一个空列表operates,用于存储开平仓交易记录。 + 5. 定义内部函数__add_operate,用于向operates列表中添加开平仓交易记录。 + 函数接受日期dt、bar_id、交易量volume、价格price和操作类型operate作为参数。 + 函数根据交易量的绝对值循环添加交易记录到operates列表中。 + 6. 将dfs转换为字典列表rows。 + 7. 处理第一个行记录。 + - 如果volume大于0,则调用__add_operate函数添加"开多"操作的交易记录。 + - 如果volume小于0,则调用__add_operate函数添加"开空"操作的交易记录。 + 8. 处理后续的行记录。 + - 使用zip函数遍历rows[:-1]和rows[1:],同时获取当前行row1和下一行row2。 + - 根据volume的正负和变化情况,调用__add_operate函数添加对应的开平仓交易记录。 + 9. 创建空列表pairs和opens,用于存储交易对和开仓记录。 + 10. 遍历operates列表中的交易记录。 + - 如果操作类型为"开多"或"开空",将交易记录添加到opens列表中,并继续下一次循环。 + - 如果操作类型为"平多"或"平空",将对应的开仓记录从opens列表中弹出。 + 根据开仓和平仓的价格计算盈亏比例,并创建一个交易对字典,将其添加到pairs列表中。 + 11. 将pairs列表转换为DataFrame,并返回包含交易标的的开平仓交易记录的DataFrame。 + + """ dfs = self.dfw[self.dfw['symbol'] == symbol].copy() dfs['volume'] = (dfs['weight'] * pow(10, self.digits)).astype(int) dfs['bar_id'] = list(range(1, len(dfs) + 1)) @@ -286,14 +356,41 @@ def __add_operate(dt, bar_id, volume, price, operate): df_pairs = pd.DataFrame(pairs) return df_pairs - def backtest(self): - """回测所有合约的收益率""" + def process_symbol(self, symbol): + """处理某个合约的回测数据""" + daily = self.get_symbol_daily(symbol) + pairs = self.get_symbol_pairs(symbol) + return symbol, {"daily": daily, "pairs": pairs} + + def backtest(self, n_jobs=1): + """回测所有合约的收益率 + + 函数计算逻辑: + + 1. 获取数据:遍历所有合约,调用get_symbol_daily方法获取每个合约的日收益,调用get_symbol_pairs方法获取每个合约的交易流水。 + + 2. 数据处理:将每个合约的日收益合并为一个DataFrame,使用pd.pivot_table方法将数据重塑为以日期为索引、合约为列、 + 收益率为值的表格,并将缺失值填充为0。计算所有合约收益率的平均值,并将该列添加到DataFrame中。将结果存储在res字典中, + 键为合约名,值为包含日行情数据和交易对数据的字典。 + + 3. 绩效评价:计算回测结果的开始日期和结束日期,调用daily_performance方法评估总收益率的绩效指标。将每个合约的交易对数据 + 合并为一个DataFrame,调用evaluate_pairs方法评估交易对的绩效指标。将结果存储在stats字典中,并更新到绩效评价的字典中。 + + 4. 返回结果:将合约的等权日收益数据和绩效评价结果存储在res字典中,并将该字典作为函数的返回结果。 + """ + n_jobs = min(n_jobs, cpu_count()) + logger.info(f"n_jobs={n_jobs},将使用 {n_jobs} 个进程进行回测") + symbols = self.symbols res = {} - for symbol in symbols: - daily = self.get_symbol_daily(symbol) - pairs = self.get_symbol_pairs(symbol) - res[symbol] = {"daily": daily, "pairs": pairs} + if n_jobs <= 1: + for symbol in tqdm(sorted(symbols), desc="WBT进度"): + res[symbol] = self.process_symbol(symbol)[1] + else: + with ProcessPoolExecutor(n_jobs) as pool: + for symbol, res_symbol in tqdm(pool.map(self.process_symbol, sorted(symbols)), + desc="WBT进度", total=len(symbols)): + res[symbol] = res_symbol dret = pd.concat([v['daily'] for k, v in res.items() if k in symbols], ignore_index=True) dret = pd.pivot_table(dret, index='date', columns='symbol', values='return').fillna(0) diff --git a/czsc/utils/__init__.py b/czsc/utils/__init__.py index e15073dbb..d96c0a710 100644 --- a/czsc/utils/__init__.py +++ b/czsc/utils/__init__.py @@ -13,7 +13,7 @@ from .bar_generator import BarGenerator, freq_end_time, resample_bars from .bar_generator import is_trading_time, get_intraday_times, check_freq_and_market from .io import dill_dump, dill_load, read_json, save_json -from .sig import check_pressure_support, check_gap_info, is_bis_down, is_bis_up, get_sub_elements +from .sig import check_pressure_support, check_gap_info, is_bis_down, is_bis_up, get_sub_elements, is_symmetry_zs from .sig import same_dir_counts, fast_slow_cross, count_last_same, create_single_signal from .plotly_plot import KlineChart from .trade import cal_trade_price, update_nbars, update_bbars, update_tbars, risk_free_returns diff --git a/czsc/utils/bar_generator.py b/czsc/utils/bar_generator.py index 240ea917c..0ea5869c0 100644 --- a/czsc/utils/bar_generator.py +++ b/czsc/utils/bar_generator.py @@ -11,7 +11,6 @@ from czsc.objects import RawBar, Freq from pathlib import Path from loguru import logger -from czsc.utils.calendar import next_trading_date mss = pd.read_feather(Path(__file__).parent / "minites_split.feather") @@ -43,6 +42,14 @@ def get_intraday_times(freq='1分钟', market="A股"): def check_freq_and_market(time_seq: List[AnyStr], freq: Optional[AnyStr] = None): """检查时间序列是否为同一周期,是否为同一市场 + 函数计算逻辑: + + 1. 如果`freq`在特定列表中,函数直接返回`freq`和"默认"作为市场类型。 + 2. 如果`freq`是'1分钟',函数会添加额外的时间点到`time_seq`中。 + 3. 函数去除`time_seq`中的重复时间点,并确保其长度至少为2。 + 4. 函数遍历`freq_market_times`字典,寻找与`time_seq`匹配的项,并返回对应的`freq_x`和`market`。 + 5. 如果没有找到匹配的项,函数返回None和"默认"。 + :param time_seq: 时间序列,如 ['11:00', '15:00', '23:00', '01:00', '02:30'] :param freq: 时间序列对应的K线周期,可选参数,使用该参数可以加快检查速度。 可选值:1分钟、5分钟、15分钟、30分钟、60分钟、日线、周线、月线、季线、年线 @@ -137,7 +144,16 @@ def freq_end_time(dt: datetime, freq: Union[Freq, AnyStr], market="A股") -> dat def resample_bars(df: pd.DataFrame, target_freq: Union[Freq, AnyStr], raw_bars=True, **kwargs): - """将df中的K线序列转换为目标周期的K线序列 + """将给定的K线数据重新采样为目标周期的K线数据 + + 函数计算逻辑: + + 1. 确定目标周期`target_freq`的类型和市场类型。 + 2. 添加一个新列`freq_edt`,表示每个数据点对应的目标周期的结束时间。 + 3. 根据`freq_edt`对数据进行分组,并对每组数据进行聚合,得到目标周期的K线数据。 + 4. 重置索引,并选择需要的列。 + 5. 根据`raw_bars`参数,决定返回的数据类型:如果为True,转换为`RawBar`对象;如果为False,直接返回DataFrame。 + 6. 如果`drop_unfinished`参数为True,删除最后一根未完成的K线。 :param df: 原始K线数据,必须包含以下列:symbol, dt, open, close, high, low, vol, amount。样例如下: symbol dt open close high low \ @@ -190,10 +206,6 @@ def resample_bars(df: pd.DataFrame, target_freq: Union[Freq, AnyStr], raw_bars=T if df['dt'].iloc[-1] < _bars[-1].dt: _bars.pop() return _bars - # if df['dt'].iloc[-1] < _bars[-1].dt: - # # 清除最后一根未完成的K线 - # _bars.pop() - # return _bars else: return dfk1 @@ -216,7 +228,9 @@ def __init__(self, base_freq: str, freqs: List[str], max_count: int = 5000, mark def __validate_freqs(self): from czsc.utils import sorted_freqs - # sorted_freqs = ['Tick', '1分钟', '5分钟', '15分钟', '30分钟', '60分钟', '日线', '周线', '月线', '季线', '年线'] + if self.base_freq not in sorted_freqs: + raise ValueError(f'base_freq is not in sorted_freqs: {self.base_freq}') + i = sorted_freqs.index(self.base_freq) f = sorted_freqs[i:] for freq in self.freqs: @@ -226,9 +240,15 @@ def __validate_freqs(self): def init_freq_bars(self, freq: str, bars: List[RawBar]): """初始化某个周期的K线序列 + 函数计算逻辑: + + 1. 首先,它断言`freq`必须是`self.bars`的键之一。如果`freq`不在`self.bars`的键中,代码会抛出一个断言错误。 + 2. 然后,它断言`self.bars[freq]`必须为空。如果`self.bars[freq]`不为空,代码会抛出一个断言错误,并显示一条错误消息。 + 3. 如果以上两个断言都通过,它会将`bars`赋值给`self.bars[freq]`,从而初始化指定频率的K线序列。 + 4. 最后,它会将`bars`列表中的最后一个`RawBar`对象的`symbol`属性赋值给`self.symbol`。 + :param freq: 周期名称 :param bars: K线序列 - :return: """ assert freq in self.bars.keys() assert not self.bars[freq], f"self.bars['{freq}'] 不为空,不允许执行初始化" @@ -241,9 +261,18 @@ def __repr__(self): def _update_freq(self, bar: RawBar, freq: Freq) -> None: """更新指定周期K线 + 函数计算逻辑: + + 1. 计算目标频率的结束时间`freq_edt`。 + 2. 检查`self.bars`中是否已经有目标频率的K线。如果没有,创建一个新的`RawBar`对象,并将其添加到`self.bars`中,然后返回。 + 3. 如果已经有目标频率的K线,获取最后一根K线`last`。 + 4. 检查`freq_edt`是否不等于最后一根K线的日期时间。如果不等于,创建一个新的`RawBar`对象,并将其添加到`self.bars`中。 + 5. 如果`freq_edt`等于最后一根K线的日期时间,创建一个新的`RawBar`对象,其开盘价为最后一根K线的开盘价, + 收盘价为当前K线的收盘价,最高价为最后一根K线和当前K线的最高价中的最大值,最低价为最后一根K线和当前K线的最低价中的最小值, + 成交量和成交金额为最后一根K线和当前K线的成交量和成交金额的和。然后用这个新的`RawBar`对象替换`self.bars`中的最后一根K线。 + :param bar: 基础周期已完成K线 :param freq: 目标周期 - :return: """ freq_edt = freq_end_time(bar.dt, freq, self.market) @@ -268,11 +297,21 @@ def _update_freq(self, bar: RawBar, freq: Freq) -> None: def update(self, bar: RawBar) -> None: """更新各周期K线 + 函数计算逻辑: + + 1. 首先,它获取基准频率`base_freq`,并断言`bar`的频率值等于`base_freq`。 + 2. 然后,它将`bar`的符号和日期时间设置为`self.symbol`和`self.end_dt`。 + 3. 接下来,它检查是否已经有一个与`bar`日期时间相同的K线存在于`self.bars[base_freq]`中。 + 如果存在,它会记录一个警告并返回,不进行任何更新。 + 4. 如果不存在重复的K线,它会遍历`self.bars`的所有键(即所有的频率),并对每个频率调用`self._update_freq`方法来更新该频率的K线。 + 5. 最后,它会限制在内存中的K线数量,确保每个频率的K线数量不超过`self.max_count`。 + :param bar: 必须是已经结束的Bar - :return: + :return: None """ base_freq = self.base_freq - assert bar.freq.value == base_freq + if bar.freq.value != base_freq: + raise ValueError(f"Input bar frequency does not match base frequency. Expected {base_freq}, got {bar.freq.value}") self.symbol = bar.symbol self.end_dt = bar.dt @@ -285,4 +324,5 @@ def update(self, bar: RawBar) -> None: # 限制存在内存中的K限制数量 for f, b in self.bars.items(): - self.bars[f] = b[-self.max_count:] + if len(b) > self.max_count: + self.bars[f] = b[-self.max_count:] diff --git a/czsc/utils/features.py b/czsc/utils/features.py index 1f1df8abf..53cbfc2a2 100644 --- a/czsc/utils/features.py +++ b/czsc/utils/features.py @@ -13,11 +13,21 @@ def normalize_feature(df, x_col, **kwargs): """因子标准化:缩尾,然后标准化 - :param df: pd.DataFrame,数据源 + 函数计算逻辑: + + 1. 首先,检查因子列x_col是否存在缺失值,如果存在缺失值,则抛出异常,提示缺失值的数量。 + 2. 从kwargs参数中获取缩尾比例q的值,默认为0.05。 + 3. 对因子列进行缩尾操作,首先根据 dt 分组,然后使用lambda函数对每个组内的因子进行缩尾处理, + 将超过缩尾比例的值截断,并使用scale函数进行标准化。 + 4. 将处理后的因子列重新赋值给原始DataFrame对象的对应列。 + + :param df: pd.DataFrame,数据 :param x_col: str,因子列名 :param kwargs: - q: float,缩尾比例, 默认 0.05 + + :return: pd.DataFrame,处理后的数据 """ df = df.copy() assert df[x_col].isna().sum() == 0, "因子有缺失值,缺失数量为:{}".format(df[x_col].isna().sum()) @@ -29,6 +39,24 @@ def normalize_feature(df, x_col, **kwargs): def normalize_ts_feature(df, x_col, n=10, **kwargs): """对时间序列数据进行归一化处理 + 函数计算逻辑: + + 1. 首先,进行一系列的断言检查,确保因子值的取值数量大于分层数量,并且因子列没有缺失值。 + 2. 从kwargs参数中获取分层方法method的值,默认为"expanding",以及min_periods的值,默认为300。 + 3. 如果在DataFrame的列中不存在x_col_norm列,则进行以下操作: + - 如果分层方法是"expanding",则使用expanding函数对因子列进行处理,计算每个时间点的标准化值,公式为(当前值 - 平均值) / 标准差。 + - 如果分层方法是"rolling",则使用rolling函数对因子列进行处理,计算每个窗口的标准化值,窗口大小为min_periods,公式同上。 + - 如果分层方法不是上述两种情况,则抛出错误。 + - 对于缺失值,获取原始值,然后进行标准化。 + + 4. 如果在DataFrame的列中不存在x_col_qcut列,则进行以下操作: + - 如果分层方法是"expanding",则使用expanding函数对因子列进行处理,计算每个时间点的分位数,将其转化为分位数的标签(0到n-1)。 + - 如果分层方法是"rolling",则使用rolling函数对因子列进行处理,计算每个窗口的分位数,窗口大小为min_periods。 + - 如果分层方法不是上述两种情况,则抛出错误。 + - 使用分位数后的值填充原始值中的缺失值。 + - 对于缺失值,获取原始值,然后进行分位数处理分层。 + - 创建一个新的列x_col分层,根据分位数的标签值,将其转化为"第xx层"的字符串形式。 + :param df: 因子数据,必须包含 dt, x_col 列,其中 dt 为日期,x_col 为因子值,数据样例: :param x_col: 因子列名 :param n: 分层数量,默认为10 @@ -56,7 +84,7 @@ def normalize_ts_feature(df, x_col, n=10, **kwargs): else: raise ValueError("method 必须为 expanding 或 rolling") - # 用标准化后的值填充原始值中的缺失值 + # 对于缺失值,获取原始值,然后进行标准化 na_x = df[df[f"{x_col}_norm"].isna()][x_col].values df.loc[df[f"{x_col}_norm"].isna(), f"{x_col}_norm"] = na_x - na_x.mean() / na_x.std() @@ -72,7 +100,7 @@ def normalize_ts_feature(df, x_col, n=10, **kwargs): else: raise ValueError("method 必须为 expanding 或 rolling") - # 用分位数后的值填充原始值中的缺失值 + # 对于缺失值,获取原始值,然后进行分位数处理分层 na_x = df[df[f"{x_col}_qcut"].isna()][x_col].values df.loc[df[f"{x_col}_qcut"].isna(), f"{x_col}_qcut"] = pd.qcut(na_x, q=n, labels=False, duplicates='drop', retbins=False) @@ -86,7 +114,15 @@ def normalize_ts_feature(df, x_col, n=10, **kwargs): def feture_cross_layering(df, x_col, **kwargs): - """因子在时间截面上分层 + """对因子数据在时间截面上进行分层处理 + + 函数计算逻辑: + + 1. 首先从参数中获取分层数量 n,默认为10。 + 2. 确保数据 df 包含 dt、symbol 和指定的因子列 x_col, 确保标的数量大于分层数量。 + 3. 如果因子列的唯一值数量大于分层数量,使用 pd.qcut 函数将因子列进行分层,按照分位数进行分组。 + 4. 如果因子列的唯一值数量小于等于分层数量,按照因子列的唯一值进行排序,并将每个因子值映射为对应的层级。 + 5. 将分层结果转换为字符串形式,以表示层级。 :param df: 因子数据,数据样例: diff --git a/czsc/utils/sig.py b/czsc/utils/sig.py index fcb477afa..516410606 100644 --- a/czsc/utils/sig.py +++ b/czsc/utils/sig.py @@ -72,9 +72,9 @@ def check_cross_info(fast: [List, np.array], slow: [List, np.array]): temp_fast.append(fast[i]) temp_slow.append(slow[i]) - if i >= 2 and delta[i-1] <= 0 < delta[i]: + if i >= 2 and delta[i - 1] <= 0 < delta[i]: kind = "金叉" - elif i >= 2 and delta[i-1] >= 0 > delta[i]: + elif i >= 2 and delta[i - 1] >= 0 > delta[i]: kind = "死叉" else: continue @@ -187,9 +187,9 @@ def fast_slow_cross(fast: [List, np.array], slow: [List, np.array]): temp_fast.append(fast[i]) temp_slow.append(slow[i]) - if i >= 2 and delta[i-1] <= 0 < delta[i]: + if i >= 2 and delta[i - 1] <= 0 < delta[i]: kind = "金叉" - elif i >= 2 and delta[i-1] >= 0 > delta[i]: + elif i >= 2 and delta[i - 1] >= 0 > delta[i]: kind = "死叉" else: continue diff --git a/czsc/utils/signal_analyzer.py b/czsc/utils/signal_analyzer.py index b4e4fc44e..6f8733c83 100644 --- a/czsc/utils/signal_analyzer.py +++ b/czsc/utils/signal_analyzer.py @@ -3,7 +3,7 @@ author: zengbin93 email: zeng_bin8888@163.com create_dt: 2023/3/30 21:13 -describe: +describe: """ import os import hashlib @@ -136,7 +136,6 @@ def __init__(self, symbols, read_bars, signals_config, results_path, **kwargs): self.kwargs = kwargs self.task_hash = hashlib.sha256((str(signals_config) + str(symbols)).encode('utf-8')).hexdigest()[:8].upper() - def generate_symbol_signals(self, symbol): from czsc.traders.sig_parse import get_signals_freqs from czsc.traders.base import generate_czsc_signals @@ -155,12 +154,12 @@ def generate_symbol_signals(self, symbol): if len(bars) < 100: logger.error(f"{symbol} 信号生成失败:数据量不足") return pd.DataFrame() - + sigs: pd.DataFrame = generate_czsc_signals(bars, deepcopy(self.signals_config), sdt=sdt, df=True) # type: ignore if sigs.empty: logger.error(f"{symbol} 信号生成失败:数据量不足") return pd.DataFrame() - + sigs.drop(['freq', 'cache'], axis=1, inplace=True) update_nbars(sigs, price_col='open', move=1, numbers=(1, 2, 3, 5, 8, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100)) diff --git a/czsc/utils/stats.py b/czsc/utils/stats.py index 0f8562e42..67414adf5 100644 --- a/czsc/utils/stats.py +++ b/czsc/utils/stats.py @@ -22,7 +22,25 @@ def cal_break_even_point(seq) -> float: def subtract_fee(df, fee=1): - """依据单品种持仓信号扣除手续费""" + """依据单品种持仓信号扣除手续费 + + 函数执行逻辑: + + 1. 首先,函数对输入的df进行检查,确保其包含所需的列:'dt'(日期时间)和'pos'(持仓)。同时,检查'pos'列的值是否符合要求,即只能是0、1或-1。 + 2. 如果df中不包含'n1b'(名义收益率)列,函数会根据'price'列计算'n1b'列。 + 3. 然后,函数为输入的DataFrame df添加一个新列'date',该列包含交易日期(从'dt'列中提取)。 + 4. 接下来,函数根据持仓('pos')和名义收益率('n1b')计算'edge_pre_fee'(手续费前收益)和'edge_post_fee'(手续费后收益)两列。 + 5. 函数根据持仓信号计算开仓和平仓的位置。 + 开仓位置(open_pos)是持仓信号发生变化的位置(即,当前持仓与前一个持仓不同),并且当前持仓不为0。 + 平仓位置(exit_pos)是持仓信号发生变化的位置(即,当前持仓与前一个持仓不同),并且前一个持仓不为0。 + 6. 根据手续费规则,开仓时在第一个持仓K线上扣除手续费,平仓时在最后一个持仓K线上扣除手续费。 + 函数通过将'edge_post_fee'列的值在开仓和平仓位置上分别减去手续费(fee)来实现这一逻辑。 + 7. 最后,函数返回修改后的DataFrame df。 + + :param df: 包含dt、pos、price、n1b列的DataFrame + :param fee: 手续费,单位:BP + :return: 修改后的DataFrame + """ assert 'dt' in df.columns, 'dt 列必须存在' assert 'pos' in df.columns, 'pos 列必须存在' assert all(x in [0, 1, -1] for x in df['pos'].unique()), "pos 列的值必须是 0, 1, -1 中的一个" @@ -44,9 +62,23 @@ def subtract_fee(df, fee=1): def daily_performance(daily_returns): - """计算日收益数据的年化收益率、夏普比率、最大回撤、卡玛比率 + """采用单利计算日收益数据的各项指标 + + 函数计算逻辑: + + 1. 首先,将传入的日收益率数据转换为NumPy数组,并指定数据类型为float64。 + 2. 然后,进行一系列判断:如果日收益率数据为空或标准差为零或全部为零,则返回一个字典,其中所有指标的值都为零。 + 3. 如果日收益率数据满足要求,则进行具体的指标计算: + + - 年化收益率 = 日收益率列表的和 / 日收益率列表的长度 * 252 + - 夏普比率 = 日收益率的均值 / 日收益率的标准差 * 标准差的根号252 + - 最大回撤 = 累计日收益率的最高累积值 - 累计日收益率 + - 卡玛比率 = 年化收益率 / 最大回撤(如果最大回撤不为零,则除以最大回撤;否则为10) + - 日胜率 = 大于零的日收益率的个数 / 日收益率的总个数 + - 年化波动率 = 日收益率的标准差 * 标准差的根号252 + - 非零覆盖 = 非零的日收益率个数 / 日收益率的总个数 - 所有计算都采用单利计算 + 4. 将所有指标的值存储在一个字典中,其中键为指标名称,值为相应的计算结果。 :param daily_returns: 日收益率数据,样例: [0.01, 0.02, -0.01, 0.03, 0.02, -0.02, 0.01, -0.01, 0.02, 0.01] diff --git a/czsc/utils/trade.py b/czsc/utils/trade.py index b39ac5d3b..d3743ad1c 100644 --- a/czsc/utils/trade.py +++ b/czsc/utils/trade.py @@ -13,12 +13,12 @@ def risk_free_returns(start_date="20180101", end_date="20210101", year_returns=0.03): """创建无风险收益率序列 - :param start_date: str, defaults to "20180101" - 起始日期 - :param end_date: str, defaults to "20210101" - 截止日期 - :param year_returns: float, defaults to 0.03 - 年化收益率 + 创建一个 Pandas DataFrame,包含两列:"date" 和 "returns"。"date" 列包含从 trade_dates 获取的所有交易日期, + "returns" 列包含无风险收益率序列,计算方法是将年化收益率(year_returns)除以 252(一年的交易日数量,假设为每周 5 天) + + :param start_date: 起始日期 + :param end_date: 截止日期 + :param year_returns: 年化收益率 :return: pd.DataFrame """ from czsc.utils.calendar import get_trading_dates @@ -30,6 +30,21 @@ def risk_free_returns(start_date="20180101", end_date="20210101", year_returns=0 def cal_trade_price(bars: Union[List[RawBar], pd.DataFrame], decimals=3, **kwargs): """计算给定品种基础周期K线数据的交易价格 + 函数执行逻辑: + + 1. 首先,根据输入的 bars 参数类型(列表或 DataFrame),将其转换为 DataFrame 格式,并将其存储在变量 df 中。 + 2. 计算下一根K线的开盘价和收盘价,分别存储在新列 next_open 和 next_close 中。同时,将这两个新列名添加到 price_cols 列表中。 + 3. 计算 TWAP(时间加权平均价格)和 VWAP(成交量加权平均价格)。为此,函数使用了一个 for 循环, + 遍历 t_seq 参数(默认值为 (5, 10, 15, 20, 30, 60))。在每次循环中: + + - 计算 TWAP:使用 rolling(t).mean().shift(-t) 方法计算时间窗口为 t 的滚动平均收盘价。 + - 计算 VWAP:首先计算滚动窗口内的成交量之和(sum_vol_t)和成交量乘以收盘价之和(sum_vcp_t),然后用后者除以前者,并向下移动 t 个单位。 + - 将 TWAP 和 VWAP 的列名添加到 price_cols 列表中。 + + 4. 遍历 price_cols 列表中的每个列,将其中的 NaN 值替换为对应行的收盘价。 + 5. 从 DataFrame 中选择所需的列(包括基本的K线数据列和新计算的交易价格列),并使用 round(decimals) 方法保留指定的小数位数(默认为3)。 + 6. 返回处理后的 DataFrame。 + :param bars: 基础周期K线数据,一般是1分钟周期的K线 :param decimals: 保留小数位数,默认值3 :return: 交易价格表 @@ -61,10 +76,17 @@ def cal_trade_price(bars: Union[List[RawBar], pd.DataFrame], decimals=3, **kwarg def update_nbars(da, price_col='close', numbers=(1, 2, 5, 10, 20, 30), move=0) -> None: - """在da数据上新增后面 n 根 bar 的累计收益 + """在给定的 da 上计算并添加后面 n 根 bar 的累计收益列 收益计量单位:BP;1倍涨幅 = 10000BP + 函数的逻辑如下: + + 1. 首先,检查 price_col 是否在输入的 DataFrame(da)的列名中。如果不在,抛出 ValueError。 + 2. 确保 move 是一个非负整数。 + 3. 使用 for 循环遍历 numbers 列表中的每个整数 n, 对于每个整数 n,计算 n 根 bar 的累计收益。 + 4. 返回 None,表示这个函数会直接修改输入的 DataFrame(da),而不返回新的 DataFrame。 + :param da: 数据,DataFrame结构 :param price_col: 价格列 :param numbers: 考察的bar的数目的列表 @@ -81,7 +103,13 @@ def update_nbars(da, price_col='close', numbers=(1, 2, 5, 10, 20, 30), move=0) - def update_bbars(da, price_col='close', numbers=(1, 2, 5, 10, 20, 30)) -> None: - """在da数据上新增前面 n 根 bar 的累计收益 + """在给定的 da 数据上计算并添加前面 n 根 bar 的累计收益列 + + 函数的逻辑如下: + + 1. 首先,检查 price_col 是否在输入的 DataFrame(da)的列名中。如果不在,抛出 ValueError。 + 2. 使用 for 循环遍历 numbers 列表中的每个整数 n,对于每个整数 n,计算 n 根 bar 的累计收益。 + 3. 返回 None,表示这个函数会直接修改输入的 da,而不返回新的 DataFrame。 :param da: K线数据,DataFrame结构 :param price_col: 价格列 @@ -99,9 +127,18 @@ def update_bbars(da, price_col='close', numbers=(1, 2, 5, 10, 20, 30)) -> None: def update_tbars(da: pd.DataFrame, event_col: str) -> None: """计算带 Event 方向信息的未来收益 + 函数的逻辑如下: + + 1. 从输入的 da的列名中提取所有以 'n' 开头,以 'b' 结尾的列名,这些列名表示未来 n 根 bar 的累计收益。将这些列名存储在 n_seq 列表中。 + 2. 使用 for 循环遍历 n_seq 列表中的每个整数 n。 + 3. 对于每个整数 n,计算带有 Event 方向信息的未来收益。 + 计算方法是:将前面 n 根 bar 的累计收益(列名 f'n{n}b')与事件信号列(event_col)的值相乘。 + 将计算结果存储在一个新的列中,列名为 f't{n}b'。 + 4. 返回 None,表示这个函数会直接修改输入的 da,而不返回新的 DataFrame。 + :param da: K线数据,DataFrame结构 :param event_col: 事件信号列名,含有 0, 1, -1 三种值,0 表示无事件,1 表示看多事件,-1 表示看空事件 - :return: + :return: None """ n_seq = [int(x.strip('nb')) for x in da.columns if x[0] == 'n' and x[-1] == 'b'] for n in n_seq: diff --git a/examples/explore_func_tree.py b/examples/explore_func_tree.py new file mode 100644 index 000000000..41525ef2a --- /dev/null +++ b/examples/explore_func_tree.py @@ -0,0 +1,15 @@ +import inspect +import czsc +import pandas as pd +import pkgutil + + +def process_functions(): + functions = inspect.getmembers(czsc, inspect.isfunction) + functions = [f"{f[1].__module__}.{f[1].__name__}" for f in functions] + df = pd.DataFrame({"代码块": functions}) + df['负责人'] = None + df['进展'] = None + df['合并入库'] = None + df = df.sort_values(by="代码块", ascending=False) + df.to_excel("czsc功能清单.xlsx", index=False) diff --git a/examples/test_offline/test_update_bi.py b/examples/test_offline/test_update_bi.py new file mode 100644 index 000000000..38b156648 --- /dev/null +++ b/examples/test_offline/test_update_bi.py @@ -0,0 +1,25 @@ +import sys +sys.path.insert(0, '.') +sys.path.insert(0, '..') +sys.path.insert(0, '../..') +import pandas as pd +import os +from czsc import RawBar, Freq, CZSC, welcome + +welcome() + + +def read_daily(): + file_kline = os.path.join(r"D:\ZB\git_repo\waditu\czsc\test", "data/000001.SH_D.csv") + kline = pd.read_csv(file_kline, encoding="utf-8") + kline['amount'] = kline['close'] * kline['vol'] + + kline.loc[:, "dt"] = pd.to_datetime(kline.dt) + bars = [RawBar(symbol=row['symbol'], id=i, freq=Freq.D, open=row['open'], dt=row['dt'], + close=row['close'], high=row['high'], low=row['low'], vol=row['vol'], amount=row['amount']) + for i, row in kline.iterrows()] + return bars + +bars = read_daily() +c = CZSC(bars) +# %timeit c = CZSC(bars) \ No newline at end of file diff --git a/examples/test_offline/test_weight_backtest.py b/examples/test_offline/test_weight_backtest.py index 3ac6c4af9..1657ebbd3 100644 --- a/examples/test_offline/test_weight_backtest.py +++ b/examples/test_offline/test_weight_backtest.py @@ -4,13 +4,14 @@ import czsc import pandas as pd -assert czsc.WeightBacktest.version == "V231005" +assert czsc.WeightBacktest.version == "V231126" def run_by_weights(): """从持仓权重样例数据中回测""" - dfw = pd.read_feather(r"C:\Users\zengb\Desktop\231005\weight_example.feather") - wb = czsc.WeightBacktest(dfw, digits=1, fee_rate=0.0002) + dfw = pd.read_feather(r"C:\Users\zengb\Downloads\weight_example.feather") + wb = czsc.WeightBacktest(dfw, digits=1, fee_rate=0.0002, n_jobs=1) + # wb = czsc.WeightBacktest(dfw, digits=1, fee_rate=0.0002) # ------------------------------------------------------------------------------------ # 查看绩效评价 @@ -38,3 +39,7 @@ def run_by_weights(): print(symbol_res) wb.report(res_path=r"C:\Users\zengb\Desktop\231005\weight_example") + + +if __name__ == '__main__': + run_by_weights()