audio_analysis.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. import os
  2. import re
  3. import time
  4. import json
  5. from pathlib import Path
  6. from collections import Counter
  7. from typing import List, Dict
  8. from tqdm import tqdm
  9. import torch
  10. from transformers import AutoModelForCausalLM
  11. from volcenginesdkarkruntime import Ark
  12. from modules.call_sencevoice import process_audio
  13. from .common import filter_oral_data, label_data, filter_label_data
  14. from utils.common import del_key
  15. from .path_config import PathConfig
  16. path_config = PathConfig()
  17. from .logger_config import setup_logger
  18. logger = setup_logger(__name__)
  19. client = Ark(
  20. api_key= "817dff39-5586-4f9b-acba-55004167c0b1",
  21. base_url="https://ark.cn-beijing.volces.com/api/v3"
  22. )
  23. def speaker_extract(audio_path):
  24. with open(audio_path, 'r') as file:
  25. data = json.load(file)
  26. sentences = data["sentence_info"]
  27. # 只保留有效的说话者
  28. valid_speakers = {item['spk'] for item in sentences}
  29. spk_counts = Counter(d['spk'] for d in sentences if d['spk'] in valid_speakers)
  30. if not spk_counts:
  31. return []
  32. # 找到最常见的说话者
  33. most_common_spk = spk_counts.most_common(1)[0][0]
  34. result = [d for d in sentences if d['spk'] == most_common_spk]
  35. logger.info(f"筛选出主播语句:{len(result)} 条")
  36. return result
  37. def filter_text_len(data_list: List[Dict], min_length: int = 4) -> List[Dict]:
  38. return [
  39. item for item in data_list
  40. if len(item.get('text', '')) >= min_length
  41. ]
  42. def exclude_word_list(data_list, exclude_words=None):
  43. if exclude_words is None:
  44. with open(path_config.get_path("exclude_word_json"), 'r', encoding='utf-8') as file:
  45. exclude_words = json.load(file)["exclude_word"]
  46. # 编译正则表达式模式,匹配任意排除词
  47. pattern = re.compile("|".join(map(re.escape, exclude_words)))
  48. result = [
  49. item for item in data_list
  50. if not pattern.search(item.get("text", ""))
  51. ]
  52. logger.info(f"筛选后主播语句:{len(result)} 条")
  53. return result
  54. def include_word_list(data_list, include_words=None):
  55. # 如果没有传入include_words参数,则使用默认的include_words列表
  56. if include_words is None:
  57. with open(path_config.get_path("include_word_json"), 'r', encoding='utf-8') as file:
  58. include_words = json.load(file)["include_word"]
  59. # 构建正则表达式模式(自动转义特殊字符)
  60. pattern = re.compile("|".join(map(re.escape, include_words)))
  61. # 返回包含指定关键词的列表
  62. result = [
  63. item for item in data_list
  64. if pattern.search(item.get("text", ""))
  65. ]
  66. logger.info(f"筛留后主播语句:{len(result)} 条")
  67. return result
  68. class TextClassifier:
  69. def __init__(self, model_path=path_config.get_path('Megrez-3B-Omni')):
  70. self.model = None
  71. self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  72. self.load_model(model_path)
  73. def load_model(self, model_path):
  74. print(f"Loading model from {model_path}...")
  75. start = time.time()
  76. self.model = AutoModelForCausalLM.from_pretrained(
  77. model_path,
  78. trust_remote_code=True,
  79. torch_dtype=torch.bfloat16,
  80. # attn_implementation="flash_attention_2"
  81. ).eval().to(self.device)
  82. self.model = torch.compile(self.model)
  83. print(f"Model loaded in {time.time() - start:.2f}s")
  84. def classify(self, text, max_new_tokens=50, sampling=False, verbose=False):
  85. if not self.model:
  86. raise ValueError("Model not loaded. Call load_model() first.")
  87. prompt = """
  88. ## 任务:从输入文本中进行关键信息提取,需要提取**面料**、**颜色**、**版型**、**工艺**这四个关键信息。
  89. ## 背景:
  90. - 面料:衣服采用的面料材质。
  91. - 颜色:衣服的颜色。
  92. - 版型:衣服的版型是指服装的整体结构、剪裁方式及尺寸比例关系。
  93. - 工艺:涉及衣服设计、材料处理、裁剪、缝制、装饰、定型等。
  94. ## 输出格式:{"面料": "无则输出None", "颜色": "无则输出None", "版型": "无则输出None", "工艺": "无则输出None"}
  95. ## 输出案例:{"面料": "羊毛", "颜色": "None", "版型": "None", "工艺": "None"}
  96. ## 要求:必须以JSON格式输出提取的结果
  97. ## 输入文本:
  98. """
  99. messages = [{
  100. "role": "user",
  101. "content": {
  102. "text": f"{prompt}{text}",
  103. }
  104. }]
  105. start = time.time()
  106. response = self.model.chat(
  107. messages,
  108. sampling=sampling,
  109. temperature=0.0,
  110. repetition_penalty=1.0,
  111. max_new_tokens=max_new_tokens
  112. )
  113. if verbose:
  114. print(f"Inference time: {time.time() - start:.4f}s")
  115. return response.split('<|turn_end|>')[0]
  116. def __del__(self):
  117. """释放模型资源"""
  118. try:
  119. if hasattr(self, 'model') and self.model is not None:
  120. del self.model
  121. if torch.cuda.is_available(): # 关键:检查CUDA是否可用
  122. torch.cuda.empty_cache()
  123. except Exception as e:
  124. pass # 避免析构函数抛出异常干扰主流程
  125. def text_classifer(user_prompt):
  126. system_prompt = """
  127. ## 任务:判断输入的文本是否在讲解衣服特性、属性
  128. ## 背景知识:
  129. ---
  130. **正例:**
  131. - 这是假两件的款式
  132. - 我采用的来自澳大利亚进口的美丽诺羊毛是羊毛中的天花板
  133. - 采用立体裁切,A字版型
  134. - 100%新疆长绒棉,亲肤透气,久穿不易起球变形。
  135. - 高腰A字裙版型,腰线提升视觉比例,下摆微蓬显腿细。
  136. - 超短上衣+低腰裤组合,五五分身材慎选,易显腿短
  137. ---
  138. **反例:** 没有说明衣服属性、特性的具体内容。
  139. - 看看喜欢的款式。
  140. - 你可以去搜去问去找羊毛,
  141. - 对我们来说工艺更难,
  142. - 顶梁柱面料。
  143. - 利亚在冬天的招牌面料自然不可能便宜。
  144. - 今天水洗绵羊毛的这条背心裙以后来一千五只有一条。
  145. ---
  146. ## 输出格式:{"讲解衣服": //<文本是否在讲解衣服特性、属性,取值范围:是、否>}
  147. ## 要求:必须以JSON格式输出提取的结果
  148. ## 注意事项:
  149. - 如果是讲衣服便宜实惠,则输出:{"讲解衣服": "否"}
  150. - 如果没有讲解出衣服属性、特性的实质内容,则输出:{"讲解衣服": "否"}
  151. """
  152. completion = client.chat.completions.create(
  153. messages = [
  154. {"role": "system", "content": system_prompt},
  155. {"role": "user", "content": user_prompt},
  156. ],
  157. model="ep-20241018084532-cgm84", # ep-20241018084532-cgm84 deepseek-v3-241226
  158. temperature = 0.01,
  159. max_tokens = 200
  160. )
  161. return completion.choices[0].message.content
  162. def save_json(path: Path, data: dict, raw_video: str = None):
  163. """保存JSON文件的辅助函数"""
  164. content = {"raw_video": raw_video, "oral_dict_list": data} if raw_video else data
  165. with open(path, "w", encoding="utf-8") as f:
  166. json.dump(content, f, ensure_ascii=False, indent=4)
  167. def audio_analysis_pipeline(video_path: str):
  168. audio_name = Path(video_path).stem + ".json"
  169. try:
  170. # 1. 音频处理和保存
  171. transcript = process_audio(video_path)
  172. audio_path = path_config.get_path('audio_json') / audio_name
  173. save_json(audio_path, transcript)
  174. del_key(audio_path)
  175. # 2. 语句筛选流程
  176. speaker_data = speaker_extract(audio_path)
  177. filtered_len = filter_text_len(speaker_data)
  178. # 3. 文本过滤
  179. # 筛除文本
  180. excluded_data = exclude_word_list(filtered_len)
  181. filter_1_path = path_config.get_path('output_filter_1') / audio_name
  182. save_json(filter_1_path, excluded_data, video_path)
  183. # 筛留文本
  184. included_data = include_word_list(excluded_data)
  185. filter_2_path = path_config.get_path('output_filter_2') / audio_name
  186. save_json(filter_2_path, included_data, video_path)
  187. # 4. 文本分类和标注
  188. logger.info("开始语句打标")
  189. classifier = TextClassifier()
  190. labeled_data = [
  191. {**text, "attribute": classifier.classify(text["text"])}
  192. for text in tqdm(included_data, desc="Processing texts")
  193. ]
  194. # 保存一次打标结果
  195. filter_3_path = path_config.get_path('output_filter_3') / audio_name
  196. save_json(filter_3_path, labeled_data, video_path)
  197. # 5. 二次处理流程
  198. # 筛选衣服介绍语句
  199. filter_oral_data(filter_3_path)
  200. # 二次打标
  201. label_data(filter_3_path)
  202. # 二次筛选
  203. filter_4_path = path_config.get_path('output_filter_4') / audio_name
  204. filter_label_data(filter_4_path)
  205. logger.info("音频分析流水线处理完成")
  206. except Exception as e:
  207. logger.error(f"处理文件 {audio_name} 时发生错误: {str(e)}")
  208. raise
  209. if __name__ == "__main__":
  210. # 初始化分类器(模型只会加载一次)
  211. # classifier = TextClassifier()
  212. # 执行多次分类
  213. texts = ["这个颜色很好看", "这款纯羊毛的衣服自然不便宜"]
  214. for text in texts:
  215. result = text_classifer(text) # classifier.classify(text) text_classifer(text)
  216. print(f"Input: {text}\nResult: {result}\n")