|
@@ -107,6 +107,8 @@ def bert_predict(text):
|
|
|
else:
|
|
|
return ''
|
|
|
def process_data(origin_csv_path, temp_path, name_col, chunksize, total_processed, df_error):
|
|
|
+ if name_col not in pd.read_csv(origin_csv_path).columns:
|
|
|
+ return '列名错误', None, None
|
|
|
for chunk in tqdm(pd.read_csv(origin_csv_path, chunksize=chunksize, iterator=True), total=len(pd.read_csv(origin_csv_path)) // chunksize + 1, desc='Processing', unit='item'):
|
|
|
try:
|
|
|
# 对每个块进行处理
|
|
@@ -136,6 +138,8 @@ async def predict_excel(file_path, name_col, temp_path, save_path, chunksize=5):
|
|
|
df_origin.to_csv(origin_csv_path, index=False)
|
|
|
|
|
|
temp_path, df_error, total_processed = await asyncio.to_thread(process_data, origin_csv_path, temp_path, name_col, chunksize, total_processed, df_error)
|
|
|
+ if temp_path == '列名错误':
|
|
|
+ return '列名错误', None, None
|
|
|
df_final = pd.read_csv(temp_path)
|
|
|
df_final.to_excel(save_path, index=False)
|
|
|
try:
|
|
@@ -165,6 +169,8 @@ async def predict_csv(file_path, name_col, temp_path, save_path, chunksize=5):
|
|
|
|
|
|
# for chunk in tqdm(pd.read_csv(origin_csv_path, chunksize=chunksize, iterator=True), total=len(pd.read_csv(origin_csv_path)) // chunksize + 1, desc='Processing', unit='item'):
|
|
|
temp_path, df_error, total_processed = await asyncio.to_thread(process_data, origin_csv_path, temp_path, name_col, chunksize, total_processed, df_error)
|
|
|
+ if temp_path == '列名错误':
|
|
|
+ return '列名错误', None, None
|
|
|
df_final = pd.read_csv(temp_path)
|
|
|
df_final.to_csv(save_path, index=False)
|
|
|
os.remove(origin_csv_path)
|
|
@@ -239,6 +245,8 @@ async def classify_data(request: ClassificationRequest):
|
|
|
df_origin = pd.read_csv(request.path)
|
|
|
else:
|
|
|
df_origin = pd.read_excel(request.path)
|
|
|
+ if request.name_column not in df_origin.columns or request.one_key not in df_origin.columns:
|
|
|
+ return JSONResponse(content={"message": "用户标识列输入错误,请重新输入","data":{},"code":500}, status_code=500)
|
|
|
df_origin['AI Name'] = df_origin[request.name_column]
|
|
|
df_origin['AI Group'] = ''
|
|
|
df_use = df_origin[['AI Name', 'AI Group']]
|
|
@@ -308,6 +316,13 @@ async def classify_data_bert(request: ClassificationRequestBert):
|
|
|
"data":{"output_file": file_base_url + save_path.split(basic_path)[1]},
|
|
|
"code":200
|
|
|
}, status_code=200)
|
|
|
+ elif save_path == '列名错误':
|
|
|
+ return JSONResponse(content={
|
|
|
+ "message": "用户标识列输入错误,请重新输入",
|
|
|
+ "data":{},
|
|
|
+ "code":500
|
|
|
+ }, status_code=200)
|
|
|
+
|
|
|
else:
|
|
|
return JSONResponse(content={
|
|
|
"message": "分类完成只完成部分",
|