Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Agent tuning 多轮会话使用不同的参数连续问相同的话题, 工具识别失败 #3734

Closed
1 task done
frankyu83 opened this issue May 14, 2024 · 0 comments
Closed
1 task done
Labels
wontfix This will not be worked on

Comments

@frankyu83
Copy link

Reminder

  • I have read the README and searched the existing issues.

Reproduction

一、第一组训练参数和测试用例

sft 参数

CUDA_VISIBLE_DEVICES=0 USE_MODELSCOPE_HUB=1 python src/train.py \
    --stage sft \
    --do_train True \
    --model_name_or_path 01ai/Yi-9B \
    --finetuning_type lora \
    --template default \
    --use_unsloth True \
    --dataset_dir data \
    --dataset glaive_toolcall,alpaca_gpt4_en,alpaca_gpt4_zh,oaast_sft_zh \
    --cutoff_len 4096 \
    --learning_rate 5e-05 \
    --num_train_epochs 3.0 \
    --max_samples 10000 \
    --per_device_train_batch_size 4 \
    --gradient_accumulation_steps 4 \
    --lr_scheduler_type cosine \
    --max_grad_norm 1.0 \
    --logging_steps 5 \
    --save_steps 300 \
    --warmup_steps 0 \
    --output_dir saves/Yi-6B/lora/Yi-9B-agent-8 \
    --fp16 True \
    --lora_rank 8 \
    --lora_dropout 0.1 \
    --lora_target all \
    --plot_loss True

训练结束后 train_loss = 0.8420096

openai-style api 启动脚本

CUDA_VISIBLE_DEVICES=0 API_PORT=8000 python src/api.py \
    --model_name_or_path ~/models/Yi-9B-agent-8 \
    --template default \
    --infer_backend vllm \
    --vllm_enforce_eager

postman 测试用例

1.单轮会话,正确识别工具

2.两轮会话,第1轮问问题a、第2轮问问题b,2轮均正确识别工具

request={
	"model":"Yi-9B-agent",
	"messages":[
	    {
	        "role": "user",
	        "content": "Hi, I need to know the distance between New York and Los Angeles"
	    },
	    {
	        "role": "assistant",
	        "content": null,
	        "tool_calls": [
	            {
	                "id": "call_3a25dba5142746c28b902d2947b03cef",
	                "type": "function",
	                "function": {
	                    "name": "calculate_distance",
	                    "arguments": "{\"location1\": \"New York\", \"location2\": \"Los Angeles\"}"
	                }
	            }
	        ]
	    },
	    {
	        "role": "tool",
	        "tool_call_id": "call_3a25dba5142746c28b902d2947b03cef",
	        "content": "{\"distance\": 2448.2, \"unit\": \"miles\"}"
	    },
	    {
	        "role": "assistant",
	        "content": "The distance between New York and Los Angeles is approximately 2448.2 miles."
	    },
	    {
	        "role": "user",
	        "content": "50路客流详情"
	    }
	]
	,
	"temperature":0.1,
	"top_p":0.9,
	"stream":false,
	"tools":[
	  {
	    "type": "function",
	    "function":{
	      "name": "get_line_passenger_flow_detail",
	      "description": "线路客流详情",
	      "parameters": {
	          "type": "object",
	          "properties": {
	              "lineName": {"type": "string", "description": "线路名,例如 11、20"},
	              "hour": {"type": "string", "description": "小时,24小时制,没有则空"}
	          },
	          "required": ["lineName"]
	      }
	    }
	  },
	  {
	    "type":"function",
	    "function":{
	      "name":"get_line_passenger_flow_rank",
	      "description":"获取线路客流排名",
	      "parameters":{
	        "type":"object",
	        "properties":{
	          "sort":{"type":"string","enum":["最高","最低","第一","最多","最少"],"description":"排序方式"}
	        }
	      }
	    }
	  },
	  {
	    "type":"function",
	    "function":{
	      "name": "calculate_distance",
	      "description": "Calculate the distance between two locations",
	      "parameters": {
	          "type": "object",
	          "properties": {
	              "location1": {
	                  "type": "string",
	                  "description": "The first location"
	              },
	              "location2": {
	                  "type": "string",
	                  "description": "The second location"
	              }
	          },
	          "required": [
	              "location1",
	              "location2"
	          ]
	      }
	    }
	  }
	]
}


