Skip to content

Commit

Permalink
V0.9.36 更新一批代码 (#177)
Browse files Browse the repository at this point in the history
* 0.9.36 start coding

* 0.9.35 fix bug

* 0.9.36 优化策略持仓权重发布

* 0.9.36 fix resample_bars

* 0.9.36 remove retry

* 0.9.36 update

* 0.9.36 新增协作成员专属数据接口

* 0.9.36 新增 risk_free_returns 函数
  • Loading branch information
zengbin93 authored Nov 19, 2023
1 parent acdda96 commit 3160c9f
Show file tree
Hide file tree
Showing 14 changed files with 313 additions and 46 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pythonpackage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ name: Python package

on:
push:
branches: [ master, V0.9.35 ]
branches: [ master, V0.9.36 ]
pull_request:
branches: [ master ]

Expand Down
8 changes: 5 additions & 3 deletions czsc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from czsc import sensors
from czsc import aphorism
from czsc.analyze import CZSC
from czsc.objects import Freq, Operate, Direction, Signal, Factor, Event, RawBar, NewBar, Position
from czsc.objects import Freq, Operate, Direction, Signal, Factor, Event, RawBar, NewBar, Position, ZS
from czsc.strategies import CzscStrategyBase, CzscJsonStrategy
from czsc.sensors import holds_concepts_effect, CTAResearch, EventMatchSensor
from czsc.sensors.feature import FixedNumberSelector, FeatureAnalyzeBase
Expand Down Expand Up @@ -60,6 +60,8 @@
update_bbars,
update_tbars,
update_nbars,
risk_free_returns,

CrossSectionalPerformance,
cross_sectional_ranker,
cross_sectional_ic,
Expand Down Expand Up @@ -110,10 +112,10 @@
feture_cross_layering,
)

__version__ = "0.9.35"
__version__ = "0.9.36"
__author__ = "zengbin93"
__email__ = "[email protected]"
__date__ = "20231104"
__date__ = "20231112"



Expand Down
19 changes: 18 additions & 1 deletion czsc/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,24 @@


def remove_include(k1: NewBar, k2: NewBar, k3: RawBar):
"""去除包含关系:输入三根k线,其中k1和k2为没有包含关系的K线,k3为原始K线"""
"""去除包含关系:输入三根k线,其中k1和k2为没有包含关系的K线,k3为原始K线
处理逻辑如下:
1. 首先,通过比较k1和k2的高点(high)的大小关系来确定direction的值。如果k1的高点小于k2的高点,
则设定direction为Up;如果k1的高点大于k2的高点,则设定direction为Down;如果k1和k2的高点相等,
则创建一个新的K线k4,与k3具有相同的属性,并返回False和k4。
2. 接下来,判断k2和k3之间是否存在包含关系。如果存在,则根据direction的值进行处理。
- 如果direction为Up,则选择k2和k3中的较大高点作为新K线k4的高点,较大低点作为低点,较大高点所在的时间戳(dt)作为k4的时间戳。
- 如果direction为Down,则选择k2和k3中的较小高点作为新K线k4的高点,较小低点作为低点,较小低点所在的时间戳(dt)作为k4的时间戳。
- 如果direction的值不是Up也不是Down,则抛出ValueError异常。
3. 根据上述处理得到的高点、低点、开盘价(open_)、收盘价(close),计算新K线k4的成交量(vol)和成交金额(amount),
并将k2中除了与k3时间戳相同的元素之外的其他元素与k3一起作为k4的元素列表(elements)。
4. 返回一个布尔值和新的K线k4。如果k2和k3之间存在包含关系,则返回True和k4;否则返回False和k4,其中k4与k3具有相同的属性。
"""
if k1.high < k2.high:
direction = Direction.Up
elif k1.high > k2.high:
Expand Down
107 changes: 107 additions & 0 deletions czsc/connectors/cooperation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# -*- coding: utf-8 -*-
"""
author: zengbin93
email: [email protected]
create_dt: 2023/11/15 20:45
describe: CZSC开源协作团队内部使用数据接口
"""
import os
import czsc
import pandas as pd
from czsc import RawBar, Freq

# 首次使用需要打开一个python终端按如下方式设置 token
# czsc.set_url_token(token='your token', url='http://zbczsc.com:9106')

cache_path = os.getenv("CZSC_CACHE_PATH", os.path.expanduser("~/.quant_data_cache"))
dc = czsc.DataClient(url='http://zbczsc.com:9106', cache_path=cache_path)


