Skip to content

Commit

Permalink
Add utilities for MetaDriveType (#526)
Browse files Browse the repository at this point in the history
* introducing more type class functions

* update scenario description

* format
  • Loading branch information
pengzhenghao authored Oct 24, 2023
1 parent d3a3432 commit eabe101
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 10 deletions.
3 changes: 2 additions & 1 deletion metadrive/scenario/scenario_description.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,8 @@ def get_export_file_name(dataset, dataset_version, scenario_name):
@staticmethod
def is_scenario_file(file_name):
file_name = os.path.basename(file_name)
assert file_name[-4:] == ".pkl", "{} is not .pkl file".format(file_name)
if not file_name.endswith(".pkl"):
return False
file_name = file_name.replace(".pkl", "")
return os.path.basename(file_name)[:3] == "sd_" or all(char.isdigit() for char in file_name)

Expand Down
12 changes: 8 additions & 4 deletions metadrive/scenario/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import os
import pathlib
import pickle

import matplotlib.pyplot as plt
Expand Down Expand Up @@ -356,8 +357,9 @@ def read_dataset_summary(file_folder, check_file_existence=True):
2) the list of all scenarios IDs, and
3) a dict mapping from scenario IDs to the folder that hosts their files.
"""
summary_file = os.path.join(file_folder, SD.DATASET.SUMMARY_FILE)
mapping_file = os.path.join(file_folder, SD.DATASET.MAPPING_FILE)
file_folder = pathlib.Path(file_folder)
summary_file = file_folder / SD.DATASET.SUMMARY_FILE
mapping_file = file_folder / SD.DATASET.MAPPING_FILE
if os.path.isfile(summary_file):
with open(summary_file, "rb") as f:
summary_dict = pickle.load(f)
Expand All @@ -375,18 +377,20 @@ def read_dataset_summary(file_folder, check_file_existence=True):
files = [p for p in files]
summary_dict = {f: read_scenario_data(os.path.join(file_folder, f))["metadata"] for f in files}

mapping = None
if os.path.exists(mapping_file):
with open(mapping_file, "rb") as f:
mapping = pickle.load(f)
else:

if not mapping:
# Create a fake one
mapping = {k: "" for k in summary_dict}

if check_file_existence:
for file in summary_dict:
assert file in mapping, "FileName in mapping mismatch with summary"
assert SD.is_scenario_file(file), "File:{} is not sd scenario file".format(file)
file_path = os.path.join(file_folder, mapping[file], file)
file_path = file_folder / mapping[file] / file
assert os.path.exists(file_path), "Can not find file: {}".format(file_path)

return summary_dict, list(summary_dict.keys()), mapping
Expand Down
34 changes: 29 additions & 5 deletions metadrive/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class MetaDriveType:

# ===== Lane, Road =====
LANE_SURFACE_STREET = "LANE_SURFACE_STREET"
# Unlike a set of lanes separated by broken/solid line, this includes intersection and some unstrcutured roads.
# Unlike a set of lanes separated by broken/solid line, this includes intersection and some unstructured roads.
LANE_SURFACE_UNSTRUCTURE = "LANE_SURFACE_UNSTRUCTURE"
# use them as less frequent as possible, it is for waymo compatibility
LANE_UNKNOWN = "LANE_UNKNOWN"
Expand Down Expand Up @@ -161,10 +161,38 @@ def is_road_boundary_line(cls, edge):
def is_sidewalk(cls, edge):
return edge == cls.BOUNDARY_SIDEWALK

@classmethod
def is_stop_sign(cls, type):
return type == MetaDriveType.STOP_SIGN

@classmethod
def is_speed_bump(cls, type):
return type == MetaDriveType.SPEED_BUMP

@classmethod
def is_driveway(cls, type):
return type == MetaDriveType.DRIVEWAY

@classmethod
def is_crosswalk(cls, type):
return type == MetaDriveType.CROSSWALK

@classmethod
def is_vehicle(cls, type):
return type == cls.VEHICLE

@classmethod
def is_pedestrian(cls, type):
return type == cls.PEDESTRIAN

@classmethod
def is_cyclist(cls, type):
return type == cls.CYCLIST

@classmethod
def is_participant(cls, type):
return type in (cls.CYCLIST, cls.PEDESTRIAN, cls.VEHICLE, cls.UNSET, cls.OTHER)

@classmethod
def is_traffic_light_in_yellow(cls, light):
return cls.simplify_light_status(light) == cls.LIGHT_YELLOW
Expand Down Expand Up @@ -211,10 +239,6 @@ def simplify_light_status(cls, status: str):
logger.warning("TrafficLightStatus: {} is not MetaDriveType".format(status))
return cls.LIGHT_UNKNOWN

@classmethod
def is_crosswalk(cls, type):
return type == MetaDriveType.CROSSWALK

def __init__(self, type=None):
# TODO extend this base class to all objects! It is only affect lane so far.
# TODO Or people can only know the type with isinstance()
Expand Down

0 comments on commit eabe101

Please sign in to comment.