|
| 1 | +package main |
| 2 | + |
| 3 | +import ( |
| 4 | + "fmt" |
| 5 | + "math" |
| 6 | + |
| 7 | + e "github.com/daniel4x/GoGrad/engine" |
| 8 | +) |
| 9 | + |
| 10 | +func createXORData() ([][]*e.Value, []float64) { |
| 11 | + x := [][]float64{ |
| 12 | + {0, 0}, |
| 13 | + {0, 1}, |
| 14 | + {1, 0}, |
| 15 | + {1, 1}, |
| 16 | + } |
| 17 | + y := []float64{-1, 1, 1, -1} // setting False as -1 and True as 1 just to get a cleaner outputs from the model |
| 18 | + |
| 19 | + return e.MakeValueMatrix(x), y |
| 20 | +} |
| 21 | + |
| 22 | +func printData(X [][]*e.Value, y []float64) { |
| 23 | + for i := 0; i < len(X); i++ { |
| 24 | + fmt.Printf("(%v, %v) -> %v\n", X[i][0].Data(), X[i][1].Data(), y[i]) |
| 25 | + } |
| 26 | +} |
| 27 | + |
| 28 | +func main() { |
| 29 | + // Create XOR dataset |
| 30 | + X, y := createXORData() |
| 31 | + fmt.Println("XOR dataset:") |
| 32 | + printData(X, y) |
| 33 | + |
| 34 | + // Define a two-layer MLP with 2 input neurons, 2 hidden layers with 4 neurons each, and 1 output neuron |
| 35 | + nn := e.NewMLP(2, []int{4, 4, 1}) |
| 36 | + fmt.Println("\nMulti-layer Perceptron Definition:\n", nn) |
| 37 | + |
| 38 | + // Train the model |
| 39 | + epochs := 2000 |
| 40 | + alpha := 0.01 |
| 41 | + |
| 42 | + for i := 0; i < epochs; i++ { |
| 43 | + y_model := make([]*e.Value, len(X)) |
| 44 | + |
| 45 | + // Forward pass |
| 46 | + // Feed in each data point |
| 47 | + for j := 0; j < len(X); j++ { |
| 48 | + y_model[j] = nn.Call(X[j]) |
| 49 | + } |
| 50 | + |
| 51 | + // Compute the loss |
| 52 | + loss := y_model[0].Sub(y[0]).Pow(2) |
| 53 | + for j := 1; j < len(y_model); j++ { |
| 54 | + loss = loss.Add(y_model[j].Sub(y[j]).Pow(2)) |
| 55 | + } |
| 56 | + |
| 57 | + // Backward pass |
| 58 | + // zero the gradients to avoid accumulation between epochs |
| 59 | + params := nn.Parameters() |
| 60 | + for j := 0; j < len(params); j++ { |
| 61 | + params[j].ZeroGrad() |
| 62 | + } |
| 63 | + |
| 64 | + loss.Backward() // backward |
| 65 | + |
| 66 | + // Update the parameters |
| 67 | + for j := 0; j < len(params); j++ { |
| 68 | + params[j].SetData(params[j].Data() - alpha*params[j].Grad()) |
| 69 | + } |
| 70 | + |
| 71 | + if (i+1)%100 == 0 { |
| 72 | + // Print the loss every 100 epochs |
| 73 | + fmt.Println("epoch", i, "loss", loss.Data()) |
| 74 | + } |
| 75 | + } |
| 76 | + |
| 77 | + // Test the model |
| 78 | + predictions := make([]float64, len(X)) |
| 79 | + for i := 0; i < len(X); i++ { |
| 80 | + predictions[i] = nn.Call(X[i]).Data() |
| 81 | + } |
| 82 | + |
| 83 | + fmt.Println("\nTesting the model:") |
| 84 | + for i := 0; i < len(X); i++ { |
| 85 | + fmt.Printf("(%v, %v) -> Actual: %v Prediction %v\n", X[i][0].Data(), X[i][1].Data(), y[i], predictions[i]) |
| 86 | + } |
| 87 | + |
| 88 | + // Raise error if the difference between the actual and predicted values is greater than 0.1 |
| 89 | + for i := 0; i < len(X); i++ { |
| 90 | + if math.Abs(y[i]-predictions[i]) > 0.1 { |
| 91 | + panic(fmt.Sprintf("\nTest failed: (%v, %v) -> Actual: %v Prediction %v\n", X[i][0].Data(), X[i][1].Data(), y[i], predictions[i])) |
| 92 | + } |
| 93 | + } |
| 94 | +} |
0 commit comments