diff --git a/utils/data_generator.py b/utils/data_generator.py index 046d3e8..1bde58d 100644 --- a/utils/data_generator.py +++ b/utils/data_generator.py @@ -207,7 +207,13 @@ def transform(self, x): Use the calculated scalar to transform data. """ - return (x - self.mean) / self.std + # the numpy contain zero, + # dividing zero gives NAN, which break the code + # return (x - self.mean) / self.std + x = (x - self.mean) / self.std + x[np.isnan(x)] = 0 + + return x def generate_train(self): """Generate batch data for training. @@ -353,7 +359,14 @@ def transform(self, x): Use the calculated scalar to transform data. """ - return (x - self.mean) / self.std + # return (x - self.mean) / self.std + # the numpy contain zero, + # dividing zero gives NAN, which break the code + # return (x - self.mean) / self.std + x = (x - self.mean) / self.std + x[np.isnan(x)] = 0 + + return x def generate_eval(self): """