让 Java 和 Python 携手合作非常容易,这在开发原型时尤其有价值。
我们从一个实现 Snake 游戏逻辑的 Java 程序开始:场上总有一块食物。每当蛇到达食物时,它就会生长并出现新的食物。如果蛇咬自己或咬墙,游戏结束。
我们的目标是训练一个神经网络来控制蛇,让蛇在犯错和游戏结束之前吃掉尽可能多的食物。首先,我们需要一个代表游戏当前状态的张量。它充当我们神经网络的输入,以便网络可以使用它来预测下一步要采取的最佳步骤。为了让这个例子简单,我们的张量只是一个包含七个元素的向量,可以是 1 或 0:前四个表示食物是在蛇的右边、左边、前面还是后面,接下来的三个条目表示如果蛇头的左边、前面和右边的田地都被一堵墙或蛇的尾巴挡住了。
我们示例的完整源代码可在 GitHub 上找到。
使用 JPype 导入Java类即可:
<b>import</b> jpype <b>import</b> jpype.<b>import</b>s from jpype.types <b>import</b> * # launch the JVM jpype.startJVM(classpath=['../target/autosnake-1.0-SNAPSHOT.jar']) # <b>import</b> the Java module from me.schawe.autosnake <b>import</b> SnakeLogic # construct an object of the `SnakeLogic` <b>class</b> ... width, height = 10, 10 snake_logic = SnakeLogic(width, height) # ... and call a method on it print(snake_logic.trainingState())
JPype 在与 Python 解释器相同的进程中启动 JVM,并让它们使用 Java 本机接口 (JNI) 进行通信。
其他选项:
在 Java 中加载模型
使用deeplearning4j将训练好的模型加载到 Java 中……
<font><i>// https://deeplearning4j.konduit.ai/deeplearning4j/how-to-guides/keras-import</i></font><font> <b>public</b> <b>class</b> Autopilot { ComputationGraph model; <b>public</b> Autopilot(String pathToModel) { <b>try</b> { model = KerasModelImport.importKerasModelAndWeights(pathToModel, false); } <b>catch</b> (Exception e) { e.printStackTrace(); } } </font><font><i>// infer the next move from the given state</i></font><font> <b>public</b> <b>int</b> nextMove(<b>boolean</b>[] state) { INDArray input = Nd4j.create(state).reshape(1, state.length); INDArray output = model.output(input)[0]; <b>int</b> action = output.ravel().argMax().getInt(0); <b>return</b> action; } } 调用: <b>public</b> <b>class</b> SnakeLogic { Autopilot autopilot = <b>new</b> Autopilot(</font><font>"path/to/model.h5"</font><font>); <b>public</b> <b>void</b> update() { <b>int</b> action = autopilot.nextMove(trainingState()); turnRelative(action); </font><font><i>// rest of the update omitted</i></font><font> } </font><font><i>// further methods omitted</i></font><font> } </font>