Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug - Translation server] - Missing tgtparam in translator.translate method (allows some multilingual/seq2seq models to work properly) #2586

Open
medfreeman opened this issue Apr 22, 2024 · 0 comments · May be fixed by #2585

Comments

@medfreeman
Copy link

medfreeman commented Apr 22, 2024

Some multilingual/seq2seq models such as M2M100 (c.f. Generation section in the linked page) require the bos_token set to the target language id in the sequence tgt property.

In the case of the translation server, to be able to specify the requested translation language, we the need to directly manipulate the sequence tgt property prior to translation.

But in its current state the server has a disconnection between the sequence ref/ref_tok (which can be manipulated through tokenizers/processors btw) and tgt string prior to being sent to ctranslate2.

c.f.

"tgt": {"tgt": ref_tok} if ref_tok is not None else None,

Basically the parameter tgt of the self.translator.translate method is never provided.

c.f.

scores, predictions = self.translator.translate(examples)

I successfully implemented a one-line patch that properly passes the parameter through and allows me to do multilingual translation.
It should not have side-effects on other type of models (for which the sequence ref is empty after tokenizing the sequence), by setting the parameter as an empty string in those cases.

Here’s the PR: #2585

Example of multilingual translation with a M2M100 model:

conf.json

{
    "models_root": "./available_models",
    "models": [
        {
            "id": 100,
            "model": "m2m-multi4-ft-ck945k/",
            "ct2_model": "m2m-multi4-ft-ck945k/",
            "load": true,
            "on_timeout": "unload",
            "ct2_translator_args": {
                "inter_threads": 4,
                "intra_threads": 2
            },
            "ct2_translate_batch_args": {},
            "opt": {
                "beam_size": 1,
                "batch_size": 8,
                "tgt_file_prefix": true
            },
            "preprocess": ["available_models.m2m-multi4-ft-ck945k.tokenizer.m2m100_tokenizer.preprocess"],
            "postprocess": ["available_models.m2m-multi4-ft-ck945k.tokenizer.m2m100_tokenizer.postprocess"]
        }
    ]
}

available_models/m2m-multi4-ft-ck945k/tokenizer/m2m100_tokenizer.py

import os
from pathlib import Path
from transformers import M2M100Tokenizer

cache = None

def loadTokenizer(model_root, logger):
        global cache
        if cache is not None:
              return cache

        model_path = os.path.join(model_root, "m2m-multi4-ft-ck945k/tokenizer/")
        logger.info("Loading m2m100 tokenizer from %s", model_path)
        cache = M2M100Tokenizer.from_pretrained(model_path)

        return cache

def preprocess(sequence, server_model):
        """Preprocess a single sequence.

        Args:
            sequence (dict[str, Unknown]): The sequence to preprocess.

        Returns:
            sequence (dict[str, Unknown]): The preprocessed sequence."""
        server_model.logger.info(f"Running preprocessor '{ Path(__file__).stem }'")

        ref = sequence.get("ref", None)
        if ref[0] is not None:
            server_model.logger.debug(f"${ref[0]=}")
            tgt_lang = ref[0].get("tgt_lang", None)
            if tgt_lang is not None:
                server_model.logger.debug(f"${tgt_lang=}")

                tokenizer = loadTokenizer(server_model.model_root, server_model.logger)

                seg = sequence.get("seg", None)
                tok = tokenizer.convert_ids_to_tokens(
                    tokenizer.encode(seg[0])
                )
                tok = " ".join(tok)

                sequence["seg"][0] = tok

                lang_prefix = f"__{tgt_lang}__"
                sequence["ref"][0] = f"{lang_prefix}"
                server_model.logger.info(f"Added lang prefix to ref: '{lang_prefix}'")
                server_model.logger.debug(f"${sequence['ref'][0]=}")

        return sequence

def postprocess(sequence, server_model):
        """Postprocess a single sequence.

        Args:
            sequence (dict[str, Unknown]): The sequence to postprocess.

        Returns:
            sequence (dict[str, Unknown]): The post processed sequence."""
        server_model.logger.info(f"Running postprocessor '{ Path(__file__).stem }'")

        tokenizer = loadTokenizer(server_model.model_root, server_model.logger)

        seg = sequence.get("seg", None)
        detok = tokenizer.decode(
            tokenizer.convert_tokens_to_ids(seg[0].split()[1:]),
            skip_special_tokens=True
        )
        return detok

Sample request to server:

[
    {
        "src": "Brian is in the kitchen.",
        "id": 100,
        "ref": {
            "src_lang": "en",
            "tgt_lang": "fr"
        }
    },
    {
        "src": "By the way, do you like to eat pancakes?",
        "id": 100,
        "ref": {
            "src_lang": "en",
            "tgt_lang": "fr"
        }
    }
]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant