之前,我们文章里有讲,通过softmax回归对图片进行分类,具体文章请见《使用Softmax进行分类代码实现》。今天我们通过高级API更简洁地实现多层感知机。
private static RandomAccessDataset getDataset(Dataset.Usage usage) throws IOException { Mnist mnist = Mnist.builder() .optUsage(usage) .setSampling(32, true) .optLimit(64) .build(); mnist.prepare(new ProgressBar()); return mnist; } //训练集 RandomAccessDataset trainingSet = getDataset(Dataset.Usage.TRAIN); //验证集 RandomAccessDataset validateSet = getDataset(Dataset.Usage.TEST);
Block block = new Mlp( Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH, Mnist.NUM_CLASSES, new int[] {128, 64});
private static DefaultTrainingConfig setupTrainingConfig() { String outputDir = "build/model"; SaveModelTrainingListener listener = new SaveModelTrainingListener(outputDir); listener.setSaveModelCallback( trainer -> { TrainingResult result = trainer.getTrainingResult(); Model model = trainer.getModel(); float accuracy = result.getValidateEvaluation("Accuracy"); model.setProperty("Accuracy", String.format("%.5f", accuracy)); model.setProperty("Loss", String.format("%.5f", result.getValidateLoss())); }); return new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) .addEvaluator(new Accuracy()) .addTrainingListeners(TrainingListener.Defaults.logging(outputDir)) .addTrainingListeners(listener); }
try (Model model = Model.newInstance("mlp")) { model.setBlock(block); try (Trainer trainer = model.newTrainer(config)) { trainer.setMetrics(new Metrics()); Shape inputShape = new Shape(1, Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH); trainer.initialize(inputShape); EasyTrain.fit(trainer, 15, trainingSet, validateSet); //保存模型 model.save(Paths.get("build/model"), "mlp"); return trainer.getTrainingResult(); } }
在这里,我们经过15个Epoch的训练,最终在build/model
目录下,生成我们训练好的模型。后续我们将通使用我们训练的模型进行图片分类预测。关注公众号,解锁后续图片预测部分实现。