wangdalin 2 달 전
커밋
eb364d3df8

+ 1 - 0
.gitignore

@@ -0,0 +1 @@
+/dirs

BIN
__pycache__/agent.cpython-310.pyc


BIN
__pycache__/config.cpython-310.pyc


BIN
__pycache__/embedding_search.cpython-310.pyc


BIN
__pycache__/file_process.cpython-310.pyc


BIN
__pycache__/milvus_process.cpython-310.pyc


BIN
__pycache__/prompt.cpython-310.pyc


BIN
__pycache__/util.cpython-310.pyc


+ 41 - 0
agent.py

@@ -0,0 +1,41 @@
+import autogen, json
+from config import llm_config, llm_config_ds
+from prompt import get_summary_system
+search_answer = autogen.AssistantAgent(
+    name="summary_content",
+    llm_config=llm_config_ds,
+    system_message=get_summary_system,
+    code_execution_config=False,
+    human_input_mode="NEVER",
+)
+async def get_content_summary(question, res_info, final_data):
+    try:
+        data = list(final_data.items())
+        final_chunks = {}
+        prompt = "问题:\n" + question + '\n资料信息:\n' + res_info
+        answer = await search_answer.a_generate_reply(messages=[{'role':'user', 'content': prompt}])
+        print(answer)
+        print(type(answer))
+        if '```json' in answer:
+            answer = answer.split('```json')[1].split('```')[0]
+            answer = json.loads(answer)
+        elif '{' in answer:
+            answer = answer.split('{')[1].split('}')[0]
+            answer = json.loads("{" + answer + "}")
+        print(answer)
+        for k,v in answer.items():
+            final_chunks[data[int(k)-1][0]] = v
+        search_str = "\n".join([f"[{i+1}]: \n {k}: \n{s} \n" for i, (k,s) in enumerate(final_chunks.items())])
+        return final_chunks, search_str
+    except Exception as e:
+        print(f'重写报错: {str(e)}')
+        return {}, res_info
+        
+    
+
+if __name__ == '__main__':
+    import asyncio
+
+    answer = asyncio.run(search_answer.a_generate_reply(messages=[{'role':'user', 'content': '英国的首都在哪里'}]))
+    print(answer)
+    pass

+ 67 - 0
client.py

@@ -0,0 +1,67 @@
+import argparse
+import requests
+import json
+
+BASE_URL = "http://localhost:5666"
+
+def chat_with_server(client_id: str, prompt: str, history: str):
+    url = f"{BASE_URL}/chat"
+    print(history)
+    print(type(history))
+    history = '[{"role":"user", "content":"你好"}]'
+    data = {
+        "client_id": client_id,
+        "prompt": prompt,
+        "history": history
+    }
+    response = requests.post(url, data=data)
+    return response.json()
+
+def upload_file_to_server(client_id: str, file_path: str):
+    url = f"{BASE_URL}/uploadfile/"
+    files = {"file": open(file_path, "rb")}
+    data = {"client_id": client_id}
+    response = requests.post(url, files=files, data=data)
+    return response.json()
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--mode", choices=["chat", "upload"], required=True, help="选择调用的接口")
+    parser.add_argument("--client_id", required=True, help="客户端ID")
+    parser.add_argument("--prompt", help="用户输入的问题,仅用于chat模式")
+    parser.add_argument("--history", default='[]', help="历史消息JSON字符串,仅用于chat模式")
+    parser.add_argument("--file", help="上传的文件或文件夹路径,仅用于upload模式")
+    
+    args = parser.parse_args()
+    
+    if args.mode == "chat":
+        if not args.prompt:
+            print("Error: --prompt 参数在 chat 模式下是必须的!")
+            return
+        response = chat_with_server(args.client_id, args.prompt, args.history)
+        print("Chat Response:", response)
+    elif args.mode == "upload":
+        if not args.file:
+            print("Error: --file 参数在 upload 模式下是必须的!")
+            return
+            
+        import os
+        if os.path.isfile(args.file):
+            files = [args.file]
+        elif os.path.isdir(args.file):
+            files = [os.path.join(args.file, f) for f in os.listdir(args.file) 
+                    if os.path.isfile(os.path.join(args.file, f))]
+        else:
+            print(f"Error: 路径 {args.file} 不存在或不可访问!")
+            return
+            
+        for file_path in files:
+            try:
+                print(f"\n正在上传文件: {file_path}")
+                response = upload_file_to_server(args.client_id, file_path)
+                print(f"上传结果: {response}")
+            except Exception as e:
+                print(f"文件 {file_path} 上传失败: {str(e)}")
+
+if __name__ == "__main__":
+    main()

+ 60 - 0
config.py

