name_classify_api.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. import pandas as pd
  2. import os, time, shutil, sys
  3. import openai
  4. from fastapi import FastAPI, UploadFile, File, Form
  5. from pydantic import BaseModel
  6. from fastapi.responses import JSONResponse
  7. from datetime import datetime
  8. import uvicorn, socket
  9. from tqdm import tqdm
  10. from fastapi.staticfiles import StaticFiles
  11. from config import *
  12. import asyncio, threading
  13. from functions import split_dataframe_to_dict, extract_list_from_string
  14. sys.path.append(os.path.join(os.path.dirname(__file__), 'bert'))
  15. import torch
  16. # from model import BertClassifier
  17. from bert.model import BertClassifier
  18. from transformers import BertTokenizer, BertConfig
  19. from fastapi.middleware.cors import CORSMiddleware
  20. app = FastAPI()
  21. app.add_middleware(
  22. CORSMiddleware,
  23. allow_origins=["*"],
  24. allow_credentials=True,
  25. allow_methods=["*"],
  26. allow_headers=["*"],
  27. )
  28. if not os.path.exists(basic_path):
  29. os.makedirs(basic_path, exist_ok=True)
  30. app.mount("/data", StaticFiles(directory=basic_path), name="static")
  31. bert_config = BertConfig.from_pretrained(pre_train_model)
  32. # 定义模型
  33. model = BertClassifier(bert_config, len(label_revert_map.keys()))
  34. # 加载训练好的模型
  35. model.load_state_dict(torch.load(model_save_path, map_location=torch.device('cuda:0')))
  36. model.eval()
  37. tokenizer = BertTokenizer.from_pretrained(pre_train_model)
  38. # class ModelManager:
  39. # def __init__(self, pre_train_model, model_save_path, label_revert_map):
  40. # self.pre_train_model = pre_train_model
  41. # self.model_save_path = model_save_path
  42. # self.label_revert_map = label_revert_map
  43. # self.model_cache = {}
  44. # self.tokenizer_cache = {}
  45. # def get_model(self):
  46. # # 使用当前线程ID作为键,确保每个线程有自己的模型实例
  47. # thread_id = threading.get_ident()
  48. # if thread_id not in self.model_cache:
  49. # bert_config = BertConfig.from_pretrained(self.pre_train_model)
  50. # model = BertClassifier(bert_config, len(self.label_revert_map.keys()))
  51. # model.load_state_dict(torch.load(self.model_save_path, map_location=torch.device('cuda:0')))
  52. # model.eval()
  53. # self.model_cache[thread_id] = model
  54. # if thread_id not in self.tokenizer_cache:
  55. # tokenizer = BertTokenizer.from_pretrained(self.pre_train_model)
  56. # self.tokenizer_cache[thread_id] = tokenizer
  57. # print(f"Thread {thread_id}: Tokenizer id: {id(self.tokenizer_cache[thread_id])}, Model id: {id(self.model_cache[thread_id])}")
  58. # return self.tokenizer_cache[thread_id], self.model_cache[thread_id]
  59. # 使用示例
  60. # model_manager = ModelManager(pre_train_model, model_save_path, label_revert_map)
  61. class ClassificationRequest(BaseModel):
  62. path: str
  63. client_id: str
  64. one_key: str
  65. name_column: str
  66. api_key: str = "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"
  67. proxy: bool = False
  68. chunk_size: int = 100
  69. class ClassificationRequestBert(BaseModel):
  70. path: str
  71. client_id: str
  72. name_column: str
  73. # bert_config = BertConfig.from_pretrained(pre_train_model)
  74. # # 定义模型
  75. # model = BertClassifier(bert_config, len(label_revert_map.keys()))
  76. # # 加载训练好的模型
  77. # model.load_state_dict(torch.load(model_save_path, map_location=torch.device('cuda:0')))
  78. # model.eval()
  79. # tokenizer = BertTokenizer.from_pretrained(pre_train_model)
  80. def bert_predict(text):
  81. if type(text) == str and text != '':
  82. token = tokenizer(text, add_special_tokens=True, padding='max_length', truncation=True, max_length=512)
  83. input_ids = token['input_ids']
  84. attention_mask = token['attention_mask']
  85. token_type_ids = token['token_type_ids']
  86. input_ids = torch.tensor([input_ids], dtype=torch.long)
  87. attention_mask = torch.tensor([attention_mask], dtype=torch.long)
  88. token_type_ids = torch.tensor([token_type_ids], dtype=torch.long)
  89. predicted = model(
  90. input_ids,
  91. attention_mask,
  92. token_type_ids,
  93. )
  94. pred_label = torch.argmax(predicted, dim=1).numpy()[0]
  95. return label_revert_map[pred_label]
  96. else:
  97. return ''
  98. def process_data(origin_csv_path, temp_path, name_col, chunksize, total_processed, df_error):
  99. if name_col not in pd.read_csv(origin_csv_path).columns:
  100. return '列名错误', None, None
  101. for chunk in tqdm(pd.read_csv(origin_csv_path, chunksize=chunksize, iterator=True), total=len(pd.read_csv(origin_csv_path)) // chunksize + 1, desc='Processing', unit='item'):
  102. try:
  103. # 对每个块进行处理
  104. chunk['AI Group'] = chunk[name_col].apply(lambda x : bert_predict(x))
  105. # 增量保存处理结果
  106. if total_processed == 0:
  107. chunk.to_csv(temp_path, mode='w', index=False)
  108. else:
  109. chunk.to_csv(temp_path, mode='a', header=False, index=False)
  110. # 更新已处理的数据量
  111. total_processed += len(chunk)
  112. except Exception as e:
  113. df_error = pd.concat([df_error, chunk])
  114. return temp_path, df_error, total_processed
  115. async def predict_excel(file_path, name_col, temp_path, save_path, chunksize=5):
  116. # 初始化变量
  117. error_file_name, error_file_extension = os.path.splitext(os.path.basename(save_path))
  118. # 添加后缀
  119. error_file = error_file_name + '_error' + error_file_extension
  120. origin_csv_file = error_file_name + '_origin.csv'
  121. # 生成新的文件路径
  122. error_file_path = os.path.join(os.path.dirname(save_path), error_file)
  123. origin_csv_path = os.path.join(os.path.dirname(save_path), origin_csv_file)
  124. total_processed = 0
  125. df_origin = pd.read_excel(file_path)
  126. df_error = pd.DataFrame(columns=df_origin.columns)
  127. df_origin.to_csv(origin_csv_path, index=False)
  128. temp_path, df_error, total_processed = await asyncio.to_thread(process_data, origin_csv_path, temp_path, name_col, chunksize, total_processed, df_error)
  129. if temp_path == '列名错误':
  130. return '列名错误', None, None
  131. df_final = pd.read_csv(temp_path)
  132. df_final.to_excel(save_path, index=False)
  133. try:
  134. os.remove(origin_csv_path)
  135. except Exception as e :
  136. pass
  137. if len(df_error) == 0:
  138. return save_path, '', 0
  139. else:
  140. df_error.to_excel(error_file_path)
  141. return save_path, error_file_path, len(df_error)
  142. async def predict_csv(file_path, name_col, temp_path, save_path, chunksize=5):
  143. # 初始化变量
  144. error_file_name, error_file_extension = os.path.splitext(os.path.basename(save_path))
  145. # 添加后缀
  146. error_file = error_file_name + '_error' + error_file_extension
  147. origin_csv_file = error_file_name + '_origin.csv'
  148. # 生成新的文件路径
  149. error_file_path = os.path.join(os.path.dirname(save_path), error_file)
  150. origin_csv_path = os.path.join(os.path.dirname(save_path), origin_csv_file)
  151. total_processed = 0
  152. df_origin = pd.read_csv(file_path)
  153. df_error = pd.DataFrame(columns=df_origin.columns)
  154. df_origin.to_csv(origin_csv_path, index=False)
  155. # 按块读取 CSV 文件
  156. # for chunk in tqdm(pd.read_csv(origin_csv_path, chunksize=chunksize, iterator=True), total=len(pd.read_csv(origin_csv_path)) // chunksize + 1, desc='Processing', unit='item'):
  157. temp_path, df_error, total_processed = await asyncio.to_thread(process_data, origin_csv_path, temp_path, name_col, chunksize, total_processed, df_error)
  158. if temp_path == '列名错误':
  159. return '列名错误', None, None
  160. df_final = pd.read_csv(temp_path)
  161. df_final.to_csv(save_path, index=False)
  162. os.remove(origin_csv_path)
  163. if len(df_error) == 0:
  164. return save_path, '', 0
  165. else:
  166. df_error.to_csv(error_file_path)
  167. return save_path, error_file_path, len(df_error)
  168. def remove_files():
  169. current_time = time.time()
  170. TIME_THRESHOLD_FILEPATH = 30 * 24 * 60 * 60
  171. TIME_THRESHOLD_FILE = 10 * 24 * 60 * 60
  172. for root, dirs, files in os.walk(basic_path, topdown=False):
  173. # 删除文件
  174. for file in files:
  175. file_path = os.path.join(root, file)
  176. if current_time - os.path.getmtime(file_path) > TIME_THRESHOLD_FILE:
  177. print(f"删除文件: {file_path}")
  178. os.remove(file_path)
  179. # 删除文件夹
  180. for dir in dirs:
  181. dir_path = os.path.join(root, dir)
  182. if current_time - os.path.getmtime(dir_path) > TIME_THRESHOLD_FILEPATH:
  183. print(f"删除文件夹: {dir_path}")
  184. shutil.rmtree(dir_path)
  185. @app.post("/uploadfile/")
  186. async def create_upload_file(file: UploadFile = File(...), client_id: str = Form(...)):
  187. user_directory = f'{basic_path}/{client_id}'
  188. if not os.path.exists(user_directory):
  189. os.makedirs(user_directory)
  190. os.chmod(user_directory, 0o777) # 设置用户目录权限为777
  191. file_location = os.path.join(user_directory, file.filename)
  192. try:
  193. with open(file_location, "wb+") as file_object:
  194. file_object.write(file.file.read())
  195. os.chmod(file_location, 0o777) # 设置文件权限为777
  196. return JSONResponse(content={
  197. "message": f"文件 '{file.filename}' 上传成功",
  198. "data":{"client_id": client_id,"file_path": file_location},
  199. "code":200
  200. }, status_code=200)
  201. except Exception as e:
  202. return JSONResponse(content={"message": f"发生错误: {str(e)}", "data":{}, "code":500}, status_code=500)
  203. async def process_openai_data(deal_result, client, temp_csv):
  204. for name, value in tqdm(deal_result.items(), desc='Processing', unit='item'):
  205. try:
  206. message = [
  207. {'role':'system', 'content': cls_system_prompt},
  208. {'role':'user', 'content':user_prompt.format(chunk=str(value))}
  209. ]
  210. # result_string = post_openai(message)
  211. response = await client.chat.completions.create(model='gpt-4',messages=message)
  212. result_string = response.choices[0].message.content
  213. result = extract_list_from_string(result_string)
  214. if result:
  215. df_output = pd.DataFrame(result)
  216. df_output.to_csv(temp_csv, mode='a', header=True, index=False)
  217. else:
  218. continue
  219. except Exception as e:
  220. print(f'{name}出现问题啦, 错误为:{e} 请自行调试')
  221. @app.post("/classify_openai/")
  222. async def classify_data(request: ClassificationRequest):
  223. try:
  224. remove_files()
  225. work_path = f'{basic_path}/{request.client_id}'
  226. if not os.path.exists(work_path):
  227. os.makedirs(work_path, exist_ok=True)
  228. final_file_name, final_file_extension = os.path.splitext(os.path.basename(request.path))
  229. timestamp_str = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
  230. if final_file_extension == '.csv':
  231. df_origin = pd.read_csv(request.path)
  232. else:
  233. df_origin = pd.read_excel(request.path)
  234. if request.name_column not in df_origin.columns or request.one_key not in df_origin.columns:
  235. return JSONResponse(content={"message": "用户标识列输入错误,请重新输入","data":{},"code":500}, status_code=500)
  236. df_origin['AI Name'] = df_origin[request.name_column]
  237. df_origin['AI Group'] = ''
  238. df_use = df_origin[['AI Name', 'AI Group']]
  239. deal_result = split_dataframe_to_dict(df_use, request.chunk_size)
  240. # 生成当前时间的时间戳字符串
  241. temp_csv = work_path + '/' + timestamp_str + 'output_temp.csv'
  242. # 添加后缀
  243. final_file = final_file_name + '_classify' + final_file_extension
  244. # 生成新的文件路径
  245. new_file_path = os.path.join(os.path.dirname(request.path), final_file)
  246. if not request.proxy:
  247. print(f'用户{request.client_id}正在使用直连的gpt-API')
  248. client = openai.AsyncOpenAI(api_key=request.api_key, base_url=openai_url)
  249. else:
  250. client = openai.AsyncOpenAI(api_key=request.api_key, base_url=proxy_url)
  251. x = await process_openai_data(deal_result, client, temp_csv)
  252. if os.path.exists(temp_csv):
  253. df_result = pd.read_csv(temp_csv)
  254. df_origin = df_origin.drop(columns='AI Group')
  255. df_final = df_origin.merge(df_result, on='AI Name', how='left').drop_duplicates(subset=[request.one_key,'AI Name'], keep='first')
  256. df_final = df_final.drop(columns='AI Name')
  257. if final_file_extension == '.csv':
  258. df_final.to_csv(new_file_path, index=False)
  259. else:
  260. df_final.to_excel(new_file_path, index=False)
  261. return JSONResponse(content={
  262. "message": f"分类完成",
  263. "data":{"output_file": file_base_url + new_file_path.split(basic_path)[1]},
  264. "code":200
  265. }, status_code=200)
  266. else:
  267. print('...')
  268. return JSONResponse(content={"message": "文件没能处理成功","data":{},"code":500}, status_code=500)
  269. except Exception as e:
  270. print('okok')
  271. return JSONResponse(content={"message": f"处理出现错误: {e}","data":{}, "code":500}, status_code=500)
  272. @app.post("/classify_bert/")
  273. async def real_process(request: ClassificationRequestBert):
  274. task = asyncio.create_task(classify_data_bert(request))
  275. result = await task
  276. return result
  277. async def classify_data_bert(request: ClassificationRequestBert):
  278. remove_files()
  279. work_path = f'{basic_path}/{request.client_id}'
  280. if not os.path.exists(work_path):
  281. os.makedirs(work_path, exist_ok=True)
  282. final_file_name, final_file_extension = os.path.splitext(os.path.basename(request.path))
  283. final_file = final_file_name + '_classify' + final_file_extension
  284. # 生成新的文件路径
  285. new_file_path = os.path.join(os.path.dirname(request.path), final_file)
  286. timestamp_str = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
  287. temp_csv = work_path + '/' + timestamp_str + 'output_temp.csv'
  288. if final_file_extension == '.csv':
  289. save_path, error_path, error_len = await predict_csv(request.path, request.name_column, temp_path=temp_csv ,save_path=new_file_path)
  290. else:
  291. # 添加后缀
  292. save_path, error_path, error_len = await predict_excel(request.path, request.name_column, temp_path=temp_csv ,save_path=new_file_path)
  293. if error_len == 0:
  294. return JSONResponse(content={
  295. "message": "分类完成",
  296. "data":{"output_file": file_base_url + save_path.split(basic_path)[1]},
  297. "code":200
  298. }, status_code=200)
  299. elif save_path == '列名错误':
  300. return JSONResponse(content={
  301. "message": "用户标识列输入错误,请重新输入",
  302. "data":{},
  303. "code":500
  304. }, status_code=200)
  305. else:
  306. return JSONResponse(content={
  307. "message": "分类完成只完成部分",
  308. "data":{"output_file": file_base_url + save_path.split(basic_path)[1],
  309. "output_file_nonprocess":file_base_url + error_path.split(basic_path)[1]},
  310. "code":200
  311. }, status_code=200)
  312. @app.get("/test/")
  313. async def classify_data():
  314. return '测试成功'
  315. if __name__ == "__main__":
  316. uvicorn.run(app, host="0.0.0.0", port=port)