add some code
This commit is contained in:
153
managed_components/espressif__esp-sr/model/movemodel.py
Normal file
153
managed_components/espressif__esp-sr/model/movemodel.py
Normal file
@@ -0,0 +1,153 @@
|
||||
import argparse
|
||||
import io
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
sys.dont_write_bytecode = True
|
||||
from pack_model import pack_models
|
||||
|
||||
|
||||
def calculate_total_size(folder_path):
|
||||
total_size = 0
|
||||
for file_name in os.listdir(folder_path):
|
||||
path = os.path.join(folder_path, file_name)
|
||||
if os.path.isdir(path):
|
||||
total_size = total_size + calculate_total_size(path)
|
||||
if os.path.isfile(path):
|
||||
total_size = total_size + os.path.getsize(path)
|
||||
return total_size
|
||||
|
||||
def copy_wakenet_from_sdkconfig(model_path, sdkconfig_path, target_path):
|
||||
"""
|
||||
Copy wakenet model from model_path to target_path based on sdkconfig
|
||||
"""
|
||||
models = []
|
||||
with io.open(sdkconfig_path, "r") as f:
|
||||
for label in f:
|
||||
label = label.strip("\n")
|
||||
if 'CONFIG_SR_WN' in label and '#' not in label[0]:
|
||||
if '_NONE' in label:
|
||||
continue
|
||||
if '=' in label:
|
||||
label = label.split("=")[0]
|
||||
if '_MULTI' in label:
|
||||
label = label[:-6]
|
||||
model_name = label.split("_SR_WN_")[-1].lower()
|
||||
models.append(model_name)
|
||||
|
||||
for item in models:
|
||||
wakeword_model_path = model_path + '/wakenet_model/' + item
|
||||
if os.path.exists(wakeword_model_path):
|
||||
shutil.copytree(wakeword_model_path, target_path+'/'+item)
|
||||
|
||||
def copy_multinet_from_sdkconfig(model_path, sdkconfig_path, target_path):
|
||||
"""
|
||||
Copy multinet model from model_path to target_path based on sdkconfig
|
||||
"""
|
||||
with io.open(sdkconfig_path, "r") as f:
|
||||
models_string = ''
|
||||
for label in f:
|
||||
label = label.strip("\n")
|
||||
if 'CONFIG_SR_MN' in label and label[0] != '#':
|
||||
models_string += label
|
||||
|
||||
models = []
|
||||
if "CONFIG_SR_MN_CN_MULTINET3_SINGLE_RECOGNITION" in models_string:
|
||||
models.append('mn3_cn')
|
||||
elif "CONFIG_SR_MN_CN_MULTINET4_5_SINGLE_RECOGNITION_QUANT8" in models_string:
|
||||
models.append('mn4q8_cn')
|
||||
elif "CONFIG_SR_MN_CN_MULTINET4_5_SINGLE_RECOGNITION" in models_string:
|
||||
models.append('mn4_cn')
|
||||
elif "CONFIG_SR_MN_CN_MULTINET5_RECOGNITION_QUANT8" in models_string:
|
||||
models.append('mn5q8_cn')
|
||||
elif "CONFIG_SR_MN_CN_MULTINET6_QUANT" in models_string:
|
||||
models.append('mn6_cn')
|
||||
elif "CONFIG_SR_MN_CN_MULTINET6_AC_QUANT" in models_string:
|
||||
models.append('mn6_cn_ac')
|
||||
elif "CONFIG_SR_MN_CN_MULTINET7_QUANT" in models_string:
|
||||
models.append('mn7_cn')
|
||||
elif "CONFIG_SR_MN_CN_MULTINET7_AC_QUANT" in models_string:
|
||||
models.append('mn7_cn_ac')
|
||||
|
||||
if "CONFIG_SR_MN_EN_MULTINET5_SINGLE_RECOGNITION_QUANT8" in models_string:
|
||||
models.append('mn5q8_en')
|
||||
elif "CONFIG_SR_MN_EN_MULTINET5_SINGLE_RECOGNITION" in models_string:
|
||||
models.append('mn5_en')
|
||||
elif "CONFIG_SR_MN_EN_MULTINET6_QUANT" in models_string:
|
||||
models.append('mn6_en')
|
||||
elif "CONFIG_SR_MN_EN_MULTINET7_QUANT" in models_string:
|
||||
models.append('mn7_en')
|
||||
|
||||
if "MULTINET6" in models_string or "MULTINET7" in models_string:
|
||||
models.append('fst')
|
||||
|
||||
for item in models:
|
||||
shutil.copytree(model_path + '/multinet_model/' + item, target_path+'/'+item)
|
||||
|
||||
|
||||
def copy_nsnet_from_sdkconfig(model_path, sdkconfig_path, target_path):
|
||||
"""
|
||||
Copy nsnet model from model_path to target_path based on sdkconfig
|
||||
"""
|
||||
with io.open(sdkconfig_path, "r") as f:
|
||||
models_string = ''
|
||||
for label in f:
|
||||
label = label.strip("\n")
|
||||
if 'CONFIG_SR_NSN' in label and label[0] != '#':
|
||||
models_string += label
|
||||
|
||||
models = []
|
||||
if "CONFIG_SR_NSN_NSNET1" in models_string:
|
||||
models.append('nsnet1')
|
||||
if "CONFIG_SR_NSN_NSNET2" in models_string:
|
||||
models.append('nsnet2')
|
||||
|
||||
for item in models:
|
||||
shutil.copytree(model_path + '/nsnet_model/' + item, target_path+'/'+item)
|
||||
|
||||
def copy_vadnet_from_sdkconfig(model_path, sdkconfig_path, target_path):
|
||||
"""
|
||||
Copy vadnet model from model_path to target_path based on sdkconfig
|
||||
"""
|
||||
with io.open(sdkconfig_path, "r") as f:
|
||||
models_string = ''
|
||||
for label in f:
|
||||
label = label.strip("\n")
|
||||
if 'CONFIG_SR_VADN' in label and label[0] != '#':
|
||||
models_string += label
|
||||
|
||||
models = []
|
||||
if "CONFIG_SR_VADN_VADNET1_SMALL" in models_string:
|
||||
models.append('vadnet1_small')
|
||||
elif "CONFIG_SR_VADN_VADNET1_MEDIUM" in models_string:
|
||||
models.append('vadnet1_medium')
|
||||
|
||||
for item in models:
|
||||
shutil.copytree(model_path + '/vadnet_model/' + item, target_path+'/'+item)
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Model generator tool')
|
||||
parser.add_argument('-d1', '--sdkconfig_path')
|
||||
parser.add_argument('-d2', '--model_path')
|
||||
parser.add_argument('-d3', '--build_path')
|
||||
args = parser.parse_args()
|
||||
|
||||
sdkconfig_path = args.sdkconfig_path
|
||||
model_path = args.model_path + '/model'
|
||||
target_path = args.build_path + '/srmodels'
|
||||
image_file = "srmodels.bin"
|
||||
|
||||
if os.path.exists(target_path):
|
||||
shutil.rmtree(target_path)
|
||||
os.makedirs(target_path)
|
||||
|
||||
copy_multinet_from_sdkconfig(model_path, sdkconfig_path, target_path)
|
||||
copy_wakenet_from_sdkconfig(model_path, sdkconfig_path, target_path)
|
||||
copy_nsnet_from_sdkconfig(model_path, sdkconfig_path, target_path)
|
||||
copy_vadnet_from_sdkconfig(model_path, sdkconfig_path, target_path)
|
||||
pack_models(target_path, image_file)
|
||||
total_size = os.path.getsize(os.path.join(target_path, image_file))
|
||||
recommended_size = int(math.ceil(total_size/1024))
|
||||
print("Recommended model partition size: %dK" % (recommended_size))
|
||||
Reference in New Issue
Block a user