Skip to content

Commit

Permalink
feat: support custom order in insert stmt
Browse files Browse the repository at this point in the history
  • Loading branch information
vagetablechicken committed Jul 5, 2022
1 parent d5a6a1f commit cb2f0bd
Show file tree
Hide file tree
Showing 11 changed files with 710 additions and 516 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package com._4paradigm.openmldb.common;

import java.io.Serializable;

public class Pair<K, V> implements Serializable {

/**
* Key of this <code>Pair</code>.
*/
private K key;

/**
* Gets the key for this pair.
*
* @return key for this pair
*/
public K getKey() {
return key;
}

/**
* Value of this this <code>Pair</code>.
*/
private V value;

/**
* Gets the value for this pair.
*
* @return value for this pair
*/
public V getValue() {
return value;
}

/**
* Creates a new pair
*
* @param key The key for this pair
* @param value The value to use for this pair
*/
public Pair(K key, V value) {
this.key = key;
this.value = value;
}

/**
* <p><code>String</code> representation of this
* <code>Pair</code>.</p>
*
* <p>The default name/value delimiter '=' is always used.</p>
*
* @return <code>String</code> representation of this <code>Pair</code>
*/
@Override
public String toString() {
return key + "=" + value;
}

/**
* <p>Generate a hash code for this <code>Pair</code>.</p>
*
* <p>The hash code is calculated using both the name and
* the value of the <code>Pair</code>.</p>
*
* @return hash code for this <code>Pair</code>
*/
@Override
public int hashCode() {
// name's hashCode is multiplied by an arbitrary prime number (13)
// in order to make sure there is a difference in the hashCode between
// these two parameters:
// name: a value: aa
// name: aa value: a
return key.hashCode() * 13 + (value == null ? 0 : value.hashCode());
}

/**
* <p>Test this <code>Pair</code> for equality with another
* <code>Object</code>.</p>
*
* <p>If the <code>Object</code> to be tested is not a
* <code>Pair</code> or is <code>null</code>, then this method
* returns <code>false</code>.</p>
*
* <p>Two <code>Pair</code>s are considered equal if and only if
* both the names and values are equal.</p>
*
* @param o the <code>Object</code> to test for
* equality with this <code>Pair</code>
* @return <code>true</code> if the given <code>Object</code> is
* equal to this <code>Pair</code> else <code>false</code>
*/
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o instanceof Pair) {
Pair pair = (Pair) o;
if (key != null ? !key.equals(pair.key) : pair.key != null) return false;
if (value != null ? !value.equals(pair.value) : pair.value != null) return false;
return true;
}
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@

package com._4paradigm.openmldb.jdbc;

import static com._4paradigm.openmldb.sdk.impl.Util.sqlTypeToString;

import com._4paradigm.openmldb.DataType;
import com._4paradigm.openmldb.Schema;
import com._4paradigm.openmldb.common.Pair;
import com._4paradigm.openmldb.sdk.Common;

