|
@@ -9,24 +9,66 @@ import uvicorn, socket
|
|
|
from tqdm import tqdm
|
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
from config import *
|
|
|
-
|
|
|
+import asyncio, threading
|
|
|
from functions import split_dataframe_to_dict, extract_list_from_string
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), 'bert'))
|
|
|
import torch
|
|
|
# from model import BertClassifier
|
|
|
from bert.model import BertClassifier
|
|
|
from transformers import BertTokenizer, BertConfig
|
|
|
-
|
|
|
+from fastapi.middleware.cors import CORSMiddleware
|
|
|
app = FastAPI()
|
|
|
+app.add_middleware(
|
|
|
+ CORSMiddleware,
|
|
|
+ allow_origins=["*"],
|
|
|
+ allow_credentials=True,
|
|
|
+ allow_methods=["*"],
|
|
|
+ allow_headers=["*"],
|
|
|
+)
|
|
|
if not os.path.exists(basic_path):
|
|
|
os.makedirs(basic_path, exist_ok=True)
|
|
|
app.mount("/data", StaticFiles(directory=basic_path), name="static")
|
|
|
+bert_config = BertConfig.from_pretrained(pre_train_model)
|
|
|
+# 定义模型
|
|
|
+model = BertClassifier(bert_config, len(label_revert_map.keys()))
|
|
|
+# 加载训练好的模型
|
|
|
+model.load_state_dict(torch.load(model_save_path, map_location=torch.device('cuda:0')))
|
|
|
+model.eval()
|
|
|
+tokenizer = BertTokenizer.from_pretrained(pre_train_model)
|
|
|
+# class ModelManager:
|
|
|
+# def __init__(self, pre_train_model, model_save_path, label_revert_map):
|
|
|
+# self.pre_train_model = pre_train_model
|
|
|
+# self.model_save_path = model_save_path
|
|
|
+# self.label_revert_map = label_revert_map
|
|
|
+# self.model_cache = {}
|
|
|
+# self.tokenizer_cache = {}
|
|
|
+
|
|
|
+# def get_model(self):
|
|
|
+# # 使用当前线程ID作为键,确保每个线程有自己的模型实例
|
|
|
+# thread_id = threading.get_ident()
|
|
|
+
|
|
|
+# if thread_id not in self.model_cache:
|
|
|
+# bert_config = BertConfig.from_pretrained(self.pre_train_model)
|
|
|
+# model = BertClassifier(bert_config, len(self.label_revert_map.keys()))
|
|
|
+# model.load_state_dict(torch.load(self.model_save_path, map_location=torch.device('cuda:0')))
|
|
|
+# model.eval()
|
|
|
+# self.model_cache[thread_id] = model
|
|
|
+
|
|
|
+# if thread_id not in self.tokenizer_cache:
|
|
|
+# tokenizer = BertTokenizer.from_pretrained(self.pre_train_model)
|
|
|
+# self.tokenizer_cache[thread_id] = tokenizer
|
|
|
+
|
|
|
+# print(f"Thread {thread_id}: Tokenizer id: {id(self.tokenizer_cache[thread_id])}, Model id: {id(self.model_cache[thread_id])}")
|
|
|
+# return self.tokenizer_cache[thread_id], self.model_cache[thread_id]
|
|
|
+
|
|
|
+# 使用示例
|
|
|
+# model_manager = ModelManager(pre_train_model, model_save_path, label_revert_map)
|
|
|
class ClassificationRequest(BaseModel):
|
|
|
path: str
|
|
|
client_id: str
|
|
|
one_key: str
|
|
|
name_column: str
|
|
|
- api_key: str = "sk-iREtaVNjamaBArOTlc_2BfGFJVPiU-9EjSFMUspIPBT3BlbkFJxS0SMmKZD9L9UumPczee4VKawCwVeGBQAr9MgsWGkA"
|
|
|
+ api_key: str = "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"
|
|
|
proxy: bool = False
|
|
|
chunk_size: int = 100
|
|
|
|
|
@@ -35,13 +77,13 @@ class ClassificationRequestBert(BaseModel):
|
|
|
client_id: str
|
|
|
name_column: str
|
|
|
|
|
|
-bert_config = BertConfig.from_pretrained(pre_train_model)
|
|
|
-# 定义模型
|
|
|
-model = BertClassifier(bert_config, len(label_revert_map.keys()))
|
|
|
-# 加载训练好的模型
|
|
|
-model.load_state_dict(torch.load(model_save_path, map_location=torch.device('cpu')))
|
|
|
-model.eval()
|
|
|
-tokenizer = BertTokenizer.from_pretrained(pre_train_model)
|
|
|
+# bert_config = BertConfig.from_pretrained(pre_train_model)
|
|
|
+# # 定义模型
|
|
|
+# model = BertClassifier(bert_config, len(label_revert_map.keys()))
|
|
|
+# # 加载训练好的模型
|
|
|
+# model.load_state_dict(torch.load(model_save_path, map_location=torch.device('cuda:0')))
|
|
|
+# model.eval()
|
|
|
+# tokenizer = BertTokenizer.from_pretrained(pre_train_model)
|
|
|
|
|
|
def bert_predict(text):
|
|
|
if type(text) == str and text != '':
|
|
@@ -64,7 +106,22 @@ def bert_predict(text):
|
|
|
return label_revert_map[pred_label]
|
|
|
else:
|
|
|
return ''
|
|
|
-def predict_excel(file_path, name_col, temp_path, save_path, chunksize=5):
|
|
|
+def process_data(origin_csv_path, temp_path, name_col, chunksize, total_processed, df_error):
|
|
|
+ for chunk in tqdm(pd.read_csv(origin_csv_path, chunksize=chunksize, iterator=True), total=len(pd.read_csv(origin_csv_path)) // chunksize + 1, desc='Processing', unit='item'):
|
|
|
+ try:
|
|
|
+ # 对每个块进行处理
|
|
|
+ chunk['AI Group'] = chunk[name_col].apply(lambda x : bert_predict(x))
|
|
|
+ # 增量保存处理结果
|
|
|
+ if total_processed == 0:
|
|
|
+ chunk.to_csv(temp_path, mode='w', index=False)
|
|
|
+ else:
|
|
|
+ chunk.to_csv(temp_path, mode='a', header=False, index=False)
|
|
|
+ # 更新已处理的数据量
|
|
|
+ total_processed += len(chunk)
|
|
|
+ except Exception as e:
|
|
|
+ df_error = pd.concat([df_error, chunk])
|
|
|
+ return temp_path, df_error, total_processed
|
|
|
+async def predict_excel(file_path, name_col, temp_path, save_path, chunksize=5):
|
|
|
# 初始化变量
|
|
|
error_file_name, error_file_extension = os.path.splitext(os.path.basename(save_path))
|
|
|
# 添加后缀
|
|
@@ -77,28 +134,44 @@ def predict_excel(file_path, name_col, temp_path, save_path, chunksize=5):
|
|
|
df_origin = pd.read_excel(file_path)
|
|
|
df_error = pd.DataFrame(columns=df_origin.columns)
|
|
|
df_origin.to_csv(origin_csv_path, index=False)
|
|
|
+
|
|
|
+ temp_path, df_error, total_processed = await asyncio.to_thread(process_data, origin_csv_path, temp_path, name_col, chunksize, total_processed, df_error)
|
|
|
+ df_final = pd.read_csv(temp_path)
|
|
|
+ df_final.to_excel(save_path, index=False)
|
|
|
+ try:
|
|
|
+ os.remove(origin_csv_path)
|
|
|
+ except Exception as e :
|
|
|
+ pass
|
|
|
+ if len(df_error) == 0:
|
|
|
+ return save_path, '', 0
|
|
|
+ else:
|
|
|
+ df_error.to_excel(error_file_path)
|
|
|
+ return save_path, error_file_path, len(df_error)
|
|
|
+
|
|
|
+async def predict_csv(file_path, name_col, temp_path, save_path, chunksize=5):
|
|
|
+ # 初始化变量
|
|
|
+ error_file_name, error_file_extension = os.path.splitext(os.path.basename(save_path))
|
|
|
+ # 添加后缀
|
|
|
+ error_file = error_file_name + '_error' + error_file_extension
|
|
|
+ origin_csv_file = error_file_name + '_origin.csv'
|
|
|
+ # 生成新的文件路径
|
|
|
+ error_file_path = os.path.join(os.path.dirname(save_path), error_file)
|
|
|
+ origin_csv_path = os.path.join(os.path.dirname(save_path), origin_csv_file)
|
|
|
+ total_processed = 0
|
|
|
+ df_origin = pd.read_csv(file_path)
|
|
|
+ df_error = pd.DataFrame(columns=df_origin.columns)
|
|
|
+ df_origin.to_csv(origin_csv_path, index=False)
|
|
|
# 按块读取 CSV 文件
|
|
|
|
|
|
- for chunk in tqdm(pd.read_csv(origin_csv_path, chunksize=chunksize, iterator=True), total=len(pd.read_csv(origin_csv_path)) // chunksize + 1, desc='Processing', unit='item'):
|
|
|
- try:
|
|
|
- # 对每个块进行处理
|
|
|
- chunk['classify'] = chunk[name_col].apply(bert_predict)
|
|
|
- # 增量保存处理结果
|
|
|
- if total_processed == 0:
|
|
|
- chunk.to_csv(temp_path, mode='w', index=False)
|
|
|
- else:
|
|
|
- chunk.to_csv(temp_path, mode='a', header=False, index=False)
|
|
|
- # 更新已处理的数据量
|
|
|
- total_processed += len(chunk)
|
|
|
- except Exception as e:
|
|
|
- df_error = pd.concat([df_error, chunk])
|
|
|
+ # for chunk in tqdm(pd.read_csv(origin_csv_path, chunksize=chunksize, iterator=True), total=len(pd.read_csv(origin_csv_path)) // chunksize + 1, desc='Processing', unit='item'):
|
|
|
+ temp_path, df_error, total_processed = await asyncio.to_thread(process_data, origin_csv_path, temp_path, name_col, chunksize, total_processed, df_error)
|
|
|
df_final = pd.read_csv(temp_path)
|
|
|
- df_final.to_excel(save_path)
|
|
|
+ df_final.to_csv(save_path, index=False)
|
|
|
os.remove(origin_csv_path)
|
|
|
if len(df_error) == 0:
|
|
|
return save_path, '', 0
|
|
|
else:
|
|
|
- df_error.to_excel(error_file_path)
|
|
|
+ df_error.to_csv(error_file_path)
|
|
|
return save_path, error_file_path, len(df_error)
|
|
|
def remove_files():
|
|
|
current_time = time.time()
|
|
@@ -130,12 +203,29 @@ async def create_upload_file(file: UploadFile = File(...), client_id: str = Form
|
|
|
os.chmod(file_location, 0o777) # 设置文件权限为777
|
|
|
return JSONResponse(content={
|
|
|
"message": f"文件 '{file.filename}' 上传成功",
|
|
|
- "client_id": client_id,
|
|
|
- "file_path": file_location
|
|
|
+ "data":{"client_id": client_id,"file_path": file_location},
|
|
|
+ "code":200
|
|
|
}, status_code=200)
|
|
|
except Exception as e:
|
|
|
- return JSONResponse(content={"message": f"发生错误: {str(e)}"}, status_code=500)
|
|
|
-
|
|
|
+ return JSONResponse(content={"message": f"发生错误: {str(e)}", "data":{}, "code":500}, status_code=500)
|
|
|
+async def process_openai_data(deal_result, client, temp_csv):
|
|
|
+ for name, value in tqdm(deal_result.items(), desc='Processing', unit='item'):
|
|
|
+ try:
|
|
|
+ message = [
|
|
|
+ {'role':'system', 'content': cls_system_prompt},
|
|
|
+ {'role':'user', 'content':user_prompt.format(chunk=str(value))}
|
|
|
+ ]
|
|
|
+ # result_string = post_openai(message)
|
|
|
+ response = await client.chat.completions.create(model='gpt-4',messages=message)
|
|
|
+ result_string = response.choices[0].message.content
|
|
|
+ result = extract_list_from_string(result_string)
|
|
|
+ if result:
|
|
|
+ df_output = pd.DataFrame(result)
|
|
|
+ df_output.to_csv(temp_csv, mode='a', header=True, index=False)
|
|
|
+ else:
|
|
|
+ continue
|
|
|
+ except Exception as e:
|
|
|
+ print(f'{name}出现问题啦, 错误为:{e} 请自行调试')
|
|
|
@app.post("/classify_openai/")
|
|
|
async def classify_data(request: ClassificationRequest):
|
|
|
try:
|
|
@@ -143,16 +233,20 @@ async def classify_data(request: ClassificationRequest):
|
|
|
work_path = f'{basic_path}/{request.client_id}'
|
|
|
if not os.path.exists(work_path):
|
|
|
os.makedirs(work_path, exist_ok=True)
|
|
|
+ final_file_name, final_file_extension = os.path.splitext(os.path.basename(request.path))
|
|
|
timestamp_str = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
|
|
- df_origin = pd.read_excel(request.path)
|
|
|
- df_origin['name'] = df_origin[request.name_column]
|
|
|
- df_origin['classify'] = ''
|
|
|
- df_use = df_origin[['name', 'classify']]
|
|
|
+ if final_file_extension == '.csv':
|
|
|
+ df_origin = pd.read_csv(request.path)
|
|
|
+ else:
|
|
|
+ df_origin = pd.read_excel(request.path)
|
|
|
+ df_origin['AI Name'] = df_origin[request.name_column]
|
|
|
+ df_origin['AI Group'] = ''
|
|
|
+ df_use = df_origin[['AI Name', 'AI Group']]
|
|
|
deal_result = split_dataframe_to_dict(df_use, request.chunk_size)
|
|
|
# 生成当前时间的时间戳字符串
|
|
|
|
|
|
temp_csv = work_path + '/' + timestamp_str + 'output_temp.csv'
|
|
|
- final_file_name, final_file_extension = os.path.splitext(os.path.basename(request.path))
|
|
|
+
|
|
|
# 添加后缀
|
|
|
final_file = final_file_name + '_classify' + final_file_extension
|
|
|
|
|
@@ -161,54 +255,68 @@ async def classify_data(request: ClassificationRequest):
|
|
|
|
|
|
if not request.proxy:
|
|
|
print(f'用户{request.client_id}正在使用直连的gpt-API')
|
|
|
- client = openai.OpenAI(api_key=request.api_key, base_url=openai_url)
|
|
|
+ client = openai.AsyncOpenAI(api_key=request.api_key, base_url=openai_url)
|
|
|
else:
|
|
|
- client = openai.OpenAI(api_key=request.api_key, base_url=proxy_url)
|
|
|
- for name, value in tqdm(deal_result.items(), desc='Processing', unit='item'):
|
|
|
- try:
|
|
|
- message = [
|
|
|
- {'role':'system', 'content': cls_system_prompt},
|
|
|
- {'role':'user', 'content':user_prompt.format(chunk=str(value))}
|
|
|
- ]
|
|
|
- # result_string = post_openai(message)
|
|
|
- response = client.chat.completions.create(model='gpt-4',messages=message)
|
|
|
- result_string = response.choices[0].message.content
|
|
|
- result = extract_list_from_string(result_string)
|
|
|
- if result:
|
|
|
- df_output = pd.DataFrame(result)
|
|
|
- df_output.to_csv(temp_csv, mode='a', header=True, index=False)
|
|
|
- else:
|
|
|
- continue
|
|
|
- except Exception as e:
|
|
|
- print(f'{name}出现问题啦, 错误为:{e} 请自行调试')
|
|
|
+ client = openai.AsyncOpenAI(api_key=request.api_key, base_url=proxy_url)
|
|
|
+ x = await process_openai_data(deal_result, client, temp_csv)
|
|
|
if os.path.exists(temp_csv):
|
|
|
df_result = pd.read_csv(temp_csv)
|
|
|
- df_final = df_origin.merge(df_result, on='name', how='left').drop_duplicates(subset=[request.one_key,'name'], keep='first')
|
|
|
- df_final.to_excel(new_file_path)
|
|
|
- return {"message": "分类完成", "output_file": file_base_url + new_file_path.split(basic_path)[1]}
|
|
|
+ df_origin = df_origin.drop(columns='AI Group')
|
|
|
+ df_final = df_origin.merge(df_result, on='AI Name', how='left').drop_duplicates(subset=[request.one_key,'AI Name'], keep='first')
|
|
|
+ df_final = df_final.drop(columns='AI Name')
|
|
|
+ if final_file_extension == '.csv':
|
|
|
+ df_final.to_csv(new_file_path, index=False)
|
|
|
+ else:
|
|
|
+ df_final.to_excel(new_file_path, index=False)
|
|
|
+ return JSONResponse(content={
|
|
|
+ "message": f"分类完成",
|
|
|
+ "data":{"output_file": file_base_url + new_file_path.split(basic_path)[1]},
|
|
|
+ "code":200
|
|
|
+ }, status_code=200)
|
|
|
else:
|
|
|
- return {"message": "文件没能处理成功"}
|
|
|
+ print('...')
|
|
|
+ return JSONResponse(content={"message": "文件没能处理成功","data":{},"code":500}, status_code=500)
|
|
|
except Exception as e:
|
|
|
- return {"message": f"处理出现错误: {e}"}
|
|
|
+ print('okok')
|
|
|
+ return JSONResponse(content={"message": f"处理出现错误: {e}","data":{}, "code":500}, status_code=500)
|
|
|
|
|
|
@app.post("/classify_bert/")
|
|
|
-async def classify_data(request: ClassificationRequestBert):
|
|
|
+async def real_process(request: ClassificationRequestBert):
|
|
|
+ task = asyncio.create_task(classify_data_bert(request))
|
|
|
+ result = await task
|
|
|
+ return result
|
|
|
+
|
|
|
+async def classify_data_bert(request: ClassificationRequestBert):
|
|
|
remove_files()
|
|
|
work_path = f'{basic_path}/{request.client_id}'
|
|
|
if not os.path.exists(work_path):
|
|
|
os.makedirs(work_path, exist_ok=True)
|
|
|
final_file_name, final_file_extension = os.path.splitext(os.path.basename(request.path))
|
|
|
- # 添加后缀
|
|
|
final_file = final_file_name + '_classify' + final_file_extension
|
|
|
-
|
|
|
# 生成新的文件路径
|
|
|
new_file_path = os.path.join(os.path.dirname(request.path), final_file)
|
|
|
timestamp_str = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
|
|
temp_csv = work_path + '/' + timestamp_str + 'output_temp.csv'
|
|
|
- save_path, error_path, error_len = predict_excel(request.path, request.name_column, temp_path=temp_csv ,save_path=new_file_path)
|
|
|
+ if final_file_extension == '.csv':
|
|
|
+ save_path, error_path, error_len = await predict_csv(request.path, request.name_column, temp_path=temp_csv ,save_path=new_file_path)
|
|
|
+ else:
|
|
|
+ # 添加后缀
|
|
|
+ save_path, error_path, error_len = await predict_excel(request.path, request.name_column, temp_path=temp_csv ,save_path=new_file_path)
|
|
|
if error_len == 0:
|
|
|
- return {"message": "分类完成", "output_file": file_base_url + save_path.split(basic_path)[1]}
|
|
|
+ return JSONResponse(content={
|
|
|
+ "message": "分类完成",
|
|
|
+ "data":{"output_file": file_base_url + save_path.split(basic_path)[1]},
|
|
|
+ "code":200
|
|
|
+ }, status_code=200)
|
|
|
else:
|
|
|
- return {"message": "分类完成只完成部分", "output_file": file_base_url + save_path.split(basic_path)[1], "output_file_nonprocess":file_base_url + error_path.split(basic_path)[1],}
|
|
|
+ return JSONResponse(content={
|
|
|
+ "message": "分类完成只完成部分",
|
|
|
+ "data":{"output_file": file_base_url + save_path.split(basic_path)[1],
|
|
|
+ "output_file_nonprocess":file_base_url + error_path.split(basic_path)[1]},
|
|
|
+ "code":200
|
|
|
+ }, status_code=200)
|
|
|
+@app.get("/test/")
|
|
|
+async def classify_data():
|
|
|
+ return '测试成功'
|
|
|
if __name__ == "__main__":
|
|
|
uvicorn.run(app, host="0.0.0.0", port=port)
|