# 使用AI来打造智能策略

TIP

需要使用 northstar v7.1及以上版本

# 理解架构

AI模块架构

首先从架构设计上,所谓的AI能力本质上是通过集成了Tensorflow的预训练模型加载器来实现的。而要实现预训练模型的生产环境闭环,就必须要让策略在生产环境中产生相同的高维度实时数据。

举个例子,比如我观察到行情的波动可能与成交量、持仓量、MACD、均线相关,我希望把这些数据作为神经网络模型的输入向量,然后通过构建一个神经网络去预测未来五个K线的涨跌比例,并以此来作为我策略进出场的依据。

这时我就需要在策略中实现 SamplingAware 接口,该接口定义了策略是如何产生神经网络的输入向量。

通过收集不同行情状态下的输入向量,我们便有了训练数据。然后便可以通过Tensorflow构建我们的神经网络,并且进行训练。

最后,我们便可以把训练好的模型,通过 Tensorflow-java 的API来加载模型,然后通过实时的向量来产生预测数据。

整个完整的流程如下图:

AI模块采样与推理流程

# 如何进行数据采样

假设我们已经定义好了一个策略,并在策略中实现了 SamplingAware 接口,并继承 AbstractModelBasedStrategy, 看起就像这样

@StrategicComponent(AiSupportedStrategy.NAME)
public class AiSupportedStrategy extends AbstractModelBasedStrategy implements TradeStrategy, SamplingAware{

	public static final String NAME = "AI策略";	

    @Override
	public SampleData sample() {
        // 实现采样方法
    }

	@Override
	public boolean isSampling() {
		// 区分是采样阶段,还是推理阶段
	}

    @Override
	public void onMergedBar(Bar bar) {
		if(!indexContract.equals(bar.contract()) || !isEnabled()) {
			return;
		}
		curBar = bar;
		/* 如果是采样阶段,不需要执行交易逻辑 */
		if(isSampling()) {
			return;
		}

		/* 以下是交易逻辑 */
		predict().thenAccept(result -> {
			// 根据模型的预测结果进行交易
		});
	}
}

采样逻辑已经在主程序的 ModuleContext 中实现,只要处于采样阶段就会自动采样。

我们此时可以启动程序,创建一个历史回放网关及一个模组,像平常做回测一样回放历史数据。回放的同时便会在data/sampling目录生成以模组名称命名的csv文件,如下图位置: 采样文件

以下是完整的示例代码,仅供参考:

package org.dromara.northstar.external.strategy;

import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.DoubleStream;

import org.dromara.northstar.ai.AbstractModelBasedStrategy;
import org.dromara.northstar.ai.SampleData;
import org.dromara.northstar.ai.SamplingAware;
import org.dromara.northstar.common.constant.DateTimeConstant;
import org.dromara.northstar.common.constant.FieldType;
import org.dromara.northstar.common.constant.ModuleState;
import org.dromara.northstar.common.model.DynamicParams;
import org.dromara.northstar.common.model.Setting;
import org.dromara.northstar.common.model.core.Bar;
import org.dromara.northstar.common.model.core.Contract;
import org.dromara.northstar.common.utils.TradeHelper;
import org.dromara.northstar.indicator.Indicator;
import org.dromara.northstar.indicator.constant.PeriodUnit;
import org.dromara.northstar.indicator.constant.ValueType;
import org.dromara.northstar.indicator.helper.NormalizeIndicator;
import org.dromara.northstar.indicator.helper.SimpleValueIndicator;
import org.dromara.northstar.indicator.model.Configuration;
import org.dromara.northstar.indicator.momentum.KDIndicator;
import org.dromara.northstar.indicator.trend.MACDIndicator;
import org.dromara.northstar.indicator.trend.PuBuIndicator;
import org.dromara.northstar.indicator.volume.ExpandedVolumeThresholdIndicator;
import org.dromara.northstar.strategy.IModuleContext;
import org.dromara.northstar.strategy.IModuleStrategyContext;
import org.dromara.northstar.strategy.StrategicComponent;
import org.dromara.northstar.strategy.TradeStrategy;
import org.slf4j.Logger;

import cn.hutool.core.lang.Assert;
import lombok.Getter;
import lombok.Setter;

@StrategicComponent(AiSupportedStrategy.NAME)
public class AiSupportedStrategy extends AbstractModelBasedStrategy implements TradeStrategy, SamplingAware{

	public static final String NAME = "AI策略";	
	
	InitParams params;	// 策略的参数配置信息 
	
	Contract tradeContract;
	Contract indexContract;
	
	PeriodRule rule5m;
	PeriodRule rule10m;
	PeriodRule rule15m;
	PeriodRule rule30m;
	PeriodRule rule60m;
	
