Skip to content

Commit ffedb72

Browse files
committed
feat(transformers): add support to Dia
Signed-off-by: Ettore Di Giacinto <[email protected]>
1 parent 09457b9 commit ffedb72

File tree

1 file changed

+86
-8
lines changed

1 file changed

+86
-8
lines changed

backend/python/transformers/backend.py

Lines changed: 86 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
XPU=os.environ.get("XPU", "0") == "1"
2424
from transformers import AutoTokenizer, AutoModel, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria, MambaConfig, MambaForCausalLM
25-
from transformers import AutoProcessor, MusicgenForConditionalGeneration
25+
from transformers import AutoProcessor, MusicgenForConditionalGeneration, DiaForConditionalGeneration
2626
from scipy.io import wavfile
2727
import outetts
2828
from sentence_transformers import SentenceTransformer
@@ -90,13 +90,38 @@ def LoadModel(self, request, context):
9090
self.CUDA = torch.cuda.is_available()
9191
self.OV=False
9292
self.OuteTTS=False
93+
self.DiaTTS=False
9394
self.SentenceTransformer = False
9495

9596
device_map="cpu"
9697

9798
quantization = None
9899
autoTokenizer = True
99100

101+
# Parse options from request.Options
102+
self.options = {}
103+
options = request.Options
104+
105+
# The options are a list of strings in this form optname:optvalue
106+
# We are storing all the options in a dict so we can use it later when generating
107+
# Example options: ["max_new_tokens:3072", "guidance_scale:3.0", "temperature:1.8", "top_p:0.90", "top_k:45"]
108+
for opt in options:
109+
if ":" not in opt:
110+
continue
111+
key, value = opt.split(":", 1)
112+
# if value is a number, convert it to the appropriate type
113+
try:
114+
if "." in value:
115+
value = float(value)
116+
else:
117+
value = int(value)
118+
except ValueError:
119+
# Keep as string if conversion fails
120+
pass
121+
self.options[key] = value
122+
123+
print(f"Parsed options: {self.options}", file=sys.stderr)
124+
100125
if self.CUDA:
101126
from transformers import BitsAndBytesConfig, AutoModelForCausalLM
102127
if request.MainGPU:
@@ -202,6 +227,11 @@ def LoadModel(self, request, context):
202227
autoTokenizer = False
203228
self.processor = AutoProcessor.from_pretrained(model_name)
204229
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
230+
elif request.Type == "DiaForConditionalGeneration":
231+
autoTokenizer = False
232+
self.processor = AutoProcessor.from_pretrained(model_name)
233+
self.model = DiaForConditionalGeneration.from_pretrained(model_name)
234+
self.DiaTTS = True
205235
elif request.Type == "OuteTTS":
206236
autoTokenizer = False
207237
options = request.Options
@@ -262,7 +292,7 @@ def LoadModel(self, request, context):
262292
elif hasattr(self.model, 'config') and hasattr(self.model.config, 'max_position_embeddings'):
263293
self.max_tokens = self.model.config.max_position_embeddings
264294
else:
265-
self.max_tokens = 512
295+
self.max_tokens = self.options.get("max_new_tokens", 512)
266296

267297
if autoTokenizer:
268298
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_safetensors=True)
@@ -485,16 +515,15 @@ def SoundGeneration(self, request, context):
485515
return_tensors="pt",
486516
)
487517

488-
tokens = 256
489518
if request.HasField('duration'):
490519
tokens = int(request.duration * 51.2) # 256 tokens = 5 seconds, therefore 51.2 tokens is one second
491-
guidance = 3.0
520+
guidance = self.options.get("guidance_scale", 3.0)
492521
if request.HasField('temperature'):
493522
guidance = request.temperature
494-
dosample = True
523+
dosample = self.options.get("do_sample", True)
495524
if request.HasField('sample'):
496525
dosample = request.sample
497-
audio_values = self.model.generate(**inputs, do_sample=dosample, guidance_scale=guidance, max_new_tokens=tokens)
526+
audio_values = self.model.generate(**inputs, do_sample=dosample, guidance_scale=guidance, max_new_tokens=self.max_tokens)
498527
print("[transformers-musicgen] SoundGeneration generated!", file=sys.stderr)
499528
sampling_rate = self.model.config.audio_encoder.sampling_rate
500529
wavfile.write(request.dst, rate=sampling_rate, data=audio_values[0, 0].numpy())
@@ -506,13 +535,59 @@ def SoundGeneration(self, request, context):
506535
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
507536
return backend_pb2.Result(success=True)
508537

538+
539+
def DiaTTS(self, request, context):
540+
"""
541+
Generates dialogue audio using the Dia model.
542+
543+
Args:
544+
request: A TTSRequest containing text dialogue and generation parameters
545+
context: The gRPC context
546+
547+
Returns:
548+
A Result object indicating success or failure
549+
"""
550+
try:
551+
print("[DiaTTS] generating dialogue audio", file=sys.stderr)
552+
553+
# Prepare text input - expect dialogue format like [S1] ... [S2] ...
554+
text = [request.text]
555+
556+
# Process the input
557+
inputs = self.processor(text=text, padding=True, return_tensors="pt")
558+
559+
# Generate audio with parameters from options or defaults
560+
generation_params = {
561+
**inputs,
562+
"max_new_tokens": self.max_tokens,
563+
"guidance_scale": self.options.get("guidance_scale", 3.0),
564+
"temperature": self.options.get("temperature", 1.8),
565+
"top_p": self.options.get("top_p", 0.90),
566+
"top_k": self.options.get("top_k", 45)
567+
}
568+
569+
outputs = self.model.generate(**generation_params)
570+
571+
# Decode and save audio
572+
outputs = self.processor.batch_decode(outputs)
573+
self.processor.save_audio(outputs, request.dst)
574+
575+
print("[DiaTTS] Generated dialogue audio", file=sys.stderr)
576+
print("[DiaTTS] Audio saved to", request.dst, file=sys.stderr)
577+
print("[DiaTTS] Dialogue generation done", file=sys.stderr)
578+
579+
except Exception as err:
580+
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
581+
return backend_pb2.Result(success=True)
582+
583+
509584
def OuteTTS(self, request, context):
510585
try:
511586
print("[OuteTTS] generating TTS", file=sys.stderr)
512587
gen_cfg = outetts.GenerationConfig(
513588
text="Speech synthesis is the artificial production of human speech.",
514-
temperature=0.1,
515-
repetition_penalty=1.1,
589+
temperature=self.options.get("temperature", 0.1),
590+
repetition_penalty=self.options.get("repetition_penalty", 1.1),
516591
max_length=self.max_tokens,
517592
speaker=self.speaker,
518593
# voice_characteristics="upbeat enthusiasm, friendliness, clarity, professionalism, and trustworthiness"
@@ -529,6 +604,9 @@ def OuteTTS(self, request, context):
529604
def TTS(self, request, context):
530605
if self.OuteTTS:
531606
return self.OuteTTS(request, context)
607+
608+
if self.DiaTTS:
609+
return self.DiaTTS(request, context)
532610

533611
model_name = request.model
534612
try:

0 commit comments

Comments
 (0)