Skip to content

Commit

Permalink
Merge pull request #32524 from vespa-engine/interns/magnus/tensorexpr…
Browse files Browse the repository at this point in the history
…essions

Interns/magnus/tensorexpressions
  • Loading branch information
Mangern authored Oct 11, 2024
2 parents 4b679f0 + 3436d42 commit ec05917
Show file tree
Hide file tree
Showing 8 changed files with 186 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ public List<Symbol> findSymbols(Symbol symbol) {
public List<Symbol> findSymbols(Symbol scope, SymbolType type, String shortIdentifier) {
// First candidates are all symbols with correct type and correct short identifier

// Special case for schema and document because a schema can sometimes refer to a document and vice versa
if (type == SymbolType.SCHEMA || type == SymbolType.DOCUMENT) {
SymbolType firstCheck = (type == SymbolType.SCHEMA ? SymbolType.SCHEMA : SymbolType.DOCUMENT);
List<Symbol> schemaDefinitions =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ public static ParseResult parseContent(ParseContext context) {

diagnostics.addAll(ResolverTraversal.traverse(context, tolerantResult.CST().get()));

// Unresolved field arguments (type and indexing settings) need to be resolved after 'ResolverTraversal'
// because the indexing language is traversed in the ResolverTraversal
for (UnresolvedFieldArgument fieldArg : context.unresolvedFieldArguments()) {
Optional<Diagnostic> diagnostic = FieldArgumentResolver.resolveFieldArgument(context, fieldArg);
if (diagnostic.isPresent()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@
import ai.vespa.schemals.parser.ast.rootSchema;
import ai.vespa.schemals.parser.ast.structFieldDefinition;
import ai.vespa.schemals.parser.ast.tensorTypeElm;
import ai.vespa.schemals.parser.rankingexpression.ast.LCURLY;
import ai.vespa.schemals.parser.rankingexpression.ast.lambdaFunction;
import ai.vespa.schemals.parser.rankingexpression.ast.tensorType;
import ai.vespa.schemals.parser.rankingexpression.ast.tensorTypeDimension;
import ai.vespa.schemals.tree.CSTUtils;
import ai.vespa.schemals.tree.SchemaNode;
import ai.vespa.schemals.tree.SchemaNode.LanguageType;
Expand All @@ -52,6 +55,11 @@ public IdentifySymbolDefinition(ParseContext context) {
}


/**
* Marks the node as a symbol with SymbolStatus DEFINITION
* It is mainly based on the node type being an identifier, and the parent being of a certain type.
* But in a lot of cases we need to check more.
*/
public ArrayList<Diagnostic> identify(SchemaNode node) {
ArrayList<Diagnostic> ret = new ArrayList<Diagnostic>();

Expand All @@ -72,94 +80,107 @@ public ArrayList<Diagnostic> identify(SchemaNode node) {
SchemaNode parent = node.getParent();
if (parent == null) return ret;

if (parent.isASTInstance(importField.class) && node.getPreviousSibling() != null && node.getPreviousSibling().isASTInstance(AS.class)) {
createSymbol(node, SymbolType.FIELD);
if (handleSpecialCases(node, parent, ret)) {
return ret;
}

// Prevent inheritance from beeing marked as a definition
if (parent.indexOf(node) >= 3) {
// Unnless it is a paramenter to a function
if (parent.isASTInstance(functionElm.class) && node.isASTInstance(identifierStr.class)) {
createSymbol(node, SymbolType.PARAMETER);
}
Map<Class<?>, SymbolType> searchMap = isIdentifier ? SchemaIndex.IDENTIFIER_TYPE_MAP : SchemaIndex.IDENTIFIER_WITH_DASH_TYPE_MAP;
SymbolType symbolType = searchMap.get(parent.getASTClass());
if (symbolType == null) return ret;

// Root item, should not have a scope
if (parent.isASTInstance(namedDocument.class) || parent.isASTInstance(rootSchema.class)) {
node.setSymbol(symbolType, context.fileURI());
node.setSymbolStatus(SymbolStatus.DEFINITION);
context.schemaIndex().insertSymbolDefinition(node.getSymbol());
return ret;
}

Map<Class<?>, SymbolType> searchMap = isIdentifier ? SchemaIndex.IDENTIFIER_TYPE_MAP : SchemaIndex.IDENTIFIER_WITH_DASH_TYPE_MAP;
SymbolType symbolType = searchMap.get(parent.getASTClass());
if (symbolType != null) {
Optional<Symbol> scope = CSTUtils.findScope(node);
if (scope.isEmpty()) {
if (symbolType == SymbolType.RANK_PROFILE && parent.getParent() != null && parent.getParent().isASTInstance(RootRankProfile.class)) {
// we are in a rank-profile file (.profile)
String workspaceRootURI = context.scheduler().getWorkspaceURI();
if (workspaceRootURI == null) return ret;
String currentURI = context.fileURI();

if (parent.isASTInstance(namedDocument.class) || parent.isASTInstance(rootSchema.class)) {
node.setSymbol(symbolType, context.fileURI());
node.setSymbolStatus(SymbolStatus.DEFINITION);
context.schemaIndex().insertSymbolDefinition(node.getSymbol());
return ret;
}
String schemaName = FileUtils.firstPathComponentAfterPrefix(currentURI, workspaceRootURI);

Optional<Symbol> scope = CSTUtils.findScope(node);
if (scope.isEmpty()) {
if (symbolType == SymbolType.RANK_PROFILE && parent.getParent() != null && parent.getParent().isASTInstance(RootRankProfile.class)) {
// we are in a rank-profile file (.profile)
String workspaceRootURI = context.scheduler().getWorkspaceURI();
if (workspaceRootURI == null) return ret;
String currentURI = context.fileURI();
if (schemaName == null) return ret;

String schemaName = FileUtils.firstPathComponentAfterPrefix(currentURI, workspaceRootURI);
Optional<Symbol> schemaSymbol = context.schemaIndex().getSchemaDefinition(schemaName);

if (schemaName == null) return ret;
if (schemaSymbol.isEmpty()) return ret;

Optional<Symbol> schemaSymbol = context.schemaIndex().getSchemaDefinition(schemaName);
// TODO: rank-profile belonging to namedDocument??
node.setSymbol(symbolType, context.fileURI(), schemaSymbol.get());
node.setSymbolStatus(SymbolStatus.DEFINITION);
context.schemaIndex().insertSymbolDefinition(node.getSymbol());
}
return ret;
}

if (schemaSymbol.isEmpty()) return ret;
node.setSymbol(symbolType, context.fileURI(), scope.get());

// TODO: rank-profile belonging to namedDocument??
node.setSymbol(symbolType, context.fileURI(), schemaSymbol.get());
node.setSymbolStatus(SymbolStatus.DEFINITION);
context.schemaIndex().insertSymbolDefinition(node.getSymbol());
}
return ret;
}
// Check if this is an invalid 'redefinition' of existing identifier.
Optional<Symbol> existingSymbol = context.schemaIndex().findSymbolInScope(node.getSymbol());

node.setSymbol(symbolType, context.fileURI(), scope.get());
if (existingSymbol.isEmpty()) {
node.setSymbolStatus(SymbolStatus.DEFINITION);
context.schemaIndex().insertSymbolDefinition(node.getSymbol());

Optional<Symbol> existingSymbol = context.schemaIndex().findSymbolInScope(node.getSymbol());
if (node.getSymbol().getType() == SymbolType.FUNCTION) {
verifySymbolFunctionName(node, ret);
}
return ret;
}

if (existingSymbol.isEmpty()) {
node.setSymbolStatus(SymbolStatus.DEFINITION);
context.schemaIndex().insertSymbolDefinition(node.getSymbol());
node.setSymbolStatus(SymbolStatus.INVALID);

if (node.getSymbol().getType() == SymbolType.FUNCTION) {
verifySymbolFunctionName(node, ret);
}
if (symbolType == SymbolType.FIELD) {
Range range = null;

} else {
node.setSymbolStatus(SymbolStatus.INVALID);
if (parent.getParent().isASTInstance(fieldOutsideDoc.class)) {
range = node.getRange();
} else if (!context.fieldIndex().getIsInsideDoc(existingSymbol.get())) {
range = existingSymbol.get().getNode().getRange();
}

if (range != null)
ret.add(new SchemaDiagnostic.Builder()
.setRange(range)
.setMessage("Field '" + node.getText() + "' shadows a document field with the same name.")
.setSeverity(DiagnosticSeverity.Warning)
.build());
}

if (symbolType == SymbolType.FIELD) {
Range range = null;
return ret;
}

if (parent.getParent().isASTInstance(fieldOutsideDoc.class)) {
range = node.getRange();
} else if (!context.fieldIndex().getIsInsideDoc(existingSymbol.get())) {
range = existingSymbol.get().getNode().getRange();
}
/**
* @return true if it was a special case that should require early return of {@link IdentifySymbolDefinition#identify}.
*/
private boolean handleSpecialCases(SchemaNode node, SchemaNode parent, List<Diagnostic> diagnostics) {
// import ... as <DEFINITION>
if (parent.isASTInstance(importField.class) && node.getPreviousSibling() != null && node.getPreviousSibling().isASTInstance(AS.class)) {
createSymbol(node, SymbolType.FIELD);
return true;
}

if (range != null)
ret.add(new SchemaDiagnostic.Builder()
.setRange(range)
.setMessage("Field '" + node.getText() + "' shadows a document field with the same name.")
.setSeverity(DiagnosticSeverity.Warning)
.build());
}
}
// function <FUNCTION-DEFINITION>(<PARAMETER-DEFINITION>, <PARAMETER-DEFINITION>, ...) { ... }
//
if (parent.indexOf(node) >= 3 && parent.isASTInstance(functionElm.class) && node.isASTInstance(identifierStr.class)) {
createSymbol(node, SymbolType.PARAMETER);
return true;
}

return ret;
// Prevent inheritance from being marked as a definition
// <keyword> <DEFINITION> inherits <NOT-DEFINITION>, <NOT-DEFINITION> ...
if (parent.indexOf(node) >= 3) {
return true;
}

return ret;
return false;
}

/**
Expand Down Expand Up @@ -256,28 +277,38 @@ private ArrayList<Diagnostic> identifyDefinitionInRankExpression(SchemaNode node
return ret;
}

SchemaNode grandParent = node.getParent(2);
if (grandParent == null || !grandParent.isASTInstance(lambdaFunction.class) || grandParent.size() < 1) {
SchemaNode parent = node.getParent();
if (parent == null) return ret;

SchemaNode grandParent = parent.getParent();
if (grandParent == null) return ret;

if (parent.isASTInstance(tensorTypeDimension.class) && grandParent.isASTInstance(tensorType.class)) {
handleTensorTypeDefinitions(node, grandParent, ret);
return ret;
}

SchemaNode parent = grandParent.get(0);
if (!grandParent.isASTInstance(lambdaFunction.class) || grandParent.size() < 1) {
return ret;
}

if (!parent.hasSymbol()) {
// This is specific to lambda function definitions
SchemaNode lambdaDefinitionNode = grandParent.get(0);
if (!lambdaDefinitionNode.hasSymbol()) {

Optional<Symbol> parentScope = CSTUtils.findScope(parent);

if (parentScope.isEmpty()) {
return ret;
}

parent.setSymbol(SymbolType.LAMBDA_FUNCTION, context.fileURI(), parentScope.get(), "lambda_" + node.hashCode());
parent.setSymbolStatus(SymbolStatus.DEFINITION);
context.schemaIndex().insertSymbolDefinition(parent.getSymbol());
lambdaDefinitionNode.setSymbol(SymbolType.LAMBDA_FUNCTION, context.fileURI(), parentScope.get(), "lambda_" + node.hashCode());
lambdaDefinitionNode.setSymbolStatus(SymbolStatus.DEFINITION);
context.schemaIndex().insertSymbolDefinition(lambdaDefinitionNode.getSymbol());
}


node.setSymbol(SymbolType.PARAMETER, context.fileURI(), parent.getSymbol());
node.setSymbol(SymbolType.PARAMETER, context.fileURI(), lambdaDefinitionNode.getSymbol());

if (context.schemaIndex().findSymbolsInScope(node.getSymbol()).size() == 0) {
node.setSymbolStatus(SymbolStatus.DEFINITION);
Expand All @@ -289,6 +320,50 @@ private ArrayList<Diagnostic> identifyDefinitionInRankExpression(SchemaNode node
return ret;
}

/**
* For rank expressions.
* Example:
* tensor<float>(d0[1], d1[10])
* d0, and d1 should be marked as definition because they can be referenced in the body.
* For this to work, the tensor type itself should also be marked as a definition, in order to give the dimensions a scope.
* We will use the hashCode trick to give the tensorType node an unique identifier.
*/
private void handleTensorTypeDefinitions(SchemaNode identifierNode, SchemaNode tensorTypeNode, List<Diagnostic> diagnostics) {
Optional<Symbol> parentScope = CSTUtils.findScope(tensorTypeNode.getParent());
if (parentScope.isEmpty()) return;

if (!tensorTypeNode.hasSymbol()) {
tensorTypeNode.setSymbol(SymbolType.TENSOR, context.fileURI(), parentScope.get(), "tensor_" + tensorTypeNode.hashCode());
tensorTypeNode.setSymbolStatus(SymbolStatus.DEFINITION);
context.schemaIndex().insertSymbolDefinition(tensorTypeNode.getSymbol());
}

Symbol scope = tensorTypeNode.getSymbol();
// TODO: better check of indexed versus mapped dimension type based on existing tensor parsing?
SymbolType dimensionType = SymbolType.TENSOR_DIMENSION_INDEXED;
if (identifierNode.getNextSibling() != null && identifierNode.getNextSibling().isASTInstance(LCURLY.class)) {
dimensionType = SymbolType.TENSOR_DIMENSION_MAPPED;
}

identifierNode.setSymbol(dimensionType, context.fileURI(), scope, identifierNode.getText());

Optional<Symbol> existingSymbolMapped = context.schemaIndex().findSymbolInScope(scope, SymbolType.TENSOR_DIMENSION_MAPPED, identifierNode.getText());
Optional<Symbol> existingSymbolIndexed = context.schemaIndex().findSymbolInScope(scope, SymbolType.TENSOR_DIMENSION_INDEXED, identifierNode.getText());

if (existingSymbolMapped.isPresent() || existingSymbolIndexed.isPresent()) {
identifierNode.setSymbolStatus(SymbolStatus.INVALID);
diagnostics.add(new SchemaDiagnostic.Builder()
.setRange(identifierNode.getRange())
.setMessage("Duplicate tensor dimension " + identifierNode.getText())
.setSeverity(DiagnosticSeverity.Error)
.build());
return;
}

identifierNode.setSymbolStatus(SymbolStatus.DEFINITION);
context.schemaIndex().insertSymbolDefinition(identifierNode.getSymbol());
}

private static final Set<String> reservedFunctionNames = ReservedFunctionNames.getReservedNames();
// TODO: Maybe add distance and bm25 to the list?
private void verifySymbolFunctionName(SchemaNode node, List<Diagnostic> diagnostics) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,14 @@
import ai.vespa.schemals.schemadocument.resolvers.RankExpression.argument.FieldArgument.UnresolvedFieldArgument;
import ai.vespa.schemals.tree.SchemaNode;

/**
* This checks that the type or indexing settings for a referenced field are correct.
* For instance, some uses require that a field has indexing 'attribute'. Other uses require that a field is of a certain type
* or set of types.
* The information about the required data is in class {@link ai.vespa.schemals.schemadocument.resolvers.RankExpression.argument.FieldArgument.UnresolvedFieldArgument}
* The resolver then looks up the field definition and chcks if the registered indexing settings are correct.
*/
public class FieldArgumentResolver {


public static Optional<Diagnostic> resolveFieldArgument(ParseContext context, UnresolvedFieldArgument fieldArgument) {

if (fieldArgument.indexingTypes().size() == 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ private static List<Diagnostic> traverseRankExpressionTree(RankNode node, ParseC
diagnostics.addAll(traverseRankExpressionTree(child, context));
}

// All feature nodes has a symbol before the parse
// All feature nodes has a symbol before the traversal
if (node.hasSymbol()) {

if (node.getSymbolStatus() == SymbolStatus.UNRESOLVED) {
Expand Down Expand Up @@ -144,6 +144,8 @@ private static void findBuiltInFunction(RankNode node, ParseContext context, Lis
// add(SymbolType.PARAMETER); // This is a special case
add(SymbolType.FUNCTION);
add(SymbolType.RANK_CONSTANT);
add(SymbolType.TENSOR_DIMENSION_MAPPED);
add(SymbolType.TENSOR_DIMENSION_INDEXED);
}};

private static void resolveReference(RankNode referenceNode, ParseContext context, List<Diagnostic> diagnostics) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@
import ai.vespa.schemals.index.SchemaIndex;
import ai.vespa.schemals.index.Symbol;
import ai.vespa.schemals.index.Symbol.SymbolStatus;
import ai.vespa.schemals.index.Symbol.SymbolType;
import ai.vespa.schemals.parser.Node;
import ai.vespa.schemals.parser.TokenSource;
import ai.vespa.schemals.parser.ast.documentElm;
import ai.vespa.schemals.parser.ast.functionElm;
import ai.vespa.schemals.parser.rankingexpression.ast.BaseNode;
import ai.vespa.schemals.parser.rankingexpression.ast.lambdaFunction;
import ai.vespa.schemals.parser.rankingexpression.ast.tensorGenerateBody;
import ai.vespa.schemals.parser.rankingexpression.ast.tensorType;
import ai.vespa.schemals.tree.SchemaNode.LanguageType;
import ai.vespa.schemals.tree.indexinglanguage.ILUtils;
import ai.vespa.schemals.tree.rankingexpression.RankingExpressionUtils;
Expand Down Expand Up @@ -194,13 +197,18 @@ public static Optional<Symbol> findScope(SchemaNode node) {
if (currentNode.isASTInstance(lambdaFunction.class)) {
indexGuess = 0;
}

if (indexGuess < currentNode.size()) {
SchemaNode potentialDefinition = currentNode.get(indexGuess);
if (potentialDefinition.hasSymbol() && potentialDefinition.getSymbol().getStatus() == SymbolStatus.DEFINITION) {
return Optional.of(potentialDefinition.getSymbol());
}
}
} else if (currentNode.isASTInstance(tensorGenerateBody.class) &&
currentNode.getPreviousSibling() != null &&
currentNode.getPreviousSibling().hasSymbol() &&
currentNode.getPreviousSibling().getSymbol().getType() == SymbolType.TENSOR) {
// Edge case for tensor type in rank expression
return Optional.of(currentNode.getPreviousSibling().getSymbol());
}

currentNode = currentNode.getParent();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ Stream<DynamicTest> generateBadFileTests() {

new BadFileTestCase("src/test/sdfiles/single/rankprofilefuncs.sd", 2),
new BadFileTestCase("src/test/sdfiles/single/onnxmodel.sd", 1),
new BadFileTestCase("src/test/sdfiles/single/tensorGenerate.sd", 2)
};

return Arrays.stream(tests)
Expand Down
Loading

0 comments on commit ec05917

Please sign in to comment.