|
@@ -1,5 +1,5 @@
|
|
import pandas as pd
|
|
import pandas as pd
|
|
-import os, time, shutil
|
|
|
|
|
|
+import os, time, shutil, sys
|
|
import openai
|
|
import openai
|
|
from fastapi import FastAPI, UploadFile, File, Form
|
|
from fastapi import FastAPI, UploadFile, File, Form
|
|
from pydantic import BaseModel
|
|
from pydantic import BaseModel
|
|
@@ -9,7 +9,14 @@ import uvicorn, socket
|
|
from tqdm import tqdm
|
|
from tqdm import tqdm
|
|
from fastapi.staticfiles import StaticFiles
|
|
from fastapi.staticfiles import StaticFiles
|
|
from config import *
|
|
from config import *
|
|
|
|
+
|
|
from functions import split_dataframe_to_dict, extract_list_from_string
|
|
from functions import split_dataframe_to_dict, extract_list_from_string
|
|
|
|
+sys.path.append(os.path.join(os.path.dirname(__file__), 'bert'))
|
|
|
|
+import torch
|
|
|
|
+# from model import BertClassifier
|
|
|
|
+from bert.model import BertClassifier
|
|
|
|
+from transformers import BertTokenizer, BertConfig
|
|
|
|
+
|
|
app = FastAPI()
|
|
app = FastAPI()
|
|
app.mount("/data", StaticFiles(directory='./process'), name="static")
|
|
app.mount("/data", StaticFiles(directory='./process'), name="static")
|
|
class ClassificationRequest(BaseModel):
|
|
class ClassificationRequest(BaseModel):
|
|
@@ -21,6 +28,90 @@ class ClassificationRequest(BaseModel):
|
|
proxy: bool = False
|
|
proxy: bool = False
|
|
chunk_size: int = 100
|
|
chunk_size: int = 100
|
|
|
|
|
|
|
|
+class ClassificationRequestBert(BaseModel):
|
|
|
|
+ path: str
|
|
|
|
+ client_id: str
|
|
|
|
+ name_column: str
|
|
|
|
+
|
|
|
|
+bert_config = BertConfig.from_pretrained(pre_train_model)
|
|
|
|
+# 定义模型
|
|
|
|
+model = BertClassifier(bert_config, len(label_revert_map.keys()))
|
|
|
|
+# 加载训练好的模型
|
|
|
|
+model.load_state_dict(torch.load(model_save_path, map_location=torch.device('cpu')))
|
|
|
|
+model.eval()
|
|
|
|
+tokenizer = BertTokenizer.from_pretrained(pre_train_model)
|
|
|
|
+
|
|
|
|
+def bert_predict(text):
|
|
|
|
+ token = tokenizer(text, add_special_tokens=True, padding='max_length', truncation=True, max_length=512)
|
|
|
|
+ input_ids = token['input_ids']
|
|
|
|
+ attention_mask = token['attention_mask']
|
|
|
|
+ token_type_ids = token['token_type_ids']
|
|
|
|
+
|
|
|
|
+ input_ids = torch.tensor([input_ids], dtype=torch.long)
|
|
|
|
+ attention_mask = torch.tensor([attention_mask], dtype=torch.long)
|
|
|
|
+ token_type_ids = torch.tensor([token_type_ids], dtype=torch.long)
|
|
|
|
+
|
|
|
|
+ predicted = model(
|
|
|
|
+ input_ids,
|
|
|
|
+ attention_mask,
|
|
|
|
+ token_type_ids,
|
|
|
|
+ )
|
|
|
|
+ pred_label = torch.argmax(predicted, dim=1).numpy()[0]
|
|
|
|
+
|
|
|
|
+ return label_revert_map[pred_label]
|
|
|
|
+def predict_excel(file_path, name_col, temp_path, save_path, chunksize=5):
|
|
|
|
+ # 初始化变量
|
|
|
|
+ error_file_name, error_file_extension = os.path.splitext(os.path.basename(save_path))
|
|
|
|
+ # 添加后缀
|
|
|
|
+ error_file = error_file_name + '_error' + error_file_extension
|
|
|
|
+ origin_csv_file = error_file_name + '_origin.csv'
|
|
|
|
+ # 生成新的文件路径
|
|
|
|
+ error_file_path = os.path.join(os.path.dirname(save_path), error_file)
|
|
|
|
+ origin_csv_path = os.path.join(os.path.dirname(save_path), origin_csv_file)
|
|
|
|
+ total_processed = 0
|
|
|
|
+ df_origin = pd.read_excel(file_path)
|
|
|
|
+ df_error = pd.DataFrame(columns=df_origin.columns)
|
|
|
|
+ df_origin.to_csv(origin_csv_path, index=False)
|
|
|
|
+ # 按块读取 CSV 文件
|
|
|
|
+
|
|
|
|
+ for chunk in tqdm(pd.read_csv(origin_csv_path, chunksize=chunksize, iterator=True), desc='Processing', unit='item'):
|
|
|
|
+ try:
|
|
|
|
+ # 对每个块进行处理
|
|
|
|
+ chunk['classify'] = chunk[name_col].apply(bert_predict)
|
|
|
|
+ # 增量保存处理结果
|
|
|
|
+ if total_processed == 0:
|
|
|
|
+ chunk.to_csv(temp_path, mode='w', index=False)
|
|
|
|
+ else:
|
|
|
|
+ chunk.to_csv(temp_path, mode='a', header=False, index=False)
|
|
|
|
+ # 更新已处理的数据量
|
|
|
|
+ total_processed += len(chunk)
|
|
|
|
+ except Exception as e:
|
|
|
|
+ df_error = pd.concat([df_error, chunk])
|
|
|
|
+ df_final = pd.read_csv(temp_path)
|
|
|
|
+ df_final.to_excel(save_path)
|
|
|
|
+ os.remove(origin_csv_path)
|
|
|
|
+ if len(df_error) == 0:
|
|
|
|
+ return save_path, '', 0
|
|
|
|
+ else:
|
|
|
|
+ df_error.to_excel(error_file_path)
|
|
|
|
+ return save_path, error_file_path, len(df_error)
|
|
|
|
+def remove_files():
|
|
|
|
+ current_time = time.time()
|
|
|
|
+ TIME_THRESHOLD_FILEPATH = 30 * 24 * 60 * 60
|
|
|
|
+ TIME_THRESHOLD_FILE = 10 * 24 * 60 * 60
|
|
|
|
+ for root, dirs, files in os.walk(basic_path, topdown=False):
|
|
|
|
+ # 删除文件
|
|
|
|
+ for file in files:
|
|
|
|
+ file_path = os.path.join(root, file)
|
|
|
|
+ if current_time - os.path.getmtime(file_path) > TIME_THRESHOLD_FILE:
|
|
|
|
+ print(f"删除文件: {file_path}")
|
|
|
|
+ os.remove(file_path)
|
|
|
|
+ # 删除文件夹
|
|
|
|
+ for dir in dirs:
|
|
|
|
+ dir_path = os.path.join(root, dir)
|
|
|
|
+ if current_time - os.path.getmtime(dir_path) > TIME_THRESHOLD_FILEPATH:
|
|
|
|
+ print(f"删除文件夹: {dir_path}")
|
|
|
|
+ shutil.rmtree(dir_path)
|
|
@app.post("/uploadfile/")
|
|
@app.post("/uploadfile/")
|
|
async def create_upload_file(file: UploadFile = File(...), client_id: str = Form(...)):
|
|
async def create_upload_file(file: UploadFile = File(...), client_id: str = Form(...)):
|
|
user_directory = f'{basic_path}/{client_id}'
|
|
user_directory = f'{basic_path}/{client_id}'
|
|
@@ -40,26 +131,10 @@ async def create_upload_file(file: UploadFile = File(...), client_id: str = Form
|
|
except Exception as e:
|
|
except Exception as e:
|
|
return JSONResponse(content={"message": f"发生错误: {str(e)}"}, status_code=500)
|
|
return JSONResponse(content={"message": f"发生错误: {str(e)}"}, status_code=500)
|
|
|
|
|
|
-@app.post("/classify/")
|
|
|
|
|
|
+@app.post("/classify_openai/")
|
|
async def classify_data(request: ClassificationRequest):
|
|
async def classify_data(request: ClassificationRequest):
|
|
try:
|
|
try:
|
|
- current_time = time.time()
|
|
|
|
- TIME_THRESHOLD_FILEPATH = 30 * 24 * 60 * 60
|
|
|
|
- TIME_THRESHOLD_FILE = 10 * 24 * 60 * 60
|
|
|
|
- for root, dirs, files in os.walk(basic_path, topdown=False):
|
|
|
|
- # 删除文件
|
|
|
|
- for file in files:
|
|
|
|
- file_path = os.path.join(root, file)
|
|
|
|
- if current_time - os.path.getmtime(file_path) > TIME_THRESHOLD_FILE:
|
|
|
|
- print(f"删除文件: {file_path}")
|
|
|
|
- os.remove(file_path)
|
|
|
|
- # 删除文件夹
|
|
|
|
- for dir in dirs:
|
|
|
|
- dir_path = os.path.join(root, dir)
|
|
|
|
- if current_time - os.path.getmtime(dir_path) > TIME_THRESHOLD_FILEPATH:
|
|
|
|
- print(f"删除文件夹: {dir_path}")
|
|
|
|
- shutil.rmtree(dir_path)
|
|
|
|
-
|
|
|
|
|
|
+ remove_files()
|
|
work_path = f'{basic_path}/{request.client_id}'
|
|
work_path = f'{basic_path}/{request.client_id}'
|
|
if not os.path.exists(work_path):
|
|
if not os.path.exists(work_path):
|
|
os.makedirs(work_path, exist_ok=True)
|
|
os.makedirs(work_path, exist_ok=True)
|
|
@@ -111,5 +186,24 @@ async def classify_data(request: ClassificationRequest):
|
|
except Exception as e:
|
|
except Exception as e:
|
|
return {"message": f"处理出现错误: {e}"}
|
|
return {"message": f"处理出现错误: {e}"}
|
|
|
|
|
|
|
|
+@app.post("/classify_bert/")
|
|
|
|
+async def classify_data(request: ClassificationRequestBert):
|
|
|
|
+ remove_files()
|
|
|
|
+ work_path = f'{basic_path}/{request.client_id}'
|
|
|
|
+ if not os.path.exists(work_path):
|
|
|
|
+ os.makedirs(work_path, exist_ok=True)
|
|
|
|
+ final_file_name, final_file_extension = os.path.splitext(os.path.basename(request.path))
|
|
|
|
+ # 添加后缀
|
|
|
|
+ final_file = final_file_name + '_classify' + final_file_extension
|
|
|
|
+
|
|
|
|
+ # 生成新的文件路径
|
|
|
|
+ new_file_path = os.path.join(os.path.dirname(request.path), final_file)
|
|
|
|
+ timestamp_str = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
|
|
|
+ temp_csv = work_path + '/' + timestamp_str + 'output_temp.csv'
|
|
|
|
+ save_path, error_path, error_len = predict_excel(request.path, request.name_column, temp_path=temp_csv ,save_path=new_file_path)
|
|
|
|
+ if error_len == 0:
|
|
|
|
+ return {"message": "分类完成", "output_file": file_base_url + save_path.split(basic_path)[1]}
|
|
|
|
+ else:
|
|
|
|
+ return {"message": "分类完成只完成部分", "output_file": file_base_url + save_path.split(basic_path)[1], "output_file_nonprocess":file_base_url + save_path.split(error_path)[1],}
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|
|
uvicorn.run(app, host="0.0.0.0", port=port)
|
|
uvicorn.run(app, host="0.0.0.0", port=port)
|