	Bar curBar;
	
	Logger logger;
	
	TradeHelper helper;
	
	@Override
	protected void initIndicators() {
		logger = ctx.getLogger(getClass());
		Set<Contract> contracts = ((IModuleContext)ctx).getModule().getModuleDescription()
				.getModuleAccountSettingsDescription()
				.stream()
				.flatMap(mad -> mad.getBindedContracts().stream())
				.map(csi -> ctx.getContract(csi.getUnifiedSymbol()))
				.collect(Collectors.toSet());
		Assert.isTrue(contracts.size() <= 2, "只能绑定两个合约");
		contracts.stream().filter(c -> c.name().contains("指数")).findAny().ifPresent(c -> indexContract = c);
		contracts.stream().filter(c -> c.tradable()).findAny().ifPresent(c -> tradeContract = c);
		Assert.notNull(indexContract, "指数合约为空");
		if(tradeContract == null) {
			logger.warn("交易合约未设置,用指数合约代替");
			tradeContract = indexContract;
		}
		
		rule5m = new PeriodRule(ctx, indexContract, 5, PeriodUnit.MINUTE);
		rule10m = new PeriodRule(ctx, indexContract, 10, PeriodUnit.MINUTE);
		rule15m = new PeriodRule(ctx, indexContract, 15, PeriodUnit.MINUTE);
		rule30m = new PeriodRule(ctx, indexContract, 30, PeriodUnit.MINUTE);
		rule60m = new PeriodRule(ctx, indexContract, 60, PeriodUnit.MINUTE);
		
		helper = TradeHelper.builder().context(getContext()).tradeContract(tradeContract).build();
	}
	
	@Override
	public void onMergedBar(Bar bar) {
		if(!indexContract.equals(bar.contract()) || !isEnabled()) {
			return;
		}
		curBar = bar;
		/* 如果是采样阶段,不需要执行交易逻辑 */
		if(isSampling()) {
			return;
		}

		/* 以下是交易逻辑 */
		predict().thenAccept(result -> {
			double longRatio = result[0];
			logger.info("{} {} {}", curBar.actionDay(), curBar.actionTime(), longRatio);
			if(ctx.getState().isEmpty()) {
				if(longRatio > 0.7) {
					helper.doBuyOpen(1);
				} else {
					helper.doSellOpen(1);
				}
			}
			if(ctx.getState() == ModuleState.HOLDING_SHORT && longRatio > 0.4) {
				helper.doBuyClose(1);
			}
			if(ctx.getState() == ModuleState.HOLDING_LONG && longRatio < 0.6) {
				helper.doSellClose(1);
			}
		});
	}
	
	@Override
	public SampleData sample() {
		double[] data = List.of(rule5m, rule10m, rule15m, rule30m, rule60m)
				.stream()
				.flatMapToDouble(r -> DoubleStream.of(r.sample()))
				.toArray();
		return SampleData.builder()
				.actionDate(curBar.actionDay().format(DateTimeConstant.D_FORMAT_INT_FORMATTER))
				.actionTime(curBar.actionTime().format(DateTimeConstant.T_FORMAT_FORMATTER))
				.marketPrice(curBar.closePrice())
				.states(convert(data))
				.build();
	}
	
	private float[] convert(double[] dblArr) {
		float[] result = new float[dblArr.length];
		for(int i=0; i<dblArr.length; i++) {
			result[i] = (float) dblArr[i];
		}
		return result;
	}
	
	@Override
	public boolean isSampling() {
		return params.mode.equals(MODE_SAMPLING);
	}
	
	@Override
	protected int inputDim() {
		return 40;
	}

	@Override
	protected int outputDim() {
		return 1;
	}
	
	@Override
	public String name() {
		return NAME;
	}
	
	@Override
	public DynamicParams getDynamicParams() {
		return new InitParams();
	}

	@Override
	public void initWithParams(DynamicParams params) {
		this.params = (InitParams) params;
	}
	
	@Getter
	public class PeriodRule {
		
		private static final int STAT_SAMPLE_SIZE = 180;
		
		private MACDIndicator macd;
		private Indicator stdMACD;
		private Indicator stdDIFF;
		private Indicator stdDEA;
		
		private Indicator close;
		private Indicator pb6;
		
		private ExpandedVolumeThresholdIndicator volBaseline;
		private Indicator vol;
		private Indicator oi;
		private Indicator stdOI;
		
		private KDIndicator kd;
		
