ソースを参照

优化代码, 增加GBK解码

wangdalin 9 ヶ月 前
コミット
f13f279d92
3 ファイル変更116 行追加164 行削除
  1. 116 160
      name_classify_api.py
  2. 0 4
      test.csv
  3. BIN
      test.xlsx

+ 116 - 160
name_classify_api.py

@@ -35,34 +35,7 @@ model = BertClassifier(bert_config, len(label_revert_map.keys()))
 model.load_state_dict(torch.load(model_save_path, map_location=torch.device('cuda:0')))
 model.eval()
 tokenizer = BertTokenizer.from_pretrained(pre_train_model)
-# class ModelManager:
-#     def __init__(self, pre_train_model, model_save_path, label_revert_map):
-#         self.pre_train_model = pre_train_model
-#         self.model_save_path = model_save_path
-#         self.label_revert_map = label_revert_map
-#         self.model_cache = {}
-#         self.tokenizer_cache = {}
 
-#     def get_model(self):
-#         # 使用当前线程ID作为键,确保每个线程有自己的模型实例
-#         thread_id = threading.get_ident()
-        
-#         if thread_id not in self.model_cache:
-#             bert_config = BertConfig.from_pretrained(self.pre_train_model)
-#             model = BertClassifier(bert_config, len(self.label_revert_map.keys()))
-#             model.load_state_dict(torch.load(self.model_save_path, map_location=torch.device('cuda:0')))
-#             model.eval()
-#             self.model_cache[thread_id] = model
-
-#         if thread_id not in self.tokenizer_cache:
-#             tokenizer = BertTokenizer.from_pretrained(self.pre_train_model)
-#             self.tokenizer_cache[thread_id] = tokenizer
-
-#         print(f"Thread {thread_id}: Tokenizer id: {id(self.tokenizer_cache[thread_id])}, Model id: {id(self.model_cache[thread_id])}")
-#         return self.tokenizer_cache[thread_id], self.model_cache[thread_id]
-
-# 使用示例
-# model_manager = ModelManager(pre_train_model, model_save_path, label_revert_map)
 class ClassificationRequest(BaseModel):
     path: str
     client_id: str
@@ -77,14 +50,6 @@ class ClassificationRequestBert(BaseModel):
     client_id: str
     name_column: str
 
-# bert_config = BertConfig.from_pretrained(pre_train_model)
-# # 定义模型
-# model = BertClassifier(bert_config, len(label_revert_map.keys()))
-# # 加载训练好的模型
-# model.load_state_dict(torch.load(model_save_path, map_location=torch.device('cuda:0')))
-# model.eval()
-# tokenizer = BertTokenizer.from_pretrained(pre_train_model)
-
 def bert_predict(text):
     if type(text) == str and text != '':
         token = tokenizer(text, add_special_tokens=True, padding='max_length', truncation=True, max_length=512)
@@ -106,79 +71,7 @@ def bert_predict(text):
         return label_revert_map[pred_label]
     else:
         return ''
