Skip to content

Commit

Permalink
Improve type mismatch with subexpression error
Browse files Browse the repository at this point in the history
  • Loading branch information
valis committed Apr 18, 2024
1 parent 19f9a2a commit f72884a
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 30 deletions.
40 changes: 40 additions & 0 deletions base/src/main/java/org/arend/core/expr/Expression.java
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,46 @@ public Expression copy() {
return accept(new SubstVisitor(new ExprSubstitution(), LevelSubstitution.EMPTY), null);
}

public Expression copyStrict() {
return accept(new SubstVisitor(new ExprSubstitution(), LevelSubstitution.EMPTY) {
@Override
public Expression visitReference(ReferenceExpression expr, Void params) {
Expression result = super.visitReference(expr, params);
return result == expr ? new ReferenceExpression(expr.getBinding()) : result;
}

@Override
public Expression visitInferenceReference(InferenceReferenceExpression expr, Void params) {
Expression result = super.visitInferenceReference(expr, params);
return result == expr ? new InferenceReferenceExpression(expr.getVariable(), expr.getSubstExpression()) : result;
}

@Override
public Expression visitUniverse(UniverseExpression expr, Void params) {
Expression result = super.visitUniverse(expr, params);
return result == expr ? new UniverseExpression(expr.getSort()) : result;
}

@Override
public Expression visitError(ErrorExpression expr, Void params) {
Expression result = super.visitError(expr, params);
return result == expr ? new ErrorExpression(expr.getExpression(), expr.getGoalName(), expr.useExpression()) : result;
}

@Override
public Expression visitInteger(IntegerExpression expr, Void params) {
Expression result = super.visitInteger(expr, params);
return result != expr ? result : expr instanceof SmallIntegerExpression ? new SmallIntegerExpression(expr.getSmallInteger()) : new BigIntegerExpression(expr.getBigInteger());
}

@Override
public Expression visitString(StringExpression expr, Void params) {
Expression result = super.visitString(expr, params);
return result == expr ? new StringExpression(expr.getString()) : result;
}
}, null);
}

