Browse Source

业务修改, csv支持, 异步并行

wangdalin 10 months ago
parent
commit
5793e852d5
5 changed files with 188 additions and 73 deletions
  1. 9 6
      client.py
  2. 2 2
      config.py
  3. 173 65
      name_classify_api.py
  4. 4 0
      test.csv
  5. BIN
      test.xlsx

+ 9 - 6
client.py

@@ -8,12 +8,11 @@ def cls_process_openai(path, client_id, one_key, name_column, api_key="sk-iREtaV
         "client_id": client_id,
         "one_key": one_key,
         "name_column": name_column,
-        "api_key": api_key,
         "proxy": proxy,  
         "chunk_size":chunk_size
     }
     # 发送 POST 请求
-    response = requests.post(f"http://{ip}:8070/classify_openai/", json=data)
+    response = requests.post(f"http://10.41.1.57:8071/classify_openai/", json=data)
 
     # 处理响应
     if response.status_code == 200:
@@ -31,7 +30,7 @@ def cls_process_bert(path, client_id, name_column):
         "name_column": name_column
     }
     # 发送 POST 请求
-    response = requests.post(f"http://{ip}:8070/classify_bert/", json=data)
+    response = requests.post(f"http://10.41.1.57:8071/classify_bert/", json=data)
 
     # 处理响应
     if response.status_code == 200:
@@ -52,13 +51,13 @@ def upload_process(file_path, client_id):
         files = {
             'file': (name, file, 'text/plain')  # 'text/plain' 是文件类型,可以根据需要修改
         }
-        response = requests.post(f"http://{ip}:8070/uploadfile/", data=data, files=files)
+        response = requests.post(f"http://10.41.1.57:8071/uploadfile/", data=data, files=files)
 
     # 处理响应
     if response.status_code == 200:
         result = response.json()
         print("接口响应:", result)
-        return result.get('file_path')
+        return result.get('data').get('file_path')
     else:
         print(f"请求失败,状态码: {response.status_code}")
         print("错误信息:", response.text)
@@ -71,15 +70,19 @@ def both_process_api(path, client_id, one_key, name_column):
 
 def both_process_bert(path, client_id, name_column):
     file_path = upload_process(file_path=path, client_id=client_id)
+    print(file_path)
     if file_path:
         cls_process_bert(path=file_path, client_id=client_id, name_column=name_column)
+    
 
 if __name__ == "__main__":
     # both_process(path="E:/code/name_classify/data_before/人群分类-自建站2024.08.xlsx", client_id='wangdalin', one_key='网店单号', name_column='收货姓名')
     # upload_process(file_path="E:/code/name_classify/data_before/人群分类-自建站2024.08.xlsx", client_id='wangdalin')
     # cls_process_openai(path="./process/wangdalin\\人群分类-自建站2024.08.xlsx",client_id='wangdalin', one_key='网店单号', name_column='收货姓名')
     # cls_process_bert(path="./process/wangdalin/人群分类-自建站2024.08.xlsx",client_id='wangdalin', name_column='收货姓名')
-    both_process_bert(path='/dalin/name_classify/process/1售后Shopify退货人群匹配0902-0908.xlsx', client_id='wangdalin', name_column='姓名')
+    # both_process_bert(path='/dalin/name_classify/process/wangdalin/1销售Shopify人群分析0909-0915.xlsx', client_id='wangdalin', name_column='姓名')
+    both_process_api(path='/dalin/name_classify/test.csv', client_id='wangdalin', one_key='one_key', name_column='name')
+    # both_process_bert(path='/dalin/name_classify/test.csv', client_id='wangdalin', name_column='name')
     # cls_process_bert(path='/mnt/e/code/name_classify/process/wangdalin/3销售Shopify人群分析0902-0908_classify_error.xlsx', client_id='wangdalin', name_column='姓名')
     
 

+ 2 - 2
config.py

@@ -18,9 +18,9 @@ proxy_url = 'https://fast.bemore.lol/v1'
 cls_system_prompt = '你是一个名字判断专家,你需要根据提供的列表中的每一个字典元素的会员姓名,判断其名字分类,分别为3类: 亚裔华人,亚裔非华人, 非亚裔,并将结果填充到会员分类中, 整合之后返回与提供数据一样的格式给我'
 user_prompt = """提供的数据:{chunk}
                 返回的数据:"""
-port = 8070
+port = 8071
 ip = get_ip_address()
-file_base_url = f'http://{ip}:{port}/data'
+file_base_url = f'https://ncls.gloria.com.cn:8070/data'
 model_save_path = '/dalin/name_classify/bert/model/bert-name-classify/best_model_2024_9_4.pkl'
 pre_train_model = '/dalin/name_classify/bert/model/bert-name-classify'
 label_revert_map = {0:'亚裔华人', 1:'亚裔非华人', 2:'非亚裔'}

+ 173 - 65
name_classify_api.py

@@ -9,24 +9,66 @@ import uvicorn, socket
 from tqdm import tqdm
 from fastapi.staticfiles import StaticFiles
 from config import *
-
+import asyncio, threading
 from functions import split_dataframe_to_dict, extract_list_from_string
 sys.path.append(os.path.join(os.path.dirname(__file__), 'bert'))
 import torch
 # from model import BertClassifier
 from bert.model import BertClassifier
 from transformers import BertTokenizer, BertConfig
-
+from fastapi.middleware.cors import CORSMiddleware
 app = FastAPI()
+app.add_middleware(
+   CORSMiddleware,
+   allow_origins=["*"],
+   allow_credentials=True,
+   allow_methods=["*"],
+   allow_headers=["*"],
+)
 if not os.path.exists(basic_path):
     os.makedirs(basic_path, exist_ok=True)
 app.mount("/data", StaticFiles(directory=basic_path), name="static")
+bert_config = BertConfig.from_pretrained(pre_train_model)
+# 定义模型
+model = BertClassifier(bert_config, len(label_revert_map.keys()))
+# 加载训练好的模型
+model.load_state_dict(torch.load(model_save_path, map_location=torch.device('cuda:0')))
+model.eval()
+tokenizer = BertTokenizer.from_pretrained(pre_train_model)
+# class ModelManager:
+#     def __init__(self, pre_train_model, model_save_path, label_revert_map):
+#         self.pre_train_model = pre_train_model
+#         self.model_save_path = model_save_path
+#         self.label_revert_map = label_revert_map
+#         self.model_cache = {}
+#         self.tokenizer_cache = {}
+
+#     def get_model(self):
+#         # 使用当前线程ID作为键,确保每个线程有自己的模型实例
+#         thread_id = threading.get_ident()
+        
+#         if thread_id not in self.model_cache:
+#             bert_config = BertConfig.from_pretrained(self.pre_train_model)
+#             model = BertClassifier(bert_config, len(self.label_revert_map.keys()))
+#             model.load_state_dict(torch.load(self.model_save_path, map_location=torch.device('cuda:0')))
+#             model.eval()
+#             self.model_cache[thread_id] = model
+
+#         if thread_id not in self.tokenizer_cache:
+#             tokenizer = BertTokenizer.from_pretrained(self.pre_train_model)
+#             self.tokenizer_cache[thread_id] = tokenizer
+
+#         print(f"Thread {thread_id}: Tokenizer id: {id(self.tokenizer_cache[thread_id])}, Model id: {id(self.model_cache[thread_id])}")
+#         return self.tokenizer_cache[thread_id], self.model_cache[thread_id]
+
+# 使用示例
+# model_manager = ModelManager(pre_train_model, model_save_path, label_revert_map)
 class ClassificationRequest(BaseModel):
     path: str
     client_id: str
     one_key: str
     name_column: str
-    api_key: str = "sk-iREtaVNjamaBArOTlc_2BfGFJVPiU-9EjSFMUspIPBT3BlbkFJxS0SMmKZD9L9UumPczee4VKawCwVeGBQAr9MgsWGkA"
+    api_key: str = "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"
     proxy: bool = False
     chunk_size: int = 100
 
@@ -35,13 +77,13 @@ class ClassificationRequestBert(BaseModel):
     client_id: str
     name_column: str
 
-bert_config = BertConfig.from_pretrained(pre_train_model)
-# 定义模型
-model = BertClassifier(bert_config, len(label_revert_map.keys()))
-# 加载训练好的模型
-model.load_state_dict(torch.load(model_save_path, map_location=torch.device('cpu')))
-model.eval()
-tokenizer = BertTokenizer.from_pretrained(pre_train_model)
+# bert_config = BertConfig.from_pretrained(pre_train_model)
+# # 定义模型
+# model = BertClassifier(bert_config, len(label_revert_map.keys()))
+# # 加载训练好的模型
+# model.load_state_dict(torch.load(model_save_path, map_location=torch.device('cuda:0')))
+# model.eval()
+# tokenizer = BertTokenizer.from_pretrained(pre_train_model)
 
 def bert_predict(text):
     if type(text) == str and text != '':
@@ -64,7 +106,22 @@ def bert_predict(text):
         return label_revert_map[pred_label]
     else:
         return ''
-def predict_excel(file_path, name_col, temp_path, save_path, chunksize=5):
+def process_data(origin_csv_path, temp_path, name_col, chunksize, total_processed, df_error):
+    for chunk in tqdm(pd.read_csv(origin_csv_path, chunksize=chunksize, iterator=True), total=len(pd.read_csv(origin_csv_path)) // chunksize + 1, desc='Processing', unit='item'):
+            try:
+                # 对每个块进行处理
+                chunk['AI Group'] = chunk[name_col].apply(lambda x : bert_predict(x))
+                # 增量保存处理结果
+                if total_processed == 0:
+                    chunk.to_csv(temp_path, mode='w', index=False)
+                else:
+                    chunk.to_csv(temp_path, mode='a', header=False, index=False)
+                # 更新已处理的数据量
+                total_processed += len(chunk)
+            except Exception as e:
+                df_error = pd.concat([df_error, chunk])
+    return temp_path, df_error, total_processed
+async def predict_excel(file_path, name_col, temp_path, save_path, chunksize=5):
     # 初始化变量
     error_file_name, error_file_extension = os.path.splitext(os.path.basename(save_path))
     # 添加后缀
@@ -77,28 +134,44 @@ def predict_excel(file_path, name_col, temp_path, save_path, chunksize=5):
     df_origin = pd.read_excel(file_path)
     df_error = pd.DataFrame(columns=df_origin.columns)
     df_origin.to_csv(origin_csv_path, index=False)
+
+    temp_path, df_error, total_processed = await asyncio.to_thread(process_data, origin_csv_path, temp_path, name_col, chunksize, total_processed, df_error)
+    df_final = pd.read_csv(temp_path)
+    df_final.to_excel(save_path, index=False)
+    try:
+        os.remove(origin_csv_path)
+    except Exception as e :
+        pass
+    if len(df_error) == 0:
+        return save_path, '', 0
+    else:
+        df_error.to_excel(error_file_path)
+        return save_path, error_file_path, len(df_error)
+
+async def predict_csv(file_path, name_col, temp_path, save_path, chunksize=5):
+    # 初始化变量
+    error_file_name, error_file_extension = os.path.splitext(os.path.basename(save_path))
+    # 添加后缀
+    error_file = error_file_name + '_error' + error_file_extension
+    origin_csv_file = error_file_name + '_origin.csv'
+    # 生成新的文件路径
+    error_file_path = os.path.join(os.path.dirname(save_path), error_file)
+    origin_csv_path = os.path.join(os.path.dirname(save_path), origin_csv_file)
+    total_processed = 0
+    df_origin = pd.read_csv(file_path)
+    df_error = pd.DataFrame(columns=df_origin.columns)
+    df_origin.to_csv(origin_csv_path, index=False)
     # 按块读取 CSV 文件
     
-    for chunk in tqdm(pd.read_csv(origin_csv_path, chunksize=chunksize, iterator=True), total=len(pd.read_csv(origin_csv_path)) // chunksize + 1, desc='Processing', unit='item'):
-        try:
-            # 对每个块进行处理
-            chunk['classify'] = chunk[name_col].apply(bert_predict)
-            # 增量保存处理结果
-            if total_processed == 0:
-                chunk.to_csv(temp_path, mode='w', index=False)
-            else:
-                chunk.to_csv(temp_path, mode='a', header=False, index=False)
-            # 更新已处理的数据量
-            total_processed += len(chunk)
-        except Exception as e:
-            df_error = pd.concat([df_error, chunk])
+    # for chunk in tqdm(pd.read_csv(origin_csv_path, chunksize=chunksize, iterator=True), total=len(pd.read_csv(origin_csv_path)) // chunksize + 1, desc='Processing', unit='item'):
+    temp_path, df_error, total_processed = await asyncio.to_thread(process_data, origin_csv_path, temp_path, name_col, chunksize, total_processed, df_error)
     df_final = pd.read_csv(temp_path)
-    df_final.to_excel(save_path)
+    df_final.to_csv(save_path, index=False)
     os.remove(origin_csv_path)
     if len(df_error) == 0:
         return save_path, '', 0
     else:
-        df_error.to_excel(error_file_path)
+        df_error.to_csv(error_file_path)
         return save_path, error_file_path, len(df_error)
 def remove_files():
     current_time = time.time()
@@ -130,12 +203,29 @@ async def create_upload_file(file: UploadFile = File(...), client_id: str = Form
         os.chmod(file_location, 0o777)  # 设置文件权限为777
         return JSONResponse(content={
             "message": f"文件 '{file.filename}' 上传成功",
-            "client_id": client_id,
-            "file_path": file_location
+            "data":{"client_id": client_id,"file_path": file_location},
+            "code":200
         }, status_code=200)
     except Exception as e:
-        return JSONResponse(content={"message": f"发生错误: {str(e)}"}, status_code=500)
-    
+        return JSONResponse(content={"message": f"发生错误: {str(e)}", "data":{}, "code":500}, status_code=500)
+async def process_openai_data(deal_result, client, temp_csv):
+    for name, value in tqdm(deal_result.items(), desc='Processing', unit='item'):
+        try:
+            message = [
+                {'role':'system', 'content': cls_system_prompt},
+                {'role':'user', 'content':user_prompt.format(chunk=str(value))}
+            ]
+            # result_string = post_openai(message)
+            response = await client.chat.completions.create(model='gpt-4',messages=message)
+            result_string = response.choices[0].message.content
+            result = extract_list_from_string(result_string)
+            if result:
+                df_output = pd.DataFrame(result)
+                df_output.to_csv(temp_csv, mode='a', header=True, index=False)
+            else:
+                continue
+        except Exception as e:
+            print(f'{name}出现问题啦, 错误为:{e} 请自行调试')
 @app.post("/classify_openai/")
 async def classify_data(request: ClassificationRequest):
     try:
@@ -143,16 +233,20 @@ async def classify_data(request: ClassificationRequest):
         work_path = f'{basic_path}/{request.client_id}'
         if not os.path.exists(work_path):
             os.makedirs(work_path, exist_ok=True)
+        final_file_name, final_file_extension = os.path.splitext(os.path.basename(request.path))
         timestamp_str = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
-        df_origin = pd.read_excel(request.path)
-        df_origin['name'] = df_origin[request.name_column]
-        df_origin['classify'] = ''
-        df_use = df_origin[['name', 'classify']]
+        if final_file_extension == '.csv':
+            df_origin = pd.read_csv(request.path)
+        else:
+            df_origin = pd.read_excel(request.path)
+        df_origin['AI Name'] = df_origin[request.name_column]
+        df_origin['AI Group'] = ''
+        df_use = df_origin[['AI Name', 'AI Group']]
         deal_result = split_dataframe_to_dict(df_use, request.chunk_size)
         # 生成当前时间的时间戳字符串
         
         temp_csv = work_path + '/' + timestamp_str + 'output_temp.csv'
-        final_file_name, final_file_extension = os.path.splitext(os.path.basename(request.path))
+        
         # 添加后缀
         final_file = final_file_name + '_classify' + final_file_extension
 
@@ -161,54 +255,68 @@ async def classify_data(request: ClassificationRequest):
 
         if not request.proxy:
             print(f'用户{request.client_id}正在使用直连的gpt-API')
-            client = openai.OpenAI(api_key=request.api_key, base_url=openai_url)
+            client = openai.AsyncOpenAI(api_key=request.api_key, base_url=openai_url)
         else:
-            client = openai.OpenAI(api_key=request.api_key, base_url=proxy_url)
-        for name, value in tqdm(deal_result.items(), desc='Processing', unit='item'):
-            try:
-                message = [
-                    {'role':'system', 'content': cls_system_prompt},
-                    {'role':'user', 'content':user_prompt.format(chunk=str(value))}
-                ]
-                # result_string = post_openai(message)
-                response = client.chat.completions.create(model='gpt-4',messages=message)
-                result_string = response.choices[0].message.content
-                result = extract_list_from_string(result_string)
-                if result:
-                    df_output = pd.DataFrame(result)
-                    df_output.to_csv(temp_csv, mode='a', header=True, index=False)
-                else:
-                    continue
-            except Exception as e:
-                print(f'{name}出现问题啦, 错误为:{e} 请自行调试')
+            client = openai.AsyncOpenAI(api_key=request.api_key, base_url=proxy_url)
+        x = await process_openai_data(deal_result, client, temp_csv)
         if os.path.exists(temp_csv):
             df_result = pd.read_csv(temp_csv)
-            df_final = df_origin.merge(df_result, on='name', how='left').drop_duplicates(subset=[request.one_key,'name'], keep='first')
-            df_final.to_excel(new_file_path)
-            return {"message": "分类完成", "output_file": file_base_url + new_file_path.split(basic_path)[1]}
+            df_origin = df_origin.drop(columns='AI Group')
+            df_final = df_origin.merge(df_result, on='AI Name', how='left').drop_duplicates(subset=[request.one_key,'AI Name'], keep='first')
+            df_final = df_final.drop(columns='AI Name')
+            if final_file_extension == '.csv':
+                df_final.to_csv(new_file_path, index=False)
+            else:
+                df_final.to_excel(new_file_path, index=False)
+            return JSONResponse(content={
+            "message": f"分类完成",
+            "data":{"output_file": file_base_url + new_file_path.split(basic_path)[1]},
+            "code":200
+        }, status_code=200)
         else:
-            return {"message": "文件没能处理成功"}
+            print('...')
+            return JSONResponse(content={"message": "文件没能处理成功","data":{},"code":500}, status_code=500)
     except Exception as e:
-        return {"message": f"处理出现错误: {e}"}
+        print('okok')
+        return JSONResponse(content={"message": f"处理出现错误: {e}","data":{}, "code":500}, status_code=500)
 
 @app.post("/classify_bert/")
-async def classify_data(request: ClassificationRequestBert):
+async def real_process(request: ClassificationRequestBert):
+    task = asyncio.create_task(classify_data_bert(request))
+    result = await task
+    return result
+
+async def classify_data_bert(request: ClassificationRequestBert):
     remove_files()
     work_path = f'{basic_path}/{request.client_id}'
     if not os.path.exists(work_path):
         os.makedirs(work_path, exist_ok=True)
     final_file_name, final_file_extension = os.path.splitext(os.path.basename(request.path))
-    # 添加后缀
     final_file = final_file_name + '_classify' + final_file_extension
-
     # 生成新的文件路径
     new_file_path = os.path.join(os.path.dirname(request.path), final_file)
     timestamp_str = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
     temp_csv = work_path + '/' + timestamp_str + 'output_temp.csv'
-    save_path, error_path, error_len = predict_excel(request.path, request.name_column, temp_path=temp_csv ,save_path=new_file_path)
+    if final_file_extension == '.csv':
+        save_path, error_path, error_len = await predict_csv(request.path, request.name_column, temp_path=temp_csv ,save_path=new_file_path)
+    else:
+        # 添加后缀
+        save_path, error_path, error_len = await predict_excel(request.path, request.name_column, temp_path=temp_csv ,save_path=new_file_path)
     if error_len == 0:
-        return {"message": "分类完成", "output_file": file_base_url + save_path.split(basic_path)[1]}
+        return JSONResponse(content={
+            "message": "分类完成",
+            "data":{"output_file": file_base_url + save_path.split(basic_path)[1]},
+            "code":200
+        }, status_code=200)
     else:
-        return {"message": "分类完成只完成部分", "output_file": file_base_url + save_path.split(basic_path)[1], "output_file_nonprocess":file_base_url + error_path.split(basic_path)[1],}
+        return JSONResponse(content={
+            "message": "分类完成只完成部分",
+            "data":{"output_file": file_base_url + save_path.split(basic_path)[1],
+                    "output_file_nonprocess":file_base_url + error_path.split(basic_path)[1]},
+            "code":200
+        }, status_code=200)
+@app.get("/test/")
+async def classify_data():
+    return '测试成功'
 if __name__ == "__main__":
     uvicorn.run(app, host="0.0.0.0", port=port)

+ 4 - 0
test.csv

@@ -0,0 +1,4 @@
+name,one_key
+Wang Dalin,0
+Carl,1
+Jack,2

BIN
test.xlsx