def format_kline(kline: pd.DataFrame, freq: Freq):
"""格式化K线数据
:param kline: K线数据,格式如下:
========== ========= ====== ======= ====== ===== =========== ===========
dt code open close high low vol amount
========== ========= ====== ======= ====== ===== =========== ===========
2022-01-04 600520.SH 20.54 21.12 21.17 20.33 2.1724e+06 1.94007e+07
2022-01-05 600520.SH 21.17 20.73 21.29 20.52 1.8835e+06 1.67258e+07
2022-01-06 600520.SH 20.56 21.17 21.57 18.69 3.4227e+06 3.11461e+07
2022-01-07 600520.SH 21.5 20.61 21.5 20.61 2.51741e+06 2.24819e+07
2022-01-10 600520.SH 20.4 21.69 21.69 20.4 4.80894e+06 4.39598e+07
========== ========= ====== ======= ====== ===== =========== ===========
:return: 格式化后的K线数据
"""
bars = []
for i, row in kline.iterrows():
bar = RawBar(symbol=row['code'], id=i, freq=freq, dt=row['dt'],
open=row['open'], close=row['close'], high=row['high'],
low=row['low'], vol=row['vol'], amount=row['amount'])
bars.append(bar)
return bars


def get_symbols(name, **kwargs):
"""获取指定分组下的所有标的代码
:param name: 分组名称,可选值:'A股指数', 'ETF', '股票', '期货主力'
:param kwargs:
:return:
"""
if name == "股票":
data = dc.stock_basic(nobj=1, status=1)
return data['code'].tolist()

if name == "ETF":
raise NotImplementedError

if name == "A股指数":
raise NotImplementedError

if name == "期货主力":
kline = dc.future_klines(trade_date="20231101")
return kline['code'].unique().tolist()

raise ValueError(f"{name} 分组无法识别,获取标的列表失败!")


def get_raw_bars(symbol, freq, sdt, edt, fq='前复权', **kwargs):
"""获取 CZSC 库定义的标准 RawBar 对象列表
:param symbol: 标的代码
:param freq: 周期,支持 Freq 对象,或者字符串,如
'1分钟', '5分钟', '15分钟', '30分钟', '60分钟', '日线', '周线', '月线', '季线', '年线'
:param sdt: 开始时间
:param edt: 结束时间
:param fq: 除权类型,可选值:'前复权', '后复权', '不复权'
:param kwargs:
:return:
"""
freq = czsc.Freq(freq)

if symbol.endswith(".SH") or symbol.endswith(".SZ"):
fq_map = {"前复权": "qfq", "后复权": "hfq", "不复权": None}
adj = fq_map.get(fq, None)
if freq.value.endswith('分钟'):
df = dc.pro_bar(code=symbol, sdt=sdt, edt=edt, freq='min', adj=adj)
df = df[~df['dt'].str.endswith("09:30:00")].reset_index(drop=True)
else:
df = dc.pro_bar(code=symbol, sdt=sdt, edt=edt, freq='day', adj=adj)
df.rename(columns={'code': 'symbol'}, inplace=True)
df['dt'] = pd.to_datetime(df['dt'])
return czsc.resample_bars(df, target_freq=freq)

if symbol.endswith("9001"):
if freq.value.endswith('分钟'):
df = dc.future_klines(code=symbol, sdt=sdt, edt=edt, freq='1m')
else:
df = dc.future_klines(code=symbol, sdt=sdt, edt=edt, freq='1d')
df.rename(columns={'code': 'symbol'}, inplace=True)
df['amount'] = df['vol'] * df['close']
df = df[['symbol', 'dt', 'open', 'close', 'high', 'low', 'vol', 'amount']].copy().reset_index(drop=True)
df['dt'] = pd.to_datetime(df['dt'])
return czsc.resample_bars(df, target_freq=freq)

