Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a bool fold_lowercase to whisper_context_params #2005

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -3,7 +3,7 @@
import com.sun.jna.Structure;
import com.sun.jna.ptr.PointerByReference;
import io.github.ggerganov.whispercpp.ggml.GgmlType;
import io.github.ggerganov.whispercpp.WhisperModel;
import io.github.ggerganov.whispercpp.model.WhisperModel;
import io.github.ggerganov.whispercpp.params.WhisperContextParams;

import java.util.List;
Expand Down
@@ -1,4 +1,4 @@
package io.github.ggerganov.whispercpp;
package io.github.ggerganov.whispercpp.model;

import io.github.ggerganov.whispercpp.ggml.GgmlTensor;
import io.github.ggerganov.whispercpp.model.EModel;
Expand Down
@@ -0,0 +1,16 @@
package io.github.ggerganov.whispercpp.params;

import com.sun.jna.*;

import java.util.Arrays;
import java.util.List;

public class WhisperAheads extends Structure {
public long n_heads;
public Pointer heads;

@Override
protected List<String> getFieldOrder() {
return Arrays.asList("n_heads", "heads");
}
}
@@ -0,0 +1,18 @@
package io.github.ggerganov.whispercpp.params;

public enum WhisperAlignmentHeadsPreset {
WHISPER_AHEADS_NONE,
WHISPER_AHEADS_N_TOP_MOST, // All heads from the N-top-most text-layers
WHISPER_AHEADS_CUSTOM,
WHISPER_AHEADS_TINY_EN,
WHISPER_AHEADS_TINY,
WHISPER_AHEADS_BASE_EN,
WHISPER_AHEADS_BASE,
WHISPER_AHEADS_SMALL_EN,
WHISPER_AHEADS_SMALL,
WHISPER_AHEADS_MEDIUM_EN,
WHISPER_AHEADS_MEDIUM,
WHISPER_AHEADS_LARGE_V1,
WHISPER_AHEADS_LARGE_V2,
WHISPER_AHEADS_LARGE_V3
}
Expand Up @@ -19,13 +19,26 @@ public WhisperContextParams(Pointer p) {
/** Use GPU for inference Number (default = true) */
public CBool use_gpu;

/** CUDA device */
public int gpu_device;

/** Fold language tokens to lowercase */
public CBool vocab_lc;

public CBool dtw_token_timestamps;
public WhisperAlignmentHeadsPreset dtw_aheads_preset;
public int dtw_n_top;
public WhisperAheads dtw_aheads;
public long dtw_mem_size;

/** Use GPU for inference Number (default = true) */
public void useGpu(boolean enable) {
use_gpu = enable ? CBool.TRUE : CBool.FALSE;
}

@Override
protected List<String> getFieldOrder() {
return Arrays.asList("use_gpu");
return Arrays.asList("use_gpu", "gpu_device", "vocab_lc", "dtw_token_timestamps", "dtw_aheads_preset",
"dtw_n_top", "dtw_aheads", "dtw_mem_size");
}
}
4 changes: 4 additions & 0 deletions examples/command/command.cpp
Expand Up @@ -44,6 +44,7 @@ struct whisper_params {
bool print_energy = false;
bool no_timestamps = true;
bool use_gpu = true;
bool model_fold_lc = false;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change name to vocab_lc


std::string language = "en";
std::string model = "models/ggml-base.en.bin";
Expand Down Expand Up @@ -77,6 +78,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else if ( arg == "--model-fold-lc") { params.model_fold_lc = true; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
Expand Down Expand Up @@ -114,6 +116,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
fprintf(stderr, " --model-fold-lc [%-7s] fold all model tokens to lowercase\n", params.model_fold_lc ? "true" : "false");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
Expand Down Expand Up @@ -690,6 +693,7 @@ int main(int argc, char ** argv) {

struct whisper_context_params cparams = whisper_context_default_params();
cparams.use_gpu = params.use_gpu;
cparams.vocab_lc = params.model_fold_lc;

struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);

Expand Down
4 changes: 4 additions & 0 deletions examples/main/main.cpp
Expand Up @@ -65,6 +65,7 @@ struct whisper_params {
bool no_timestamps = false;
bool log_score = false;
bool use_gpu = true;
bool model_fold_lc = false;

std::string language = "en";
std::string prompt;
Expand Down Expand Up @@ -145,6 +146,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
else if ( arg == "--model-fold-lc") { params.model_fold_lc = true; }
else if (arg == "-l" || arg == "--language") { params.language = whisper_param_turn_lowercase(argv[++i]); }
else if (arg == "-dl" || arg == "--detect-language") { params.detect_language = true; }
else if ( arg == "--prompt") { params.prompt = argv[++i]; }
Expand Down Expand Up @@ -205,6 +207,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false");
fprintf(stderr, " --model-fold-lc [%-7s] fold all model tokens to lowercase\n", params.model_fold_lc ? "true" : "false");
Comment on lines 209 to +210
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false");
fprintf(stderr, " --model-fold-lc [%-7s] fold all model tokens to lowercase\n", params.model_fold_lc ? "true" : "false");
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false");
fprintf(stderr, " --vocab-lc [%-7s] fold all vocab tokens to lowercase\n", params.vocab_lc ? "true" : "false");

fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false");
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt (max n_text_ctx/2 tokens)\n", params.prompt.c_str());
Expand Down Expand Up @@ -893,6 +896,7 @@ int main(int argc, char ** argv) {

struct whisper_context_params cparams = whisper_context_default_params();
cparams.use_gpu = params.use_gpu;
cparams.vocab_lc = params.model_fold_lc;

if (!params.dtw.empty()) {
cparams.dtw_token_timestamps = true;
Expand Down
8 changes: 7 additions & 1 deletion whisper.cpp
Expand Up @@ -1383,6 +1383,12 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
word = "";
}

// If requested, output all text as lowercase.
if (wctx.params.vocab_lc) {
std::transform(word.begin(), word.end(), word.begin(),
[](unsigned char c) { return std::tolower(c); });
}

vocab.token_to_id[word] = i;
vocab.id_to_token[i] = word;

Expand Down Expand Up @@ -3374,7 +3380,7 @@ struct whisper_context_params whisper_context_default_params() {
struct whisper_context_params result = {
/*.use_gpu =*/ true,
/*.gpu_device =*/ 0,

/*.vocab_lc =*/ false,
/*.dtw_token_timestamps =*/ false,
/*.dtw_aheads_preset =*/ WHISPER_AHEADS_NONE,
/*.dtw_n_top =*/ -1,
Expand Down
3 changes: 3 additions & 0 deletions whisper.h
Expand Up @@ -115,6 +115,9 @@ extern "C" {
bool use_gpu;
int gpu_device; // CUDA device

// Fold language tokens to lowercase
bool vocab_lc;

// [EXPERIMENTAL] Token-level timestamps with DTW
bool dtw_token_timestamps;
enum whisper_alignment_heads_preset dtw_aheads_preset;
Expand Down