The Essence of Neural Networks (As Explained by Karpathy) (7 Part Series)
1 The Essence of Neural Networks (As Explained by Karpathy)
2 Understanding Backpropagation from Scratch with micrograd – Derivatives
… 3 more parts…
3 Representing Math Expressions As Graphs in micrograd
4 Back-Propagation Spelled Out – As Explained by Karpathy
5 Modeling a Neuron in micrograd (As Explained by Karpathy)
6 Fully Automated Gradient Calculation on Expression Graph (As Explained By Karpathy)
7 Fixing A Bug in micrograd BackProp (As Explained by Karpathy)
Hi there! I’m Shrijith Venkatrama, founder of Hexmos. Right now, I’m building LiveAPI, a tool that makes generating API docs from your code ridiculously easy.
What We Want to Represent
In the last post, we manually did some slope calculation like so:
<span>h</span> <span>=</span> <span>0.001</span><span>#inputs </span><span>a</span> <span>=</span> <span>2.0</span><span>b</span> <span>=</span> <span>-</span><span>3.0</span><span>c</span> <span>=</span> <span>10.0</span><span>d1</span> <span>=</span> <span>a</span><span>*</span><span>b</span> <span>+</span> <span>c</span><span>c</span> <span>+=</span> <span>h</span><span>d2</span> <span>=</span> <span>a</span><span>*</span><span>b</span> <span>+</span> <span>c</span><span>print</span><span>(</span><span>f</span><span>"</span><span>d1 = </span><span>{</span><span>d1</span><span>}</span><span>"</span><span>)</span><span>print</span><span>(</span><span>f</span><span>"</span><span>d2 = </span><span>{</span><span>d2</span><span>}</span><span>"</span><span>)</span><span>print</span><span>(</span><span>f</span><span>"</span><span>w.r.t a (d2 - d1) / h = </span><span>{</span><span>(</span><span>d2</span> <span>-</span> <span>d1</span><span>)</span> <span>/</span> <span>h</span><span>}</span><span>"</span><span>)</span><span>h</span> <span>=</span> <span>0.001</span> <span>#inputs </span><span>a</span> <span>=</span> <span>2.0</span> <span>b</span> <span>=</span> <span>-</span><span>3.0</span> <span>c</span> <span>=</span> <span>10.0</span> <span>d1</span> <span>=</span> <span>a</span><span>*</span><span>b</span> <span>+</span> <span>c</span> <span>c</span> <span>+=</span> <span>h</span> <span>d2</span> <span>=</span> <span>a</span><span>*</span><span>b</span> <span>+</span> <span>c</span> <span>print</span><span>(</span><span>f</span><span>"</span><span>d1 = </span><span>{</span><span>d1</span><span>}</span><span>"</span><span>)</span> <span>print</span><span>(</span><span>f</span><span>"</span><span>d2 = </span><span>{</span><span>d2</span><span>}</span><span>"</span><span>)</span> <span>print</span><span>(</span><span>f</span><span>"</span><span>w.r.t a (d2 - d1) / h = </span><span>{</span><span>(</span><span>d2</span> <span>-</span> <span>d1</span><span>)</span> <span>/</span> <span>h</span><span>}</span><span>"</span><span>)</span>h = 0.001 #inputs a = 2.0 b = -3.0 c = 10.0 d1 = a*b + c c += h d2 = a*b + c print(f"d1 = {d1}") print(f"d2 = {d2}") print(f"w.r.t a (d2 - d1) / h = {(d2 - d1) / h}")
Enter fullscreen mode Exit fullscreen mode
The goal is to represent the above expression L = a*b + c
in an easy way, and then do critical operations on it, such as find dL/da
, dL/db
, dL/dc
, etc. This is important because neural network training is about adjusting the values of each node in the expression graph, until the inputs are mapped to output in a desirable way.
Building the Value
class (foundation for getting neural networks)
The first step in fulfill the above is to represent a single value.
Iteration 1: Represent a single Value
<span>class</span> <span>Value</span><span>:</span><span>def</span> <span>__init__</span><span>(</span><span>self</span><span>,</span> <span>data</span><span>):</span><span>self</span><span>.</span><span>data</span> <span>=</span> <span>data</span><span>def</span> <span>__repr__</span><span>(</span><span>self</span><span>):</span><span>return</span> <span>f</span><span>"</span><span>Value(data=</span><span>{</span><span>self</span><span>.</span><span>data</span><span>}</span><span>)</span><span>"</span><span>class</span> <span>Value</span><span>:</span> <span>def</span> <span>__init__</span><span>(</span><span>self</span><span>,</span> <span>data</span><span>):</span> <span>self</span><span>.</span><span>data</span> <span>=</span> <span>data</span> <span>def</span> <span>__repr__</span><span>(</span><span>self</span><span>):</span> <span>return</span> <span>f</span><span>"</span><span>Value(data=</span><span>{</span><span>self</span><span>.</span><span>data</span><span>}</span><span>)</span><span>"</span>class Value: def __init__(self, data): self.data = data def __repr__(self): return f"Value(data={self.data})"
Enter fullscreen mode Exit fullscreen mode
Iteration 2: Represent Multiple Values And Operations on Them
We add addition and multiplication supports on the Value
class so that we can do a + b
or a * b + c
<span>class</span> <span>Value</span><span>:</span><span>def</span> <span>__init__</span><span>(</span><span>self</span><span>,</span> <span>data</span><span>):</span><span>self</span><span>.</span><span>data</span> <span>=</span> <span>data</span><span>def</span> <span>__repr__</span><span>(</span><span>self</span><span>):</span><span>return</span> <span>f</span><span>"</span><span>Value(data=</span><span>{</span><span>self</span><span>.</span><span>data</span><span>}</span><span>)</span><span>"</span><span>def</span> <span>__add__</span><span>(</span><span>self</span><span>,</span> <span>other</span><span>):</span><span>return</span> <span>Value</span><span>(</span><span>self</span><span>.</span><span>data</span> <span>+</span> <span>other</span><span>.</span><span>data</span><span>)</span><span>def</span> <span>__mul__</span><span>(</span><span>self</span><span>,</span> <span>other</span><span>):</span><span>return</span> <span>Value</span><span>(</span><span>self</span><span>.</span><span>data</span> <span>*</span> <span>other</span><span>.</span><span>data</span><span>)</span><span>a</span> <span>=</span> <span>Value</span><span>(</span><span>2.0</span><span>)</span><span>b</span> <span>=</span> <span>Value</span><span>(</span><span>-</span><span>3.0</span><span>)</span><span>print</span><span>(</span><span>a</span><span>*</span><span>b</span><span>)</span><span>c</span> <span>=</span> <span>Value</span><span>(</span><span>10</span><span>)</span><span>print</span><span>(</span><span>a</span> <span>*</span> <span>b</span> <span>+</span> <span>c</span><span>)</span><span>print</span><span>((</span><span>a</span><span>.</span><span>__mul__</span><span>(</span><span>b</span><span>)).</span><span>__add__</span><span>(</span><span>c</span><span>))</span> <span># same as above </span><span>class</span> <span>Value</span><span>:</span> <span>def</span> <span>__init__</span><span>(</span><span>self</span><span>,</span> <span>data</span><span>):</span> <span>self</span><span>.</span><span>data</span> <span>=</span> <span>data</span> <span>def</span> <span>__repr__</span><span>(</span><span>self</span><span>):</span> <span>return</span> <span>f</span><span>"</span><span>Value(data=</span><span>{</span><span>self</span><span>.</span><span>data</span><span>}</span><span>)</span><span>"</span> <span>def</span> <span>__add__</span><span>(</span><span>self</span><span>,</span> <span>other</span><span>):</span> <span>return</span> <span>Value</span><span>(</span><span>self</span><span>.</span><span>data</span> <span>+</span> <span>other</span><span>.</span><span>data</span><span>)</span> <span>def</span> <span>__mul__</span><span>(</span><span>self</span><span>,</span> <span>other</span><span>):</span> <span>return</span> <span>Value</span><span>(</span><span>self</span><span>.</span><span>data</span> <span>*</span> <span>other</span><span>.</span><span>data</span><span>)</span> <span>a</span> <span>=</span> <span>Value</span><span>(</span><span>2.0</span><span>)</span> <span>b</span> <span>=</span> <span>Value</span><span>(</span><span>-</span><span>3.0</span><span>)</span> <span>print</span><span>(</span><span>a</span><span>*</span><span>b</span><span>)</span> <span>c</span> <span>=</span> <span>Value</span><span>(</span><span>10</span><span>)</span> <span>print</span><span>(</span><span>a</span> <span>*</span> <span>b</span> <span>+</span> <span>c</span><span>)</span> <span>print</span><span>((</span><span>a</span><span>.</span><span>__mul__</span><span>(</span><span>b</span><span>)).</span><span>__add__</span><span>(</span><span>c</span><span>))</span> <span># same as above </span>class Value: def __init__(self, data): self.data = data def __repr__(self): return f"Value(data={self.data})" def __add__(self, other): return Value(self.data + other.data) def __mul__(self, other): return Value(self.data * other.data) a = Value(2.0) b = Value(-3.0) print(a*b) c = Value(10) print(a * b + c) print((a.__mul__(b)).__add__(c)) # same as above
Enter fullscreen mode Exit fullscreen mode
Iteration 3: Store whole expressions
The next step is to store the “whole chain” of values and operations in a nice graph.
The way this is done is via introducing two new object attributes: _prev
and _op
. For each node – we record what are the nodes beneath/before it. And also – we specify what operation was performed between those nodes that came before to get the present node.
<span>class</span> <span>Value</span><span>:</span><span>def</span> <span>__init__</span><span>(</span><span>self</span><span>,</span> <span>data</span><span>,</span> <span>_children</span><span>=</span><span>(),</span> <span>_op</span><span>=</span><span>''</span><span>):</span><span>self</span><span>.</span><span>data</span> <span>=</span> <span>data</span><span>self</span><span>.</span><span>_prev</span> <span>=</span> <span>set</span><span>(</span><span>_children</span><span>)</span><span>self</span><span>.</span><span>_op</span> <span>=</span> <span>_op</span><span>def</span> <span>__repr__</span><span>(</span><span>self</span><span>):</span><span>return</span> <span>f</span><span>"</span><span>Value(data=</span><span>{</span><span>self</span><span>.</span><span>data</span><span>}</span><span>)</span><span>"</span><span>def</span> <span>__add__</span><span>(</span><span>self</span><span>,</span> <span>other</span><span>):</span><span>return</span> <span>Value</span><span>(</span><span>self</span><span>.</span><span>data</span> <span>+</span> <span>other</span><span>.</span><span>data</span><span>,</span> <span>(</span><span>self</span><span>,</span> <span>other</span><span>),</span> <span>'</span><span>+</span><span>'</span><span>)</span><span>def</span> <span>__mul__</span><span>(</span><span>self</span><span>,</span> <span>other</span><span>):</span><span>return</span> <span>Value</span><span>(</span><span>self</span><span>.</span><span>data</span> <span>*</span> <span>other</span><span>.</span><span>data</span><span>,</span> <span>(</span><span>self</span><span>,</span> <span>other</span><span>),</span> <span>'</span><span>-</span><span>'</span><span>)</span><span>a</span> <span>=</span> <span>Value</span><span>(</span><span>2.0</span><span>)</span><span>b</span> <span>=</span> <span>Value</span><span>(</span><span>-</span><span>3.0</span><span>)</span><span>c</span> <span>=</span> <span>Value</span><span>(</span><span>10</span><span>)</span><span>e</span> <span>=</span> <span>a</span> <span>*</span> <span>b</span><span>d</span> <span>=</span> <span>e</span> <span>+</span> <span>c</span><span>print</span><span>(</span><span>d</span><span>.</span><span>_prev</span><span>)</span><span>print</span><span>(</span><span>d</span><span>.</span><span>_op</span><span>)</span><span>print</span><span>(</span><span>"</span><span>---</span><span>"</span><span>)</span><span>print</span><span>(</span><span>e</span><span>.</span><span>_prev</span><span>)</span><span>print</span><span>(</span><span>e</span><span>.</span><span>_op</span><span>)</span><span>`</span><span>class</span> <span>Value</span><span>:</span> <span>def</span> <span>__init__</span><span>(</span><span>self</span><span>,</span> <span>data</span><span>,</span> <span>_children</span><span>=</span><span>(),</span> <span>_op</span><span>=</span><span>''</span><span>):</span> <span>self</span><span>.</span><span>data</span> <span>=</span> <span>data</span> <span>self</span><span>.</span><span>_prev</span> <span>=</span> <span>set</span><span>(</span><span>_children</span><span>)</span> <span>self</span><span>.</span><span>_op</span> <span>=</span> <span>_op</span> <span>def</span> <span>__repr__</span><span>(</span><span>self</span><span>):</span> <span>return</span> <span>f</span><span>"</span><span>Value(data=</span><span>{</span><span>self</span><span>.</span><span>data</span><span>}</span><span>)</span><span>"</span> <span>def</span> <span>__add__</span><span>(</span><span>self</span><span>,</span> <span>other</span><span>):</span> <span>return</span> <span>Value</span><span>(</span><span>self</span><span>.</span><span>data</span> <span>+</span> <span>other</span><span>.</span><span>data</span><span>,</span> <span>(</span><span>self</span><span>,</span> <span>other</span><span>),</span> <span>'</span><span>+</span><span>'</span><span>)</span> <span>def</span> <span>__mul__</span><span>(</span><span>self</span><span>,</span> <span>other</span><span>):</span> <span>return</span> <span>Value</span><span>(</span><span>self</span><span>.</span><span>data</span> <span>*</span> <span>other</span><span>.</span><span>data</span><span>,</span> <span>(</span><span>self</span><span>,</span> <span>other</span><span>),</span> <span>'</span><span>-</span><span>'</span><span>)</span> <span>a</span> <span>=</span> <span>Value</span><span>(</span><span>2.0</span><span>)</span> <span>b</span> <span>=</span> <span>Value</span><span>(</span><span>-</span><span>3.0</span><span>)</span> <span>c</span> <span>=</span> <span>Value</span><span>(</span><span>10</span><span>)</span> <span>e</span> <span>=</span> <span>a</span> <span>*</span> <span>b</span> <span>d</span> <span>=</span> <span>e</span> <span>+</span> <span>c</span> <span>print</span><span>(</span><span>d</span><span>.</span><span>_prev</span><span>)</span> <span>print</span><span>(</span><span>d</span><span>.</span><span>_op</span><span>)</span> <span>print</span><span>(</span><span>"</span><span>---</span><span>"</span><span>)</span> <span>print</span><span>(</span><span>e</span><span>.</span><span>_prev</span><span>)</span> <span>print</span><span>(</span><span>e</span><span>.</span><span>_op</span><span>)</span><span>`</span>class Value: def __init__(self, data, _children=(), _op=''): self.data = data self._prev = set(_children) self._op = _op def __repr__(self): return f"Value(data={self.data})" def __add__(self, other): return Value(self.data + other.data, (self, other), '+') def __mul__(self, other): return Value(self.data * other.data, (self, other), '-') a = Value(2.0) b = Value(-3.0) c = Value(10) e = a * b d = e + c print(d._prev) print(d._op) print("---") print(e._prev) print(e._op)`
Enter fullscreen mode Exit fullscreen mode
Visualizing the Expression Graph
Karpathy shares a nice bit of code built on top of GraphViz to display the expressions as a graph. Values are represented in squaraes, and operations in ellipses:
<span>from</span> <span>graphviz</span> <span>import</span> <span>Digraph</span><span>def</span> <span>trace</span><span>(</span><span>root</span><span>):</span><span># Builds a set of all nodes and edges in a graph </span> <span>nodes</span><span>,</span> <span>edges</span> <span>=</span> <span>set</span><span>(),</span> <span>set</span><span>()</span><span>def</span> <span>build</span><span>(</span><span>v</span><span>):</span><span>if</span> <span>v</span> <span>not</span> <span>in</span> <span>nodes</span><span>:</span><span>nodes</span><span>.</span><span>add</span><span>(</span><span>v</span><span>)</span><span>for</span> <span>child</span> <span>in</span> <span>v</span><span>.</span><span>_prev</span><span>:</span><span>edges</span><span>.</span><span>add</span><span>((</span><span>child</span><span>,</span> <span>v</span><span>))</span><span>build</span><span>(</span><span>child</span><span>)</span><span>build</span><span>(</span><span>root</span><span>)</span><span>return</span> <span>nodes</span><span>,</span> <span>edges</span><span>def</span> <span>draw_dot</span><span>(</span><span>root</span><span>):</span><span>dot</span> <span>=</span> <span>Digraph</span><span>(</span><span>format</span><span>=</span><span>'</span><span>svg</span><span>'</span><span>,</span> <span>graph_attr</span><span>=</span><span>{</span><span>'</span><span>rankdir</span><span>'</span><span>:</span> <span>'</span><span>LR</span><span>'</span><span>})</span> <span># LR = left to right </span><span>nodes</span><span>,</span> <span>edges</span> <span>=</span> <span>trace</span><span>(</span><span>root</span><span>)</span><span>for</span> <span>n</span> <span>in</span> <span>nodes</span><span>:</span><span>uid</span> <span>=</span> <span>str</span><span>(</span><span>id</span><span>(</span><span>n</span><span>))</span><span># For any value in the graph, create a rectangular ('record') node for it </span> <span>dot</span><span>.</span><span>node</span><span>(</span><span>name</span><span>=</span><span>uid</span><span>,</span> <span>label</span><span>=</span><span>"</span><span>{ data %.4f }</span><span>"</span> <span>%</span> <span>(</span><span>n</span><span>.</span><span>data</span><span>,),</span> <span>shape</span><span>=</span><span>'</span><span>record</span><span>'</span><span>)</span><span>if</span> <span>n</span><span>.</span><span>_op</span><span>:</span><span># If this value is a result of some operation, create an op node for it </span> <span>dot</span><span>.</span><span>node</span><span>(</span><span>name</span><span>=</span><span>uid</span> <span>+</span> <span>n</span><span>.</span><span>_op</span><span>,</span> <span>label</span><span>=</span><span>n</span><span>.</span><span>_op</span><span>)</span><span># And connect this node to it </span> <span>dot</span><span>.</span><span>edge</span><span>(</span><span>uid</span> <span>+</span> <span>n</span><span>.</span><span>_op</span><span>,</span> <span>uid</span><span>)</span><span>for</span> <span>n1</span><span>,</span> <span>n2</span> <span>in</span> <span>edges</span><span>:</span><span># Connect n1 to the op node of n2 </span> <span>dot</span><span>.</span><span>edge</span><span>(</span><span>str</span><span>(</span><span>id</span><span>(</span><span>n1</span><span>)),</span> <span>str</span><span>(</span><span>id</span><span>(</span><span>n2</span><span>))</span> <span>+</span> <span>n2</span><span>.</span><span>_op</span><span>)</span><span>return</span> <span>dot</span><span>from</span> <span>graphviz</span> <span>import</span> <span>Digraph</span> <span>def</span> <span>trace</span><span>(</span><span>root</span><span>):</span> <span># Builds a set of all nodes and edges in a graph </span> <span>nodes</span><span>,</span> <span>edges</span> <span>=</span> <span>set</span><span>(),</span> <span>set</span><span>()</span> <span>def</span> <span>build</span><span>(</span><span>v</span><span>):</span> <span>if</span> <span>v</span> <span>not</span> <span>in</span> <span>nodes</span><span>:</span> <span>nodes</span><span>.</span><span>add</span><span>(</span><span>v</span><span>)</span> <span>for</span> <span>child</span> <span>in</span> <span>v</span><span>.</span><span>_prev</span><span>:</span> <span>edges</span><span>.</span><span>add</span><span>((</span><span>child</span><span>,</span> <span>v</span><span>))</span> <span>build</span><span>(</span><span>child</span><span>)</span> <span>build</span><span>(</span><span>root</span><span>)</span> <span>return</span> <span>nodes</span><span>,</span> <span>edges</span> <span>def</span> <span>draw_dot</span><span>(</span><span>root</span><span>):</span> <span>dot</span> <span>=</span> <span>Digraph</span><span>(</span><span>format</span><span>=</span><span>'</span><span>svg</span><span>'</span><span>,</span> <span>graph_attr</span><span>=</span><span>{</span><span>'</span><span>rankdir</span><span>'</span><span>:</span> <span>'</span><span>LR</span><span>'</span><span>})</span> <span># LR = left to right </span> <span>nodes</span><span>,</span> <span>edges</span> <span>=</span> <span>trace</span><span>(</span><span>root</span><span>)</span> <span>for</span> <span>n</span> <span>in</span> <span>nodes</span><span>:</span> <span>uid</span> <span>=</span> <span>str</span><span>(</span><span>id</span><span>(</span><span>n</span><span>))</span> <span># For any value in the graph, create a rectangular ('record') node for it </span> <span>dot</span><span>.</span><span>node</span><span>(</span><span>name</span><span>=</span><span>uid</span><span>,</span> <span>label</span><span>=</span><span>"</span><span>{ data %.4f }</span><span>"</span> <span>%</span> <span>(</span><span>n</span><span>.</span><span>data</span><span>,),</span> <span>shape</span><span>=</span><span>'</span><span>record</span><span>'</span><span>)</span> <span>if</span> <span>n</span><span>.</span><span>_op</span><span>:</span> <span># If this value is a result of some operation, create an op node for it </span> <span>dot</span><span>.</span><span>node</span><span>(</span><span>name</span><span>=</span><span>uid</span> <span>+</span> <span>n</span><span>.</span><span>_op</span><span>,</span> <span>label</span><span>=</span><span>n</span><span>.</span><span>_op</span><span>)</span> <span># And connect this node to it </span> <span>dot</span><span>.</span><span>edge</span><span>(</span><span>uid</span> <span>+</span> <span>n</span><span>.</span><span>_op</span><span>,</span> <span>uid</span><span>)</span> <span>for</span> <span>n1</span><span>,</span> <span>n2</span> <span>in</span> <span>edges</span><span>:</span> <span># Connect n1 to the op node of n2 </span> <span>dot</span><span>.</span><span>edge</span><span>(</span><span>str</span><span>(</span><span>id</span><span>(</span><span>n1</span><span>)),</span> <span>str</span><span>(</span><span>id</span><span>(</span><span>n2</span><span>))</span> <span>+</span> <span>n2</span><span>.</span><span>_op</span><span>)</span> <span>return</span> <span>dot</span>from graphviz import Digraph def trace(root): # Builds a set of all nodes and edges in a graph nodes, edges = set(), set() def build(v): if v not in nodes: nodes.add(v) for child in v._prev: edges.add((child, v)) build(child) build(root) return nodes, edges def draw_dot(root): dot = Digraph(format='svg', graph_attr={'rankdir': 'LR'}) # LR = left to right nodes, edges = trace(root) for n in nodes: uid = str(id(n)) # For any value in the graph, create a rectangular ('record') node for it dot.node(name=uid, label="{ data %.4f }" % (n.data,), shape='record') if n._op: # If this value is a result of some operation, create an op node for it dot.node(name=uid + n._op, label=n._op) # And connect this node to it dot.edge(uid + n._op, uid) for n1, n2 in edges: # Connect n1 to the op node of n2 dot.edge(str(id(n1)), str(id(n2)) + n2._op) return dot
Enter fullscreen mode Exit fullscreen mode
I can do the following to get an image of the graph:
<span>draw_dot</span><span>(</span><span>d</span><span>)</span> <span># where d is the expression defined above </span><span>draw_dot</span><span>(</span><span>d</span><span>)</span> <span># where d is the expression defined above </span>draw_dot(d) # where d is the expression defined above
Enter fullscreen mode Exit fullscreen mode
Reference
The spelled-out intro to neural networks and backpropagation: building micrograd)
The Essence of Neural Networks (As Explained by Karpathy) (7 Part Series)
1 The Essence of Neural Networks (As Explained by Karpathy)
2 Understanding Backpropagation from Scratch with micrograd – Derivatives
… 3 more parts…
3 Representing Math Expressions As Graphs in micrograd
4 Back-Propagation Spelled Out – As Explained by Karpathy
5 Modeling a Neuron in micrograd (As Explained by Karpathy)
6 Fully Automated Gradient Calculation on Expression Graph (As Explained By Karpathy)
7 Fixing A Bug in micrograd BackProp (As Explained by Karpathy)
暂无评论内容