123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881 |
- import os, copy, uuid, json, re
- import tiktoken, base64, unicodedata
- from pathlib import Path
- import urllib.parse
- from typing import Dict, Union, List, Literal, Any, Tuple,Collection,Set,Optional
- from pydantic import BaseModel, field_validator, model_validator
- from openai.types.chat import ChatCompletionMessage
- from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall
- VOCAB_FILES_NAMES = {'vocab_file': 'qwen.tiktoken'}
- PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
- ENDOFTEXT = '<|endoftext|>'
- IMSTART = '<|im_start|>'
- IMEND = '<|im_end|>'
- # as the default behavior is changed to allow special tokens in
- # regular texts, the surface forms of special tokens need to be
- # as different as possible to minimize the impact
- EXTRAS = tuple((f'<|extra_{i}|>' for i in range(205)))
- # changed to use actual index to avoid misconfiguration with vocabulary expansion
- SPECIAL_START_ID = 151643
- SPECIAL_TOKENS = tuple(enumerate(
- ((
- ENDOFTEXT,
- IMSTART,
- IMEND,
- ) + EXTRAS),
- start=SPECIAL_START_ID,
- ))
- ROLE = 'role'
- CONTENT = 'content'
- NAME = 'name'
- SYSTEM = 'system'
- USER = 'user'
- ASSISTANT = 'assistant'
- FUNCTION = 'function'
- FN_NAME = '✿FUNCTION✿'
- FN_ARGS = '✿ARGS✿'
- FN_RESULT = '✿RESULT✿'
- FN_EXIT = '✿RETURN✿'
- FN_STOP_WORDS = [FN_RESULT, FN_EXIT]
- FN_CALL_TEMPLATE_INFO_ZH = """# 工具
- ## 你拥有如下工具:
- {tool_descs}"""
- FN_CALL_TEMPLATE_INFO_EN = """# Tools
- ## You have access to the following tools:
- {tool_descs}"""
- FN_CALL_TEMPLATE_FMT_ZH = """## 你可以在回复中插入零次、一次或多次以下命令以调用工具:
- %s: 工具名称,必须是[{tool_names}]之一。
- %s: 工具输入
- %s: 工具结果
- %s: 根据工具结果进行回复,需将图片用渲染出来""" % (
- FN_NAME,
- FN_ARGS,
- FN_RESULT,
- FN_EXIT,
- )
- FN_CALL_TEMPLATE_FMT_EN = """## When you need to call a tool, please insert the following command in your reply, which can be called zero or multiple times according to your needs:
- %s: The tool to use, should be one of [{tool_names}]
- %s: The input of the tool
- %s: Tool results
- %s: Reply based on tool results. Images need to be rendered as """ % (
- FN_NAME,
- FN_ARGS,
- FN_RESULT,
- FN_EXIT,
- )
- FN_CALL_TEMPLATE_FMT_PARA_ZH = """## 你可以在回复中插入以下命令以并行调用N个工具:
- %s: 工具1的名称,必须是[{tool_names}]之一
- %s: 工具1的输入
- %s: 工具2的名称
- %s: 工具2的输入
- ...
- %s: 工具N的名称
- %s: 工具N的输入
- %s: 工具1的结果
- %s: 工具2的结果
- ...
- %s: 工具N的结果
- %s: 根据工具结果进行回复,需将图片用渲染出来""" % (
- FN_NAME,
- FN_ARGS,
- FN_NAME,
- FN_ARGS,
- FN_NAME,
- FN_ARGS,
- FN_RESULT,
- FN_RESULT,
- FN_RESULT,
- FN_EXIT,
- )
- FN_CALL_TEMPLATE_FMT_PARA_EN = """## Insert the following command in your reply when you need to call N tools in parallel:
- %s: The name of tool 1, should be one of [{tool_names}]
- %s: The input of tool 1
- %s: The name of tool 2
- %s: The input of tool 2
- ...
- %s: The name of tool N
- %s: The input of tool N
- %s: The result of tool 1
- %s: The result of tool 2
- ...
- %s: The result of tool N
- %s: Reply based on tool results. Images need to be rendered as """ % (
- FN_NAME,
- FN_ARGS,
- FN_NAME,
- FN_ARGS,
- FN_NAME,
- FN_ARGS,
- FN_RESULT,
- FN_RESULT,
- FN_RESULT,
- FN_EXIT,
- )
- FN_CALL_TEMPLATE = {
- 'zh': FN_CALL_TEMPLATE_INFO_ZH + '\n\n' + FN_CALL_TEMPLATE_FMT_ZH,
- 'en': FN_CALL_TEMPLATE_INFO_EN + '\n\n' + FN_CALL_TEMPLATE_FMT_EN,
- 'zh_parallel': FN_CALL_TEMPLATE_INFO_ZH + '\n\n' + FN_CALL_TEMPLATE_FMT_PARA_ZH,
- 'en_parallel': FN_CALL_TEMPLATE_INFO_EN + '\n\n' + FN_CALL_TEMPLATE_FMT_PARA_EN,
- }
- CHINESE_CHAR_RE = re.compile(r'[\u4e00-\u9fff]')
- class QWenTokenizer:
- """QWen tokenizer."""
- vocab_files_names = VOCAB_FILES_NAMES
- def __init__(
- self,
- vocab_file=None,
- errors='replace',
- extra_vocab_file=None,
- **kwargs,
- ):
- if not vocab_file:
- vocab_file = VOCAB_FILES_NAMES['vocab_file']
- self._decode_use_source_tokenizer = False
- # how to handle errors in decoding UTF-8 byte sequences
- # use ignore if you are in streaming inference
- self.errors = errors
- self.mergeable_ranks = self._load_tiktoken_bpe(vocab_file) # type: Dict[bytes, int]
- self.special_tokens = {token: index for index, token in SPECIAL_TOKENS}
- # try load extra vocab from file
- if extra_vocab_file is not None:
- used_ids = set(self.mergeable_ranks.values()) | set(self.special_tokens.values())
- extra_mergeable_ranks = self._load_tiktoken_bpe(extra_vocab_file)
- for token, index in extra_mergeable_ranks.items():
- if token in self.mergeable_ranks:
- continue
- if index in used_ids:
- continue
- self.mergeable_ranks[token] = index
- # the index may be sparse after this, but don't worry tiktoken.Encoding will handle this
- enc = tiktoken.Encoding(
- 'Qwen',
- pat_str=PAT_STR,
- mergeable_ranks=self.mergeable_ranks,
- special_tokens=self.special_tokens,
- )
- assert len(self.mergeable_ranks) + len(
- self.special_tokens
- ) == enc.n_vocab, f'{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding'
- self.decoder = {v: k for k, v in self.mergeable_ranks.items()} # type: dict[int, bytes|str]
- self.decoder.update({v: k for k, v in self.special_tokens.items()})
- self.tokenizer = enc # type: tiktoken.Encoding
- self.eod_id = self.tokenizer.eot_token
- self.im_start_id = self.special_tokens[IMSTART]
- self.im_end_id = self.special_tokens[IMEND]
- def _load_tiktoken_bpe(self, tiktoken_bpe_file: str) -> Dict[bytes, int]:
- with open(tiktoken_bpe_file, 'rb') as f:
- contents = f.read()
- return {
- base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line)
- }
- def __getstate__(self):
- # for pickle lovers
- state = self.__dict__.copy()
- del state['tokenizer']
- return state
- def __setstate__(self, state):
- # tokenizer is not python native; don't pass it; rebuild it
- self.__dict__.update(state)
- enc = tiktoken.Encoding(
- 'Qwen',
- pat_str=PAT_STR,
- mergeable_ranks=self.mergeable_ranks,
- special_tokens=self.special_tokens,
- )
- self.tokenizer = enc
- def __len__(self) -> int:
- return self.tokenizer.n_vocab
- def get_vocab(self) -> Dict[bytes, int]:
- return self.mergeable_ranks
- def convert_tokens_to_ids(self, tokens: Union[bytes, str, List[Union[bytes, str]]]) -> List[int]:
- ids = []
- if isinstance(tokens, (str, bytes)):
- if tokens in self.special_tokens:
- return self.special_tokens[tokens]
- else:
- return self.mergeable_ranks.get(tokens)
- for token in tokens:
- if token in self.special_tokens:
- ids.append(self.special_tokens[token])
- else:
- ids.append(self.mergeable_ranks.get(token))
- return ids
- def tokenize(
- self,
- text: str,
- allowed_special: Union[Set, str] = 'all',
- disallowed_special: Union[Collection, str] = (),
- **kwargs,
- ) -> List[Union[bytes, str]]:
- """
- Converts a string in a sequence of tokens.
- Args:
- text (`str`):
- The sequence to be encoded.
- allowed_special (`Literal["all"]` or `set`):
- The surface forms of the tokens to be encoded as special tokens in regular texts.
- Default to "all".
- disallowed_special (`Literal["all"]` or `Collection`):
- The surface forms of the tokens that should not be in regular texts and trigger errors.
- Default to an empty tuple.
- kwargs (additional keyword arguments, *optional*):
- Will be passed to the underlying model specific encode method.
- Returns:
- `List[bytes|str]`: The list of tokens.
- """
- tokens = []
- text = unicodedata.normalize('NFC', text)
- # this implementation takes a detour: text -> token id -> token surface forms
- for t in self.tokenizer.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special):
- tokens.append(self.decoder[t])
- return tokens
- def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
- """
- Converts a sequence of tokens in a single string.
- """
- text = ''
- temp = b''
- for t in tokens:
- if isinstance(t, str):
- if temp:
- text += temp.decode('utf-8', errors=self.errors)
- temp = b''
- text += t
- elif isinstance(t, bytes):
- temp += t
- else:
- raise TypeError('token should only be of type types or str')
- if temp:
- text += temp.decode('utf-8', errors=self.errors)
- return text
- @property
- def vocab_size(self):
- return self.tokenizer.n_vocab
- def _decode(
- self,
- token_ids: Union[int, List[int]],
- skip_special_tokens: bool = False,
- errors: str = None,
- **kwargs,
- ) -> str:
- if isinstance(token_ids, int):
- token_ids = [token_ids]
- if skip_special_tokens:
- token_ids = [i for i in token_ids if i < self.eod_id]
- return self.tokenizer.decode(token_ids, errors=errors or self.errors)
- def encode(self, text: str) -> List[int]:
- return self.convert_tokens_to_ids(self.tokenize(text))
- def count_tokens(self, text: str) -> int:
- return len(self.tokenize(text))
- def truncate(self, text: str, max_token: int, start_token: int = 0) -> str:
- token_list = self.tokenize(text)
- token_list = token_list[start_token:min(len(token_list), start_token + max_token)]
- return self.convert_tokens_to_string(token_list)
- class BaseModelCompatibleDict(BaseModel):
- def __getitem__(self, item):
- return getattr(self, item)
- def __setitem__(self, key, value):
- setattr(self, key, value)
- def model_dump(self, **kwargs):
- return super().model_dump(exclude_none=True, **kwargs)
- def model_dump_json(self, **kwargs):
- return super().model_dump_json(exclude_none=True, **kwargs)
- def get(self, key, default=None):
- try:
- value = getattr(self, key)
- if value:
- return value
- else:
- return default
- except AttributeError:
- return default
- def __str__(self):
- return f'{self.model_dump()}'
- class FunctionCall(BaseModelCompatibleDict):
- name: str
- arguments: str
- def __init__(self, name: str, arguments: str):
- super().__init__(name=name, arguments=arguments)
- def __repr__(self):
- return f'FunctionCall({self.model_dump()})'
- class ContentItem(BaseModelCompatibleDict):
- text: Optional[str] = None
- image: Optional[str] = None
- file: Optional[str] = None
- def __init__(self, text: Optional[str] = None, image: Optional[str] = None, file: Optional[str] = None):
- super().__init__(text=text, image=image, file=file)
- @model_validator(mode='after')
- def check_exclusivity(self):
- provided_fields = 0
- if self.text is not None:
- provided_fields += 1
- if self.image:
- provided_fields += 1
- if self.file:
- provided_fields += 1
- if provided_fields != 1:
- raise ValueError("Exactly one of 'text', 'image', or 'file' must be provided.")
- return self
- def __repr__(self):
- return f'ContentItem({self.model_dump()})'
- def get_type_and_value(self) -> Tuple[Literal['text', 'image', 'file'], str]:
- (t, v), = self.model_dump().items()
- assert t in ('text', 'image', 'file')
- return t, v
- @property
- def type(self) -> Literal['text', 'image', 'file']:
- t, v = self.get_type_and_value()
- return t
- @property
- def value(self) -> str:
- t, v = self.get_type_and_value()
- return v
- class Message(BaseModelCompatibleDict):
- role: str
- content: Union[str, List[ContentItem]]
- name: Optional[str] = None
- function_call: Optional[FunctionCall] = None
- def __init__(self,
- role: str,
- content: Optional[Union[str, List[ContentItem]]],
- name: Optional[str] = None,
- function_call: Optional[FunctionCall] = None,
- **kwargs):
- if content is None:
- content = ''
- super().__init__(role=role, content=content, name=name, function_call=function_call)
- def __repr__(self):
- return f'Message({self.model_dump()})'
- @field_validator('role')
- def role_checker(cls, value: str) -> str:
- if value not in [USER, ASSISTANT, SYSTEM, FUNCTION]:
- raise ValueError(f'{value} must be one of {",".join([USER, ASSISTANT, SYSTEM, FUNCTION])}')
- return value
- class messages_process:
- def __init__(self) -> None:
- pass
- def preprocess(self, messages, func):
- lang: Literal['en', 'zh'] = 'zh' if self.has_chinese_messages(messages) else 'en'
- new_messages = []
- # Only return dict when all input messages are dict
- if not messages:
- _return_message_type = 'message'
- for msg in messages:
- if isinstance(msg, dict):
- new_messages.append(Message(**msg))
- else:
- new_messages.append(msg)
- messages = copy.deepcopy(new_messages)
- messages = self._format_as_text_messages(messages)
- messages = self.prepend_fncall_system(messages, functions=func, lang=lang)
- if messages and messages[-1].role == ASSISTANT:
- assert len(messages) > 1 and messages[-2].role == USER
- assert messages[-1].function_call is None
- usr = messages[-2].content
- bot = messages[-1].content
- sep = '\n\n'
- if isinstance(usr, str) and isinstance(bot, str):
- usr = usr + sep + bot
- elif isinstance(usr, list) and isinstance(bot, list):
- usr = usr + [ContentItem(text=sep)] + bot
- else:
- raise NotImplementedError
- text_to_complete = copy.deepcopy(messages[-2])
- text_to_complete.content = usr
- messages = messages[:-2] + [text_to_complete]
- messages = [msg.model_dump() for msg in messages]
- return messages
- def post_process(self, messages, generate_cfg):
- messages = [self.format_as_multimodal_message(msg, add_upload_info=False) for msg in messages]
- if not generate_cfg.get('skip_stopword_postproc', False):
- stop = generate_cfg.get('stop', [])
- messages = self._postprocess_stop_words(messages, stop=stop)
- messages = self.postprocess_fncall_messages(messages)
- messages = self.convert_messages_to_target_type(messages, 'message')
- return messages
- def has_chinese_messages(self, messages: List[Union[Message, dict]], check_roles: Tuple[str] = (SYSTEM, USER)) -> bool:
- for m in messages:
- if m['role'] in check_roles:
- if self.has_chinese_chars(m['content']):
- return True
- return False
- def _format_as_text_messages(self, messages: List[Message]) -> List[Message]:
- for msg in messages:
- if isinstance(msg.content, list):
- for item in msg.content:
- assert item.type == 'text'
- else:
- assert isinstance(msg.content, str)
- messages = [self.format_as_text_message(msg, add_upload_info=False) for msg in messages]
- return messages
- def prepend_fncall_system(self, messages: List[Message], functions: List[Dict], lang: Literal['en', 'zh'],
- parallel_function_calls: bool = False, ):
- tool_desc_template = FN_CALL_TEMPLATE[lang + ('_parallel' if parallel_function_calls else '')]
- tool_descs = '\n\n'.join(self.get_function_description(function, lang=lang) for function in functions)
- tool_names = ','.join(function.get('name', function.get('name_for_model', '')) for function in functions)
- tool_system = tool_desc_template.format(tool_descs=tool_descs, tool_names=tool_names)
- assert messages[0].role == SYSTEM
- messages = copy.deepcopy(messages[:1]) + messages[1:]
- if isinstance(messages[0].content, str):
- messages[0].content += '\n\n' + tool_system
- else:
- messages[0].content.append(ContentItem(text='\n\n' + tool_system))
- return messages
- def get_function_description(self, function: Dict, lang: Literal['en', 'zh']) -> str:
- """
- Text description of function
- """
- tool_desc_template = {
- 'zh': '### {name_for_human}\n\n{name_for_model}: {description_for_model} 输入参数:{parameters} {args_format}',
- 'en': '### {name_for_human}\n\n{name_for_model}: {description_for_model} Parameters: {parameters} {args_format}'
- }
- tool_desc = tool_desc_template[lang]
- name = function.get('name', None)
- name_for_human = function.get('name_for_human', name)
- name_for_model = function.get('name_for_model', name)
- assert name_for_human and name_for_model
- if name_for_model == 'code_interpreter':
- args_format = {
- 'zh': '此工具的输入应为Markdown代码块。',
- 'en': 'Enclose the code within triple backticks (`) at the beginning and end of the code.',
- }
- else:
- args_format = {
- 'zh': '此工具的输入应为JSON对象。',
- 'en': 'Format the arguments as a JSON object.',
- }
- args_format = function.get('args_format', args_format[lang])
- return tool_desc.format(name_for_human=name_for_human,
- name_for_model=name_for_model,
- description_for_model=function['description'],
- parameters=json.dumps(function['parameters'], ensure_ascii=False),
- args_format=args_format).rstrip()
-
- def format_as_multimodal_message(
- self,
- msg: Message,
- add_upload_info: bool,
- lang: Literal['auto', 'en', 'zh'] = 'auto',
- ) -> Message:
- assert msg.role in (USER, ASSISTANT, SYSTEM, FUNCTION)
- content: List[ContentItem] = []
- if isinstance(msg.content, str): # if text content
- if msg.content:
- content = [ContentItem(text=msg.content)]
- elif isinstance(msg.content, list): # if multimodal content
- files = []
- for item in msg.content:
- k, v = item.get_type_and_value()
- if k == 'text':
- content.append(ContentItem(text=v))
- if k == 'image':
- content.append(item)
- if k in ('file', 'image'):
- # Move 'file' out of 'content' since it's not natively supported by models
- files.append(v)
- if add_upload_info and files and (msg.role in (SYSTEM, USER)):
- if lang == 'auto':
- has_zh = self.has_chinese_chars(msg)
- else:
- has_zh = (lang == 'zh')
- upload = []
- for f in [self.get_basename_from_url(f) for f in files]:
- if self.is_image(f):
- if has_zh:
- upload.append(f'')
- else:
- upload.append(f'')
- else:
- if has_zh:
- upload.append(f'[文件]({f})')
- else:
- upload.append(f'[file]({f})')
- upload = ' '.join(upload)
- if has_zh:
- upload = f'(上传了 {upload})\n\n'
- else:
- upload = f'(Uploaded {upload})\n\n'
- # Check and avoid adding duplicate upload info
- upload_info_already_added = False
- for item in content:
- if item.text and (upload in item.text):
- upload_info_already_added = True
- if not upload_info_already_added:
- content = [ContentItem(text=upload)] + content
- else:
- raise TypeError
- msg = Message(
- role=msg.role,
- content=content,
- name=msg.name if msg.role == FUNCTION else None,
- function_call=msg.function_call,
- )
- return msg
-
- def format_as_text_message(self,
- msg: Message,
- add_upload_info: bool,
- lang: Literal['auto', 'en', 'zh'] = 'auto',
- ) -> Message:
- msg = self.format_as_multimodal_message(msg, add_upload_info=add_upload_info, lang=lang)
- text = ''
- for item in msg.content:
- if item.type == 'text':
- text += item.value
- msg.content = text
- return msg
- def _postprocess_stop_words(self, messages: List[Message], stop: List[str]) -> List[Message]:
- messages = copy.deepcopy(messages)
- # Make sure it stops before stop words.
- trunc_messages = []
- for msg in messages:
- truncated = False
- trunc_content = []
- for i, item in enumerate(msg.content):
- item_type, item_text = item.get_type_and_value()
- if item_type == 'text':
- truncated, item.text = self._truncate_at_stop_word(text=item_text, stop=stop)
- trunc_content.append(item)
- if truncated:
- break
- msg.content = trunc_content
- trunc_messages.append(msg)
- if truncated:
- break
- messages = trunc_messages
- # It may ends with partial stopword 'Observation' when the full stopword is 'Observation:'.
- # The following post-processing step removes partial stop words.
- partial_stop = []
- for s in stop:
- s = tokenizer.tokenize(s)[:-1]
- if s:
- s = tokenizer.convert_tokens_to_string(s)
- partial_stop.append(s)
- partial_stop = sorted(set(partial_stop))
- last_msg = messages[-1].content
- for i in range(len(last_msg) - 1, -1, -1):
- item_type, item_text = last_msg[i].get_type_and_value()
- if item_type == 'text':
- for s in partial_stop:
- if item_text.endswith(s):
- last_msg[i].text = item_text[:-len(s)]
- break
- return messages
-
- def postprocess_fncall_messages(self, messages: List[Message]) -> List[Message]:
- """
- If the model calls function by built-in function call template,
- convert and display it in function_call format.
- """
- # Remove ': ' brought by continued generation of function calling
- last_msg = messages[-1].content
- for i in range(len(last_msg)):
- item_type, item_text = last_msg[i].get_type_and_value()
- if item_type == 'text':
- if item_text.startswith(': '):
- last_msg[i].text = item_text[2:]
- elif item_text.startswith(':'):
- last_msg[i].text = item_text[1:]
- break
- new_messages = []
- for msg in messages:
- role, content = msg.role, msg.content
- assert isinstance(content, list)
- if role in (SYSTEM, USER):
- new_messages.append(Message(role=role, content=content))
- continue
- new_content = []
- for item in content:
- item_type, item_text = item.get_type_and_value()
- if item_type != 'text': # multimodal
- new_content.append(item)
- continue
- for stop_word in [FN_RESULT, FN_EXIT]:
- assert stop_word in FN_STOP_WORDS
- assert stop_word not in item_text, 'Something wrong, stop words are expected to be excluded.'
- i = item_text.find(f'{FN_NAME}:')
- # If no function call:
- if i < 0:
- show_text = self.remove_incomplete_special_tokens(item_text)
- if show_text:
- new_content.append(ContentItem(text=show_text))
- continue
- # If it says something before function call:
- if i > 0:
- answer = item_text[:i].lstrip('\n').rstrip()
- if answer.endswith('\n'):
- answer = answer[:-1]
- show_text = self.remove_incomplete_special_tokens(answer)
- if show_text:
- new_content.append(ContentItem(text=show_text))
- if new_content:
- new_messages.append(Message(
- role=role,
- content=new_content,
- )) # split thought and function call
- new_content = []
- item_text = item_text[i:]
- # If has function call:
- for part in item_text.split(f'{FN_NAME}:'):
- if not part:
- continue
- if part.endswith('\n'):
- part = part[:-1]
- arg_sep = f'{FN_ARGS}:'
- i = part.find(arg_sep)
- if i < 0:
- fn_name = part.strip()
- list_of_fn_args = ['']
- else:
- fn_name = part[:i].strip()
- list_of_fn_args = [_.strip() for _ in part[i + len(arg_sep):].split(arg_sep)]
- fn_name = self.remove_incomplete_special_tokens(fn_name)
- for fn_args in list_of_fn_args:
- fn_args = self.remove_incomplete_special_tokens(fn_args)
- fn_args = self.remove_trailing_comment_of_fn_args(fn_args)
- new_messages.append(
- Message(
- role=ASSISTANT,
- content=[],
- function_call=FunctionCall(
- name=fn_name,
- arguments=fn_args,
- ),
- ))
- # Break here and discard the text after function call
- return new_messages
- if new_content:
- new_messages.append(Message(role=role, content=new_content))
- return new_messages
-
- def remove_incomplete_special_tokens(self, text: str) -> str:
- special_tokens = (FN_NAME, FN_ARGS, FN_RESULT, FN_EXIT)
- text = text.rstrip()
- if text.endswith(special_tokens):
- for s in special_tokens:
- if text.endswith(s):
- text = text[:-len(s)]
- break
- else:
- trail_start = text.rfind('✿')
- trail_token = text[trail_start:]
- for s in special_tokens:
- if s.startswith(trail_token):
- text = text[:trail_start]
- break
- text = text.lstrip('\n').rstrip()
- return text
-
- # For hotfix badcases such as `{"arg1": "value1"} <!-- this is an example comment -->`.
- def remove_trailing_comment_of_fn_args(self, fn_args: str):
- fn_args = fn_args.strip()
- if fn_args.startswith('{'):
- k = fn_args.rfind('}')
- if k > 0:
- fn_args = fn_args[:k + 1]
- if fn_args.startswith('```'):
- k = fn_args.rfind('\n```')
- if k > 0:
- fn_args = fn_args[:k + 4]
- return fn_args
-
- def convert_messages_to_target_type(self, messages: List[Message],
- target_type: str) -> Union[List[Message], List[Dict]]:
- if target_type == 'message':
- return [Message(**x) if isinstance(x, dict) else x for x in messages]
- elif target_type == 'dict':
- return [x.model_dump() if not isinstance(x, dict) else x for x in messages]
- else:
- raise NotImplementedError
-
- def has_chinese_chars(self, data: Any) -> bool:
- text = f'{data}'
- return bool(CHINESE_CHAR_RE.search(text))
-
- def _truncate_at_stop_word(self, text: str, stop: List[str]):
- truncated = False
- for s in stop:
- k = text.find(s)
- if k >= 0:
- truncated = True
- text = text[:k]
- return truncated, text
-
- def is_image(self, path_or_url: str) -> bool:
- filename = self.get_basename_from_url(path_or_url).lower()
- for ext in ['jpg', 'jpeg', 'png', 'webp']:
- if filename.endswith(ext):
- return True
- return False
-
- def get_basename_from_url(self, path_or_url: str) -> str:
- if re.match(r'^[A-Za-z]:\\', path_or_url):
- # "C:\\a\\b\\c" -> "C:/a/b/c"
- path_or_url = path_or_url.replace('\\', '/')
- # "/mnt/a/b/c" -> "c"
- # "https://github.com/here?k=v" -> "here"
- # "https://github.com/" -> ""
- basename = urllib.parse.urlparse(path_or_url).path
- basename = os.path.basename(basename)
- basename = urllib.parse.unquote(basename)
- basename = basename.strip()
- # "https://github.com/" -> "" -> "github.com"
- if not basename:
- basename = [x.strip() for x in path_or_url.split('/') if x.strip()][-1]
- return basename
-
- @staticmethod
- def detect_tool(message: Message) -> Tuple[bool, str, str, str]:
- func_name = None
- func_args = None
- if message.function_call:
- func_call = message.function_call
- func_name = func_call.name
- func_args = func_call.arguments
- text = message.content
- if not text:
- text = ''
- return (func_name is not None), func_name, func_args, text
-
- @staticmethod
- def create_chat_completion_message(
- role: str,
- content: Optional[str] = None,
- tool_calls: Optional[List[dict]] = None
- ) -> ChatCompletionMessage:
- """
- 创建一个ChatCompletionMessage对象,支持添加tool_calls。
-
- :param role: 消息的角色("system", "user", "assistant", "tool")
- :param content: 消息的内容,对于tool调用可以为None
- :param tool_calls: 工具调用列表,每个元素是一个字典,包含 'name' 和 'arguments' 键
- :return: ChatCompletionMessage对象
- """
- if tool_calls:
- formatted_tool_calls = [
- ChatCompletionMessageToolCall(
- id=str(uuid.uuid4()), # 使用UUID生成随机ID
- type="function",
- function={
- "name": call['name'],
- "arguments": call['arguments']
- }
- )
- for call in tool_calls
- ]
- return ChatCompletionMessage(
- role=role,
- content=content,
- tool_calls=formatted_tool_calls
- )
- else:
- return ChatCompletionMessage(role=role, content=content)
-
- tokenizer = QWenTokenizer(Path(__file__).resolve().parent / 'qwen.tiktoken')
|