-def process_data(origin_csv_path, temp_path, name_col, chunksize, total_processed, df_error):
-    if name_col not in pd.read_csv(origin_csv_path).columns:
-        return '列名错误', None, None
-    for chunk in tqdm(pd.read_csv(origin_csv_path, chunksize=chunksize, iterator=True), total=len(pd.read_csv(origin_csv_path)) // chunksize + 1, desc='Processing', unit='item'):
-            try:
-                # 对每个块进行处理
-                chunk['AI Group'] = chunk[name_col].apply(lambda x : bert_predict(x))
-                # 增量保存处理结果
-                if total_processed == 0:
-                    chunk.to_csv(temp_path, mode='w', index=False)
-                else:
-                    chunk.to_csv(temp_path, mode='a', header=False, index=False)
-                # 更新已处理的数据量
-                total_processed += len(chunk)
-            except Exception as e:
-                df_error = pd.concat([df_error, chunk])
-    return temp_path, df_error, total_processed
-async def predict_excel(file_path, name_col, temp_path, save_path, chunksize=5):
-    # 初始化变量
-    error_file_name, error_file_extension = os.path.splitext(os.path.basename(save_path))
-    # 添加后缀
-    error_file = error_file_name + '_error' + error_file_extension
-    origin_csv_file = error_file_name + '_origin.csv'
-    # 生成新的文件路径
-    error_file_path = os.path.join(os.path.dirname(save_path), error_file)
-    origin_csv_path = os.path.join(os.path.dirname(save_path), origin_csv_file)
-    total_processed = 0
-    df_origin = pd.read_excel(file_path)
-    df_error = pd.DataFrame(columns=df_origin.columns)
-    df_origin.to_csv(origin_csv_path, index=False)
-
-    temp_path, df_error, total_processed = await asyncio.to_thread(process_data, origin_csv_path, temp_path, name_col, chunksize, total_processed, df_error)
-    if temp_path == '列名错误':
-        return '列名错误', None, None
-    df_final = pd.read_csv(temp_path)
-    df_final.to_excel(save_path, index=False)
-    try:
-        os.remove(origin_csv_path)
-    except Exception as e :
-        pass
-    if len(df_error) == 0:
-        return save_path, '', 0
-    else:
-        df_error.to_excel(error_file_path)
-        return save_path, error_file_path, len(df_error)
 
-async def predict_csv(file_path, name_col, temp_path, save_path, chunksize=5):
-    # 初始化变量
-    error_file_name, error_file_extension = os.path.splitext(os.path.basename(save_path))
-    # 添加后缀
-    error_file = error_file_name + '_error' + error_file_extension
-    origin_csv_file = error_file_name + '_origin.csv'
-    # 生成新的文件路径
-    error_file_path = os.path.join(os.path.dirname(save_path), error_file)
-    origin_csv_path = os.path.join(os.path.dirname(save_path), origin_csv_file)
-    total_processed = 0
-    df_origin = pd.read_csv(file_path)
-    df_error = pd.DataFrame(columns=df_origin.columns)
-    df_origin.to_csv(origin_csv_path, index=False)
-    # 按块读取 CSV 文件
-    
-    # for chunk in tqdm(pd.read_csv(origin_csv_path, chunksize=chunksize, iterator=True), total=len(pd.read_csv(origin_csv_path)) // chunksize + 1, desc='Processing', unit='item'):
-    temp_path, df_error, total_processed = await asyncio.to_thread(process_data, origin_csv_path, temp_path, name_col, chunksize, total_processed, df_error)
-    if temp_path == '列名错误':
-        return '列名错误', None, None
-    df_final = pd.read_csv(temp_path)
-    df_final.to_csv(save_path, index=False)
-    os.remove(origin_csv_path)
-    if len(df_error) == 0:
-        return save_path, '', 0
-    else:
-        df_error.to_csv(error_file_path)
-        return save_path, error_file_path, len(df_error)
 def remove_files():
     current_time = time.time()
     TIME_THRESHOLD_FILEPATH = 30 * 24 * 60 * 60
@@ -196,6 +89,7 @@ def remove_files():
             if current_time - os.path.getmtime(dir_path) > TIME_THRESHOLD_FILEPATH:
                 print(f"删除文件夹: {dir_path}")
                 shutil.rmtree(dir_path)
+
 @app.post("/uploadfile/")
 async def create_upload_file(file: UploadFile = File(...), client_id: str = Form(...)):
     user_directory = f'{basic_path}/{client_id}'
@@ -214,24 +108,22 @@ async def create_upload_file(file: UploadFile = File(...), client_id: str = Form
         }, status_code=200)
     except Exception as e:
         return JSONResponse(content={"message": f"发生错误: {str(e)}", "data":{}, "code":500}, status_code=500)
+
 async def process_openai_data(deal_result, client, temp_csv):
     for name, value in tqdm(deal_result.items(), desc='Processing', unit='item'):
