Neural Network in Java
Here is a minimal implementation of a neural network in Java. This is meant to be copy-pasted into your project, should you ever have a need to use a neural network. The code is also meant to be easy to follow and have reasonably good performance. Check it out if you’re interested. Any helpful comments humbly accepted. I am not well versed in Java and would gladly like to know how to make the code more idiomatic Java.
https://github.com/dlidstrom/NeuralNetworkInAllLangs/tree/main/Java
Example usage
var trainingData = Arrays.asList(
new DataItem(new double[]{0, 0}, new double[]{Logical.xor(0, 0), Logical.xnor(0, 0), Logical.or(0, 0), Logical.and(0, 0), Logical.nor(0, 0), Logical.nand(0, 0)}),
new DataItem(new double[]{0, 1}, new double[]{Logical.xor(0, 1), Logical.xnor(0, 1), Logical.or(0, 1), Logical.and(0, 1), Logical.nor(0, 1), Logical.nand(0, 1)}),
new DataItem(new double[]{1, 0}, new double[]{Logical.xor(1, 0), Logical.xnor(1, 0), Logical.or(1, 0), Logical.and(1, 0), Logical.nor(1, 0), Logical.nand(1, 0)}),
new DataItem(new double[]{1, 1}, new double[]{Logical.xor(1, 1), Logical.xnor(1, 1), Logical.or(1, 1), Logical.and(1, 1), Logical.nor(1, 1), Logical.nand(1, 1)})
).toArray(new DataItem[0]);
Trainer trainer = Trainer.create(2, 2, 6, rand);
double lr = 1.0;
int ITERS = 4000;
for (int e = 0; e < ITERS; e++) {
var sample = trainingData[e % trainingData.length];
trainer.train(sample.input(), sample.output(), lr);
}
Network network = trainer.network();
System.out.println("Result after " + ITERS + " iterations");
System.out.println(" XOR XNOR OR AND NOR NAND");
for (var sample : trainingData) {
double[] pred = network.predict(sample.input());
System.out.printf(
Locale.ROOT,
"%d,%d = %.3f %.3f %.3f %.3f %.3f %.3f%n",
(int) sample.input()[0], (int) sample.input()[1],
pred[0], pred[1], pred[2], pred[3], pred[4], pred[5]);
}
Enter fullscreen mode Exit fullscreen mode
This example shows how to implement a neural network that can be used to predict 6 logical functions: xor, xnor, or, and, nor, nand. It uses two input neurons, two hidden neurons, and 6 output neurons. Such a network contains 24 weights which are trained to correctly predict all 6 functions.
You can use this implementation for handwriting recognition, game playing, predictions, and more.
Here’s the output of the above sample:
Result after 4000 iterations
XOR XNOR OR AND NOR NAND
0,0 = 0.038 0.962 0.038 0.001 0.963 0.999
0,1 = 0.961 0.039 0.970 0.026 0.029 0.974
1,0 = 0.961 0.039 0.970 0.026 0.030 0.974
1,1 = 0.049 0.952 0.994 0.956 0.006 0.044
Enter fullscreen mode Exit fullscreen mode
You’ll need to round the values to get exact answers, but this is what a neural network will output, in general.
暂无评论内容