@@ -0,0 +1,60 @@
+import argparse
+from transformers import AutoTokenizer, AutoModelForSequenceClassification
+from FlagEmbedding import BGEM3FlagModel
+from fastapi import WebSocket
+from util import get_cus_logger
+from marker.converters.pdf import PdfConverter
+from marker.models import create_model_dict
+def load_parse_args():
+    # 创建 ArgumentParser 对象
+    parser = argparse.ArgumentParser(description="Process some integers.")
+    parser.add_argument('--base_url', type=str, default="https://dashscope.aliyuncs.com/compatible-mode/v1",help='llm base_url')
+    parser.add_argument('--api_key', type=str, default="sk-04b63960983445f980d85ff185a17876",help='llm api_key')
+    parser.add_argument('--model', type=str, choices=['qwen-max', 'gpt-4'], default='qwen-max', help='The model to use')
+    parser.add_argument('--static_dir', type=str, default="/workspace", help='the directory for the code to work')
+
+    # 解析命令行参数
+    args = parser.parse_args()
+    return args
+
+nltk_path = '/root/nltk_data/tokenizers'
+args = load_parse_args()
+static_dir = '/workspace'
+llm_config={
+    "config_list": [
+        {
+            "model": args.model, # Same as in vLLM command
+            "api_key": args.api_key, # Not needed
+            "base_url": args.base_url  # Your vLLM URL, with '/v1' added
+        }
+    ],
+    "cache_seed": None, # Turns off caching, useful for testing different models
+    "temperature": 0.5,
+}
+llm_config_ds={
+    "config_list": [
+        {
+            "model": 'deepseek-r1', # Same as in vLLM command
+            "api_key": args.api_key, # Not needed
+            "base_url": args.base_url  # Your vLLM URL, with '/v1' added
+        }
+    ],
+    "cache_seed": None, # Turns off caching, useful for testing different models
+    "temperature": 0.5,
+}
+
+milvus_url = "http://10.41.1.57:19530"
+
+bge_model_path = '/model/bge-m3'
+bge_rerank_path = '/model/bge-reranker-v2-m3'
+BASE_UPLOAD_DIRECTORY = '/workspace'
+upload_path = """/workspace/{client_id}/"""
+converter = PdfConverter(
+    artifact_dict=create_model_dict(),
+)
+
+bge_model = BGEM3FlagModel(bge_model_path, use_fp16=True, device='cuda:0') # Setting use_fp16 to True speeds up computation with a slight performance degradation
+bge_rerank_tokenizer = AutoTokenizer.from_pretrained(bge_rerank_path)
+bge_rerank_model = AutoModelForSequenceClassification.from_pretrained(bge_rerank_path)
+bge_rerank_model.to('cuda:0')
+bge_rerank_model.eval()

+ 74 - 0
embedding_search.py

@@ -0,0 +1,74 @@
+from transformers import AutoTokenizer, AutoModel
+import torch
+import numpy as np
+from config import bge_model, bge_rerank_model, bge_rerank_tokenizer, logger_search
+def model_embedding_bge(chunk):
+    try:
+        if isinstance(chunk, str):
+            sentences = [chunk]
+        else:
+            sentences = chunk
+        # 尝试以更安全的方式调用 bge_model.encode
+        try:
+            # 检查模型是否有 "use_fp16" 属性并暂时禁用
+            temp_use_fp16 = None
+            if hasattr(bge_model, 'use_fp16'):
+                temp_use_fp16 = bge_model.use_fp16
+                bge_model.use_fp16 = False
+                
+            embeddings = bge_model.encode(
+                sentences,
+                batch_size=12,
+                max_length=8192,
+            )['dense_vecs']
+            result = []
+            for emd in embeddings:
+                result.append(emd.tolist())
+            # 恢复原始设置
+            if temp_use_fp16 is not None:
+                bge_model.use_fp16 = temp_use_fp16
+                
+        except RuntimeError as e:
+            if "expected scalar type Float but found Half" in str(e):
+                # 尝试将模型转为 float32 并重试
+                if hasattr(bge_model, 'model') and hasattr(bge_model.model, 'float'):
+                    bge_model.model.float()
+                    
+                # 禁用 fp16
+                if hasattr(bge_model, 'use_fp16'):
+                    bge_model.use_fp16 = False
+                    
+                embeddings = bge_model.encode(
+                    sentences,
+                    batch_size=12,
+                    max_length=8192,
+                )['dense_vecs']
+                result = []
+                for emd in embeddings:
+                    result.append(emd.tolist())
+            else:
+                raise e
+        return result if len(result) > 1 else result[0]
+    except Exception as e:
+        raise e
+
+def fill_embed_nan(vector):
+    return np.nan_to_num(vector, nan=0.0, posinf=1.0, neginf=-1.0)
+
+
+def bge_rerank(query, indicates, n=5, socre=0.6):
+    
+    with torch.no_grad():
+        results = {}
+        for chos in indicates:
+            inputs_1 = bge_rerank_tokenizer([[query, chos]], padding=True, truncation=True, return_tensors='pt', max_length=512).to(device='cuda:0')
+            scores_1 = bge_rerank_model(**inputs_1, return_dict=True).logits.view(-1, ).float()[0]
+            probs = torch.sigmoid(scores_1)
+            if probs > socre:
+                results[chos] = probs
+        sorted_results = sorted(results.items(), key=lambda x: x[1], reverse=True)
+        final_result = [x[0] for x in sorted_results]
+        return final_result
+
+if __name__ == '__main__':
+    pass

+ 325 - 0
file_process.py

