wangdalin hai 10 meses
pai
achega
bed0e4d284
Modificáronse 4 ficheiros con 142 adicións e 29 borrados
  1. 3 3
      .gitignore
  2. 23 6
      client.py
  3. 3 1
      config.py
  4. 113 19
      name_classify_api.py

+ 3 - 3
.gitignore

@@ -1,3 +1,3 @@
-data_before
-data_out
-process
+process
+__pycache__
+bert-train

+ 23 - 6
client.py

@@ -1,6 +1,6 @@
 import requests, os
 import requests, os
 
 
-def cls_process(path, client_id, one_key, name_column, api_key="sk-iREtaVNjamaBArOTlc_2BfGFJVPiU-9EjSFMUspIPBT3BlbkFJxS0SMmKZD9L9UumPczee4VKawCwVeGBQAr9MgsWGkA",
+def cls_process_openai(path, client_id, one_key, name_column, api_key="sk-iREtaVNjamaBArOTlc_2BfGFJVPiU-9EjSFMUspIPBT3BlbkFJxS0SMmKZD9L9UumPczee4VKawCwVeGBQAr9MgsWGkA",
                 proxy=False, chunk_size=100):
                 proxy=False, chunk_size=100):
     # 定义请求的数据
     # 定义请求的数据
     data = {
     data = {
@@ -13,7 +13,7 @@ def cls_process(path, client_id, one_key, name_column, api_key="sk-iREtaVNjamaBA
         "chunk_size":chunk_size
         "chunk_size":chunk_size
     }
     }
     # 发送 POST 请求
     # 发送 POST 请求
-    response = requests.post("http://10.41.3.69:8070/classify/", json=data)
+    response = requests.post("http://172.28.5.44:8070/classify_openai/", json=data)
 
 
     # 处理响应
     # 处理响应
     if response.status_code == 200:
     if response.status_code == 200:
@@ -23,7 +23,23 @@ def cls_process(path, client_id, one_key, name_column, api_key="sk-iREtaVNjamaBA
         print(f"请求失败,状态码: {response.status_code}")
         print(f"请求失败,状态码: {response.status_code}")
         print("错误信息:", response.text)
         print("错误信息:", response.text)
 
 
+def cls_process_bert(path, client_id, name_column):
+    # 定义请求的数据
+    data = {
+        "path": path,
+        "client_id": client_id,
+        "name_column": name_column
+    }
+    # 发送 POST 请求
+    response = requests.post("http://172.28.5.44:8070/classify_bert/", json=data)
 
 
+    # 处理响应
+    if response.status_code == 200:
+        result = response.json()
+        print("接口响应:", result)
+    else:
+        print(f"请求失败,状态码: {response.status_code}")
+        print("错误信息:", response.text)
 def upload_process(file_path, client_id):
 def upload_process(file_path, client_id):
     # 表单数据
     # 表单数据
     data = {
     data = {
@@ -36,7 +52,7 @@ def upload_process(file_path, client_id):
         files = {
         files = {
             'file': (name, file, 'text/plain')  # 'text/plain' 是文件类型,可以根据需要修改
             'file': (name, file, 'text/plain')  # 'text/plain' 是文件类型,可以根据需要修改
         }
         }
-        response = requests.post("http://10.41.3.69:8070/uploadfile/", data=data, files=files)
+        response = requests.post("http://ip:port/uploadfile/", data=data, files=files)
 
 
     # 处理响应
     # 处理响应
     if response.status_code == 200:
     if response.status_code == 200:
@@ -51,10 +67,11 @@ def upload_process(file_path, client_id):
 def both_process(path, client_id, one_key, name_column):
 def both_process(path, client_id, one_key, name_column):
     file_path = upload_process(file_path=path, client_id=client_id)
     file_path = upload_process(file_path=path, client_id=client_id)
     if file_path:
     if file_path:
-        cls_process(path=file_path, client_id=client_id, one_key=one_key, name_column=name_column)
+        cls_process_openai(path=file_path, client_id=client_id, one_key=one_key, name_column=name_column)
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
-    both_process(path="E:/code/name_classify/data_before/人群分类-自建站2024.08.xlsx", client_id='wangdalin', one_key='网店单号', name_column='收货姓名')
+    # both_process(path="E:/code/name_classify/data_before/人群分类-自建站2024.08.xlsx", client_id='wangdalin', one_key='网店单号', name_column='收货姓名')
     # upload_process(file_path="E:/code/name_classify/data_before/人群分类-自建站2024.08.xlsx", client_id='wangdalin')
     # upload_process(file_path="E:/code/name_classify/data_before/人群分类-自建站2024.08.xlsx", client_id='wangdalin')
-    # cls_process(path="./process/wangdalin\\人群分类-自建站2024.08.xlsx",client_id='wangdalin', one_key='网店单号', name_column='收货姓名')
+    # cls_process_openai(path="./process/wangdalin\\人群分类-自建站2024.08.xlsx",client_id='wangdalin', one_key='网店单号', name_column='收货姓名')
+    cls_process_bert(path="./process/wangdalin/人群分类-自建站2024.08.xlsx",client_id='wangdalin', name_column='收货姓名')
 
 

+ 3 - 1
config.py

@@ -20,4 +20,6 @@ user_prompt = """提供的数据:{chunk}
                 返回的数据:"""
                 返回的数据:"""
 port = 8070
 port = 8070
 file_base_url = f'http://{get_ip_address()}:{port}/data'
 file_base_url = f'http://{get_ip_address()}:{port}/data'
-
+model_save_path = '/mnt/e/code/name_classify/bert/model/bert-name-classify/best_model_2024_9_4.pkl'
+pre_train_model = '/mnt/e/code/name_classify/bert/model/bert-name-classify'
+label_revert_map = {0:'亚裔华人', 1:'亚裔非华人', 2:'非亚裔'}

+ 113 - 19
name_classify_api.py

@@ -1,5 +1,5 @@
 import pandas as pd 
 import pandas as pd 
-import os, time, shutil
+import os, time, shutil, sys
 import openai
 import openai
 from fastapi import FastAPI, UploadFile, File, Form
 from fastapi import FastAPI, UploadFile, File, Form
 from pydantic import BaseModel
 from pydantic import BaseModel
@@ -9,7 +9,14 @@ import uvicorn, socket
 from tqdm import tqdm
 from tqdm import tqdm
 from fastapi.staticfiles import StaticFiles
 from fastapi.staticfiles import StaticFiles
 from config import *
 from config import *
+
 from functions import split_dataframe_to_dict, extract_list_from_string
 from functions import split_dataframe_to_dict, extract_list_from_string
+sys.path.append(os.path.join(os.path.dirname(__file__), 'bert'))
+import torch
+# from model import BertClassifier
+from bert.model import BertClassifier
+from transformers import BertTokenizer, BertConfig
+
 app = FastAPI()
 app = FastAPI()
 app.mount("/data", StaticFiles(directory='./process'), name="static")
 app.mount("/data", StaticFiles(directory='./process'), name="static")
 class ClassificationRequest(BaseModel):
 class ClassificationRequest(BaseModel):
@@ -21,6 +28,90 @@ class ClassificationRequest(BaseModel):
     proxy: bool = False
     proxy: bool = False
     chunk_size: int = 100
     chunk_size: int = 100
 
 
+class ClassificationRequestBert(BaseModel):
+    path: str
+    client_id: str
+    name_column: str
+
+bert_config = BertConfig.from_pretrained(pre_train_model)
+# 定义模型
+model = BertClassifier(bert_config, len(label_revert_map.keys()))
+# 加载训练好的模型
+model.load_state_dict(torch.load(model_save_path, map_location=torch.device('cpu')))
+model.eval()
+tokenizer = BertTokenizer.from_pretrained(pre_train_model)
+
+def bert_predict(text):
+    token = tokenizer(text, add_special_tokens=True, padding='max_length', truncation=True, max_length=512)
+    input_ids = token['input_ids']
+    attention_mask = token['attention_mask']
+    token_type_ids = token['token_type_ids']
+
+    input_ids = torch.tensor([input_ids], dtype=torch.long)
+    attention_mask = torch.tensor([attention_mask], dtype=torch.long)
+    token_type_ids = torch.tensor([token_type_ids], dtype=torch.long)
+
+    predicted = model(
+        input_ids,
+        attention_mask,
+        token_type_ids,
+    )
+    pred_label = torch.argmax(predicted, dim=1).numpy()[0]
+    
+    return label_revert_map[pred_label]
+def predict_excel(file_path, name_col, temp_path, save_path, chunksize=5):
+    # 初始化变量
+    error_file_name, error_file_extension = os.path.splitext(os.path.basename(save_path))
+    # 添加后缀
+    error_file = error_file_name + '_error' + error_file_extension
+    origin_csv_file = error_file_name + '_origin.csv'
+    # 生成新的文件路径
+    error_file_path = os.path.join(os.path.dirname(save_path), error_file)
+    origin_csv_path = os.path.join(os.path.dirname(save_path), origin_csv_file)
+    total_processed = 0
+    df_origin = pd.read_excel(file_path)
+    df_error = pd.DataFrame(columns=df_origin.columns)
+    df_origin.to_csv(origin_csv_path, index=False)
+    # 按块读取 CSV 文件
+    
+    for chunk in tqdm(pd.read_csv(origin_csv_path, chunksize=chunksize, iterator=True), desc='Processing', unit='item'):
+        try:
+            # 对每个块进行处理
+            chunk['classify'] = chunk[name_col].apply(bert_predict)
+            # 增量保存处理结果
+            if total_processed == 0:
+                chunk.to_csv(temp_path, mode='w', index=False)
+            else:
+                chunk.to_csv(temp_path, mode='a', header=False, index=False)
+            # 更新已处理的数据量
+            total_processed += len(chunk)
+        except Exception as e:
+            df_error = pd.concat([df_error, chunk])
+    df_final = pd.read_csv(temp_path)
+    df_final.to_excel(save_path)
+    os.remove(origin_csv_path)
+    if len(df_error) == 0:
+        return save_path, '', 0
+    else:
+        df_error.to_excel(error_file_path)
+        return save_path, error_file_path, len(df_error)
+def remove_files():
+    current_time = time.time()
+    TIME_THRESHOLD_FILEPATH = 30 * 24 * 60 * 60
+    TIME_THRESHOLD_FILE = 10 * 24 * 60 * 60
+    for root, dirs, files in os.walk(basic_path, topdown=False):
+        # 删除文件
+        for file in files:
+            file_path = os.path.join(root, file)
+            if current_time - os.path.getmtime(file_path) > TIME_THRESHOLD_FILE:
+                print(f"删除文件: {file_path}")
+                os.remove(file_path)
+        # 删除文件夹
+        for dir in dirs:
+            dir_path = os.path.join(root, dir)
+            if current_time - os.path.getmtime(dir_path) > TIME_THRESHOLD_FILEPATH:
+                print(f"删除文件夹: {dir_path}")
+                shutil.rmtree(dir_path)
 @app.post("/uploadfile/")
 @app.post("/uploadfile/")
 async def create_upload_file(file: UploadFile = File(...), client_id: str = Form(...)):
 async def create_upload_file(file: UploadFile = File(...), client_id: str = Form(...)):
     user_directory = f'{basic_path}/{client_id}'
     user_directory = f'{basic_path}/{client_id}'
@@ -40,26 +131,10 @@ async def create_upload_file(file: UploadFile = File(...), client_id: str = Form
     except Exception as e:
     except Exception as e:
         return JSONResponse(content={"message": f"发生错误: {str(e)}"}, status_code=500)
         return JSONResponse(content={"message": f"发生错误: {str(e)}"}, status_code=500)
     
     
-@app.post("/classify/")
+@app.post("/classify_openai/")
 async def classify_data(request: ClassificationRequest):
 async def classify_data(request: ClassificationRequest):
     try:
     try:
-        current_time = time.time()
-        TIME_THRESHOLD_FILEPATH = 30 * 24 * 60 * 60
-        TIME_THRESHOLD_FILE = 10 * 24 * 60 * 60
-        for root, dirs, files in os.walk(basic_path, topdown=False):
-            # 删除文件
-            for file in files:
-                file_path = os.path.join(root, file)
-                if current_time - os.path.getmtime(file_path) > TIME_THRESHOLD_FILE:
-                    print(f"删除文件: {file_path}")
-                    os.remove(file_path)
-            # 删除文件夹
-            for dir in dirs:
-                dir_path = os.path.join(root, dir)
-                if current_time - os.path.getmtime(dir_path) > TIME_THRESHOLD_FILEPATH:
-                    print(f"删除文件夹: {dir_path}")
-                    shutil.rmtree(dir_path)
-        
+        remove_files()
         work_path = f'{basic_path}/{request.client_id}'
         work_path = f'{basic_path}/{request.client_id}'
         if not os.path.exists(work_path):
         if not os.path.exists(work_path):
             os.makedirs(work_path, exist_ok=True)
             os.makedirs(work_path, exist_ok=True)
@@ -111,5 +186,24 @@ async def classify_data(request: ClassificationRequest):
     except Exception as e:
     except Exception as e:
         return {"message": f"处理出现错误: {e}"}
         return {"message": f"处理出现错误: {e}"}
 
 
+@app.post("/classify_bert/")
+async def classify_data(request: ClassificationRequestBert):
+    remove_files()
+    work_path = f'{basic_path}/{request.client_id}'
+    if not os.path.exists(work_path):
+        os.makedirs(work_path, exist_ok=True)
+    final_file_name, final_file_extension = os.path.splitext(os.path.basename(request.path))
+    # 添加后缀
+    final_file = final_file_name + '_classify' + final_file_extension
+
+    # 生成新的文件路径
+    new_file_path = os.path.join(os.path.dirname(request.path), final_file)
+    timestamp_str = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
+    temp_csv = work_path + '/' + timestamp_str + 'output_temp.csv'
+    save_path, error_path, error_len = predict_excel(request.path, request.name_column, temp_path=temp_csv ,save_path=new_file_path)
+    if error_len == 0:
+        return {"message": "分类完成", "output_file": file_base_url + save_path.split(basic_path)[1]}
+    else:
+        return {"message": "分类完成只完成部分", "output_file": file_base_url + save_path.split(basic_path)[1], "output_file_nonprocess":file_base_url + save_path.split(error_path)[1],}
 if __name__ == "__main__":
 if __name__ == "__main__":
     uvicorn.run(app, host="0.0.0.0", port=port)
     uvicorn.run(app, host="0.0.0.0", port=port)