diff --git a/parquet_s3_fdw.h b/parquet_s3_fdw.h index f788342..23de92d 100644 --- a/parquet_s3_fdw.h +++ b/parquet_s3_fdw.h @@ -51,6 +51,8 @@ typedef struct parquet_s3_server_opt bool use_minio; /* Connect to MinIO instead of Amazon S3. */ bool keep_connections; /* setting value of keep_connections * server option */ + bool use_credential_provider; /* Retrieve AWS credentials using + * AWS Credential providers shhnwz */ char *region; /* AWS region to connect to */ char *endpoint; /* Address and port to connect to */ } parquet_s3_server_opt; @@ -71,6 +73,7 @@ int ExecForeignDDL(Oid serverOid, /* Option name for CREATE FOREIGN SERVER. */ #define SERVER_OPTION_USE_MINIO "use_minio" #define SERVER_OPTION_KEEP_CONNECTIONS "keep_connections" +#define SERVER_OPTION_USE_CREDENTIAL_PROVIDER "use_credential_provider" //shhnwz #define SERVER_OPTION_REGION "region" #define SERVER_OPTION_ENDPOINT "endpoint" diff --git a/parquet_s3_fdw.hpp b/parquet_s3_fdw.hpp index 902cc46..448be54 100644 --- a/parquet_s3_fdw.hpp +++ b/parquet_s3_fdw.hpp @@ -58,7 +58,7 @@ typedef enum FileLocation_t /* * We would like to cache FileReader. When creating new hash entry, * the memory of entry is allocated by PostgreSQL core. But FileReader is - * a unique_ptr. In order to initialize it in parquet_s3_fdw, we define + * a unique_ptr. In order to initialize it in parquet_s3_fdw, we define * FileReaderCache class and the cache entry has the pointer of this class. */ class FileReaderCache @@ -84,7 +84,7 @@ extern List *extract_parquet_fields(const char *path, const char *dirname, Aws:: extern char *create_foreign_table_query(const char *tablename, const char *schemaname, const char *servername, char **paths, int npaths, List *fields, List *options); -extern Aws::S3::S3Client *parquetGetConnection(UserMapping *user, bool use_minio); +extern Aws::S3::S3Client *parquetGetConnection(UserMapping *user, parquet_s3_server_opt* option); extern Aws::S3::S3Client *parquetGetConnectionByTableid(Oid foreigntableid, Oid userid); extern void parquetReleaseConnection(Aws::S3::S3Client *conn); extern List* parquetGetS3ObjectList(Aws::S3::S3Client *s3_cli, const char *s3path); diff --git a/parquet_s3_fdw_connection.cpp b/parquet_s3_fdw_connection.cpp index ba10c48..9174659 100644 --- a/parquet_s3_fdw_connection.cpp +++ b/parquet_s3_fdw_connection.cpp @@ -11,6 +11,8 @@ *------------------------------------------------------------------------- */ #include +#include //shhnwz +#include #include #include #include @@ -119,13 +121,13 @@ PG_FUNCTION_INFO_V1(parquet_s3_fdw_disconnect_all); } /* prototypes of private functions */ -static void make_new_connection(ConnCacheEntry *entry, UserMapping *user, bool use_minio); +static void make_new_connection(ConnCacheEntry *entry, UserMapping *user, parquet_s3_server_opt* option); static bool disconnect_cached_connections(Oid serverid); -static Aws::S3::S3Client *create_s3_connection(ForeignServer *server, UserMapping *user, bool use_minio); +static Aws::S3::S3Client *create_s3_connection(ForeignServer *server, UserMapping *user, parquet_s3_server_opt* option); static void close_s3_connection(ConnCacheEntry *entry); static void check_conn_params(const char **keywords, const char **values, UserMapping *user); static void parquet_fdw_inval_callback(Datum arg, int cacheid, uint32 hashvalue); -static Aws::S3::S3Client* s3_client_open(const char *user, const char *password, bool use_minio, const char *endpoint, const char *awsRegion); +static Aws::S3::S3Client* s3_client_open(const char *user, const char *password, bool use_minio, bool use_credential_provider, const char *endpoint, const char *awsRegion); static void s3_client_close(Aws::S3::S3Client *s3_client); extern "C" void @@ -148,7 +150,7 @@ parquet_s3_shutdown() * if we don't already have a suitable one. */ Aws::S3::S3Client * -parquetGetConnection(UserMapping *user, bool use_minio) +parquetGetConnection(UserMapping *user, parquet_s3_server_opt* option) { bool found; ConnCacheEntry *entry; @@ -215,7 +217,7 @@ parquetGetConnection(UserMapping *user, bool use_minio) * will remain in a valid empty state, ie conn == NULL.) */ if (entry->conn == NULL) - make_new_connection(entry, user, use_minio); + make_new_connection(entry, user, option); return entry->conn; } @@ -225,7 +227,7 @@ parquetGetConnection(UserMapping *user, bool use_minio) * establish new connection to the remote server. */ static void -make_new_connection(ConnCacheEntry *entry, UserMapping *user, bool use_minio) +make_new_connection(ConnCacheEntry *entry, UserMapping *user, parquet_s3_server_opt* option) { ForeignServer *server = GetForeignServer(user->serverid); @@ -242,7 +244,7 @@ make_new_connection(ConnCacheEntry *entry, UserMapping *user, bool use_minio) ObjectIdGetDatum(user->umid)); /* Now try to make the handle */ - entry->conn = create_s3_connection(server, user, use_minio); + entry->conn = create_s3_connection(server, user, option); elog(DEBUG3, "parquet_s3_fdw: new parquet_fdw handle %p for server \"%s\" (user mapping oid %u, userid %u)", entry->conn, server->servername, user->umid, user->userid); @@ -330,7 +332,7 @@ ExtractConnectionOptions(List *defelems, const char **keywords, * Connect to remote server using specified server and user mapping properties. */ static Aws::S3::S3Client * -create_s3_connection(ForeignServer *server, UserMapping *user, bool use_minio) +create_s3_connection(ForeignServer *server, UserMapping *user, parquet_s3_server_opt* option) { Aws::S3::S3Client *volatile conn = NULL; @@ -354,7 +356,7 @@ create_s3_connection(ForeignServer *server, UserMapping *user, bool use_minio) n = list_length(lst_options) + 1; keywords = (const char **) palloc(n * sizeof(char *)); values = (const char **) palloc(n * sizeof(char *)); - + n = ExtractConnectionOptions( lst_options, keywords, values); keywords[n] = values[n] = NULL; @@ -380,7 +382,7 @@ create_s3_connection(ForeignServer *server, UserMapping *user, bool use_minio) endpoint = defGetString(def); } - conn = s3_client_open(id, password, use_minio, endpoint, awsRegion); + conn = s3_client_open(id, password, option->use_minio, option->use_credential_provider, endpoint, awsRegion); //shhnwz if (!conn) ereport(ERROR, (errcode(ERRCODE_SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION), @@ -494,13 +496,29 @@ parquet_fdw_inval_callback(Datum arg, int cacheid, uint32 hashvalue) * Create S3 handle. */ static Aws::S3::S3Client* -s3_client_open(const char *user, const char *password, bool use_minio, const char *endpoint, const char * awsRegion) +s3_client_open(const char *user, const char *password, bool use_minio, bool use_credential_provider, const char *endpoint, const char * awsRegion) { const Aws::String access_key_id = user; const Aws::String secret_access_key = password; - Aws::Auth::AWSCredentials cred = Aws::Auth::AWSCredentials(access_key_id, secret_access_key); + Aws::Auth::AWSCredentials cred; Aws::S3::S3Client *s3_client; - +//shhnwz-> + if (use_credential_provider) + { + Aws::Auth::DefaultAWSCredentialsProviderChain cred_provider; //shhnwz + cred = cred_provider.GetAWSCredentials(); + if (awsRegion == NULL) + { + awsRegion = Aws::Environment::GetEnv("AWS_REGION").c_str(); + } + elog(DEBUG1, "parquet_s3_fdw: AWSAccessKeyId %s", cred.GetAWSAccessKeyId().c_str()); + elog(DEBUG1, "parquet_s3_fdw: AWSSecretKeyId %s", cred.GetAWSSecretKey().c_str()); + elog(DEBUG1, "parquet_s3_fdw: AWSREGION %s", awsRegion); + }else + { + cred = Aws::Auth::AWSCredentials(access_key_id, secret_access_key); + } +//<-shhnwz pthread_mutex_lock(&cred_mtx); Aws::Client::ClientConfiguration clientConfig; pthread_mutex_unlock(&cred_mtx); @@ -551,7 +569,7 @@ parquetGetConnectionByTableid(Oid foreigntableid, Oid userid) Assert(userid != InvalidOid); user = GetUserMapping(userid, fserver->serverid); options = parquet_s3_get_options(foreigntableid); - s3client = parquetGetConnection(user, options->use_minio); + s3client = parquetGetConnection(user, options); } return s3client; } @@ -621,7 +639,7 @@ parquetGetS3ObjectList(Aws::S3::S3Client *s3_cli, const char *s3path) /* * If the keep_connections option of its server is disabled, - * then discard it to recover. Next parquetGetConnection + * then discard it to recover. Next parquetGetConnection * will open a new connection. */ void @@ -703,7 +721,7 @@ parquetIsS3Filenames(List *filenames) /* * Split s3 path into bucket name and file path. - * If foreign table option 'dirname' is specified, dirname starts by + * If foreign table option 'dirname' is specified, dirname starts by * "s3://". And filename is already set by get_filenames_in_dir(). * On the other hand, if foreign table option 'filename' is specified, * dirname is NULL (Or empty string when ANALYZE was executed) @@ -791,14 +809,14 @@ List * parquetImportForeignSchemaS3(ImportForeignSchemaStmt *stmt, Oid serverOid) { List *cmds = NIL; - Aws::S3::S3Client *s3client; + Aws::S3::S3Client *s3client; List *objects; ListCell *cell; ForeignServer *fserver = GetForeignServer(serverOid); UserMapping *user = GetUserMapping(GetUserId(), fserver->serverid); parquet_s3_server_opt *options = parquet_s3_get_server_options(serverOid); - s3client = parquetGetConnection(user, options->use_minio); + s3client = parquetGetConnection(user, options); objects = parquetGetS3ObjectList(s3client, stmt->remote_schema); @@ -878,11 +896,11 @@ parquetExtractParquetFields(List *fields, char **paths, const char *servername) if (!fields) { if (IS_S3_PATH(paths[0])) - { + { ForeignServer *fserver = GetForeignServerByName(servername, false); UserMapping *user = GetUserMapping(GetUserId(), fserver->serverid); parquet_s3_server_opt *options = parquet_s3_get_server_options(fserver->serverid); - Aws::S3::S3Client *s3client = parquetGetConnection(user, options->use_minio); + Aws::S3::S3Client *s3client = parquetGetConnection(user, options); fields = extract_parquet_fields(paths[0], NULL, s3client); } diff --git a/parquet_s3_fdw_server_option.c b/parquet_s3_fdw_server_option.c index d836fd9..f3b84a0 100644 --- a/parquet_s3_fdw_server_option.c +++ b/parquet_s3_fdw_server_option.c @@ -33,7 +33,8 @@ bool parquet_s3_is_valid_server_option(DefElem *def) { if (strcmp(def->defname, SERVER_OPTION_USE_MINIO) == 0 || - strcmp(def->defname, SERVER_OPTION_KEEP_CONNECTIONS) == 0) + strcmp(def->defname, SERVER_OPTION_KEEP_CONNECTIONS) == 0 || + strcmp(def->defname, SERVER_OPTION_USE_CREDENTIAL_PROVIDER) == 0) //shhnwz { /* Check that bool value is valid */ bool check_bool_valid; @@ -71,6 +72,8 @@ parquet_s3_extract_options(List *options, parquet_s3_server_opt * opt) opt->use_minio = defGetBoolean(def); else if (strcmp(def->defname, SERVER_OPTION_KEEP_CONNECTIONS) == 0) opt->keep_connections = defGetBoolean(def); + else if (strcmp(def->defname, SERVER_OPTION_USE_CREDENTIAL_PROVIDER) == 0) //shhnwz + opt->use_credential_provider = defGetBoolean(def); else if (strcmp(def->defname, SERVER_OPTION_REGION) == 0) opt->region = defGetString(def); else if (strcmp(def->defname, SERVER_OPTION_ENDPOINT) == 0) @@ -98,6 +101,7 @@ parquet_s3_get_options(Oid foreignoid) opt->use_minio = false; /* By default, all the connections to any foreign servers are kept open. */ opt->keep_connections = true; + opt->use_credential_provider = false; //shhnwz opt->region = "ap-northeast-1"; opt->endpoint = "127.0.0.1:9000"; @@ -147,6 +151,7 @@ parquet_s3_get_server_options(Oid serverid) opt->use_minio = false; /* By default, all the connections to any foreign servers are kept open. */ opt->keep_connections = true; + opt->use_credential_provider = false; //shhnwz opt->region = "ap-northeast-1"; opt->endpoint = "127.0.0.1:9000"; diff --git a/src/parquet_impl.cpp b/src/parquet_impl.cpp index d1ee672..8070936 100644 --- a/src/parquet_impl.cpp +++ b/src/parquet_impl.cpp @@ -989,7 +989,7 @@ extract_rowgroups_list(const char *filename, } /* loop over rowgroups */ } catch(const std::exception& e) { - error = e.what(); + error = e.what(); } if (!error.empty()) { if (reader_entry) @@ -2540,7 +2540,7 @@ parquetAcquireSampleRowsFunc(Relation relation, int /* elevel */, slcols.insert(std::string(strVal(rcol))); } - festate = create_parquet_execution_state(RT_MULTI, reader_cxt, + festate = create_parquet_execution_state(RT_MULTI, reader_cxt, fdw_private.dirname, fdw_private.s3client, tupleDesc, @@ -2753,7 +2753,7 @@ parquetIsForeignScanParallelSafe(PlannerInfo * /* root */, RelOptInfo *rel, RangeTblEntry * /* rte */) { - /* Plan nodes that reference a correlated SubPlan is always parallel restricted. + /* Plan nodes that reference a correlated SubPlan is always parallel restricted. * Therefore, return false when there is lateral join. */ if (rel->lateral_relids) @@ -4329,6 +4329,7 @@ int ExecForeignDDL(Oid serverOid, table = GetForeignTable(RelationGetRelid(rel)); user = GetUserMapping(GetUserId(), serverOid); + parquet_s3_server_opt *options = parquet_s3_get_options(serverOid); foreach(lc, server->options) { DefElem *def = (DefElem *) lfirst(lc); @@ -4351,7 +4352,7 @@ int ExecForeignDDL(Oid serverOid, } if (IS_S3_PATH(dirname) || parquetIsS3Filenames(filenames)) - s3_client = parquetGetConnection(user, use_minio); + s3_client = parquetGetConnection(user, options); //shhnwz else s3_client = NULL;