123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145 |
- from fastapi import FastAPI, WebSocket, WebSocketDisconnect
- from fastapi.staticfiles import StaticFiles
- from fastapi import FastAPI, File, UploadFile, Form, Body
- from fastapi.responses import JSONResponse
- from fastapi.middleware.cors import CORSMiddleware
- import json, os, asyncio
- from milvus_process import update_mulvus_file, get_search_results
- from config import static_dir, upload_path, llm_config, llm_config_ds
- from prompt import output_system_prompt_use, rag_system_prompt, rag_system_prompt_pure, rag_system_prompt_qw
- from file_process import DocumentProcessor
- import traceback
- from autogen import register_function
- from copy import deepcopy
- from openai import AsyncOpenAI
- import autogen
- from agent import get_content_summary
- app = FastAPI()
- processor = DocumentProcessor()
- app.mount("/workspace", StaticFiles(directory=static_dir), name="static")
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- status_map = {}
- @app.post("/chat")
- async def chat(client_id: str = Form(...), prompt: str = Form(...), history: str = Body(...)):
- try:
-
- output_agent = autogen.AssistantAgent(
- name="output_answer",
- llm_config=llm_config,
- system_message=output_system_prompt_use,
- code_execution_config=False,
- human_input_mode="NEVER",
- )
- user_proxy = autogen.UserProxyAgent(
- name="user_proxy",
- is_termination_msg=lambda x: x.get("content", "") and x.get("content", "").rstrip().endswith("TERMINATE"),
- human_input_mode="ALWAYS",
- max_consecutive_auto_reply=2,
- code_execution_config=False
- )
- register_function(get_search_results,caller=output_agent,executor=user_proxy,name='get_search_results', description="搜索专业知识库和联网搜索获取信息, 用户所有的非常识性问题使用这个函数")
- answer = {}
- search_results = ''
- history = history.replace("'", '"')
- history = json.loads(history) # 解析为列表
-
- message_use = deepcopy(history)
- message_use.append({'role':'user', 'content':prompt})
-
- use_tool = 0
- use_search = 0
- while isinstance(answer, dict):
- answer = await output_agent.a_generate_reply(messages=message_use)
-
- if isinstance(answer,dict):
- message_use.append(answer)
- tool_calls = answer.get('tool_calls', [])
- for call in tool_calls:
- if isinstance(call,dict):
- function_info = call.get('function',{})
- if function_info and isinstance(function_info,dict):
- func_name = function_info.get('name')
- func_args = function_info.get('arguments')
-
- # 将JSON字符串解析为字典
- try:
- args = json.loads(func_args)
- except json.JSONDecodeError as e:
-
- message_use.append({'role': 'tool','name':func_name,'content':f"Failed to decode arguments: {e}"})
- continue
- use_tool += 1
- # 查找并执行函数
- if func_name == 'get_search_results':
- if args and isinstance(args, dict):
- query = args.get('query',prompt)
- else:
- query = prompt
- data, search_res = await get_search_results(query=query)
- final_data, search_results = await get_content_summary(question=prompt, res_info=search_res, final_data=data)
- message_use.append({'role': 'tool','name':func_name,'content':search_results,})
- use_search += 1
- if search_results:
- rag_system_prompt_use = rag_system_prompt_qw
- else:
- rag_system_prompt_use = rag_system_prompt_pure
- rag_summary_agent = autogen.AssistantAgent(
- name="rag_answer",
- llm_config=llm_config,
- system_message=rag_system_prompt_use,
- code_execution_config=False,
- human_input_mode="NEVER",
- )
- message_rag = deepcopy(history)
- message_rag.append({'role':'user', 'content': prompt + '\n' + search_results if search_results else prompt})
- final_answer = await rag_summary_agent.a_generate_reply(messages=message_rag)
- return JSONResponse(content={
- "total_tokens": 1000,
- "completion_tokens": 1000,
- "content": final_answer,
- }, status_code=200)
- except Exception as e:
- print(f"出错啦:{str(e)}")
- return JSONResponse(content={
- "total_tokens": 1000,
- "completion_tokens": 1000,
- "content": '出错啦,请联系管理员吧!',
- }, status_code=200)
-
- @app.post("/uploadfile/")
- async def create_upload_file(file: UploadFile = File(...), client_id: str = Form(...)):
-
- temp_directory = upload_path.format(client_id=client_id)
- if not os.path.exists(temp_directory):
- os.makedirs(temp_directory)
- os.chmod(temp_directory, 0o777) # 设置用户目录权限为777
-
- file_location = os.path.join(temp_directory, file.filename)
-
- # try:
- with open(file_location, "wb+") as file_object:
- file_object.write(file.file.read())
- os.chmod(file_location, 0o777) # 设置文件权限为777
- chunks = await asyncio.to_thread(processor.read_file, file_location)
- update_status = await update_mulvus_file(client_id=client_id, file_name=file.filename, chunks=chunks)
- return JSONResponse(content={
- "message": f"文件 '{file.filename}' 上传成功",
- "client_id": client_id,
- "file_path": file.filename,
- "update_status": update_status.get('result','succeed')
- }, status_code=200)
- # except Exception as e:
- # return JSONResponse(content={"message": f"发生错误: {str(e)}"}, status_code=500)
- if __name__ == "__main__":
- import uvicorn
- uvicorn.run(app, host="0.0.0.0", port=5666)
|