Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(diffusers): add 'safety_check' pipeline argument #7862

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

rickstaa
Copy link

@rickstaa rickstaa commented May 5, 2024

What does this PR do?

This pull request introduces the safety_check argument to the call method of the StableDiffusionPipeline. This new argument provides users with the flexibility to dynamically enable or disable safety checks during a pipeline execution. The primary motivation for this feature is to give users the option to filter NSFW content when generating images, depending on their specific needs.

Alternatives Considered

I'm currently using a code snippet provided on the forum to dynamically toggle the safety check, but integrating this functionality directly into the diffusion pipeline would streamline my code and could also benefit others.

from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from transformers import CLIPFeatureExtractor
import numpy as np
import torch
from PIL import Image 
from typing import Optional, Tuple, Union

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch_device = device
torch_dtype = torch.float16

safety_checker = StableDiffusionSafetyChecker.from_pretrained(
    "CompVis/stable-diffusion-safety-checker"
).to(device)
feature_extractor = CLIPFeatureExtractor.from_pretrained(
    "openai/clip-vit-base-patch32"
)

def check_nsfw_images(
    images: list[Image.Image],
    output_type: str | None = "pil"
) -> tuple[list[Image.Image], list[bool]]:
    safety_checker_input = feature_extractor(images, return_tensors="pt").to(device)
    images_np = [np.array(img) for img in images]

    _, has_nsfw_concepts = safety_checker(
        images=images_np,
        clip_input=safety_checker_input.pixel_values.to(torch_device),
    )
    if output_type == "pil":
      return images, has_nsfw_concepts
    return images_np, has_nsfw_concepts

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@rickstaa rickstaa force-pushed the add_safety_check_pipeline_argument branch 3 times, most recently from 02e00b6 to 28c1cee Compare May 5, 2024 18:01
This commit introduces a new `safety_check` pipeline argument to the
Stable Diffusion pipeline. This argument allows users to dynamically
enable or disable the safety check during a pipeline call.
@rickstaa rickstaa force-pushed the add_safety_check_pipeline_argument branch from 28c1cee to a4b1c85 Compare May 5, 2024 18:02
@sayakpaul
Copy link
Member

I don't think we need an argument like this given one can manually set safety_checker argument of a StableDiffusionPipeline to None.

@yiyixuxu WDYT?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants