wangdalin 9 ヶ月 前
コミット
3e104729c2
2 ファイル変更16 行追加1 行削除
  1. 1 1
      config.py
  2. 15 0
      name_classify_api.py

+ 1 - 1
config.py

@@ -11,7 +11,7 @@ def get_ip_address():
         return ip_address
     except Exception as e:
         return str(e)
-api_key = "sk-iREtaVNjamaBArOTlc_2BfGFJVPiU-9EjSFMUspIPBT3BlbkFJxS0SMmKZD9L9UumPczee4VKawCwVeGBQAr9MgsWGkA"
+api_key = "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"
 basic_path = './process'
 openai_url = 'https://api.openai.com/v1'
 proxy_url = 'https://fast.bemore.lol/v1'

+ 15 - 0
name_classify_api.py

@@ -107,6 +107,8 @@ def bert_predict(text):
     else:
         return ''
 def process_data(origin_csv_path, temp_path, name_col, chunksize, total_processed, df_error):
+    if name_col not in pd.read_csv(origin_csv_path).columns:
+        return '列名错误', None, None
     for chunk in tqdm(pd.read_csv(origin_csv_path, chunksize=chunksize, iterator=True), total=len(pd.read_csv(origin_csv_path)) // chunksize + 1, desc='Processing', unit='item'):
             try:
                 # 对每个块进行处理
@@ -136,6 +138,8 @@ async def predict_excel(file_path, name_col, temp_path, save_path, chunksize=5):
     df_origin.to_csv(origin_csv_path, index=False)
 
     temp_path, df_error, total_processed = await asyncio.to_thread(process_data, origin_csv_path, temp_path, name_col, chunksize, total_processed, df_error)
+    if temp_path == '列名错误':
+        return '列名错误', None, None
     df_final = pd.read_csv(temp_path)
     df_final.to_excel(save_path, index=False)
     try:
@@ -165,6 +169,8 @@ async def predict_csv(file_path, name_col, temp_path, save_path, chunksize=5):
     
     # for chunk in tqdm(pd.read_csv(origin_csv_path, chunksize=chunksize, iterator=True), total=len(pd.read_csv(origin_csv_path)) // chunksize + 1, desc='Processing', unit='item'):
     temp_path, df_error, total_processed = await asyncio.to_thread(process_data, origin_csv_path, temp_path, name_col, chunksize, total_processed, df_error)
+    if temp_path == '列名错误':
+        return '列名错误', None, None
     df_final = pd.read_csv(temp_path)
     df_final.to_csv(save_path, index=False)
     os.remove(origin_csv_path)
@@ -239,6 +245,8 @@ async def classify_data(request: ClassificationRequest):
             df_origin = pd.read_csv(request.path)
         else:
             df_origin = pd.read_excel(request.path)
+        if request.name_column not in df_origin.columns or request.one_key not in df_origin.columns:
+            return JSONResponse(content={"message": "用户标识列输入错误,请重新输入","data":{},"code":500}, status_code=500)
         df_origin['AI Name'] = df_origin[request.name_column]
         df_origin['AI Group'] = ''
         df_use = df_origin[['AI Name', 'AI Group']]
@@ -308,6 +316,13 @@ async def classify_data_bert(request: ClassificationRequestBert):
             "data":{"output_file": file_base_url + save_path.split(basic_path)[1]},
             "code":200
         }, status_code=200)
+    elif save_path == '列名错误':
+        return JSONResponse(content={
+            "message": "用户标识列输入错误,请重新输入",
+            "data":{},
+            "code":500
+        }, status_code=200)
+        
     else:
         return JSONResponse(content={
             "message": "分类完成只完成部分",