wangdalin пре 10 месеци
родитељ
комит
5ddb5b163b
4 измењених фајлова са 47 додато и 27 уклоњено
  1. 14 6
      client.py
  2. 4 3
      config.py
  3. 23 18
      name_classify_api.py
  4. 6 0
      requirements.txt

+ 14 - 6
client.py

@@ -1,5 +1,5 @@
 import requests, os
-
+from config import *
 def cls_process_openai(path, client_id, one_key, name_column, api_key="sk-iREtaVNjamaBArOTlc_2BfGFJVPiU-9EjSFMUspIPBT3BlbkFJxS0SMmKZD9L9UumPczee4VKawCwVeGBQAr9MgsWGkA",
                 proxy=False, chunk_size=100):
     # 定义请求的数据
@@ -13,7 +13,7 @@ def cls_process_openai(path, client_id, one_key, name_column, api_key="sk-iREtaV
         "chunk_size":chunk_size
     }
     # 发送 POST 请求
-    response = requests.post("http://172.28.5.44:8070/classify_openai/", json=data)
+    response = requests.post(f"http://{ip}:8070/classify_openai/", json=data)
 
     # 处理响应
     if response.status_code == 200:
@@ -31,7 +31,7 @@ def cls_process_bert(path, client_id, name_column):
         "name_column": name_column
     }
     # 发送 POST 请求
-    response = requests.post("http://172.28.5.44:8070/classify_bert/", json=data)
+    response = requests.post(f"http://{ip}:8070/classify_bert/", json=data)
 
     # 处理响应
     if response.status_code == 200:
@@ -52,7 +52,7 @@ def upload_process(file_path, client_id):
         files = {
             'file': (name, file, 'text/plain')  # 'text/plain' 是文件类型,可以根据需要修改
         }
-        response = requests.post("http://ip:port/uploadfile/", data=data, files=files)
+        response = requests.post(f"http://{ip}:8070/uploadfile/", data=data, files=files)
 
     # 处理响应
     if response.status_code == 200:
@@ -64,14 +64,22 @@ def upload_process(file_path, client_id):
         print("错误信息:", response.text)
         return None
 
-def both_process(path, client_id, one_key, name_column):
+def both_process_api(path, client_id, one_key, name_column):
     file_path = upload_process(file_path=path, client_id=client_id)
     if file_path:
         cls_process_openai(path=file_path, client_id=client_id, one_key=one_key, name_column=name_column)
 
+def both_process_bert(path, client_id, name_column):
+    file_path = upload_process(file_path=path, client_id=client_id)
+    if file_path:
+        cls_process_bert(path=file_path, client_id=client_id, name_column=name_column)
+
 if __name__ == "__main__":
     # both_process(path="E:/code/name_classify/data_before/人群分类-自建站2024.08.xlsx", client_id='wangdalin', one_key='网店单号', name_column='收货姓名')
     # upload_process(file_path="E:/code/name_classify/data_before/人群分类-自建站2024.08.xlsx", client_id='wangdalin')
     # cls_process_openai(path="./process/wangdalin\\人群分类-自建站2024.08.xlsx",client_id='wangdalin', one_key='网店单号', name_column='收货姓名')
-    cls_process_bert(path="./process/wangdalin/人群分类-自建站2024.08.xlsx",client_id='wangdalin', name_column='收货姓名')
+    # cls_process_bert(path="./process/wangdalin/人群分类-自建站2024.08.xlsx",client_id='wangdalin', name_column='收货姓名')
+    both_process_bert(path='/dalin/name_classify/process/1售后Shopify退货人群匹配0902-0908.xlsx', client_id='wangdalin', name_column='姓名')
+    # cls_process_bert(path='/mnt/e/code/name_classify/process/wangdalin/3销售Shopify人群分析0902-0908_classify_error.xlsx', client_id='wangdalin', name_column='姓名')
+    
 

+ 4 - 3
config.py

@@ -19,7 +19,8 @@ cls_system_prompt = '你是一个名字判断专家,你需要根据提供的
 user_prompt = """提供的数据:{chunk}
                 返回的数据:"""
 port = 8070
-file_base_url = f'http://{get_ip_address()}:{port}/data'
-model_save_path = '/mnt/e/code/name_classify/bert/model/bert-name-classify/best_model_2024_9_4.pkl'
-pre_train_model = '/mnt/e/code/name_classify/bert/model/bert-name-classify'
+ip = get_ip_address()
+file_base_url = f'http://{ip}:{port}/data'
+model_save_path = '/dalin/name_classify/bert/model/bert-name-classify/best_model_2024_9_4.pkl'
+pre_train_model = '/dalin/name_classify/bert/model/bert-name-classify'
 label_revert_map = {0:'亚裔华人', 1:'亚裔非华人', 2:'非亚裔'}

