Skip to content

Commit

Permalink
Support versioned key (#36)
Browse files Browse the repository at this point in the history
* support versioned download

* PR comments

* support several other operations

* remove some newlines

* add s3 mv tests

* add cli args

* provide options to list versions

* tweak help str

* add flag for rm

* lint fix

* don't support mv command on versioned files
  • Loading branch information
Chenyang Liu authored Apr 10, 2017
1 parent 87dbd86 commit bfc30a7
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 22 deletions.
12 changes: 9 additions & 3 deletions baiji/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,19 @@ class ListCommand(BaijiCommand):
uri = cli.Flag(["-B", "--uri"], help='This option does nothing. It used to return URIs instead of paths, but this is now the default.')
detail = cli.Flag(['-l', '--detail'], help='print details, like `ls -l`')
shallow = cli.Flag("--shallow", help='process key names hierarchically and return only immediate "children" (like ls, instead of like find)')
list_versions = cli.Flag(['--list-versions'], help='print all versions')

def main(self, key):
if self.uri:
print "-B and --uri are deprecated options"
try:
keys = s3.ls(key, return_full_urls=True, require_s3_scheme=True, shallow=self.shallow)
keys = s3.ls(key, return_full_urls=True, require_s3_scheme=True, shallow=self.shallow, list_versions=self.list_versions)
if self.detail:
from baiji.util.console import sizeof_format_human_readable
for key in keys:
info = s3.info(key)
enc = " enc" if info['encrypted'] else " "
print "%s\t%s%s\t%s" % (sizeof_format_human_readable(info['size']), info['last_modified'], enc, key.encode('utf-8'),)
print "%s\t%s%s\t%s\t%s" % (sizeof_format_human_readable(info['size']), info['last_modified'], enc, key.encode('utf-8'), info['version_id'])
else:
print u"\n".join(keys).encode('utf-8')
except s3.InvalidSchemeException as e:
Expand All @@ -47,11 +49,12 @@ class RemoveCommand(BaijiCommand):
DESCRIPTION = "delete files on s3"
recursive = cli.Flag(['-r', '--recursive'], help='remove everything below key')
force = cli.Flag(['-f', '--force'], help="don't prompt for confirmation on recursive rm")
version_id = cli.SwitchAttr('--version-id', str, default=None, help='s3 object version ID')
def main(self, key):
if self.recursive:
s3.rm_r(key, force=self.force)
else:
s3.rm(key)
s3.rm(key, version_id=self.version_id)

class CopyCommand(BaijiCommand):
DESCRIPTION = "copy files from or to s3"
Expand All @@ -65,6 +68,7 @@ class CopyCommand(BaijiCommand):
gzip = cli.Flag(['-z', '--gzip'], help='Store compressed')
policy = cli.SwitchAttr('--policy', str, help='override policy when copying to s3 (e.g. private, public-read, bucket-owner-read')
encoding = cli.SwitchAttr('--encoding', str, help='Content-Encoding: gzip, etc')
version_id = cli.SwitchAttr('--version-id', str, default=None, help='s3 object version ID')
def main(self, src, dst):
kwargs = {
'force': self.force,
Expand All @@ -75,6 +79,7 @@ def main(self, src, dst):
'encrypt': self.encrypt,
'gzip': self.gzip,
'skip': self.skip,
'version_id': self.version_id,
}
if self.recursive or self.recursive_parallel:
s3.cp_r(src, dst, parallel=self.recursive_parallel, **kwargs)
Expand Down Expand Up @@ -117,6 +122,7 @@ def main(self, src, dst):
class ExistsCommand(BaijiCommand):
DESCRIPTION = "check if a file exists on s3"
retries = cli.SwitchAttr('--retries', int, help='how many times to retry', default=3)
version_id = cli.SwitchAttr('--version-id', str, default=None, help='s3 object version ID')
def main(self, key):
if not s3.exists(key, retries_allowed=self.retries):
return -1
Expand Down
34 changes: 22 additions & 12 deletions baiji/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,21 +52,23 @@ def _bucket(self, name, cache_buckets=None):
else:
raise

def _lookup(self, bucket_name, key, cache_buckets=None):
def _lookup(self, bucket_name, key, cache_buckets=None, version_id=None):
'''
See _bucket for the details on cache_buckets
'''
from baiji.util.munging import _strip_initial_slashes
from baiji.util.lookup import get_versioned_key_remote

key = _strip_initial_slashes(key)

try:
bucket = self._bucket(bucket_name, cache_buckets=cache_buckets)
except BucketNotFound:
return None

return bucket.lookup(key)
return get_versioned_key_remote(bucket, key, version_id=version_id)

def cp(self, key_or_file_from, key_or_file_to, force=False, progress=False, policy=None, preserve_acl=False, encoding=None, encrypt=True, gzip=False, content_type=None, guess_content_type=False, metadata=None, skip=False, validate=True, max_size=None):
def cp(self, key_or_file_from, key_or_file_to, force=False, progress=False, policy=None, preserve_acl=False, encoding=None, encrypt=True, gzip=False, content_type=None, guess_content_type=False, metadata=None, skip=False, validate=True, max_size=None, version_id=None):
"""
Copy file to or from AWS S3
Expand Down Expand Up @@ -95,6 +97,7 @@ def cp(self, key_or_file_from, key_or_file_to, force=False, progress=False, poli
op.skip = skip
op.validate = validate
op.max_size = max_size
op.version_id = version_id

if guess_content_type:
op.guess_content_type()
Expand Down Expand Up @@ -157,7 +160,7 @@ def common_prefix(a, b):
except KeyExists as e:
print str(e)

def rm(self, key_or_file):
def rm(self, key_or_file, version_id=None):
'''
Remove a key from AWS S3
'''
Expand All @@ -172,9 +175,9 @@ def rm(self, key_or_file):
else:
raise KeyNotFound("%s does not exist" % key_or_file)
elif k.scheme == 's3':
if not self.exists(key_or_file):
if not self.exists(key_or_file, version_id=version_id):
raise KeyNotFound("%s does not exist" % key_or_file)
return self._bucket(k.netloc).delete_key(_strip_initial_slashes(k.path))
return self._bucket(k.netloc).delete_key(_strip_initial_slashes(k.path), version_id=version_id)
else:
raise InvalidSchemeException("URI Scheme %s is not implemented" % k.scheme)

Expand All @@ -199,7 +202,7 @@ def rm_r(self, key_or_file, force=False, quiet=False):
if not quiet:
print "[deleted] %s" % url

def ls(self, s3prefix, return_full_urls=False, require_s3_scheme=False, shallow=False, followlinks=False):
def ls(self, s3prefix, return_full_urls=False, require_s3_scheme=False, shallow=False, followlinks=False, list_versions=False):
'''
List files on AWS S3
prefix is given as an S3 url: ``s3://bucket-name/path/to/dir``.
Expand All @@ -226,7 +229,13 @@ def ls(self, s3prefix, return_full_urls=False, require_s3_scheme=False, shallow=
clean_paths = lambda x: "s3://" + k.netloc + path.sep + x.name
else:
clean_paths = lambda x: path.sep + x.name
return itertools.imap(clean_paths, self._bucket(k.netloc).list(prefix=prefix, delimiter=delimiter))

if list_versions:
result_list_iterator = self._bucket(k.netloc).list_versions(prefix=prefix, delimiter=delimiter)
else:
result_list_iterator = self._bucket(k.netloc).list(prefix=prefix, delimiter=delimiter)

return itertools.imap(clean_paths, result_list_iterator)
elif k.scheme == 'file':
if require_s3_scheme:
raise InvalidSchemeException('URI should begin with s3://')
Expand Down Expand Up @@ -292,11 +301,12 @@ def info(self, key_or_file):
result['encrypted'] = bool(remote_object.encrypted)
result['acl'] = remote_object.get_acl()
result['owner'] = remote_object.owner
result['version_id'] = remote_object.version_id
else:
raise InvalidSchemeException("URI Scheme %s is not implemented" % k.scheme)
return result

def exists(self, key_or_file, retries_allowed=3):
def exists(self, key_or_file, retries_allowed=3, version_id=None):
'''
Check if a file exists on AWS S3
Expand All @@ -323,7 +333,7 @@ def exists(self, key_or_file, retries_allowed=3):
elif k.scheme == 's3':
retry_attempts = 0
while retry_attempts < retries_allowed:
key = self._lookup(k.netloc, k.path, cache_buckets=True)
key = self._lookup(k.netloc, k.path, cache_buckets=True, version_id=version_id)
if key:
if retry_attempts > 0: # only if we find it after failing at least once
import warnings
Expand All @@ -335,15 +345,15 @@ def exists(self, key_or_file, retries_allowed=3):
else:
raise InvalidSchemeException("URI Scheme %s is not implemented" % k.scheme)

def size(self, key_or_file):
def size(self, key_or_file, version_id=None):
'''
Return the size of a file. If it's on s3, don't download it.
'''
k = path.parse(key_or_file)
if k.scheme == 'file':
return os.path.getsize(k.path)
elif k.scheme == 's3':
k = self._lookup(k.netloc, k.path)
k = self._lookup(k.netloc, k.path, version_id=version_id)
if k is None:
raise KeyNotFound("s3://%s/%s not found on s3" % (k.netloc, k.path))
return k.size
Expand Down
35 changes: 30 additions & 5 deletions baiji/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,25 @@ def etag(self):
return self.connection.etag(self.uri)
def rm(self):
return self.connection.rm(self.uri)
def lookup(self):
def lookup(self, version_id=None):
from baiji.util.lookup import get_versioned_key_remote

if self.is_file:
raise ValueError("S3CopyOperation.CopyableKey.lookup called for local file")
key = self.bucket.lookup(self.remote_path)

key = get_versioned_key_remote(self.bucket, self.remote_path, version_id=version_id)

if not key:
raise KeyNotFound("Error finding %s on s3: doesn't exist" % (self.uri))
return key

def create(self):
if self.is_file:
raise ValueError("S3CopyOperation.CopyableKey.create called for local file")
from boto.s3.key import Key
key = Key(self.bucket)
key.key = self.remote_path

return key

def __init__(self, src, dst, connection):
Expand Down Expand Up @@ -99,6 +105,8 @@ def __init__(self, src, dst, connection):
self._retries = 0

self.file_size = None
# s3 version
self._version_id = None

@property # read only
def retries_made(self):
Expand All @@ -113,6 +121,14 @@ def policy(self, val):
raise ValueError("Policy only allowed when copying to s3")
self._policy = val # we get initialized with a call to the setter in init pylint: disable=attribute-defined-outside-init

@property
def version_id(self):
return self._version_id
@version_id.setter
def version_id(self, val):
self._version_id = val # we get initialized with a call to the setter in init pylint: disable=attribute-defined-outside-init


@property
def preserve_acl(self):
return self._preserve_acl
Expand Down Expand Up @@ -356,7 +372,7 @@ def download(self):
# twice by the same process.
tf = tempfile.NamedTemporaryFile(delete=False)
try:
key = self.src.lookup()
key = self.src.lookup(version_id=self.version_id)

with FileTransferProgressbar(supress=(not self.progress)) as cb:
key.get_contents_to_file(tf, cb=cb)
Expand All @@ -380,7 +396,6 @@ def download(self):
self.download()
else:
raise

finally:
self.connection.rm(tf.name)

Expand All @@ -404,7 +419,17 @@ def remote_copy(self):
meta['Content-Encoding'] = key.content_encoding
meta['Content-Type'] = key.content_type
meta = dict(meta.items() + self.metadata.items())
self.dst.bucket.copy_key(self.dst.remote_path, self.src.bucket_name, src, preserve_acl=self.preserve_acl, metadata=meta, headers=headers, encrypt_key=self.encrypt)
self.dst.bucket.copy_key(
self.dst.remote_path,
self.src.bucket_name,
src,
preserve_acl=self.preserve_acl,
metadata=meta,
headers=headers,
encrypt_key=self.encrypt,
src_version_id=self.version_id
)

if self.progress:
print 'Copied %s to %s' % (self.src.uri, self.dst.uri)

Expand Down
3 changes: 3 additions & 0 deletions baiji/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,6 @@ class S3Warning(RuntimeWarning):

class EventualConsistencyWarning(S3Warning):
pass

class InvalidVersionID(S3Exception):
pass
2 changes: 1 addition & 1 deletion baiji/package_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
#
# See https://www.python.org/dev/peps/pep-0420/#namespace-packages-today

__version__ = '2.6.2'
__version__ = '2.7.0'
42 changes: 41 additions & 1 deletion baiji/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def setUp(self):
self.tmp_dir = tempfile.mkdtemp('bodylabs-test')
self.local_file = create_random_temporary_file()


def tearDown(self):
shutil.rmtree(self.tmp_dir, ignore_errors=True)
os.remove(self.local_file)
Expand Down Expand Up @@ -83,6 +82,21 @@ def existing_remote_file(self):
s3.cp(self.local_file, uri)
return uri

@property
def existing_versioned_remote_file(self):
# use a hardcoded path for test versioned file on S3
# to avoid bookkeeping
# the current test won't make versioned copies of the file
# the remote object will be either deleted (which will be overwritten later)
# or download to local

uri = 's3://baiji-test-versioned/FOO/A_preexisting_file.md'

if not s3.exists(uri):
s3.cp(self.local_file, uri)

return uri

class TestS3Exists(TestAWSBase):

@mock.patch('baiji.connection.S3Connection._lookup')
Expand Down Expand Up @@ -114,6 +128,11 @@ def test_s3_exists_return_false_if_the_file_never_shows_up(self, mock_lookup):
self.assertFalse(s3.exists('s3://foo'))
self.assertEqual(mock_lookup.call_count, 3)

def test_s3_exists_return_false_if_with_unmatched_version_id(self):

# test not exists with specified versionId
unknown_version_id = '5elgojhtA8BGJerqfbciN78eU74SJ9mX'
self.assertFalse(s3.exists(self.existing_versioned_remote_file, version_id=unknown_version_id))

class TestEtag(TestAWSBase):

Expand Down Expand Up @@ -174,6 +193,27 @@ def test_s3_cp_download(self):
s3.cp(self.existing_remote_file, os.path.join(self.tmp_dir, 'DL'))
self.assertTrue(os.path.exists(os.path.join(self.tmp_dir, 'DL', s3.path.basename(self.existing_remote_file))))

def test_s3_cp_download_versioned_success_with_valid_version_id(self):
version_id = s3.info(self.existing_versioned_remote_file)['version_id']
s3.cp(self.existing_versioned_remote_file, os.path.join(self.tmp_dir, 'DL', 'TEST.foo'), version_id=version_id)
self.assertTrue(os.path.exists(os.path.join(self.tmp_dir, 'DL', 'TEST.foo')))

def test_s3_cp_download_versioned_raise_key_not_found_with_unknown_version_id(self):

from baiji.exceptions import KeyNotFound
unknown_version_id = '5elgojhtA8BGJerqfbciN78eU74SJ9mX'
# test raise KeyNotFound with unknown versionId
with self.assertRaises(KeyNotFound):
s3.cp(self.existing_versioned_remote_file, os.path.join(self.tmp_dir, 'DL', 'TEST.foo'), version_id=unknown_version_id)

def test_s3_cp_download_versioned_raise_invalid_version_id_with_bad_version_id(self):
from baiji.exceptions import InvalidVersionID

invalid_version_id = '1111'
# test raise S3ResponseError with invalid versionId
with self.assertRaises(InvalidVersionID):
s3.cp(self.existing_versioned_remote_file, os.path.join(self.tmp_dir, 'DL', 'TEST.foo'), version_id=invalid_version_id)

@mock.patch('baiji.copy.S3CopyOperation.ensure_integrity')
def test_s3_cp_download_corrupted_recover_in_one_retry(self, ensure_integrity_mock):
from baiji.exceptions import get_transient_error_class
Expand Down
17 changes: 17 additions & 0 deletions baiji/util/lookup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
def get_versioned_key_remote(bucket, key_name, version_id=None):
'''
Utility function to get versioned key from a bucket
'''
from boto.exception import S3ResponseError
from baiji.exceptions import InvalidVersionID

key = None
try:
key = bucket.get_key(key_name, version_id=version_id)
except S3ResponseError as e:
if e.status == 400:
raise InvalidVersionID("Invalid versionID %s" % version_id)
else:
raise e
return key

0 comments on commit bfc30a7

Please sign in to comment.