22
22
23
23
XPU = os .environ .get ("XPU" , "0" ) == "1"
24
24
from transformers import AutoTokenizer , AutoModel , set_seed , TextIteratorStreamer , StoppingCriteriaList , StopStringCriteria , MambaConfig , MambaForCausalLM
25
- from transformers import AutoProcessor , MusicgenForConditionalGeneration
25
+ from transformers import AutoProcessor , MusicgenForConditionalGeneration , DiaForConditionalGeneration
26
26
from scipy .io import wavfile
27
27
import outetts
28
28
from sentence_transformers import SentenceTransformer
@@ -90,13 +90,38 @@ def LoadModel(self, request, context):
90
90
self .CUDA = torch .cuda .is_available ()
91
91
self .OV = False
92
92
self .OuteTTS = False
93
+ self .DiaTTS = False
93
94
self .SentenceTransformer = False
94
95
95
96
device_map = "cpu"
96
97
97
98
quantization = None
98
99
autoTokenizer = True
99
100
101
+ # Parse options from request.Options
102
+ self .options = {}
103
+ options = request .Options
104
+
105
+ # The options are a list of strings in this form optname:optvalue
106
+ # We are storing all the options in a dict so we can use it later when generating
107
+ # Example options: ["max_new_tokens:3072", "guidance_scale:3.0", "temperature:1.8", "top_p:0.90", "top_k:45"]
108
+ for opt in options :
109
+ if ":" not in opt :
110
+ continue
111
+ key , value = opt .split (":" , 1 )
112
+ # if value is a number, convert it to the appropriate type
113
+ try :
114
+ if "." in value :
115
+ value = float (value )
116
+ else :
117
+ value = int (value )
118
+ except ValueError :
119
+ # Keep as string if conversion fails
120
+ pass
121
+ self .options [key ] = value
122
+
123
+ print (f"Parsed options: { self .options } " , file = sys .stderr )
124
+
100
125
if self .CUDA :
101
126
from transformers import BitsAndBytesConfig , AutoModelForCausalLM
102
127
if request .MainGPU :
@@ -202,6 +227,11 @@ def LoadModel(self, request, context):
202
227
autoTokenizer = False
203
228
self .processor = AutoProcessor .from_pretrained (model_name )
204
229
self .model = MusicgenForConditionalGeneration .from_pretrained (model_name )
230
+ elif request .Type == "DiaForConditionalGeneration" :
231
+ autoTokenizer = False
232
+ self .processor = AutoProcessor .from_pretrained (model_name )
233
+ self .model = DiaForConditionalGeneration .from_pretrained (model_name )
234
+ self .DiaTTS = True
205
235
elif request .Type == "OuteTTS" :
206
236
autoTokenizer = False
207
237
options = request .Options
@@ -262,7 +292,7 @@ def LoadModel(self, request, context):
262
292
elif hasattr (self .model , 'config' ) and hasattr (self .model .config , 'max_position_embeddings' ):
263
293
self .max_tokens = self .model .config .max_position_embeddings
264
294
else :
265
- self .max_tokens = 512
295
+ self .max_tokens = self . options . get ( "max_new_tokens" , 512 )
266
296
267
297
if autoTokenizer :
268
298
self .tokenizer = AutoTokenizer .from_pretrained (model_name , use_safetensors = True )
@@ -485,16 +515,15 @@ def SoundGeneration(self, request, context):
485
515
return_tensors = "pt" ,
486
516
)
487
517
488
- tokens = 256
489
518
if request .HasField ('duration' ):
490
519
tokens = int (request .duration * 51.2 ) # 256 tokens = 5 seconds, therefore 51.2 tokens is one second
491
- guidance = 3.0
520
+ guidance = self . options . get ( "guidance_scale" , 3.0 )
492
521
if request .HasField ('temperature' ):
493
522
guidance = request .temperature
494
- dosample = True
523
+ dosample = self . options . get ( "do_sample" , True )
495
524
if request .HasField ('sample' ):
496
525
dosample = request .sample
497
- audio_values = self .model .generate (** inputs , do_sample = dosample , guidance_scale = guidance , max_new_tokens = tokens )
526
+ audio_values = self .model .generate (** inputs , do_sample = dosample , guidance_scale = guidance , max_new_tokens = self . max_tokens )
498
527
print ("[transformers-musicgen] SoundGeneration generated!" , file = sys .stderr )
499
528
sampling_rate = self .model .config .audio_encoder .sampling_rate
500
529
wavfile .write (request .dst , rate = sampling_rate , data = audio_values [0 , 0 ].numpy ())
@@ -506,13 +535,59 @@ def SoundGeneration(self, request, context):
506
535
return backend_pb2 .Result (success = False , message = f"Unexpected { err = } , { type (err )= } " )
507
536
return backend_pb2 .Result (success = True )
508
537
538
+
539
+ def DiaTTS (self , request , context ):
540
+ """
541
+ Generates dialogue audio using the Dia model.
542
+
543
+ Args:
544
+ request: A TTSRequest containing text dialogue and generation parameters
545
+ context: The gRPC context
546
+
547
+ Returns:
548
+ A Result object indicating success or failure
549
+ """
550
+ try :
551
+ print ("[DiaTTS] generating dialogue audio" , file = sys .stderr )
552
+
553
+ # Prepare text input - expect dialogue format like [S1] ... [S2] ...
554
+ text = [request .text ]
555
+
556
+ # Process the input
557
+ inputs = self .processor (text = text , padding = True , return_tensors = "pt" )
558
+
559
+ # Generate audio with parameters from options or defaults
560
+ generation_params = {
561
+ ** inputs ,
562
+ "max_new_tokens" : self .max_tokens ,
563
+ "guidance_scale" : self .options .get ("guidance_scale" , 3.0 ),
564
+ "temperature" : self .options .get ("temperature" , 1.8 ),
565
+ "top_p" : self .options .get ("top_p" , 0.90 ),
566
+ "top_k" : self .options .get ("top_k" , 45 )
567
+ }
568
+
569
+ outputs = self .model .generate (** generation_params )
570
+
571
+ # Decode and save audio
572
+ outputs = self .processor .batch_decode (outputs )
573
+ self .processor .save_audio (outputs , request .dst )
574
+
575
+ print ("[DiaTTS] Generated dialogue audio" , file = sys .stderr )
576
+ print ("[DiaTTS] Audio saved to" , request .dst , file = sys .stderr )
577
+ print ("[DiaTTS] Dialogue generation done" , file = sys .stderr )
578
+
579
+ except Exception as err :
580
+ return backend_pb2 .Result (success = False , message = f"Unexpected { err = } , { type (err )= } " )
581
+ return backend_pb2 .Result (success = True )
582
+
583
+
509
584
def OuteTTS (self , request , context ):
510
585
try :
511
586
print ("[OuteTTS] generating TTS" , file = sys .stderr )
512
587
gen_cfg = outetts .GenerationConfig (
513
588
text = "Speech synthesis is the artificial production of human speech." ,
514
- temperature = 0.1 ,
515
- repetition_penalty = 1.1 ,
589
+ temperature = self . options . get ( "temperature" , 0.1 ) ,
590
+ repetition_penalty = self . options . get ( "repetition_penalty" , 1.1 ) ,
516
591
max_length = self .max_tokens ,
517
592
speaker = self .speaker ,
518
593
# voice_characteristics="upbeat enthusiasm, friendliness, clarity, professionalism, and trustworthiness"
@@ -529,6 +604,9 @@ def OuteTTS(self, request, context):
529
604
def TTS (self , request , context ):
530
605
if self .OuteTTS :
531
606
return self .OuteTTS (request , context )
607
+
608
+ if self .DiaTTS :
609
+ return self .DiaTTS (request , context )
532
610
533
611
model_name = request .model
534
612
try :
0 commit comments