-
-
Notifications
You must be signed in to change notification settings - Fork 48.9k
Add automatic differentiation algorithm #10977
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
7e37b7a
1fbe703
86b3397
d1e7f73
8645334
ad5b56d
1997de6
71234dd
a81632c
1402242
24401e0
a3d7418
4bea57d
69d8fbd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,324 @@ | ||||||||||||||||||||||
""" | ||||||||||||||||||||||
Demonstration of the Automatic Differentiation (Reverse mode). | ||||||||||||||||||||||
|
||||||||||||||||||||||
Reference: https://en.wikipedia.org/wiki/Automatic_differentiation | ||||||||||||||||||||||
|
||||||||||||||||||||||
Author: Poojan smart | ||||||||||||||||||||||
Email: [email protected] | ||||||||||||||||||||||
|
||||||||||||||||||||||
Examples: | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
||||||||||||||||||||||
>>> with GradientTracker() as tracker: | ||||||||||||||||||||||
... a = Variable([2.0, 5.0]) | ||||||||||||||||||||||
... b = Variable([1.0, 2.0]) | ||||||||||||||||||||||
... m = Variable([1.0, 2.0]) | ||||||||||||||||||||||
... c = a + b | ||||||||||||||||||||||
... d = a * b | ||||||||||||||||||||||
... e = c / d | ||||||||||||||||||||||
>>> print(tracker.gradient(e, a)) | ||||||||||||||||||||||
|
>>> print(tracker.gradient(e, a)) | |
>>> tracker.gradient(e, a) |
Repeat below.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
>>> print(tracker.gradient(e, m)) | |
None | |
>>> tracker.gradient(e, m) is None | |
True |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We get 88 characters per line max so let's use them to get to fewer lines.
Class represents n-dimensional object which is used to wrap | |
numpy array on which operations will be performed and gradient | |
will be calculated. | |
Class represents n-dimensional object which is used to wrap numpy array on which | |
operations will be performed and the gradient will be calculated. |
cclauss marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
cclauss marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
cclauss marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from typing import Any
def __init__(self, value) -> None: | |
def __init__(self, value: Any) -> None: |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def __str__(self) -> str: | |
def __repr__(self) -> str: |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def numpy(self) -> np.ndarray: | |
def to_ndarray(self) -> np.ndarray: |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add
... SUB
is a bit difficult to track. What able tracker.add_operation()
--> tracker.append()
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would put this at the top of the file.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
88 chars
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As there is no test file in this pull request nor any test function or class in the file machine_learning/automatic_differentiation.py
, please provide doctest for the function add_params
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As there is no test file in this pull request nor any test function or class in the file machine_learning/automatic_differentiation.py
, please provide doctest for the function add_output
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please provide type hint for the parameter: value
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please provide type hint for the parameter: value
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please provide type hint for the parameter: value
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As there is no test file in this pull request nor any test function or class in the file machine_learning/automatic_differentiation.py
, please provide doctest for the function __eq__
Please provide type hint for the parameter: value
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if isinstance(value, OpType): | |
return self.op_type == value | |
return False | |
return self.op_type == value if isinstance(value, OpType) else False |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be a defaultdict[list]
so that we can streamline below.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if param in partial_deriv: | |
partial_deriv[param] += dparam_dtarget | |
else: | |
partial_deriv[param] = dparam_dtarget | |
partial_deriv[param] += dparam_dtarget |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if param.result_of: | |
if param.result_of != OpType.NOOP: | |
operation_queue.append(param.result_of) | |
if param.result_of and param.result_of != OpType.NOOP: | |
operation_queue.append(param.result_of) |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if source in partial_deriv: | |
return partial_deriv[source] | |
return None | |
return partial_deriv.get(source) |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The method .derivative()
does not need a variable also called derivative
. Just return the value as soon as you have it.
derivative = None |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
derivative = np.ones_like(params[0].numpy(), dtype=np.float64) | |
elif operation == OpType.SUB: | |
return np.ones_like(params[0].numpy(), dtype=np.float64) | |
if operation == OpType.SUB: |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
derivative = np.ones_like(params[0].numpy(), dtype=np.float64) | |
else: | |
derivative = -np.ones_like(params[1].numpy(), dtype=np.float64) | |
elif operation == OpType.MUL: | |
derivative = ( | |
return np.ones_like(params[0].numpy(), dtype=np.float64) | |
else: | |
return -np.ones_like(params[1].numpy(), dtype=np.float64) | |
if operation == OpType.MUL: | |
return ( |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
elif operation == OpType.DIV: | |
if operation == OpType.DIV: |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
derivative = 1 / params[1].numpy() | |
else: | |
derivative = -params[0].numpy() / (params[1].numpy() ** 2) | |
elif operation == OpType.MATMUL: | |
derivative = ( | |
return 1 / params[1].numpy() | |
else: | |
return -params[0].numpy() / (params[1].numpy() ** 2) | |
if operation == OpType.MATMUL: | |
return ( |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
elif operation == OpType.POWER: | |
power = operation.other_params["power"] | |
derivative = power * (params[0].numpy() ** (power - 1)) | |
if operation == OpType.POWER: | |
power = operation.other_params["power"] | |
return power * (params[0].numpy() ** (power - 1)) |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to raise ValueError("Invalid operation")
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your choice...