MENU

【whisper】生成時のパラメーターを詳細に指定して精度を上げる

この記事はwhisperの生成時パラメーターについての記事です。

目次

パラメーターについて

あわせて読みたい
Whisper We’re on a journey to advance and democratize artificial intelligence through open source and open science.
Qiita
OpenAI Whisper のコマンドオプション - Qiita あらすじ以前は、GPU非搭載 CPU上でOpenAIのWhisperを試して音声データを文字起こししてみた。でWhisperの利用方法を説明しました。今後社内で使っていくことが決まりまし...

input_features

  • 型: torch.Tensor (形状: (batch_size, feature_size, sequence_length))
  • 説明: 音声波形から抽出されたログメル(log-mel)特徴量を表すテンソルです。
    元の音声波形は、.flac.wav ファイルを読み込んで List[float]numpy.ndarray 型の配列として取得できます(例: soundfile ライブラリを使用)。
    特徴量の抽出には、AutoFeatureExtractor を用いてログメル特徴量を生成し、パディングや torch.FloatTensor への変換を行います。

generation_config

  • 型: ~generation.GenerationConfig
  • 説明: 生成プロセスに使用する設定です。この設定を指定しない場合、以下の優先順位でデフォルト値が使用されます:
    1. モデルファイル内の generation_config.json が存在する場合、その値を使用。
    2. モデル設定に基づくデフォルト値。
      generation_config に含まれる属性と一致する kwargs(キーワード引数) を渡すことで、生成設定を上書きできます。

logits_processor

  • 型: LogitsProcessorList
  • 説明: デフォルトのロジットプロセッサに加えて、カスタムロジットプロセッサを追加できます。
    ただし、引数や生成設定によって既に作成されたプロセッサと重複する場合、エラーが発生します。
    上級者向けの機能 です。

stopping_criteria

  • 型: StoppingCriteriaList
  • 説明: デフォルトの停止基準に加えて、カスタム停止基準を追加できます。
    引数や生成設定によって既に作成された停止基準と重複する場合、エラーが発生します。
    上級者向けの機能 です。

prefix_allowed_tokens_fn

  • 型: Callable[[int, torch.Tensor], List[int]]
  • 説明: ビームサーチ中に許可されたトークンを各ステップで制限する関数です。
    この関数は 2 つの引数(バッチ ID batch_id と入力 ID input_ids)を取り、次の生成ステップで許可されるトークンのリストを返します。
    プロンプトに基づいた制約付き生成を行う際に役立ちます。

