diff --git a/czsc/eda.py b/czsc/eda.py index 845ef1ba8..171271a96 100644 --- a/czsc/eda.py +++ b/czsc/eda.py @@ -83,25 +83,27 @@ def remove_beta_effects(df, **kwargs): return dfr -def cross_sectional_strategy(df, factor, **kwargs): +def cross_sectional_strategy(df, factor, weight="weight", long=0.3, short=0.3, **kwargs): """根据截面因子值构建多空组合 :param df: pd.DataFrame, 包含因子列的数据, 必须包含 dt, symbol, factor 列 :param factor: str, 因子列名称 + :param weight: str, 权重列名称,默认为 weight + :param long: float, 多头持仓比例/数量,默认为 0.3, 取值范围为 [0, n_symbols], 0~1 表示比例,大于等于1表示数量 + :param short: float, 空头持仓比例/数量,默认为 0.3, 取值范围为 [0, n_symbols], 0~1 表示比例,大于等于1表示数量 :param kwargs: - factor_direction: str, 因子方向,positive 或 negative - - long_num: int, 多头持仓数量 - - short_num: int, 空头持仓数量 - logger: loguru.logger, 日志记录器 + - norm: bool, 是否对 weight 进行截面持仓标准化,默认为 False :return: pd.DataFrame, 包含 weight 列的数据 """ factor_direction = kwargs.get("factor_direction", "positive") - long_num = kwargs.get("long_num", 5) - short_num = kwargs.get("short_num", 5) logger = kwargs.get("logger", loguru.logger) + norm = kwargs.get("norm", True) + assert long >= 0 and short >= 0, "long 和 short 参数必须大于等于0" assert factor in df.columns, f"{factor} 不在 df 中" assert factor_direction in ["positive", "negative"], f"factor_direction 参数错误" @@ -109,20 +111,33 @@ def cross_sectional_strategy(df, factor, **kwargs): if factor_direction == "negative": df[factor] = -df[factor] - df['weight'] = 0 + df[weight] = 0 + rows = [] + for dt, dfg in df.groupby("dt"): - if len(dfg) < long_num + short_num: - logger.warning(f"{dt} 截面数据量过小,跳过;仅有 {len(dfg)} 条数据,需要 {long_num + short_num} 条数据") + long_num = long if long >= 1 else int(len(dfg) * long) + short_num = short if short >= 1 else int(len(dfg) * short) + + if long_num == 0 and short_num == 0: + logger.warning(f"{dt} 多空目前持仓数量都为0; long: {long}, short: {short}") + rows.append(dfg) continue - dfa = dfg.sort_values(factor, ascending=False).head(long_num) - dfb = dfg.sort_values(factor, ascending=True).head(short_num) - if long_num > 0: - df.loc[dfa.index, "weight"] = 1 / long_num - if short_num > 0: - df.loc[dfb.index, "weight"] = -1 / short_num + long_symbols = dfg.sort_values(factor, ascending=False).head(long_num)['symbol'].tolist() + short_symbols = dfg.sort_values(factor, ascending=True).head(short_num)['symbol'].tolist() - return df + union_symbols = set(long_symbols) & set(short_symbols) + if union_symbols: + logger.warning(f"{dt} 存在同时在多头和空头的品种:{union_symbols}") + long_symbols = list(set(long_symbols) - union_symbols) + short_symbols = list(set(short_symbols) - union_symbols) + + dfg.loc[dfg['symbol'].isin(long_symbols), weight] = 1 / long_num if norm else 1 + dfg.loc[dfg['symbol'].isin(short_symbols), weight] = -1 / short_num if norm else -1 + rows.append(dfg) + + dfx = pd.concat(rows, ignore_index=True) + return dfx def judge_factor_direction(df: pd.DataFrame, factor, target='n1b', by='symbol', **kwargs): diff --git a/examples/develop/weight_backtest.py b/examples/develop/weight_backtest.py index f3f66e84a..745a59680 100644 --- a/examples/develop/weight_backtest.py +++ b/examples/develop/weight_backtest.py @@ -20,11 +20,16 @@ def test_ensemble_weight(): def test_rust_weight_backtest(): """从持仓权重样例数据中回测""" - from rs_czsc import PyBacktest as WeightBacktest + from rs_czsc import WeightBacktest + # from rs_czsc import daily_performance + # from czsc import daily_performance + + # stats = daily_performance([0.01, 0.02, -0.03, 0.04, 0.05]) dfw = pd.read_feather(r"C:\Users\zengb\Downloads\weight_example.feather") - wb = WeightBacktest(czsc.to_arrow(dfw), digits=2, fee_rate=0.0002, n_jobs=1) + # wb = WeightBacktest(czsc.to_arrow(dfw), digit=2, fee_rate=0.0002, n_jobs=1) + wb = WeightBacktest(dfw, digits=2, fee_rate=0.0002, n_jobs=1) - # ss = sorted(wb.stats.items()) - # print(ss) + ss = sorted(wb.stats.items()) + print(ss) diff --git a/test/test_cross_sectional_strategy.py b/test/test_cross_sectional_strategy.py new file mode 100644 index 000000000..143d4fbef --- /dev/null +++ b/test/test_cross_sectional_strategy.py @@ -0,0 +1,68 @@ +# tests/test_cross_sectional_strategy.py +import pytest +import pandas as pd +from czsc.eda import cross_sectional_strategy + + +@pytest.fixture +def sample_data(): + data = { + "dt": [ + "2023-01-01", + "2023-01-02", + "2023-01-03", + "2023-01-04", + "2023-01-05", + "2023-01-06", + "2023-01-07", + "2023-01-08", + "2023-01-09", + "2023-01-10", + ] + * 5, + "symbol": ["A"] * 10 + ["B"] * 10 + ["C"] * 10 + ["D"] * 10 + ["E"] * 10, + "factor": list(range(1, 51)), + } + return pd.DataFrame(data) + + +def test_cross_sectional_strategy_positive(sample_data): + result = cross_sectional_strategy(sample_data, factor="factor", long=0.5, short=0.5, factor_direction="positive") + assert "weight" in result.columns + assert result["weight"].sum() == 0 # Long and short positions should balance out + + +def test_cross_sectional_strategy_negative(sample_data): + result = cross_sectional_strategy(sample_data, factor="factor", long=0.5, short=0.5, factor_direction="negative") + assert "weight" in result.columns + assert result["weight"].sum() == 0 # Long and short positions should balance out + print(result) + + +def test_cross_sectional_strategy_negative_norm(sample_data): + result = cross_sectional_strategy( + sample_data, factor="factor", long=0.5, short=0.5, factor_direction="negative", norm=False + ) + assert "weight" in result.columns + assert result["weight"].sum() == 0 # Long and short positions should balance out + print(result) + + +def test_cross_sectional_strategy_no_positions(sample_data): + result = cross_sectional_strategy(sample_data, factor="factor", long=0, short=0) + assert "weight" in result.columns + assert result["weight"].sum() == 0 # No positions should be taken + + +def test_cross_sectional_strategy_invalid_factor(sample_data): + with pytest.raises(AssertionError): + cross_sectional_strategy(sample_data, factor="invalid_factor", long=0.5, short=0.5) + + +def test_cross_sectional_strategy_invalid_factor_direction(sample_data): + with pytest.raises(AssertionError): + cross_sectional_strategy(sample_data, factor="factor", long=0.5, short=0.5, factor_direction="invalid") + + +if __name__ == "__main__": + pytest.main()