Skip to content

Commit

Permalink
debug
Browse files Browse the repository at this point in the history
  • Loading branch information
HanxSmile committed Apr 4, 2023
1 parent 2eb4e92 commit 5475958
Showing 1 changed file with 75 additions and 46 deletions.
121 changes: 75 additions & 46 deletions dsdl/converter/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from pycocotools import mask as mask_util
import xml.etree.ElementTree as ET


DSDL_TASK_NAMES_TEMPLATE = {
"Image Classification": "class_template",
"Object Detection": "detection_template",
Expand All @@ -28,6 +27,7 @@

DSDL_META_KEYS = ("Dataset Name", "HomePage", "Modality", "Task", "Subset Name")


def load_yaml(yaml_path):
try:
with open(yaml_path, 'r', encoding='utf-8') as fp:
Expand All @@ -36,6 +36,7 @@ def load_yaml(yaml_path):
except:
return None


def load_json(json_path):
try:
with open(json_path, 'r', encoding='utf-8') as fp:
Expand All @@ -44,6 +45,7 @@ def load_json(json_path):
except:
return None


def load_text(text_path):
try:
with open(text_path, 'r', encoding='utf-8') as fp:
Expand All @@ -53,6 +55,7 @@ def load_text(text_path):
except:
return None


def bbox_xymin_xymax_to_xymin_w_h(bbox_value):
bbox_result = None
if len(bbox_value) == 4:
Expand All @@ -67,6 +70,7 @@ def bbox_xymin_xymax_to_xymin_w_h(bbox_value):
raise ValueError(f"bbox_value length should be 4.")
return bbox_result


def bbox_xymin_w_h_to_xymin_xymax(bbox_value):
bbox_result = None
if len(bbox_value) == 4:
Expand All @@ -81,48 +85,52 @@ def bbox_xymin_w_h_to_xymin_xymax(bbox_value):
raise ValueError(f"bbox_value length should be 4.")
return bbox_result


def bbox_xycenter_w_h_to_xymin_w_h(bbox_value):
bbox_result = None
if len(bbox_value) == 4:
x_center = float(bbox_value[0])
y_center = float(bbox_value[1])
width = float(bbox_value[2])
height = float(bbox_value[3])
xmin = x_center - (width/2)
ymin = y_center - (height/2)
xmin = x_center - (width / 2)
ymin = y_center - (height / 2)
bbox_result = [xmin, ymin, width, height]
else:
raise ValueError(f"bbox_value length should be 4.")
return bbox_result


def bbox_xycenter_w_h_normal_to_xymin_w_h(bbox_value, image_width, image_height):
bbox_result = None
if len(bbox_value) == 4:
x_center = float(bbox_value[0]) * float(image_width)
y_center = float(bbox_value[1]) * float(image_height)
width = float(bbox_value[2]) * float(image_width)
height = float(bbox_value[3]) * float(image_height)
xmin = x_center - (width/2)
ymin = y_center - (height/2)
xmin = x_center - (width / 2)
ymin = y_center - (height / 2)
bbox_result = [xmin, ymin, width, height]
else:
raise ValueError(f"bbox_value length should be 4.")
return bbox_result


def bbox_xymin_w_h_to_xycenter_w_h(bbox_value):
bbox_result = None
if len(bbox_value) == 4:
xmin = float(bbox_value[0])
ymin = float(bbox_value[1])
width = float(bbox_value[2])
height = float(bbox_value[3])
x_center = xmin + (width/2)
y_center = ymin + (height/2)
x_center = xmin + (width / 2)
y_center = ymin + (height / 2)
bbox_result = [x_center, y_center, width, height]
else:
raise ValueError(f"bbox_value length should be 4.")
return bbox_result


def bbox_xymin_w_h_to_xycenter_w_h_normal(bbox_value, image_width, image_height):
bbox_result = None
if len(bbox_value) == 4:
Expand All @@ -131,11 +139,11 @@ def bbox_xymin_w_h_to_xycenter_w_h_normal(bbox_value, image_width, image_height)
width = float(bbox_value[2])
height = float(bbox_value[3])

x_center = xmin + (width/2)
y_center = ymin + (height/2)
dw = 1/float(image_width)
dh = 1/float(image_height)
x_center = xmin + (width / 2)
y_center = ymin + (height / 2)
dw = 1 / float(image_width)
dh = 1 / float(image_height)

x_center *= dw
y_center *= dh
width *= dw
Expand All @@ -145,29 +153,32 @@ def bbox_xymin_w_h_to_xycenter_w_h_normal(bbox_value, image_width, image_height)
raise ValueError(f"bbox_value length should be 4.")
return bbox_result


def replace_special_characters(str_in):
str_out = re.sub("\W", "_", str_in)
return str_out


def annToRLE(ann, img_height, img_width):
"""
Convert annotation which can be polygons, uncompressed RLE to RLE.
:return: binary mask (numpy 2D array)
"""
h, w = img_height, img_width
segm = ann['segmentation']
if type(segm) == list:
# polygon -- a single object might consist of multiple parts
# we merge all parts into one mask rle code
rles = mask_util.frPyObjects(segm, h, w)
rle = mask_util.merge(rles)
elif type(segm['counts']) == list:
# uncompressed RLE
rle = mask_util.frPyObjects(segm, h, w)
else:
# rle
rle = ann['segmentation']
return rle
"""
Convert annotation which can be polygons, uncompressed RLE to RLE.
:return: binary mask (numpy 2D array)
"""
h, w = img_height, img_width
segm = ann['segmentation']
if type(segm) == list:
# polygon -- a single object might consist of multiple parts
# we merge all parts into one mask rle code
rles = mask_util.frPyObjects(segm, h, w)
rle = mask_util.merge(rles)
elif type(segm['counts']) == list:
# uncompressed RLE
rle = mask_util.frPyObjects(segm, h, w)
else:
# rle
rle = ann['segmentation']
return rle


def annToMask(ann, h, w):
"""
Expand All @@ -178,6 +189,7 @@ def annToMask(ann, h, w):
m = mask_util.decode(rle)
return m


def mask2polygon(mask):
# fortran_ground_truth_binary_mask = np.asfortranarray(mask)
# encoded_ground_truth = mask_util.encode(fortran_ground_truth_binary_mask)
Expand All @@ -192,18 +204,20 @@ def mask2polygon(mask):
segmentations.append(segmentation)
return segmentations


def rle2polygon(ann, img_height, img_width):
curr_mask = annToMask(ann, img_height, img_width)
curr_polygon = mask2polygon(curr_mask)
result = []
for k, v in enumerate(curr_polygon):
curr_v = np.array(v).reshape(-1,2).tolist()
curr_v = np.array(v).reshape(-1, 2).tolist()
result.append(curr_v)
return result

def generate_class_dom(dsdl_root_path, names_list, class_dom_name="ClassDom"):

def generate_class_dom(dsdl_root_path, names_list, class_dom_name="ClassDom", add_quote=False):
# the dsdl_version should get from dsdl sdk.
dsdl_version="0.5.3"
dsdl_version = "0.5.3"
if not class_dom_name:
raise ValueError(f"The class-dom name is {class_dom_name}, please specify the correct class-dom name !")
if len(names_list) == 0:
Expand All @@ -220,17 +234,21 @@ def generate_class_dom(dsdl_root_path, names_list, class_dom_name="ClassDom"):
code_str += " $def: class_domain\n"
code_str += " classes:\n"
for cname in names_list:
code_str += f" - {cname}\n"
if add_quote:
code_str += f' - "{cname}"\n'
else:
code_str += f' - {cname}\n'
with open(class_dom_path, "w", encoding='utf-8') as fp1:
fp1.write(code_str)
else:
print(f"{class_dom_path} already exists !")
print("class_dom.yaml is generated.")


def generate_global_info(dsdl_root_path, class_info_list):
if len(class_info_list) == 0:
raise ValueError("The class_info_list is an empty list, please specify the correct class_info_list !")

save_path_p = Path(dsdl_root_path)
save_defs_p = save_path_p.joinpath("defs")
if not save_defs_p.exists():
Expand All @@ -245,6 +263,7 @@ def generate_global_info(dsdl_root_path, class_info_list):
else:
print(f"The {global_info_path} already exists !")


def get_dsdl_template_file_name(dsdl_root_path):
save_path_p = Path(dsdl_root_path)
save_defs_p = save_path_p.joinpath("defs")
Expand All @@ -259,6 +278,7 @@ def get_dsdl_template_file_name(dsdl_root_path):
template_file = item_file.name
return template_file


def get_subset_yaml_str(meta_info, template_file_name, sample_struct_name, class_dom_name, dsdl_version):
dataset_name = meta_info["Dataset Name"]
homepage = meta_info["HomePage"]
Expand All @@ -281,6 +301,7 @@ def get_subset_yaml_str(meta_info, template_file_name, sample_struct_name, class
yaml_str += f' sample-path: {subset_name}_samples.json\n'
return yaml_str


def generate_subset_yaml_and_json(meta_dict, dsdl_root_path, samples_list):
save_path_p = Path(dsdl_root_path)
sub_name = meta_dict['Subset Name']
Expand All @@ -296,27 +317,28 @@ def generate_subset_yaml_and_json(meta_dict, dsdl_root_path, samples_list):
template_file_name = get_dsdl_template_file_name(dsdl_root_path)
if not template_file_name:
raise FileNotFoundError("The task template file is missing .")

sample_struct_name = get_dsdl_sample_struct_name(dsdl_root_path)
class_dom_name = get_dsdl_class_dom_name(dsdl_root_path)

# the dsdl_version should get from dsdl sdk.
dsdl_version="0.5.3"
dsdl_version = "0.5.3"
yaml_str = get_subset_yaml_str(meta_dict, template_file_name, sample_struct_name, class_dom_name, dsdl_version)
if not sub_yaml_path.exists():
with open(sub_yaml_path, "w", encoding='utf-8') as fp1:
fp1.write(yaml_str)

samples_result = {"samples": samples_list}
with open(sub_sample_json_path, "w", encoding="utf-8") as fp:
json.dump(samples_result, fp)


def parse_xml_det_task(xml_path):
tree = ET.parse(xml_path)
root = tree.getroot()
sample_dict = {}
object_list = []

size_node = root.find("size")
sample_dict["width"] = int(size_node.find("width").text)
sample_dict["height"] = int(size_node.find("height").text)
Expand All @@ -327,7 +349,7 @@ def parse_xml_det_task(xml_path):
curr_obj_dict = {}
name = obj_node.find("name").text
curr_obj_dict["name"] = name

bndbox_node = obj_node.find("bndbox")
xmin = float(bndbox_node.find("xmin").text)
xmax = float(bndbox_node.find("xmax").text)
Expand All @@ -338,8 +360,8 @@ def parse_xml_det_task(xml_path):

difficult = obj_node.find("difficult")
occluded = obj_node.find("occluded")
truncated = obj_node.find("truncated")
pose = obj_node.find("pose")
truncated = obj_node.find("truncated")
pose = obj_node.find("pose")
if difficult is not None:
curr_obj_dict["difficult"] = int(difficult.text)
if occluded is not None:
Expand All @@ -352,6 +374,7 @@ def parse_xml_det_task(xml_path):
sample_dict["objects"] = object_list
return sample_dict


def check_dsdl_meta_info(meta_info_in):
if isinstance(meta_info_in, dict):
meta_keys_set = set(DSDL_META_KEYS)
Expand All @@ -360,19 +383,23 @@ def check_dsdl_meta_info(meta_info_in):
if not meta_info_in["Dataset Name"]:
raise ValueError(f"Dataset Name is null, please specify the dataset name !")
if meta_info_in["Modality"] not in DSDL_MODALITYS:
raise ValueError(f"DSDL Modality' value must in {DSDL_MODALITYS}, but current value is {meta_info_in['Modality']}")
raise ValueError(
f"DSDL Modality' value must in {DSDL_MODALITYS}, but current value is {meta_info_in['Modality']}")
if meta_info_in["Task"] not in DSDL_TASK_NAMES_TEMPLATE:
raise ValueError(f"Task value must in {DSDL_TASK_NAMES_TEMPLATE}, but current value is {meta_info_in['Task']}")
raise ValueError(
f"Task value must in {DSDL_TASK_NAMES_TEMPLATE}, but current value is {meta_info_in['Task']}")
else:
raise ValueError(f"meta_info_in's keys must in {DSDL_META_KEYS}, but current keys is {curr_meta_keys_set}")
else:
raise TypeError("The param meta_info_in's type must be dict !")


def check_task_template_file(dsdl_root_path):
template_name = get_dsdl_template_file_name(dsdl_root_path)
if not template_name:
raise FileNotFoundError(f"The task template file is missing .")


def struct_sort(struct_dict_in):
digraph_obj = nx.DiGraph()
struct_names = list(struct_dict_in.keys())
Expand All @@ -382,12 +409,13 @@ def struct_sort(struct_dict_in):
for field_value in s_fields_values:
for k in struct_names:
if k in field_value:
digraph_obj.add_edge(k, s_name)
digraph_obj.add_edge(k, s_name)
if not nx.is_directed_acyclic_graph(digraph_obj):
raise "define cycle found."
ordered_struct_name = list(nx.topological_sort(digraph_obj))
return ordered_struct_name


def get_dsdl_sample_struct_name(dsdl_root_path):
save_path_p = Path(dsdl_root_path)
save_defs_p = save_path_p.joinpath("defs")
Expand All @@ -400,6 +428,7 @@ def get_dsdl_sample_struct_name(dsdl_root_path):
order_names = struct_sort(curr_data)
return order_names[-1]


def get_dsdl_class_dom_name(dsdl_root_path):
save_path_p = Path(dsdl_root_path)
save_defs_p = save_path_p.joinpath("defs")
Expand Down

0 comments on commit 5475958

Please sign in to comment.