		public PeriodRule(IModuleStrategyContext ctx, Contract c, int numOfUnit, PeriodUnit unit) {
			macd = new MACDIndicator(Configuration.builder().indicatorName("MACD").contract(c).numOfUnits(numOfUnit).period(unit).cacheLength(STAT_SAMPLE_SIZE).build(), 12, 26, 9);
			stdMACD = new NormalizeIndicator(Configuration.builder().indicatorName("stdMACD").contract(c).numOfUnits(numOfUnit).period(unit).build(), macd);
			stdDIFF = new NormalizeIndicator(Configuration.builder().indicatorName("stdDIFF").contract(c).numOfUnits(numOfUnit).period(unit).build(), macd.getDiffLine());
			stdDEA = new NormalizeIndicator(Configuration.builder().indicatorName("stdDEA").contract(c).numOfUnits(numOfUnit).period(unit).build(), macd.getDeaLine());
			
			close = new SimpleValueIndicator(Configuration.builder().indicatorName("C").contract(c).numOfUnits(numOfUnit).period(unit).build());
			pb6 = new PuBuIndicator(Configuration.builder().indicatorName("PB6").contract(c).numOfUnits(numOfUnit).period(unit).build(), 24);
			
			volBaseline = new ExpandedVolumeThresholdIndicator(Configuration.builder().indicatorName("VolBaseline").contract(c).numOfUnits(numOfUnit).period(unit).build(), 5);
			vol = new SimpleValueIndicator(Configuration.builder().indicatorName("Vol").contract(c).numOfUnits(numOfUnit).period(unit).valueType(ValueType.VOL_DELTA).build());
			oi = new SimpleValueIndicator(Configuration.builder().indicatorName("OI").contract(c).numOfUnits(numOfUnit).period(unit).valueType(ValueType.OI_DELTA).cacheLength(STAT_SAMPLE_SIZE).build());
			stdOI = new NormalizeIndicator(Configuration.builder().indicatorName("stdOI").contract(c).numOfUnits(numOfUnit).period(unit).build(), oi);
			
			kd = new KDIndicator(Configuration.builder().indicatorName("KD").contract(c).numOfUnits(numOfUnit).period(unit).build(), 9, 3, 3);
			
			ctx.registerIndicator(macd);
			ctx.registerIndicator(stdMACD);
			ctx.registerIndicator(stdDIFF);
			ctx.registerIndicator(stdDEA);
			ctx.registerIndicator(close);
			ctx.registerIndicator(pb6);
			ctx.registerIndicator(volBaseline);
			ctx.registerIndicator(vol);
			ctx.registerIndicator(oi);
			ctx.registerIndicator(stdOI);
			ctx.registerIndicator(kd);
		}
		
		/**
		 * 量、仓、K、D、PB6、MACD、DEA、DIFF
		 * @return
		 */
		public double[] sample() {
			// 采样时顺便对数据进行标准化处理
			return new double[] {
					Math.log(vol.value(0) / volBaseline.value(0)),
					stdOI.value(0),
					kd.getK().value(0) / 100 - 0.5, // K值为[0, 100]的范围,把范围缩小成[0, 1]并平移使得值范围为[-0.5, 0.5]
					(kd.getK().value(0) - kd.getD().value(0)) / 10,
					100 * (close.value(0) - pb6.value(0)) / pb6.value(0),
					stdMACD.value(0),
					stdDIFF.value(0),
					stdDEA.value(0),
			};
		}
	}
	
	@Setter
	@Getter
	public static class InitParams extends DynamicParams {			
		
		@Setting(label="模式", order=0, type = FieldType.SELECT, options = {MODE_SAMPLING, MODE_PREDICTING})
		private String mode;
		
	}

}

# 如何构建与训练模型

原始的训练数据已经准备好,便可以在jupyter notebook上构建自己的神经网络模型。

同时,由于数据文件是CSV格式的,采用numpy做进一步的加工是轻而易举的事情。不过,要注意的是,为了确保生产环境中能实现推理闭环,对数据的处理仅限于生成标注结果Y,而不应该修改输入向量X。

至于具体采用什么样的神经网络,便由用户自行探索,不在此讨论范围之内。

# 如何保存与加载模型

在python中,Tensorflow模型训练完成后,可以使用 model.save() 来保存训练模型。
注意,保存的【名称】应该与模组名一致

model.save('名称', save_format='tf')

其保存后目录内的结构应该如下:

.(模组名)
├─ fingerprint.pb
├─ keras_metadata.pb
├─ saved_model.pb
├─ assets                               
└─ variables                                

把模型目录复制到 northstar 主程序目录下的 models 文件夹下(如果不存在就创建一个)。

最后,别忘记了把模组调回到 非采样 模式,便可以读取预训练的模型进行推理了。