14
14
import torch
15
15
import torchaudio as ta
16
16
from chatterbox .tts import ChatterboxTTS
17
-
17
+ from chatterbox . mtl_tts import ChatterboxMultilingualTTS
18
18
import grpc
19
19
20
+ def is_float (s ):
21
+ try :
22
+ float (s )
23
+ return True
24
+ except ValueError :
25
+ return False
20
26
21
27
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
22
28
@@ -47,6 +53,27 @@ def LoadModel(self, request, context):
47
53
if not torch .cuda .is_available () and request .CUDA :
48
54
return backend_pb2 .Result (success = False , message = "CUDA is not available" )
49
55
56
+
57
+ options = request .Options
58
+
59
+ # empty dict
60
+ self .options = {}
61
+
62
+ # The options are a list of strings in this form optname:optvalue
63
+ # We are storing all the options in a dict so we can use it later when
64
+ # generating the images
65
+ for opt in options :
66
+ if ":" not in opt :
67
+ continue
68
+ key , value = opt .split (":" )
69
+ # if value is a number, convert it to the appropriate type
70
+ if is_float (value ):
71
+ if value .is_integer ():
72
+ value = int (value )
73
+ else :
74
+ value = float (value )
75
+ self .options [key ] = value
76
+
50
77
self .AudioPath = None
51
78
52
79
if os .path .isabs (request .AudioPath ):
@@ -56,10 +83,14 @@ def LoadModel(self, request, context):
56
83
modelFileBase = os .path .dirname (request .ModelFile )
57
84
# modify LoraAdapter to be relative to modelFileBase
58
85
self .AudioPath = os .path .join (modelFileBase , request .AudioPath )
59
-
60
86
try :
61
87
print ("Preparing models, please wait" , file = sys .stderr )
62
- self .model = ChatterboxTTS .from_pretrained (device = device )
88
+ if "multilingual" in self .options :
89
+ # remove key from options
90
+ del self .options ["multilingual" ]
91
+ self .model = ChatterboxMultilingualTTS .from_pretrained (device = device )
92
+ else :
93
+ self .model = ChatterboxTTS .from_pretrained (device = device )
63
94
except Exception as err :
64
95
return backend_pb2 .Result (success = False , message = f"Unexpected { err = } , { type (err )= } " )
65
96
# Implement your logic here for the LoadModel service
@@ -68,12 +99,18 @@ def LoadModel(self, request, context):
68
99
69
100
def TTS (self , request , context ):
70
101
try :
71
- # Generate audio using ChatterboxTTS
102
+ kwargs = {}
103
+
104
+ if "language" in self .options :
105
+ kwargs ["language_id" ] = self .options ["language" ]
72
106
if self .AudioPath is not None :
73
- wav = self .model .generate (request .text , audio_prompt_path = self .AudioPath )
74
- else :
75
- wav = self .model .generate (request .text )
76
-
107
+ kwargs ["audio_prompt_path" ] = self .AudioPath
108
+
109
+ # add options to kwargs
110
+ kwargs .update (self .options )
111
+
112
+ # Generate audio using ChatterboxTTS
113
+ wav = self .model .generate (request .text , ** kwargs )
77
114
# Save the generated audio
78
115
ta .save (request .dst , wav , self .model .sr )
79
116
0 commit comments