Skip to content

Commit

Permalink
Added feature to fetch s3 credential from aws provider chain
Browse files Browse the repository at this point in the history
  • Loading branch information
shhnwz committed May 30, 2024
1 parent 50f3e02 commit fe200f2
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 27 deletions.
3 changes: 3 additions & 0 deletions parquet_s3_fdw.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ typedef struct parquet_s3_server_opt
bool use_minio; /* Connect to MinIO instead of Amazon S3. */
bool keep_connections; /* setting value of keep_connections
* server option */
bool use_credential_provider; /* Retrieve AWS credentials using
* AWS Credential providers shhnwz */
char *region; /* AWS region to connect to */
char *endpoint; /* Address and port to connect to */
} parquet_s3_server_opt;
Expand All @@ -71,6 +73,7 @@ int ExecForeignDDL(Oid serverOid,
/* Option name for CREATE FOREIGN SERVER. */
#define SERVER_OPTION_USE_MINIO "use_minio"
#define SERVER_OPTION_KEEP_CONNECTIONS "keep_connections"
#define SERVER_OPTION_USE_CREDENTIAL_PROVIDER "use_credential_provider" //shhnwz
#define SERVER_OPTION_REGION "region"
#define SERVER_OPTION_ENDPOINT "endpoint"

Expand Down
4 changes: 2 additions & 2 deletions parquet_s3_fdw.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ typedef enum FileLocation_t
/*
* We would like to cache FileReader. When creating new hash entry,
* the memory of entry is allocated by PostgreSQL core. But FileReader is
* a unique_ptr. In order to initialize it in parquet_s3_fdw, we define
* a unique_ptr. In order to initialize it in parquet_s3_fdw, we define
* FileReaderCache class and the cache entry has the pointer of this class.
*/
class FileReaderCache
Expand All @@ -84,7 +84,7 @@ extern List *extract_parquet_fields(const char *path, const char *dirname, Aws::
extern char *create_foreign_table_query(const char *tablename, const char *schemaname, const char *servername,
char **paths, int npaths, List *fields, List *options);

extern Aws::S3::S3Client *parquetGetConnection(UserMapping *user, bool use_minio);
extern Aws::S3::S3Client *parquetGetConnection(UserMapping *user, parquet_s3_server_opt* option);
extern Aws::S3::S3Client *parquetGetConnectionByTableid(Oid foreigntableid, Oid userid);
extern void parquetReleaseConnection(Aws::S3::S3Client *conn);
extern List* parquetGetS3ObjectList(Aws::S3::S3Client *s3_cli, const char *s3path);
Expand Down
58 changes: 38 additions & 20 deletions parquet_s3_fdw_connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
*-------------------------------------------------------------------------
*/
#include <aws/core/auth/AWSCredentialsProvider.h>
#include <aws/core/auth/AWSCredentialsProviderChain.h> //shhnwz
#include <aws/core/platform/Environment.h>
#include <aws/core/auth/AWSAuthSigner.h>
#include <aws/core/Aws.h>
#include <aws/s3/S3Client.h>
Expand Down Expand Up @@ -119,13 +121,13 @@ PG_FUNCTION_INFO_V1(parquet_s3_fdw_disconnect_all);
}

/* prototypes of private functions */
static void make_new_connection(ConnCacheEntry *entry, UserMapping *user, bool use_minio);
static void make_new_connection(ConnCacheEntry *entry, UserMapping *user, parquet_s3_server_opt* option);
static bool disconnect_cached_connections(Oid serverid);
static Aws::S3::S3Client *create_s3_connection(ForeignServer *server, UserMapping *user, bool use_minio);
static Aws::S3::S3Client *create_s3_connection(ForeignServer *server, UserMapping *user, parquet_s3_server_opt* option);
static void close_s3_connection(ConnCacheEntry *entry);
static void check_conn_params(const char **keywords, const char **values, UserMapping *user);
static void parquet_fdw_inval_callback(Datum arg, int cacheid, uint32 hashvalue);
static Aws::S3::S3Client* s3_client_open(const char *user, const char *password, bool use_minio, const char *endpoint, const char *awsRegion);
static Aws::S3::S3Client* s3_client_open(const char *user, const char *password, bool use_minio, bool use_credential_provider, const char *endpoint, const char *awsRegion);
static void s3_client_close(Aws::S3::S3Client *s3_client);

extern "C" void
Expand All @@ -148,7 +150,7 @@ parquet_s3_shutdown()
* if we don't already have a suitable one.
*/
Aws::S3::S3Client *
parquetGetConnection(UserMapping *user, bool use_minio)
parquetGetConnection(UserMapping *user, parquet_s3_server_opt* option)
{
bool found;
ConnCacheEntry *entry;
Expand Down Expand Up @@ -215,7 +217,7 @@ parquetGetConnection(UserMapping *user, bool use_minio)
* will remain in a valid empty state, ie conn == NULL.)
*/
if (entry->conn == NULL)
make_new_connection(entry, user, use_minio);
make_new_connection(entry, user, option);

return entry->conn;
}
Expand All @@ -225,7 +227,7 @@ parquetGetConnection(UserMapping *user, bool use_minio)
* establish new connection to the remote server.
*/
static void
make_new_connection(ConnCacheEntry *entry, UserMapping *user, bool use_minio)
make_new_connection(ConnCacheEntry *entry, UserMapping *user, parquet_s3_server_opt* option)
{
ForeignServer *server = GetForeignServer(user->serverid);

Expand All @@ -242,7 +244,7 @@ make_new_connection(ConnCacheEntry *entry, UserMapping *user, bool use_minio)
ObjectIdGetDatum(user->umid));

