@@ -152,8 +152,7 @@ async def async_iter_response(
152
152
content = filter_json (content )
153
153
yield ChatCompletion .model_construct (content , finish_reason , completion_id , int (time .time ()))
154
154
finally :
155
- if hasattr (response , 'aclose' ):
156
- await safe_aclose (response )
155
+ await safe_aclose (response )
157
156
158
157
async def async_iter_append_model_and_provider (
159
158
response : AsyncChatCompletionResponseType
@@ -167,8 +166,7 @@ async def async_iter_append_model_and_provider(
167
166
chunk .provider = last_provider .get ("name" )
168
167
yield chunk
169
168
finally :
170
- if hasattr (response , 'aclose' ):
171
- await safe_aclose (response )
169
+ await safe_aclose (response )
172
170
173
171
class Client (BaseClient ):
174
172
def __init__ (
@@ -292,7 +290,7 @@ async def async_generate(
292
290
proxy = self .client .proxy
293
291
294
292
response = None
295
- if isinstance ( provider , type ) and issubclass ( provider , AsyncGeneratorProvider ):
293
+ if hasattr ( provider_handler , "create_async_generator" ):
296
294
messages = [{"role" : "user" , "content" : f"Generate a image: { prompt } " }]
297
295
async for item in provider_handler .create_async_generator (model , messages , prompt = prompt , ** kwargs ):
298
296
if isinstance (item , ImageResponse ):
@@ -354,7 +352,7 @@ async def async_create_variation(
354
352
if proxy is None :
355
353
proxy = self .client .proxy
356
354
357
- if isinstance (provider , type ) and issubclass ( provider , AsyncGeneratorProvider ):
355
+ if hasattr (provider , "create_async_generator" ):
358
356
messages = [{"role" : "user" , "content" : "create a variation of this image" }]
359
357
generator = None
360
358
try :
@@ -364,8 +362,7 @@ async def async_create_variation(
364
362
response = chunk
365
363
break
366
364
finally :
367
- if generator and hasattr (generator , 'aclose' ):
368
- await safe_aclose (generator )
365
+ await safe_aclose (generator )
369
366
elif hasattr (provider , 'create_variation' ):
370
367
if asyncio .iscoroutinefunction (provider .create_variation ):
371
368
response = await provider .create_variation (image , model = model , response_format = response_format , proxy = proxy , ** kwargs )
@@ -454,7 +451,11 @@ def create(
454
451
)
455
452
stop = [stop ] if isinstance (stop , str ) else stop
456
453
457
- response = provider .create_completion (
454
+ if hasattr (provider , "create_async_generator" ):
455
+ create_handler = provider .create_async_generator
456
+ else :
457
+ create_handler = provider .create_completion
458
+ response = create_handler (
458
459
model ,
459
460
messages ,
460
461
stream = stream ,
0 commit comments