raise ValueError(f"symbol {symbol} 无法识别,获取数据失败!")
96 changes: 68 additions & 28 deletions czsc/traders/rwc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
create_dt: 2023/9/24 15:19
describe: 策略持仓权重管理
"""
import os
import time
import json
import redis
Expand All @@ -14,18 +15,15 @@
from datetime import datetime


logger.disable(__name__)


class RedisWeightsClient:
"""策略持仓权重收发客户端"""

version = "V231111"
version = "V231112"

def __init__(self, strategy_name, redis_url, **kwargs):
def __init__(self, strategy_name, redis_url=None, send_heartbeat=True, **kwargs):
"""
:param strategy_name: str, 策略名
:param redis_url: str, redis连接字符串
:param redis_url: str, redis连接字符串, 默认为None, 即从环境变量 RWC_REDIS_URL 中读取
For example::
Expand All @@ -41,19 +39,27 @@ def __init__(self, strategy_name, redis_url, **kwargs):
<https://www.iana.org/assignments/uri-schemes/prov/rediss>
- ``unix://``: creates a Unix Domain Socket connection.
:param send_heartbeat: boolean, 是否发送心跳
如果为True,会在后台启动一个线程,每15秒向redis发送一次心跳,用于检测策略是否存活。
推荐在写入数据时设置为True,读取数据时设置为False,避免无用的心跳。
:param kwargs: dict, 其他参数
- key_prefix: str, redis中key的前缀,默认为 Weights
- heartbeat_prefix: str, 心跳key的前缀,默认为 heartbeat
"""
self.strategy_name = strategy_name
self.redis_url = redis_url
self.redis_url = redis_url if redis_url else os.getenv("RWC_REDIS_URL")
self.key_prefix = kwargs.get("key_prefix", "Weights")

self.heartbeat_client = redis.from_url(redis_url, decode_responses=True)
self.heartbeat_prefix = kwargs.get("heartbeat_prefix", "heartbeat")

thread_safe_pool = redis.BlockingConnectionPool.from_url(redis_url, decode_responses=True)
thread_safe_pool = redis.BlockingConnectionPool.from_url(self.redis_url, decode_responses=True)
self.r = redis.Redis(connection_pool=thread_safe_pool)
self.lua_publish = RedisWeightsClient.register_lua_publish(self.r)

if kwargs.get('send_heartbeat', True):
if send_heartbeat:
self.heartbeat_client = redis.from_url(self.redis_url, decode_responses=True)
self.heartbeat_prefix = kwargs.get("heartbeat_prefix", "heartbeat")
self.heartbeat_thread = threading.Thread(target=self.__heartbeat, daemon=True)
self.heartbeat_thread.start()

Expand Down Expand Up @@ -134,7 +140,17 @@ def publish_dataframe(self, df, overwrite=False, batch_size=10000):
"""
df = df.copy()
df['dt'] = pd.to_datetime(df['dt'])
df = df.sort_values('dt')
logger.info(f"输入数据中有 {len(df)} 条权重信号")

# 去除单个品种下相邻时间权重相同的数据
_res = []
for _, dfg in df.groupby('symbol'):
dfg = dfg.sort_values('dt', ascending=True).reset_index(drop=True)
dfg = dfg[dfg['weight'].diff().fillna(1) != 0].copy()
_res.append(dfg)
df = pd.concat(_res, ignore_index=True)
df = df.sort_values(['dt']).reset_index(drop=True)
logger.info(f"去除单个品种下相邻时间权重相同的数据后,剩余 {len(df)} 条权重信号")

if 'price' not in df.columns:
df['price'] = 0
Expand Down Expand Up @@ -251,18 +267,36 @@ def get_symbols(self):
symbols = {x.split(":")[2] for x in keys}
return list(symbols)

def get_last_weights(self, symbols=None, ignore_zero=True):
def get_last_weights(self, symbols=None, ignore_zero=True, lua=True):
"""获取最近的持仓权重
:param symbols: list, 品种列表
:param ignore_zero: boolean, 是否忽略权重为0的品种
:param lua: boolean, 是否使用 lua 脚本获取,默认为True
如果要全量获取,推荐使用 lua 脚本,速度更快;如果要获取指定 symbols,不推荐使用 lua 脚本。
:return: pd.DataFrame
"""
symbols = symbols if symbols else self.get_symbols()
with self.r.pipeline() as pipe:
for symbol in symbols:
pipe.hgetall(f'{self.key_prefix}:{self.strategy_name}:{symbol}:LAST')
rows = pipe.execute()
if lua:
lua_script = """
local keys = redis.call('KEYS', ARGV[1])
local results = {}
for i=1, #keys do
results[i] = redis.call('HGETALL', keys[i])
end
return results
"""
key_pattern = self.key_prefix + ':' + self.strategy_name + ':*:LAST'
results = self.r.eval(lua_script, 0, key_pattern)
rows = [dict(zip(r[::2], r[1::2])) for r in results] # type: ignore
if symbols:
rows = [r for r in rows if r['symbol'] in symbols]