/* Now try to make the handle */
entry->conn = create_s3_connection(server, user, use_minio);
entry->conn = create_s3_connection(server, user, option);

elog(DEBUG3, "parquet_s3_fdw: new parquet_fdw handle %p for server \"%s\" (user mapping oid %u, userid %u)",
entry->conn, server->servername, user->umid, user->userid);
Expand Down Expand Up @@ -330,7 +332,7 @@ ExtractConnectionOptions(List *defelems, const char **keywords,
* Connect to remote server using specified server and user mapping properties.
*/
static Aws::S3::S3Client *
create_s3_connection(ForeignServer *server, UserMapping *user, bool use_minio)
create_s3_connection(ForeignServer *server, UserMapping *user, parquet_s3_server_opt* option)
{
Aws::S3::S3Client *volatile conn = NULL;

Expand All @@ -354,7 +356,7 @@ create_s3_connection(ForeignServer *server, UserMapping *user, bool use_minio)
n = list_length(lst_options) + 1;
keywords = (const char **) palloc(n * sizeof(char *));
values = (const char **) palloc(n * sizeof(char *));

n = ExtractConnectionOptions( lst_options,
keywords, values);
keywords[n] = values[n] = NULL;
Expand All @@ -380,7 +382,7 @@ create_s3_connection(ForeignServer *server, UserMapping *user, bool use_minio)
endpoint = defGetString(def);
}

conn = s3_client_open(id, password, use_minio, endpoint, awsRegion);
conn = s3_client_open(id, password, option->use_minio, option->use_credential_provider, endpoint, awsRegion); //shhnwz
if (!conn)
ereport(ERROR,
(errcode(ERRCODE_SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION),
Expand Down Expand Up @@ -494,13 +496,29 @@ parquet_fdw_inval_callback(Datum arg, int cacheid, uint32 hashvalue)
* Create S3 handle.
*/
static Aws::S3::S3Client*
s3_client_open(const char *user, const char *password, bool use_minio, const char *endpoint, const char * awsRegion)
s3_client_open(const char *user, const char *password, bool use_minio, bool use_credential_provider, const char *endpoint, const char * awsRegion)
{
const Aws::String access_key_id = user;
const Aws::String secret_access_key = password;
Aws::Auth::AWSCredentials cred = Aws::Auth::AWSCredentials(access_key_id, secret_access_key);
Aws::Auth::AWSCredentials cred;
Aws::S3::S3Client *s3_client;

//shhnwz->
if (use_credential_provider)
{
Aws::Auth::DefaultAWSCredentialsProviderChain cred_provider; //shhnwz
cred = cred_provider.GetAWSCredentials();
if (awsRegion == NULL)
{
awsRegion = Aws::Environment::GetEnv("AWS_REGION").c_str();
}
elog(DEBUG1, "parquet_s3_fdw: AWSAccessKeyId %s", cred.GetAWSAccessKeyId().c_str());
elog(DEBUG1, "parquet_s3_fdw: AWSSecretKeyId %s", cred.GetAWSSecretKey().c_str());
elog(DEBUG1, "parquet_s3_fdw: AWSREGION %s", awsRegion);
}else
{
cred = Aws::Auth::AWSCredentials(access_key_id, secret_access_key);
}
//<-shhnwz
pthread_mutex_lock(&cred_mtx);
Aws::Client::ClientConfiguration clientConfig;
pthread_mutex_unlock(&cred_mtx);
Expand Down Expand Up @@ -551,7 +569,7 @@ parquetGetConnectionByTableid(Oid foreigntableid, Oid userid)
Assert(userid != InvalidOid);
user = GetUserMapping(userid, fserver->serverid);
options = parquet_s3_get_options(foreigntableid);
s3client = parquetGetConnection(user, options->use_minio);
s3client = parquetGetConnection(user, options);
}
return s3client;
}
Expand Down Expand Up @@ -621,7 +639,7 @@ parquetGetS3ObjectList(Aws::S3::S3Client *s3_cli, const char *s3path)

/*
* If the keep_connections option of its server is disabled,
* then discard it to recover. Next parquetGetConnection
* then discard it to recover. Next parquetGetConnection
* will open a new connection.
*/
void
Expand Down Expand Up @@ -703,7 +721,7 @@ parquetIsS3Filenames(List *filenames)

/*
* Split s3 path into bucket name and file path.
* If foreign table option 'dirname' is specified, dirname starts by
* If foreign table option 'dirname' is specified, dirname starts by
* "s3://". And filename is already set by get_filenames_in_dir().
* On the other hand, if foreign table option 'filename' is specified,
* dirname is NULL (Or empty string when ANALYZE was executed)
Expand Down Expand Up @@ -791,14 +809,14 @@ List *
parquetImportForeignSchemaS3(ImportForeignSchemaStmt *stmt, Oid serverOid)
{
List *cmds = NIL;
Aws::S3::S3Client *s3client;
Aws::S3::S3Client *s3client;
List *objects;
ListCell *cell;

ForeignServer *fserver = GetForeignServer(serverOid);
UserMapping *user = GetUserMapping(GetUserId(), fserver->serverid);
parquet_s3_server_opt *options = parquet_s3_get_server_options(serverOid);
s3client = parquetGetConnection(user, options->use_minio);
s3client = parquetGetConnection(user, options);

objects = parquetGetS3ObjectList(s3client, stmt->remote_schema);

Expand Down Expand Up @@ -878,11 +896,11 @@ parquetExtractParquetFields(List *fields, char **paths, const char *servername)
if (!fields)
{
if (IS_S3_PATH(paths[0]))
{
{
ForeignServer *fserver = GetForeignServerByName(servername, false);
UserMapping *user = GetUserMapping(GetUserId(), fserver->serverid);
parquet_s3_server_opt *options = parquet_s3_get_server_options(fserver->serverid);
Aws::S3::S3Client *s3client = parquetGetConnection(user, options->use_minio);
Aws::S3::S3Client *s3client = parquetGetConnection(user, options);

fields = extract_parquet_fields(paths[0], NULL, s3client);
}
Expand Down
7 changes: 6 additions & 1 deletion parquet_s3_fdw_server_option.c
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ bool
parquet_s3_is_valid_server_option(DefElem *def)
{
if (strcmp(def->defname, SERVER_OPTION_USE_MINIO) == 0 ||
strcmp(def->defname, SERVER_OPTION_KEEP_CONNECTIONS) == 0)
strcmp(def->defname, SERVER_OPTION_KEEP_CONNECTIONS) == 0 ||
strcmp(def->defname, SERVER_OPTION_USE_CREDENTIAL_PROVIDER) == 0) //shhnwz
{
/* Check that bool value is valid */
bool check_bool_valid;
Expand Down Expand Up @@ -71,6 +72,8 @@ parquet_s3_extract_options(List *options, parquet_s3_server_opt * opt)
opt->use_minio = defGetBoolean(def);
else if (strcmp(def->defname, SERVER_OPTION_KEEP_CONNECTIONS) == 0)
opt->keep_connections = defGetBoolean(def);
else if (strcmp(def->defname, SERVER_OPTION_USE_CREDENTIAL_PROVIDER) == 0) //shhnwz
opt->use_credential_provider = defGetBoolean(def);
else if (strcmp(def->defname, SERVER_OPTION_REGION) == 0)
opt->region = defGetString(def);
else if (strcmp(def->defname, SERVER_OPTION_ENDPOINT) == 0)
Expand Down Expand Up @@ -98,6 +101,7 @@ parquet_s3_get_options(Oid foreignoid)
opt->use_minio = false;
/* By default, all the connections to any foreign servers are kept open. */
opt->keep_connections = true;
opt->use_credential_provider = false; //shhnwz
opt->region = "ap-northeast-1";
opt->endpoint = "127.0.0.1:9000";

Expand Down Expand Up @@ -147,6 +151,7 @@ parquet_s3_get_server_options(Oid serverid)
opt->use_minio = false;
/* By default, all the connections to any foreign servers are kept open. */
opt->keep_connections = true;
opt->use_credential_provider = false; //shhnwz
opt->region = "ap-northeast-1";
opt->endpoint = "127.0.0.1:9000";

Expand Down
9 changes: 5 additions & 4 deletions src/parquet_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -989,7 +989,7 @@ extract_rowgroups_list(const char *filename,
} /* loop over rowgroups */
}
catch(const std::exception& e) {
error = e.what();
error = e.what();
}
if (!error.empty()) {
if (reader_entry)
Expand Down Expand Up @@ -2540,7 +2540,7 @@ parquetAcquireSampleRowsFunc(Relation relation, int /* elevel */,
slcols.insert(std::string(strVal(rcol)));
}

festate = create_parquet_execution_state(RT_MULTI, reader_cxt,
festate = create_parquet_execution_state(RT_MULTI, reader_cxt,
fdw_private.dirname,
fdw_private.s3client,
tupleDesc,
Expand Down Expand Up @@ -2753,7 +2753,7 @@ parquetIsForeignScanParallelSafe(PlannerInfo * /* root */,
RelOptInfo *rel,
RangeTblEntry * /* rte */)
{
/* Plan nodes that reference a correlated SubPlan is always parallel restricted.
/* Plan nodes that reference a correlated SubPlan is always parallel restricted.
* Therefore, return false when there is lateral join.
*/
if (rel->lateral_relids)
Expand Down Expand Up @@ -4329,6 +4329,7 @@ int ExecForeignDDL(Oid serverOid,
table = GetForeignTable(RelationGetRelid(rel));
user = GetUserMapping(GetUserId(), serverOid);

parquet_s3_server_opt *options = parquet_s3_get_options(serverOid);
foreach(lc, server->options)
{
DefElem *def = (DefElem *) lfirst(lc);
Expand All @@ -4351,7 +4352,7 @@ int ExecForeignDDL(Oid serverOid,
}

if (IS_S3_PATH(dirname) || parquetIsS3Filenames(filenames))
s3_client = parquetGetConnection(user, use_minio);
s3_client = parquetGetConnection(user, options); //shhnwz
else
s3_client = NULL;

Expand Down

0 comments on commit fe200f2

Please sign in to comment.