123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293 |
- import pandas as pd
- import os, time, shutil, sys
- import openai
- from fastapi import FastAPI, UploadFile, File, Form
- from pydantic import BaseModel
- from fastapi.responses import JSONResponse
- from datetime import datetime
- import uvicorn, socket
- from tqdm import tqdm
- from fastapi.staticfiles import StaticFiles
- from config import *
- import asyncio, threading
- from functions import split_dataframe_to_dict, extract_list_from_string
- sys.path.append(os.path.join(os.path.dirname(__file__), 'bert'))
- import torch
- # from model import BertClassifier
- from bert.model import BertClassifier
- from transformers import BertTokenizer, BertConfig
- from fastapi.middleware.cors import CORSMiddleware
- app = FastAPI()
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- if not os.path.exists(basic_path):
- os.makedirs(basic_path, exist_ok=True)
- app.mount("/data", StaticFiles(directory=basic_path), name="static")
- bert_config = BertConfig.from_pretrained(pre_train_model)
- # 定义模型
- model = BertClassifier(bert_config, len(label_revert_map.keys()))
- # 加载训练好的模型
- model.load_state_dict(torch.load(model_save_path, map_location=torch.device('cuda:0')))
- model.eval()
- tokenizer = BertTokenizer.from_pretrained(pre_train_model)
- class ClassificationRequest(BaseModel):
- path: str
- client_id: str
- one_key: str
- name_column: str
- api_key: str = "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"
- proxy: bool = False
- chunk_size: int = 100
- class ClassificationRequestBert(BaseModel):
- path: str
- client_id: str
- name_column: str
- def bert_predict(text):
- if type(text) == str and text != '':
- token = tokenizer(text, add_special_tokens=True, padding='max_length', truncation=True, max_length=512)
- input_ids = token['input_ids']
- attention_mask = token['attention_mask']
- token_type_ids = token['token_type_ids']
- input_ids = torch.tensor([input_ids], dtype=torch.long)
- attention_mask = torch.tensor([attention_mask], dtype=torch.long)
- token_type_ids = torch.tensor([token_type_ids], dtype=torch.long)
- predicted = model(
- input_ids,
- attention_mask,
- token_type_ids,
- )
- pred_label = torch.argmax(predicted, dim=1).numpy()[0]
-
- return label_revert_map[pred_label]
- else:
- return ''
- def remove_files():
- current_time = time.time()
- TIME_THRESHOLD_FILEPATH = 30 * 24 * 60 * 60
- TIME_THRESHOLD_FILE = 10 * 24 * 60 * 60
- for root, dirs, files in os.walk(basic_path, topdown=False):
- # 删除文件
- for file in files:
- file_path = os.path.join(root, file)
- if current_time - os.path.getmtime(file_path) > TIME_THRESHOLD_FILE:
- print(f"删除文件: {file_path}")
- os.remove(file_path)
- # 删除文件夹
- for dir in dirs:
- dir_path = os.path.join(root, dir)
- if current_time - os.path.getmtime(dir_path) > TIME_THRESHOLD_FILEPATH:
- print(f"删除文件夹: {dir_path}")
- shutil.rmtree(dir_path)
- @app.post("/uploadfile/")
- async def create_upload_file(file: UploadFile = File(...), client_id: str = Form(...)):
- user_directory = f'{basic_path}/{client_id}'
- if not os.path.exists(user_directory):
- os.makedirs(user_directory)
- os.chmod(user_directory, 0o777) # 设置用户目录权限为777
- file_location = os.path.join(user_directory, file.filename)
- try:
- with open(file_location, "wb+") as file_object:
- file_object.write(file.file.read())
- os.chmod(file_location, 0o777) # 设置文件权限为777
- return JSONResponse(content={
- "message": f"文件 '{file.filename}' 上传成功",
- "data":{"client_id": client_id,"file_path": file_location},
- "code":200
- }, status_code=200)
- except Exception as e:
- return JSONResponse(content={"message": f"发生错误: {str(e)}", "data":{}, "code":500}, status_code=500)
- async def process_openai_data(deal_result, client, temp_csv):
- for name, value in tqdm(deal_result.items(), desc='Processing', unit='item'):
- message = [
- {'role':'system', 'content': cls_system_prompt},
- {'role':'user', 'content':user_prompt.format(chunk=str(value))}
- ]
- # result_string = post_openai(message)
- response = await client.chat.completions.create(model='gpt-4',messages=message)
- result_string = response.choices[0].message.content
- result = extract_list_from_string(result_string)
- if result:
- df_output = pd.DataFrame(result)
- df_output.to_csv(temp_csv, mode='a', header=True, index=False)
- else:
- continue
- @app.post("/classify_openai/")
- async def classify_data(request: ClassificationRequest):
- try:
- remove_files()
- work_path = f'{basic_path}/{request.client_id}'
- if not os.path.exists(work_path):
- os.makedirs(work_path, exist_ok=True)
- final_file_name, final_file_extension = os.path.splitext(os.path.basename(request.path))
- timestamp_str = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
- if final_file_extension == '.csv':
- df_origin = pd.read_csv(request.path)
- else:
- df_origin = pd.read_excel(request.path)
- if request.name_column not in df_origin.columns or request.one_key not in df_origin.columns:
- return JSONResponse(content={"message": "用户标识列输入错误,请重新输入","data":{},"code":500}, status_code=500)
- df_origin['AI Name'] = df_origin[request.name_column]
- df_origin['AI Group'] = ''
- df_use = df_origin[['AI Name', 'AI Group']]
- deal_result = split_dataframe_to_dict(df_use, request.chunk_size)
- # 生成当前时间的时间戳字符串
-
- temp_csv = work_path + '/' + timestamp_str + 'output_temp.csv'
-
- # 添加后缀
- final_file = final_file_name + '_classify' + final_file_extension
- # 生成新的文件路径
- new_file_path = os.path.join(os.path.dirname(request.path), final_file)
- if not request.proxy:
- print(f'用户{request.client_id}正在使用直连的gpt-API')
- client = openai.AsyncOpenAI(api_key=request.api_key, base_url=openai_url)
- else:
- client = openai.AsyncOpenAI(api_key=request.api_key, base_url=proxy_url)
- x = await process_openai_data(deal_result, client, temp_csv)
- if os.path.exists(temp_csv):
- df_result = pd.read_csv(temp_csv)
- df_origin = df_origin.drop(columns='AI Group')
- df_final = df_origin.merge(df_result, on='AI Name', how='left').drop_duplicates(subset=[request.one_key,'AI Name'], keep='first')
- df_final = df_final.drop(columns='AI Name')
- if final_file_extension == '.csv':
- df_final.to_csv(new_file_path, index=False)
- else:
- df_final.to_excel(new_file_path, index=False)
- return JSONResponse(content={
- "message": f"分类完成",
- "data":{"output_file": file_base_url + new_file_path.split(basic_path)[1]},
- "code":200
- }, status_code=200)
- else:
- print('...')
- return JSONResponse(content={"message": "文件没能处理成功","data":{},"code":500}, status_code=500)
- except Exception as e:
- print('okok')
- return JSONResponse(content={"message": f"处理出现错误: {e}","data":{}, "code":500}, status_code=500)
- @app.post("/classify_bert/")
- async def real_process(request: ClassificationRequestBert):
- # task = asyncio.create_task(classify_data_bert(request))
- result = await classify_data_bert(request)
- return result
- async def classify_data_bert(request: ClassificationRequestBert):
- try:
- remove_files()
- work_path = f'{basic_path}/{request.client_id}'
- if not os.path.exists(work_path):
- os.makedirs(work_path, exist_ok=True)
- final_file_name, final_file_extension = os.path.splitext(os.path.basename(request.path))
- final_file = final_file_name + '_classify' + final_file_extension
- new_file_path = os.path.join(os.path.dirname(request.path), final_file)
- timestamp_str = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
- temp_csv = work_path + '/' + timestamp_str + 'output_temp.csv'
- def process_data(origin_csv_path, temp_path, name_col, chunksize, total_processed, df_error):
- if name_col not in pd.read_csv(origin_csv_path).columns:
- return '列名错误', None, None
- for chunk in tqdm(pd.read_csv(origin_csv_path, chunksize=chunksize, iterator=True), total=len(pd.read_csv(origin_csv_path)) // chunksize + 1, desc='Processing', unit='item'):
- try:
- # 对每个块进行处理
- chunk['AI Group'] = chunk[name_col].apply(lambda x : bert_predict(x))
- # 增量保存处理结果
- if total_processed == 0:
- chunk.to_csv(temp_path, mode='w', index=False)
- else:
- chunk.to_csv(temp_path, mode='a', header=False, index=False)
- # 更新已处理的数据量
- total_processed += len(chunk)
- except Exception as e:
- df_error = pd.concat([df_error, chunk])
- return temp_path, df_error, total_processed
- async def process_file(file_path, name_col, temp_path, save_path, chunksize=5):
- error_file_name, error_file_extension = os.path.splitext(os.path.basename(save_path))
- error_file = error_file_name + '_error' + error_file_extension
- origin_csv_file = error_file_name + '_origin.csv'
- error_file_path = os.path.join(os.path.dirname(save_path), error_file)
- origin_csv_path = os.path.join(os.path.dirname(save_path), origin_csv_file)
-
- if final_file_extension == '.csv':
- try:
- df_origin = pd.read_csv(file_path, encoding='utf-8')
- except UnicodeDecodeError:
- df_origin = pd.read_csv(file_path, encoding='gbk')
- else:
- df_origin = pd.read_excel(file_path)
-
- df_origin.to_csv(origin_csv_path, index=False)
- df_error = pd.DataFrame(columns=df_origin.columns)
-
- if name_col not in df_origin.columns:
- return '列名错误', None, None
-
- total_processed = 0
- temp_path, df_error, total_processed = await asyncio.to_thread(process_data, origin_csv_path, temp_path, name_col, chunksize, total_processed, df_error)
- df_final = pd.read_csv(temp_path)
- if final_file_extension == '.csv':
- df_final.to_csv(save_path, index=False)
- else:
- df_final.to_excel(save_path, index=False)
-
- try:
- os.remove(origin_csv_path)
- except Exception as e:
- pass
- if len(df_error) == 0:
- return save_path, '', 0
- else:
- if final_file_extension == '.csv':
- df_error.to_csv(error_file_path, index=False)
- else:
- df_error.to_excel(error_file_path, index=False)
- return save_path, error_file_path, len(df_error)
- save_path, error_path, error_len = await process_file(request.path, request.name_column, temp_path=temp_csv, save_path=new_file_path)
- if error_len == 0:
- return JSONResponse(content={
- "message": "分类完成",
- "data":{"output_file": file_base_url + save_path.split(basic_path)[1]},
- "code":200
- }, status_code=200)
- elif save_path == '列名错误':
- return JSONResponse(content={
- "message": "用户标识列输入错误,请重新输入",
- "data":{},
- "code":500
- }, status_code=200)
- else:
- return JSONResponse(content={
- "message": "分类完成只完成部分",
- "data":{"output_file": file_base_url + save_path.split(basic_path)[1],
- "output_file_nonprocess":file_base_url + error_path.split(basic_path)[1]},
- "code":200
- }, status_code=200)
- except Exception as e:
- return JSONResponse(content={
- "message": f"出现错误: {e}, 请检查列名和文件重试",
- "data":{},
- "code":500
- }, status_code=200)
- @app.get("/test/")
- async def classify_data():
- return '测试成功'
- if __name__ == "__main__":
- uvicorn.run(app, host="0.0.0.0", port=port)
|