diff --git a/docs/en/reference/sql/ddl/SET_STATEMENT.md b/docs/en/reference/sql/ddl/SET_STATEMENT.md
index c2e8416fc67..6c0e83de75a 100644
--- a/docs/en/reference/sql/ddl/SET_STATEMENT.md
+++ b/docs/en/reference/sql/ddl/SET_STATEMENT.md
@@ -34,6 +34,7 @@ The following format is also equivalent.
| @@session.enable_trace|@@enable_trace | When the value is `true`, an error message stack will be printed when the SQL statement has a syntax error or an error occurs during the plan generation process.
When the value is `false`, only the basic error message will be printed if there is a SQL syntax error or an error occurs during the plan generation process. | `true`,
`false` | `false` |
| @@session.sync_job|@@sync_job | When the value is `true`, the offline command will be executed synchronously, waiting for the final result of the execution.
When the value is `false`, the offline command returns immediately. If you need to check the execution, please use `SHOW JOB` command. | `true`,
`false` | `false` |
| @@session.sync_timeout|@@sync_timeout | When `sync_job=true`, you can configure the waiting time for synchronization commands. The timeout will return immediately. After the timeout returns, you can still view the command execution through `SHOW JOB`. | Int | 20000 |
+| @@session.spark_config|@@spark_config | Set the Spark configuration for offline jobs, configure like 'spark.executor.memory=2g;spark.executor.cores=2'. Notice that the priority of this Spark configuration is higer than TaskManager Spark configuration but lower than CLI Spark configuration file. | String | "" |
## Example
diff --git a/docs/zh/openmldb_sql/ddl/SET_STATEMENT.md b/docs/zh/openmldb_sql/ddl/SET_STATEMENT.md
index 0284e37e19f..1b513913e10 100644
--- a/docs/zh/openmldb_sql/ddl/SET_STATEMENT.md
+++ b/docs/zh/openmldb_sql/ddl/SET_STATEMENT.md
@@ -35,7 +35,7 @@ sessionVariableName ::= '@@'Identifier | '@@session.'Identifier | '@@global.'Ide
| @@session.enable_trace|@@enable_trace | 当该变量值为 `true`,SQL语句有语法错误或者在计划生成过程发生错误时,会打印错误信息栈。
当该变量值为 `false`,SQL语句有语法错误或者在计划生成过程发生错误时,仅打印基本错误信息。 | "true" \| "false" | "false" |
| @@session.sync_job|@@sync_job | 当该变量值为 `true`,离线的命令将变为同步,等待执行的最终结果。
当该变量值为 `false`,离线的命令即时返回,若要查看命令的执行情况,请使用`SHOW JOB`。 | "true" \| "false" | "false" |
| @@session.job_timeout|@@job_timeout | 可配置离线异步命令或离线管理命令的等待时间(以*毫秒*为单位),将立即返回。离线异步命令返回后仍可通过`SHOW JOB`查看命令执行情况。 | Int | "20000" |
-
+| @@session.spark_config|@@spark_config | 设置离线任务的 Spark 参数,配置项参考 'spark.executor.memory=2g;spark.executor.cores=2'。注意此 Spark 配置优先级高于 TaskManager 默认 Spark 配置,低于命令行的 Spark 配置文件。 | String | "" |
## Example
### 设置和显示会话系统变量
diff --git a/src/sdk/sql_cluster_router.cc b/src/sdk/sql_cluster_router.cc
index 1a55e94fb2e..bb4793dfb56 100644
--- a/src/sdk/sql_cluster_router.cc
+++ b/src/sdk/sql_cluster_router.cc
@@ -19,8 +19,10 @@
#include
#include
#include
+#include
#include
#include
+#include
#include
#include
@@ -319,6 +321,7 @@ bool SQLClusterRouter::Init() {
session_variables_.emplace("enable_trace", "false");
session_variables_.emplace("sync_job", "false");
session_variables_.emplace("job_timeout", "60000"); // rpc request timeout for taskmanager
+ session_variables_.emplace("spark_config", "");
}
return true;
}
@@ -2980,7 +2983,7 @@ std::shared_ptr SQLClusterRouter::ExecuteOfflineQuery(
bool is_sync_job, int job_timeout,
::hybridse::sdk::Status* status) {
RET_IF_NULL_AND_WARN(status, "output status is nullptr");
- std::map config;
+ std::map config = ParseSparkConfigString(GetSparkConfig());
ReadSparkConfFromFile(std::dynamic_pointer_cast(options_)->spark_conf_path, &config);
if (is_sync_job) {
@@ -3049,6 +3052,16 @@ int SQLClusterRouter::GetJobTimeout() {
return 60000;
}
+std::string SQLClusterRouter::GetSparkConfig() {
+ std::lock_guard<::openmldb::base::SpinMutex> lock(mu_);
+ auto it = session_variables_.find("spark_config");
+ if (it != session_variables_.end()) {
+ return it->second;
+ }
+
+ return "";
+}
+
::hybridse::sdk::Status SQLClusterRouter::SetVariable(hybridse::node::SetPlanNode* node) {
std::string key = node->Key();
std::transform(key.begin(), key.end(), key.begin(), ::tolower);
@@ -3083,6 +3096,13 @@ ::hybridse::sdk::Status SQLClusterRouter::SetVariable(hybridse::node::SetPlanNod
if (!absl::SimpleAtoi(value, &new_timeout)) {
return {StatusCode::kCmdError, "Fail to parse value, can't set the request timeout"};
}
+ } else if (key == "spark_config") {
+ if (!CheckSparkConfigString(value)) {
+ return {
+ StatusCode::kCmdError,
+ "Fail to parse spark config, set like 'spark.executor.memory=2g;spark.executor.cores=2'"
+ };
+ }
} else {
return {};
}
@@ -3090,6 +3110,20 @@ ::hybridse::sdk::Status SQLClusterRouter::SetVariable(hybridse::node::SetPlanNod
return {};
}
+bool SQLClusterRouter::CheckSparkConfigString(const std::string& input) {
+ std::istringstream iss(input);
+ std::string keyValue;
+
+ while (std::getline(iss, keyValue, ';')) {
+ // Check if the substring starts with "spark."
+ if (keyValue.find("spark.") != 0) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
::hybridse::sdk::Status SQLClusterRouter::ParseNamesFromArgs(const std::string& db,
const std::vector& args, std::string* db_name, std::string* name) {
if (args.size() == 1) {
@@ -4523,6 +4557,34 @@ bool SQLClusterRouter::CheckTableStatus(const std::string& db, const std::string
return check_succeed;
}
+std::map SQLClusterRouter::ParseSparkConfigString(const std::string& input) {
+ std::map configMap;
+
+ std::istringstream iss(input);
+ std::string keyValue;
+
+ while (std::getline(iss, keyValue, ';')) {
+ // Split the key-value pair
+ size_t equalPos = keyValue.find('=');
+ if (equalPos != std::string::npos) {
+ std::string key = keyValue.substr(0, equalPos);
+ std::string value = keyValue.substr(equalPos + 1);
+
+ // Check if the key starts with "spark."
+ if (key.find("spark.") == 0) {
+ // Add to the map
+ configMap[key] = value;
+ } else {
+ std::cerr << "Error: Key does not start with 'spark.' - " << key << std::endl;
+ }
+ } else {
+ std::cerr << "Error: Invalid key-value pair - " << keyValue << std::endl;
+ }
+ }
+
+ return configMap;
+}
+
void SQLClusterRouter::ReadSparkConfFromFile(std::string conf_file_path, std::map* config) {
if (!conf_file_path.empty()) {
boost::property_tree::ptree pt;
diff --git a/src/sdk/sql_cluster_router.h b/src/sdk/sql_cluster_router.h
index f5661c9a1bb..b5854fe7ab3 100644
--- a/src/sdk/sql_cluster_router.h
+++ b/src/sdk/sql_cluster_router.h
@@ -283,6 +283,12 @@ class SQLClusterRouter : public SQLRouter {
// get job timeout from the session variables, we will use the timeout when sending requests to the taskmanager
int GetJobTimeout();
+ std::string GetSparkConfig();
+
+ std::map ParseSparkConfigString(const std::string& input);
+
+ bool CheckSparkConfigString(const std::string& input);
+
::openmldb::base::Status ExecuteOfflineQueryAsync(const std::string& sql,
const std::map& config,
const std::string& default_db, int job_timeout,