else:
symbols = symbols if symbols else self.get_symbols()
with self.r.pipeline() as pipe:
for symbol in symbols:
pipe.hgetall(f'{self.key_prefix}:{self.strategy_name}:{symbol}:LAST')
rows = pipe.execute()

dfw = pd.DataFrame(rows)
dfw['weight'] = dfw['weight'].astype(float)
Expand Down Expand Up @@ -318,16 +352,22 @@ def get_all_weights(self, sdt=None, edt=None, ignore_zero=True) -> pd.DataFrame:
:param ignore_zero: boolean, 是否忽略权重为0的品种
:return: pd.DataFrame
"""
keys = self.get_keys(f"{self.key_prefix}:{self.strategy_name}:*:*")
if keys is None or len(keys) == 0: # type: ignore
return pd.DataFrame()
lua_script = """
local keys = redis.call('KEYS', ARGV[1])
local results = {}
for i=1, #keys do
local last_part = keys[i]:match('([^:]+)$')
if #last_part == 14 and tonumber(last_part) ~= nil then
results[#results + 1] = redis.call('HGETALL', keys[i])
end
end
return results
"""
key_pattern = self.key_prefix + ':' + self.strategy_name + ':*:*'
results = self.r.eval(lua_script, 0, key_pattern)
results = [dict(zip(r[::2], r[1::2])) for r in results] # type: ignore

keys = [x for x in keys if len(x.split(":")[-1]) == 14] # type: ignore
with self.r.pipeline() as pipe:
for key in keys:
pipe.hgetall(key)
rows = pipe.execute()
df = pd.DataFrame(rows)
df = pd.DataFrame(results)
df['dt'] = pd.to_datetime(df['dt'])
df['weight'] = df['weight'].astype(float)
df = df.sort_values(['dt', 'symbol']).reset_index(drop=True)
Expand Down
2 changes: 1 addition & 1 deletion czsc/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .sig import check_pressure_support, check_gap_info, is_bis_down, is_bis_up, get_sub_elements
from .sig import same_dir_counts, fast_slow_cross, count_last_same, create_single_signal
from .plotly_plot import KlineChart
from .trade import cal_trade_price, update_nbars, update_bbars, update_tbars
from .trade import cal_trade_price, update_nbars, update_bbars, update_tbars, risk_free_returns
from .cross import CrossSectionalPerformance, cross_sectional_ranker
from .stats import daily_performance, net_value_stats, subtract_fee
from .signal_analyzer import SignalAnalyzer, SignalPerformance
Expand Down
22 changes: 17 additions & 5 deletions czsc/utils/bar_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,14 +154,22 @@ def resample_bars(df: pd.DataFrame, target_freq: Union[Freq, AnyStr], raw_bars=T
4 402854600 1.315272e+12
:param target_freq: 目标周期
:param raw_bars: 是否将转换后的K线序列转换为RawBar对象
:param kwargs:
- base_freq: 基础周期,如果不指定,则根据df中的dt列自动推断
- drop_unfinished: 是否删除最后一根未完成的K线
:return: 转换后的K线序列
"""
if not isinstance(target_freq, Freq):
target_freq = Freq(target_freq)

base_freq = kwargs.get('base_freq', None)
uni_times = df['dt'].head(2000).apply(lambda x: x.strftime("%H:%M")).unique().tolist()
_, market = check_freq_and_market(uni_times, freq=base_freq)
if target_freq.value.endswith("分钟"):
uni_times = df['dt'].head(2000).apply(lambda x: x.strftime("%H:%M")).unique().tolist()
_, market = check_freq_and_market(uni_times, freq=base_freq)
else:
market = "默认"

df['freq_edt'] = df['dt'].apply(lambda x: freq_end_time(x, target_freq, market))
dfk1 = df.groupby('freq_edt').agg(
Expand All @@ -177,11 +185,15 @@ def resample_bars(df: pd.DataFrame, target_freq: Union[Freq, AnyStr], raw_bars=T
row.update({'id': i, 'freq': target_freq})
_bars.append(RawBar(**row))

if df['dt'].iloc[-1] < _bars[-1].dt:
if kwargs.get('drop_unfinished', True):
# 清除最后一根未完成的K线
_bars.pop()

if df['dt'].iloc[-1] < _bars[-1].dt:
_bars.pop()
return _bars
# if df['dt'].iloc[-1] < _bars[-1].dt:
# # 清除最后一根未完成的K线
# _bars.pop()
# return _bars
else:
return dfk1

Expand Down
Loading

0 comments on commit 3160c9f

Please sign in to comment.