Skip to content

Commit

Permalink
wip batched training
Browse files Browse the repository at this point in the history
  • Loading branch information
timokau committed Oct 19, 2020
1 parent 755ee12 commit 74ad5e7
Showing 1 changed file with 95 additions and 40 deletions.
135 changes: 95 additions & 40 deletions feta.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,30 @@
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset

class Permutation:
def __init__(self):
pass

class Ranking:
# A class representing a ranking of objects.
def __init__(self, objects, indices):
self.objects = objects
self.indices = indices

def get_objects(self):
return self.objects

def get_ranking_indices(self):
return self.indices

def get_ranked_objects(self):
return self.objects[self.indices]

def __repr__(self):
return repr(self.get_ranked_objects())

def hinged_rank_loss(y_true, y_pred):
"""Compute the loss between two rankings.
Expand All @@ -26,8 +49,11 @@ def hinged_rank_loss(y_true, y_pred):
# 2d matrix which is 1 if the row-element *should* be ranked higher than the column element
# TODO should be torch.gt this might be <=1 and is not a proper mask
mask = torch.clamp(y_true[:, None] - y_true[:, :, None], min=0, max=1)
# print("===MASK===")
# print(mask)
# how much higher/lower the elements are actually ranked
diff = y_pred[:, :, None] - y_pred[:, None]
# print(diff)
# loss for those elements that should be ranked higher, but are ranked lower
# TODO verify this is correct
hinge = torch.clamp(mask * diff, min=0)
Expand All @@ -42,8 +68,8 @@ def __init__(self, n_features: int):
# TODO non-linear model
self.lin = nn.Linear(n_features, 1)

def forward(self, object):
return self.lin(object)
def forward(self, objects):
return self.lin(objects)


class FirstOrderUtility(nn.Module):
Expand All @@ -52,74 +78,92 @@ def __init__(self, n_features: int):
# TODO both ways?
self.lin = nn.Linear(n_features * 2, 1)

def forward(self, first_object, second_object):
def forward(self, object_pairs):
# TODO more complicated model
return self.lin(torch.cat(frist_object, second_object))
return self.lin(object_pairs)


class FETAModel(nn.Module):
def __init__(self, zeroth_order_utility, first_order_utility):
def __init__(self, zeroth_order_utility, first_order_utility, n_objects):
super().__init__()
self.zeroth_order_utility = zeroth_order_utility
self.first_order_utility = first_order_utility
self.n_objects = n_objects

def forward(self, instance):
def forward(self, instances):
"""Aggregate zeroth and first order utility, returning a 1d tensor of evaluations."""
# TODO multiple at once?
(n_objects, n_features) = instance.shape
context_utility = 0
# TODO for every object, compute its own utility and all pairwise utilities
# context_utility = self.first_order_utility(instances).sum(axis=1)
# context_utility = 0
result = (
self.zeroth_order_utility(instance) + 1 / (n_objects - 1) * context_utility
self.zeroth_order_utility(instances) + 1 / (self.n_objects - 1) * context_utility
)
self.zeroth_order_utility(instance)
self.zeroth_order_utility(instances)
return result[:,0]


class FETARankingEstimator(BaseEstimator):
def __init__(self, lr=1e-5, batch_size=1):
def __init__(self, lr=1e-5, batch_size=64):
self.lr = lr
self.batch_size = batch_size

def fit(self, X, Y):
(n_samples, n_objects, n_features) = X.shape
self.model_ = FETAModel(
ZerothOrderUtility(n_features), FirstOrderUtility(n_features)
).double()
def fit(self, dataset):
data_loader = DataLoader(dataset, batch_size=64)
# n_instances = len(dataset)
# (n_samples, n_objects, n_features) = X.shape
(n_objects, n_features) = dataset[0][0].shape
# self.model_ = FETAModel(
# ZerothOrderUtility(n_features), FirstOrderUtility(n_features), n_objects
# ).double()
self.model_ = ZerothOrderUtility(n_features).double()
self.model_.train()
self.loss_function_ = hinged_rank_loss
self.optimizer_ = optim.SGD(self.model_.parameters(), lr=self.lr)

assert self.batch_size == 1 # for now
# Overfit to make sure the model can learn.
for i in range(1000):
for (sample, true_ranking) in zip(X, Y):
# scores should ideally match the ranks
scores = self.model_(sample)
loss = self.loss_function_(true_ranking[None, :], scores[None, :])
print(loss)
loss.backward()
self.optimizer_.step()
self.optimizer_.zero_grad()
for (samples, true_rankings) in data_loader:
# scores should ideally match the ranks
# The model gives the score of each object in a singleton list;
# squeeze to remove that last dimension.
scores = self.model_(samples).squeeze(-1)
# print("==SAMPLES==")
# print(samples)
# print("==SCORES==")
# print(scores)
# print(true_rankings)
loss = self.loss_function_(true_rankings, scores)
# print("===DEBUG===")
# print(true_rankings)
# print("===DEBUG2===")
# print(scores)
# print("===DEBUG3===")
loss.sum().backward()
self.optimizer_.step()
self.optimizer_.zero_grad()
self.model_.eval()

def rank(self, objects):
"""Rank based on evaluations."""
evaluations = self.model_(objects)
(_values, indices) = torch.sort(evaluations)
(_values, indices) = torch.sort(evaluations.squeeze())
return indices


def trivial_ranking_problem(n_objects, n_instances, random_state):
class TrivialRankingProblem(TensorDataset):
"""Generate a trivial ranking problem for testing purposes.
Each object has only one feature, and the true ranking is determined by
ranking on that feature.
Consider this generated problem:
>>> (x, y_true) = trivial_ranking_problem(n_objects=2, n_instances=2, random_state=np.random.RandomState(42))
>>> dataset = TrivialRankingProblem(
... n_objects=2,
... n_instances=2,
... random_state=np.random.RandomState(42),
... )
>>> x, y_true = dataset.tensors
>>> x
tensor([[[ 0.4967],
[-0.1383]],
Expand All @@ -133,25 +177,36 @@ def trivial_ranking_problem(n_objects, n_instances, random_state):
tensor([[1, 0],
[0, 1]])
"""
n_features = 1
x = torch.tensor(
random_state.randn(n_instances, n_objects, n_features)
)
y_true = x.argsort(axis=1).argsort(axis=1).squeeze(axis=-1)
return x, y_true
def __init__(self, n_objects, n_instances, random_state):
n_features = 1
x = torch.tensor(
random_state.randn(n_instances, n_objects, n_features)
)
y_true = x.argsort(axis=1).argsort(axis=1).squeeze(axis=-1)
super().__init__(x, y_true)


def _main():
n_objects = 3
n_instances = 100
n_instances = 100000
batch_size = 64
random_state = np.random.RandomState(42)
(object_lists, true_rankings) = trivial_ranking_problem(

train_ds = TrivialRankingProblem(
n_objects, n_instances, random_state
)
test_ds = TrivialRankingProblem(
n_objects, n_instances, random_state
)

estimator = FETARankingEstimator()
estimator.fit(object_lists, true_rankings)
print(estimator.rank(object_lists[0]), true_rankings[0])

estimator.fit(train_ds)
print("===INPUT===")
print(test_ds[0])
print("===SCORES===")
print(estimator.model_(test_ds[0][0]))
print("===RANKING===")
print(estimator.rank(test_ds[0][0]))


if __name__ == "__main__":
Expand Down

0 comments on commit 74ad5e7

Please sign in to comment.