@@ -0,0 +1,325 @@
+import PyPDF2
+import docx
+import nltk, subprocess, os
+from config import nltk_path, converter
+from typing import List, Union
+from pathlib import Path
+import re
+import pandas as pd
+from nltk.tokenize import sent_tokenize
+import spacy
+from langdetect import detect
+from milvus_process import update_mulvus_file
+import fitz
+from marker.output import text_from_rendered
+try:
+    nltk.data.find(nltk_path)
+except LookupError:
+    nltk.download('punkt')
+try:
+    nlp = spacy.load("zh_core_web_sm")
+except OSError:
+    pass
+class DocumentProcessor:
+    def __init__(self,):
+        pass
+    def get_file_len(self, file_path: Union[str, Path]) -> int:
+        text = self.read_file(file_path)
+        length = len(text)
+        del text
+        return length
+    
+    @staticmethod
+    def convert_docx_to_doc(input_path):
+        output_dir = os.path.dirname(input_path)
+        command = [
+            "soffice", "--headless", "--convert-to", "docx", input_path, "--outdir", output_dir
+        ]
+        subprocess.run(command, check=True)
+        
+    def _read_doc(self, file_path) -> str:
+        """读取Word文档"""
+        self.convert_docx_to_doc(file_path)
+        old_file = Path(file_path)  # 原始 .doc 文件
+        new_file = old_file.with_suffix(".docx")  # 转换后的 .docx 文件
+
+        if old_file.exists():  # 确保旧文件存在
+            old_file.unlink()  # 删除旧文件
+
+        doc = docx.Document(new_file)  # 读取 .docx
+        text = "\n".join([paragraph.text for paragraph in doc.paragraphs])
+        return self.create_chunks(text)
+    
+    def read_file(self, file_path: Union[str, Path]) -> str:
+        """
+        读取不同格式的文档
+        
+        Args:
+            file_path: 文件路径
+            
+        Returns:
+            str: 提取的文本内容
+        """
+        file_path = Path(file_path)
+        extension = file_path.suffix.lower()
+        
+        if extension == '.pdf':
+            return self._read_pdf(file_path)
+        elif extension == '.docx':
+            return self._read_docx(file_path)
+        elif extension == '.doc':
+            return self._read_doc(file_path)
+        elif extension == '.txt':
+            return self._read_txt(file_path)
+        elif extension == '.csv':
+            return self._read_csv(file_path)
+        elif extension == '.xlsx':
+            return self._read_excel(file_path)
+        else:
+            raise ValueError(f"Unsupported file format: {extension}")
+    
+    def _read_pdf(self, file_path) -> str:
+        """读取PDF文件"""
+        rendered = converter(str(file_path))
+        text, x, images = text_from_rendered(rendered)
+        return self.create_chunks(text=text)
+            
+    def _read_docx(self, file_path: Path) -> str:
+        """读取Word文档"""
+        doc = docx.Document(file_path)
+        text = "\n".join([paragraph.text for paragraph in doc.paragraphs])
+        return self.create_chunks(text)
+    
+    def _read_txt(self, file_path: Path) -> str:
+        """读取文本文件"""
+        with open(file_path, 'r', encoding='utf-8') as file:
+            return self.create_chunks(file.read())
+    
+    def _read_excel(self, file_path: Path) -> str:
+        """读取Excel文件"""
+        df = pd.read_excel(file_path, sheet_name=None)
+        text = ""
+        for sheet_name, sheet_df in df.items():
+            text += f"\nSheet: {sheet_name}\n"
+            text += sheet_df.to_csv(index=False, sep=' ', header=True)
+        return self.create_chunks(text)
+
+    def _read_csv(self, file_path: Path) -> str:
+        """读取CSV文件"""
+        df = pd.read_csv(file_path)
+        return self.create_chunks(df.to_csv(index=False, sep=' ', header=True))
+
+    
+    def _clean_text(self, text: str) -> str:
+        """
+        清理文本
+        - 移除多余的空白字符
+        - 标准化换行符
+        """
+        # 替换多个空格为单个空格
+        text = re.sub(r'\s+', ' ', text)
+        # 标准化换行符
+        text = text.replace('\r\n', '\n').replace('\r', '\n')
+        # 移除空行
+        text = '\n'.join(line.strip() for line in text.split('\n') if line.strip())
+        return text.strip()
+    
+    def split_into_sentences(self, text: str) -> List[str]:
+        """
+        将文本分割成句子
+        
+        Args:
+            text: 输入文本
+            
+        Returns:
+            List[str]: 句子列表
+        """
+        # 使用NLTK进行句子分割
+        sentences = sent_tokenize(text)
+        return sentences
+    def force_split_sentence(self, sentence: str, max_length: int) -> List[str]:
+        """
+        强制将超长句子按字符数切分
+        
+        Args:
+            sentence (str): 输入的句子
+            max_length (int): 最大长度
+            
+        Returns:
+            List[str]: 切分后的句子片段列表
+        """
+        # 使用标点符号作为次要切分点
+        punctuation = '。,;!?,.;!?'
+        parts = []
+        current_part = ''
+        
+        # 优先在标点符号处切分
+        chars = list(sentence)
+        for i, char in enumerate(chars):
+            current_part += char
+            
+            # 如果当前部分达到最大长度或遇到标点符号
+            if (len(current_part) >= max_length and char in punctuation) or \
+            (len(current_part) >= max_length * 1.2):  # 允许略微超过max_length以寻找标点
+                parts.append(current_part)
+                current_part = ''
+        
+        # 处理剩余部分
+        if current_part:
+            # 如果剩余部分仍然过长,强制按长度切分
+            while len(current_part) > max_length:
+                parts.append(current_part[:max_length] + '...')
+                current_part = '...' + current_part[max_length:]
+            parts.append(current_part)
+        
+        return parts
+
+    def split_text_nltk(self, text: str, chunk_size: int = 1500, overlap_size: int = 100) -> List[str]:
+        """
+        使用NLTK进行中文文本分割,支持文本块重叠和超长句子处理
+        
+        Args:
+            text (str): 输入的中文文本
+            chunk_size (int): 每个chunk的近似字符数
+            overlap_size (int): 相邻chunk之间的重叠字符数
+            
+        Returns:
+            List[str]: 分割后的文本块列表
+        """
+        text = self._clean_text(text)
+        sentences = nltk.sent_tokenize(text)
+        chunks = self.process_sentences(sentences=sentences, chunk_size=chunk_size, overlap_size=overlap_size)
+        return chunks
+
+    def split_text_spacy(self, text: str, chunk_size: int = 500, overlap_size: int = 100) -> List[str]:
+        """
+        使用SpaCy进行中文文本分割,支持文本块重叠和超长句子处理
+        
+        Args:
+            text (str): 输入的中文文本
+            chunk_size (int): 每个chunk的近似字符数
+            overlap_size (int): 相邻chunk之间的重叠字符数
+            
+        Returns:
+            List[str]: 分割后的文本块列表
+        """
+        text = self._clean_text(text)
+        doc = nlp(text)
+        chunks = []
+        sentences = [sent.text for sent in doc.sents]
+        chunks = self.process_sentences(sentences=sentences, chunk_size=chunk_size, overlap_size=overlap_size)
+        return chunks
+
+
+    def process_sentences(self, sentences, chunk_size: int = 500, overlap_size: int = 100):
+        chunks = []
+        current_chunk = []
+        current_length = 0
+        
+        for sentence in sentences:
+            # 处理超长句子
+            if len(sentence) > chunk_size:
+                # 先处理当前chunk中已有的内容
+                if current_chunk:
+                    chunks.append("".join(current_chunk))
+                    current_chunk = []
+                    current_length = 0
+                
+                # 强制切分超长句子
+                sentence_parts = self.force_split_sentence(sentence, chunk_size)
+                for part in sentence_parts:
+                    chunks.append(part)
+                continue
+            
+            # 正常处理普通长度的句子
+            if current_length + len(sentence) <= chunk_size:
+                current_chunk.append(sentence)
+                current_length += len(sentence)
+            else:
+                if current_chunk:
+                    chunks.append("".join(current_chunk))
+                    
+                # 处理重叠
+                overlap_chars = 0
+                overlap_sentences = []
+                for prev_sentence in reversed(current_chunk):
+                    if overlap_chars + len(prev_sentence) <= overlap_size:
+                        overlap_sentences.insert(0, prev_sentence)
+                        overlap_chars += len(prev_sentence)
+                    else:
+                        break
+                
+                current_chunk = overlap_sentences + [sentence]
+                current_length = sum(len(s) for s in current_chunk)
+        
+        if current_chunk:
+            chunks.append("".join(current_chunk))
+        
+        return chunks
+    def create_chunks(self, text: str, chunk_size=300, overlap_size=100) -> List[str]:
+        is_chinese = self.is_chinese_text(text)
+
+        if is_chinese:
+            # print('检测为中文文章, 采用spacy')
+            chunks = self.split_text_spacy(text,chunk_size=chunk_size,overlap_size=overlap_size)
+        else:
+            # print('检测为外文文章, 采用nltk')
+            chunks = self.split_text_spacy(text,chunk_size=chunk_size,overlap_size=overlap_size)
+        return chunks
+
+        
+    def is_chinese_text(self, text: str) -> bool:
+        """
+        判断文本是否主要为中文
+        
+        Args:
+            text (str): 输入文本
+            
+        Returns:
+            bool: 如果是中文文本返回True,否则返回False
+        """
+        try:
+            chinese_chars = len(re.findall(r'[\u4e00-\u9fff]', text))
+            total_chars = len(re.findall(r'\w', text)) + chinese_chars
+            char_ratio = chinese_chars / max(total_chars, 1)
+            if char_ratio > 0.1:
+                return True
+            # 使用langdetect进行语言检测
+            lang = detect(text)
+            # 如果检测失败,使用备用方法
+            if not lang:
+                raise Exception("Language detection failed")
+                
+            return lang == 'zh-cn' or lang == 'zh-tw' or lang == 'zh'
+        
+        except Exception:
+            return char_ratio > 0.1
+    def process_document(self, file_path: Union[str, Path], chunk_size=1000, overlap_size=250) -> List[str]:
+        """
+        处理文档的主方法
+        
+        Args:
+            file_path: 文档路径
+            
+        Returns:
+            List[str]: 处理后的文本块列表
+        """
+        # 读取文档
+        text = self.read_file(file_path)
+        chunks = self.create_chunks(text=text, chunk_size=chunk_size, overlap_size=overlap_size)
+        # return chunks
+        return chunks
+
+if __name__ == '__main__':
+    import asyncio
+    processor = DocumentProcessor()
+
+    # 处理文档
+    chunks = processor.read_file("./tests/test.pdf")
+    # 打印结果
+    # for i, chunk in enumerate(chunks):
+    #     print(f"Chunk {i+1}:")
+    #     print(chunk)
+    #     print(len(chunk))
+    #     print("-" * 50)
+    status = asyncio.run(update_mulvus_file(client_id='test', file_name='test.pdf',chunks=chunks))

+ 232 - 0
milvus_process.py

@@ -0,0 +1,232 @@
+from pymilvus import  DataType, FieldSchema, CollectionSchema, AnnSearchRequest, WeightedRanker, Function, FunctionType
+from config import milvus_url, logger_search
+from embedding_search import model_embedding_bge, bge_rerank
+import time, asyncio
+import traceback
+from pymilvus import MilvusClient, DataType, RRFRanker
+from util import encode_to_base64
+from typing_extensions import Annotated
+
+async def create_collection(collection_name):
+    client = MilvusClient(uri=milvus_url)
+    if client.has_collection(collection_name=collection_name):
+        # client.drop_collection(collection_name=collection_name)
+        # print('ok')
+        return False
+    fields = [
+            FieldSchema(name="client_id", dtype=DataType.VARCHAR, max_length=65535),
+            FieldSchema(name="file_name", dtype=DataType.VARCHAR, max_length=65535),
+            FieldSchema(name="filename_embedding", dtype=DataType.FLOAT_VECTOR, dim=1024),
+            FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=65535, is_primary=True, enable_analyzer=True),
+            FieldSchema(name="content_embedding", dtype=DataType.FLOAT_VECTOR, dim=1024),
+            FieldSchema(name="content_sparse", dtype=DataType.SPARSE_FLOAT_VECTOR)
+        ]
+    bm25_function = Function(
+        name="text_bm25_emb", 
+        input_field_names=["content"],
+        output_field_names=["content_sparse"], 
+        function_type=FunctionType.BM25,
+    )
+    schema = CollectionSchema(fields=fields, enable_dynamic_field=True, auto_id=False)
+    schema.add_function(bm25_function)
+    
+    index_params = client.prepare_index_params()
+    index_params.add_index(
+        field_name="client_id",
+        index_type="INVERTED"
+    )
+    index_params.add_index(
+        field_name="file_name",
+        index_type="INVERTED"
+    )
+    index_params.add_index(
+        field_name="content_embedding", 
+        index_type="IVF_FLAT",
+        metric_type="COSINE",
+        params={ "nlist": 128 }
+    )
+    index_params.add_index(
+        field_name="filename_embedding", 
+        index_type="IVF_FLAT",
+        metric_type="COSINE",
+        params={ "nlist": 128 }
+    )
+    index_params.add_index(
+        field_name="content_sparse",
+        index_type="AUTOINDEX", 
+        metric_type="BM25"
+    )
+    client.create_collection(
+    collection_name=collection_name, 
+    schema=schema, 
+    index_params=index_params
+)
+    return True
+
+async def update_mulvus_file(client_id, file_name, chunks, collection_name='gloria_files'):
+    try:
+        start_time = time.time()
+        # 从图中拿数据
+        client = MilvusClient(uri=milvus_url)
+        await create_collection(collection_name=collection_name)
+        client.load_collection(collection_name=collection_name,)    
+        try:
+            
+            res = client.delete(collection_name=collection_name, filter=f'file_name == "{file_name}"')
+        except Exception as e:
+            
+            print(e)
+
+        final_data = []
+        filename_embedding = await asyncio.to_thread(model_embedding_bge, file_name)
+        for chunk in chunks:
+            chunk_embedding = await asyncio.to_thread(model_embedding_bge, chunk)
+            final_data.append({
+                'client_id': client_id, 
+                'file_name': file_name, 
+                "filename_embedding":filename_embedding,
+                'content': chunk, 
+                'content_embedding': chunk_embedding, 
+            })
+
+        insert_dict = await asyncio.to_thread(client.insert, collection_name=collection_name, data=final_data)
+        
+        status = {"result": "succeed"}
+        end_time = time.time()
+        logger_search.info(f"{file_name} 嵌入milvus单次耗时: {end_time - start_time} s")
+        client.release_collection(collection_name=collection_name)
+        client.close()
+        return status
+
+    except Exception as e:
+        logger_search.error(f"处理 {client_id} 的文件 {file_name} 嵌入milvus报错如下:{e}")
+        traceback.print_exc()
+        status = {"result": "failed"}
+        return status
+
+
+async def get_filedata_from_milvus(query, collection_name='gloria_files', file_name=None, thread=0.5):
+    try:
+        client = MilvusClient(uri=milvus_url)
+        query_embedding = await asyncio.to_thread(model_embedding_bge, query)
+        time0 = time.time()
+        client.load_collection(collection_name=collection_name)
+        time1 = time.time()
+        logger_search.info(f"load 时间: {time1 - time0} s")
+        
+        search_filter = f'file_name == "{file_name}"' if file_name else ''
+
+        # 创建针对不同字段的ANN搜索请求
+        search_param_1 = {
+            "data": [query_embedding],          # 确保输入为二维向量
+            "anns_field": "filename_embedding",# 文件名向量字段
+            "param": {
+                "metric_type": "COSINE",
+                "params": {"nprobe": 10}
+            },
+            "limit": 300,
+        }
+        request_1 = AnnSearchRequest(**search_param_1)
+
+        search_param_2 = {
+            "data": [query_embedding],
+            "anns_field": "content_embedding",# 段落内容向量字段
+            "param": {
+                "metric_type": "COSINE",
+                "params": {"nprobe": 10}
+            },
+            "limit": 300,
+            
+        }
+        request_2 = AnnSearchRequest(**search_param_2)
+
+        
+        req_temp = AnnSearchRequest(
+                data=[query],
+                anns_field='content_sparse',
+                param={
+                    "metric_type": 'BM25',
+                    "params": {'drop_ratio_search': 0.2}
+                },
+                limit=50,
+            )
+        # 执行混合搜索(文件名权重0.4,内容权重0.6)
+        rerank = WeightedRanker(0.2, 0.5, 0.5)
+        res = await asyncio.to_thread(
+            client.hybrid_search,
+            collection_name=collection_name,
+            reqs=[request_1, request_2, req_temp],
+            ranker=rerank,
+            output_fields=['content', 'file_name'],
+            filter = search_filter
+        )
+
+        client.release_collection(collection_name=collection_name)
+        client.close()
+        final_answer = []
+        for hits in res:
+            for hit in hits:
+                if hit.get('distance', 0) > thread:
+                    entity = hit.get('entity')
+                    final_answer.append({'file_name': entity.get("file_name"), 'content': entity.get("content")})
+    except Exception as e:
+        logger_search.error(e)
+        final_answer = []
+    logger_search.info(f'从milvus文件知识库搜索到{len(final_answer)}条信息')
+    return final_answer
+
+async def rerank_file_data(query, data):
+    lookup = {}  # 使用单个字典存储映射关系
+    chunks = []
+    final_chunks = {}
+    for item in data:
+        content = item.get('content', '')
+        # 直接存储原始数据项的引用,避免后续查找
+        lookup[content] = item
+        chunks.append(content)
+    # 执行重排序
+    logger_search.info(chunks)
+    reranked = bge_rerank(query=query, indicates=chunks, socre=0.3)
+
+    
+    send_data = [lookup[chunk] for chunk in reranked if chunk in lookup]
+    
+    reranked_content = [i.get('content') for i in send_data]
+    
+
+    send_data = [dict(t) for t in {tuple(d.items()) for d in send_data}]
+    logger_search.info(send_data)
+    for sed in send_data:
+        if not sed.get('file_name') in final_chunks:
+            final_chunks[sed.get('file_name')] = sed.get('content')
+        final_chunks[sed.get('file_name')] += sed.get('content')
+    logger_search.info(final_chunks)
+    search_results_str = "\n".join([f"[{i+1}]: \n {k}: \n{s} \n" for i, (k,s) in enumerate(final_chunks.items())])
+
+    return final_chunks, search_results_str
+
+async def get_search_results(query: Annotated[str, "需要搜索的问题"], summary=True):
+    data = await get_filedata_from_milvus(query=query)
+    final_data, search_str = await rerank_file_data(query=query,data=data)
+    if not final_data:
+        final_chunks = {}
+        for sed in data:
+            if not sed.get('file_name') in final_chunks:
+                final_chunks[sed.get('file_name')] = sed.get('content')
+            final_chunks[sed.get('file_name')] += sed.get('content')
+        search_str = "\n".join([f"[{i+1}]: \n {k}: \n{s} \n" for i, (k,s) in enumerate(final_chunks.items())])
+        final_data = final_chunks
+
+    logger_search.info(search_str)
+    return final_data, search_str
+
+
+if __name__ == '__main__':
+    from agent import get_content_summary
+    final_data, result = asyncio.run(get_search_results(query='每月盘点时间有要求吗?'))
+    data, res = asyncio.run(get_content_summary(question='每月盘点时间有要求吗?',res_info=result, final_data=final_data))
+    logger_search.info(data)
+    logger_search.info(res)
+    # x = asyncio.run(create_collection(collection_name='gloria_files'))
+    # print(x)
+    pass

