Skip to content

Commit

Permalink
Make dataset internals more robust; fix TensorFlow MNIST test (all pl…
Browse files Browse the repository at this point in the history
…atforms), fix imdb test (macOS)
  • Loading branch information
ttytm committed May 27, 2024
1 parent 0403eb8 commit b977e82
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 77 deletions.
16 changes: 6 additions & 10 deletions datasets/imdb.v
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ module datasets
import vtl
import os

pub const imdb_folder_name = 'aclImdb'
pub const imdb_file_name = '${imdb_folder_name}_v1.tar.gz'
pub const imdb_file_name = 'aclImdb_v1.tar.gz'
pub const imdb_base_url = 'http://ai.stanford.edu/~amaas/data/sentiment/'

// ImdbDataset is a dataset for sentiment analysis.
Expand All @@ -18,20 +17,17 @@ pub:

// load_imdb_helper loads the IMDB dataset for a given split.
fn load_imdb_helper(split string) !(&vtl.Tensor[string], &vtl.Tensor[int]) {
paths := download_dataset(
dataset_path := download_dataset(
dataset: 'imdb'
baseurl: datasets.imdb_base_url
extract: true
tar: true
urls_names: {
datasets.imdb_file_name: datasets.imdb_folder_name
}
compressed: true
uncompressed_dir: 'aclImdb'
file: datasets.imdb_file_name
)!

mut split_paths := []string{}

dataset_dir := paths[datasets.imdb_file_name]
split_dir := os.join_path(dataset_dir, split)
split_dir := os.join_path(dataset_path, split)
pos_dir := os.join_path(split_dir, 'pos')
neg_dir := os.join_path(split_dir, 'neg')

Expand Down
94 changes: 42 additions & 52 deletions datasets/loader.v
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ struct RawDownload {

fn load_from_url(data RawDownload) ! {
datasets_cache_dir := get_cache_dir('datasets')

if !os.is_dir(datasets_cache_dir) {
os.mkdir_all(datasets_cache_dir)!
}
Expand All @@ -31,76 +30,67 @@ fn load_from_url(data RawDownload) ! {
} else {
data.target
}

if os.is_file(cache_file_path) {
return
}

http.download_file(data.url, cache_file_path)!
}

@[params]
struct DatasetDownload {
dataset string
baseurl string
extract bool
tar bool
urls_names map[string]string
dataset string
baseurl string
file string
compressed bool
uncompressed_dir string @[required]
}

fn download_dataset(data DatasetDownload) !map[string]string {
mut loaded_paths := map[string]string{}
fn download_dataset(data DatasetDownload) !string {
dataset_dir := os.real_path(get_cache_dir('datasets', data.dataset))

for path, filename in data.urls_names {
dataset_dir := get_cache_dir('datasets', data.dataset)
// Handline extensions like `*.tar.gz`.
exts := os.file_name(data.file).rsplit_nth('.', 3)
is_tar := exts[0] == 'tar' || (exts.len > 1 && exts[1] == 'tar')

target := os.join_path(dataset_dir, data.file)
if os.exists(target) {
$if debug ? {
println('${data.file} already exists')
}
} else {
if !os.is_dir(dataset_dir) {
os.mkdir_all(dataset_dir)!
}

target := os.join_path(dataset_dir, filename)

if os.exists(target) {
$if debug ? {
// we assume that the correct extraction process was done
// before
// TODO: check for extraction...
println('${filename} already exists')
}
} else {
$if debug ? {
println('Downloading ${filename} from ${data.baseurl}${path}')
}
load_from_url(url: '${data.baseurl}${path}', target: target)!
if data.extract {
$if debug ? {
println('Downloading ${data.file} from ${data.baseurl}${data.file}')
}
load_from_url(url: '${data.baseurl}${data.file}', target: target)!
}
uncompressed_path := os.join_path(dataset_dir, data.uncompressed_dir)
if data.compressed && !os.is_dir(uncompressed_path) {
$if debug ? {
println('Extracting ${data.file}')
}
if is_tar {
result := os.execute('tar -xvzf ${target} -C ${dataset_dir}')
if result.exit_code != 0 {
$if debug ? {
println('Extracting ${target}')
}
if data.tar {
result := os.execute('tar -xvzf ${target} -C ${dataset_dir}')
if result.exit_code != 0 {
$if debug ? {
println('Error extracting ${target}')
println('Exit code: ${result.exit_code}')
println('Output: ${result.output}')
}
return error_with_code('Error extracting ${target}', result.exit_code)
}
} else {
file_content := os.read_file(target)!
content := gzip.decompress(file_content.bytes(),
verify_header_checksum: true
verify_length: false
verify_checksum: false
)!
umcompressed_filename := target#[0..-3]
os.write_file(umcompressed_filename, content.bytestr())!
println('Error extracting ${target}')
println('Exit code: ${result.exit_code}')
println('Output: ${result.output}')
}
return error_with_code('Error extracting ${target}', result.exit_code)
}
} else {
file_content := os.read_file(target)!
content := gzip.decompress(file_content.bytes(),
verify_header_checksum: true
verify_length: false
verify_checksum: false
)!
os.write_file(uncompressed_path, content.bytestr())!
}

loaded_paths[path] = target
}

return loaded_paths
return uncompressed_path
}
21 changes: 7 additions & 14 deletions datasets/mnist.v
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module datasets
import vtl
import os

pub const mnist_base_url = 'http://yann.lecun.com/exdb/mnist/'
pub const mnist_base_url = 'https://github.com/golbin/TensorFlow-MNIST/raw/master/mnist/data/'
pub const mnist_train_images_file = 'train-images-idx3-ubyte.gz'
pub const mnist_train_labels_file = 'train-labels-idx1-ubyte.gz'
pub const mnist_test_images_file = 't10k-images-idx3-ubyte.gz'
Expand All @@ -19,36 +19,29 @@ pub:
}

// load_mnist_helper loads the MNIST dataset from the given filename.
fn load_mnist_helper(filename string) !string {
paths := download_dataset(
fn load_mnist_helper(file string) !string {
dataset_path := download_dataset(
dataset: 'mnist'
baseurl: datasets.mnist_base_url
extract: true
urls_names: {
filename: filename
}
compressed: true
uncompressed_dir: file.all_before_last('.')
file: file
)!

path := paths[filename]
uncompressed_path := path#[0..-3]
return os.read_file(uncompressed_path)!
return os.read_file(dataset_path)!
}

// load_mnist_features loads the MNIST features.
fn load_mnist_features(filename string) !&vtl.Tensor[u8] {
content := load_mnist_helper(filename)!

features := content[16..].bytes()

return vtl.from_1d(features)!.reshape([-1, 28, 28])
}

// load_mnist_labels loads the MNIST labels.
fn load_mnist_labels(filename string) !&vtl.Tensor[u8] {
content := load_mnist_helper(filename)!

labels := content[8..].bytes()

return vtl.from_1d(labels)!
}

Expand Down
1 change: 0 additions & 1 deletion datasets/mnist_test.v
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
// vtest flaky: true
// vtest retry: 3
import vtl.datasets

Expand Down

0 comments on commit b977e82

Please sign in to comment.