import pandas as pd import os, time, shutil, sys import openai from fastapi import FastAPI, UploadFile, File, Form from pydantic import BaseModel from fastapi.responses import JSONResponse from datetime import datetime import uvicorn, socket from tqdm import tqdm from fastapi.staticfiles import StaticFiles from config import * import asyncio, threading from functions import split_dataframe_to_dict, extract_list_from_string sys.path.append(os.path.join(os.path.dirname(__file__), 'bert')) import torch # from model import BertClassifier from bert.model import BertClassifier from transformers import BertTokenizer, BertConfig from fastapi.middleware.cors import CORSMiddleware app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) if not os.path.exists(basic_path): os.makedirs(basic_path, exist_ok=True) app.mount("/data", StaticFiles(directory=basic_path), name="static") bert_config = BertConfig.from_pretrained(pre_train_model) # 定义模型 model = BertClassifier(bert_config, len(label_revert_map.keys())) # 加载训练好的模型 model.load_state_dict(torch.load(model_save_path, map_location=torch.device('cuda:0'))) model.eval() tokenizer = BertTokenizer.from_pretrained(pre_train_model) class ClassificationRequest(BaseModel): path: str client_id: str one_key: str name_column: str api_key: str = "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" proxy: bool = False chunk_size: int = 100 class ClassificationRequestBert(BaseModel): path: str client_id: str name_column: str def bert_predict(text): if type(text) == str and text != '': token = tokenizer(text, add_special_tokens=True, padding='max_length', truncation=True, max_length=512) input_ids = token['input_ids'] attention_mask = token['attention_mask'] token_type_ids = token['token_type_ids'] input_ids = torch.tensor([input_ids], dtype=torch.long) attention_mask = torch.tensor([attention_mask], dtype=torch.long) token_type_ids = torch.tensor([token_type_ids], dtype=torch.long) predicted = model( input_ids, attention_mask, token_type_ids, ) pred_label = torch.argmax(predicted, dim=1).numpy()[0] return label_revert_map[pred_label] else: return '' def remove_files(): current_time = time.time() TIME_THRESHOLD_FILEPATH = 30 * 24 * 60 * 60 TIME_THRESHOLD_FILE = 10 * 24 * 60 * 60 for root, dirs, files in os.walk(basic_path, topdown=False): # 删除文件 for file in files: file_path = os.path.join(root, file) if current_time - os.path.getmtime(file_path) > TIME_THRESHOLD_FILE: print(f"删除文件: {file_path}") os.remove(file_path) # 删除文件夹 for dir in dirs: dir_path = os.path.join(root, dir) if current_time - os.path.getmtime(dir_path) > TIME_THRESHOLD_FILEPATH: print(f"删除文件夹: {dir_path}") shutil.rmtree(dir_path) @app.post("/uploadfile/") async def create_upload_file(file: UploadFile = File(...), client_id: str = Form(...)): user_directory = f'{basic_path}/{client_id}' if not os.path.exists(user_directory): os.makedirs(user_directory) os.chmod(user_directory, 0o777) # 设置用户目录权限为777 file_location = os.path.join(user_directory, file.filename) try: with open(file_location, "wb+") as file_object: file_object.write(file.file.read()) os.chmod(file_location, 0o777) # 设置文件权限为777 return JSONResponse(content={ "message": f"文件 '{file.filename}' 上传成功", "data":{"client_id": client_id,"file_path": file_location}, "code":200 }, status_code=200) except Exception as e: return JSONResponse(content={"message": f"发生错误: {str(e)}", "data":{}, "code":500}, status_code=500) async def process_openai_data(deal_result, client, temp_csv): for name, value in tqdm(deal_result.items(), desc='Processing', unit='item'): message = [ {'role':'system', 'content': cls_system_prompt}, {'role':'user', 'content':user_prompt.format(chunk=str(value))} ] # result_string = post_openai(message) response = await client.chat.completions.create(model='gpt-4',messages=message) result_string = response.choices[0].message.content result = extract_list_from_string(result_string) if result: df_output = pd.DataFrame(result) df_output.to_csv(temp_csv, mode='a', header=True, index=False) else: continue @app.post("/classify_openai/") async def classify_data(request: ClassificationRequest): try: remove_files() work_path = f'{basic_path}/{request.client_id}' if not os.path.exists(work_path): os.makedirs(work_path, exist_ok=True) final_file_name, final_file_extension = os.path.splitext(os.path.basename(request.path)) timestamp_str = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") if final_file_extension == '.csv': df_origin = pd.read_csv(request.path) else: df_origin = pd.read_excel(request.path) if request.name_column not in df_origin.columns or request.one_key not in df_origin.columns: return JSONResponse(content={"message": "用户标识列输入错误,请重新输入","data":{},"code":500}, status_code=500) df_origin['AI Name'] = df_origin[request.name_column] df_origin['AI Group'] = '' df_use = df_origin[['AI Name', 'AI Group']] deal_result = split_dataframe_to_dict(df_use, request.chunk_size) # 生成当前时间的时间戳字符串 temp_csv = work_path + '/' + timestamp_str + 'output_temp.csv' # 添加后缀 final_file = final_file_name + '_classify' + final_file_extension # 生成新的文件路径 new_file_path = os.path.join(os.path.dirname(request.path), final_file) if not request.proxy: print(f'用户{request.client_id}正在使用直连的gpt-API') client = openai.AsyncOpenAI(api_key=request.api_key, base_url=openai_url) else: client = openai.AsyncOpenAI(api_key=request.api_key, base_url=proxy_url) x = await process_openai_data(deal_result, client, temp_csv) if os.path.exists(temp_csv): df_result = pd.read_csv(temp_csv) df_origin = df_origin.drop(columns='AI Group') df_final = df_origin.merge(df_result, on='AI Name', how='left').drop_duplicates(subset=[request.one_key,'AI Name'], keep='first') df_final = df_final.drop(columns='AI Name') if final_file_extension == '.csv': df_final.to_csv(new_file_path, index=False) else: df_final.to_excel(new_file_path, index=False) return JSONResponse(content={ "message": f"分类完成", "data":{"output_file": file_base_url + new_file_path.split(basic_path)[1]}, "code":200 }, status_code=200) else: print('...') return JSONResponse(content={"message": "文件没能处理成功","data":{},"code":500}, status_code=500) except Exception as e: print('okok') return JSONResponse(content={"message": f"处理出现错误: {e}","data":{}, "code":500}, status_code=500) @app.post("/classify_bert/") async def real_process(request: ClassificationRequestBert): # task = asyncio.create_task(classify_data_bert(request)) result = await classify_data_bert(request) return result async def classify_data_bert(request: ClassificationRequestBert): try: remove_files() work_path = f'{basic_path}/{request.client_id}' if not os.path.exists(work_path): os.makedirs(work_path, exist_ok=True) final_file_name, final_file_extension = os.path.splitext(os.path.basename(request.path)) final_file = final_file_name + '_classify' + final_file_extension new_file_path = os.path.join(os.path.dirname(request.path), final_file) timestamp_str = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") temp_csv = work_path + '/' + timestamp_str + 'output_temp.csv' def process_data(origin_csv_path, temp_path, name_col, chunksize, total_processed, df_error): if name_col not in pd.read_csv(origin_csv_path).columns: return '列名错误', None, None for chunk in tqdm(pd.read_csv(origin_csv_path, chunksize=chunksize, iterator=True), total=len(pd.read_csv(origin_csv_path)) // chunksize + 1, desc='Processing', unit='item'): try: # 对每个块进行处理 chunk['AI Group'] = chunk[name_col].apply(lambda x : bert_predict(x)) # 增量保存处理结果 if total_processed == 0: chunk.to_csv(temp_path, mode='w', index=False) else: chunk.to_csv(temp_path, mode='a', header=False, index=False) # 更新已处理的数据量 total_processed += len(chunk) except Exception as e: df_error = pd.concat([df_error, chunk]) return temp_path, df_error, total_processed async def process_file(file_path, name_col, temp_path, save_path, chunksize=5): error_file_name, error_file_extension = os.path.splitext(os.path.basename(save_path)) error_file = error_file_name + '_error' + error_file_extension origin_csv_file = error_file_name + '_origin.csv' error_file_path = os.path.join(os.path.dirname(save_path), error_file) origin_csv_path = os.path.join(os.path.dirname(save_path), origin_csv_file) if final_file_extension == '.csv': try: df_origin = pd.read_csv(file_path, encoding='utf-8') except UnicodeDecodeError: df_origin = pd.read_csv(file_path, encoding='gbk') else: df_origin = pd.read_excel(file_path) df_origin.to_csv(origin_csv_path, index=False) df_error = pd.DataFrame(columns=df_origin.columns) if name_col not in df_origin.columns: return '列名错误', None, None total_processed = 0 temp_path, df_error, total_processed = await asyncio.to_thread(process_data, origin_csv_path, temp_path, name_col, chunksize, total_processed, df_error) df_final = pd.read_csv(temp_path) if final_file_extension == '.csv': df_final.to_csv(save_path, index=False) else: df_final.to_excel(save_path, index=False) try: os.remove(origin_csv_path) except Exception as e: pass if len(df_error) == 0: return save_path, '', 0 else: if final_file_extension == '.csv': df_error.to_csv(error_file_path, index=False) else: df_error.to_excel(error_file_path, index=False) return save_path, error_file_path, len(df_error) save_path, error_path, error_len = await process_file(request.path, request.name_column, temp_path=temp_csv, save_path=new_file_path) if error_len == 0: return JSONResponse(content={ "message": "分类完成", "data":{"output_file": file_base_url + save_path.split(basic_path)[1]}, "code":200 }, status_code=200) elif save_path == '列名错误': return JSONResponse(content={ "message": "用户标识列输入错误,请重新输入", "data":{}, "code":500 }, status_code=200) else: return JSONResponse(content={ "message": "分类完成只完成部分", "data":{"output_file": file_base_url + save_path.split(basic_path)[1], "output_file_nonprocess":file_base_url + error_path.split(basic_path)[1]}, "code":200 }, status_code=200) except Exception as e: return JSONResponse(content={ "message": f"出现错误: {e}, 请检查列名和文件重试", "data":{}, "code":500 }, status_code=200) @app.get("/test/") async def classify_data(): return '测试成功' if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=port)