response={
    "id": "chatcmpl-45e803572b724bbf8fafe758aef461b8",
    "object": "chat.completion",
    "created": 1715592813,
    "model": "Yi-9B-agent",
    "choices": [
        {
            "index": 0,
            "message": {
                "role": "assistant",
                "content": null,
                "tool_calls": [
                    {
                        "id": "call_46e4fc51869e4e648abed252bd36f5ce",
                        "type": "function",
                        "function": {
                            "name": "get_line_passenger_flow_detail",
                            "arguments": "{\"lineName\": \"50\"}"
                        }
                    }
                ]
            },
            "finish_reason": "tool_calls"
        }
    ],
    "usage": {
        "prompt_tokens": 444,
        "completion_tokens": 26,
        "total_tokens": 470
    }
}

3.两轮会话,第1轮问问题a(对话历史中包含function_call、observation message )、第2轮问问题a(问题中参数不同),第1轮正确识别工具、第2轮工具识别失败

request={
	"model":"Yi-9B-agent",
	"messages":[
		{"role":"user","content":"936路客流详情"},
		{"role": "assistant", "content": null, "tool_calls": [
	      {
	        "id": "call_d6d243b5f6204f66a4c56d0f8c9d4f5f",
	        "type": "function",
	        "function": {
	          "name": "get_line_passenger_flow_detail",
	          "arguments": "{\"lineName\": \"936\", \"hour\": \"\"}"
	        }
	      }
	   ]},
	  {"role": "tool", "tool_call_id": "call_d6d243b5f6204f66a4c56d0f8c9d4f5f", "content":  "{\"时段客流\":[{\"name\":\"14\",\"num\":\"404\",\"averageNum\":\"363\"},{\"name\":\"15\",\"num\":\"311\",\"averageNum\":\"367\"},{\"name\":\"16\",\"num\":\"0\",\"averageNum\":\"428\"}],\"日均客流\":\"6024\"}"},
		{"role":"assistant","content":"936路客流详情如下:\n时段客流:\n- 14: 404人,平均客流为363人\n- 15: 311人,平均客流为367人\n- 16: 0人,平均客流为428人\n日均客流:6024人"},
		{"role":"user","content":"50路客流详情"}
	],
	"temperature":0.1,
	"top_p":0.9,
	"stream":false,
	"tools":[
	  {
	    "type": "function",
	    "function":{
	      "name": "get_line_passenger_flow_detail",
	      "description": "线路客流详情",
	      "parameters": {
	          "type": "object",
	          "properties": {
	              "lineName": {"type": "string", "description": "线路名,例如 11、20"},
	              "hour": {"type": "string", "description": "小时,24小时制,没有则空"}
	          },
	          "required": ["lineName"]
	      }
	    }
	  },
	  {
	    "type":"function",
	    "function":{
	      "name":"get_line_passenger_flow_rank",
	      "description":"获取线路客流排名",
	      "parameters":{
	        "type":"object",
	        "properties":{
	          "sort":{"type":"string","enum":["最高","最低","第一","最多","最少"],"description":"排序方式"}
	        }
	      }
	    }
	  },
	  {
	    "type":"function",
	    "function":{
	      "name": "calculate_distance",
	      "description": "Calculate the distance between two locations",
	      "parameters": {
	          "type": "object",
	          "properties": {
	              "location1": {
	                  "type": "string",
	                  "description": "The first location"
	              },
	              "location2": {
	                  "type": "string",
	                  "description": "The second location"
	              }
	          },
	          "required": [
	              "location1",
	              "location2"
	          ]
	      }
	    }
	  }
	]
}
response={
    "id": "chatcmpl-247c75a2e6e644f48a98bb76d29a0e86",
    "object": "chat.completion",
    "created": 1715595132,
    "model": "Yi-9B-agent",
    "choices": [
        {
            "index": 0,
            "message": {
                "role": "assistant",
                "content": "Action: \nAction Input: ",
                "tool_calls": null
            },
            "finish_reason": "stop"
        }
    ],
    "usage": {
        "prompt_tokens": 503,
        "completion_tokens": 9,
        "total_tokens": 512
    }
}

二、第二组训练参数和测试用例

考虑第一组模型可能训练不够充分,此组模型训练时只使用了 glaive_toolcall一个数据集,训练到第1个 epoch结束时 train_loss = 0.4242 就暂停了训练
sft 参数

