|
@@ -18,7 +18,9 @@ from bert.model import BertClassifier
|
|
|
from transformers import BertTokenizer, BertConfig
|
|
|
|
|
|
app = FastAPI()
|
|
|
-app.mount("/data", StaticFiles(directory='./process'), name="static")
|
|
|
+if not os.path.exists(basic_path):
|
|
|
+ os.makedirs(basic_path, exist_ok=True)
|
|
|
+app.mount("/data", StaticFiles(directory=basic_path), name="static")
|
|
|
class ClassificationRequest(BaseModel):
|
|
|
path: str
|
|
|
client_id: str
|
|
@@ -42,23 +44,26 @@ model.eval()
|
|
|
tokenizer = BertTokenizer.from_pretrained(pre_train_model)
|
|
|
|
|
|
def bert_predict(text):
|
|
|
- token = tokenizer(text, add_special_tokens=True, padding='max_length', truncation=True, max_length=512)
|
|
|
- input_ids = token['input_ids']
|
|
|
- attention_mask = token['attention_mask']
|
|
|
- token_type_ids = token['token_type_ids']
|
|
|
+ if type(text) == str and text != '':
|
|
|
+ token = tokenizer(text, add_special_tokens=True, padding='max_length', truncation=True, max_length=512)
|
|
|
+ input_ids = token['input_ids']
|
|
|
+ attention_mask = token['attention_mask']
|
|
|
+ token_type_ids = token['token_type_ids']
|
|
|
|
|
|
- input_ids = torch.tensor([input_ids], dtype=torch.long)
|
|
|
- attention_mask = torch.tensor([attention_mask], dtype=torch.long)
|
|
|
- token_type_ids = torch.tensor([token_type_ids], dtype=torch.long)
|
|
|
+ input_ids = torch.tensor([input_ids], dtype=torch.long)
|
|
|
+ attention_mask = torch.tensor([attention_mask], dtype=torch.long)
|
|
|
+ token_type_ids = torch.tensor([token_type_ids], dtype=torch.long)
|
|
|
|
|
|
- predicted = model(
|
|
|
- input_ids,
|
|
|
- attention_mask,
|
|
|
- token_type_ids,
|
|
|
- )
|
|
|
- pred_label = torch.argmax(predicted, dim=1).numpy()[0]
|
|
|
-
|
|
|
- return label_revert_map[pred_label]
|
|
|
+ predicted = model(
|
|
|
+ input_ids,
|
|
|
+ attention_mask,
|
|
|
+ token_type_ids,
|
|
|
+ )
|
|
|
+ pred_label = torch.argmax(predicted, dim=1).numpy()[0]
|
|
|
+
|
|
|
+ return label_revert_map[pred_label]
|
|
|
+ else:
|
|
|
+ return ''
|
|
|
def predict_excel(file_path, name_col, temp_path, save_path, chunksize=5):
|
|
|
# 初始化变量
|
|
|
error_file_name, error_file_extension = os.path.splitext(os.path.basename(save_path))
|
|
@@ -74,7 +79,7 @@ def predict_excel(file_path, name_col, temp_path, save_path, chunksize=5):
|
|
|
df_origin.to_csv(origin_csv_path, index=False)
|
|
|
# 按块读取 CSV 文件
|
|
|
|
|
|
- for chunk in tqdm(pd.read_csv(origin_csv_path, chunksize=chunksize, iterator=True), desc='Processing', unit='item'):
|
|
|
+ for chunk in tqdm(pd.read_csv(origin_csv_path, chunksize=chunksize, iterator=True), total=len(pd.read_csv(origin_csv_path)) // chunksize + 1, desc='Processing', unit='item'):
|
|
|
try:
|
|
|
# 对每个块进行处理
|
|
|
chunk['classify'] = chunk[name_col].apply(bert_predict)
|
|
@@ -204,6 +209,6 @@ async def classify_data(request: ClassificationRequestBert):
|
|
|
if error_len == 0:
|
|
|
return {"message": "分类完成", "output_file": file_base_url + save_path.split(basic_path)[1]}
|
|
|
else:
|
|
|
- return {"message": "分类完成只完成部分", "output_file": file_base_url + save_path.split(basic_path)[1], "output_file_nonprocess":file_base_url + save_path.split(error_path)[1],}
|
|
|
+ return {"message": "分类完成只完成部分", "output_file": file_base_url + save_path.split(basic_path)[1], "output_file_nonprocess":file_base_url + error_path.split(basic_path)[1],}
|
|
|
if __name__ == "__main__":
|
|
|
uvicorn.run(app, host="0.0.0.0", port=port)
|