22
22
23
23
XPU = os .environ .get ("XPU" , "0" ) == "1"
24
24
from transformers import AutoTokenizer , AutoModel , set_seed , TextIteratorStreamer , StoppingCriteriaList , StopStringCriteria , MambaConfig , MambaForCausalLM
25
- from transformers import AutoProcessor , MusicgenForConditionalGeneration
25
+ from transformers import AutoProcessor , MusicgenForConditionalGeneration , DiaForConditionalGeneration
26
26
from scipy .io import wavfile
27
27
import outetts
28
28
from sentence_transformers import SentenceTransformer
@@ -90,6 +90,7 @@ def LoadModel(self, request, context):
90
90
self .CUDA = torch .cuda .is_available ()
91
91
self .OV = False
92
92
self .OuteTTS = False
93
+ self .DiaTTS = False
93
94
self .SentenceTransformer = False
94
95
95
96
device_map = "cpu"
@@ -202,6 +203,11 @@ def LoadModel(self, request, context):
202
203
autoTokenizer = False
203
204
self .processor = AutoProcessor .from_pretrained (model_name )
204
205
self .model = MusicgenForConditionalGeneration .from_pretrained (model_name )
206
+ elif request .Type == "DiaForConditionalGeneration" :
207
+ autoTokenizer = False
208
+ self .processor = AutoProcessor .from_pretrained (model_name )
209
+ self .model = DiaForConditionalGeneration .from_pretrained (model_name )
210
+ self .DiaTTS = True
205
211
elif request .Type == "OuteTTS" :
206
212
autoTokenizer = False
207
213
options = request .Options
@@ -506,6 +512,50 @@ def SoundGeneration(self, request, context):
506
512
return backend_pb2 .Result (success = False , message = f"Unexpected { err = } , { type (err )= } " )
507
513
return backend_pb2 .Result (success = True )
508
514
515
+
516
+ def DiaTTS (self , request , context ):
517
+ """
518
+ Generates dialogue audio using the Dia model.
519
+
520
+ Args:
521
+ request: A TTSRequest containing text dialogue and generation parameters
522
+ context: The gRPC context
523
+
524
+ Returns:
525
+ A Result object indicating success or failure
526
+ """
527
+ try :
528
+ print ("[DiaTTS] generating dialogue audio" , file = sys .stderr )
529
+
530
+ # Prepare text input - expect dialogue format like [S1] ... [S2] ...
531
+ text = [request .text ]
532
+
533
+ # Process the input
534
+ inputs = self .processor (text = text , padding = True , return_tensors = "pt" )
535
+
536
+ # Generate audio with default Dia parameters
537
+ outputs = self .model .generate (
538
+ ** inputs ,
539
+ max_new_tokens = 3072 ,
540
+ guidance_scale = 3.0 ,
541
+ temperature = 1.8 ,
542
+ top_p = 0.90 ,
543
+ top_k = 45
544
+ )
545
+
546
+ # Decode and save audio
547
+ outputs = self .processor .batch_decode (outputs )
548
+ self .processor .save_audio (outputs , request .dst )
549
+
550
+ print ("[DiaTTS] Generated dialogue audio" , file = sys .stderr )
551
+ print ("[DiaTTS] Audio saved to" , request .dst , file = sys .stderr )
552
+ print ("[DiaTTS] Dialogue generation done" , file = sys .stderr )
553
+
554
+ except Exception as err :
555
+ return backend_pb2 .Result (success = False , message = f"Unexpected { err = } , { type (err )= } " )
556
+ return backend_pb2 .Result (success = True )
557
+
558
+
509
559
def OuteTTS (self , request , context ):
510
560
try :
511
561
print ("[OuteTTS] generating TTS" , file = sys .stderr )
@@ -529,6 +579,9 @@ def OuteTTS(self, request, context):
529
579
def TTS (self , request , context ):
530
580
if self .OuteTTS :
531
581
return self .OuteTTS (request , context )
582
+
583
+ if self .DiaTTS :
584
+ return self .DiaTTS (request , context )
532
585
533
586
model_name = request .model
534
587
try :
0 commit comments