前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >ChatGLM3-6B的Transformers.Model的核心接口说明

ChatGLM3-6B的Transformers.Model的核心接口说明

原创
作者头像
buzzfrog
修改2023-11-13 12:40:55
1.9K0
修改2023-11-13 12:40:55
举报
文章被收录于专栏:云上修行云上修行

背景

ChatGLM3-6B是10月底最新发布的智谱AI语言大模型。效果确实有明显的进步。但从文档上来看,仅有几个Demo以及B站官网视频 https://www.bilibili.com/video/BV1uC4y1J7yA 可供参考。但如果希望深入研究,关键的调用:

代码语言:txt
复制
model.stream_chat(tokenizer, input, history, past_key_values=past_key_values,
                            return_past_key_values=True,
                            max_length=max_length, 
                            top_p=top_p,
                            temperature=temperature)

到底每个参数是什么含义?

由于Huggingface上、modelscope.cn上以及chatglm的github上,都没有详细的核心接口说明。全网检索很久,也没有找到答案。最后经过研究,可以通过源码文件来了解:https://huggingface.co/THUDM/chatglm3-6b/blob/main/modeling_chatglm.py

本文通过给出相关接口注释,帮助大家了解相关接口的用法。

源码溯源

在huggingface的ChatGLM3-6B的主页中,点击Files标签页。

可以发现modeling_chatglm.py文件,接口代码即在其中。

接口注释

聊天函数

代码语言:python
复制
    @torch.inference_mode()
    def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
             max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
             **kwargs):
        """
        聊天函数,接受一段文本查询,返回模型的响应。

        参数:
            tokenizer: 用于处理输入和输出文本的tokenizer对象。
            query (str): 用户的文本输入。
            history (List[Dict], 可选): 对话历史,每一项都是一个字典,包含角色('role')和内容('content')。默认为None。
            role (str, 可选): 输入文本的角色,可以是'user'或者'assistant'。默认为'user'。
            max_length (int, 可选): 生成文本的最大长度。默认为8192。
            num_beams (int, 可选): Beam搜索的宽度,如果值大于1,则使用Beam搜索。默认为1。
            do_sample (bool, 可选): 是否从预测分布中进行采样。默认为True。
            top_p (float, 可选): 采用nucleus采样时的累积概率阈值。默认为0.8。
            temperature (float, 可选): 控制生成文本的随机性的参数。默认为0.8。
            logits_processor (LogitsProcessorList, 可选): 用于处理和修改生成步骤中的logits的对象。默认为None。
            **kwargs: 其他传递给模型生成函数的参数。

        返回:
            response (str): 模型的响应文本。
            history (List[Dict]): 更新后的对话历史。
        """
        if history is None:
            history = []
        if logits_processor is None:
            logits_processor = LogitsProcessorList()
        logits_processor.append(InvalidScoreLogitsProcessor())
        gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
                      "temperature": temperature, "logits_processor": logits_processor, **kwargs}
        inputs = tokenizer.build_chat_input(query, history=history, role=role)
        inputs = inputs.to(self.device)
        eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
                        tokenizer.get_command("<|observation|>")]
        outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
        outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
        response = tokenizer.decode(outputs)
        history.append({"role": role, "content": query})
        response, history = self.process_response(response, history)
        return response, history

流式聊天函数

代码语言:python
复制
    @torch.inference_mode()
    def stream_chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
                    past_key_values=None,max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
                    logits_processor=None, return_past_key_values=False, **kwargs):
        """
        流式聊天函数,接受一段文本查询,返回模型的响应。这个函数是一个生成器,可以在流式处理中使用。

        参数:
            tokenizer: 用于处理输入和输出文本的tokenizer对象。
            query (str): 用户的文本输入。
            history (List[Dict], 可选): 对话历史,每一项都是一个字典,包含角色('role')和内容('content')。默认为None。
            role (str, 可选): 输入文本的角色,可以是'user'或者'assistant'。默认为'user'。
            past_key_values (List[Tensor], 可选): 用于transformer模型的过去的键值对。默认为None。
            max_length (int, 可选): 生成文本的最大长度。默认为8192。
            do_sample (bool, 可选): 是否从预测分布中进行采样。默认为True。
            top_p (float, 可选): 采用nucleus采样时的累积概率阈值。默认为0.8。
            temperature (float, 可选): 控制生成文本的随机性的参数。默认为0.8。
            logits_processor (LogitsProcessorList, 可选): 用于处理和修改生成步骤中的logits的对象。默认为None。
            return_past_key_values (bool, 可选): 是否返回过去的键值对,用于下一步的生成。默认为False。
            **kwargs: 其他传递给模型生成函数的参数。

        返回:
            response (str): 模型的响应文本。
            history (List[Dict]): 更新后的对话历史。
            past_key_values (List[Tensor], 可选): 如果return_past_key_values为True,返回用于下一步生成的过去的键值对。
        """
        if history is None:
            history = []
        if logits_processor is None:
            logits_processor = LogitsProcessorList()
        logits_processor.append(InvalidScoreLogitsProcessor())
        eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
                        tokenizer.get_command("<|observation|>")]
        gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
                      "temperature": temperature, "logits_processor": logits_processor, **kwargs}
        if past_key_values is None:
            inputs = tokenizer.build_chat_input(query, history=history, role=role)
        else:
            inputs = tokenizer.build_chat_input(query, role=role)
        inputs = inputs.to(self.device)
        if past_key_values is not None:
            past_length = past_key_values[0][0].shape[0]
            if self.transformer.pre_seq_len is not None:
                past_length -= self.transformer.pre_seq_len
            inputs.position_ids += past_length
            attention_mask = inputs.attention_mask
            attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
            inputs['attention_mask'] = attention_mask
        history.append({"role": role, "content": query})
        for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,
                                            eos_token_id=eos_token_id, return_past_key_values=return_past_key_values,
                                            **gen_kwargs):
            if return_past_key_values:
                outputs, past_key_values = outputs
            outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
            response = tokenizer.decode(outputs)
            if response and response[-1] != "�":
                response, new_history = self.process_response(response, history)
                if return_past_key_values:
                    yield response, new_history, past_key_values
                else:
                    yield response, new_history

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 背景
  • 源码溯源
  • 接口注释
    • 聊天函数
      • 流式聊天函数
      领券
      问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档