-
Notifications
You must be signed in to change notification settings - Fork 3.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug] Add Azure-hosted Dalle Image Generation Support #2586
base: main
Are you sure you want to change the base?
Conversation
The failing compressible agent tests is not caused by this PR. |
@@ -287,5 +299,9 @@ def _validate_resolution_format(resolution: str): | |||
|
|||
|
|||
def _validate_dalle_model(model: str): | |||
if model not in ["dall-e-3", "dall-e-2"]: | |||
raise ValueError(f"Invalid DALL-E model: {model}. Must be 'dall-e-3' or 'dall-e-2'") | |||
if model not in VALID_DALLE_MODELS: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For Azure models, they are not necessarily named as one of VALID_DALLE_MODELS. They can be named as anything.
@@ -108,6 +114,12 @@ def cache_key(self, prompt: str) -> str: | |||
keys = (prompt, self._model, self._resolution, self._quality, self._num_images) | |||
return ",".join([str(k) for k in keys]) | |||
|
|||
def _dalle_client_factory(self, dalle_config: Dict) -> Union[OpenAI, AzureOpenAI]: | |||
if dalle_config.get("api_type") == "azure": | |||
return AzureOpenAI(api_key=dalle_config["api_key"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AzureOpenAI requires more fields than api_key.
if dalle_config.get("api_type") == "azure": | ||
return AzureOpenAI(api_key=dalle_config["api_key"]) | ||
else: | ||
return OpenAI(api_key=dalle_config["api_key"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For OpenAI, I'm not sure if more keys are possible. Why do we pass the api_key
only here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good PR. My biggest concern is mentioned by Chi. Azure endpoints allow users to rename their model, which is not necessarily named as one of VALID_DALLE_MODELS.
@sonichi @BeibinLi thank you for the reviews! Initially, I assumed that OpenAI and AzureOpenAI were identical. I’ve been attempting to get access to Azure OpenAI APIs, but it seems like I need a corporate account/email. As I’m unable to test these changes locally I’ll keep this PR in draft until I manage to access Azure OpenA. |
This PR adds support for Dalle image generation hosted on Azure.
Why are these changes needed?
Related issue number
Closes #2510
Checks