+ 129 - 0
prompt.py

@@ -0,0 +1,129 @@
+
+
+output_system_prompt_use = """
+# ROLE  
+你叫歌莉娅AI,你是一位擅长使用工具来解决问题的专家:  
+- 擅长通过外部工具调用来获取结果
+
+
+# OBJECTIVE  
+根据用户提出的问题,调用工具来解决问题, 如果工具的信息还不够,可以使用继续使用多种工具。
+
+# TASK REQUIREMENTS  
+1. 信息获取   
+    - 对于常识性问题直接回答"ok"
+    - 动态问题需调用工具获取准确结果  
+    - 如果寻找不到与问题相关文档片段, 请调用工具查看全文内容
+    - 遇到无法解答的问题, 请使用搜索和推理工具
+    - 得到工具的返回结果后,直接回复"ok", 禁止进行多余的说明。
+
+2. 特定工具说明
+  - get_file_relate_content 是寻找与问题相关的文档片段, 有具体问题的时候使用
+  - get_file_full_content 是获取文件的全文信息, 在没有相关文档片段或者问题不具体的时候使用
+  - get_search_results 是搜索知识库和联网内容, 一般问到一些具体的问题可以用它搜索, 推理的问题不用它
+  - reasoning 这是多步推理,搜索,总结的函数,只有在问题涉及到需要多步推理的时候才用上它, 
+    如"与第五交响曲创作于同一世纪的交通工具是什么?", 他需要先搜索第五交响曲创作于哪个世纪, 然后再搜这个世纪发明的交通工具,要进行多步推理, 就用reasoning推理
+    如"与xx相似的xx是什么", 要进行多步推理, 就用reasoning推理
+
+2. 回答生成  
+  - 使用工具后或遇到常识问题时,直接回复"ok"
+
+3. 质量控制  
+    - 尽可能地使用工具解决用户问题,如果工具的信息还不够,可以使用继续使用多种工具。
+    - 确保工具的参数输入正确
+
+Begin!
+"""
+
+get_summary_system = """
+# ROLE  
+你叫歌莉娅AI,你是一位擅长从多个信息来源中找到与问题相关的来源并根据答案进行重写的专家:  
+
+# OBJECTIVE  
+根据用户提出的问题,从多个信息来源中找到与问题相关的来源,并根据答案进行重写的专家, 返回json格式
+
+# TASK REQUIREMENTS  
+1. 信息筛选   
+    - 如果有多个信息源, 则选择与问题相关的信息源进行内容重写
+
+2. 内容重写
+  - 对于选择的信息源进行重写时需要注意以下几点:
+    - 确保信息源重写信息是准确的, 禁止重写非信息源里面的内容
+    - 确保重写后的内容不丢失信息
+    - 确保重写内容与答案相关,不相关的可以不写进来
+
+2. 回答生成  
+  返回json格式,采用```json ```包裹重写信息
+
+
+# EXAMPLE
+Input:
+问题: "火星离地球多远? 参考资料: 1 xxxxx 2 xxxxx"
+
+Output: ```json
+{
+    "1": "根据问题和资料1提取的信息摘要",
+    "2": "根据问题和资料2提取的信息摘要"
+}
+```
+
+Begin!
+"""
+
+rag_system_prompt_qw = """
+# ROLE
+你叫歌莉娅AI, 是一位专业的知识库助手。
+
+# OBJECTIVE
+基于提供的参考文档,直接回答问题,不要添加冗余说明和回答与答案无关的内容。如果没有参考资料,就用你的知识进行回答, 禁止编造引用数据。
+
+# TASK REQUIREMENTS
+1. 资料解析
+    - 仔细阅读所有提供的参考资料
+    - 识别与问题相关的关键信息
+    - 注意文档的元数据信息(来源、时间、作者等)
+
+2. 引用规范
+    - 使用[1], [2]等格式在关键信息的句尾标注引用来源
+    - 每个关键信息都需要有对应引用
+
+3. 回答生成
+    - 直接回答问题,不要添加冗余说明和回答与答案无关的内容
+    - 基于引用的资料内容构建回答
+    - 确保每个重要论点都有引用支持
+    - 在信息不足时明确指出
+    - 多个来源信息时注意整合和对比
+    - 如果没有参考资料,就用你的知识进行常识回答。
+    - 严格基于提供的参考资料回答问题
+
+
+4. 质量控制
+    - 准确引用信息来源,清晰标注信息来源,确保引用准确无误
+    - 确保每个结论都有明确的引用支持
+    - 禁止透露你的引用能力和其他能力的信息
+    - 请在回答末尾罗列资料信息
+
+# RESPONSE FORMAT
+回答应包含以下部分:
+1. 主体回答(带有引用标注)
+2. 关键引用信息的罗列
+
+# INPUT FORMAT
+你将收到以下格式的输入:
+1. 参考资料信息(包含序号和内容)
+2. 用户问题
+
+# EXAMPLE
+Input:
+问题: "火星离地球多远? 参考资料: 1 xxxxx 2 xxxxx"
+
+Output:
+火星离地球....[1]
+[1] 相关信息的关键提取
+Begin!
+"""
+rag_system_prompt = """
+你叫歌莉娅AI, 是一位专业的知识库助手, 如有相关资料,请使用[1], [2]等格式在相关回答中引用来源,并且一定要在回答末尾添加来源信息的关键参考(对应文件名或来源title + 需总结过的简要关键信息)。
+"""
+rag_system_prompt_pure = """
+你叫歌莉娅AI, 是一位专业的知识库助手。你需要专业,详细地回答用户问题。"""

