New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow passing in max_tokens:
to Langchain::LLM::Azure
#404
base: main
Are you sure you want to change the base?
Conversation
@@ -64,7 +64,7 @@ def complete(prompt:, **params) | |||
parameters = compose_parameters @defaults[:completion_model_name], params | |||
|
|||
parameters[:messages] = compose_chat_messages(prompt: prompt) | |||
parameters[:max_tokens] = validate_max_tokens(parameters[:messages], parameters[:model]) | |||
parameters[:max_tokens] = params[:max_tokens] || validate_max_tokens(parameters[:messages], parameters[:model]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is an improvement however I think the following might be better.
validate_max_tokens(parameters[:messages], parameters[:model]) # raises exception if maximum exceeded
parameters[:max_tokens] = params[:max_tokens] if params[:max_tokens]
This should still perform the validation and only pass max_tokens
to the OpenAI Client if it is set as an argument to this method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here's another approach that I forgot that @bricolage tackled: https://github.com/andreibondarev/langchainrb/pull/388/files
validate_max_tokens()
accepts a 3rd argument which is user passed-in max_tokens
and then selects the smaller (min) one between the allowed max_tokens and passed in one.
Thoughts on this approach here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the problem with that approach is when max_tokens
is not passed-in. This solution will set the max_tokens
to the maximum amount of tokens supported by the model. (and then Azure will consume the maximum amount of tokens)
In this scenario I think it would be better to not set anything and so the Azure defaults are used.
https://learn.microsoft.com/en-us/azure/ai-services/openai/reference
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can also skip the whole thing by adding a skip_max_tokens_validation: true/false
in each LLM class.
llm = Langchain::LLM::Azure.new ...
llm.skip_max_tokens_validation = true
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would provide a solution however I don't think its ideal. I think we always want to run the validation part of the validate_max_tokens
method because its good to know of the payload size exceeds the maximum supported context before making a request.
I think there are 4 use cases we are trying to support:
- No configuration - It should not set max_tokens and so it will us the Azure defaults
- Setting default max_tokens for an llm to be used for all requests
- Setting override for max_tokens per request (e..g via complete or chat methods)
- (current behaviour) Setting max_tokens to the remaining tokens in the context.
I think the following would satisfy these but its messy.
llm = Langchain::LLM::Azure.new ...
llm.default_max_tokens = 300 # use 300 tokens by default in request
# or
llm.default_max_tokens = :max # use all tokens available in request context.
# or (not set)
llm.default_max_tokens = nil
...
def complete(max_tokens: 300) # or max_tokens: :max
...
remaining_tokens = validate_max_tokens(parameters[:messages], parameters[:model]) # raises exception if maximum exceeded
max_tokens = params.fetch(:max_tokens, llm.default_max_tokens) # if not passed in use llm default max tokens
parameters[:max_tokens] = remaining_tokens if max_tokens == :max
parameters[:max_tokens] ||= max_tokens if max_tokens
...
# call OpenAI Client with parameters
end
No description provided.