CUDA_VISIBLE_DEVICES=0 USE_MODELSCOPE_HUB=1 python src/train.py \
    --stage sft \
    --do_train True \
    --model_name_or_path 01ai/Yi-9B \
    --finetuning_type lora \
    --template default \
    --use_unsloth True \
    --dataset_dir data \
    --dataset glaive_toolcall \
    --cutoff_len 4096 \
    --learning_rate 5e-05 \
    --num_train_epochs 1.0 \
    --max_samples 10000 \
    --per_device_train_batch_size 4 \
    --gradient_accumulation_steps 4 \
    --lr_scheduler_type cosine \
    --max_grad_norm 1.0 \
    --logging_steps 5 \
    --save_steps 300 \
    --warmup_steps 0 \
    --output_dir saves/Yi-6B/lora/Yi-9B-agent-9 \
    --fp16 True \
    --lora_rank 8 \
    --lora_dropout 0.1 \
    --lora_target all \
    --plot_loss True

训练结束后 train_loss = 0.4242

postman 测试用例
与第一组3个测试用例一致
测试结果一样

三、测试 glaiveai/glaive-function-calling-v2-small模型,以验证 glaive-function-calling-v2数据集是否支持多轮会话工具识别

两轮会话,第1轮问问题a(对话历史中包含function_call、observation message )、第2轮问问题a(问题中参数不同),2轮均正确识别工具

测试python脚本如下

from transformers import AutoModelForCausalLM , AutoTokenizer

prompt = '''SYSTEM: You are an helpful assistant who has access to the following functions to help the user, you can use the functions if needed-
[
{
  "name": "get_current_weather",
  "description": "get the weather of the location",
  "parameters": {
      "type": "object",
      "properties": {
          "location": {"type": "string", "description": "城市"},
          "unit": {"type": "string","enum": ["摄氏度","华氏度"]}
      },
      "required": ["location"]
  }
},
{
  "name": "calculate_distance",
  "description": "Calculate the distance between two locations",
  "parameters": {
      "type": "object",
      "properties": {
          "location1": {
              "type": "string",
              "description": "The first location"
          },
          "location2": {
              "type": "string",
              "description": "The second location"
          }
      },
      "required": [
          "location1",
          "location2"
      ]
  }
}
]
USER: Hi, I need to know the distance between New York and Los Angeles?
ASSISTANT: <functioncall> {"name": "calculate_distance", "arguments": '{"location1": "New York", "location2": "Los Angeles"}'}
FUNCTION CALL: {"distance": 2448.2, "unit": "miles"}
ASSISTANT: The distance between New York and Los Angeles is approximately 2448.2 miles.
USER: What about the distance from London to Paris?

'''


tokenizer = AutoTokenizer.from_pretrained("/models/glaiveai/glaive-function-calling-v2-small", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("/models/glaiveai/glaive-function-calling-v2-small", trust_remote_code=True).half().cuda()

inputs = tokenizer(prompt,return_tensors="pt").to(model.device)

outputs = model.generate(**inputs,do_sample=True,temperature=0.1,top_p=0.95,max_new_tokens=100)

print(tokenizer.decode(outputs[0],skip_special_tokens=True))

print output

ASSISTANT: <functioncall> {"name": "calculate_distance", "arguments": '{"location1": "London", "location2": "Paris"}'}

四、统计data/glaive_toolcall_10k.json数据集中工具调用sample占比情况

单轮工具= 3303
多轮工具 = 1805
totalSample = 10000
一、二 两个测试组,训练参数 max_samples=10000, 训练数据中多轮样本大于1000了,应满足训练需要了

Expected behavior

多轮会话使用不同的参数问相同的话题,正确识别工具

System Info

  • transformers version: 4.39.3
  • Platform: Linux-5.15.146.1-microsoft-standard-WSL2-x86_64-with-glibc2.31
  • Python version: 3.10.12
  • Huggingface_hub version: 0.20.3
  • Safetensors version: 0.4.2
  • Accelerate version: 0.29.2
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.1.2+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: Yes RTX4090
  • Using distributed or parallel set-up in script?: No

Others

请教要怎么改进才能达到预期结果

@hiyouga hiyouga added pending This problem is yet to be addressed. wontfix This will not be worked on and removed pending This problem is yet to be addressed. labels May 24, 2024
@hiyouga hiyouga closed this as not planned Won't fix, can't repro, duplicate, stale May 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
wontfix This will not be worked on
Projects
None yet
Development

No branches or pull requests

2 participants