diff --git a/nnabla_rl/__init__.py b/nnabla_rl/__init__.py index 20806c2f..bd9971ee 100644 --- a/nnabla_rl/__init__.py +++ b/nnabla_rl/__init__.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '0.14.0.dev1' +__version__ = '0.15.0.dev1' from nnabla_rl.logger import enable_logging, disable_logging # noqa from nnabla_rl.scopes import eval_scope, is_eval_scope # noqa diff --git a/nnabla_rl/models/pybullet/q_functions.py b/nnabla_rl/models/pybullet/q_functions.py index 6157e7c6..98ccc942 100644 --- a/nnabla_rl/models/pybullet/q_functions.py +++ b/nnabla_rl/models/pybullet/q_functions.py @@ -107,7 +107,7 @@ def argmax_q(self, s: Tuple[nn.Variable, nn.Variable]) -> nn.Variable: def objective_function(a: nn.Variable) -> nn.Variable: batch_size, sample_size, action_dim = a.shape a = a.reshape((batch_size*sample_size, action_dim)) - q_value = self.q(tiled_s, a) # type: ignore + q_value = self.q(tiled_s, a) q_value = q_value.reshape((batch_size, sample_size, 1)) return q_value