diff --git a/feta.py b/feta.py index 909713da..783ecb45 100644 --- a/feta.py +++ b/feta.py @@ -6,7 +6,30 @@ import torch import torch.nn as nn from torch import optim +from torch.utils.data import DataLoader +from torch.utils.data import TensorDataset +class Permutation: + def __init__(self): + pass + +class Ranking: + # A class representing a ranking of objects. + def __init__(self, objects, indices): + self.objects = objects + self.indices = indices + + def get_objects(self): + return self.objects + + def get_ranking_indices(self): + return self.indices + + def get_ranked_objects(self): + return self.objects[self.indices] + + def __repr__(self): + return repr(self.get_ranked_objects()) def hinged_rank_loss(y_true, y_pred): """Compute the loss between two rankings. @@ -26,8 +49,11 @@ def hinged_rank_loss(y_true, y_pred): # 2d matrix which is 1 if the row-element *should* be ranked higher than the column element # TODO should be torch.gt this might be <=1 and is not a proper mask mask = torch.clamp(y_true[:, None] - y_true[:, :, None], min=0, max=1) + # print("===MASK===") + # print(mask) # how much higher/lower the elements are actually ranked diff = y_pred[:, :, None] - y_pred[:, None] + # print(diff) # loss for those elements that should be ranked higher, but are ranked lower # TODO verify this is correct hinge = torch.clamp(mask * diff, min=0) @@ -42,8 +68,8 @@ def __init__(self, n_features: int): # TODO non-linear model self.lin = nn.Linear(n_features, 1) - def forward(self, object): - return self.lin(object) + def forward(self, objects): + return self.lin(objects) class FirstOrderUtility(nn.Module): @@ -52,66 +78,79 @@ def __init__(self, n_features: int): # TODO both ways? self.lin = nn.Linear(n_features * 2, 1) - def forward(self, first_object, second_object): + def forward(self, object_pairs): # TODO more complicated model - return self.lin(torch.cat(frist_object, second_object)) + return self.lin(object_pairs) class FETAModel(nn.Module): - def __init__(self, zeroth_order_utility, first_order_utility): + def __init__(self, zeroth_order_utility, first_order_utility, n_objects): super().__init__() self.zeroth_order_utility = zeroth_order_utility self.first_order_utility = first_order_utility + self.n_objects = n_objects - def forward(self, instance): + def forward(self, instances): """Aggregate zeroth and first order utility, returning a 1d tensor of evaluations.""" # TODO multiple at once? - (n_objects, n_features) = instance.shape context_utility = 0 # TODO for every object, compute its own utility and all pairwise utilities # context_utility = self.first_order_utility(instances).sum(axis=1) # context_utility = 0 result = ( - self.zeroth_order_utility(instance) + 1 / (n_objects - 1) * context_utility + self.zeroth_order_utility(instances) + 1 / (self.n_objects - 1) * context_utility ) - self.zeroth_order_utility(instance) + self.zeroth_order_utility(instances) return result[:,0] class FETARankingEstimator(BaseEstimator): - def __init__(self, lr=1e-5, batch_size=1): + def __init__(self, lr=1e-5, batch_size=64): self.lr = lr self.batch_size = batch_size - def fit(self, X, Y): - (n_samples, n_objects, n_features) = X.shape - self.model_ = FETAModel( - ZerothOrderUtility(n_features), FirstOrderUtility(n_features) - ).double() + def fit(self, dataset): + data_loader = DataLoader(dataset, batch_size=64) + # n_instances = len(dataset) + # (n_samples, n_objects, n_features) = X.shape + (n_objects, n_features) = dataset[0][0].shape + # self.model_ = FETAModel( + # ZerothOrderUtility(n_features), FirstOrderUtility(n_features), n_objects + # ).double() + self.model_ = ZerothOrderUtility(n_features).double() self.model_.train() self.loss_function_ = hinged_rank_loss self.optimizer_ = optim.SGD(self.model_.parameters(), lr=self.lr) - assert self.batch_size == 1 # for now - # Overfit to make sure the model can learn. - for i in range(1000): - for (sample, true_ranking) in zip(X, Y): - # scores should ideally match the ranks - scores = self.model_(sample) - loss = self.loss_function_(true_ranking[None, :], scores[None, :]) - print(loss) - loss.backward() - self.optimizer_.step() - self.optimizer_.zero_grad() + for (samples, true_rankings) in data_loader: + # scores should ideally match the ranks + # The model gives the score of each object in a singleton list; + # squeeze to remove that last dimension. + scores = self.model_(samples).squeeze(-1) + # print("==SAMPLES==") + # print(samples) + # print("==SCORES==") + # print(scores) + # print(true_rankings) + loss = self.loss_function_(true_rankings, scores) + # print("===DEBUG===") + # print(true_rankings) + # print("===DEBUG2===") + # print(scores) + # print("===DEBUG3===") + loss.sum().backward() + self.optimizer_.step() + self.optimizer_.zero_grad() + self.model_.eval() def rank(self, objects): """Rank based on evaluations.""" evaluations = self.model_(objects) - (_values, indices) = torch.sort(evaluations) + (_values, indices) = torch.sort(evaluations.squeeze()) return indices -def trivial_ranking_problem(n_objects, n_instances, random_state): +class TrivialRankingProblem(TensorDataset): """Generate a trivial ranking problem for testing purposes. Each object has only one feature, and the true ranking is determined by @@ -119,7 +158,12 @@ def trivial_ranking_problem(n_objects, n_instances, random_state): Consider this generated problem: - >>> (x, y_true) = trivial_ranking_problem(n_objects=2, n_instances=2, random_state=np.random.RandomState(42)) + >>> dataset = TrivialRankingProblem( + ... n_objects=2, + ... n_instances=2, + ... random_state=np.random.RandomState(42), + ... ) + >>> x, y_true = dataset.tensors >>> x tensor([[[ 0.4967], [-0.1383]], @@ -133,25 +177,36 @@ def trivial_ranking_problem(n_objects, n_instances, random_state): tensor([[1, 0], [0, 1]]) """ - n_features = 1 - x = torch.tensor( - random_state.randn(n_instances, n_objects, n_features) - ) - y_true = x.argsort(axis=1).argsort(axis=1).squeeze(axis=-1) - return x, y_true + def __init__(self, n_objects, n_instances, random_state): + n_features = 1 + x = torch.tensor( + random_state.randn(n_instances, n_objects, n_features) + ) + y_true = x.argsort(axis=1).argsort(axis=1).squeeze(axis=-1) + super().__init__(x, y_true) def _main(): n_objects = 3 - n_instances = 100 + n_instances = 100000 + batch_size = 64 random_state = np.random.RandomState(42) - (object_lists, true_rankings) = trivial_ranking_problem( + + train_ds = TrivialRankingProblem( + n_objects, n_instances, random_state + ) + test_ds = TrivialRankingProblem( n_objects, n_instances, random_state ) - estimator = FETARankingEstimator() - estimator.fit(object_lists, true_rankings) - print(estimator.rank(object_lists[0]), true_rankings[0]) + + estimator.fit(train_ds) + print("===INPUT===") + print(test_ds[0]) + print("===SCORES===") + print(estimator.model_(test_ds[0][0])) + print("===RANKING===") + print(estimator.rank(test_ds[0][0])) if __name__ == "__main__":