name_classify_api.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. import pandas as pd
  2. import os, time, shutil, sys
  3. import openai
  4. from fastapi import FastAPI, UploadFile, File, Form
  5. from pydantic import BaseModel
  6. from fastapi.responses import JSONResponse
  7. from datetime import datetime
  8. import uvicorn, socket
  9. from tqdm import tqdm
  10. from fastapi.staticfiles import StaticFiles
  11. from config import *
  12. from functions import split_dataframe_to_dict, extract_list_from_string
  13. sys.path.append(os.path.join(os.path.dirname(__file__), 'bert'))
  14. import torch
  15. # from model import BertClassifier
  16. from bert.model import BertClassifier
  17. from transformers import BertTokenizer, BertConfig
  18. app = FastAPI()
  19. if not os.path.exists(basic_path):
  20. os.makedirs(basic_path, exist_ok=True)
  21. app.mount("/data", StaticFiles(directory=basic_path), name="static")
  22. class ClassificationRequest(BaseModel):
  23. path: str
  24. client_id: str
  25. one_key: str
  26. name_column: str
  27. api_key: str = "sk-iREtaVNjamaBArOTlc_2BfGFJVPiU-9EjSFMUspIPBT3BlbkFJxS0SMmKZD9L9UumPczee4VKawCwVeGBQAr9MgsWGkA"
  28. proxy: bool = False
  29. chunk_size: int = 100
  30. class ClassificationRequestBert(BaseModel):
  31. path: str
  32. client_id: str
  33. name_column: str
  34. bert_config = BertConfig.from_pretrained(pre_train_model)
  35. # 定义模型
  36. model = BertClassifier(bert_config, len(label_revert_map.keys()))
  37. # 加载训练好的模型
  38. model.load_state_dict(torch.load(model_save_path, map_location=torch.device('cpu')))
  39. model.eval()
  40. tokenizer = BertTokenizer.from_pretrained(pre_train_model)
  41. def bert_predict(text):
  42. if type(text) == str and text != '':
  43. token = tokenizer(text, add_special_tokens=True, padding='max_length', truncation=True, max_length=512)
  44. input_ids = token['input_ids']
  45. attention_mask = token['attention_mask']
  46. token_type_ids = token['token_type_ids']
  47. input_ids = torch.tensor([input_ids], dtype=torch.long)
  48. attention_mask = torch.tensor([attention_mask], dtype=torch.long)
  49. token_type_ids = torch.tensor([token_type_ids], dtype=torch.long)
  50. predicted = model(
  51. input_ids,
  52. attention_mask,
  53. token_type_ids,
  54. )
  55. pred_label = torch.argmax(predicted, dim=1).numpy()[0]
  56. return label_revert_map[pred_label]
  57. else:
  58. return ''
  59. def predict_excel(file_path, name_col, temp_path, save_path, chunksize=5):
  60. # 初始化变量
  61. error_file_name, error_file_extension = os.path.splitext(os.path.basename(save_path))
  62. # 添加后缀
  63. error_file = error_file_name + '_error' + error_file_extension
  64. origin_csv_file = error_file_name + '_origin.csv'
  65. # 生成新的文件路径
  66. error_file_path = os.path.join(os.path.dirname(save_path), error_file)
  67. origin_csv_path = os.path.join(os.path.dirname(save_path), origin_csv_file)
  68. total_processed = 0
  69. df_origin = pd.read_excel(file_path)
  70. df_error = pd.DataFrame(columns=df_origin.columns)
  71. df_origin.to_csv(origin_csv_path, index=False)
  72. # 按块读取 CSV 文件
  73. for chunk in tqdm(pd.read_csv(origin_csv_path, chunksize=chunksize, iterator=True), total=len(pd.read_csv(origin_csv_path)) // chunksize + 1, desc='Processing', unit='item'):
  74. try:
  75. # 对每个块进行处理
  76. chunk['classify'] = chunk[name_col].apply(bert_predict)
  77. # 增量保存处理结果
  78. if total_processed == 0:
  79. chunk.to_csv(temp_path, mode='w', index=False)
  80. else:
  81. chunk.to_csv(temp_path, mode='a', header=False, index=False)
  82. # 更新已处理的数据量
  83. total_processed += len(chunk)
  84. except Exception as e:
  85. df_error = pd.concat([df_error, chunk])
  86. df_final = pd.read_csv(temp_path)
  87. df_final.to_excel(save_path)
  88. os.remove(origin_csv_path)
  89. if len(df_error) == 0:
  90. return save_path, '', 0
  91. else:
  92. df_error.to_excel(error_file_path)
  93. return save_path, error_file_path, len(df_error)
  94. def remove_files():
  95. current_time = time.time()
  96. TIME_THRESHOLD_FILEPATH = 30 * 24 * 60 * 60
  97. TIME_THRESHOLD_FILE = 10 * 24 * 60 * 60
  98. for root, dirs, files in os.walk(basic_path, topdown=False):
  99. # 删除文件
  100. for file in files:
  101. file_path = os.path.join(root, file)
  102. if current_time - os.path.getmtime(file_path) > TIME_THRESHOLD_FILE:
  103. print(f"删除文件: {file_path}")
  104. os.remove(file_path)
  105. # 删除文件夹
  106. for dir in dirs:
  107. dir_path = os.path.join(root, dir)
  108. if current_time - os.path.getmtime(dir_path) > TIME_THRESHOLD_FILEPATH:
  109. print(f"删除文件夹: {dir_path}")
  110. shutil.rmtree(dir_path)
  111. @app.post("/uploadfile/")
  112. async def create_upload_file(file: UploadFile = File(...), client_id: str = Form(...)):
  113. user_directory = f'{basic_path}/{client_id}'
  114. if not os.path.exists(user_directory):
  115. os.makedirs(user_directory)
  116. os.chmod(user_directory, 0o777) # 设置用户目录权限为777
  117. file_location = os.path.join(user_directory, file.filename)
  118. try:
  119. with open(file_location, "wb+") as file_object:
  120. file_object.write(file.file.read())
  121. os.chmod(file_location, 0o777) # 设置文件权限为777
  122. return JSONResponse(content={
  123. "message": f"文件 '{file.filename}' 上传成功",
  124. "client_id": client_id,
  125. "file_path": file_location
  126. }, status_code=200)
  127. except Exception as e:
  128. return JSONResponse(content={"message": f"发生错误: {str(e)}"}, status_code=500)
  129. @app.post("/classify_openai/")
  130. async def classify_data(request: ClassificationRequest):
  131. try:
  132. remove_files()
  133. work_path = f'{basic_path}/{request.client_id}'
  134. if not os.path.exists(work_path):
  135. os.makedirs(work_path, exist_ok=True)
  136. timestamp_str = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
  137. df_origin = pd.read_excel(request.path)
  138. df_origin['name'] = df_origin[request.name_column]
  139. df_origin['classify'] = ''
  140. df_use = df_origin[['name', 'classify']]
  141. deal_result = split_dataframe_to_dict(df_use, request.chunk_size)
  142. # 生成当前时间的时间戳字符串
  143. temp_csv = work_path + '/' + timestamp_str + 'output_temp.csv'
  144. final_file_name, final_file_extension = os.path.splitext(os.path.basename(request.path))
  145. # 添加后缀
  146. final_file = final_file_name + '_classify' + final_file_extension
  147. # 生成新的文件路径
  148. new_file_path = os.path.join(os.path.dirname(request.path), final_file)
  149. if not request.proxy:
  150. print(f'用户{request.client_id}正在使用直连的gpt-API')
  151. client = openai.OpenAI(api_key=request.api_key, base_url=openai_url)
  152. else:
  153. client = openai.OpenAI(api_key=request.api_key, base_url=proxy_url)
  154. for name, value in tqdm(deal_result.items(), desc='Processing', unit='item'):
  155. try:
  156. message = [
  157. {'role':'system', 'content': cls_system_prompt},
  158. {'role':'user', 'content':user_prompt.format(chunk=str(value))}
  159. ]
  160. # result_string = post_openai(message)
  161. response = client.chat.completions.create(model='gpt-4',messages=message)
  162. result_string = response.choices[0].message.content
  163. result = extract_list_from_string(result_string)
  164. if result:
  165. df_output = pd.DataFrame(result)
  166. df_output.to_csv(temp_csv, mode='a', header=True, index=False)
  167. else:
  168. continue
  169. except Exception as e:
  170. print(f'{name}出现问题啦, 错误为:{e} 请自行调试')
  171. if os.path.exists(temp_csv):
  172. df_result = pd.read_csv(temp_csv)
  173. df_final = df_origin.merge(df_result, on='name', how='left').drop_duplicates(subset=[request.one_key,'name'], keep='first')
  174. df_final.to_excel(new_file_path)
  175. return {"message": "分类完成", "output_file": file_base_url + new_file_path.split(basic_path)[1]}
  176. else:
  177. return {"message": "文件没能处理成功"}
  178. except Exception as e:
  179. return {"message": f"处理出现错误: {e}"}
  180. @app.post("/classify_bert/")
  181. async def classify_data(request: ClassificationRequestBert):
  182. remove_files()
  183. work_path = f'{basic_path}/{request.client_id}'
  184. if not os.path.exists(work_path):
  185. os.makedirs(work_path, exist_ok=True)
  186. final_file_name, final_file_extension = os.path.splitext(os.path.basename(request.path))
  187. # 添加后缀
  188. final_file = final_file_name + '_classify' + final_file_extension
  189. # 生成新的文件路径
  190. new_file_path = os.path.join(os.path.dirname(request.path), final_file)
  191. timestamp_str = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
  192. temp_csv = work_path + '/' + timestamp_str + 'output_temp.csv'
  193. save_path, error_path, error_len = predict_excel(request.path, request.name_column, temp_path=temp_csv ,save_path=new_file_path)
  194. if error_len == 0:
  195. return {"message": "分类完成", "output_file": file_base_url + save_path.split(basic_path)[1]}
  196. else:
  197. return {"message": "分类完成只完成部分", "output_file": file_base_url + save_path.split(basic_path)[1], "output_file_nonprocess":file_base_url + error_path.split(basic_path)[1],}
  198. if __name__ == "__main__":
  199. uvicorn.run(app, host="0.0.0.0", port=port)