+ 145 - 0
server.py

@@ -0,0 +1,145 @@
+from fastapi import FastAPI, WebSocket, WebSocketDisconnect
+from fastapi.staticfiles import StaticFiles
+from fastapi import FastAPI, File, UploadFile, Form, Body
+from fastapi.responses import JSONResponse
+from fastapi.middleware.cors import CORSMiddleware
+import json, os, asyncio
+from milvus_process import update_mulvus_file, get_search_results
+from config import static_dir, upload_path, llm_config, llm_config_ds
+from prompt import output_system_prompt_use, rag_system_prompt, rag_system_prompt_pure, rag_system_prompt_qw
+from file_process import DocumentProcessor
+import traceback
+from autogen import register_function
+from copy import deepcopy
+from openai import AsyncOpenAI
+import autogen
+from agent import get_content_summary
+app = FastAPI()
+processor = DocumentProcessor()
+app.mount("/workspace", StaticFiles(directory=static_dir), name="static")
+app.add_middleware(
+    CORSMiddleware,
+   allow_origins=["*"],
+   allow_credentials=True,
+   allow_methods=["*"],
+   allow_headers=["*"],
+)
+status_map = {}
+
+
+@app.post("/chat")
+async def chat(client_id: str = Form(...), prompt: str = Form(...), history: str = Body(...)):
+    try:
+        
+        output_agent = autogen.AssistantAgent(
+            name="output_answer",
+            llm_config=llm_config,
+            system_message=output_system_prompt_use,
+            code_execution_config=False,
+            human_input_mode="NEVER",
+        )
+        user_proxy = autogen.UserProxyAgent(
+                name="user_proxy",
+                is_termination_msg=lambda x: x.get("content", "") and x.get("content", "").rstrip().endswith("TERMINATE"),
+                human_input_mode="ALWAYS",
+                max_consecutive_auto_reply=2,
+                code_execution_config=False
+            )
+        register_function(get_search_results,caller=output_agent,executor=user_proxy,name='get_search_results', description="搜索专业知识库和联网搜索获取信息, 用户所有的非常识性问题使用这个函数")
+        answer = {}
+        search_results = ''
+        history = history.replace("'", '"')
+        history = json.loads(history)  # 解析为列表
+        
+        message_use = deepcopy(history)
+        message_use.append({'role':'user', 'content':prompt})
+        
+        use_tool = 0
+        use_search = 0
+        while isinstance(answer, dict):
+            answer = await output_agent.a_generate_reply(messages=message_use)
+            
+            if isinstance(answer,dict):
+                    message_use.append(answer)
+                    tool_calls = answer.get('tool_calls', [])
+                    for call in tool_calls:
+                        if isinstance(call,dict):
+                            function_info = call.get('function',{})
+                            if function_info and isinstance(function_info,dict):
+                                func_name = function_info.get('name')
+                                func_args = function_info.get('arguments')
+                                
+                                # 将JSON字符串解析为字典
+                                try:
+                                    args = json.loads(func_args)
+                                except json.JSONDecodeError as e:
+                                    
+                                    message_use.append({'role': 'tool','name':func_name,'content':f"Failed to decode arguments: {e}"})
+                                    continue
+                                use_tool += 1
+                                # 查找并执行函数
+                                if func_name == 'get_search_results':
+                                    if args and isinstance(args, dict):
+                                        query = args.get('query',prompt)
+                                    else:
+                                        query = prompt
+                                    data, search_res = await get_search_results(query=query) 
+                                    final_data, search_results = await get_content_summary(question=prompt, res_info=search_res, final_data=data)
+                                    message_use.append({'role': 'tool','name':func_name,'content':search_results,})
+                                    use_search += 1
+        if search_results:
+                rag_system_prompt_use = rag_system_prompt_qw
+        else:
+            rag_system_prompt_use = rag_system_prompt_pure
+        rag_summary_agent = autogen.AssistantAgent(
+                            name="rag_answer",
+                            llm_config=llm_config,
+                            system_message=rag_system_prompt_use,
+                            code_execution_config=False,
+                            human_input_mode="NEVER",
+                        )
+        message_rag = deepcopy(history)
+        message_rag.append({'role':'user', 'content': prompt + '\n' + search_results if search_results else prompt})
+        final_answer = await rag_summary_agent.a_generate_reply(messages=message_rag)
+        return JSONResponse(content={
+                "total_tokens": 1000,
+                "completion_tokens": 1000,
+                "content": final_answer,
+            }, status_code=200)
+    except Exception as e:
+        print(f"出错啦:{str(e)}")
+        return JSONResponse(content={
+                "total_tokens": 1000,
+                "completion_tokens": 1000,
+                "content": '出错啦,请联系管理员吧!',
+            }, status_code=200)
+
+            
+@app.post("/uploadfile/")
+async def create_upload_file(file: UploadFile = File(...), client_id: str = Form(...)):
+    
+    temp_directory = upload_path.format(client_id=client_id)
+    if not os.path.exists(temp_directory):
+        os.makedirs(temp_directory)
+        os.chmod(temp_directory, 0o777)  # 设置用户目录权限为777
+    
+    file_location = os.path.join(temp_directory, file.filename)
+    
+    # try:
+    with open(file_location, "wb+") as file_object:
+        file_object.write(file.file.read())
+    os.chmod(file_location, 0o777)  # 设置文件权限为777
+    chunks = await asyncio.to_thread(processor.read_file, file_location)
+    update_status = await update_mulvus_file(client_id=client_id, file_name=file.filename, chunks=chunks)
+    return JSONResponse(content={
+        "message": f"文件 '{file.filename}' 上传成功",
+        "client_id": client_id,
+        "file_path": file.filename,
+        "update_status": update_status.get('result','succeed')
+    }, status_code=200)
+    # except Exception as e:
+    #     return JSONResponse(content={"message": f"发生错误: {str(e)}"}, status_code=500)
+
+if __name__ == "__main__":
+    import uvicorn
+    uvicorn.run(app, host="0.0.0.0", port=5666)

