|
@@ -35,34 +35,7 @@ model = BertClassifier(bert_config, len(label_revert_map.keys()))
|
|
model.load_state_dict(torch.load(model_save_path, map_location=torch.device('cuda:0')))
|
|
model.load_state_dict(torch.load(model_save_path, map_location=torch.device('cuda:0')))
|
|
model.eval()
|
|
model.eval()
|
|
tokenizer = BertTokenizer.from_pretrained(pre_train_model)
|
|
tokenizer = BertTokenizer.from_pretrained(pre_train_model)
|
|
-# class ModelManager:
|
|
|
|
-# def __init__(self, pre_train_model, model_save_path, label_revert_map):
|
|
|
|
-# self.pre_train_model = pre_train_model
|
|
|
|
-# self.model_save_path = model_save_path
|
|
|
|
-# self.label_revert_map = label_revert_map
|
|
|
|
-# self.model_cache = {}
|
|
|
|
-# self.tokenizer_cache = {}
|
|
|
|
|
|
|
|
-# def get_model(self):
|
|
|
|
-# # 使用当前线程ID作为键,确保每个线程有自己的模型实例
|
|
|
|
-# thread_id = threading.get_ident()
|
|
|
|
-
|
|
|
|
-# if thread_id not in self.model_cache:
|
|
|
|
-# bert_config = BertConfig.from_pretrained(self.pre_train_model)
|
|
|
|
-# model = BertClassifier(bert_config, len(self.label_revert_map.keys()))
|
|
|
|
-# model.load_state_dict(torch.load(self.model_save_path, map_location=torch.device('cuda:0')))
|
|
|
|
-# model.eval()
|
|
|
|
-# self.model_cache[thread_id] = model
|
|
|
|
-
|
|
|
|
-# if thread_id not in self.tokenizer_cache:
|
|
|
|
-# tokenizer = BertTokenizer.from_pretrained(self.pre_train_model)
|
|
|
|
-# self.tokenizer_cache[thread_id] = tokenizer
|
|
|
|
-
|
|
|
|
-# print(f"Thread {thread_id}: Tokenizer id: {id(self.tokenizer_cache[thread_id])}, Model id: {id(self.model_cache[thread_id])}")
|
|
|
|
-# return self.tokenizer_cache[thread_id], self.model_cache[thread_id]
|
|
|
|
-
|
|
|
|
-# 使用示例
|
|
|
|
-# model_manager = ModelManager(pre_train_model, model_save_path, label_revert_map)
|
|
|
|
class ClassificationRequest(BaseModel):
|
|
class ClassificationRequest(BaseModel):
|
|
path: str
|
|
path: str
|
|
client_id: str
|
|
client_id: str
|
|
@@ -77,14 +50,6 @@ class ClassificationRequestBert(BaseModel):
|
|
client_id: str
|
|
client_id: str
|
|
name_column: str
|
|
name_column: str
|
|
|
|
|
|
-# bert_config = BertConfig.from_pretrained(pre_train_model)
|
|
|
|
-# # 定义模型
|
|
|
|
-# model = BertClassifier(bert_config, len(label_revert_map.keys()))
|
|
|
|
-# # 加载训练好的模型
|
|
|
|
-# model.load_state_dict(torch.load(model_save_path, map_location=torch.device('cuda:0')))
|
|
|
|
-# model.eval()
|
|
|
|
-# tokenizer = BertTokenizer.from_pretrained(pre_train_model)
|
|
|
|
-
|
|
|
|
def bert_predict(text):
|
|
def bert_predict(text):
|
|
if type(text) == str and text != '':
|
|
if type(text) == str and text != '':
|
|
token = tokenizer(text, add_special_tokens=True, padding='max_length', truncation=True, max_length=512)
|
|
token = tokenizer(text, add_special_tokens=True, padding='max_length', truncation=True, max_length=512)
|
|
@@ -106,79 +71,7 @@ def bert_predict(text):
|
|
return label_revert_map[pred_label]
|
|
return label_revert_map[pred_label]
|
|
else:
|
|
else:
|
|
return ''
|
|
return ''
|
|
-def process_data(origin_csv_path, temp_path, name_col, chunksize, total_processed, df_error):
|
|
|
|
- if name_col not in pd.read_csv(origin_csv_path).columns:
|
|
|
|
- return '列名错误', None, None
|
|
|
|
- for chunk in tqdm(pd.read_csv(origin_csv_path, chunksize=chunksize, iterator=True), total=len(pd.read_csv(origin_csv_path)) // chunksize + 1, desc='Processing', unit='item'):
|
|
|
|
- try:
|
|
|
|
- # 对每个块进行处理
|
|
|
|
- chunk['AI Group'] = chunk[name_col].apply(lambda x : bert_predict(x))
|
|
|
|
- # 增量保存处理结果
|
|
|
|
- if total_processed == 0:
|
|
|
|
- chunk.to_csv(temp_path, mode='w', index=False)
|
|
|
|
- else:
|
|
|
|
- chunk.to_csv(temp_path, mode='a', header=False, index=False)
|
|
|
|
- # 更新已处理的数据量
|
|
|
|
- total_processed += len(chunk)
|
|
|
|
- except Exception as e:
|
|
|
|
- df_error = pd.concat([df_error, chunk])
|
|
|
|
- return temp_path, df_error, total_processed
|
|
|
|
-async def predict_excel(file_path, name_col, temp_path, save_path, chunksize=5):
|
|
|
|
- # 初始化变量
|
|
|
|
- error_file_name, error_file_extension = os.path.splitext(os.path.basename(save_path))
|
|
|
|
- # 添加后缀
|
|
|
|
- error_file = error_file_name + '_error' + error_file_extension
|
|
|
|
- origin_csv_file = error_file_name + '_origin.csv'
|
|
|
|
- # 生成新的文件路径
|
|
|
|
- error_file_path = os.path.join(os.path.dirname(save_path), error_file)
|
|
|
|
- origin_csv_path = os.path.join(os.path.dirname(save_path), origin_csv_file)
|
|
|
|
- total_processed = 0
|
|
|
|
- df_origin = pd.read_excel(file_path)
|
|
|
|
- df_error = pd.DataFrame(columns=df_origin.columns)
|
|
|
|
- df_origin.to_csv(origin_csv_path, index=False)
|
|
|
|
-
|
|
|
|
- temp_path, df_error, total_processed = await asyncio.to_thread(process_data, origin_csv_path, temp_path, name_col, chunksize, total_processed, df_error)
|
|
|
|
- if temp_path == '列名错误':
|
|
|
|
- return '列名错误', None, None
|
|
|
|
- df_final = pd.read_csv(temp_path)
|
|
|
|
- df_final.to_excel(save_path, index=False)
|
|
|
|
- try:
|
|
|
|
- os.remove(origin_csv_path)
|
|
|
|
- except Exception as e :
|
|
|
|
- pass
|
|
|
|
- if len(df_error) == 0:
|
|
|
|
- return save_path, '', 0
|
|
|
|
- else:
|
|
|
|
- df_error.to_excel(error_file_path)
|
|
|
|
- return save_path, error_file_path, len(df_error)
|
|
|
|
|
|
|
|
-async def predict_csv(file_path, name_col, temp_path, save_path, chunksize=5):
|
|
|
|
- # 初始化变量
|
|
|
|
- error_file_name, error_file_extension = os.path.splitext(os.path.basename(save_path))
|
|
|
|
- # 添加后缀
|
|
|
|
- error_file = error_file_name + '_error' + error_file_extension
|
|
|
|
- origin_csv_file = error_file_name + '_origin.csv'
|
|
|
|
- # 生成新的文件路径
|
|
|
|
- error_file_path = os.path.join(os.path.dirname(save_path), error_file)
|
|
|
|
- origin_csv_path = os.path.join(os.path.dirname(save_path), origin_csv_file)
|
|
|
|
- total_processed = 0
|
|
|
|
- df_origin = pd.read_csv(file_path)
|
|
|
|
- df_error = pd.DataFrame(columns=df_origin.columns)
|
|
|
|
- df_origin.to_csv(origin_csv_path, index=False)
|
|
|
|
- # 按块读取 CSV 文件
|
|
|
|
-
|
|
|
|
- # for chunk in tqdm(pd.read_csv(origin_csv_path, chunksize=chunksize, iterator=True), total=len(pd.read_csv(origin_csv_path)) // chunksize + 1, desc='Processing', unit='item'):
|
|
|
|
- temp_path, df_error, total_processed = await asyncio.to_thread(process_data, origin_csv_path, temp_path, name_col, chunksize, total_processed, df_error)
|
|
|
|
- if temp_path == '列名错误':
|
|
|
|
- return '列名错误', None, None
|
|
|
|
- df_final = pd.read_csv(temp_path)
|
|
|
|
- df_final.to_csv(save_path, index=False)
|
|
|
|
- os.remove(origin_csv_path)
|
|
|
|
- if len(df_error) == 0:
|
|
|
|
- return save_path, '', 0
|
|
|
|
- else:
|
|
|
|
- df_error.to_csv(error_file_path)
|
|
|
|
- return save_path, error_file_path, len(df_error)
|
|
|
|
def remove_files():
|
|
def remove_files():
|
|
current_time = time.time()
|
|
current_time = time.time()
|
|
TIME_THRESHOLD_FILEPATH = 30 * 24 * 60 * 60
|
|
TIME_THRESHOLD_FILEPATH = 30 * 24 * 60 * 60
|
|
@@ -196,6 +89,7 @@ def remove_files():
|
|
if current_time - os.path.getmtime(dir_path) > TIME_THRESHOLD_FILEPATH:
|
|
if current_time - os.path.getmtime(dir_path) > TIME_THRESHOLD_FILEPATH:
|
|
print(f"删除文件夹: {dir_path}")
|
|
print(f"删除文件夹: {dir_path}")
|
|
shutil.rmtree(dir_path)
|
|
shutil.rmtree(dir_path)
|
|
|
|
+
|
|
@app.post("/uploadfile/")
|
|
@app.post("/uploadfile/")
|
|
async def create_upload_file(file: UploadFile = File(...), client_id: str = Form(...)):
|
|
async def create_upload_file(file: UploadFile = File(...), client_id: str = Form(...)):
|
|
user_directory = f'{basic_path}/{client_id}'
|
|
user_directory = f'{basic_path}/{client_id}'
|
|
@@ -214,24 +108,22 @@ async def create_upload_file(file: UploadFile = File(...), client_id: str = Form
|
|
}, status_code=200)
|
|
}, status_code=200)
|
|
except Exception as e:
|
|
except Exception as e:
|
|
return JSONResponse(content={"message": f"发生错误: {str(e)}", "data":{}, "code":500}, status_code=500)
|
|
return JSONResponse(content={"message": f"发生错误: {str(e)}", "data":{}, "code":500}, status_code=500)
|
|
|
|
+
|
|
async def process_openai_data(deal_result, client, temp_csv):
|
|
async def process_openai_data(deal_result, client, temp_csv):
|
|
for name, value in tqdm(deal_result.items(), desc='Processing', unit='item'):
|
|
for name, value in tqdm(deal_result.items(), desc='Processing', unit='item'):
|
|
- try:
|
|
|
|
- message = [
|
|
|
|
- {'role':'system', 'content': cls_system_prompt},
|
|
|
|
- {'role':'user', 'content':user_prompt.format(chunk=str(value))}
|
|
|
|
- ]
|
|
|
|
- # result_string = post_openai(message)
|
|
|
|
- response = await client.chat.completions.create(model='gpt-4',messages=message)
|
|
|
|
- result_string = response.choices[0].message.content
|
|
|
|
- result = extract_list_from_string(result_string)
|
|
|
|
- if result:
|
|
|
|
- df_output = pd.DataFrame(result)
|
|
|
|
- df_output.to_csv(temp_csv, mode='a', header=True, index=False)
|
|
|
|
- else:
|
|
|
|
- continue
|
|
|
|
- except Exception as e:
|
|
|
|
- print(f'{name}出现问题啦, 错误为:{e} 请自行调试')
|
|
|
|
|
|
+ message = [
|
|
|
|
+ {'role':'system', 'content': cls_system_prompt},
|
|
|
|
+ {'role':'user', 'content':user_prompt.format(chunk=str(value))}
|
|
|
|
+ ]
|
|
|
|
+ # result_string = post_openai(message)
|
|
|
|
+ response = await client.chat.completions.create(model='gpt-4',messages=message)
|
|
|
|
+ result_string = response.choices[0].message.content
|
|
|
|
+ result = extract_list_from_string(result_string)
|
|
|
|
+ if result:
|
|
|
|
+ df_output = pd.DataFrame(result)
|
|
|
|
+ df_output.to_csv(temp_csv, mode='a', header=True, index=False)
|
|
|
|
+ else:
|
|
|
|
+ continue
|
|
@app.post("/classify_openai/")
|
|
@app.post("/classify_openai/")
|
|
async def classify_data(request: ClassificationRequest):
|
|
async def classify_data(request: ClassificationRequest):
|
|
try:
|
|
try:
|
|
@@ -290,48 +182,112 @@ async def classify_data(request: ClassificationRequest):
|
|
|
|
|
|
@app.post("/classify_bert/")
|
|
@app.post("/classify_bert/")
|
|
async def real_process(request: ClassificationRequestBert):
|
|
async def real_process(request: ClassificationRequestBert):
|
|
- task = asyncio.create_task(classify_data_bert(request))
|
|
|
|
- result = await task
|
|
|
|
|
|
+ # task = asyncio.create_task(classify_data_bert(request))
|
|
|
|
+ result = await classify_data_bert(request)
|
|
return result
|
|
return result
|
|
|
|
|
|
async def classify_data_bert(request: ClassificationRequestBert):
|
|
async def classify_data_bert(request: ClassificationRequestBert):
|
|
- remove_files()
|
|
|
|
- work_path = f'{basic_path}/{request.client_id}'
|
|
|
|
- if not os.path.exists(work_path):
|
|
|
|
- os.makedirs(work_path, exist_ok=True)
|
|
|
|
- final_file_name, final_file_extension = os.path.splitext(os.path.basename(request.path))
|
|
|
|
- final_file = final_file_name + '_classify' + final_file_extension
|
|
|
|
- # 生成新的文件路径
|
|
|
|
- new_file_path = os.path.join(os.path.dirname(request.path), final_file)
|
|
|
|
- timestamp_str = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
|
|
|
- temp_csv = work_path + '/' + timestamp_str + 'output_temp.csv'
|
|
|
|
- if final_file_extension == '.csv':
|
|
|
|
- save_path, error_path, error_len = await predict_csv(request.path, request.name_column, temp_path=temp_csv ,save_path=new_file_path)
|
|
|
|
- else:
|
|
|
|
- # 添加后缀
|
|
|
|
- save_path, error_path, error_len = await predict_excel(request.path, request.name_column, temp_path=temp_csv ,save_path=new_file_path)
|
|
|
|
- if error_len == 0:
|
|
|
|
- return JSONResponse(content={
|
|
|
|
- "message": "分类完成",
|
|
|
|
- "data":{"output_file": file_base_url + save_path.split(basic_path)[1]},
|
|
|
|
- "code":200
|
|
|
|
- }, status_code=200)
|
|
|
|
- elif save_path == '列名错误':
|
|
|
|
- return JSONResponse(content={
|
|
|
|
- "message": "用户标识列输入错误,请重新输入",
|
|
|
|
- "data":{},
|
|
|
|
- "code":500
|
|
|
|
- }, status_code=200)
|
|
|
|
-
|
|
|
|
- else:
|
|
|
|
|
|
+ try:
|
|
|
|
+ remove_files()
|
|
|
|
+ work_path = f'{basic_path}/{request.client_id}'
|
|
|
|
+ if not os.path.exists(work_path):
|
|
|
|
+ os.makedirs(work_path, exist_ok=True)
|
|
|
|
+ final_file_name, final_file_extension = os.path.splitext(os.path.basename(request.path))
|
|
|
|
+ final_file = final_file_name + '_classify' + final_file_extension
|
|
|
|
+ new_file_path = os.path.join(os.path.dirname(request.path), final_file)
|
|
|
|
+ timestamp_str = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
|
|
|
+ temp_csv = work_path + '/' + timestamp_str + 'output_temp.csv'
|
|
|
|
+ def process_data(origin_csv_path, temp_path, name_col, chunksize, total_processed, df_error):
|
|
|
|
+ if name_col not in pd.read_csv(origin_csv_path).columns:
|
|
|
|
+ return '列名错误', None, None
|
|
|
|
+ for chunk in tqdm(pd.read_csv(origin_csv_path, chunksize=chunksize, iterator=True), total=len(pd.read_csv(origin_csv_path)) // chunksize + 1, desc='Processing', unit='item'):
|
|
|
|
+ try:
|
|
|
|
+ # 对每个块进行处理
|
|
|
|
+ chunk['AI Group'] = chunk[name_col].apply(lambda x : bert_predict(x))
|
|
|
|
+ # 增量保存处理结果
|
|
|
|
+ if total_processed == 0:
|
|
|
|
+ chunk.to_csv(temp_path, mode='w', index=False)
|
|
|
|
+ else:
|
|
|
|
+ chunk.to_csv(temp_path, mode='a', header=False, index=False)
|
|
|
|
+ # 更新已处理的数据量
|
|
|
|
+ total_processed += len(chunk)
|
|
|
|
+ except Exception as e:
|
|
|
|
+ df_error = pd.concat([df_error, chunk])
|
|
|
|
+ return temp_path, df_error, total_processed
|
|
|
|
+ async def process_file(file_path, name_col, temp_path, save_path, chunksize=5):
|
|
|
|
+ error_file_name, error_file_extension = os.path.splitext(os.path.basename(save_path))
|
|
|
|
+ error_file = error_file_name + '_error' + error_file_extension
|
|
|
|
+ origin_csv_file = error_file_name + '_origin.csv'
|
|
|
|
+ error_file_path = os.path.join(os.path.dirname(save_path), error_file)
|
|
|
|
+ origin_csv_path = os.path.join(os.path.dirname(save_path), origin_csv_file)
|
|
|
|
+
|
|
|
|
+ if final_file_extension == '.csv':
|
|
|
|
+ try:
|
|
|
|
+ df_origin = pd.read_csv(file_path, encoding='utf-8')
|
|
|
|
+ except UnicodeDecodeError:
|
|
|
|
+ df_origin = pd.read_csv(file_path, encoding='gbk')
|
|
|
|
+ else:
|
|
|
|
+ df_origin = pd.read_excel(file_path)
|
|
|
|
+
|
|
|
|
+ df_origin.to_csv(origin_csv_path, index=False)
|
|
|
|
+ df_error = pd.DataFrame(columns=df_origin.columns)
|
|
|
|
+
|
|
|
|
+ if name_col not in df_origin.columns:
|
|
|
|
+ return '列名错误', None, None
|
|
|
|
+
|
|
|
|
+ total_processed = 0
|
|
|
|
+ temp_path, df_error, total_processed = await asyncio.to_thread(process_data, origin_csv_path, temp_path, name_col, chunksize, total_processed, df_error)
|
|
|
|
+
|
|
|
|
+ df_final = pd.read_csv(temp_path)
|
|
|
|
+ if final_file_extension == '.csv':
|
|
|
|
+ df_final.to_csv(save_path, index=False)
|
|
|
|
+ else:
|
|
|
|
+ df_final.to_excel(save_path, index=False)
|
|
|
|
+
|
|
|
|
+ try:
|
|
|
|
+ os.remove(origin_csv_path)
|
|
|
|
+ except Exception as e:
|
|
|
|
+ pass
|
|
|
|
+
|
|
|
|
+ if len(df_error) == 0:
|
|
|
|
+ return save_path, '', 0
|
|
|
|
+ else:
|
|
|
|
+ if final_file_extension == '.csv':
|
|
|
|
+ df_error.to_csv(error_file_path, index=False)
|
|
|
|
+ else:
|
|
|
|
+ df_error.to_excel(error_file_path, index=False)
|
|
|
|
+ return save_path, error_file_path, len(df_error)
|
|
|
|
+
|
|
|
|
+ save_path, error_path, error_len = await process_file(request.path, request.name_column, temp_path=temp_csv, save_path=new_file_path)
|
|
|
|
+
|
|
|
|
+ if error_len == 0:
|
|
|
|
+ return JSONResponse(content={
|
|
|
|
+ "message": "分类完成",
|
|
|
|
+ "data":{"output_file": file_base_url + save_path.split(basic_path)[1]},
|
|
|
|
+ "code":200
|
|
|
|
+ }, status_code=200)
|
|
|
|
+ elif save_path == '列名错误':
|
|
|
|
+ return JSONResponse(content={
|
|
|
|
+ "message": "用户标识列输入错误,请重新输入",
|
|
|
|
+ "data":{},
|
|
|
|
+ "code":500
|
|
|
|
+ }, status_code=200)
|
|
|
|
+ else:
|
|
|
|
+ return JSONResponse(content={
|
|
|
|
+ "message": "分类完成只完成部分",
|
|
|
|
+ "data":{"output_file": file_base_url + save_path.split(basic_path)[1],
|
|
|
|
+ "output_file_nonprocess":file_base_url + error_path.split(basic_path)[1]},
|
|
|
|
+ "code":200
|
|
|
|
+ }, status_code=200)
|
|
|
|
+ except Exception as e:
|
|
return JSONResponse(content={
|
|
return JSONResponse(content={
|
|
- "message": "分类完成只完成部分",
|
|
|
|
- "data":{"output_file": file_base_url + save_path.split(basic_path)[1],
|
|
|
|
- "output_file_nonprocess":file_base_url + error_path.split(basic_path)[1]},
|
|
|
|
- "code":200
|
|
|
|
- }, status_code=200)
|
|
|
|
|
|
+ "message": f"出现错误: {e}, 请检查列名和文件重试",
|
|
|
|
+ "data":{},
|
|
|
|
+ "code":500
|
|
|
|
+ }, status_code=200)
|
|
|
|
+
|
|
@app.get("/test/")
|
|
@app.get("/test/")
|
|
async def classify_data():
|
|
async def classify_data():
|
|
return '测试成功'
|
|
return '测试成功'
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|
|
- uvicorn.run(app, host="0.0.0.0", port=port)
|
|
|
|
|
|
+ uvicorn.run(app, host="0.0.0.0", port=port)
|