diff --git a/datasets/imdb.v b/datasets/imdb.v index 7155a85a..9f106ffc 100644 --- a/datasets/imdb.v +++ b/datasets/imdb.v @@ -3,8 +3,7 @@ module datasets import vtl import os -pub const imdb_folder_name = 'aclImdb' -pub const imdb_file_name = '${imdb_folder_name}_v1.tar.gz' +pub const imdb_file_name = 'aclImdb_v1.tar.gz' pub const imdb_base_url = 'http://ai.stanford.edu/~amaas/data/sentiment/' // ImdbDataset is a dataset for sentiment analysis. @@ -18,20 +17,17 @@ pub: // load_imdb_helper loads the IMDB dataset for a given split. fn load_imdb_helper(split string) !(&vtl.Tensor[string], &vtl.Tensor[int]) { - paths := download_dataset( + dataset_path := download_dataset( dataset: 'imdb' baseurl: datasets.imdb_base_url - extract: true - tar: true - urls_names: { - datasets.imdb_file_name: datasets.imdb_folder_name - } + compressed: true + uncompressed_dir: 'aclImdb' + file: datasets.imdb_file_name )! mut split_paths := []string{} - dataset_dir := paths[datasets.imdb_file_name] - split_dir := os.join_path(dataset_dir, split) + split_dir := os.join_path(dataset_path, split) pos_dir := os.join_path(split_dir, 'pos') neg_dir := os.join_path(split_dir, 'neg') diff --git a/datasets/loader.v b/datasets/loader.v index 7faaa3e1..f7f7a3ff 100644 --- a/datasets/loader.v +++ b/datasets/loader.v @@ -21,7 +21,6 @@ struct RawDownload { fn load_from_url(data RawDownload) ! { datasets_cache_dir := get_cache_dir('datasets') - if !os.is_dir(datasets_cache_dir) { os.mkdir_all(datasets_cache_dir)! } @@ -31,76 +30,67 @@ fn load_from_url(data RawDownload) ! { } else { data.target } - if os.is_file(cache_file_path) { return } - http.download_file(data.url, cache_file_path)! } @[params] struct DatasetDownload { - dataset string - baseurl string - extract bool - tar bool - urls_names map[string]string + dataset string + baseurl string + file string + compressed bool + uncompressed_dir string @[required] } -fn download_dataset(data DatasetDownload) !map[string]string { - mut loaded_paths := map[string]string{} +fn download_dataset(data DatasetDownload) !string { + dataset_dir := os.real_path(get_cache_dir('datasets', data.dataset)) - for path, filename in data.urls_names { - dataset_dir := get_cache_dir('datasets', data.dataset) + // Handline extensions like `*.tar.gz`. + exts := os.file_name(data.file).rsplit_nth('.', 3) + is_tar := exts[0] == 'tar' || (exts.len > 1 && exts[1] == 'tar') + target := os.join_path(dataset_dir, data.file) + if os.exists(target) { + $if debug ? { + println('${data.file} already exists') + } + } else { if !os.is_dir(dataset_dir) { os.mkdir_all(dataset_dir)! } - - target := os.join_path(dataset_dir, filename) - - if os.exists(target) { - $if debug ? { - // we assume that the correct extraction process was done - // before - // TODO: check for extraction... - println('${filename} already exists') - } - } else { - $if debug ? { - println('Downloading ${filename} from ${data.baseurl}${path}') - } - load_from_url(url: '${data.baseurl}${path}', target: target)! - if data.extract { + $if debug ? { + println('Downloading ${data.file} from ${data.baseurl}${data.file}') + } + load_from_url(url: '${data.baseurl}${data.file}', target: target)! + } + uncompressed_path := os.join_path(dataset_dir, data.uncompressed_dir) + if data.compressed && !os.is_dir(uncompressed_path) { + $if debug ? { + println('Extracting ${data.file}') + } + if is_tar { + result := os.execute('tar -xvzf ${target} -C ${dataset_dir}') + if result.exit_code != 0 { $if debug ? { - println('Extracting ${target}') - } - if data.tar { - result := os.execute('tar -xvzf ${target} -C ${dataset_dir}') - if result.exit_code != 0 { - $if debug ? { - println('Error extracting ${target}') - println('Exit code: ${result.exit_code}') - println('Output: ${result.output}') - } - return error_with_code('Error extracting ${target}', result.exit_code) - } - } else { - file_content := os.read_file(target)! - content := gzip.decompress(file_content.bytes(), - verify_header_checksum: true - verify_length: false - verify_checksum: false - )! - umcompressed_filename := target#[0..-3] - os.write_file(umcompressed_filename, content.bytestr())! + println('Error extracting ${target}') + println('Exit code: ${result.exit_code}') + println('Output: ${result.output}') } + return error_with_code('Error extracting ${target}', result.exit_code) } + } else { + file_content := os.read_file(target)! + content := gzip.decompress(file_content.bytes(), + verify_header_checksum: true + verify_length: false + verify_checksum: false + )! + os.write_file(uncompressed_path, content.bytestr())! } - - loaded_paths[path] = target } - return loaded_paths + return uncompressed_path } diff --git a/datasets/mnist.v b/datasets/mnist.v index 83df51cb..a4e9b9f0 100644 --- a/datasets/mnist.v +++ b/datasets/mnist.v @@ -3,7 +3,7 @@ module datasets import vtl import os -pub const mnist_base_url = 'http://yann.lecun.com/exdb/mnist/' +pub const mnist_base_url = 'https://github.com/golbin/TensorFlow-MNIST/raw/master/mnist/data/' pub const mnist_train_images_file = 'train-images-idx3-ubyte.gz' pub const mnist_train_labels_file = 'train-labels-idx1-ubyte.gz' pub const mnist_test_images_file = 't10k-images-idx3-ubyte.gz' @@ -19,36 +19,29 @@ pub: } // load_mnist_helper loads the MNIST dataset from the given filename. -fn load_mnist_helper(filename string) !string { - paths := download_dataset( +fn load_mnist_helper(file string) !string { + dataset_path := download_dataset( dataset: 'mnist' baseurl: datasets.mnist_base_url - extract: true - urls_names: { - filename: filename - } + compressed: true + uncompressed_dir: file.all_before_last('.') + file: file )! - path := paths[filename] - uncompressed_path := path#[0..-3] - return os.read_file(uncompressed_path)! + return os.read_file(dataset_path)! } // load_mnist_features loads the MNIST features. fn load_mnist_features(filename string) !&vtl.Tensor[u8] { content := load_mnist_helper(filename)! - features := content[16..].bytes() - return vtl.from_1d(features)!.reshape([-1, 28, 28]) } // load_mnist_labels loads the MNIST labels. fn load_mnist_labels(filename string) !&vtl.Tensor[u8] { content := load_mnist_helper(filename)! - labels := content[8..].bytes() - return vtl.from_1d(labels)! } diff --git a/datasets/mnist_test.v b/datasets/mnist_test.v index 0c232c67..30a05274 100644 --- a/datasets/mnist_test.v +++ b/datasets/mnist_test.v @@ -1,4 +1,3 @@ -// vtest flaky: true // vtest retry: 3 import vtl.datasets