forked from chimpler/postgres-aws-s3
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathaws_s3--1.0.0.sql
321 lines (288 loc) · 11 KB
/
aws_s3--1.0.0.sql
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
-- complain if script is sourced in psql, rather than via CREATE EXTENSION
\echo Use "CREATE EXTENSION aws_s3" to load this file. \quit
CREATE SCHEMA IF NOT EXISTS aws_commons;
CREATE SCHEMA IF NOT EXISTS aws_s3;
DROP TYPE IF EXISTS aws_commons._s3_uri_1 CASCADE;
CREATE TYPE aws_commons._s3_uri_1 AS (bucket TEXT, file_path TEXT, region TEXT);
DROP TYPE IF EXISTS aws_commons._aws_credentials_1 CASCADE;
CREATE TYPE aws_commons._aws_credentials_1 AS (access_key TEXT, secret_key TEXT, session_token TEXT);
--
-- Create a aws_commons._s3_uri_1 object that holds the bucket, key and region
--
CREATE OR REPLACE FUNCTION aws_commons.create_s3_uri(
s3_bucket text,
s3_key text,
aws_region text
) RETURNS aws_commons._s3_uri_1
LANGUAGE plpython3u IMMUTABLE
AS $$
return (s3_bucket, s3_key, aws_region)
$$;
--
-- Create a aws_commons._aws_credentials_1 object that holds the access_key, secret_key and session_token
--
CREATE OR REPLACE FUNCTION aws_commons.create_aws_credentials(
access_key text,
secret_key text,
session_token text
) RETURNS aws_commons._aws_credentials_1
LANGUAGE plpython3u IMMUTABLE
AS $$
return (access_key, secret_key, session_token)
$$;
CREATE OR REPLACE FUNCTION aws_s3.table_import_from_s3 (
table_name text,
column_list text,
options text,
bucket text,
file_path text,
region text,
access_key text default null,
secret_key text default null,
session_token text default null,
endpoint_url text default null,
read_timeout integer default 60,
override boolean default false,
tempfile_dir text default '/var/lib/postgresql/data/'
) RETURNS int
LANGUAGE plpython3u
AS $$
def cache_import(module_name):
module_cache = SD.get('__modules__', {})
if module_name in module_cache:
return module_cache[module_name]
else:
import importlib
_module = importlib.import_module(module_name)
if not module_cache:
SD['__modules__'] = module_cache
module_cache[module_name] = _module
return _module
boto3 = cache_import('boto3')
tempfile = cache_import('tempfile')
gzip = cache_import('gzip')
shutil = cache_import('shutil')
plan = plpy.prepare("select name, current_setting('aws_s3.' || name, true) as value from (select unnest(array['access_key_id', 'secret_access_key', 'session_token', 'endpoint_url']) as name) a");
default_aws_settings = {
row['name']: row['value']
for row in plan.execute()
}
aws_settings = {
'aws_access_key_id': access_key if access_key else default_aws_settings.get('access_key_id', 'unknown'),
'aws_secret_access_key': secret_key if secret_key else default_aws_settings.get('secret_access_key', 'unknown'),
'aws_session_token': session_token if session_token else default_aws_settings.get('session_token'),
'endpoint_url': endpoint_url if endpoint_url else default_aws_settings.get('endpoint_url')
}
s3 = boto3.resource(
's3',
region_name=region,
config=boto3.session.Config(read_timeout=read_timeout),
**aws_settings
)
if override:
plpy.execute("TRUNCATE TABLE {table_name} RESTRICT;".format(table_name=table_name))
formatted_column_list = "({column_list})".format(column_list=column_list) if column_list else ''
num_rows = 0
for file_path_item in file_path.split(","):
file_path_item = file_path_item.strip()
if not file_path_item:
continue
s3_objects = []
if file_path_item.endswith("/"): # Directory
bucket_objects = s3.Bucket(bucket).objects.filter(Prefix=file_path_item)
s3_objects = [bucket_object for bucket_object in bucket_objects]
else: # File
s3_object = s3.Object(bucket, file_path_item)
s3_objects = [s3_object]
for s3_object in s3_objects:
response = s3_object.get()
content_encoding = response.get('ContentEncoding')
body = response['Body']
user_content_encoding = response.get('x-amz-meta-content-encoding')
with tempfile.NamedTemporaryFile(dir=tempfile_dir) as fd:
if (content_encoding and content_encoding.lower() == 'gzip') or (user_content_encoding and user_content_encoding.lower() == 'gzip'):
with gzip.GzipFile(fileobj=body) as gzipfile:
while fd.write(gzipfile.read(204800)):
pass
else:
while fd.write(body.read(204800)):
pass
fd.flush()
res = plpy.execute("COPY {table_name} {formatted_column_list} FROM {filename} {options};".format(
table_name=table_name,
filename=plpy.quote_literal(fd.name),
formatted_column_list=formatted_column_list,
options=options
)
)
num_rows += res.nrows()
return num_rows
$$;
--
-- S3 function to import data from S3 into a table
--
CREATE OR REPLACE FUNCTION aws_s3.table_import_from_s3(
table_name text,
column_list text,
options text,
s3_info aws_commons._s3_uri_1,
credentials aws_commons._aws_credentials_1,
endpoint_url text default null,
read_timeout integer default 60,
override boolean default false,
tempfile_dir text default '/var/lib/postgresql/data/'
) RETURNS INT
LANGUAGE plpython3u
AS $$
plan = plpy.prepare(
'SELECT aws_s3.table_import_from_s3($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) AS num_rows',
['TEXT', 'TEXT', 'TEXT', 'TEXT', 'TEXT', 'TEXT', 'TEXT', 'TEXT', 'TEXT', 'TEXT', 'INTEGER', 'BOOLEAN', 'TEXT']
)
return plan.execute(
[
table_name,
column_list,
options,
s3_info['bucket'],
s3_info['file_path'],
s3_info['region'],
credentials['access_key'],
credentials['secret_key'],
credentials['session_token'],
endpoint_url,
read_timeout,
override,
tempfile_dir
]
)[0]['num_rows']
$$;
CREATE OR REPLACE FUNCTION aws_s3.query_export_to_s3(
query text,
bucket text,
file_path text,
region text default null,
access_key text default null,
secret_key text default null,
session_token text default null,
options text default null,
endpoint_url text default null,
read_timeout integer default 60,
override boolean default false,
tempfile_dir text default '/var/lib/postgresql/data/',
OUT rows_uploaded bigint,
OUT files_uploaded bigint,
OUT bytes_uploaded bigint
) RETURNS SETOF RECORD
LANGUAGE plpython3u
AS $$
def cache_import(module_name):
module_cache = SD.get('__modules__', {})
if module_name in module_cache:
return module_cache[module_name]
else:
import importlib
_module = importlib.import_module(module_name)
if not module_cache:
SD['__modules__'] = module_cache
module_cache[module_name] = _module
return _module
def file_exists(bucket, file_path, s3_client):
try:
s3_client.head_object(Bucket=bucket, Key=file_path)
return True
except:
return False
def get_unique_file_path(base_name, counter, extension):
return f"{base_name}_part{counter}{extension}"
boto3 = cache_import('boto3')
tempfile = cache_import('tempfile')
re = cache_import("re")
plan = plpy.prepare("select name, current_setting('aws_s3.' || name, true) as value from (select unnest(array['access_key_id', 'secret_access_key', 'session_token', 'endpoint_url']) as name) a");
default_aws_settings = {
row['name']: row['value']
for row in plan.execute()
}
aws_settings = {
'aws_access_key_id': access_key if access_key else default_aws_settings.get('access_key_id', 'unknown'),
'aws_secret_access_key': secret_key if secret_key else default_aws_settings.get('secret_access_key', 'unknown'),
'aws_session_token': session_token if session_token else default_aws_settings.get('session_token'),
'endpoint_url': endpoint_url if endpoint_url else default_aws_settings.get('endpoint_url')
}
s3 = boto3.client(
's3',
region_name=region,
config=boto3.session.Config(read_timeout=read_timeout),
**aws_settings
)
upload_file_path = file_path
if not override:
# generate unique file path
file_path_parts = re.match(r'^(.*?)(\.[^.]*$|$)', upload_file_path)
base_name = file_path_parts.group(1)
extension = file_path_parts.group(2)
if not file_exists(bucket, file_path, s3):
upload_file_path = file_path
else:
counter = 1
while file_exists(bucket, get_unique_file_path(base_name, counter, extension), s3):
counter += 1
upload_file_path = get_unique_file_path(base_name, counter, extension)
with tempfile.NamedTemporaryFile(dir=tempfile_dir) as fd:
plan = plpy.prepare(
"COPY ({query}) TO '{filename}' {options}".format(
query=query,
filename=fd.name,
options="({options})".format(options=options) if options else ''
)
)
plan.execute()
num_lines = 0
size = 0
while True:
buffer = fd.read(8192 * 1024)
if not buffer:
break
num_lines += buffer.count(b'\n')
size += len(buffer)
fd.seek(0)
s3.upload_fileobj(fd, bucket, upload_file_path)
if 'HEADER TRUE' in options.upper():
num_lines -= 1
yield (num_lines, 1, size)
$$;
CREATE OR REPLACE FUNCTION aws_s3.query_export_to_s3(
query text,
s3_info aws_commons._s3_uri_1,
credentials aws_commons._aws_credentials_1 default null,
options text default null,
endpoint_url text default null,
read_timeout integer default 60,
override boolean default false,
tempfile_dir text default '/var/lib/postgresql/data/',
OUT rows_uploaded bigint,
OUT files_uploaded bigint,
OUT bytes_uploaded bigint
) RETURNS SETOF RECORD
LANGUAGE plpython3u
AS $$
plan = plpy.prepare(
'SELECT * FROM aws_s3.query_export_to_s3($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)',
['TEXT', 'TEXT', 'TEXT', 'TEXT', 'TEXT', 'TEXT', 'TEXT', 'TEXT', 'TEXT', 'INTEGER', 'BOOLEAN', 'TEXT']
)
return plan.execute(
[
query,
s3_info.get('bucket'),
s3_info.get('file_path'),
s3_info.get('region'),
credentials.get('access_key') if credentials else None,
credentials.get('secret_key') if credentials else None,
credentials.get('session_token') if credentials else None,
options,
endpoint_url,
read_timeout,
override,
tempfile_dir
]
)
$$;