Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Glebashnik/fix tensor adress hash code #32943

Merged
merged 7 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion vespajlib/src/main/java/com/yahoo/tensor/impl/LabelImpl.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.impl;


import com.google.common.hash.HashFunction;
import com.google.common.hash.Hashing;
import com.yahoo.tensor.Label;
import com.yahoo.tensor.Tensor;

Expand All @@ -10,17 +13,25 @@
* @author glebashnik
*/
class LabelImpl implements Label {
// Hash function with avalanche effect to avoid too many hash bucket collisions.
private static final HashFunction hashFunction = Hashing.murmur3_32_fixed();

private final long numeric;
private final String string;

// Caching the hash code to avoid recalculating it when cached labels are reused in multiple tensors.
private final int hashCode;

LabelImpl(long numeric) {
this.numeric = numeric;
this.string = null;
this.hashCode = hashFunction.hashLong(numeric).asInt();
}

LabelImpl(long numeric, String string) {
this.numeric = numeric;
this.string = string;
this.hashCode = hashFunction.hashLong(numeric).asInt();
}

@Override
Expand Down Expand Up @@ -57,6 +68,6 @@ public boolean equals(Object object) {

@Override
public int hashCode() {
return Long.hashCode(numeric);
return hashCode;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@ public TensorAddress withLabel(int labelIndex, long label) {
if (labelIndex == 0) return new TensorAddressAny1(LabelCache.GLOBAL.getOrCreateLabel(label));
throw new IllegalArgumentException("No label " + labelIndex);
}

@Override public int hashCode() { return (int)Math.abs(label.asNumeric()); }

@Override public int hashCode() {
return 31 + label.hashCode();
}

@Override
public boolean equals(Object o) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import com.yahoo.tensor.Label;
import com.yahoo.tensor.TensorAddress;

import static java.lang.Math.abs;

/**
* A two-dimensional address.
*
Expand Down Expand Up @@ -40,11 +38,12 @@ public TensorAddress withLabel(int labelIndex, long label) {
};
}

// Same as Objects.hash(...) but a little faster since it avoids creating an array, loop and null checks.
@Override
public int hashCode() {
long hash = abs(label0.asNumeric()) |
(abs(label1.asNumeric()) << (64 - Long.numberOfLeadingZeros(abs(label0.asNumeric()))));
return (int) hash;
return 31 * 31
+ 31 * label0.hashCode()
+ label1.hashCode();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import com.yahoo.tensor.Label;
import com.yahoo.tensor.TensorAddress;

import static java.lang.Math.abs;

/**
* A three-dimensional address.
*
Expand Down Expand Up @@ -43,12 +41,13 @@ public TensorAddress withLabel(int labelIndex, long label) {
};
}

// Same as Objects.hash(...) but a little faster since it avoids creating an array, loop and null checks.
@Override
public int hashCode() {
long hash = abs(label0.asNumeric()) |
(abs(label1.asNumeric()) << (64 - Long.numberOfLeadingZeros(abs(label0.asNumeric())))) |
(abs(label2.asNumeric()) << (2*64 - (Long.numberOfLeadingZeros(abs(label0.asNumeric())) + Long.numberOfLeadingZeros(abs(label1.asNumeric())))));
return (int) hash;
return 31 * 31 * 31
+ 31 * 31 * label0.hashCode()
+ 31 * label1.hashCode()
+ label2.hashCode();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import com.yahoo.tensor.Label;
import com.yahoo.tensor.TensorAddress;

import static java.lang.Math.abs;

/**
* A four-dimensional address.
*
Expand Down Expand Up @@ -46,13 +44,14 @@ public TensorAddress withLabel(int labelIndex, long label) {
};
}

// Same as Objects.hash(...) but a little faster since it avoids creating an array, loop and null checks.
@Override
public int hashCode() {
long hash = abs(label0.asNumeric()) |
(abs(label1.asNumeric()) << (64 - Long.numberOfLeadingZeros(abs(label0.asNumeric())))) |
(abs(label2.asNumeric()) << (2*64 - (Long.numberOfLeadingZeros(abs(label0.asNumeric())) + Long.numberOfLeadingZeros(abs(label1.asNumeric()))))) |
(abs(label3.asNumeric()) << (3*64 - (Long.numberOfLeadingZeros(abs(label0.asNumeric())) + Long.numberOfLeadingZeros(abs(label1.asNumeric())) + Long.numberOfLeadingZeros(abs(label1.asNumeric())))));
return (int) hash;
return 31 * 31 * 31 * 31
+ 31 * 31 * 31 * label0.hashCode()
+ 31 * 31 * label1.hashCode()
+ 31 * label2.hashCode()
+ label3.hashCode();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

import java.util.Arrays;

import static java.lang.Math.abs;

/**
* An n-dimensional address.
*
Expand Down Expand Up @@ -42,13 +40,15 @@ public TensorAddress withLabel(int labelIndex, long label) {
return new TensorAddressAnyN(copy);
}

// Same as Arrays.hashCode(labels) but without null checks.
@Override public int hashCode() {
long hash = abs(labels[0].asNumeric());
for (int i = 0; i < size(); i++) {
hash = hash | (abs(labels[i].asNumeric()) << (32 - Long.numberOfLeadingZeros(hash)));
}
return (int) hash;
}
int result = 1;

for (var label : labels)
result = 31 * result + label.hashCode();

return result;
}

@Override
public boolean equals(Object o) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ void testStringVersusNumericAddressEquality() {
void testInEquality() {
notEqual(ofLabels("1"), ofLabels("2"));
notEqual(of(1), of(2));
notEqual(ofLabels("1"), ofLabels("01"));
notEqual(ofLabels("0"), ofLabels("00"));
notEqual(ofLabels("1"), ofLabels("01"));
}
@Test
void testDimensionsEffectsEqualityAndHash() {
Expand Down Expand Up @@ -79,5 +79,22 @@ void testPartialCopy() {
int[] o_1_3_2 = {1,3,2};
equal(ofLabels("b", "d", "c"), abcd.partialCopy(o_1_3_2));
}

// This test was designed for a previous version of the hashCode to produce collisions, see:
// https://github.com/vespa-engine/vespa/blob/073cb004ce0c26da9eff4b8f4745e0174bc85424/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAnyN.java#L45
// Current implementation of the hashCode shouldn't produce collisions for this test.
@Test
void testHashCodeForLowEntropy() {
var e = TensorAddress.ofLabels("1", "4", "5", "6", "x", "y", "z");
var f = TensorAddress.ofLabels("1", "4", "5", "6", "x", "y", "z");
assertEquals(e.hashCode(), f.hashCode());

var a = TensorAddress.ofLabels("a", "b", "c", "d", "e", "f", "g");
var b = TensorAddress.ofLabels("a", "b", "c", "d", "e", "f", "g", "h", "i", "j");
assertNotEquals(a.hashCode(), b.hashCode());

var c = TensorAddress.ofLabels("1", "4", "5", "6", "x", "y", "z");
var d = TensorAddress.ofLabels("1", "3", "5", "7", "z", "b", "c", "d", "e", "f");
assertNotEquals(c.hashCode(), d.hashCode());
glebashnik marked this conversation as resolved.
Show resolved Hide resolved
}
}