# 使用AI来打造智能策略
TIP
需要使用 northstar
v7.1及以上版本
# 理解架构
首先从架构设计上,所谓的AI能力本质上是通过集成了Tensorflow的预训练模型加载器来实现的。而要实现预训练模型的生产环境闭环,就必须要让策略在生产环境中产生相同的高维度实时数据。
举个例子,比如我观察到行情的波动可能与成交量、持仓量、MACD、均线相关,我希望把这些数据作为神经网络模型的输入向量,然后通过构建一个神经网络去预测未来五个K线的涨跌比例,并以此来作为我策略进出场的依据。
这时我就需要在策略中实现 SamplingAware
接口,该接口定义了策略是如何产生神经网络的输入向量。
通过收集不同行情状态下的输入向量,我们便有了训练数据。然后便可以通过Tensorflow构建我们的神经网络,并且进行训练。
最后,我们便可以把训练好的模型,通过 Tensorflow-java
的API来加载模型,然后通过实时的向量来产生预测数据。
整个完整的流程如下图:
# 如何进行数据采样
假设我们已经定义好了一个策略,并在策略中实现了 SamplingAware
接口,并继承 AbstractModelBasedStrategy
, 看起就像这样
@StrategicComponent(AiSupportedStrategy.NAME)
public class AiSupportedStrategy extends AbstractModelBasedStrategy implements TradeStrategy, SamplingAware{
public static final String NAME = "AI策略";
@Override
public SampleData sample() {
// 实现采样方法
}
@Override
public boolean isSampling() {
// 区分是采样阶段,还是推理阶段
}
@Override
public void onMergedBar(Bar bar) {
if(!indexContract.equals(bar.contract()) || !isEnabled()) {
return;
}
curBar = bar;
/* 如果是采样阶段,不需要执行交易逻辑 */
if(isSampling()) {
return;
}
/* 以下是交易逻辑 */
predict().thenAccept(result -> {
// 根据模型的预测结果进行交易
});
}
}
采样逻辑已经在主程序的 ModuleContext
中实现,只要处于采样阶段就会自动采样。
我们此时可以启动程序,创建一个历史回放网关及一个模组,像平常做回测一样回放历史数据。回放的同时便会在data/sampling目录生成以模组名称命名的csv文件,如下图位置:
以下是完整的示例代码,仅供参考:
package org.dromara.northstar.external.strategy;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.DoubleStream;
import org.dromara.northstar.ai.AbstractModelBasedStrategy;
import org.dromara.northstar.ai.SampleData;
import org.dromara.northstar.ai.SamplingAware;
import org.dromara.northstar.common.constant.DateTimeConstant;
import org.dromara.northstar.common.constant.FieldType;
import org.dromara.northstar.common.constant.ModuleState;
import org.dromara.northstar.common.model.DynamicParams;
import org.dromara.northstar.common.model.Setting;
import org.dromara.northstar.common.model.core.Bar;
import org.dromara.northstar.common.model.core.Contract;
import org.dromara.northstar.common.utils.TradeHelper;
import org.dromara.northstar.indicator.Indicator;
import org.dromara.northstar.indicator.constant.PeriodUnit;
import org.dromara.northstar.indicator.constant.ValueType;
import org.dromara.northstar.indicator.helper.NormalizeIndicator;
import org.dromara.northstar.indicator.helper.SimpleValueIndicator;
import org.dromara.northstar.indicator.model.Configuration;
import org.dromara.northstar.indicator.momentum.KDIndicator;
import org.dromara.northstar.indicator.trend.MACDIndicator;
import org.dromara.northstar.indicator.trend.PuBuIndicator;
import org.dromara.northstar.indicator.volume.ExpandedVolumeThresholdIndicator;
import org.dromara.northstar.strategy.IModuleContext;
import org.dromara.northstar.strategy.IModuleStrategyContext;
import org.dromara.northstar.strategy.StrategicComponent;
import org.dromara.northstar.strategy.TradeStrategy;
import org.slf4j.Logger;
import cn.hutool.core.lang.Assert;
import lombok.Getter;
import lombok.Setter;
@StrategicComponent(AiSupportedStrategy.NAME)
public class AiSupportedStrategy extends AbstractModelBasedStrategy implements TradeStrategy, SamplingAware{
public static final String NAME = "AI策略";
InitParams params; // 策略的参数配置信息
Contract tradeContract;
Contract indexContract;
PeriodRule rule5m;
PeriodRule rule10m;
PeriodRule rule15m;
PeriodRule rule30m;
PeriodRule rule60m;
Bar curBar;
Logger logger;
TradeHelper helper;
@Override
protected void initIndicators() {
logger = ctx.getLogger(getClass());
Set<Contract> contracts = ((IModuleContext)ctx).getModule().getModuleDescription()
.getModuleAccountSettingsDescription()
.stream()
.flatMap(mad -> mad.getBindedContracts().stream())
.map(csi -> ctx.getContract(csi.getUnifiedSymbol()))
.collect(Collectors.toSet());
Assert.isTrue(contracts.size() <= 2, "只能绑定两个合约");
contracts.stream().filter(c -> c.name().contains("指数")).findAny().ifPresent(c -> indexContract = c);
contracts.stream().filter(c -> c.tradable()).findAny().ifPresent(c -> tradeContract = c);
Assert.notNull(indexContract, "指数合约为空");
if(tradeContract == null) {
logger.warn("交易合约未设置,用指数合约代替");
tradeContract = indexContract;
}
rule5m = new PeriodRule(ctx, indexContract, 5, PeriodUnit.MINUTE);
rule10m = new PeriodRule(ctx, indexContract, 10, PeriodUnit.MINUTE);
rule15m = new PeriodRule(ctx, indexContract, 15, PeriodUnit.MINUTE);
rule30m = new PeriodRule(ctx, indexContract, 30, PeriodUnit.MINUTE);
rule60m = new PeriodRule(ctx, indexContract, 60, PeriodUnit.MINUTE);
helper = TradeHelper.builder().context(getContext()).tradeContract(tradeContract).build();
}
@Override
public void onMergedBar(Bar bar) {
if(!indexContract.equals(bar.contract()) || !isEnabled()) {
return;
}
curBar = bar;
/* 如果是采样阶段,不需要执行交易逻辑 */
if(isSampling()) {
return;
}
/* 以下是交易逻辑 */
predict().thenAccept(result -> {
double longRatio = result[0];
logger.info("{} {} {}", curBar.actionDay(), curBar.actionTime(), longRatio);
if(ctx.getState().isEmpty()) {
if(longRatio > 0.7) {
helper.doBuyOpen(1);
} else {
helper.doSellOpen(1);
}
}
if(ctx.getState() == ModuleState.HOLDING_SHORT && longRatio > 0.4) {
helper.doBuyClose(1);
}
if(ctx.getState() == ModuleState.HOLDING_LONG && longRatio < 0.6) {
helper.doSellClose(1);
}
});
}
@Override
public SampleData sample() {
double[] data = List.of(rule5m, rule10m, rule15m, rule30m, rule60m)
.stream()
.flatMapToDouble(r -> DoubleStream.of(r.sample()))
.toArray();
return SampleData.builder()
.actionDate(curBar.actionDay().format(DateTimeConstant.D_FORMAT_INT_FORMATTER))
.actionTime(curBar.actionTime().format(DateTimeConstant.T_FORMAT_FORMATTER))
.marketPrice(curBar.closePrice())
.states(convert(data))
.build();
}
private float[] convert(double[] dblArr) {
float[] result = new float[dblArr.length];
for(int i=0; i<dblArr.length; i++) {
result[i] = (float) dblArr[i];
}
return result;
}
@Override
public boolean isSampling() {
return params.mode.equals(MODE_SAMPLING);
}
@Override
protected int inputDim() {
return 40;
}
@Override
protected int outputDim() {
return 1;
}
@Override
public String name() {
return NAME;
}
@Override
public DynamicParams getDynamicParams() {
return new InitParams();
}
@Override
public void initWithParams(DynamicParams params) {
this.params = (InitParams) params;
}
@Getter
public class PeriodRule {
private static final int STAT_SAMPLE_SIZE = 180;
private MACDIndicator macd;
private Indicator stdMACD;
private Indicator stdDIFF;
private Indicator stdDEA;
private Indicator close;
private Indicator pb6;
private ExpandedVolumeThresholdIndicator volBaseline;
private Indicator vol;
private Indicator oi;
private Indicator stdOI;
private KDIndicator kd;
public PeriodRule(IModuleStrategyContext ctx, Contract c, int numOfUnit, PeriodUnit unit) {
macd = new MACDIndicator(Configuration.builder().indicatorName("MACD").contract(c).numOfUnits(numOfUnit).period(unit).cacheLength(STAT_SAMPLE_SIZE).build(), 12, 26, 9);
stdMACD = new NormalizeIndicator(Configuration.builder().indicatorName("stdMACD").contract(c).numOfUnits(numOfUnit).period(unit).build(), macd);
stdDIFF = new NormalizeIndicator(Configuration.builder().indicatorName("stdDIFF").contract(c).numOfUnits(numOfUnit).period(unit).build(), macd.getDiffLine());
stdDEA = new NormalizeIndicator(Configuration.builder().indicatorName("stdDEA").contract(c).numOfUnits(numOfUnit).period(unit).build(), macd.getDeaLine());
close = new SimpleValueIndicator(Configuration.builder().indicatorName("C").contract(c).numOfUnits(numOfUnit).period(unit).build());
pb6 = new PuBuIndicator(Configuration.builder().indicatorName("PB6").contract(c).numOfUnits(numOfUnit).period(unit).build(), 24);
volBaseline = new ExpandedVolumeThresholdIndicator(Configuration.builder().indicatorName("VolBaseline").contract(c).numOfUnits(numOfUnit).period(unit).build(), 5);
vol = new SimpleValueIndicator(Configuration.builder().indicatorName("Vol").contract(c).numOfUnits(numOfUnit).period(unit).valueType(ValueType.VOL_DELTA).build());
oi = new SimpleValueIndicator(Configuration.builder().indicatorName("OI").contract(c).numOfUnits(numOfUnit).period(unit).valueType(ValueType.OI_DELTA).cacheLength(STAT_SAMPLE_SIZE).build());
stdOI = new NormalizeIndicator(Configuration.builder().indicatorName("stdOI").contract(c).numOfUnits(numOfUnit).period(unit).build(), oi);
kd = new KDIndicator(Configuration.builder().indicatorName("KD").contract(c).numOfUnits(numOfUnit).period(unit).build(), 9, 3, 3);
ctx.registerIndicator(macd);
ctx.registerIndicator(stdMACD);
ctx.registerIndicator(stdDIFF);
ctx.registerIndicator(stdDEA);
ctx.registerIndicator(close);
ctx.registerIndicator(pb6);
ctx.registerIndicator(volBaseline);
ctx.registerIndicator(vol);
ctx.registerIndicator(oi);
ctx.registerIndicator(stdOI);
ctx.registerIndicator(kd);
}
/**
* 量、仓、K、D、PB6、MACD、DEA、DIFF
* @return
*/
public double[] sample() {
// 采样时顺便对数据进行标准化处理
return new double[] {
Math.log(vol.value(0) / volBaseline.value(0)),
stdOI.value(0),
kd.getK().value(0) / 100 - 0.5, // K值为[0, 100]的范围,把范围缩小成[0, 1]并平移使得值范围为[-0.5, 0.5]
(kd.getK().value(0) - kd.getD().value(0)) / 10,
100 * (close.value(0) - pb6.value(0)) / pb6.value(0),
stdMACD.value(0),
stdDIFF.value(0),
stdDEA.value(0),
};
}
}
@Setter
@Getter
public static class InitParams extends DynamicParams {
@Setting(label="模式", order=0, type = FieldType.SELECT, options = {MODE_SAMPLING, MODE_PREDICTING})
private String mode;
}
}
# 如何构建与训练模型
原始的训练数据已经准备好,便可以在jupyter notebook上构建自己的神经网络模型。
同时,由于数据文件是CSV格式的,采用numpy做进一步的加工是轻而易举的事情。不过,要注意的是,为了确保生产环境中能实现推理闭环,对数据的处理仅限于生成标注结果Y,而不应该修改输入向量X。
至于具体采用什么样的神经网络,便由用户自行探索,不在此讨论范围之内。
# 如何保存与加载模型
在python中,Tensorflow模型训练完成后,可以使用 model.save()
来保存训练模型。
注意,保存的【名称】应该与模组名一致
model.save('名称', save_format='tf')
其保存后目录内的结构应该如下:
.(模组名)
├─ fingerprint.pb
├─ keras_metadata.pb
├─ saved_model.pb
├─ assets
└─ variables
把模型目录复制到 northstar
主程序目录下的 models
文件夹下(如果不存在就创建一个)。
最后,别忘记了把模组调回到 非采样
模式,便可以读取预训练的模型进行推理了。
← 监控与管理程序化交易 架构设计 →