Skip to content

Commit

Permalink
classdom support get_hierarchy_info
Browse files Browse the repository at this point in the history
  • Loading branch information
wufan-tb authored and HanxSmile committed Apr 4, 2023
1 parent 5475958 commit 65bdcaf
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 4 deletions.
18 changes: 16 additions & 2 deletions dsdl/dataset/wrapper_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,23 @@
from .base_dataset import Dataset
from yaml import load as yaml_load
from typing import Sequence, Union, Iterable
from torch.utils.data import Dataset as _Dataset
from torch.utils.data import DataLoader, ConcatDataset
from terminaltables import AsciiTable
try:
from torch.utils.data import Dataset as _Dataset
from torch.utils.data import DataLoader, ConcatDataset
except:
from ..warning import ImportWarning
ImportWarning("'torch' is not installed.")

class _Dataset:
def __init__(self, *args, **kwargs):
pass
class DataLoader:
def __init__(self, *args, **kwargs):
pass
class ConcatDataset:
def __init__(self, *args, **kwargs):
pass

try:
from yaml import CSafeLoader as YAMLSafeLoader
Expand Down
34 changes: 34 additions & 0 deletions dsdl/geometry/classdomain.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import numpy as np

from dsdl.geometry.registry import CLASSDOMAIN, LABEL
from dsdl.geometry.label import Label
from dsdl.geometry.class_domain_attributes import Skeleton
Expand Down Expand Up @@ -105,6 +107,38 @@ def get_label(cls, name):
def get_attribute(cls, attr_name):
attr_dic = getattr(cls, "__attributes__")
return attr_dic.get(attr_name, None)

def get_hierarchy_info(cls):
class_names_in = [i.category_name for i in cls.__list__]
single_name_dict = {}
name_2_index_dict = {}
for curr_name in class_names_in:
arrs = curr_name.split(".")[1:]
for cname in arrs:
if cname not in single_name_dict:
single_name_dict[cname] = 0
single_class_nums = len(single_name_dict)
class_names_sort = sorted(list(single_name_dict.keys()))
relation_metric = np.eye(single_class_nums, single_class_nums)
for index, key in enumerate(class_names_sort):
name_2_index_dict[key] = index

for idx, item in enumerate(class_names_sort):
curr_dict = {}
index_list = []
for class_name in class_names_in:
name_arrs = class_name.split(".")[1:]
if item in name_arrs:
curr_index = name_arrs.index(item)
curr_used_names = name_arrs[:curr_index+1]
for used_name in curr_used_names:
if used_name not in curr_dict:
curr_dict[used_name] = 0
for ckey in curr_dict:
m_index = name_2_index_dict[ckey]
index_list.append(m_index)
relation_metric[idx, m_index] = 1
return class_names_sort, relation_metric


def ClassDomain(name, **kwargs):
Expand Down
4 changes: 4 additions & 0 deletions dsdl/geometry/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ def category_name(self):
"""
return self._name

@property
def leaf_node_name(self):
return self._name.split(".")[-1]

@property
def openmmlabformat(self):
"""
Expand Down
24 changes: 22 additions & 2 deletions dsdl/parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,28 @@ def _parse(self, data_file: str, library_path: str):
FIELD_PARSER = ParserField(self.struct_name_params, self.struct_name,struct_params)
field_list = dict()
for raw_field in define_value["$fields"].items():
field_name = raw_field[0].strip()
field_type = raw_field[1].strip()
try:
field_name = raw_field[0].strip()
except:
msg = f"field name must be string, got `{raw_field[0]}` (type: {type(raw_field[0])})"
if self.report_flag:
temp_check_item = CheckLogItem(
def_name=TypeEnum.STRUCT.value, msg=msg
)
CHECK_LOG.sub_struct.append(temp_check_item)
return
raise Exception(msg)
try:
field_type = raw_field[1].strip()
except:
msg = f"field type must be string, got `{raw_field[1]}` (type: {type(raw_field[1])})"
if self.report_flag:
temp_check_item = CheckLogItem(
def_name=TypeEnum.STRUCT.value, msg=msg
)
CHECK_LOG.sub_struct.append(temp_check_item)
return
raise Exception(msg)
# 判断field_name是否为python保留字和是符合命名规范
try:
check_name_format(field_name)
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def readme():
"tqdm>=4.65.0",
"scikit-image>=0.19.3",
"tifffile"
"terminaltables>=3.1.10"
],
classifiers=[
"Programming Language :: Python :: 3.8",
Expand Down

0 comments on commit 65bdcaf

Please sign in to comment.