123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278 |
- import os
- import re
- import time
- import json
- from pathlib import Path
- from collections import Counter
- from typing import List, Dict
- from tqdm import tqdm
- import torch
- from transformers import AutoModelForCausalLM
- from volcenginesdkarkruntime import Ark
- from modules.call_sencevoice import process_audio
- from .common import filter_oral_data, label_data, filter_label_data
- from utils.common import del_key
- from .path_config import PathConfig
- path_config = PathConfig()
- from .logger_config import setup_logger
- logger = setup_logger(__name__)
- client = Ark(
- api_key= "817dff39-5586-4f9b-acba-55004167c0b1",
- base_url="https://ark.cn-beijing.volces.com/api/v3"
- )
- def speaker_extract(audio_path):
- with open(audio_path, 'r') as file:
- data = json.load(file)
- sentences = data["sentence_info"]
-
- # 只保留有效的说话者
- valid_speakers = {item['spk'] for item in sentences}
- spk_counts = Counter(d['spk'] for d in sentences if d['spk'] in valid_speakers)
-
- if not spk_counts:
- return []
- # 找到最常见的说话者
- most_common_spk = spk_counts.most_common(1)[0][0]
- result = [d for d in sentences if d['spk'] == most_common_spk]
- logger.info(f"筛选出主播语句:{len(result)} 条")
- return result
- def filter_text_len(data_list: List[Dict], min_length: int = 4) -> List[Dict]:
- return [
- item for item in data_list
- if len(item.get('text', '')) >= min_length
- ]
- def exclude_word_list(data_list, exclude_words=None):
- if exclude_words is None:
- with open(path_config.get_path("exclude_word_json"), 'r', encoding='utf-8') as file:
- exclude_words = json.load(file)["exclude_word"]
-
- # 编译正则表达式模式,匹配任意排除词
- pattern = re.compile("|".join(map(re.escape, exclude_words)))
-
- result = [
- item for item in data_list
- if not pattern.search(item.get("text", ""))
- ]
- logger.info(f"筛选后主播语句:{len(result)} 条")
- return result
- def include_word_list(data_list, include_words=None):
- # 如果没有传入include_words参数,则使用默认的include_words列表
- if include_words is None:
- with open(path_config.get_path("include_word_json"), 'r', encoding='utf-8') as file:
- include_words = json.load(file)["include_word"]
-
- # 构建正则表达式模式(自动转义特殊字符)
- pattern = re.compile("|".join(map(re.escape, include_words)))
-
- # 返回包含指定关键词的列表
- result = [
- item for item in data_list
- if pattern.search(item.get("text", ""))
- ]
- logger.info(f"筛留后主播语句:{len(result)} 条")
- return result
- class TextClassifier:
- def __init__(self, model_path=path_config.get_path('Megrez-3B-Omni')):
- self.model = None
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- self.load_model(model_path)
-
- def load_model(self, model_path):
- print(f"Loading model from {model_path}...")
- start = time.time()
-
- self.model = AutoModelForCausalLM.from_pretrained(
- model_path,
- trust_remote_code=True,
- torch_dtype=torch.bfloat16,
- # attn_implementation="flash_attention_2"
- ).eval().to(self.device)
- self.model = torch.compile(self.model)
- print(f"Model loaded in {time.time() - start:.2f}s")
- def classify(self, text, max_new_tokens=50, sampling=False, verbose=False):
- if not self.model:
- raise ValueError("Model not loaded. Call load_model() first.")
- prompt = """
- ## 任务:从输入文本中进行关键信息提取,需要提取**面料**、**颜色**、**版型**、**工艺**这四个关键信息。
- ## 背景:
- - 面料:衣服采用的面料材质。
- - 颜色:衣服的颜色。
- - 版型:衣服的版型是指服装的整体结构、剪裁方式及尺寸比例关系。
- - 工艺:涉及衣服设计、材料处理、裁剪、缝制、装饰、定型等。
- ## 输出格式:{"面料": "无则输出None", "颜色": "无则输出None", "版型": "无则输出None", "工艺": "无则输出None"}
- ## 输出案例:{"面料": "羊毛", "颜色": "None", "版型": "None", "工艺": "None"}
- ## 要求:必须以JSON格式输出提取的结果
- ## 输入文本:
- """
- messages = [{
- "role": "user",
- "content": {
- "text": f"{prompt}{text}",
- }
- }]
- start = time.time()
- response = self.model.chat(
- messages,
- sampling=sampling,
- temperature=0.0,
- repetition_penalty=1.0,
- max_new_tokens=max_new_tokens
- )
-
- if verbose:
- print(f"Inference time: {time.time() - start:.4f}s")
-
- return response.split('<|turn_end|>')[0]
- def __del__(self):
- """释放模型资源"""
- try:
- if hasattr(self, 'model') and self.model is not None:
- del self.model
- if torch.cuda.is_available(): # 关键:检查CUDA是否可用
- torch.cuda.empty_cache()
- except Exception as e:
- pass # 避免析构函数抛出异常干扰主流程
- def text_classifer(user_prompt):
- system_prompt = """
- ## 任务:判断输入的文本是否在讲解衣服特性、属性
- ## 背景知识:
- ---
- **正例:**
- - 这是假两件的款式
- - 我采用的来自澳大利亚进口的美丽诺羊毛是羊毛中的天花板
- - 采用立体裁切,A字版型
- - 100%新疆长绒棉,亲肤透气,久穿不易起球变形。
- - 高腰A字裙版型,腰线提升视觉比例,下摆微蓬显腿细。
- - 超短上衣+低腰裤组合,五五分身材慎选,易显腿短
- ---
- **反例:** 没有说明衣服属性、特性的具体内容。
- - 看看喜欢的款式。
- - 你可以去搜去问去找羊毛,
- - 对我们来说工艺更难,
- - 顶梁柱面料。
- - 利亚在冬天的招牌面料自然不可能便宜。
- - 今天水洗绵羊毛的这条背心裙以后来一千五只有一条。
- ---
- ## 输出格式:{"讲解衣服": //<文本是否在讲解衣服特性、属性,取值范围:是、否>}
- ## 要求:必须以JSON格式输出提取的结果
- ## 注意事项:
- - 如果是讲衣服便宜实惠,则输出:{"讲解衣服": "否"}
- - 如果没有讲解出衣服属性、特性的实质内容,则输出:{"讲解衣服": "否"}
- """
- completion = client.chat.completions.create(
- messages = [
- {"role": "system", "content": system_prompt},
- {"role": "user", "content": user_prompt},
- ],
- model="ep-20241018084532-cgm84", # ep-20241018084532-cgm84 deepseek-v3-241226
- temperature = 0.01,
- max_tokens = 200
- )
- return completion.choices[0].message.content
- def save_json(path: Path, data: dict, raw_video: str = None):
- """保存JSON文件的辅助函数"""
- content = {"raw_video": raw_video, "oral_dict_list": data} if raw_video else data
- with open(path, "w", encoding="utf-8") as f:
- json.dump(content, f, ensure_ascii=False, indent=4)
- def audio_analysis_pipeline(video_path: str):
- audio_name = Path(video_path).stem + ".json"
- try:
- # 1. 音频处理和保存
- transcript = process_audio(video_path)
- audio_path = path_config.get_path('audio_json') / audio_name
- save_json(audio_path, transcript)
- del_key(audio_path)
- # 2. 语句筛选流程
- speaker_data = speaker_extract(audio_path)
- filtered_len = filter_text_len(speaker_data)
- # 3. 文本过滤
- # 筛除文本
- excluded_data = exclude_word_list(filtered_len)
- filter_1_path = path_config.get_path('output_filter_1') / audio_name
- save_json(filter_1_path, excluded_data, video_path)
- # 筛留文本
- included_data = include_word_list(excluded_data)
- filter_2_path = path_config.get_path('output_filter_2') / audio_name
- save_json(filter_2_path, included_data, video_path)
- # 4. 文本分类和标注
- logger.info("开始语句打标")
- classifier = TextClassifier()
- labeled_data = [
- {**text, "attribute": classifier.classify(text["text"])}
- for text in tqdm(included_data, desc="Processing texts")
- ]
- # 保存一次打标结果
- filter_3_path = path_config.get_path('output_filter_3') / audio_name
- save_json(filter_3_path, labeled_data, video_path)
- # 5. 二次处理流程
- # 筛选衣服介绍语句
- filter_oral_data(filter_3_path)
-
- # 二次打标
- label_data(filter_3_path)
-
- # 二次筛选
- filter_4_path = path_config.get_path('output_filter_4') / audio_name
- filter_label_data(filter_4_path)
-
- logger.info("音频分析流水线处理完成")
- except Exception as e:
- logger.error(f"处理文件 {audio_name} 时发生错误: {str(e)}")
- raise
- if __name__ == "__main__":
- # 初始化分类器(模型只会加载一次)
- # classifier = TextClassifier()
- # 执行多次分类
- texts = ["这个颜色很好看", "这款纯羊毛的衣服自然不便宜"]
- for text in texts:
- result = text_classifer(text) # classifier.classify(text) text_classifer(text)
- print(f"Input: {text}\nResult: {result}\n")
|