wangdalin 10 ヶ月 前
コミット
bed0e4d284
4 ファイル変更142 行追加29 行削除
  1. 3 3
      .gitignore
  2. 23 6
      client.py
  3. 3 1
      config.py
  4. 113 19
      name_classify_api.py

+ 3 - 3
.gitignore

@@ -1,3 +1,3 @@
-data_before
-data_out
-process
+process
+__pycache__
+bert-train

+ 23 - 6
client.py

@@ -1,6 +1,6 @@
 import requests, os
 
-def cls_process(path, client_id, one_key, name_column, api_key="sk-iREtaVNjamaBArOTlc_2BfGFJVPiU-9EjSFMUspIPBT3BlbkFJxS0SMmKZD9L9UumPczee4VKawCwVeGBQAr9MgsWGkA",
+def cls_process_openai(path, client_id, one_key, name_column, api_key="sk-iREtaVNjamaBArOTlc_2BfGFJVPiU-9EjSFMUspIPBT3BlbkFJxS0SMmKZD9L9UumPczee4VKawCwVeGBQAr9MgsWGkA",
                 proxy=False, chunk_size=100):
     # 定义请求的数据
     data = {
@@ -13,7 +13,7 @@ def cls_process(path, client_id, one_key, name_column, api_key="sk-iREtaVNjamaBA
         "chunk_size":chunk_size
     }
     # 发送 POST 请求
-    response = requests.post("http://10.41.3.69:8070/classify/", json=data)
+    response = requests.post("http://172.28.5.44:8070/classify_openai/", json=data)
 
     # 处理响应
     if response.status_code == 200:
@@ -23,7 +23,23 @@ def cls_process(path, client_id, one_key, name_column, api_key="sk-iREtaVNjamaBA
         print(f"请求失败,状态码: {response.status_code}")
         print("错误信息:", response.text)
 
+def cls_process_bert(path, client_id, name_column):
+    # 定义请求的数据
+    data = {
+        "path": path,
+        "client_id": client_id,
+        "name_column": name_column
+    }
+    # 发送 POST 请求
+    response = requests.post("http://172.28.5.44:8070/classify_bert/", json=data)
 
+    # 处理响应
+    if response.status_code == 200:
+        result = response.json()
+        print("接口响应:", result)
+    else:
+        print(f"请求失败,状态码: {response.status_code}")
+        print("错误信息:", response.text)
 def upload_process(file_path, client_id):
     # 表单数据
     data = {
@@ -36,7 +52,7 @@ def upload_process(file_path, client_id):
         files = {
             'file': (name, file, 'text/plain')  # 'text/plain' 是文件类型,可以根据需要修改
         }
-        response = requests.post("http://10.41.3.69:8070/uploadfile/", data=data, files=files)
+        response = requests.post("http://ip:port/uploadfile/", data=data, files=files)
 
     # 处理响应
     if response.status_code == 200:
@@ -51,10 +67,11 @@ def upload_process(file_path, client_id):
 def both_process(path, client_id, one_key, name_column):
     file_path = upload_process(file_path=path, client_id=client_id)
     if file_path:
-        cls_process(path=file_path, client_id=client_id, one_key=one_key, name_column=name_column)
+        cls_process_openai(path=file_path, client_id=client_id, one_key=one_key, name_column=name_column)
 
 if __name__ == "__main__":
-    both_process(path="E:/code/name_classify/data_before/人群分类-自建站2024.08.xlsx", client_id='wangdalin', one_key='网店单号', name_column='收货姓名')
+    # both_process(path="E:/code/name_classify/data_before/人群分类-自建站2024.08.xlsx", client_id='wangdalin', one_key='网店单号', name_column='收货姓名')
     # upload_process(file_path="E:/code/name_classify/data_before/人群分类-自建站2024.08.xlsx", client_id='wangdalin')
-    # cls_process(path="./process/wangdalin\\人群分类-自建站2024.08.xlsx",client_id='wangdalin', one_key='网店单号', name_column='收货姓名')
+    # cls_process_openai(path="./process/wangdalin\\人群分类-自建站2024.08.xlsx",client_id='wangdalin', one_key='网店单号', name_column='收货姓名')
+    cls_process_bert(path="./process/wangdalin/人群分类-自建站2024.08.xlsx",client_id='wangdalin', name_column='收货姓名')
 

+ 3 - 1
config.py

@@ -20,4 +20,6 @@ user_prompt = """提供的数据:{chunk}
                 返回的数据:"""
 port = 8070
 file_base_url = f'http://{get_ip_address()}:{port}/data'
-
+model_save_path = '/mnt/e/code/name_classify/bert/model/bert-name-classify/best_model_2024_9_4.pkl'
+pre_train_model = '/mnt/e/code/name_classify/bert/model/bert-name-classify'
+label_revert_map = {0:'亚裔华人', 1:'亚裔非华人', 2:'非亚裔'}

+ 113 - 19
name_classify_api.py

@@ -1,5 +1,5 @@
 import pandas as pd 
-import os, time, shutil
+import os, time, shutil, sys
 import openai
 from fastapi import FastAPI, UploadFile, File, Form
 from pydantic import BaseModel
@@ -9,7 +9,14 @@ import uvicorn, socket
 from tqdm import tqdm
 from fastapi.staticfiles import StaticFiles
 from config import *
+
 from functions import split_dataframe_to_dict, extract_list_from_string
+sys.path.append(os.path.join(os.path.dirname(__file__), 'bert'))
+import torch
+# from model import BertClassifier
+from bert.model import BertClassifier
+from transformers import BertTokenizer, BertConfig
+
 app = FastAPI()
 app.mount("/data", StaticFiles(directory='./process'), name="static")
 class ClassificationRequest(BaseModel):
@@ -21,6 +28,90 @@ class ClassificationRequest(BaseModel):
     proxy: bool = False
     chunk_size: int = 100
 
+class ClassificationRequestBert(BaseModel):
+    path: str
+    client_id: str
+    name_column: str
+
+bert_config = BertConfig.from_pretrained(pre_train_model)
+# 定义模型
+model = BertClassifier(bert_config, len(label_revert_map.keys()))
+# 加载训练好的模型
+model.load_state_dict(torch.load(model_save_path, map_location=torch.device('cpu')))
+model.eval()
+tokenizer = BertTokenizer.from_pretrained(pre_train_model)
+
+def bert_predict(text):
+    token = tokenizer(text, add_special_tokens=True, padding='max_length', truncation=True, max_length=512)
+    input_ids = token['input_ids']
+    attention_mask = token['attention_mask']
+    token_type_ids = token['token_type_ids']
+
+    input_ids = torch.tensor([input_ids], dtype=torch.long)
+    attention_mask = torch.tensor([attention_mask], dtype=torch.long)
+    token_type_ids = torch.tensor([token_type_ids], dtype=torch.long)
+
+    predicted = model(
+        input_ids,
+        attention_mask,
+        token_type_ids,
+    )
+    pred_label = torch.argmax(predicted, dim=1).numpy()[0]
+    
+    return label_revert_map[pred_label]
+def predict_excel(file_path, name_col, temp_path, save_path, chunksize=5):
+    # 初始化变量
+    error_file_name, error_file_extension = os.path.splitext(os.path.basename(save_path))
+    # 添加后缀
+    error_file = error_file_name + '_error' + error_file_extension
+    origin_csv_file = error_file_name + '_origin.csv'
+    # 生成新的文件路径
+    error_file_path = os.path.join(os.path.dirname(save_path), error_file)
+    origin_csv_path = os.path.join(os.path.dirname(save_path), origin_csv_file)
+    total_processed = 0
+    df_origin = pd.read_excel(file_path)
+    df_error = pd.DataFrame(columns=df_origin.columns)
+    df_origin.to_csv(origin_csv_path, index=False)
+    # 按块读取 CSV 文件
+    
+    for chunk in tqdm(pd.read_csv(origin_csv_path, chunksize=chunksize, iterator=True), desc='Processing', unit='item'):
+        try:
+            # 对每个块进行处理
+            chunk['classify'] = chunk[name_col].apply(bert_predict)
+            # 增量保存处理结果
+            if total_processed == 0:
+                chunk.to_csv(temp_path, mode='w', index=False)
+            else:
+                chunk.to_csv(temp_path, mode='a', header=False, index=False)
+            # 更新已处理的数据量
+            total_processed += len(chunk)
+        except Exception as e:
+            df_error = pd.concat([df_error, chunk])
+    df_final = pd.read_csv(temp_path)
+    df_final.to_excel(save_path)
+    os.remove(origin_csv_path)
+    if len(df_error) == 0:
+        return save_path, '', 0
+    else:
+        df_error.to_excel(error_file_path)
+        return save_path, error_file_path, len(df_error)
+def remove_files():
+    current_time = time.time()
+    TIME_THRESHOLD_FILEPATH = 30 * 24 * 60 * 60
+    TIME_THRESHOLD_FILE = 10 * 24 * 60 * 60
+    for root, dirs, files in os.walk(basic_path, topdown=False):
+        # 删除文件
+        for file in files:
+            file_path = os.path.join(root, file)
+            if current_time - os.path.getmtime(file_path) > TIME_THRESHOLD_FILE:
+                print(f"删除文件: {file_path}")
+                os.remove(file_path)
+        # 删除文件夹
+        for dir in dirs:
+            dir_path = os.path.join(root, dir)
+            if current_time - os.path.getmtime(dir_path) > TIME_THRESHOLD_FILEPATH:
+                print(f"删除文件夹: {dir_path}")
+                shutil.rmtree(dir_path)
 @app.post("/uploadfile/")
 async def create_upload_file(file: UploadFile = File(...), client_id: str = Form(...)):
     user_directory = f'{basic_path}/{client_id}'
@@ -40,26 +131,10 @@ async def create_upload_file(file: UploadFile = File(...), client_id: str = Form
     except Exception as e:
         return JSONResponse(content={"message": f"发生错误: {str(e)}"}, status_code=500)
     
-@app.post("/classify/")
+@app.post("/classify_openai/")
 async def classify_data(request: ClassificationRequest):
     try:
-        current_time = time.time()
-        TIME_THRESHOLD_FILEPATH = 30 * 24 * 60 * 60
-        TIME_THRESHOLD_FILE = 10 * 24 * 60 * 60
-        for root, dirs, files in os.walk(basic_path, topdown=False):
-            # 删除文件
-            for file in files:
-                file_path = os.path.join(root, file)
-                if current_time - os.path.getmtime(file_path) > TIME_THRESHOLD_FILE:
-                    print(f"删除文件: {file_path}")
-                    os.remove(file_path)
-            # 删除文件夹
-            for dir in dirs:
-                dir_path = os.path.join(root, dir)
-                if current_time - os.path.getmtime(dir_path) > TIME_THRESHOLD_FILEPATH:
-                    print(f"删除文件夹: {dir_path}")
-                    shutil.rmtree(dir_path)
-        
+        remove_files()
         work_path = f'{basic_path}/{request.client_id}'
         if not os.path.exists(work_path):
             os.makedirs(work_path, exist_ok=True)
@@ -111,5 +186,24 @@ async def classify_data(request: ClassificationRequest):
     except Exception as e:
         return {"message": f"处理出现错误: {e}"}
 
+@app.post("/classify_bert/")
+async def classify_data(request: ClassificationRequestBert):
+    remove_files()
+    work_path = f'{basic_path}/{request.client_id}'
+    if not os.path.exists(work_path):
+        os.makedirs(work_path, exist_ok=True)
+    final_file_name, final_file_extension = os.path.splitext(os.path.basename(request.path))
+    # 添加后缀
+    final_file = final_file_name + '_classify' + final_file_extension
+
+    # 生成新的文件路径
+    new_file_path = os.path.join(os.path.dirname(request.path), final_file)
+    timestamp_str = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
+    temp_csv = work_path + '/' + timestamp_str + 'output_temp.csv'
+    save_path, error_path, error_len = predict_excel(request.path, request.name_column, temp_path=temp_csv ,save_path=new_file_path)
+    if error_len == 0:
+        return {"message": "分类完成", "output_file": file_base_url + save_path.split(basic_path)[1]}
+    else:
+        return {"message": "分类完成只完成部分", "output_file": file_base_url + save_path.split(basic_path)[1], "output_file_nonprocess":file_base_url + save_path.split(error_path)[1],}
 if __name__ == "__main__":
     uvicorn.run(app, host="0.0.0.0", port=port)