forked from eladhoffer/ImageNet-Training
-
Notifications
You must be signed in to change notification settings - Fork 0
/
CreateLMDBs.lua
108 lines (91 loc) · 2.66 KB
/
CreateLMDBs.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
require 'image'
require 'xlua'
require 'lmdb'
local gm = require 'graphicsmagick'
local DataProvider = require 'DataProvider'
local config = require 'Config'
-------------------------------Settings----------------------------------------------
local PreProcess = function(Img)
local im = image.scale(Img, '^' .. config.ImageMinSide) --minimum side of ImageMinSide
if im:dim() == 2 then
im = im:reshape(1,im:size(1),im:size(2))
end
if im:size(1) == 1 then
im=torch.repeatTensor(im,3,1,1)
end
if im:size(1) > 3 then
im = im[{{1,3},{},{}}]
end
return im
end
local LoadImgData = function(filename)
local img = gm.Image(filename):toTensor('float','RGB','DHW')
if img == nil then
print('Image is buggy')
print(filename)
os.exit()
end
img = PreProcess(img)
if config.Compressed then
return image.compressJPG(img)
else
return img
end
end
function NameFile(filename)
local name = paths.basename(filename,'JPEG')
local substring = string.split(name,'_')
if substring[1] == 'ILSVRC2012' then -- Validation file
local num = tonumber(substring[3])
return config.ImageNetClasses.ClassNum2Wnid[config.ValidationLabels[num]] .. '_' .. num
else -- Training file
return name
end
end
function LMDBFromFilenames(filenamesProvider,env)
env:open()
local txn = env:txn()
local cursor = txn:cursor()
for i=1, filenamesProvider:size() do
local filename = filenamesProvider:getItem(i)
local data = {Data = LoadImgData(filename), Name = NameFile(filename)}
cursor:put(config.Key(i),data, lmdb.C.MDB_NODUPDATA)
if i % 1000 == 0 then
txn:commit()
print(env:stat())
collectgarbage()
txn = env:txn()
cursor = txn:cursor()
end
xlua.progress(i,filenamesProvider:size())
end
txn:commit()
env:close()
end
local TrainingFiles = DataProvider.FileSearcher{
Name = 'TrainingFilenames',
CachePrefix = config.TRAINING_DIR,
MaxNumItems = 1e8,
CacheFiles = true,
PathList = {config.TRAINING_PATH},
SubFolders = true,
Verbose = true
}
local ValidationFiles = DataProvider.FileSearcher{
Name = 'ValidationFilenames',
CachePrefix = config.VALIDATION_DIR,
MaxNumItems = 1e8,
PathList = {config.VALIDATION_PATH},
Verbose = true
}
local TrainDB = lmdb.env{
Path = config.TRAINING_DIR,
Name = 'TrainDB'
}
local ValDB = lmdb.env{
Path = config.VALIDATION_DIR,
Name = 'ValDB'
}
TrainingFiles:shuffleItems()
LMDBFromFilenames(ValidationFiles, ValDB)
LMDBFromFilenames(TrainingFiles, TrainDB)