+ 23 - 18
name_classify_api.py

@@ -18,7 +18,9 @@ from bert.model import BertClassifier
 from transformers import BertTokenizer, BertConfig
 
 app = FastAPI()
-app.mount("/data", StaticFiles(directory='./process'), name="static")
+if not os.path.exists(basic_path):
+    os.makedirs(basic_path, exist_ok=True)
+app.mount("/data", StaticFiles(directory=basic_path), name="static")
 class ClassificationRequest(BaseModel):
     path: str
     client_id: str
@@ -42,23 +44,26 @@ model.eval()
 tokenizer = BertTokenizer.from_pretrained(pre_train_model)
 
 def bert_predict(text):
-    token = tokenizer(text, add_special_tokens=True, padding='max_length', truncation=True, max_length=512)
-    input_ids = token['input_ids']
-    attention_mask = token['attention_mask']
-    token_type_ids = token['token_type_ids']
+    if type(text) == str and text != '':
+        token = tokenizer(text, add_special_tokens=True, padding='max_length', truncation=True, max_length=512)
+        input_ids = token['input_ids']
+        attention_mask = token['attention_mask']
+        token_type_ids = token['token_type_ids']
 
-    input_ids = torch.tensor([input_ids], dtype=torch.long)
-    attention_mask = torch.tensor([attention_mask], dtype=torch.long)
-    token_type_ids = torch.tensor([token_type_ids], dtype=torch.long)
+        input_ids = torch.tensor([input_ids], dtype=torch.long)
+        attention_mask = torch.tensor([attention_mask], dtype=torch.long)
+        token_type_ids = torch.tensor([token_type_ids], dtype=torch.long)
 
-    predicted = model(
-        input_ids,
-        attention_mask,
-        token_type_ids,
-    )
-    pred_label = torch.argmax(predicted, dim=1).numpy()[0]
-    
-    return label_revert_map[pred_label]
+        predicted = model(
+            input_ids,
+            attention_mask,
+            token_type_ids,
+        )
+        pred_label = torch.argmax(predicted, dim=1).numpy()[0]
+        
+        return label_revert_map[pred_label]
+    else:
+        return ''
 def predict_excel(file_path, name_col, temp_path, save_path, chunksize=5):
     # 初始化变量
     error_file_name, error_file_extension = os.path.splitext(os.path.basename(save_path))
@@ -74,7 +79,7 @@ def predict_excel(file_path, name_col, temp_path, save_path, chunksize=5):
     df_origin.to_csv(origin_csv_path, index=False)
     # 按块读取 CSV 文件
     
-    for chunk in tqdm(pd.read_csv(origin_csv_path, chunksize=chunksize, iterator=True), desc='Processing', unit='item'):
+    for chunk in tqdm(pd.read_csv(origin_csv_path, chunksize=chunksize, iterator=True), total=len(pd.read_csv(origin_csv_path)) // chunksize + 1, desc='Processing', unit='item'):
         try:
             # 对每个块进行处理
             chunk['classify'] = chunk[name_col].apply(bert_predict)
@@ -204,6 +209,6 @@ async def classify_data(request: ClassificationRequestBert):
     if error_len == 0:
         return {"message": "分类完成", "output_file": file_base_url + save_path.split(basic_path)[1]}
     else:
-        return {"message": "分类完成只完成部分", "output_file": file_base_url + save_path.split(basic_path)[1], "output_file_nonprocess":file_base_url + save_path.split(error_path)[1],}
+        return {"message": "分类完成只完成部分", "output_file": file_base_url + save_path.split(basic_path)[1], "output_file_nonprocess":file_base_url + error_path.split(basic_path)[1],}
 if __name__ == "__main__":
     uvicorn.run(app, host="0.0.0.0", port=port)

+ 6 - 0
requirements.txt

@@ -1,3 +1,9 @@
+--extra-index-url https://download.pytorch.org/whl/cu121
+torch==2.3.1
+--extra-index-url https://download.pytorch.org/whl/cu121
+torchaudio==2.3.1
+--extra-index-url https://download.pytorch.org/whl/cu121
+torchvision==0.18.1
 pandas
 openai
 requests