import java.sql.ResultSetMetaData;
Expand All @@ -28,10 +31,11 @@ public class SQLInsertMetaData implements ResultSetMetaData {

private final List<DataType> schema;
private final Schema realSchema;
private final List<Integer> idx;
private final List<Pair<Long, Integer>> idx;

public SQLInsertMetaData(List<DataType> schema,
Schema realSchema,
List<Integer> idx) {
List<Pair<Long, Integer>> idx) {
this.schema = schema;
this.realSchema = realSchema;
this.idx = idx;
Expand Down Expand Up @@ -90,7 +94,7 @@ public boolean isCurrency(int i) throws SQLException {
@Override
public int isNullable(int i) throws SQLException {
check(i);
int index = idx.get(i - 1);
Long index = idx.get(i - 1).getKey();
if (realSchema.IsColumnNotNull(index)) {
return columnNoNulls;
} else {
Expand Down Expand Up @@ -119,7 +123,7 @@ public String getColumnLabel(int i) throws SQLException {
@Override
public String getColumnName(int i) throws SQLException {
check(i);
int index = idx.get(i - 1);
Long index = idx.get(i - 1).getKey();
return realSchema.GetColumnName(index);
}

Expand Down Expand Up @@ -156,14 +160,13 @@ public String getCatalogName(int i) throws SQLException {
@Override
public int getColumnType(int i) throws SQLException {
check(i);
DataType dataType = schema.get(i - 1);
return Common.type2SqlType(dataType);
Long index = idx.get(i - 1).getKey();
return Common.type2SqlType(realSchema.GetColumnType(index));
}

@Override
@Deprecated
public String getColumnTypeName(int i) throws SQLException {
throw new SQLException("current do not support this method");
return sqlTypeToString(getColumnType(i));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import com._4paradigm.openmldb.*;

import com._4paradigm.openmldb.common.Pair;
import com._4paradigm.openmldb.jdbc.SQLInsertMetaData;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -32,6 +33,7 @@
import java.sql.Date;
import java.sql.ResultSet;
import java.util.*;
import java.util.stream.Collectors;

public class InsertPreparedStatementImpl implements PreparedStatement {
public static final Charset CHARSET = StandardCharsets.UTF_8;
Expand All @@ -48,7 +50,8 @@ public class InsertPreparedStatementImpl implements PreparedStatement {
private final List<Object> currentDatas;
private final List<DataType> currentDatasType;
private final List<Boolean> hasSet;
private final List<Integer> scehmaIdxs;
// stmt insert idx -> real table schema idx
private final List<Pair<Long, Integer>> schemaIdxes;

private boolean closed = false;
private boolean closeOnComplete = false;
Expand All @@ -63,17 +66,24 @@ public InsertPreparedStatementImpl(String db, String sql, SQLRouter router) thro
this.currentSchema = tempRow.GetSchema();
VectorUint32 idxes = tempRow.GetHoleIdx();

// In stmt order, if no columns in stmt, in schema order
// We'll sort it to schema order later, so needs the map <real_schema_idx, current_data_idx>
schemaIdxes = new ArrayList<>(idxes.size());
for (int i = 0; i < idxes.size(); i++) {
schemaIdxes.add(new Pair<>(idxes.get(i), i));
}

currentDatas = new ArrayList<>(idxes.size());
currentDatasType = new ArrayList<>(idxes.size());
hasSet = new ArrayList<>(idxes.size());
scehmaIdxs = new ArrayList<>(idxes.size());
for (int i = 0; i < idxes.size(); i++) {
long idx = idxes.get(i);
// CurrentDatas and Type order is consistent with insert stmt. We'll do appending in schema order when build
// row.
for (Pair<Long, Integer> pair : schemaIdxes) {
Long idx = pair.getKey();
DataType type = currentSchema.GetColumnType(idx);
currentDatasType.add(type);
currentDatas.add(null);
hasSet.add(false);
scehmaIdxs.add(i);
}
}

Expand Down Expand Up @@ -118,14 +128,14 @@ private void checkIdx(int i) throws SQLException {
if (i <= 0) {
throw new SQLException("error sqe number");
}
if (i > scehmaIdxs.size()) {
if (i > schemaIdxes.size()) {
throw new SQLException("out of data range");
}
}

private void checkType(int i, DataType type) throws SQLException {
if (currentDatasType.get(i - 1) != type) {
throw new SQLException("data type not match");
throw new SQLException("data type not match, expect " + currentDatasType.get(i - 1) + ", actual " + type);
}
}

Expand Down Expand Up @@ -206,7 +216,7 @@ public void setBigDecimal(int i, BigDecimal bigDecimal) throws SQLException {
}

private boolean checkNotAllowNull(int i) {
long idx = this.scehmaIdxs.get(i - 1);
Long idx = this.schemaIdxes.get(i - 1).getKey();
return this.currentSchema.IsColumnNotNull(idx);
}

Expand Down Expand Up @@ -300,22 +310,26 @@ public void setObject(int i, Object o, int i1) throws SQLException {

private void buildRow() throws SQLException {
SQLInsertRow currentRow = getSQLInsertRow();

boolean ok = currentRow.Init(stringsLen);
if (!ok) {
throw new SQLException("init row failed");
}

for (int i = 0; i < currentDatasType.size(); i++) {
Object data = currentDatas.get(i);
// SQLInsertRow::AppendXXX order is the schema order(skip the no-hole columns)
List<Pair<Long, Integer>> sortedIdxes = schemaIdxes.stream().sorted(Comparator.comparing(Pair::getKey))
.collect(Collectors.toList());

for (Pair<Long, Integer> sortedIdx : sortedIdxes) {
Integer currentDataIdx = sortedIdx.getValue();
Object data = currentDatas.get(currentDataIdx);
if (data == null) {
ok = currentRow.AppendNULL();
} else {
DataType curType = currentDatasType.get(i);
DataType curType = currentDatasType.get(currentDataIdx);
if (DataType.kTypeBool.equals(curType)) {
ok = currentRow.AppendBool((boolean) data);
} else if (DataType.kTypeDate.equals(curType)) {
java.sql.Date date = (java.sql.Date) data;
Date date = (Date) data;
ok = currentRow.AppendDate(date.getYear() + 1900, date.getMonth() + 1, date.getDate());
} else if (DataType.kTypeDouble.equals(curType)) {
ok = currentRow.AppendDouble((double) data);
Expand Down Expand Up @@ -423,9 +437,8 @@ public void setArray(int i, Array array) throws SQLException {
}

@Override
@Deprecated
public ResultSetMetaData getMetaData() throws SQLException {
return new SQLInsertMetaData(this.currentDatasType, this.currentSchema, this.scehmaIdxs);
return new SQLInsertMetaData(this.currentDatasType, this.currentSchema, this.schemaIdxes);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.sql.Types;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.IntStream;
Expand Down Expand Up @@ -183,7 +184,7 @@ public void testForKafkaConnector() throws SQLException {
String tableName = "kafka_test";
stmt = connection.createStatement();
try {
stmt.execute(String.format("create table if not exists %s(c1 int, c2 string)", tableName));
stmt.execute(String.format("create table if not exists %s(c1 int, c2 string, c3 timestamp)", tableName));
} catch (Exception e) {
Assert.fail();
}
Expand All @@ -198,6 +199,15 @@ public void testForKafkaConnector() throws SQLException {
pstmt.setFetchSize(100);

pstmt.addBatch();
insertSql = "INSERT INTO " +
tableName +
"(`c3`,`c2`) VALUES(?,?)";
pstmt = connection.prepareStatement(insertSql);
Assert.assertEquals(pstmt.getMetaData().getColumnCount(), 2);
// index starts from 1
Assert.assertEquals(pstmt.getMetaData().getColumnType(2), Types.VARCHAR);
Assert.assertEquals(pstmt.getMetaData().getColumnName(2), "c2");


try {
stmt = connection.prepareStatement("DELETE FROM " + tableName + " WHERE c1=1");
Expand Down
Loading

0 comments on commit cb2f0bd

Please sign in to comment.