88 lines
3.4 KiB
Python
88 lines
3.4 KiB
Python
import json
|
|
import yaml
|
|
import sys
|
|
import os
|
|
from dataclasses import dataclass
|
|
|
|
DIT = "DiT_Policy"
|
|
IMG_ADAPTOR = "Img_Adaptor"
|
|
|
|
@dataclass
|
|
class QuantConfig:
|
|
task_id: str = None
|
|
gpu_id: str = None
|
|
march: str = None
|
|
model_type: str = None
|
|
output_path: str = None
|
|
DiT_Policy_ONNX: str = None
|
|
DiT_Policy_CALIBRATION: str = None
|
|
Img_Adaptor_ONNX: str = None
|
|
Img_Adaptor_CALIBRATION: str = None
|
|
|
|
|
|
|
|
|
|
def load_config(config_path):
|
|
with open(config_path, "r") as file:
|
|
config = json.load(file)
|
|
|
|
if "quant" in config:
|
|
quant_info = config["quant"]
|
|
if "DiT_Policy" in quant_info:
|
|
dit_policy = quant_info["DiT_Policy"]
|
|
if "Img_Adaptor" in quant_info:
|
|
img_adaptor = quant_info["Img_Adaptor"]
|
|
|
|
opt = QuantConfig(
|
|
task_id=config.get("task_id"),
|
|
gpu_id=config.get("gpu_id"),
|
|
march=quant_info.get("march"),
|
|
model_type=quant_info.get("model_type"),
|
|
output_path=os.path.join(quant_info.get("output_path"), config.get("task_id")),
|
|
DiT_Policy_ONNX=dit_policy.get("onnx_model"),
|
|
DiT_Policy_CALIBRATION=dit_policy.get("calibration_data"),
|
|
Img_Adaptor_ONNX=img_adaptor.get("onnx_model"),
|
|
Img_Adaptor_CALIBRATION=img_adaptor.get("calibration_data")
|
|
)
|
|
os.makedirs(opt.output_path, exist_ok=True)
|
|
|
|
# PrePare Img Convert YAML
|
|
with open(f"ptq_yaml/{opt.model_type}/img_adaptor.yaml", "r") as file:
|
|
img_adaptor_yaml = yaml.safe_load(file)
|
|
img_adaptor_yaml["model_parameters"]["onnx_model"] = opt.Img_Adaptor_ONNX
|
|
img_adaptor_yaml["model_parameters"]["march"] = opt.march
|
|
img_adaptor_yaml["model_parameters"]["output_model_file_prefix"] = "rdt_img_adaptor"
|
|
img_adaptor_yaml["calibration_parameters"]["cal_data_dir"] = opt.Img_Adaptor_CALIBRATION
|
|
img_adaptor_yaml["model_parameters"]["working_dir"] = IMG_ADAPTOR
|
|
img_adaptor_yaml_path = os.path.join(opt.output_path, "img_adaptor.yaml")
|
|
with open(img_adaptor_yaml_path, 'w') as f:
|
|
yaml.safe_dump(img_adaptor_yaml, f, default_flow_style=False, allow_unicode=True)
|
|
|
|
|
|
# PrePare DiT Convert YAML
|
|
with open(f"ptq_yaml/{opt.model_type}/dit.yaml", "r") as file:
|
|
dit_yaml = yaml.safe_load(file)
|
|
for k, v in dit_yaml.get("calibration_parameters", {}).items():
|
|
if isinstance(v, str) and "{dit_cal_name}" in v:
|
|
if opt.DiT_Policy_CALIBRATION is not None:
|
|
dit_yaml["calibration_parameters"][k] = v.replace("{dit_cal_name}", opt.DiT_Policy_CALIBRATION)
|
|
else:
|
|
raise ValueError(f"DiT_Policy_CALIBRATION is None, cannot replace {{dit_cal_name}} in {k}")
|
|
dit_yaml["model_parameters"]["onnx_model"] = opt.DiT_Policy_ONNX
|
|
dit_yaml["model_parameters"]["march"] = opt.march
|
|
dit_yaml["model_parameters"]["working_dir"] = DIT
|
|
|
|
# dit_onnx_dir = os.path.dirname(opt.DiT_Policy_ONNX) if opt.DiT_Policy_ONNX else ""
|
|
# os.environ["DIT_ONNX_DIR"] = dit_onnx_dir
|
|
|
|
with open(f"ptq_yaml/{opt.model_type}/dit_op_config.json", "r") as file:
|
|
dit_json = json.load(file)
|
|
dit_yaml["calibration_parameters"]["quant_config"] = dit_json
|
|
|
|
dit_yaml_path = os.path.join(opt.output_path, "dit.yaml")
|
|
with open(dit_yaml_path, 'w') as f:
|
|
yaml.safe_dump(dit_yaml, f, default_flow_style=False, allow_unicode=True)
|
|
|
|
if __name__ == "__main__":
|
|
config_path = sys.argv[1]
|
|
config = load_config(config_path) |