-        try:
-            message = [
-                {'role':'system', 'content': cls_system_prompt},
-                {'role':'user', 'content':user_prompt.format(chunk=str(value))}
-            ]
-            # result_string = post_openai(message)
-            response = await client.chat.completions.create(model='gpt-4',messages=message)
-            result_string = response.choices[0].message.content
-            result = extract_list_from_string(result_string)
-            if result:
-                df_output = pd.DataFrame(result)
-                df_output.to_csv(temp_csv, mode='a', header=True, index=False)
-            else:
-                continue
-        except Exception as e:
-            print(f'{name}出现问题啦, 错误为:{e} 请自行调试')
+        message = [
+            {'role':'system', 'content': cls_system_prompt},
+            {'role':'user', 'content':user_prompt.format(chunk=str(value))}
+        ]
+        # result_string = post_openai(message)
+        response = await client.chat.completions.create(model='gpt-4',messages=message)
+        result_string = response.choices[0].message.content
+        result = extract_list_from_string(result_string)
+        if result:
+            df_output = pd.DataFrame(result)
+            df_output.to_csv(temp_csv, mode='a', header=True, index=False)
+        else:
+            continue
 @app.post("/classify_openai/")
 async def classify_data(request: ClassificationRequest):
     try:
@@ -290,48 +182,112 @@ async def classify_data(request: ClassificationRequest):
 
 @app.post("/classify_bert/")
 async def real_process(request: ClassificationRequestBert):
-    task = asyncio.create_task(classify_data_bert(request))
-    result = await task
+    # task = asyncio.create_task(classify_data_bert(request))
+    result = await classify_data_bert(request)
     return result
 
 async def classify_data_bert(request: ClassificationRequestBert):
