123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106 |
- from funasr import AutoModel
- from funasr.utils.postprocess_utils import rich_transcription_postprocess
- import torch
- from pathlib import Path
- from utils.logger_config import setup_logger
- logger = setup_logger(__name__)
- class SenseVoiceTranscriber:
- def __init__(self, model_dir="/data/data/luosy/models/iic/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn"):
- """
- Initialize SenseVoice transcriber
-
- Args:
- model_dir: Model directory or huggingface model name
- """
- try:
- # Detect device
- if torch.backends.mps.is_available():
- device = "mps"
- logger.info("Using MPS acceleration")
- elif torch.cuda.is_available():
- device = "cuda:0"
- logger.info("Using CUDA acceleration")
- else:
- device = "cpu"
- logger.info("Using CPU processing")
-
- logger.info(f"Loading SenseVoice model (model={model_dir}, device={device})")
- self.model = AutoModel(
- model=model_dir,
- model_revision="v2.0.4",
- vad_model="fsmn-vad", # 语音活动检测模型,切割长音频
- vad_model_revision="v2.0.4",
- punc_model="ct-punc-c", # 语音标点模型,添加标点符号
- punc_model_revision="v2.0.4",
- spk_model="cam++", # 语音识别模型,判断说话人
- trust_remote_code=True,
- disable_update=True,
- vad_kwargs={
- "max_single_segment_time": 15000,
- "min_duration": 500,
- "speech_pad": 300
- },
- punc_kwargs={
- "window_size": 128,
- "period_symbol": "。"
- },
- spk_kwargs={
- "spk_threshold": 0.7
- },
- device=device
- )
-
- logger.info("SenseVoice model loaded successfully")
-
- except Exception as e:
- logger.error(f"Failed to load SenseVoice model: {str(e)}")
- raise
- def transcribe(self, audio_path: str) -> str:
- """
- Transcribe audio to text
-
- Args:
- audio_path: Path to audio file
- Returns:
- str: Transcribed text or empty string if no speech detected
- """
- try:
- logger.info(f"开始处理音频文件: {audio_path}")
-
- # Generate transcription with no gradient computation
- with torch.no_grad():
- res = self.model.generate(
- input=audio_path,
- cache={},
- speaker_info={"spk_num": 2},
- language="zh",
- use_itn=True,
- batch_size_s=30,
- hotword=["材质", "面料", "版型", "合身"],
- beam_size=20,
- merge_vad=True,
- merge_length_s=10,
- without_timestamps=False,
- ban_emo_unk=True,
- sentence_timestamp=True
- )
-
- # Return empty string if no results
- if not res or not res[0].get("text"):
- logger.info("No speech detected")
- return ""
-
- # Get transcription text
- # transcript = rich_transcription_postprocess(res[0]["timestamp"])
- transcript = res[0]
- logger.info("STT执行完成!")
- # logger.debug(f"Transcription result:\n{transcript}")
-
- return transcript
-
- except Exception as e:
- logger.error(f"Audio processing failed: {str(e)}")
- return "" # Return empty string on error
|