Skip to content

Commit 36e5afb

Browse files
committed
feat(transformers): add support to Dia
Signed-off-by: Ettore Di Giacinto <[email protected]>
1 parent 09457b9 commit 36e5afb

File tree

1 file changed

+54
-1
lines changed

1 file changed

+54
-1
lines changed

backend/python/transformers/backend.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
XPU=os.environ.get("XPU", "0") == "1"
2424
from transformers import AutoTokenizer, AutoModel, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria, MambaConfig, MambaForCausalLM
25-
from transformers import AutoProcessor, MusicgenForConditionalGeneration
25+
from transformers import AutoProcessor, MusicgenForConditionalGeneration, DiaForConditionalGeneration
2626
from scipy.io import wavfile
2727
import outetts
2828
from sentence_transformers import SentenceTransformer
@@ -90,6 +90,7 @@ def LoadModel(self, request, context):
9090
self.CUDA = torch.cuda.is_available()
9191
self.OV=False
9292
self.OuteTTS=False
93+
self.DiaTTS=False
9394
self.SentenceTransformer = False
9495

9596
device_map="cpu"
@@ -202,6 +203,11 @@ def LoadModel(self, request, context):
202203
autoTokenizer = False
203204
self.processor = AutoProcessor.from_pretrained(model_name)
204205
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
206+
elif request.Type == "DiaForConditionalGeneration":
207+
autoTokenizer = False
208+
self.processor = AutoProcessor.from_pretrained(model_name)
209+
self.model = DiaForConditionalGeneration.from_pretrained(model_name)
210+
self.DiaTTS = True
205211
elif request.Type == "OuteTTS":
206212
autoTokenizer = False
207213
options = request.Options
@@ -506,6 +512,50 @@ def SoundGeneration(self, request, context):
506512
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
507513
return backend_pb2.Result(success=True)
508514

515+
516+
def DiaTTS(self, request, context):
517+
"""
518+
Generates dialogue audio using the Dia model.
519+
520+
Args:
521+
request: A TTSRequest containing text dialogue and generation parameters
522+
context: The gRPC context
523+
524+
Returns:
525+
A Result object indicating success or failure
526+
"""
527+
try:
528+
print("[DiaTTS] generating dialogue audio", file=sys.stderr)
529+
530+
# Prepare text input - expect dialogue format like [S1] ... [S2] ...
531+
text = [request.text]
532+
533+
# Process the input
534+
inputs = self.processor(text=text, padding=True, return_tensors="pt")
535+
536+
# Generate audio with default Dia parameters
537+
outputs = self.model.generate(
538+
**inputs,
539+
max_new_tokens=3072,
540+
guidance_scale=3.0,
541+
temperature=1.8,
542+
top_p=0.90,
543+
top_k=45
544+
)
545+
546+
# Decode and save audio
547+
outputs = self.processor.batch_decode(outputs)
548+
self.processor.save_audio(outputs, request.dst)
549+
550+
print("[DiaTTS] Generated dialogue audio", file=sys.stderr)
551+
print("[DiaTTS] Audio saved to", request.dst, file=sys.stderr)
552+
print("[DiaTTS] Dialogue generation done", file=sys.stderr)
553+
554+
except Exception as err:
555+
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
556+
return backend_pb2.Result(success=True)
557+
558+
509559
def OuteTTS(self, request, context):
510560
try:
511561
print("[OuteTTS] generating TTS", file=sys.stderr)
@@ -529,6 +579,9 @@ def OuteTTS(self, request, context):
529579
def TTS(self, request, context):
530580
if self.OuteTTS:
531581
return self.OuteTTS(request, context)
582+
583+
if self.DiaTTS:
584+
return self.DiaTTS(request, context)
532585

533586
model_name = request.model
534587
try:

0 commit comments

Comments
 (0)