-    remove_files()
-    work_path = f'{basic_path}/{request.client_id}'
-    if not os.path.exists(work_path):
-        os.makedirs(work_path, exist_ok=True)
-    final_file_name, final_file_extension = os.path.splitext(os.path.basename(request.path))
-    final_file = final_file_name + '_classify' + final_file_extension
-    # 生成新的文件路径
-    new_file_path = os.path.join(os.path.dirname(request.path), final_file)
-    timestamp_str = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
-    temp_csv = work_path + '/' + timestamp_str + 'output_temp.csv'
-    if final_file_extension == '.csv':
-        save_path, error_path, error_len = await predict_csv(request.path, request.name_column, temp_path=temp_csv ,save_path=new_file_path)
-    else:
-        # 添加后缀
-        save_path, error_path, error_len = await predict_excel(request.path, request.name_column, temp_path=temp_csv ,save_path=new_file_path)
-    if error_len == 0:
-        return JSONResponse(content={
-            "message": "分类完成",
-            "data":{"output_file": file_base_url + save_path.split(basic_path)[1]},
-            "code":200
-        }, status_code=200)
-    elif save_path == '列名错误':
-        return JSONResponse(content={
-            "message": "用户标识列输入错误,请重新输入",
-            "data":{},
-            "code":500
-        }, status_code=200)
-        
-    else:
+    try:
+        remove_files()
+        work_path = f'{basic_path}/{request.client_id}'
+        if not os.path.exists(work_path):
+            os.makedirs(work_path, exist_ok=True)
+        final_file_name, final_file_extension = os.path.splitext(os.path.basename(request.path))
+        final_file = final_file_name + '_classify' + final_file_extension
+        new_file_path = os.path.join(os.path.dirname(request.path), final_file)
+        timestamp_str = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
+        temp_csv = work_path + '/' + timestamp_str + 'output_temp.csv'
+        def process_data(origin_csv_path, temp_path, name_col, chunksize, total_processed, df_error):
+            if name_col not in pd.read_csv(origin_csv_path).columns:
+                return '列名错误', None, None
+            for chunk in tqdm(pd.read_csv(origin_csv_path, chunksize=chunksize, iterator=True), total=len(pd.read_csv(origin_csv_path)) // chunksize + 1, desc='Processing', unit='item'):
+                    try:
+                        # 对每个块进行处理
+                        chunk['AI Group'] = chunk[name_col].apply(lambda x : bert_predict(x))
+                        # 增量保存处理结果
+                        if total_processed == 0:
+                            chunk.to_csv(temp_path, mode='w', index=False)
+                        else:
+                            chunk.to_csv(temp_path, mode='a', header=False, index=False)
+                        # 更新已处理的数据量
+                        total_processed += len(chunk)
+                    except Exception as e:
+                        df_error = pd.concat([df_error, chunk])
+            return temp_path, df_error, total_processed
+        async def process_file(file_path, name_col, temp_path, save_path, chunksize=5):
+            error_file_name, error_file_extension = os.path.splitext(os.path.basename(save_path))
+            error_file = error_file_name + '_error' + error_file_extension
+            origin_csv_file = error_file_name + '_origin.csv'
+            error_file_path = os.path.join(os.path.dirname(save_path), error_file)
+            origin_csv_path = os.path.join(os.path.dirname(save_path), origin_csv_file)
+            
+            if final_file_extension == '.csv':
+                try:
+                    df_origin = pd.read_csv(file_path, encoding='utf-8')
+                except UnicodeDecodeError:
+                    df_origin = pd.read_csv(file_path, encoding='gbk')
+            else:
+                df_origin = pd.read_excel(file_path)
+            
+            df_origin.to_csv(origin_csv_path, index=False)
+            df_error = pd.DataFrame(columns=df_origin.columns)
+            
+            if name_col not in df_origin.columns:
+                return '列名错误', None, None
+            
+            total_processed = 0
+            temp_path, df_error, total_processed = await asyncio.to_thread(process_data, origin_csv_path, temp_path, name_col, chunksize, total_processed, df_error)
+
+            df_final = pd.read_csv(temp_path)
+            if final_file_extension == '.csv':
+                df_final.to_csv(save_path, index=False)
+            else:
+                df_final.to_excel(save_path, index=False)
+            
+            try:
+                os.remove(origin_csv_path)
+            except Exception as e:
+                pass
+
+            if len(df_error) == 0:
+                return save_path, '', 0
+            else:
+                if final_file_extension == '.csv':
+                    df_error.to_csv(error_file_path, index=False)
+                else:
+                    df_error.to_excel(error_file_path, index=False)
+                return save_path, error_file_path, len(df_error)
+
+        save_path, error_path, error_len = await process_file(request.path, request.name_column, temp_path=temp_csv, save_path=new_file_path)
+
+        if error_len == 0:
+            return JSONResponse(content={
+                "message": "分类完成",
+                "data":{"output_file": file_base_url + save_path.split(basic_path)[1]},
+                "code":200
+            }, status_code=200)
+        elif save_path == '列名错误':
+            return JSONResponse(content={
+                "message": "用户标识列输入错误,请重新输入",
+                "data":{},
+                "code":500
+            }, status_code=200)
+        else:
+            return JSONResponse(content={
+                "message": "分类完成只完成部分",
+                "data":{"output_file": file_base_url + save_path.split(basic_path)[1],
+                        "output_file_nonprocess":file_base_url + error_path.split(basic_path)[1]},
+                "code":200
+            }, status_code=200)
+    except Exception as e:
         return JSONResponse(content={
-            "message": "分类完成只完成部分",
-            "data":{"output_file": file_base_url + save_path.split(basic_path)[1],
-                    "output_file_nonprocess":file_base_url + error_path.split(basic_path)[1]},
-            "code":200
-        }, status_code=200)
+                "message": f"出现错误: {e}, 请检查列名和文件重试",
+                "data":{},
+                "code":500
+            }, status_code=200)
+
 @app.get("/test/")
 async def classify_data():
     return '测试成功'
 if __name__ == "__main__":
-    uvicorn.run(app, host="0.0.0.0", port=port)
+    uvicorn.run(app, host="0.0.0.0", port=port)

+ 0 - 4
test.csv

@@ -1,4 +0,0 @@
-name,one_key
-Wang Dalin,0
-Carl,1
-Jack,2

BIN
test.xlsx