voice_recognition.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. from funasr import AutoModel
  2. from funasr.utils.postprocess_utils import rich_transcription_postprocess
  3. import torch
  4. from pathlib import Path
  5. from utils.logger_config import setup_logger
  6. logger = setup_logger(__name__)
  7. class SenseVoiceTranscriber:
  8. def __init__(self, model_dir="/data/data/luosy/models/iic/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn"):
  9. """
  10. Initialize SenseVoice transcriber
  11. Args:
  12. model_dir: Model directory or huggingface model name
  13. """
  14. try:
  15. # Detect device
  16. if torch.backends.mps.is_available():
  17. device = "mps"
  18. logger.info("Using MPS acceleration")
  19. elif torch.cuda.is_available():
  20. device = "cuda:0"
  21. logger.info("Using CUDA acceleration")
  22. else:
  23. device = "cpu"
  24. logger.info("Using CPU processing")
  25. logger.info(f"Loading SenseVoice model (model={model_dir}, device={device})")
  26. self.model = AutoModel(
  27. model=model_dir,
  28. model_revision="v2.0.4",
  29. vad_model="fsmn-vad", # 语音活动检测模型,切割长音频
  30. vad_model_revision="v2.0.4",
  31. punc_model="ct-punc-c", # 语音标点模型,添加标点符号
  32. punc_model_revision="v2.0.4",
  33. spk_model="cam++", # 语音识别模型,判断说话人
  34. trust_remote_code=True,
  35. disable_update=True,
  36. vad_kwargs={
  37. "max_single_segment_time": 15000,
  38. "min_duration": 500,
  39. "speech_pad": 300
  40. },
  41. punc_kwargs={
  42. "window_size": 128,
  43. "period_symbol": "。"
  44. },
  45. spk_kwargs={
  46. "spk_threshold": 0.7
  47. },
  48. device=device
  49. )
  50. logger.info("SenseVoice model loaded successfully")
  51. except Exception as e:
  52. logger.error(f"Failed to load SenseVoice model: {str(e)}")
  53. raise
  54. def transcribe(self, audio_path: str) -> str:
  55. """
  56. Transcribe audio to text
  57. Args:
  58. audio_path: Path to audio file
  59. Returns:
  60. str: Transcribed text or empty string if no speech detected
  61. """
  62. try:
  63. logger.info(f"开始处理音频文件: {audio_path}")
  64. # Generate transcription with no gradient computation
  65. with torch.no_grad():
  66. res = self.model.generate(
  67. input=audio_path,
  68. cache={},
  69. speaker_info={"spk_num": 2},
  70. language="zh",
  71. use_itn=True,
  72. batch_size_s=30,
  73. hotword=["材质", "面料", "版型", "合身"],
  74. beam_size=20,
  75. merge_vad=True,
  76. merge_length_s=10,
  77. without_timestamps=False,
  78. ban_emo_unk=True,
  79. sentence_timestamp=True
  80. )
  81. # Return empty string if no results
  82. if not res or not res[0].get("text"):
  83. logger.info("No speech detected")
  84. return ""
  85. # Get transcription text
  86. # transcript = rich_transcription_postprocess(res[0]["timestamp"])
  87. transcript = res[0]
  88. logger.info("STT执行完成!")
  89. # logger.debug(f"Transcription result:\n{transcript}")
  90. return transcript
  91. except Exception as e:
  92. logger.error(f"Audio processing failed: {str(e)}")
  93. return "" # Return empty string on error