Weka机器学习实战之模型存储与读取

这一段时间突然想起来一件非常关键的事情,就是每次运行程序的时候模型都是重复训练的。试想一下,如果数据集非常庞大的时候,训练的时间将会被极大的放大,这对于系统来说是不可接受的。我们相有没有一种方式能够很快速地使用模型呢?答案是肯定的。

Weka训练模型保存

可以看到前面不管是使用J48决策树也好,还是Kmeans也好,都是一次性的模型训练构建和使用,下一次启动程序的时候一样还是要重新训练,非常浪费时间。这里Weka为我们提供了一种非常不错的工具,可以将我们训练的模型持久化。就是SerializationHelper,这个类可以用来存储和读取模型文件的参数。具体代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
/**
* 保存模型名称
* @param classifier
* @param modelName 模型名称
*/
public static void saveModel(Classifier classifier, String modelName) {
try {
SerializationHelper.write(MODEL_STORAGE_DIR +
modelName + MODEL_EXTENSION, classifier);
} catch (Exception e) {
e.printStackTrace();
}
}

一般建议模型的保存目录设置在系统的一个拥有读写权限的目录当中,而不是在项目文件当中,这样项目打包的时候就不会很臃肿,而且管理起来也非常方便。

Weka训练模型读取

好的,模型一般以.model后缀的形式保存起来,后面我们启动项目的时候需要使用到这个模型,那么如何加载呢?答案也是SerializationHelper,它负责读取模型文件中的参数,并以此构建一个决策树对象出来。具体代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
/**
* 读取模型
* @param modelName
* @param <T>
* @return
*/
public static <T> T readModel(String modelName) {
Classifier classifier = null;
try {
classifier = (Classifier) SerializationHelper.read(MODEL_STORAGE_DIR +
modelName + MODEL_EXTENSION);
} catch (Exception e) {
e.printStackTrace();
}
return (T) classifier;
}

这里使用到了一点泛型的知识,相信会Java的同学这个基本都知道,就不做解释了哈。

模型存储和读取测试

下面进行模型的存储和读取测试,还是以我们上一次说到的西瓜数据集为例,代码做一点点改变,以适配我们的这个类。首先是西瓜模型构建:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
/**
* 训练生成分类器
* @return
*/
public static Classifier generateClassifier() throws Exception{
Instances instances = loadDataSet(TRAINING_DATASET_FILENAME);
// 初始化分类器
Classifier j48 = new J48();
// 训练该数据集
j48.buildClassifier(instances);
// 模型保存
TrainningModelUtil.saveModel(j48, "watermelon");
return j48;
}

接着是模型的读取和应用:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
/**
* 预测出当前数据最可能所属的类别
* @return
*/
public static String predict(List<String> data, Instances trainingSet) throws Exception{
Classifier j48 = TrainningModelUtil.readModel("watermelon");
// 创建Instance
Instance instance = new DenseInstance(trainingSet.numAttributes());
// 分别添加待预测特征值
for (int i = 0; i < data.size(); i++) {
instance.setValue(trainingSet.attribute(i), data.get(i));
}
// 需要能访问数据集
instance.setDataset(trainingSet);
// 得出最可能所属类别
int index = (int)j48.classifyInstance(instance);
return trainingSet.classAttribute().value(index);
}

测试方法,这里需要先执行generateClassifier()方法产生一个模型文件哈,然后才能执行下面的代码测试,否则会报错的:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
public static void main(String[] args) {
try {
Instances instances = loadDataSet(TRAINING_DATASET_FILENAME);
// 青绿,蜷缩,沉闷,清晰,凹陷,硬滑 [是]
// 浅白,蜷缩,浊响,模糊,平坦,软粘 [否]
List<String> data =
Lists.newArrayList("浅白","蜷缩","浊响","模糊","平坦","软粘");
// 训练模型保存
// generateClassifier();
// 进行预测
String classOfData = predict(data, instances);
System.out.println("class of data is: " + classOfData);
} catch (Exception e) {
e.printStackTrace();
}
}

结果输出:

1
class of data is: 否

结果是符合预期的,说明我们编写的代码是没有问题的。关于TrainningModelUtil类的完整代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
package com.qinjiangbo.util;

import weka.classifiers.Classifier;
import weka.core.SerializationHelper;

/**
* @date: 28/12/2017 10:34 AM
* @author: [email protected]
* @description: 主要是用来保存训练好的模型到文件,以及从文件中读取训练好的模型
*/
public class TrainningModelUtil {

/**
* 模型保存的目录
*/
private static final String MODEL_STORAGE_DIR = "/Users/richard/Documents/Weka Models/";
/**
* 模型文件的后缀
*/
private static final String MODEL_EXTENSION = ".model";

/**
* 保存模型名称
* @param classifier
* @param modelName 模型名称
*/
public static void saveModel(Classifier classifier, String modelName) {
try {
SerializationHelper.write(MODEL_STORAGE_DIR +
modelName + MODEL_EXTENSION, classifier);
} catch (Exception e) {
e.printStackTrace();
}
}

/**
* 读取模型
* @param modelName
* @param <T>
* @return
*/
public static <T> T readModel(String modelName) {
Classifier classifier = null;
try {
classifier = (Classifier) SerializationHelper.read(MODEL_STORAGE_DIR +
modelName + MODEL_EXTENSION);
} catch (Exception e) {
e.printStackTrace();
}
return (T) classifier;
}
}

总结

关于模型文件的保存是非常关键的一步,因为它关系到后面整体系统性能的好坏,试想一下,如果我们每一次都需要重新训练模型,那对于整个系统的性能伤害将是巨大的。所以,我比较建议先训练模型,保存为一个文件,然后通后面再过读取这个模型文件进行预测。这也是数据挖掘和机器学习的正规的方法论。

分享到