import json import yaml import sys import os from dataclasses import dataclass DIT = "DiT_Policy" IMG_ADAPTOR = "Img_Adaptor" @dataclass class QuantConfig: task_id: str = None gpu_id: str = None march: str = None model_type: str = None output_path: str = None DiT_Policy_ONNX: str = None DiT_Policy_CALIBRATION: str = None Img_Adaptor_ONNX: str = None Img_Adaptor_CALIBRATION: str = None def load_config(config_path): with open(config_path, "r") as file: config = json.load(file) if "quant" in config: quant_info = config["quant"] if "DiT_Policy" in quant_info: dit_policy = quant_info["DiT_Policy"] if "Img_Adaptor" in quant_info: img_adaptor = quant_info["Img_Adaptor"] opt = QuantConfig( task_id=config.get("task_id"), gpu_id=config.get("gpu_id"), march=quant_info.get("march"), model_type=quant_info.get("model_type"), output_path=os.path.join(quant_info.get("output_path"), config.get("task_id")), DiT_Policy_ONNX=dit_policy.get("onnx_model"), DiT_Policy_CALIBRATION=dit_policy.get("calibration_data"), Img_Adaptor_ONNX=img_adaptor.get("onnx_model"), Img_Adaptor_CALIBRATION=img_adaptor.get("calibration_data") ) os.makedirs(opt.output_path, exist_ok=True) # PrePare Img Convert YAML with open(f"ptq_yaml/{opt.model_type}/img_adaptor.yaml", "r") as file: img_adaptor_yaml = yaml.safe_load(file) img_adaptor_yaml["model_parameters"]["onnx_model"] = opt.Img_Adaptor_ONNX img_adaptor_yaml["model_parameters"]["march"] = opt.march img_adaptor_yaml["model_parameters"]["output_model_file_prefix"] = "rdt_img_adaptor" img_adaptor_yaml["calibration_parameters"]["cal_data_dir"] = opt.Img_Adaptor_CALIBRATION img_adaptor_yaml["model_parameters"]["working_dir"] = IMG_ADAPTOR img_adaptor_yaml_path = os.path.join(opt.output_path, "img_adaptor.yaml") with open(img_adaptor_yaml_path, 'w') as f: yaml.safe_dump(img_adaptor_yaml, f, default_flow_style=False, allow_unicode=True) # PrePare DiT Convert YAML with open(f"ptq_yaml/{opt.model_type}/dit.yaml", "r") as file: dit_yaml = yaml.safe_load(file) for k, v in dit_yaml.get("calibration_parameters", {}).items(): if isinstance(v, str) and "{dit_cal_name}" in v: if opt.DiT_Policy_CALIBRATION is not None: dit_yaml["calibration_parameters"][k] = v.replace("{dit_cal_name}", opt.DiT_Policy_CALIBRATION) else: raise ValueError(f"DiT_Policy_CALIBRATION is None, cannot replace {{dit_cal_name}} in {k}") dit_yaml["model_parameters"]["onnx_model"] = opt.DiT_Policy_ONNX dit_yaml["model_parameters"]["march"] = opt.march dit_yaml["model_parameters"]["working_dir"] = DIT # dit_onnx_dir = os.path.dirname(opt.DiT_Policy_ONNX) if opt.DiT_Policy_ONNX else "" # os.environ["DIT_ONNX_DIR"] = dit_onnx_dir with open(f"ptq_yaml/{opt.model_type}/dit_op_config.json", "r") as file: dit_json = json.load(file) dit_yaml["calibration_parameters"]["quant_config"] = dit_json dit_yaml_path = os.path.join(opt.output_path, "dit.yaml") with open(dit_yaml_path, 'w') as f: yaml.safe_dump(dit_yaml, f, default_flow_style=False, allow_unicode=True) if __name__ == "__main__": config_path = sys.argv[1] config = load_config(config_path)