Skip to content

Commit 37f5e4f

Browse files
authored
feat(whisper): Add diarization (tinydiarize) (#6184)
Signed-off-by: Richard Palethorpe <[email protected]>
1 parent ffa934b commit 37f5e4f

File tree

9 files changed

+76
-47
lines changed

9 files changed

+76
-47
lines changed

backend/backend.proto

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@ message TranscriptRequest {
276276
string language = 3;
277277
uint32 threads = 4;
278278
bool translate = 5;
279+
bool diarize = 6;
279280
}
280281

281282
message TranscriptResult {
@@ -305,7 +306,7 @@ message GenerateImageRequest {
305306
// Diffusers
306307
string EnableParameters = 10;
307308
int32 CLIPSkip = 11;
308-
309+
309310
// Reference images for models that support them (e.g., Flux Kontext)
310311
repeated string ref_images = 12;
311312
}

backend/go/whisper/gowhisper.cpp

Lines changed: 38 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,34 +7,35 @@ static struct whisper_vad_context *vctx;
77
static struct whisper_context *ctx;
88
static std::vector<float> flat_segs;
99

10-
static void ggml_log_cb(enum ggml_log_level level, const char* log, void* data) {
11-
const char* level_str;
12-
13-
if (!log) {
14-
return;
15-
}
16-
17-
switch (level) {
18-
case GGML_LOG_LEVEL_DEBUG:
19-
level_str = "DEBUG";
20-
break;
21-
case GGML_LOG_LEVEL_INFO:
22-
level_str = "INFO";
23-
break;
24-
case GGML_LOG_LEVEL_WARN:
25-
level_str = "WARN";
26-
break;
27-
case GGML_LOG_LEVEL_ERROR:
28-
level_str = "ERROR";
29-
break;
30-
default: /* Potential future-proofing */
31-
level_str = "?????";
32-
break;
33-
}
34-
35-
fprintf(stderr, "[%-5s] ", level_str);
36-
fputs(log, stderr);
37-
fflush(stderr);
10+
static void ggml_log_cb(enum ggml_log_level level, const char *log,
11+
void *data) {
12+
const char *level_str;
13+
14+
if (!log) {
15+
return;
16+
}
17+
18+
switch (level) {
19+
case GGML_LOG_LEVEL_DEBUG:
20+
level_str = "DEBUG";
21+
break;
22+
case GGML_LOG_LEVEL_INFO:
23+
level_str = "INFO";
24+
break;
25+
case GGML_LOG_LEVEL_WARN:
26+
level_str = "WARN";
27+
break;
28+
case GGML_LOG_LEVEL_ERROR:
29+
level_str = "ERROR";
30+
break;
31+
default: /* Potential future-proofing */
32+
level_str = "?????";
33+
break;
34+
}
35+
36+
fprintf(stderr, "[%-5s] ", level_str);
37+
fputs(log, stderr);
38+
fflush(stderr);
3839
}
3940

4041
int load_model(const char *const model_path) {
@@ -105,8 +106,8 @@ int vad(float pcmf32[], size_t pcmf32_len, float **segs_out,
105106
return 0;
106107
}
107108

108-
int transcribe(uint32_t threads, char *lang, bool translate, float pcmf32[],
109-
size_t pcmf32_len, size_t *segs_out_len) {
109+
int transcribe(uint32_t threads, char *lang, bool translate, bool tdrz,
110+
float pcmf32[], size_t pcmf32_len, size_t *segs_out_len) {
110111
whisper_full_params wparams =
111112
whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
112113

@@ -120,6 +121,9 @@ int transcribe(uint32_t threads, char *lang, bool translate, float pcmf32[],
120121
wparams.translate = translate;
121122
wparams.debug_mode = true;
122123
wparams.print_progress = true;
124+
wparams.tdrz_enable = tdrz;
125+
126+
fprintf(stderr, "info: Enable tdrz: %d\n", tdrz);
123127

124128
if (whisper_full(ctx, wparams, pcmf32, pcmf32_len)) {
125129
fprintf(stderr, "error: transcription failed\n");
@@ -144,3 +148,7 @@ int n_tokens(int i) { return whisper_full_n_tokens(ctx, i); }
144148
int32_t get_token_id(int i, int j) {
145149
return whisper_full_get_token_id(ctx, i, j);
146150
}
151+
152+
bool get_segment_speaker_turn_next(int i) {
153+
return whisper_full_get_segment_speaker_turn_next(ctx, i);
154+
}

backend/go/whisper/gowhisper.go

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,16 @@ import (
1414
)
1515

1616
var (
17-
CppLoadModel func(modelPath string) int
18-
CppLoadModelVAD func(modelPath string) int
19-
CppVAD func(pcmf32 []float32, pcmf32Size uintptr, segsOut unsafe.Pointer, segsOutLen unsafe.Pointer) int
20-
CppTranscribe func(threads uint32, lang string, translate bool, pcmf32 []float32, pcmf32Len uintptr, segsOutLen unsafe.Pointer) int
21-
CppGetSegmentText func(i int) string
22-
CppGetSegmentStart func(i int) int64
23-
CppGetSegmentEnd func(i int) int64
24-
CppNTokens func(i int) int
25-
CppGetTokenID func(i int, j int) int
17+
CppLoadModel func(modelPath string) int
18+
CppLoadModelVAD func(modelPath string) int
19+
CppVAD func(pcmf32 []float32, pcmf32Size uintptr, segsOut unsafe.Pointer, segsOutLen unsafe.Pointer) int
20+
CppTranscribe func(threads uint32, lang string, translate bool, diarize bool, pcmf32 []float32, pcmf32Len uintptr, segsOutLen unsafe.Pointer) int
21+
CppGetSegmentText func(i int) string
22+
CppGetSegmentStart func(i int) int64
23+
CppGetSegmentEnd func(i int) int64
24+
CppNTokens func(i int) int
25+
CppGetTokenID func(i int, j int) int
26+
CppGetSegmentSpeakerTurnNext func(i int) bool
2627
)
2728

2829
type Whisper struct {
@@ -122,7 +123,7 @@ func (w *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptR
122123
segsLen := uintptr(0xdeadbeef)
123124
segsLenPtr := unsafe.Pointer(&segsLen)
124125

125-
if ret := CppTranscribe(opts.Threads, opts.Language, opts.Translate, data, uintptr(len(data)), segsLenPtr); ret != 0 {
126+
if ret := CppTranscribe(opts.Threads, opts.Language, opts.Translate, opts.Diarize, data, uintptr(len(data)), segsLenPtr); ret != 0 {
126127
return pb.TranscriptResult{}, fmt.Errorf("Failed Transcribe")
127128
}
128129

@@ -134,6 +135,10 @@ func (w *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptR
134135
txt := strings.Clone(CppGetSegmentText(i))
135136
tokens := make([]int32, CppNTokens(i))
136137

138+
if opts.Diarize && CppGetSegmentSpeakerTurnNext(i) {
139+
txt += " [SPEAKER_TURN]"
140+
}
141+
137142
for j := range tokens {
138143
tokens[j] = int32(CppGetTokenID(i, j))
139144
}
@@ -151,6 +156,6 @@ func (w *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptR
151156

152157
return pb.TranscriptResult{
153158
Segments: segments,
154-
Text: strings.TrimSpace(text),
159+
Text: strings.TrimSpace(text),
155160
}, nil
156161
}

backend/go/whisper/gowhisper.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@ int load_model(const char *const model_path);
66
int load_model_vad(const char *const model_path);
77
int vad(float pcmf32[], size_t pcmf32_size, float **segs_out,
88
size_t *segs_out_len);
9-
int transcribe(uint32_t threads, char *lang, bool translate, float pcmf32[],
10-
size_t pcmf32_len, size_t *segs_out_len);
9+
int transcribe(uint32_t threads, char *lang, bool translate, bool tdrz,
10+
float pcmf32[], size_t pcmf32_len, size_t *segs_out_len);
1111
const char *get_segment_text(int i);
1212
int64_t get_segment_t0(int i);
1313
int64_t get_segment_t1(int i);
1414
int n_tokens(int i);
1515
int32_t get_token_id(int i, int j);
16+
bool get_segment_speaker_turn_next(int i);
1617
}

backend/go/whisper/main.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ func main() {
3333
{&CppGetSegmentEnd, "get_segment_t1"},
3434
{&CppNTokens, "n_tokens"},
3535
{&CppGetTokenID, "get_token_id"},
36+
{&CppGetSegmentSpeakerTurnNext, "get_segment_speaker_turn_next"},
3637
}
3738

3839
for _, lf := range libFuncs {

core/backend/transcript.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import (
1212
"github.com/mudler/LocalAI/pkg/model"
1313
)
1414

15-
func ModelTranscription(audio, language string, translate bool, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
15+
func ModelTranscription(audio, language string, translate bool, diarize bool, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
1616

1717
if modelConfig.Backend == "" {
1818
modelConfig.Backend = model.WhisperBackend
@@ -34,6 +34,7 @@ func ModelTranscription(audio, language string, translate bool, ml *model.ModelL
3434
Dst: audio,
3535
Language: language,
3636
Translate: translate,
37+
Diarize: diarize,
3738
Threads: uint32(*modelConfig.Threads),
3839
})
3940
if err != nil {

core/cli/transcript.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ type TranscriptCMD struct {
2020
Model string `short:"m" required:"" help:"Model name to run the TTS"`
2121
Language string `short:"l" help:"Language of the audio file"`
2222
Translate bool `short:"c" help:"Translate the transcription to english"`
23+
Diarize bool `short:"d" help:"Mark speaker turns"`
2324
Threads int `short:"t" default:"1" help:"Number of threads used for parallel computation"`
2425
ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
2526
}
@@ -56,7 +57,7 @@ func (t *TranscriptCMD) Run(ctx *cliContext.Context) error {
5657
}
5758
}()
5859

59-
tr, err := backend.ModelTranscription(t.Filename, t.Language, t.Translate, ml, c, opts)
60+
tr, err := backend.ModelTranscription(t.Filename, t.Language, t.Translate, t.Diarize, ml, c, opts)
6061
if err != nil {
6162
return err
6263
}

core/http/endpoints/openai/transcription.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
3636
return fiber.ErrBadRequest
3737
}
3838

39+
diarize := c.FormValue("diarize", "false") != "false"
40+
3941
// retrieve the file data from the request
4042
file, err := c.FormFile("file")
4143
if err != nil {
@@ -67,7 +69,7 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
6769

6870
log.Debug().Msgf("Audio file copied to: %+v", dst)
6971

70-
tr, err := backend.ModelTranscription(dst, input.Language, input.Translate, ml, *config, appConfig)
72+
tr, err := backend.ModelTranscription(dst, input.Language, input.Translate, diarize, ml, *config, appConfig)
7173
if err != nil {
7274
return err
7375
}

gallery/index.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20174,6 +20174,15 @@
2017420174
- filename: "ggml-small.bin"
2017520175
uri: "huggingface://ggerganov/whisper.cpp/ggml-small.bin"
2017620176
sha256: 1be3a9b2063867b937e64e2ec7483364a79917e157fa98c5d94b5c1fffea987b
20177+
- !!merge <<: *whisper
20178+
name: "whisper-small-en-tdrz"
20179+
overrides:
20180+
parameters:
20181+
model: ggml-small.en-tdrz.bin
20182+
files:
20183+
- filename: "ggml-small.bin"
20184+
uri: "huggingface://akashmjn/tinydiarize-whisper.cpp/ggml-small.en-tdrz.bin"
20185+
sha256: ceac3ec06d1d98ef71aec665283564631055fd6129b79d8e1be4f9cc33cc54b4
2017720186
- !!merge <<: *whisper
2017820187
name: "whisper-small-en-q5_1"
2017920188
overrides:

0 commit comments

Comments
 (0)