public final Expression subst(Binding binding, Expression substExpr) {
if (substExpr instanceof ReferenceExpression && ((ReferenceExpression) substExpr).getBinding() == binding) {
return this;
Expand Down
88 changes: 58 additions & 30 deletions base/src/main/java/org/arend/core/expr/visitor/CompareVisitor.java
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,8 @@ public boolean nonNormalizingCompare(Expression expr1, Expression expr2, Express

private boolean initResult(Expression expr1, Expression expr2) {
if (myNormalCompare && myResult == null) {
expr1 = expr1.copyStrict();
expr2 = expr2.copyStrict();
myResult = new Result(expr1, expr2, expr1, expr2);
}
return myResult != null;
Expand All @@ -223,6 +225,8 @@ private boolean initResult(Expression expr1, Expression expr2, boolean correctOr

private void initResult(Expression expr1, Expression expr2, Levels levels1, Levels levels2) {
if (myNormalCompare && myResult == null) {
expr1 = expr1.copyStrict();
expr2 = expr2.copyStrict();
myResult = new Result(expr1, expr2, expr1, expr2, levels1, levels2);
}
}
Expand Down Expand Up @@ -687,42 +691,45 @@ private boolean compareConstArray(FunCallExpression atExpr, Expression otherExpr
@Override
public Boolean visitFunCall(FunCallExpression expr1, Expression expr2, Expression type) {
if (expr1.getDefinition() == Prelude.ARRAY_INDEX) {
boolean ok;
if (expr2 instanceof FunCallExpression && ((FunCallExpression) expr2).getDefinition() == Prelude.ARRAY_INDEX) {
ok = visitDefCall(expr1, expr2) || compareConstArray(expr1, expr2, type, true) || compareConstArray((FunCallExpression) expr2, expr1, type, false);
if (visitDefCall(expr1, expr2) || compareConstArray(expr1, expr2, type, true) || compareConstArray((FunCallExpression) expr2, expr1, type, false)) {
return true;
}
} else {
ok = compareConstArray(expr1, expr2, type, true);
boolean ok = compareConstArray(expr1, expr2, type, true);
if (!ok) {
initResult(expr1, expr2);
}
return ok;
}
if (!ok) {
initResult(expr1, expr2);
} else {
if (visitDefCall(expr1, expr2)) {
return true;
}
return ok;
}

if (myResult == null) {
initResult(expr1, expr2);
} else {
if (!visitDefCall(expr1, expr2)) {
if (myResult == null) {
initResult(expr1, expr2);
} else {
if (myResult.index >= 0 && myResult.index < expr1.getDefCallArguments().size()) {
List<Expression> args = new ArrayList<>(expr1.getDefCallArguments());
args.set(myResult.index, myResult.wholeExpr1);
myResult.wholeExpr1 = FunCallExpression.make(expr1.getDefinition(), expr1.getLevels(), args);
} else {
myResult.wholeExpr1 = expr1;
}
FunCallExpression funCall2 = expr2.cast(FunCallExpression.class);
if (funCall2 != null && myResult.index >= 0 && myResult.index < funCall2.getDefCallArguments().size()) {
List<Expression> args = new ArrayList<>(funCall2.getDefCallArguments());
args.set(myResult.index, myResult.wholeExpr2);
myResult.wholeExpr2 = FunCallExpression.make(funCall2.getDefinition(), funCall2.getLevels(), args);
} else {
myResult.wholeExpr2 = expr2;
}
myResult.index = -1;
}
return false;
if (myResult.index >= 0 && myResult.index < expr1.getDefCallArguments().size()) {
List<Expression> args = new ArrayList<>(expr1.getDefCallArguments());
args.set(myResult.index, myResult.wholeExpr1);
myResult.wholeExpr1 = FunCallExpression.make(expr1.getDefinition(), expr1.getLevels(), args);
} else {
myResult.wholeExpr1 = expr1;
}
return true;
FunCallExpression funCall2 = expr2.cast(FunCallExpression.class);
if (funCall2 != null && myResult.index >= 0 && myResult.index < funCall2.getDefCallArguments().size()) {
List<Expression> args = new ArrayList<>(funCall2.getDefCallArguments());
args.set(myResult.index, myResult.wholeExpr2);
myResult.wholeExpr2 = FunCallExpression.make(funCall2.getDefinition(), funCall2.getLevels(), args);
} else {
myResult.wholeExpr2 = expr2;
}
myResult.index = -1;
}

return false;
}

private void restoreConCalls(List<Pair<ConCallExpression, ConCallExpression>> stack) {
Expand Down Expand Up @@ -1945,7 +1952,28 @@ private boolean compareClassInstances(Expression expr1, ClassCallExpression clas
mySubstitution.put(classCall2.getThisBinding(), prevBinding);
}
if (!ok) {
myResult = null;
if (expr1 instanceof NewExpression && expr2 instanceof NewExpression) {
if (myResult == null) {
initResult(expr1, expr2);
} else {
if (classCall1.isImplementedHere(field)) {
Map<ClassField, Expression> impls = new LinkedHashMap<>(classCall1.getImplementedHere());
impls.put(field, myResult.wholeExpr1);
myResult.wholeExpr1 = new NewExpression(null, new ClassCallExpression(classCall1.getDefinition(), classCall1.getLevels(), impls, classCall1.getSort(), classCall1.getUniverseKind()));
} else {
myResult.wholeExpr1 = expr1;
}
if (classCall2.isImplementedHere(field)) {
Map<ClassField, Expression> impls = new LinkedHashMap<>(classCall2.getImplementedHere());
impls.put(field, myResult.wholeExpr2);
myResult.wholeExpr2 = new NewExpression(null, new ClassCallExpression(classCall2.getDefinition(), classCall2.getLevels(), impls, classCall2.getSort(), classCall2.getUniverseKind()));
} else {
myResult.wholeExpr2 = expr2;
}
}
} else {
myResult = null;
}
return false;
}
}
Expand Down

0 comments on commit f72884a

Please sign in to comment.