synced_gpus

  • 型: bool(デフォルト: False
  • 説明: FullyShardedDataParallelDeepSpeed ZeRO Stage 3 を使用する場合に、max_length までループを続行するかどうかを指定します。

return_timestamps

  • 型: bool
  • 説明: 出力テキストと一緒にタイムスタンプを返すかどうかを指定します。
    タイムスタンプは、WhisperTimestampsLogitsProcessor を有効にすることで利用可能になります。

task

  • 型: str
  • 説明: 使用するタスクを指定します。以下のいずれかを指定可能です:
    • "translate": 翻訳タスク
    • "transcribe": 音声文字起こしタスク
      この値に基づいて model.config.forced_decoder_ids が更新されます。

language

  • 型: str または list[str]
  • 説明: 生成時に使用する言語トークンを指定します。以下の形式が可能です:
    • <|en|>
    • en
    • english
      バッチ処理では、言語トークンのリストも指定可能です。利用可能なトークンは model.generation_config.lang_to_id 辞書で確認できます。

is_multilingual

  • 型: bool
  • 説明: モデルが多言語対応かどうかを指定します。

prompt_ids

  • 型: torch.Tensor
  • 説明: トークン ID の rank-1 テンソル。プロンプトを提供するために使用されます。
    例: カスタム辞書や固有名詞を指定し、正確に予測させる目的。
    このオプションは decoder_start_token_id と併用できません。

prompt_condition_type

  • 型: str
  • 説明:長文の文字起こしにおけるプロンプトの条件付け方法を指定します。
    • "first-segment": 最初のセグメントにのみ prompt_ids を条件付け。
    • "all-segments": 各セグメントに prompt_ids を条件付け(condition_on_prev_tokens を有効にする必要があります)。
      デフォルトは "first-segment"

condition_on_prev_tokens

  • 型: bool
  • 説明: 長文の文字起こしにおいて、前のセグメントを次のセグメントの条件として使用するかどうかを指定します。
    Whisper 論文では、この設定を有効にすると性能向上が示されています。

temperature

  • 型: float または list[float]
  • 説明: サンプリング生成に使用する温度を指定します。
    長文の文字起こしの場合、複数の温度値(例: [0.0, 0.2, 0.4, 0.6, 0.8, 1.0])を指定すると、失敗時に温度を変えてリトライします。

compression_ratio_threshold

  • 型: float
  • 説明: 長文の文字起こしにおいて、高い圧縮率(例: 1.35以上)のセグメントを再生成します。

logprob_threshold

  • 型: float
  • 説明: 平均対数確率が低いセグメントを再生成します。一般的な値は -1.0

no_speech_threshold

  • 型: float
  • 説明: 無音と判定されたセグメントをスキップします。

num_segment_frames

  • 型: int
  • 説明: 単一セグメントのフレーム数を指定します。

attention_mask

  • 型: torch.Tensor
  • 説明: バッチサイズが 1 を超える長文の文字起こし時に必要なマスク。

time_precision

  • 型: int(デフォルト: 0.02)
  • 説明: 出力トークンの平均時間幅(秒)。例: 0.02 は 20ms。

return_token_timestamps

  • 型: bool
  • 説明: トークンレベルのタイムスタンプを返すかどうかを指定します。

return_segments

  • 型: bool(デフォルト: False
  • 説明: すべてのセグメントをリスト形式で返すかどうかを指定します。

return_dict_in_generate

  • 型: bool(デフォルト: False
  • 説明: 生成されたトークンだけでなく、ModelOutput を返すかどうかを指定します。

実際の指定

実際はこのような使い方が望ましいかと思われます。

generate_kwargs = {
        "language": "Japanese",
        "num_beams": 6,
        "task": "transcribe",
        "no_repeat_ngram_size": 6,
        "repetition_penalty": 1.3,
        "length_penalty": 1.2,
        "condition_on_prev_tokens": False,
        #"logprob_threshold": -0.5,
        "compression_ratio_threshold": 1.7,
    }

num_beamsはビームサーチです。数を増やすほど精度が上がりますが限界があります。(過去64,128など巨大数値で試しましたが改善は見られませんでした。)

no_repeat_ngram_sizeはハルシネ対策としてちょうどよい数値をおすすめします。

repetition_penalty,length_penaltyも同様です。

また、モデルの性能が基本的には高水準の場合、condition_on_prev_tokensをFalseにしたほうがいいです。Trueにするとハルシネで繰り返しが発生した場合、繰り返し文章を参照してまた繰り返す負の繰り返しが起きます。

compression_ratio_thresholdも同じトークンが固まっている、繰り返しが多い場合有効です。

まとめ

文字起こしタスクの場合、試すとわかりますがハルシネ対策、繰り返し対策が重要です。

モデルの性能を改善のために闇雲にファインチューニングするよりも、生成時の設定を見直すだけで十分な結果を出すことは可能です。

ファインチューニングよりも簡単にできるため、まずは設定を見直し、だめだった場合はファインチューニングするようにすると効率が良いと思います。

よかったらシェアしてね!
  • URLをコピーしました!
  • URLをコピーしました!

この記事を書いた人

プログラミングをそれとなく続けてきて歴だけは10年。
コーディングは基本的な命令文とクラスの概念は理解。
あとはライブラリなどを使ってそれとなく。
最近はAI関連を触ってます。

目次