+ 50 - 0
util.py

@@ -0,0 +1,50 @@
+import base64
+import time,shutil, os
+import logging
+import os
+from logging.handlers import TimedRotatingFileHandler
+from datetime import datetime
+from colorama import Fore, Style, init
+
+def encode_to_base64(text):
+    encoded_bytes = base64.b64encode(text.encode('utf-8'))
+    return encoded_bytes.decode('utf-8').replace('=', '_deng_hao_').replace('+', '_jia_hao_').replace('/', '_xie_gang_')
+
+def decode_from_base64(encoded_text):
+    encoded_text = encoded_text.replace('_deng_hao_', '=').replace('_jia_hao_','+').replace('_xie_gang_', '/')
+    decoded_bytes = base64.b64decode(encoded_text.encode('utf-8'))
+    return decoded_bytes.decode('utf-8')
+def remove_files(basic_path):
+    current_time = time.time()
+    TIME_THRESHOLD_FILEPATH = 10 * 24 * 60 * 60
+    TIME_THRESHOLD_FILE = 2 * 60
+    for root, dirs, files in os.walk(basic_path, topdown=False):
+        try:
+            if current_time - os.path.getmtime(root) > TIME_THRESHOLD_FILEPATH:
+                
+                print(f"删除文件夹: {root}")
+                shutil.rmtree(root)
+                continue
+            for file in files:
+                file_path = os.path.join(root, file)
+                if current_time - os.path.getmtime(file_path) > TIME_THRESHOLD_FILE:
+                    print(f"删除文件: {file_path}")
+                    os.remove(file_path)
+            # 删除文件夹
+            for dir in dirs:
+                dir_path = os.path.join(root, dir)
+                if current_time - os.path.getmtime(dir_path) > TIME_THRESHOLD_FILEPATH:
+                    print(f"删除文件夹: {dir_path}")
+                    shutil.rmtree(dir_path)
+        except Exception as e:
+            print(f'删除文件出错:{e}')
+
+if __name__ == "__main__":
+    # 示例
+    chinese_text = "分区名称ggg.pdf"
+    encoded_text = encode_to_base64(chinese_text)
+    print("Base64 编码:", encoded_text)
+    decoded_text = decode_from_base64(encoded_text)
+
+    
+    print("解码后:", decoded_text)