Skip to content

Commit

Permalink
V0.9.37 更新一批代码 (#178)
Browse files Browse the repository at this point in the history
* 0.9.37 first commit

* 0.9.37 update

* 0.9.37 新增GPT解读函数执行逻辑

* 0.9.37 weight backtest 增加多进程支持

* 0.9.37 fix check bi
  • Loading branch information
zengbin93 authored Nov 26, 2023
1 parent 3160c9f commit 33b347b
Show file tree
Hide file tree
Showing 19 changed files with 505 additions and 104 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pythonpackage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ name: Python package

on:
push:
branches: [ master, V0.9.36 ]
branches: [ master, V0.9.37 ]
pull_request:
branches: [ master ]

Expand Down
5 changes: 2 additions & 3 deletions czsc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,10 @@
feture_cross_layering,
)

__version__ = "0.9.36"
__version__ = "0.9.37"
__author__ = "zengbin93"
__email__ = "[email protected]"
__date__ = "20231112"

__date__ = "20231118"


def welcome():
Expand Down
83 changes: 49 additions & 34 deletions czsc/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,23 @@ def remove_include(k1: NewBar, k2: NewBar, k3: RawBar):


def check_fx(k1: NewBar, k2: NewBar, k3: NewBar):
"""查找分型"""
"""查找分型
函数计算逻辑:
1. 如果第二个`NewBar`对象的最高价和最低价都高于第一个和第三个`NewBar`对象的对应价格,那么它被认为是顶分型(G)。
在这种情况下,函数会创建一个新的`FX`对象,其标记为`Mark.G`,并将其赋值给`fx`。
2. 如果第二个`NewBar`对象的最高价和最低价都低于第一个和第三个`NewBar`对象的对应价格,那么它被认为是底分型(D)。
在这种情况下,函数会创建一个新的`FX`对象,其标记为`Mark.D`,并将其赋值给`fx`。
3. 函数最后返回`fx`,如果没有找到分型,`fx`将为`None`。
:param k1: 第一个`NewBar`对象
:param k2: 第二个`NewBar`对象
:param k3: 第三个`NewBar`对象
:return: `FX`对象或`None`
"""
fx = None
if k1.high < k2.high > k3.high and k1.low < k2.low > k3.low:
fx = FX(symbol=k1.symbol, dt=k2.dt, mark=Mark.G, high=k2.high,
Expand All @@ -89,10 +105,24 @@ def check_fx(k1: NewBar, k2: NewBar, k3: NewBar):


def check_fxs(bars: List[NewBar]) -> List[FX]:
"""输入一串无包含关系K线,查找其中所有分型"""
"""输入一串无包含关系K线,查找其中所有分型
函数的主要步骤:
1. 创建一个空列表`fxs`用于存储找到的分型。
2. 遍历`bars`列表中的每个元素(除了第一个和最后一个),并对每三个连续的`NewBar`对象调用`check_fx`函数。
3. 如果`check_fx`函数返回一个`FX`对象,检查它的标记是否与`fxs`列表中最后一个`FX`对象的标记相同。如果相同,记录一个错误日志。
如果不同,将这个`FX`对象添加到`fxs`列表中。
4. 最后返回`fxs`列表,它包含了`bars`列表中所有找到的分型。
这个函数的主要目的是找出`bars`列表中所有的顶分型和底分型,并确保它们是交替出现的。如果发现连续的两个分型标记相同,它会记录一个错误日志。
:param bars: 无包含关系K线列表
:return: 分型列表
"""
fxs = []
for i in range(1, len(bars)-1):
fx = check_fx(bars[i-1], bars[i], bars[i+1])
for i in range(1, len(bars) - 1):
fx = check_fx(bars[i - 1], bars[i], bars[i + 1])
if isinstance(fx, FX):
# 默认情况下,fxs本身是顶底交替的,但是对于一些特殊情况下不是这样; 临时强制要求fxs序列顶底交替
if len(fxs) >= 2 and fx.mark == fxs[-1].mark:
Expand All @@ -115,32 +145,20 @@ def check_bi(bars: List[NewBar], benchmark=None):
return None, bars

fx_a = fxs[0]
try:
if fxs[0].mark == Mark.D:
direction = Direction.Up
fxs_b = [x for x in fxs if x.mark == Mark.G and x.dt > fx_a.dt and x.fx > fx_a.fx]
if not fxs_b:
return None, bars

fx_b = fxs_b[0]
for fx in fxs_b[1:]:
if fx.high >= fx_b.high:
fx_b = fx

elif fxs[0].mark == Mark.G:
direction = Direction.Down
fxs_b = [x for x in fxs if x.mark == Mark.D and x.dt > fx_a.dt and x.fx < fx_a.fx]
if not fxs_b:
return None, bars

fx_b = fxs_b[0]
for fx in fxs_b[1:]:
if fx.low <= fx_b.low:
fx_b = fx
else:
raise ValueError
except Exception as e:
logger.exception(f"笔识别错误: {e}")
if fx_a.mark == Mark.D:
direction = Direction.Up
fxs_b = (x for x in fxs if x.mark == Mark.G and x.dt > fx_a.dt and x.fx > fx_a.fx)
fx_b = max(fxs_b, key=lambda fx: fx.high, default=None)

elif fx_a.mark == Mark.G:
direction = Direction.Down
fxs_b = (x for x in fxs if x.mark == Mark.D and x.dt > fx_a.dt and x.fx < fx_a.fx)
fx_b = min(fxs_b, key=lambda fx: fx.low, default=None)

else:
raise ValueError

if fx_b is None:
return None, bars

bars_a = [x for x in bars if fx_a.elements[0].dt <= x.dt <= fx_b.elements[2].dt]
Expand Down Expand Up @@ -291,10 +309,7 @@ def update(self, bar: RawBar):
self.bars_raw = self.bars_raw[s_index:]

# 如果有信号计算函数,则进行信号计算
if self.get_signals:
self.signals = self.get_signals(c=self)
else:
self.signals = OrderedDict()
self.signals = self.get_signals(c=self) if self.get_signals else OrderedDict()

def to_echarts(self, width: str = "1400px", height: str = '580px', bs=[]):
"""绘制K线分析图
Expand Down
32 changes: 32 additions & 0 deletions czsc/connectors/cooperation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import os
import czsc
import pandas as pd
from tqdm import tqdm
from datetime import datetime
from czsc import RawBar, Freq

# 首次使用需要打开一个python终端按如下方式设置 token
Expand Down Expand Up @@ -105,3 +107,33 @@ def get_raw_bars(symbol, freq, sdt, edt, fq='前复权', **kwargs):
return czsc.resample_bars(df, target_freq=freq)

raise ValueError(f"symbol {symbol} 无法识别,获取数据失败!")


def stocks_daily_klines(years=None, **kwargs):
"""获取全市场A股的日线数据"""
adj = kwargs.get('adj', 'hfq')
if years is None:
years = ['2017', '2018', '2019', '2020', '2021', '2022', '2023']

res = []
for year in years:
ttl = 3600 * 6 if year == str(datetime.now().year) else -1
kline = dc.pro_bar(trade_year=year, adj=adj, v=2, ttl=ttl)
res.append(kline)

dfk = pd.concat(res, ignore_index=True)
dfk['dt'] = pd.to_datetime(dfk['dt'])
dfk = dfk.sort_values(['code', 'dt'], ascending=True).reset_index(drop=True)
if kwargs.get('exclude_bj', True):
dfk = dfk[~dfk['code'].str.endswith(".BJ")].reset_index(drop=True)

nxb = kwargs.get('nxb', [1, 2, 5])
if nxb:
rows = []
for _, dfg in tqdm(dfk.groupby('code'), desc="计算NXB收益率", ncols=80, colour='green'):
czsc.update_nbars(dfg, numbers=nxb, move=1, price_col='open')
rows.append(dfg)
dfk = pd.concat(rows, ignore_index=True)

dfk = dfk.rename(columns={'code': 'symbol'})
return dfk
87 changes: 79 additions & 8 deletions czsc/connectors/ts_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,100 @@
describe: Tushare数据源
"""
import os
from czsc import data
import czsc
import pandas as pd
from czsc import Freq, RawBar
from typing import List

dc = data.TsDataCache(data_path=os.environ.get('ts_data_path', r'D:\ts_data'))
# 首次使用需要打开一个python终端按如下方式设置 token
# czsc.set_url_token(token='your token', url='http://api.tushare.pro')

cache_path = os.getenv("TS_CACHE_PATH", os.path.expanduser("~/.ts_data_cache"))
dc = czsc.DataClient(url='http://api.tushare.pro', cache_path=cache_path)

def get_symbols(step):
if step.upper() == 'ALL':
return data.get_symbols(dc, 'index') + data.get_symbols(dc, 'stock') + data.get_symbols(dc, 'etfs')
return data.get_symbols(dc, step)

def format_kline(kline: pd.DataFrame, freq: Freq) -> List[RawBar]:
"""Tushare K线数据转换
:param kline: Tushare 数据接口返回的K线数据
:param freq: K线周期
:return: 转换好的K线数据
"""
bars = []
dt_key = 'trade_time' if '分钟' in freq.value else 'trade_date'
kline = kline.sort_values(dt_key, ascending=True, ignore_index=True)
records = kline.to_dict('records')

for i, record in enumerate(records):
if freq == Freq.D:
vol = int(record['vol'] * 100) if record['vol'] > 0 else 0
amount = int(record.get('amount', 0) * 1000)
else:
vol = int(record['vol']) if record['vol'] > 0 else 0
amount = int(record.get('amount', 0))

# 将每一根K线转换成 RawBar 对象
bar = RawBar(symbol=record['ts_code'], dt=pd.to_datetime(record[dt_key]),
id=i, freq=freq, open=record['open'], close=record['close'],
high=record['high'], low=record['low'],
vol=vol, # 成交量,单位:股
amount=amount, # 成交额,单位:元
)
bars.append(bar)
return bars


def get_symbols(step="all"):
"""获取标的代码"""
stocks = dc.stock_basic(exchange='', list_status='L', fields='ts_code,symbol,name,area,industry,list_date')
stocks_ = stocks[stocks['list_date'] < '2010-01-01'].ts_code.to_list()
stocks_map = {
"index": ['000905.SH', '000016.SH', '000300.SH', '000001.SH', '000852.SH',
'399001.SZ', '399006.SZ', '399376.SZ', '399377.SZ', '399317.SZ', '399303.SZ'],
"stock": stocks.ts_code.to_list(),
"check": ['000001.SZ'],
"train": stocks_[:200],
"valid": stocks_[200:600],
"etfs": ['512880.SH', '518880.SH', '515880.SH', '513050.SH', '512690.SH',
'512660.SH', '512400.SH', '512010.SH', '512000.SH', '510900.SH',
'510300.SH', '510500.SH', '510050.SH', '159992.SZ', '159985.SZ',
'159981.SZ', '159949.SZ', '159915.SZ'],
}

asset_map = {
"index": "I",
"stock": "E",
"check": "E",
"train": "E",
"valid": "E",
"etfs": "FD"
}

if step.lower() == "all":
symbols = []
for k, v in stocks_map.items():
symbols += [f"{ts_code}#{asset_map[k]}" for ts_code in v]
else:
asset = asset_map[step]
symbols = [f"{ts_code}#{asset}" for ts_code in stocks_map[step]]

return symbols


def get_raw_bars(symbol, freq, sdt, edt, fq='后复权', raw_bar=True):
"""读取本地数据"""
from czsc import data
tdc = data.TsDataCache(data_path=cache_path)
ts_code, asset = symbol.split("#")
freq = str(freq)
adj = "qfq" if fq == "前复权" else "hfq"

if "分钟" in freq:
freq = freq.replace("分钟", "min")
bars = dc.pro_bar_minutes(ts_code, sdt=sdt, edt=edt, freq=freq, asset=asset, adj=adj, raw_bar=raw_bar)
bars = tdc.pro_bar_minutes(ts_code, sdt=sdt, edt=edt, freq=freq, asset=asset, adj=adj, raw_bar=raw_bar)

else:
_map = {"日线": "D", "周线": "W", "月线": "M"}
freq = _map[freq]
bars = dc.pro_bar(ts_code, start_date=sdt, end_date=edt, freq=freq, asset=asset, adj=adj, raw_bar=raw_bar)
bars = tdc.pro_bar(ts_code, start_date=sdt, end_date=edt, freq=freq, asset=asset, adj=adj, raw_bar=raw_bar)
return bars
1 change: 1 addition & 0 deletions czsc/data/ts.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __getattr__(self, name):
print("Tushare Pro 初始化失败")


@deprecated(reason="统一到 ts_connector 中", version='1.0.0')
def format_kline(kline: pd.DataFrame, freq: Freq) -> List[RawBar]:
"""Tushare K线数据转换
Expand Down
3 changes: 0 additions & 3 deletions czsc/data/ts_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,3 @@ def stocks_daily_basic_new(self, sdt: str, edt: str):
dfb['上市天数'] = (dfb['trade_date'] - pd.to_datetime(dfb['list_date'], errors='coerce')).apply(lambda x: x.days)
dfb.to_feather(file_cache)
return dfb



8 changes: 4 additions & 4 deletions czsc/traders/rwc.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def publish(self, symbol, dt, weight, price=0, ref=None, overwrite=False):

if not overwrite:
last_dt = self.get_last_times(symbol)
if last_dt is not None and dt <= last_dt:
if last_dt is not None and dt <= last_dt: # type: ignore
logger.warning(f"不允许重复写入,已过滤 {symbol} {dt} 的重复信号")
return 0

Expand Down Expand Up @@ -213,8 +213,8 @@ def clear_all(self):
"""删除该策略所有记录"""
self.r.delete(f'{self.key_prefix}:META:{self.strategy_name}')
keys = self.get_keys(f'{self.key_prefix}:{self.strategy_name}*')
if keys is not None and len(keys) > 0:
self.r.delete(*keys)
if keys is not None and len(keys) > 0: # type: ignore
self.r.delete(*keys) # type: ignore

@staticmethod
def register_lua_publish(client):
Expand Down Expand Up @@ -264,7 +264,7 @@ def register_lua_publish(client):
def get_symbols(self):
"""获取策略交易的品种列表"""
keys = self.get_keys(f'{self.key_prefix}:{self.strategy_name}*')
symbols = {x.split(":")[2] for x in keys}
symbols = {x.split(":")[2] for x in keys} # type: ignore
return list(symbols)

def get_last_weights(self, symbols=None, ignore_zero=True, lua=True):
Expand Down
Loading

0 comments on commit 33b347b

Please sign in to comment.