Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow passing in max_tokens: to Langchain::LLM::Azure #404

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

andreibondarev
Copy link
Collaborator

No description provided.

@@ -64,7 +64,7 @@ def complete(prompt:, **params)
parameters = compose_parameters @defaults[:completion_model_name], params

parameters[:messages] = compose_chat_messages(prompt: prompt)
parameters[:max_tokens] = validate_max_tokens(parameters[:messages], parameters[:model])
parameters[:max_tokens] = params[:max_tokens] || validate_max_tokens(parameters[:messages], parameters[:model])
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is an improvement however I think the following might be better.

validate_max_tokens(parameters[:messages], parameters[:model]) # raises exception if maximum exceeded
parameters[:max_tokens] = params[:max_tokens] if params[:max_tokens]

This should still perform the validation and only pass max_tokens to the OpenAI Client if it is set as an argument to this method.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's another approach that I forgot that @bricolage tackled: https://github.com/andreibondarev/langchainrb/pull/388/files

validate_max_tokens() accepts a 3rd argument which is user passed-in max_tokens and then selects the smaller (min) one between the allowed max_tokens and passed in one.

Thoughts on this approach here?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the problem with that approach is when max_tokens is not passed-in. This solution will set the max_tokens to the maximum amount of tokens supported by the model. (and then Azure will consume the maximum amount of tokens)

In this scenario I think it would be better to not set anything and so the Azure defaults are used.
https://learn.microsoft.com/en-us/azure/ai-services/openai/reference

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also skip the whole thing by adding a skip_max_tokens_validation: true/false in each LLM class.

llm = Langchain::LLM::Azure.new ...
llm.skip_max_tokens_validation = true

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would provide a solution however I don't think its ideal. I think we always want to run the validation part of the validate_max_tokens method because its good to know of the payload size exceeds the maximum supported context before making a request.

I think there are 4 use cases we are trying to support:

  1. No configuration - It should not set max_tokens and so it will us the Azure defaults
  2. Setting default max_tokens for an llm to be used for all requests
  3. Setting override for max_tokens per request (e..g via complete or chat methods)
  4. (current behaviour) Setting max_tokens to the remaining tokens in the context.

I think the following would satisfy these but its messy.

llm = Langchain::LLM::Azure.new ...
llm.default_max_tokens = 300 # use 300 tokens by default in request
# or
llm.default_max_tokens = :max # use all tokens available in request context.
# or (not set)
llm.default_max_tokens = nil 
...

def complete(max_tokens: 300) # or max_tokens: :max
  ...
  remaining_tokens = validate_max_tokens(parameters[:messages], parameters[:model]) # raises exception if  maximum exceeded

  max_tokens =  params.fetch(:max_tokens, llm.default_max_tokens) # if not passed in use llm default max tokens

  parameters[:max_tokens] = remaining_tokens if max_tokens == :max
  parameters[:max_tokens] ||= max_tokens if max_tokens
  ...
  # call OpenAI Client with parameters
end

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants