qwen_function_call.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881
  1. import os, copy, uuid, json, re
  2. import tiktoken, base64, unicodedata
  3. from pathlib import Path
  4. import urllib.parse
  5. from typing import Dict, Union, List, Literal, Any, Tuple,Collection,Set,Optional
  6. from pydantic import BaseModel, field_validator, model_validator
  7. from openai.types.chat import ChatCompletionMessage
  8. from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall
  9. VOCAB_FILES_NAMES = {'vocab_file': 'qwen.tiktoken'}
  10. PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
  11. ENDOFTEXT = '<|endoftext|>'
  12. IMSTART = '<|im_start|>'
  13. IMEND = '<|im_end|>'
  14. # as the default behavior is changed to allow special tokens in
  15. # regular texts, the surface forms of special tokens need to be
  16. # as different as possible to minimize the impact
  17. EXTRAS = tuple((f'<|extra_{i}|>' for i in range(205)))
  18. # changed to use actual index to avoid misconfiguration with vocabulary expansion
  19. SPECIAL_START_ID = 151643
  20. SPECIAL_TOKENS = tuple(enumerate(
  21. ((
  22. ENDOFTEXT,
  23. IMSTART,
  24. IMEND,
  25. ) + EXTRAS),
  26. start=SPECIAL_START_ID,
  27. ))
  28. ROLE = 'role'
  29. CONTENT = 'content'
  30. NAME = 'name'
  31. SYSTEM = 'system'
  32. USER = 'user'
  33. ASSISTANT = 'assistant'
  34. FUNCTION = 'function'
  35. FN_NAME = '✿FUNCTION✿'
  36. FN_ARGS = '✿ARGS✿'
  37. FN_RESULT = '✿RESULT✿'
  38. FN_EXIT = '✿RETURN✿'
  39. FN_STOP_WORDS = [FN_RESULT, FN_EXIT]
  40. FN_CALL_TEMPLATE_INFO_ZH = """# 工具
  41. ## 你拥有如下工具:
  42. {tool_descs}"""
  43. FN_CALL_TEMPLATE_INFO_EN = """# Tools
  44. ## You have access to the following tools:
  45. {tool_descs}"""
  46. FN_CALL_TEMPLATE_FMT_ZH = """## 你可以在回复中插入零次、一次或多次以下命令以调用工具:
  47. %s: 工具名称,必须是[{tool_names}]之一。
  48. %s: 工具输入
  49. %s: 工具结果
  50. %s: 根据工具结果进行回复,需将图片用![](url)渲染出来""" % (
  51. FN_NAME,
  52. FN_ARGS,
  53. FN_RESULT,
  54. FN_EXIT,
  55. )
  56. FN_CALL_TEMPLATE_FMT_EN = """## When you need to call a tool, please insert the following command in your reply, which can be called zero or multiple times according to your needs:
  57. %s: The tool to use, should be one of [{tool_names}]
  58. %s: The input of the tool
  59. %s: Tool results
  60. %s: Reply based on tool results. Images need to be rendered as ![](url)""" % (
  61. FN_NAME,
  62. FN_ARGS,
  63. FN_RESULT,
  64. FN_EXIT,
  65. )
  66. FN_CALL_TEMPLATE_FMT_PARA_ZH = """## 你可以在回复中插入以下命令以并行调用N个工具:
  67. %s: 工具1的名称,必须是[{tool_names}]之一
  68. %s: 工具1的输入
  69. %s: 工具2的名称
  70. %s: 工具2的输入
  71. ...
  72. %s: 工具N的名称
  73. %s: 工具N的输入
  74. %s: 工具1的结果
  75. %s: 工具2的结果
  76. ...
  77. %s: 工具N的结果
  78. %s: 根据工具结果进行回复,需将图片用![](url)渲染出来""" % (
  79. FN_NAME,
  80. FN_ARGS,
  81. FN_NAME,
  82. FN_ARGS,
  83. FN_NAME,
  84. FN_ARGS,
  85. FN_RESULT,
  86. FN_RESULT,
  87. FN_RESULT,
  88. FN_EXIT,
  89. )
  90. FN_CALL_TEMPLATE_FMT_PARA_EN = """## Insert the following command in your reply when you need to call N tools in parallel:
  91. %s: The name of tool 1, should be one of [{tool_names}]
  92. %s: The input of tool 1
  93. %s: The name of tool 2
  94. %s: The input of tool 2
  95. ...
  96. %s: The name of tool N
  97. %s: The input of tool N
  98. %s: The result of tool 1
  99. %s: The result of tool 2
  100. ...
  101. %s: The result of tool N
  102. %s: Reply based on tool results. Images need to be rendered as ![](url)""" % (
  103. FN_NAME,
  104. FN_ARGS,
  105. FN_NAME,
  106. FN_ARGS,
  107. FN_NAME,
  108. FN_ARGS,
  109. FN_RESULT,
  110. FN_RESULT,
  111. FN_RESULT,
  112. FN_EXIT,
  113. )
  114. FN_CALL_TEMPLATE = {
  115. 'zh': FN_CALL_TEMPLATE_INFO_ZH + '\n\n' + FN_CALL_TEMPLATE_FMT_ZH,
  116. 'en': FN_CALL_TEMPLATE_INFO_EN + '\n\n' + FN_CALL_TEMPLATE_FMT_EN,
  117. 'zh_parallel': FN_CALL_TEMPLATE_INFO_ZH + '\n\n' + FN_CALL_TEMPLATE_FMT_PARA_ZH,
  118. 'en_parallel': FN_CALL_TEMPLATE_INFO_EN + '\n\n' + FN_CALL_TEMPLATE_FMT_PARA_EN,
  119. }
  120. CHINESE_CHAR_RE = re.compile(r'[\u4e00-\u9fff]')
  121. class QWenTokenizer:
  122. """QWen tokenizer."""
  123. vocab_files_names = VOCAB_FILES_NAMES
  124. def __init__(
  125. self,
  126. vocab_file=None,
  127. errors='replace',
  128. extra_vocab_file=None,
  129. **kwargs,
  130. ):
  131. if not vocab_file:
  132. vocab_file = VOCAB_FILES_NAMES['vocab_file']
  133. self._decode_use_source_tokenizer = False
  134. # how to handle errors in decoding UTF-8 byte sequences
  135. # use ignore if you are in streaming inference
  136. self.errors = errors
  137. self.mergeable_ranks = self._load_tiktoken_bpe(vocab_file) # type: Dict[bytes, int]
  138. self.special_tokens = {token: index for index, token in SPECIAL_TOKENS}
  139. # try load extra vocab from file
  140. if extra_vocab_file is not None:
  141. used_ids = set(self.mergeable_ranks.values()) | set(self.special_tokens.values())
  142. extra_mergeable_ranks = self._load_tiktoken_bpe(extra_vocab_file)
  143. for token, index in extra_mergeable_ranks.items():
  144. if token in self.mergeable_ranks:
  145. continue
  146. if index in used_ids:
  147. continue
  148. self.mergeable_ranks[token] = index
  149. # the index may be sparse after this, but don't worry tiktoken.Encoding will handle this
  150. enc = tiktoken.Encoding(
  151. 'Qwen',
  152. pat_str=PAT_STR,
  153. mergeable_ranks=self.mergeable_ranks,
  154. special_tokens=self.special_tokens,
  155. )
  156. assert len(self.mergeable_ranks) + len(
  157. self.special_tokens
  158. ) == enc.n_vocab, f'{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding'
  159. self.decoder = {v: k for k, v in self.mergeable_ranks.items()} # type: dict[int, bytes|str]
  160. self.decoder.update({v: k for k, v in self.special_tokens.items()})
  161. self.tokenizer = enc # type: tiktoken.Encoding
  162. self.eod_id = self.tokenizer.eot_token
  163. self.im_start_id = self.special_tokens[IMSTART]
  164. self.im_end_id = self.special_tokens[IMEND]
  165. def _load_tiktoken_bpe(self, tiktoken_bpe_file: str) -> Dict[bytes, int]:
  166. with open(tiktoken_bpe_file, 'rb') as f:
  167. contents = f.read()
  168. return {
  169. base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line)
  170. }
  171. def __getstate__(self):
  172. # for pickle lovers
  173. state = self.__dict__.copy()
  174. del state['tokenizer']
  175. return state
  176. def __setstate__(self, state):
  177. # tokenizer is not python native; don't pass it; rebuild it
  178. self.__dict__.update(state)
  179. enc = tiktoken.Encoding(
  180. 'Qwen',
  181. pat_str=PAT_STR,
  182. mergeable_ranks=self.mergeable_ranks,
  183. special_tokens=self.special_tokens,
  184. )
  185. self.tokenizer = enc
  186. def __len__(self) -> int:
  187. return self.tokenizer.n_vocab
  188. def get_vocab(self) -> Dict[bytes, int]:
  189. return self.mergeable_ranks
  190. def convert_tokens_to_ids(self, tokens: Union[bytes, str, List[Union[bytes, str]]]) -> List[int]:
  191. ids = []
  192. if isinstance(tokens, (str, bytes)):
  193. if tokens in self.special_tokens:
  194. return self.special_tokens[tokens]
  195. else:
  196. return self.mergeable_ranks.get(tokens)
  197. for token in tokens:
  198. if token in self.special_tokens:
  199. ids.append(self.special_tokens[token])
  200. else:
  201. ids.append(self.mergeable_ranks.get(token))
  202. return ids
  203. def tokenize(
  204. self,
  205. text: str,
  206. allowed_special: Union[Set, str] = 'all',
  207. disallowed_special: Union[Collection, str] = (),
  208. **kwargs,
  209. ) -> List[Union[bytes, str]]:
  210. """
  211. Converts a string in a sequence of tokens.
  212. Args:
  213. text (`str`):
  214. The sequence to be encoded.
  215. allowed_special (`Literal["all"]` or `set`):
  216. The surface forms of the tokens to be encoded as special tokens in regular texts.
  217. Default to "all".
  218. disallowed_special (`Literal["all"]` or `Collection`):
  219. The surface forms of the tokens that should not be in regular texts and trigger errors.
  220. Default to an empty tuple.
  221. kwargs (additional keyword arguments, *optional*):
  222. Will be passed to the underlying model specific encode method.
  223. Returns:
  224. `List[bytes|str]`: The list of tokens.
  225. """
  226. tokens = []
  227. text = unicodedata.normalize('NFC', text)
  228. # this implementation takes a detour: text -> token id -> token surface forms
  229. for t in self.tokenizer.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special):
  230. tokens.append(self.decoder[t])
  231. return tokens
  232. def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
  233. """
  234. Converts a sequence of tokens in a single string.
  235. """
  236. text = ''
  237. temp = b''
  238. for t in tokens:
  239. if isinstance(t, str):
  240. if temp:
  241. text += temp.decode('utf-8', errors=self.errors)
  242. temp = b''
  243. text += t
  244. elif isinstance(t, bytes):
  245. temp += t
  246. else:
  247. raise TypeError('token should only be of type types or str')
  248. if temp:
  249. text += temp.decode('utf-8', errors=self.errors)
  250. return text
  251. @property
  252. def vocab_size(self):
  253. return self.tokenizer.n_vocab
  254. def _decode(
  255. self,
  256. token_ids: Union[int, List[int]],
  257. skip_special_tokens: bool = False,
  258. errors: str = None,
  259. **kwargs,
  260. ) -> str:
  261. if isinstance(token_ids, int):
  262. token_ids = [token_ids]
  263. if skip_special_tokens:
  264. token_ids = [i for i in token_ids if i < self.eod_id]
  265. return self.tokenizer.decode(token_ids, errors=errors or self.errors)
  266. def encode(self, text: str) -> List[int]:
  267. return self.convert_tokens_to_ids(self.tokenize(text))
  268. def count_tokens(self, text: str) -> int:
  269. return len(self.tokenize(text))
  270. def truncate(self, text: str, max_token: int, start_token: int = 0) -> str:
  271. token_list = self.tokenize(text)
  272. token_list = token_list[start_token:min(len(token_list), start_token + max_token)]
  273. return self.convert_tokens_to_string(token_list)
  274. class BaseModelCompatibleDict(BaseModel):
  275. def __getitem__(self, item):
  276. return getattr(self, item)
  277. def __setitem__(self, key, value):
  278. setattr(self, key, value)
  279. def model_dump(self, **kwargs):
  280. return super().model_dump(exclude_none=True, **kwargs)
  281. def model_dump_json(self, **kwargs):
  282. return super().model_dump_json(exclude_none=True, **kwargs)
  283. def get(self, key, default=None):
  284. try:
  285. value = getattr(self, key)
  286. if value:
  287. return value
  288. else:
  289. return default
  290. except AttributeError:
  291. return default
  292. def __str__(self):
  293. return f'{self.model_dump()}'
  294. class FunctionCall(BaseModelCompatibleDict):
  295. name: str
  296. arguments: str
  297. def __init__(self, name: str, arguments: str):
  298. super().__init__(name=name, arguments=arguments)
  299. def __repr__(self):
  300. return f'FunctionCall({self.model_dump()})'
  301. class ContentItem(BaseModelCompatibleDict):
  302. text: Optional[str] = None
  303. image: Optional[str] = None
  304. file: Optional[str] = None
  305. def __init__(self, text: Optional[str] = None, image: Optional[str] = None, file: Optional[str] = None):
  306. super().__init__(text=text, image=image, file=file)
  307. @model_validator(mode='after')
  308. def check_exclusivity(self):
  309. provided_fields = 0
  310. if self.text is not None:
  311. provided_fields += 1
  312. if self.image:
  313. provided_fields += 1
  314. if self.file:
  315. provided_fields += 1
  316. if provided_fields != 1:
  317. raise ValueError("Exactly one of 'text', 'image', or 'file' must be provided.")
  318. return self
  319. def __repr__(self):
  320. return f'ContentItem({self.model_dump()})'
  321. def get_type_and_value(self) -> Tuple[Literal['text', 'image', 'file'], str]:
  322. (t, v), = self.model_dump().items()
  323. assert t in ('text', 'image', 'file')
  324. return t, v
  325. @property
  326. def type(self) -> Literal['text', 'image', 'file']:
  327. t, v = self.get_type_and_value()
  328. return t
  329. @property
  330. def value(self) -> str:
  331. t, v = self.get_type_and_value()
  332. return v
  333. class Message(BaseModelCompatibleDict):
  334. role: str
  335. content: Union[str, List[ContentItem]]
  336. name: Optional[str] = None
  337. function_call: Optional[FunctionCall] = None
  338. def __init__(self,
  339. role: str,
  340. content: Optional[Union[str, List[ContentItem]]],
  341. name: Optional[str] = None,
  342. function_call: Optional[FunctionCall] = None,
  343. **kwargs):
  344. if content is None:
  345. content = ''
  346. super().__init__(role=role, content=content, name=name, function_call=function_call)
  347. def __repr__(self):
  348. return f'Message({self.model_dump()})'
  349. @field_validator('role')
  350. def role_checker(cls, value: str) -> str:
  351. if value not in [USER, ASSISTANT, SYSTEM, FUNCTION]:
  352. raise ValueError(f'{value} must be one of {",".join([USER, ASSISTANT, SYSTEM, FUNCTION])}')
  353. return value
  354. class messages_process:
  355. def __init__(self) -> None:
  356. pass
  357. def preprocess(self, messages, func):
  358. lang: Literal['en', 'zh'] = 'zh' if self.has_chinese_messages(messages) else 'en'
  359. new_messages = []
  360. # Only return dict when all input messages are dict
  361. if not messages:
  362. _return_message_type = 'message'
  363. for msg in messages:
  364. if isinstance(msg, dict):
  365. new_messages.append(Message(**msg))
  366. else:
  367. new_messages.append(msg)
  368. messages = copy.deepcopy(new_messages)
  369. messages = self._format_as_text_messages(messages)
  370. messages = self.prepend_fncall_system(messages, functions=func, lang=lang)
  371. if messages and messages[-1].role == ASSISTANT:
  372. assert len(messages) > 1 and messages[-2].role == USER
  373. assert messages[-1].function_call is None
  374. usr = messages[-2].content
  375. bot = messages[-1].content
  376. sep = '\n\n'
  377. if isinstance(usr, str) and isinstance(bot, str):
  378. usr = usr + sep + bot
  379. elif isinstance(usr, list) and isinstance(bot, list):
  380. usr = usr + [ContentItem(text=sep)] + bot
  381. else:
  382. raise NotImplementedError
  383. text_to_complete = copy.deepcopy(messages[-2])
  384. text_to_complete.content = usr
  385. messages = messages[:-2] + [text_to_complete]
  386. messages = [msg.model_dump() for msg in messages]
  387. return messages
  388. def post_process(self, messages, generate_cfg):
  389. messages = [self.format_as_multimodal_message(msg, add_upload_info=False) for msg in messages]
  390. if not generate_cfg.get('skip_stopword_postproc', False):
  391. stop = generate_cfg.get('stop', [])
  392. messages = self._postprocess_stop_words(messages, stop=stop)
  393. messages = self.postprocess_fncall_messages(messages)
  394. messages = self.convert_messages_to_target_type(messages, 'message')
  395. return messages
  396. def has_chinese_messages(self, messages: List[Union[Message, dict]], check_roles: Tuple[str] = (SYSTEM, USER)) -> bool:
  397. for m in messages:
  398. if m['role'] in check_roles:
  399. if self.has_chinese_chars(m['content']):
  400. return True
  401. return False
  402. def _format_as_text_messages(self, messages: List[Message]) -> List[Message]:
  403. for msg in messages:
  404. if isinstance(msg.content, list):
  405. for item in msg.content:
  406. assert item.type == 'text'
  407. else:
  408. assert isinstance(msg.content, str)
  409. messages = [self.format_as_text_message(msg, add_upload_info=False) for msg in messages]
  410. return messages
  411. def prepend_fncall_system(self, messages: List[Message], functions: List[Dict], lang: Literal['en', 'zh'],
  412. parallel_function_calls: bool = False, ):
  413. tool_desc_template = FN_CALL_TEMPLATE[lang + ('_parallel' if parallel_function_calls else '')]
  414. tool_descs = '\n\n'.join(self.get_function_description(function, lang=lang) for function in functions)
  415. tool_names = ','.join(function.get('name', function.get('name_for_model', '')) for function in functions)
  416. tool_system = tool_desc_template.format(tool_descs=tool_descs, tool_names=tool_names)
  417. assert messages[0].role == SYSTEM
  418. messages = copy.deepcopy(messages[:1]) + messages[1:]
  419. if isinstance(messages[0].content, str):
  420. messages[0].content += '\n\n' + tool_system
  421. else:
  422. messages[0].content.append(ContentItem(text='\n\n' + tool_system))
  423. return messages
  424. def get_function_description(self, function: Dict, lang: Literal['en', 'zh']) -> str:
  425. """
  426. Text description of function
  427. """
  428. tool_desc_template = {
  429. 'zh': '### {name_for_human}\n\n{name_for_model}: {description_for_model} 输入参数:{parameters} {args_format}',
  430. 'en': '### {name_for_human}\n\n{name_for_model}: {description_for_model} Parameters: {parameters} {args_format}'
  431. }
  432. tool_desc = tool_desc_template[lang]
  433. name = function.get('name', None)
  434. name_for_human = function.get('name_for_human', name)
  435. name_for_model = function.get('name_for_model', name)
  436. assert name_for_human and name_for_model
  437. if name_for_model == 'code_interpreter':
  438. args_format = {
  439. 'zh': '此工具的输入应为Markdown代码块。',
  440. 'en': 'Enclose the code within triple backticks (`) at the beginning and end of the code.',
  441. }
  442. else:
  443. args_format = {
  444. 'zh': '此工具的输入应为JSON对象。',
  445. 'en': 'Format the arguments as a JSON object.',
  446. }
  447. args_format = function.get('args_format', args_format[lang])
  448. return tool_desc.format(name_for_human=name_for_human,
  449. name_for_model=name_for_model,
  450. description_for_model=function['description'],
  451. parameters=json.dumps(function['parameters'], ensure_ascii=False),
  452. args_format=args_format).rstrip()
  453. def format_as_multimodal_message(
  454. self,
  455. msg: Message,
  456. add_upload_info: bool,
  457. lang: Literal['auto', 'en', 'zh'] = 'auto',
  458. ) -> Message:
  459. assert msg.role in (USER, ASSISTANT, SYSTEM, FUNCTION)
  460. content: List[ContentItem] = []
  461. if isinstance(msg.content, str): # if text content
  462. if msg.content:
  463. content = [ContentItem(text=msg.content)]
  464. elif isinstance(msg.content, list): # if multimodal content
  465. files = []
  466. for item in msg.content:
  467. k, v = item.get_type_and_value()
  468. if k == 'text':
  469. content.append(ContentItem(text=v))
  470. if k == 'image':
  471. content.append(item)
  472. if k in ('file', 'image'):
  473. # Move 'file' out of 'content' since it's not natively supported by models
  474. files.append(v)
  475. if add_upload_info and files and (msg.role in (SYSTEM, USER)):
  476. if lang == 'auto':
  477. has_zh = self.has_chinese_chars(msg)
  478. else:
  479. has_zh = (lang == 'zh')
  480. upload = []
  481. for f in [self.get_basename_from_url(f) for f in files]:
  482. if self.is_image(f):
  483. if has_zh:
  484. upload.append(f'![图片]({f})')
  485. else:
  486. upload.append(f'![image]({f})')
  487. else:
  488. if has_zh:
  489. upload.append(f'[文件]({f})')
  490. else:
  491. upload.append(f'[file]({f})')
  492. upload = ' '.join(upload)
  493. if has_zh:
  494. upload = f'(上传了 {upload})\n\n'
  495. else:
  496. upload = f'(Uploaded {upload})\n\n'
  497. # Check and avoid adding duplicate upload info
  498. upload_info_already_added = False
  499. for item in content:
  500. if item.text and (upload in item.text):
  501. upload_info_already_added = True
  502. if not upload_info_already_added:
  503. content = [ContentItem(text=upload)] + content
  504. else:
  505. raise TypeError
  506. msg = Message(
  507. role=msg.role,
  508. content=content,
  509. name=msg.name if msg.role == FUNCTION else None,
  510. function_call=msg.function_call,
  511. )
  512. return msg
  513. def format_as_text_message(self,
  514. msg: Message,
  515. add_upload_info: bool,
  516. lang: Literal['auto', 'en', 'zh'] = 'auto',
  517. ) -> Message:
  518. msg = self.format_as_multimodal_message(msg, add_upload_info=add_upload_info, lang=lang)
  519. text = ''
  520. for item in msg.content:
  521. if item.type == 'text':
  522. text += item.value
  523. msg.content = text
  524. return msg
  525. def _postprocess_stop_words(self, messages: List[Message], stop: List[str]) -> List[Message]:
  526. messages = copy.deepcopy(messages)
  527. # Make sure it stops before stop words.
  528. trunc_messages = []
  529. for msg in messages:
  530. truncated = False
  531. trunc_content = []
  532. for i, item in enumerate(msg.content):
  533. item_type, item_text = item.get_type_and_value()
  534. if item_type == 'text':
  535. truncated, item.text = self._truncate_at_stop_word(text=item_text, stop=stop)
  536. trunc_content.append(item)
  537. if truncated:
  538. break
  539. msg.content = trunc_content
  540. trunc_messages.append(msg)
  541. if truncated:
  542. break
  543. messages = trunc_messages
  544. # It may ends with partial stopword 'Observation' when the full stopword is 'Observation:'.
  545. # The following post-processing step removes partial stop words.
  546. partial_stop = []
  547. for s in stop:
  548. s = tokenizer.tokenize(s)[:-1]
  549. if s:
  550. s = tokenizer.convert_tokens_to_string(s)
  551. partial_stop.append(s)
  552. partial_stop = sorted(set(partial_stop))
  553. last_msg = messages[-1].content
  554. for i in range(len(last_msg) - 1, -1, -1):
  555. item_type, item_text = last_msg[i].get_type_and_value()
  556. if item_type == 'text':
  557. for s in partial_stop:
  558. if item_text.endswith(s):
  559. last_msg[i].text = item_text[:-len(s)]
  560. break
  561. return messages
  562. def postprocess_fncall_messages(self, messages: List[Message]) -> List[Message]:
  563. """
  564. If the model calls function by built-in function call template,
  565. convert and display it in function_call format.
  566. """
  567. # Remove ': ' brought by continued generation of function calling
  568. last_msg = messages[-1].content
  569. for i in range(len(last_msg)):
  570. item_type, item_text = last_msg[i].get_type_and_value()
  571. if item_type == 'text':
  572. if item_text.startswith(': '):
  573. last_msg[i].text = item_text[2:]
  574. elif item_text.startswith(':'):
  575. last_msg[i].text = item_text[1:]
  576. break
  577. new_messages = []
  578. for msg in messages:
  579. role, content = msg.role, msg.content
  580. assert isinstance(content, list)
  581. if role in (SYSTEM, USER):
  582. new_messages.append(Message(role=role, content=content))
  583. continue
  584. new_content = []
  585. for item in content:
  586. item_type, item_text = item.get_type_and_value()
  587. if item_type != 'text': # multimodal
  588. new_content.append(item)
  589. continue
  590. for stop_word in [FN_RESULT, FN_EXIT]:
  591. assert stop_word in FN_STOP_WORDS
  592. assert stop_word not in item_text, 'Something wrong, stop words are expected to be excluded.'
  593. i = item_text.find(f'{FN_NAME}:')
  594. # If no function call:
  595. if i < 0:
  596. show_text = self.remove_incomplete_special_tokens(item_text)
  597. if show_text:
  598. new_content.append(ContentItem(text=show_text))
  599. continue
  600. # If it says something before function call:
  601. if i > 0:
  602. answer = item_text[:i].lstrip('\n').rstrip()
  603. if answer.endswith('\n'):
  604. answer = answer[:-1]
  605. show_text = self.remove_incomplete_special_tokens(answer)
  606. if show_text:
  607. new_content.append(ContentItem(text=show_text))
  608. if new_content:
  609. new_messages.append(Message(
  610. role=role,
  611. content=new_content,
  612. )) # split thought and function call
  613. new_content = []
  614. item_text = item_text[i:]
  615. # If has function call:
  616. for part in item_text.split(f'{FN_NAME}:'):
  617. if not part:
  618. continue
  619. if part.endswith('\n'):
  620. part = part[:-1]
  621. arg_sep = f'{FN_ARGS}:'
  622. i = part.find(arg_sep)
  623. if i < 0:
  624. fn_name = part.strip()
  625. list_of_fn_args = ['']
  626. else:
  627. fn_name = part[:i].strip()
  628. list_of_fn_args = [_.strip() for _ in part[i + len(arg_sep):].split(arg_sep)]
  629. fn_name = self.remove_incomplete_special_tokens(fn_name)
  630. for fn_args in list_of_fn_args:
  631. fn_args = self.remove_incomplete_special_tokens(fn_args)
  632. fn_args = self.remove_trailing_comment_of_fn_args(fn_args)
  633. new_messages.append(
  634. Message(
  635. role=ASSISTANT,
  636. content=[],
  637. function_call=FunctionCall(
  638. name=fn_name,
  639. arguments=fn_args,
  640. ),
  641. ))
  642. # Break here and discard the text after function call
  643. return new_messages
  644. if new_content:
  645. new_messages.append(Message(role=role, content=new_content))
  646. return new_messages
  647. def remove_incomplete_special_tokens(self, text: str) -> str:
  648. special_tokens = (FN_NAME, FN_ARGS, FN_RESULT, FN_EXIT)
  649. text = text.rstrip()
  650. if text.endswith(special_tokens):
  651. for s in special_tokens:
  652. if text.endswith(s):
  653. text = text[:-len(s)]
  654. break
  655. else:
  656. trail_start = text.rfind('✿')
  657. trail_token = text[trail_start:]
  658. for s in special_tokens:
  659. if s.startswith(trail_token):
  660. text = text[:trail_start]
  661. break
  662. text = text.lstrip('\n').rstrip()
  663. return text
  664. # For hotfix badcases such as `{"arg1": "value1"} <!-- this is an example comment -->`.
  665. def remove_trailing_comment_of_fn_args(self, fn_args: str):
  666. fn_args = fn_args.strip()
  667. if fn_args.startswith('{'):
  668. k = fn_args.rfind('}')
  669. if k > 0:
  670. fn_args = fn_args[:k + 1]
  671. if fn_args.startswith('```'):
  672. k = fn_args.rfind('\n```')
  673. if k > 0:
  674. fn_args = fn_args[:k + 4]
  675. return fn_args
  676. def convert_messages_to_target_type(self, messages: List[Message],
  677. target_type: str) -> Union[List[Message], List[Dict]]:
  678. if target_type == 'message':
  679. return [Message(**x) if isinstance(x, dict) else x for x in messages]
  680. elif target_type == 'dict':
  681. return [x.model_dump() if not isinstance(x, dict) else x for x in messages]
  682. else:
  683. raise NotImplementedError
  684. def has_chinese_chars(self, data: Any) -> bool:
  685. text = f'{data}'
  686. return bool(CHINESE_CHAR_RE.search(text))
  687. def _truncate_at_stop_word(self, text: str, stop: List[str]):
  688. truncated = False
  689. for s in stop:
  690. k = text.find(s)
  691. if k >= 0:
  692. truncated = True
  693. text = text[:k]
  694. return truncated, text
  695. def is_image(self, path_or_url: str) -> bool:
  696. filename = self.get_basename_from_url(path_or_url).lower()
  697. for ext in ['jpg', 'jpeg', 'png', 'webp']:
  698. if filename.endswith(ext):
  699. return True
  700. return False
  701. def get_basename_from_url(self, path_or_url: str) -> str:
  702. if re.match(r'^[A-Za-z]:\\', path_or_url):
  703. # "C:\\a\\b\\c" -> "C:/a/b/c"
  704. path_or_url = path_or_url.replace('\\', '/')
  705. # "/mnt/a/b/c" -> "c"
  706. # "https://github.com/here?k=v" -> "here"
  707. # "https://github.com/" -> ""
  708. basename = urllib.parse.urlparse(path_or_url).path
  709. basename = os.path.basename(basename)
  710. basename = urllib.parse.unquote(basename)
  711. basename = basename.strip()
  712. # "https://github.com/" -> "" -> "github.com"
  713. if not basename:
  714. basename = [x.strip() for x in path_or_url.split('/') if x.strip()][-1]
  715. return basename
  716. @staticmethod
  717. def detect_tool(message: Message) -> Tuple[bool, str, str, str]:
  718. func_name = None
  719. func_args = None
  720. if message.function_call:
  721. func_call = message.function_call
  722. func_name = func_call.name
  723. func_args = func_call.arguments
  724. text = message.content
  725. if not text:
  726. text = ''
  727. return (func_name is not None), func_name, func_args, text
  728. @staticmethod
  729. def create_chat_completion_message(
  730. role: str,
  731. content: Optional[str] = None,
  732. tool_calls: Optional[List[dict]] = None
  733. ) -> ChatCompletionMessage:
  734. """
  735. 创建一个ChatCompletionMessage对象,支持添加tool_calls。
  736. :param role: 消息的角色("system", "user", "assistant", "tool")
  737. :param content: 消息的内容,对于tool调用可以为None
  738. :param tool_calls: 工具调用列表,每个元素是一个字典,包含 'name' 和 'arguments' 键
  739. :return: ChatCompletionMessage对象
  740. """
  741. if tool_calls:
  742. formatted_tool_calls = [
  743. ChatCompletionMessageToolCall(
  744. id=str(uuid.uuid4()), # 使用UUID生成随机ID
  745. type="function",
  746. function={
  747. "name": call['name'],
  748. "arguments": call['arguments']
  749. }
  750. )
  751. for call in tool_calls
  752. ]
  753. return ChatCompletionMessage(
  754. role=role,
  755. content=content,
  756. tool_calls=formatted_tool_calls
  757. )
  758. else:
  759. return ChatCompletionMessage(role=role, content=content)
  760. tokenizer = QWenTokenizer(Path(__file__).resolve().parent / 'qwen.tiktoken')