142 lines
4.8 KiB
Python
142 lines
4.8 KiB
Python
import requests
|
||
from telegram.ext import ContextTypes
|
||
from tg_bot.Handler import Handler
|
||
from telegram import Update
|
||
from cia_bot.states.DrctStates import DrctStates
|
||
from telegram.ext import MessageHandler, filters, CallbackQueryHandler
|
||
from db_bot.User import save_user, get_user, update_user, decrement_gen_count
|
||
from db_bot.TaskModel import TaskModel
|
||
from image_helper.img import image_url_to_png_bytes, get_filename_without_extension
|
||
from io import BytesIO
|
||
from gen_api.GenApi import GenApi
|
||
from dotenv import dotenv_values
|
||
from cia_bot.keyboards.MenuKetboard import MenuKeyboard
|
||
|
||
|
||
class DrctHandler(Handler):
|
||
|
||
def __init__(self):
|
||
super().__init__()
|
||
|
||
@staticmethod
|
||
async def get_prompt(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||
await context.bot.send_message(
|
||
chat_id=update.effective_chat.id,
|
||
text='''
|
||
Загрузите изображение которое нужно улучшить.
|
||
'''
|
||
)
|
||
|
||
return DrctStates.get_state_by_key("drct_get_base_img")
|
||
|
||
# @staticmethod
|
||
# async def set_prompt(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||
# context.user_data['prompt'] = update.message.text
|
||
# await update.message.reply_text("Теперь отправь мне изображение.")
|
||
#
|
||
# return UssrStates.get_state_by_key("get_base_img")
|
||
|
||
@staticmethod
|
||
async def get_base_img(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||
user = get_user(update)
|
||
|
||
try:
|
||
# Получаем URL изображения из Telegram
|
||
photo_file = await update.message.photo[-1].get_file()
|
||
image_url = photo_file.file_path
|
||
|
||
# Проверяем доступность изображения
|
||
test_resp = requests.head(image_url)
|
||
if test_resp.status_code != 200:
|
||
await update.message.reply_text("Не удалось получить изображение")
|
||
return
|
||
|
||
if user:
|
||
if user.gen_count <= 0:
|
||
await update.message.reply_text("На сегодня художники устали! Можете обратиться к @kirill_bouko")
|
||
return
|
||
|
||
decrement_gen_count(update)
|
||
|
||
await update.message.reply_text("Улучшаем изображение")
|
||
# Генерируем новое изображение
|
||
response = await DrctHandler.run_upscale(image_url)
|
||
|
||
tm = TaskModel()
|
||
tm.save({
|
||
'user_id': update.effective_chat.id,
|
||
'request_id': response['request_id'],
|
||
'send_url': 1,
|
||
'status': 1,
|
||
})
|
||
|
||
# Получаем и отправляем результат
|
||
# result_url = response['response'][0]
|
||
# img_data = requests.get(result_url).content
|
||
#
|
||
# await update.message.reply_photo(photo=BytesIO(img_data))
|
||
#
|
||
# reply_markup = MenuKeyboard()
|
||
#
|
||
# await context.bot.send_message(
|
||
# chat_id=update.effective_chat.id,
|
||
# text='Вернуться в меню.',
|
||
# reply_markup=reply_markup.create_keyboard()
|
||
# )
|
||
|
||
return DrctStates.get_state_by_key("first")
|
||
|
||
except Exception as e:
|
||
print(f"Ошибка: {str(e)}")
|
||
finally:
|
||
context.user_data.pop('prompt', None)
|
||
|
||
@staticmethod
|
||
async def run_upscale(url):
|
||
config = dotenv_values(".env")
|
||
img = image_url_to_png_bytes(url)
|
||
|
||
img_name = get_filename_without_extension(url)
|
||
|
||
input = {
|
||
"image_url": url,
|
||
"is_sync": 0,
|
||
}
|
||
|
||
ga = GenApi(token=config['GEN_API_TOKEN'])
|
||
response = ga.drct(data_input=input)
|
||
|
||
return response
|
||
|
||
@staticmethod
|
||
def create_poster(url):
|
||
config = dotenv_values(".env")
|
||
img = image_url_to_png_bytes(url)
|
||
|
||
img_name = get_filename_without_extension(url)
|
||
|
||
input = {
|
||
"is_sync": 0,
|
||
}
|
||
|
||
files = {
|
||
# "image": [("imgs_file/2.png", image_data, 'image/png')],
|
||
"image[]": ("{name}.png".format(name=img_name), img, 'image/png'),
|
||
}
|
||
|
||
ga = GenApi(token=config['GEN_API_TOKEN'])
|
||
response = ga.drct(data_input=input, files=files)
|
||
|
||
return response
|
||
|
||
@staticmethod
|
||
def get_states(data=None) -> dict:
|
||
return {
|
||
DrctStates.get_state_by_key("first"): [
|
||
CallbackQueryHandler(DrctHandler.get_prompt, pattern='^drct_get_prompt'),
|
||
],
|
||
DrctStates.get_state_by_key("drct_get_base_img"): [
|
||
MessageHandler(filters.PHOTO, DrctHandler.get_base_img),
|
||
],
|
||
}
|