Skip to content

Commit

Permalink
🐛 Slice bug fix
Browse files Browse the repository at this point in the history
fix #2
  • Loading branch information
joey00072 committed Dec 3, 2023
1 parent d771294 commit 2a8dcdc
Showing 1 changed file with 50 additions and 35 deletions.
85 changes: 50 additions & 35 deletions tinytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
class Tensor:
__slots__ = ("data", "grad", "_ctx", "requires_grad", "device")
_compute_grad = True

def __init__(self, data, requires_grad=False):
self.data: np.ndarray = Tensor._data_to_numpy(data)
self.grad: Tensor = None
Expand Down Expand Up @@ -231,31 +232,30 @@ def _undo_broadcast(self, tensor: Tensor, grad: Tensor):
grad = grad.sum(axis=idx, keepdims=True)

return Tensor(grad)



def backward(self):
if self._ctx is None:
return

if self.grad is None:
if self.size != 1:
raise RuntimeError("Backward can not be called on non zero tensor")
self.grad = Tensor([1.0])
def topo_sort(node:Tensor,visited:set,sortlist:list)->list:
if not isinstance(node,Tensor) or node in visited:

def topo_sort(node: Tensor, visited: set, sortlist: list) -> list:
if not isinstance(node, Tensor) or node in visited:
return sortlist
visited.add(node)
if node._ctx is None:
sortlist.append(node)
return sortlist
for child_node in node._ctx.args:
topo_sort(child_node,visited,sortlist)
topo_sort(child_node, visited, sortlist)
sortlist.append(node)
return sortlist
node_list:list[Tensor] = reversed(topo_sort(self,set(),[]))

node_list: list[Tensor] = reversed(topo_sort(self, set(), []))

for node in node_list:
if node._ctx is None:
continue
Expand All @@ -270,10 +270,10 @@ def topo_sort(node:Tensor,visited:set,sortlist:list)->list:
if tensor.grad is None:
tensor.grad = Tensor(np.zeros_like(tensor.data).astype(np.float32))
tensor.grad.data += grad.numpy()
node._ctx = None

node._ctx = None


class Function:
__slots__ = (
"op",
Expand All @@ -297,7 +297,7 @@ def apply(cls, *args):
def _is_part_of_graph(ctx: Function):
if not Tensor._compute_grad:
return False

for node in ctx.args:
if isinstance(node, Tensor) and (
node.requires_grad or node._ctx is not None
Expand Down Expand Up @@ -426,10 +426,14 @@ def backward(ctx, grad):
if isinstance(slice_args, Tensor):
slice_args = slice_args.data
grad_x = np.zeros_like(x.data)
if JAX:
grad_x = grad_x.at[slice_args].set(grad.data)
if isinstance(slice, (int, tuple)):
if JAX:
grad_x = grad_x.at[slice_args].set(grad.data)
else:
grad_x[slice_args] = grad.data
else:
grad_x[slice_args] = grad.data
for s in np.array(slice_args).reshape(-1):
grad_x[s] += grad.data[s]
return Tensor(grad_x), None


Expand Down Expand Up @@ -480,7 +484,6 @@ def backward(ctx: Function, grad: Tensor) -> Tensor:

return Tensor(grad_x), None


class Power(Function):
@staticmethod
def forward(x, y):
Expand Down Expand Up @@ -545,9 +548,9 @@ def backward(ctx: Function, grad: Tensor) -> list[Tensor]:
probs = exps / np.sum(exps, axis=1, keepdims=True)
d_loss = np.zeros_like(probs)
if JAX:
d_loss = d_loss.at[np.arange(len(y_true.data)), y_true.data.astype(int)].set(
d_loss[np.arange(len(y_true.data)), y_true.data.astype(int)] - 1
)
d_loss = d_loss.at[
np.arange(len(y_true.data)), y_true.data.astype(int)
].set(d_loss[np.arange(len(y_true.data)), y_true.data.astype(int)] - 1)
else:
d_loss[np.arange(len(y_true.data)), y_true.data.astype(int)] -= 1
d_loss += probs
Expand Down Expand Up @@ -590,8 +593,9 @@ def multinomial(
tensor: Tensor, num_samples: int = 1, replacement: bool = False
) -> Tensor:
import numpy

tensor = tensor.clone()
if JAX: # I have to do this ugly shit because jax dont support multinomial
if JAX: # I have to do this ugly shit because jax dont support multinomial
tensor.data = numpy.array(np.asarray(tensor.data).tolist())
# Check for at least 2D tensor (Batch x Classes)
if tensor.data.ndim < 2:
Expand Down Expand Up @@ -686,11 +690,13 @@ def rand(*shape, requires_grad=False) -> Tensor:
if JAX:
arr = jax.random.normal(jax.random.key(int(time.time())), shape=shape)
else:
arr = np.random.normal(mean,std_dev,size=shape)
arr = np.random.normal(mean, std_dev, size=shape)
return Tensor(arr, requires_grad=requires_grad)

def argmax(tensor:Tensor,axis=None):
return Tensor(np.array(np.argmax(tensor.data,axis=axis)))

def argmax(tensor: Tensor, axis=None):
return Tensor(np.array(np.argmax(tensor.data, axis=axis)))


def tensor(data, requires_grad=False):
return Tensor(data, requires_grad)
Expand All @@ -703,21 +709,28 @@ def arange(*args, requires_grad=False):
class Device:
def __init__(self, name: str = "cpu"):
self.name = name



def no_grad():
"""Context manager to temporarily disable gradient computation."""

class NoGradContext:
def __call__(self,func):
def wrapper(*args,**kwargs):
def __call__(self, func):
def wrapper(*args, **kwargs):
with NoGradContext():
return func(*args,**kwargs)
return func(*args, **kwargs)

return wrapper

def __enter__(self):
Tensor._compute_grad = False

def __exit__(self, exc_type, exc_value, traceback):
Tensor._compute_grad = True

return NoGradContext()


class Parameter(Tensor):
def __init__(self, tensor):
super().__init__(tensor, requires_grad=True)
Expand Down Expand Up @@ -787,11 +800,11 @@ def register_buffer(self, name, value: Tensor):

def to(self, device):
return self # add gpu backend maybe

def eval(self):
for p in self.parameters():
p.requires_grad = False

def train(self):
for p in self.parameters():
p.requires_grad = True
Expand All @@ -806,7 +819,7 @@ def __init__(self, in_features: int, out_features: int, bias: bool = True):
self.out_features = out_features

self.weight = Parameter(
rand((out_features, in_features)) / np.sqrt(in_features+out_features)
rand((out_features, in_features)) / np.sqrt(in_features + out_features)
)
self.bias = Parameter(zeros(out_features)) if bias else None

Expand Down Expand Up @@ -882,9 +895,11 @@ def step(self):


if __name__ == "__main__":
x = tensor(2,requires_grad=True)
x = tensor(2, requires_grad=True)

def f(x):
return (x+1)/x
return (x + 1) / x

z = f(x)
z.backward()
print(x.grad)

0 comments on commit 2a8dcdc

Please sign in to comment.