forked from eladhoffer/ImageNet-Training
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Data.lua
111 lines (93 loc) · 3.27 KB
/
Data.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
require 'xlua'
require 'lmdb'
local DataProvider = require 'DataProvider'
local config = require 'Config'
function ExtractFromLMDBTrain(data)
require 'image'
local reSample = function(sampledImg)
local sizeImg = sampledImg:size()
local szx = torch.random(math.ceil(sizeImg[3]/4))
local szy = torch.random(math.ceil(sizeImg[2]/4))
local startx = torch.random(szx)
local starty = torch.random(szy)
return image.scale(sampledImg:narrow(2,starty,sizeImg[2]-szy):narrow(3,startx,sizeImg[3]-szx),sizeImg[3],sizeImg[2])
end
local rotate = function(angleRange)
local applyRot = function(Data)
local angle = torch.randn(1)[1]*angleRange
local rot = image.rotate(Data,math.rad(angle),'bilinear')
return rot
end
return applyRot
end
local wnid = string.split(data.Name,'_')[1]
local class = config.ImageNetClasses.Wnid2ClassNum[wnid]
local img = data.Data
if config.Compressed then
img = image.decompressJPG(img,3,'byte')
end
if math.min(img:size(2), img:size(3)) ~= config.ImageMinSide then
img = image.scale(img, '^' .. config.ImageMinSide)
end
if config.Augment == 3 then
img = rotate(0.1)(img)
img = reSample(img)
elseif config.Augment == 2 then
img = reSample(img)
end
local startX = math.random(img:size(3)-config.InputSize[3]+1)
local startY = math.random(img:size(2)-config.InputSize[2]+1)
img = img:narrow(3,startX,config.InputSize[3]):narrow(2,startY,config.InputSize[2])
local hflip = torch.random(2)==1
if hflip then
img = image.hflip(img)
end
return img, class
end
function ExtractFromLMDBTest(data)
require 'image'
local wnid = string.split(data.Name,'_')[1]
local class = config.ImageNetClasses.Wnid2ClassNum[wnid]
local img = data.Data
if config.Compressed then
img = image.decompressJPG(img,3,'byte')
end
if (math.min(img:size(2), img:size(3)) ~= config.ImageMinSide) then
img = image.scale(img, '^' .. config.ImageMinSide)
end
local startX = math.ceil((img:size(3)-config.InputSize[3]+1)/2)
local startY = math.ceil((img:size(2)-config.InputSize[2]+1)/2)
img = img:narrow(3,startX,config.InputSize[3]):narrow(2,startY,config.InputSize[2])
return img, class
end
function Keys(tensor)
local tbl = {}
for i=1,tensor:size(1) do
tbl[i] = config.Key(tensor[i])
end
return tbl
end
function EstimateMeanStd(DB, typeVal, numEst)
local typeVal = typeVal or 'simple'
local numEst = numEst or 10000
local x = torch.FloatTensor(numEst ,unpack(config.InputSize))
local randKeys = Keys(torch.randperm(DB:size()):narrow(1,1,numEst))
DB:CacheRand(randKeys, x)
local dp = DataProvider.Container{
Source = {x, nil}
}
return {typeVal, dp:normalize(typeVal)}
end
local TrainDB = DataProvider.LMDBProvider{
Source = lmdb.env({Path = config.TRAINING_DIR, RDONLY = true}),
ExtractFunction = ExtractFromLMDBTrain
}
local ValDB = DataProvider.LMDBProvider{
Source = lmdb.env({Path = config.VALIDATION_DIR , RDONLY = true}),
ExtractFunction = ExtractFromLMDBTest
}
return {
ImageNetClasses = config.ImageNetClasses,
ValDB = ValDB,
TrainDB = TrainDB,
}