diff --git a/build/launch_training.py b/build/launch_training.py index 57ede1f6b..5592d5888 100644 --- a/build/launch_training.py +++ b/build/launch_training.py @@ -39,8 +39,12 @@ def txt_to_obj(txt): base64_bytes = txt.encode("ascii") message_bytes = base64.b64decode(base64_bytes) - obj = pickle.loads(message_bytes) - return obj + try: + # If the bytes represent JSON string + return json.loads(message_bytes) + except UnicodeDecodeError: + # Otherwise the bytes are a pickled python dictionary + return pickle.loads(message_bytes) def get_highest_checkpoint(dir_path):