mirror of
https://github.com/Xtao-Labs/telegram-oauth.git
synced 2024-11-21 06:47:16 +00:00
🎉 Support oidc by telegram
This commit is contained in:
parent
43d0f26dda
commit
4fbe5eb0dc
12
.env.example
Normal file
12
.env.example
Normal file
@ -0,0 +1,12 @@
|
||||
CONN_URI=sqlite+aiosqlite:///data/db.sqlite3
|
||||
DEBUG=True
|
||||
PROJECT_URL=http://127.0.0.1
|
||||
PROJECT_LOGIN_SUCCESS_URL=http://google.com
|
||||
PROJECT_PORT=80
|
||||
JWT_PRIVATE_KEY='data/private_key'
|
||||
JWT_PUBLIC_KEY='data/public_key'
|
||||
BOT_TOKEN=xxx
|
||||
BOT_USERNAME=xxxxBot
|
||||
BOT_API_ID=111
|
||||
BOT_API_HASH=aaa
|
||||
BOT_MANAGER_IDS=[111,222]
|
4
.gitignore
vendored
4
.gitignore
vendored
@ -157,4 +157,6 @@ cython_debug/
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
.idea/
|
||||
|
||||
data/
|
||||
|
32
README.md
Normal file
32
README.md
Normal file
@ -0,0 +1,32 @@
|
||||
# Telegram OAuth
|
||||
|
||||
## Configuration
|
||||
|
||||
```dotenv
|
||||
CONN_URI=sqlite+aiosqlite:///data/db.sqlite3 # 数据库 uri
|
||||
DEBUG=True # 调试模式
|
||||
PROJECT_URL=http://127.0.0.1 # 项目可访问的地址
|
||||
PROJECT_LOGIN_SUCCESS_URL=http://google.com # 登录成功后跳转的地址
|
||||
PROJECT_PORT=80 # 项目运行的端口
|
||||
JWT_PRIVATE_KEY='data/private_key' # jwt 私钥
|
||||
JWT_PUBLIC_KEY='data/public_key' # jwt 公钥
|
||||
BOT_TOKEN=xxx # 机器人 token
|
||||
BOT_USERNAME=xxxxBot # 机器人用户名
|
||||
BOT_API_ID=111 # api id
|
||||
BOT_API_HASH=aaa # api hash
|
||||
BOT_MANAGER_IDS=[111,222] # 管理员 id
|
||||
```
|
||||
|
||||
## OIDC Endpoints
|
||||
|
||||
Auth URL : `/oauth2/authorize`
|
||||
|
||||
Token URL : `/oauth2/token`
|
||||
|
||||
Cert URL : `/oauth2/keys`
|
||||
|
||||
## OIDC Client
|
||||
|
||||
```sql
|
||||
INSERT INTO "client" ("grant_types", "response_types", "redirect_uris", "id", "client_id", "client_secret", "scope") VALUES ('authorization_code', 'code', 'https://127.0.0.1/access/callback', 'UUID', '123456', '123456', 'openid profile email');
|
||||
```
|
0
aioauth_fastapi/__init__.py
Normal file
0
aioauth_fastapi/__init__.py
Normal file
8
aioauth_fastapi/__version__.py
Normal file
8
aioauth_fastapi/__version__.py
Normal file
@ -0,0 +1,8 @@
|
||||
__title__ = "aioauth_fastapi"
|
||||
__description__ = "aioauth integration for FastAPI."
|
||||
__url__ = "https://github.com/aliev/aioauth-fastapi"
|
||||
__version__ = "0.1.2"
|
||||
__author__ = "Ali Aliyev"
|
||||
__author_email__ = "ali@aliev.me"
|
||||
__license__ = "The MIT License (MIT)"
|
||||
__copyright__ = "Copyright 2021 Ali Aliyev"
|
38
aioauth_fastapi/forms.py
Normal file
38
aioauth_fastapi/forms.py
Normal file
@ -0,0 +1,38 @@
|
||||
"""
|
||||
.. code-block:: python
|
||||
|
||||
from aioauth_fastapi import forms
|
||||
|
||||
FastAPI oauth2 forms.
|
||||
|
||||
Used to generate an OpenAPI schema.
|
||||
|
||||
----
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from aioauth.types import GrantType, TokenType
|
||||
from fastapi.params import Form
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenForm:
|
||||
grant_type: Optional[GrantType] = Form(None) # type: ignore
|
||||
client_id: Optional[str] = Form(None) # type: ignore
|
||||
client_secret: Optional[str] = Form(None) # type: ignore
|
||||
redirect_uri: Optional[str] = Form(None) # type: ignore
|
||||
scope: Optional[str] = Form(None) # type: ignore
|
||||
username: Optional[str] = Form(None) # type: ignore
|
||||
password: Optional[str] = Form(None) # type: ignore
|
||||
refresh_token: Optional[str] = Form(None) # type: ignore
|
||||
code: Optional[str] = Form(None) # type: ignore
|
||||
token: Optional[str] = Form(None) # type: ignore
|
||||
code_verifier: Optional[str] = Form(None) # type: ignore
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenIntrospectForm:
|
||||
token: Optional[str] = Form(None) # type: ignore
|
||||
token_type_hint: Optional[TokenType] = Form(None) # type: ignore
|
113
aioauth_fastapi/router.py
Normal file
113
aioauth_fastapi/router.py
Normal file
@ -0,0 +1,113 @@
|
||||
"""
|
||||
.. code-block:: python
|
||||
|
||||
from aioauth_fastapi import router
|
||||
|
||||
FastAPI routing of oauth2.
|
||||
|
||||
Usage example
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from aioauth_fastapi.router import get_oauth2_router
|
||||
from aioauth.storage import BaseStorage
|
||||
from aioauth.config import Settings
|
||||
from aioauth.server import AuthorizationServer
|
||||
from fastapi import FastAPI
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
class SQLAlchemyCRUD(BaseStorage):
|
||||
'''
|
||||
SQLAlchemyCRUD methods must be implemented here.
|
||||
'''
|
||||
|
||||
# NOTE: Redefinition of the default aioauth settings
|
||||
# INSECURE_TRANSPORT must be enabled for local development only!
|
||||
settings = Settings(
|
||||
INSECURE_TRANSPORT=True,
|
||||
)
|
||||
|
||||
storage = SQLAlchemyCRUD()
|
||||
authorization_server = AuthorizationServer(storage)
|
||||
|
||||
# Include FastAPI router with oauth2 endpoints.
|
||||
app.include_router(
|
||||
get_oauth2_router(authorization_server, settings),
|
||||
prefix="/oauth2",
|
||||
tags=["oauth2"],
|
||||
)
|
||||
|
||||
----
|
||||
"""
|
||||
|
||||
from typing import Callable, TypeVar
|
||||
|
||||
from aioauth.config import Settings
|
||||
from aioauth.requests import TRequest
|
||||
from aioauth.server import AuthorizationServer
|
||||
from aioauth.storage import TStorage
|
||||
from fastapi import APIRouter, Request
|
||||
|
||||
from .utils import (
|
||||
RequestArguments,
|
||||
default_request_factory,
|
||||
to_fastapi_response,
|
||||
to_oauth2_request,
|
||||
)
|
||||
|
||||
ARequest = TypeVar("ARequest", bound=TRequest)
|
||||
|
||||
|
||||
def get_oauth2_router(
|
||||
authorization_server: AuthorizationServer[ARequest, TStorage],
|
||||
settings: Settings = Settings(),
|
||||
request_factory: Callable[[RequestArguments], ARequest] = default_request_factory,
|
||||
) -> APIRouter:
|
||||
"""Function will create FastAPI router with the following oauth2 endpoints:
|
||||
|
||||
* POST /token
|
||||
* Endpoint creates a token response by :py:meth:`aioauth.server.AuthorizationServer.create_token_response`
|
||||
* POST `/token/introspect`
|
||||
* Endpoint creates a token introspection by :py:meth:`aioauth.server.AuthorizationServer.create_token_introspection_response`
|
||||
* GET `/authorize`
|
||||
* Endpoint creates an authorization response by :py:meth:`aioauth.server.AuthorizationServer.create_authorization_response`
|
||||
|
||||
Returns:
|
||||
:py:class:`fastapi.APIRouter`.
|
||||
"""
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/token")
|
||||
async def token(request: Request):
|
||||
oauth2_request = await to_oauth2_request(
|
||||
request=request, request_factory=request_factory, settings=settings
|
||||
)
|
||||
oauth2_response = await authorization_server.create_token_response(
|
||||
oauth2_request
|
||||
)
|
||||
return await to_fastapi_response(oauth2_response)
|
||||
|
||||
@router.post("/token/introspect")
|
||||
async def token_introspect(request: Request):
|
||||
oauth2_request = await to_oauth2_request(
|
||||
request=request, request_factory=request_factory, settings=settings
|
||||
)
|
||||
oauth2_response = (
|
||||
await authorization_server.create_token_introspection_response(
|
||||
oauth2_request
|
||||
)
|
||||
)
|
||||
return await to_fastapi_response(oauth2_response)
|
||||
|
||||
@router.get("/authorize")
|
||||
async def authorize(request: Request):
|
||||
oauth2_request = await to_oauth2_request(
|
||||
request=request, request_factory=request_factory, settings=settings
|
||||
)
|
||||
oauth2_response = await authorization_server.create_authorization_response(
|
||||
oauth2_request
|
||||
)
|
||||
return await to_fastapi_response(oauth2_response)
|
||||
|
||||
return router
|
98
aioauth_fastapi/utils.py
Normal file
98
aioauth_fastapi/utils.py
Normal file
@ -0,0 +1,98 @@
|
||||
"""
|
||||
.. code-block:: python
|
||||
|
||||
from aioauth_fastapi import utils
|
||||
|
||||
Core utils for integration with FastAPI
|
||||
|
||||
----
|
||||
"""
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
from aioauth.collections import HTTPHeaderDict
|
||||
from aioauth.config import Settings
|
||||
from aioauth.requests import Post, Query, TRequest, TUser
|
||||
from aioauth.requests import Request as OAuth2Request
|
||||
from aioauth.responses import Response as OAuth2Response
|
||||
from fastapi import Request, Response
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestArguments:
|
||||
headers: HTTPHeaderDict
|
||||
method: str
|
||||
post_args: Dict
|
||||
query_args: Dict
|
||||
settings: Settings
|
||||
url: str
|
||||
user: Optional[TUser]
|
||||
|
||||
|
||||
def default_request_factory(request_args: RequestArguments) -> OAuth2Request:
|
||||
return OAuth2Request(
|
||||
headers=request_args.headers,
|
||||
method=request_args.method, # type: ignore
|
||||
post=Post(**request_args.post_args), # type: ignore
|
||||
query=Query(**request_args.query_args), # type: ignore
|
||||
settings=request_args.settings,
|
||||
url=request_args.url,
|
||||
user=request_args.user,
|
||||
)
|
||||
|
||||
|
||||
async def to_oauth2_request(
|
||||
request: Request,
|
||||
settings: Settings = Settings(),
|
||||
request_factory: Callable[[RequestArguments], TRequest] = default_request_factory,
|
||||
) -> TRequest:
|
||||
"""Converts :py:class:`fastapi.Request` instance to :py:class:`aioauth.requests.Request` instance"""
|
||||
form = await request.form()
|
||||
|
||||
post_args = dict(form)
|
||||
query_args = dict(request.query_params)
|
||||
need_args = [
|
||||
"client_id",
|
||||
"redirect_uri",
|
||||
"response_type",
|
||||
"state",
|
||||
"scope",
|
||||
"nonce",
|
||||
"code_challenge_method",
|
||||
"code_challenge",
|
||||
"response_mode",
|
||||
]
|
||||
for arg in list(query_args.keys()):
|
||||
if arg not in need_args:
|
||||
del query_args[arg]
|
||||
method = request.method
|
||||
headers = HTTPHeaderDict(**request.headers)
|
||||
url = str(request.url)
|
||||
|
||||
user = None
|
||||
|
||||
if request.user.is_authenticated:
|
||||
user = request.user
|
||||
|
||||
request_args = RequestArguments(
|
||||
headers=headers,
|
||||
method=method,
|
||||
post_args=post_args,
|
||||
query_args=query_args,
|
||||
settings=settings,
|
||||
url=url,
|
||||
user=user,
|
||||
)
|
||||
return request_factory(request_args)
|
||||
|
||||
|
||||
async def to_fastapi_response(oauth2_response: OAuth2Response) -> Response:
|
||||
"""Converts :py:class:`aioauth.responses.Response` instance to :py:class:`fastapi.Response` instance"""
|
||||
response_content = oauth2_response.content
|
||||
headers = dict(oauth2_response.headers)
|
||||
status_code = oauth2_response.status_code
|
||||
content = json.dumps(response_content)
|
||||
|
||||
return Response(content=content, headers=headers, status_code=status_code)
|
89
alembic.ini
Normal file
89
alembic.ini
Normal file
@ -0,0 +1,89 @@
|
||||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts
|
||||
script_location = alembic/
|
||||
|
||||
# template used to generate migration files
|
||||
# file_template = %%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
# defaults to the current working directory.
|
||||
prepend_sys_path = .
|
||||
|
||||
# timezone to use when rendering the date
|
||||
# within the migration file as well as the filename.
|
||||
# string value is passed to dateutil.tz.gettz()
|
||||
# leave blank for localtime
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the
|
||||
# "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
# revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
# sourceless = false
|
||||
|
||||
# version location specification; this defaults
|
||||
# to ./versions. When using multiple version
|
||||
# directories, initial revisions must be specified with --version-path
|
||||
# version_locations = %(here)s/bar %(here)s/bat ./versions
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
sqlalchemy.url = sqlite+aiosqlite:///data/db.sqlite3
|
||||
|
||||
|
||||
[post_write_hooks]
|
||||
# post_write_hooks defines scripts or Python functions that are run
|
||||
# on newly generated revision scripts. See the documentation for further
|
||||
# detail and examples
|
||||
|
||||
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||
# hooks = black
|
||||
# black.type = console_scripts
|
||||
# black.entrypoint = black
|
||||
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Logging configuration
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
90
alembic/env.py
Normal file
90
alembic/env.py
Normal file
@ -0,0 +1,90 @@
|
||||
import asyncio
|
||||
from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from src.config import settings
|
||||
from src.oauth2.models import * # noqa
|
||||
from src.users.models import * # noqa
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
config.set_main_option("sqlalchemy.url", str(settings.CONN_URI))
|
||||
|
||||
# add your model's MetaData object here
|
||||
# for 'autogenerate' support
|
||||
# from myapp import mymodel
|
||||
# target_metadata = mymodel.Base.metadata
|
||||
target_metadata = SQLModel.metadata # type: ignore
|
||||
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
# can be acquired:
|
||||
# my_important_option = config.get_main_option("my_important_option")
|
||||
# ... etc.
|
||||
|
||||
|
||||
def run_migrations_offline():
|
||||
"""Run migrations in 'offline' mode.
|
||||
|
||||
This configures the context with just a URL
|
||||
and not an Engine, though an Engine is acceptable
|
||||
here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
|
||||
Calls to context.execute() here emit the given string to the
|
||||
script output.
|
||||
|
||||
"""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def do_run_migrations(connection):
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
async def run_migrations_online():
|
||||
"""Run migrations in 'online' mode.
|
||||
|
||||
In this scenario we need to create an Engine
|
||||
and associate a connection with the context.
|
||||
|
||||
"""
|
||||
connectable = AsyncEngine(
|
||||
engine_from_config(
|
||||
config.get_section(config.config_ini_section),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
future=True,
|
||||
)
|
||||
)
|
||||
|
||||
async with connectable.connect() as connection:
|
||||
await connection.run_sync(do_run_migrations)
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
asyncio.run(run_migrations_online())
|
25
alembic/script.py.mako
Normal file
25
alembic/script.py.mako
Normal file
@ -0,0 +1,25 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = ${repr(up_revision)}
|
||||
down_revision = ${repr(down_revision)}
|
||||
branch_labels = ${repr(branch_labels)}
|
||||
depends_on = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade():
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade():
|
||||
${downgrades if downgrades else "pass"}
|
226
alembic/versions/07a7ace268a7_initial_migrations.py
Normal file
226
alembic/versions/07a7ace268a7_initial_migrations.py
Normal file
@ -0,0 +1,226 @@
|
||||
"""Initial migrations
|
||||
|
||||
Revision ID: 07a7ace268a7
|
||||
Revises:
|
||||
Create Date: 2021-10-02 22:50:10.418498
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "07a7ace268a7"
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"users",
|
||||
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
|
||||
sa.Column("is_superuser", sa.Boolean(), nullable=True),
|
||||
sa.Column("is_blocked", sa.Boolean(), nullable=True),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=True),
|
||||
sa.Column("username", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column("password", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(op.f("ix_users_id"), "users", ["id"], unique=True)
|
||||
op.create_index(op.f("ix_users_is_active"), "users", ["is_active"], unique=False)
|
||||
op.create_index(op.f("ix_users_is_blocked"), "users", ["is_blocked"], unique=False)
|
||||
op.create_index(
|
||||
op.f("ix_users_is_superuser"), "users", ["is_superuser"], unique=False
|
||||
)
|
||||
op.create_index(op.f("ix_users_password"), "users", ["password"], unique=False)
|
||||
op.create_index(op.f("ix_users_username"), "users", ["username"], unique=True)
|
||||
op.create_table(
|
||||
"authorizationcode",
|
||||
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
|
||||
sa.Column("code", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column("client_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column("redirect_uri", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column("response_type", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column("scope", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column("auth_time", sa.Integer(), nullable=False),
|
||||
sa.Column("expires_in", sa.Integer(), nullable=False),
|
||||
sa.Column("code_challenge", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column(
|
||||
"code_challenge_method", sqlmodel.sql.sqltypes.AutoString(), nullable=True
|
||||
),
|
||||
sa.Column("nonce", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["users.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_authorizationcode_auth_time"),
|
||||
"authorizationcode",
|
||||
["auth_time"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_authorizationcode_client_id"),
|
||||
"authorizationcode",
|
||||
["client_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_authorizationcode_code"), "authorizationcode", ["code"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_authorizationcode_code_challenge"),
|
||||
"authorizationcode",
|
||||
["code_challenge"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_authorizationcode_code_challenge_method"),
|
||||
"authorizationcode",
|
||||
["code_challenge_method"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_authorizationcode_expires_in"),
|
||||
"authorizationcode",
|
||||
["expires_in"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_authorizationcode_id"), "authorizationcode", ["id"], unique=True
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_authorizationcode_nonce"), "authorizationcode", ["nonce"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_authorizationcode_redirect_uri"),
|
||||
"authorizationcode",
|
||||
["redirect_uri"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_authorizationcode_response_type"),
|
||||
"authorizationcode",
|
||||
["response_type"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_authorizationcode_scope"), "authorizationcode", ["scope"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_authorizationcode_user_id"),
|
||||
"authorizationcode",
|
||||
["user_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_table(
|
||||
"client",
|
||||
sa.Column("grant_types", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column("response_types", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column("redirect_uris", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
|
||||
sa.Column("client_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column("client_secret", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column("scope", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(op.f("ix_client_client_id"), "client", ["client_id"], unique=False)
|
||||
op.create_index(
|
||||
op.f("ix_client_client_secret"), "client", ["client_secret"], unique=False
|
||||
)
|
||||
op.create_index(op.f("ix_client_id"), "client", ["id"], unique=True)
|
||||
op.create_index(op.f("ix_client_scope"), "client", ["scope"], unique=False)
|
||||
op.create_table(
|
||||
"token",
|
||||
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
|
||||
sa.Column("access_token", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column("refresh_token", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column("scope", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column("issued_at", sa.Integer(), nullable=False),
|
||||
sa.Column("expires_in", sa.Integer(), nullable=False),
|
||||
sa.Column("refresh_token_expires_in", sa.Integer(), nullable=False),
|
||||
sa.Column("client_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column("token_type", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column("revoked", sa.Boolean(), nullable=False),
|
||||
sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["users.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(op.f("ix_token_client_id"), "token", ["client_id"], unique=False)
|
||||
op.create_index(op.f("ix_token_expires_in"), "token", ["expires_in"], unique=False)
|
||||
op.create_index(op.f("ix_token_id"), "token", ["id"], unique=True)
|
||||
op.create_index(op.f("ix_token_issued_at"), "token", ["issued_at"], unique=False)
|
||||
op.create_index(
|
||||
op.f("ix_token_refresh_token_expires_in"),
|
||||
"token",
|
||||
["refresh_token_expires_in"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(op.f("ix_token_revoked"), "token", ["revoked"], unique=False)
|
||||
op.create_index(op.f("ix_token_scope"), "token", ["scope"], unique=False)
|
||||
op.create_index(op.f("ix_token_token_type"), "token", ["token_type"], unique=False)
|
||||
op.create_index(op.f("ix_token_user_id"), "token", ["user_id"], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f("ix_token_user_id"), table_name="token")
|
||||
op.drop_index(op.f("ix_token_token_type"), table_name="token")
|
||||
op.drop_index(op.f("ix_token_scope"), table_name="token")
|
||||
op.drop_index(op.f("ix_token_revoked"), table_name="token")
|
||||
op.drop_index(op.f("ix_token_refresh_token_expires_in"), table_name="token")
|
||||
op.drop_index(op.f("ix_token_issued_at"), table_name="token")
|
||||
op.drop_index(op.f("ix_token_id"), table_name="token")
|
||||
op.drop_index(op.f("ix_token_expires_in"), table_name="token")
|
||||
op.drop_index(op.f("ix_token_client_id"), table_name="token")
|
||||
op.drop_table("token")
|
||||
op.drop_index(op.f("ix_client_scope"), table_name="client")
|
||||
op.drop_index(op.f("ix_client_id"), table_name="client")
|
||||
op.drop_index(op.f("ix_client_client_secret"), table_name="client")
|
||||
op.drop_index(op.f("ix_client_client_id"), table_name="client")
|
||||
op.drop_table("client")
|
||||
op.drop_index(op.f("ix_authorizationcode_user_id"), table_name="authorizationcode")
|
||||
op.drop_index(op.f("ix_authorizationcode_scope"), table_name="authorizationcode")
|
||||
op.drop_index(
|
||||
op.f("ix_authorizationcode_response_type"), table_name="authorizationcode"
|
||||
)
|
||||
op.drop_index(
|
||||
op.f("ix_authorizationcode_redirect_uri"), table_name="authorizationcode"
|
||||
)
|
||||
op.drop_index(op.f("ix_authorizationcode_nonce"), table_name="authorizationcode")
|
||||
op.drop_index(op.f("ix_authorizationcode_id"), table_name="authorizationcode")
|
||||
op.drop_index(
|
||||
op.f("ix_authorizationcode_expires_in"), table_name="authorizationcode"
|
||||
)
|
||||
op.drop_index(
|
||||
op.f("ix_authorizationcode_code_challenge_method"),
|
||||
table_name="authorizationcode",
|
||||
)
|
||||
op.drop_index(
|
||||
op.f("ix_authorizationcode_code_challenge"), table_name="authorizationcode"
|
||||
)
|
||||
op.drop_index(op.f("ix_authorizationcode_code"), table_name="authorizationcode")
|
||||
op.drop_index(
|
||||
op.f("ix_authorizationcode_client_id"), table_name="authorizationcode"
|
||||
)
|
||||
op.drop_index(
|
||||
op.f("ix_authorizationcode_auth_time"), table_name="authorizationcode"
|
||||
)
|
||||
op.drop_table("authorizationcode")
|
||||
op.drop_index(op.f("ix_users_username"), table_name="users")
|
||||
op.drop_index(op.f("ix_users_password"), table_name="users")
|
||||
op.drop_index(op.f("ix_users_is_superuser"), table_name="users")
|
||||
op.drop_index(op.f("ix_users_is_blocked"), table_name="users")
|
||||
op.drop_index(op.f("ix_users_is_active"), table_name="users")
|
||||
op.drop_index(op.f("ix_users_id"), table_name="users")
|
||||
op.drop_table("users")
|
||||
# ### end Alembic commands ###
|
27
alembic/versions/c76c4cbb0b3b_tgid.py
Normal file
27
alembic/versions/c76c4cbb0b3b_tgid.py
Normal file
@ -0,0 +1,27 @@
|
||||
"""tgid
|
||||
|
||||
Revision ID: c76c4cbb0b3b
|
||||
Revises: 07a7ace268a7
|
||||
Create Date: 2024-01-13 16:12:45.884304
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'c76c4cbb0b3b'
|
||||
down_revision = '07a7ace268a7'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('users', sa.Column('tg_id', sa.BigInteger(), nullable=False))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column('users', 'tg_id')
|
||||
# ### end Alembic commands ###
|
26
gen_keys.py
Normal file
26
gen_keys.py
Normal file
@ -0,0 +1,26 @@
|
||||
from pathlib import Path
|
||||
|
||||
data_path = Path("data")
|
||||
data_path.mkdir(exist_ok=True)
|
||||
private_key_path = data_path / "private_key"
|
||||
public_key_path = data_path / "public_key"
|
||||
|
||||
|
||||
def gen_keys():
|
||||
from Crypto.PublicKey import RSA
|
||||
|
||||
key = RSA.generate(2048)
|
||||
private_key = key.export_key().decode("utf-8")
|
||||
public_key = key.publickey().export_key().decode("utf-8")
|
||||
|
||||
if private_key_path.is_file() and public_key_path.is_file():
|
||||
print("Keys already exist")
|
||||
return
|
||||
with open(private_key_path, "w") as f:
|
||||
f.write(private_key)
|
||||
with open(public_key_path, "w") as f:
|
||||
f.write(public_key)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
gen_keys()
|
40
html/login.jinja
Normal file
40
html/login.jinja
Normal file
@ -0,0 +1,40 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>Title</title>
|
||||
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/bootstrap/5.3.2/css/bootstrap.min.css"
|
||||
integrity="sha512-b2QcS5SsA8tZodcDtGRELiGv5SaKSk1vDHDaQRda0htPYWZ6046lr3kJ5bAAQdpV2mmA/4v0wQF9MyU6/pDIAg=="
|
||||
crossorigin="anonymous" referrerpolicy="no-referrer"/>
|
||||
<style>
|
||||
.main {
|
||||
text-align: center;
|
||||
background-color: #fff;
|
||||
border-radius: 20px;
|
||||
width: 260px;
|
||||
height: 20px;
|
||||
margin: auto;
|
||||
position: absolute;
|
||||
top: 0;
|
||||
left: 0;
|
||||
right: 0;
|
||||
bottom: 0;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="main">
|
||||
<script async src="https://telegram.org/js/telegram-widget.js?22"
|
||||
data-telegram-login="{{ username }}"
|
||||
data-size="large"
|
||||
data-auth-url="{{ callback_url }}"
|
||||
data-request-access="write">
|
||||
</script>
|
||||
<div>
|
||||
<a href="https://t.me/{{ username }}?start=login">
|
||||
<button type="button" class="btn btn-primary">通过 BOT 登录</button>
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
50
main.py
Normal file
50
main.py
Normal file
@ -0,0 +1,50 @@
|
||||
import asyncio
|
||||
from signal import signal as signal_fn, SIGINT, SIGTERM, SIGABRT
|
||||
|
||||
from src.app import web
|
||||
from src.bot import bot
|
||||
from src.logs import logs
|
||||
|
||||
|
||||
async def idle():
|
||||
task = None
|
||||
|
||||
def signal_handler(_, __):
|
||||
if web.web_server_task:
|
||||
web.web_server_task.cancel()
|
||||
task.cancel()
|
||||
|
||||
for s in (SIGINT, SIGTERM, SIGABRT):
|
||||
signal_fn(s, signal_handler)
|
||||
|
||||
while True:
|
||||
task = asyncio.create_task(asyncio.sleep(600))
|
||||
web.bot_main_task = task
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
|
||||
async def main():
|
||||
logs.info("正在启动 Web Server")
|
||||
await web.start()
|
||||
logs.info("正在启动 Bot")
|
||||
await bot.start()
|
||||
try:
|
||||
logs.info("正在运行")
|
||||
await idle()
|
||||
finally:
|
||||
try:
|
||||
await bot.stop()
|
||||
except ConnectionError:
|
||||
pass
|
||||
if web.web_server:
|
||||
try:
|
||||
await web.web_server.shutdown()
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bot.run(main())
|
21
pyromod/__init__.py
Normal file
21
pyromod/__init__.py
Normal file
@ -0,0 +1,21 @@
|
||||
"""
|
||||
pyromod - A monkeypatched add-on for Pyrogram
|
||||
Copyright (C) 2020 Cezar H. <https://github.com/usernein>
|
||||
|
||||
This file is part of pyromod.
|
||||
|
||||
pyromod is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
pyromod is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with pyromod. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
__version__ = "1.5"
|
21
pyromod/listen/__init__.py
Normal file
21
pyromod/listen/__init__.py
Normal file
@ -0,0 +1,21 @@
|
||||
"""
|
||||
pyromod - A monkeypatcher add-on for Pyrogram
|
||||
Copyright (C) 2020 Cezar H. <https://github.com/usernein>
|
||||
|
||||
This file is part of pyromod.
|
||||
|
||||
pyromod is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
pyromod is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with pyromod. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
from .listen import Client, MessageHandler, Chat, User
|
157
pyromod/listen/listen.py
Normal file
157
pyromod/listen/listen.py
Normal file
@ -0,0 +1,157 @@
|
||||
"""
|
||||
pyromod - A monkeypatcher add-on for Pyrogram
|
||||
Copyright (C) 2020 Cezar H. <https://github.com/usernein>
|
||||
|
||||
This file is part of pyromod.
|
||||
|
||||
pyromod is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
pyromod is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with pyromod. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
|
||||
import pyrogram
|
||||
|
||||
from src.scheduler import add_delete_message_job
|
||||
from ..utils import patch, patchable
|
||||
from ..utils.errors import ListenerCanceled, TimeoutConversationError
|
||||
|
||||
pyrogram.errors.ListenerCanceled = ListenerCanceled
|
||||
|
||||
|
||||
@patch(pyrogram.client.Client)
|
||||
class Client:
|
||||
@patchable
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.listening = {}
|
||||
self.using_mod = True
|
||||
|
||||
self.old__init__(*args, **kwargs)
|
||||
|
||||
@patchable
|
||||
async def listen(self, chat_id, filters=None, timeout=None):
|
||||
if type(chat_id) != int:
|
||||
chat = await self.get_chat(chat_id)
|
||||
chat_id = chat.id
|
||||
|
||||
future = self.loop.create_future()
|
||||
future.add_done_callback(functools.partial(self.clear_listener, chat_id))
|
||||
self.listening.update({chat_id: {"future": future, "filters": filters}})
|
||||
try:
|
||||
return await asyncio.wait_for(future, timeout)
|
||||
except asyncio.exceptions.TimeoutError as e:
|
||||
raise TimeoutConversationError() from e
|
||||
|
||||
@patchable
|
||||
async def ask(self, chat_id, text, filters=None, timeout=None, *args, **kwargs):
|
||||
request = await self.send_message(chat_id, text, *args, **kwargs)
|
||||
response = await self.listen(chat_id, filters, timeout)
|
||||
response.request = request
|
||||
return response
|
||||
|
||||
@patchable
|
||||
def clear_listener(self, chat_id, future):
|
||||
if future == self.listening[chat_id]["future"]:
|
||||
self.listening.pop(chat_id, None)
|
||||
|
||||
@patchable
|
||||
def cancel_listener(self, chat_id):
|
||||
listener = self.listening.get(chat_id)
|
||||
if not listener or listener["future"].done():
|
||||
return
|
||||
|
||||
listener["future"].set_exception(ListenerCanceled())
|
||||
self.clear_listener(chat_id, listener["future"])
|
||||
|
||||
@patchable
|
||||
def cancel_all_listener(self):
|
||||
for chat_id in self.listening:
|
||||
self.cancel_listener(chat_id)
|
||||
|
||||
|
||||
@patch(pyrogram.handlers.message_handler.MessageHandler)
|
||||
class MessageHandler:
|
||||
@patchable
|
||||
def __init__(self, callback: callable, filters=None):
|
||||
self.user_callback = callback
|
||||
self.old__init__(self.resolve_listener, filters)
|
||||
|
||||
@patchable
|
||||
async def resolve_listener(self, client, message, *args):
|
||||
listener = client.listening.get(message.chat.id)
|
||||
if listener and not listener["future"].done():
|
||||
listener["future"].set_result(message)
|
||||
else:
|
||||
if listener and listener["future"].done():
|
||||
client.clear_listener(message.chat.id, listener["future"])
|
||||
await self.user_callback(client, message, *args)
|
||||
|
||||
@patchable
|
||||
async def check(self, client, update):
|
||||
listener = client.listening.get(update.chat.id)
|
||||
|
||||
if listener and not listener["future"].done():
|
||||
return (
|
||||
await listener["filters"](client, update)
|
||||
if callable(listener["filters"])
|
||||
else True
|
||||
)
|
||||
|
||||
return await self.filters(client, update) if callable(self.filters) else True
|
||||
|
||||
|
||||
@patch(pyrogram.types.user_and_chats.chat.Chat)
|
||||
class Chat(pyrogram.types.Chat):
|
||||
@patchable
|
||||
def listen(self, *args, **kwargs):
|
||||
return self._client.listen(self.id, *args, **kwargs)
|
||||
|
||||
@patchable
|
||||
def ask(self, *args, **kwargs):
|
||||
return self._client.ask(self.id, *args, **kwargs)
|
||||
|
||||
@patchable
|
||||
def cancel_listener(self):
|
||||
return self._client.cancel_listener(self.id)
|
||||
|
||||
|
||||
@patch(pyrogram.types.user_and_chats.user.User)
|
||||
class User(pyrogram.types.User):
|
||||
@patchable
|
||||
def listen(self, *args, **kwargs):
|
||||
return self._client.listen(self.id, *args, **kwargs)
|
||||
|
||||
@patchable
|
||||
def ask(self, *args, **kwargs):
|
||||
return self._client.ask(self.id, *args, **kwargs)
|
||||
|
||||
@patchable
|
||||
def cancel_listener(self):
|
||||
return self._client.cancel_listener(self.id)
|
||||
|
||||
|
||||
@patch(pyrogram.types.messages_and_media.Message)
|
||||
class Message(pyrogram.types.Message):
|
||||
@patchable
|
||||
async def safe_delete(self, revoke: bool = True):
|
||||
try:
|
||||
return await self._client.delete_messages(
|
||||
chat_id=self.chat.id, message_ids=self.id, revoke=revoke
|
||||
)
|
||||
except Exception as e: # noqa
|
||||
return False
|
||||
|
||||
@patchable
|
||||
async def delay_delete(self, delay: int = 60):
|
||||
add_delete_message_job(self, delay)
|
21
pyromod/utils/__init__.py
Normal file
21
pyromod/utils/__init__.py
Normal file
@ -0,0 +1,21 @@
|
||||
"""
|
||||
pyromod - A monkeypatcher add-on for Pyrogram
|
||||
Copyright (C) 2020 Cezar H. <https://github.com/usernein>
|
||||
|
||||
This file is part of pyromod.
|
||||
|
||||
pyromod is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
pyromod is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with pyromod. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
from .utils import patch, patchable
|
16
pyromod/utils/errors.py
Normal file
16
pyromod/utils/errors.py
Normal file
@ -0,0 +1,16 @@
|
||||
class TimeoutConversationError(Exception):
|
||||
"""
|
||||
Occurs when the conversation times out.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("Response read timed out")
|
||||
|
||||
|
||||
class ListenerCanceled(Exception):
|
||||
"""
|
||||
Occurs when the listener is canceled.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("Listener was canceled")
|
38
pyromod/utils/utils.py
Normal file
38
pyromod/utils/utils.py
Normal file
@ -0,0 +1,38 @@
|
||||
"""
|
||||
pyromod - A monkeypatcher add-on for Pyrogram
|
||||
Copyright (C) 2020 Cezar H. <https://github.com/usernein>
|
||||
|
||||
This file is part of pyromod.
|
||||
|
||||
pyromod is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
pyromod is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with pyromod. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
|
||||
def patch(obj):
|
||||
def is_patchable(item):
|
||||
return getattr(item[1], "patchable", False)
|
||||
|
||||
def wrapper(container):
|
||||
for name, func in filter(is_patchable, container.__dict__.items()):
|
||||
old = getattr(obj, name, None)
|
||||
setattr(obj, f"old{name}", old)
|
||||
setattr(obj, name, func)
|
||||
return container
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def patchable(func):
|
||||
func.patchable = True
|
||||
return func
|
22
requirements.txt
Normal file
22
requirements.txt
Normal file
@ -0,0 +1,22 @@
|
||||
fastapi==0.109.0
|
||||
uvicorn==0.25.0
|
||||
git+https://github.com/aliev/aioauth
|
||||
sqlmodel==0.0.14
|
||||
alembic==1.13.1
|
||||
aiosqlite==0.19.0
|
||||
PyCryptodome==3.20.0
|
||||
python-jose[cryptography]==3.3.0
|
||||
python-multipart==0.0.6
|
||||
orjson==3.9.10
|
||||
jinja2==3.1.3
|
||||
pydantic~=2.5.3
|
||||
pydantic-settings==2.1.0
|
||||
SQLAlchemy~=2.0.25
|
||||
starlette~=0.35.1
|
||||
pyrogram==2.0.106
|
||||
tgcrypto==1.2.5
|
||||
pytz~=2023.3.post1
|
||||
APScheduler~=3.10.4
|
||||
coloredlogs~=15.0.1
|
||||
httpx==0.26.0
|
||||
asyncmy==0.2.9
|
0
src/__init__.py
Normal file
0
src/__init__.py
Normal file
87
src/app.py
Normal file
87
src/app.py
Normal file
@ -0,0 +1,87 @@
|
||||
import asyncio
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from starlette.middleware.authentication import AuthenticationMiddleware
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
from .config import settings
|
||||
from .events import on_shutdown, on_startup
|
||||
from .logs import logs
|
||||
from .oauth2 import endpoints as oauth2_endpoints
|
||||
from .users import endpoints as users_endpoints
|
||||
from .users.backends import TokenAuthenticationBackend
|
||||
|
||||
|
||||
class Web:
|
||||
def __init__(self):
|
||||
self.app = FastAPI(
|
||||
title=settings.PROJECT_NAME,
|
||||
docs_url="/api/openapi",
|
||||
openapi_url="/api/openapi.json",
|
||||
default_response_class=ORJSONResponse,
|
||||
on_startup=on_startup,
|
||||
on_shutdown=on_shutdown,
|
||||
)
|
||||
self.web_server = None
|
||||
self.web_server_task = None
|
||||
self.bot_main_task = None
|
||||
|
||||
def init_web(self):
|
||||
# Include API router
|
||||
self.app.include_router(users_endpoints.router, prefix="/api/users", tags=["users"])
|
||||
|
||||
# Define aioauth-fastapi endpoints
|
||||
self.app.include_router(
|
||||
oauth2_endpoints.router,
|
||||
prefix="/oauth2",
|
||||
tags=["oauth2"],
|
||||
)
|
||||
self.app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.CORS_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
self.app.add_middleware(AuthenticationMiddleware, backend=TokenAuthenticationBackend())
|
||||
|
||||
async def start(self):
|
||||
import uvicorn
|
||||
|
||||
self.init_web()
|
||||
self.web_server = uvicorn.Server(
|
||||
config=uvicorn.Config(
|
||||
self.app,
|
||||
host=settings.PROJECT_HOST,
|
||||
port=settings.PROJECT_PORT,
|
||||
reload=settings.DEBUG,
|
||||
)
|
||||
)
|
||||
server_config = self.web_server.config
|
||||
server_config.setup_event_loop()
|
||||
if not server_config.loaded:
|
||||
server_config.load()
|
||||
self.web_server.lifespan = server_config.lifespan_class(server_config)
|
||||
try:
|
||||
await self.web_server.startup()
|
||||
except OSError as e:
|
||||
if e.errno == 10048:
|
||||
logs.error("Web Server 端口被占用:%s", e)
|
||||
logs.error("Web Server 启动失败,正在退出")
|
||||
raise SystemExit from None
|
||||
|
||||
if self.web_server.should_exit:
|
||||
logs.error("Web Server 启动失败,正在退出")
|
||||
raise SystemExit from None
|
||||
logs.info("Web Server 启动成功")
|
||||
self.web_server_task = asyncio.create_task(self.web_server.main_loop())
|
||||
|
||||
async def stop(self):
|
||||
if self.web_server_task:
|
||||
self.web_server_task.cancel()
|
||||
if self.bot_main_task:
|
||||
self.bot_main_task.cancel()
|
||||
|
||||
|
||||
web = Web()
|
13
src/bot.py
Normal file
13
src/bot.py
Normal file
@ -0,0 +1,13 @@
|
||||
import pyromod.listen
|
||||
from pyrogram import Client
|
||||
|
||||
from .config import settings
|
||||
|
||||
bot = Client(
|
||||
"bot",
|
||||
bot_token=settings.BOT_TOKEN,
|
||||
api_id=settings.BOT_API_ID,
|
||||
api_hash=settings.BOT_API_HASH,
|
||||
plugins={"root": "src.telegram.plugins"},
|
||||
workdir="data",
|
||||
)
|
35
src/config.py
Normal file
35
src/config.py
Normal file
@ -0,0 +1,35 @@
|
||||
from typing import List
|
||||
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
PROJECT_NAME: str = "Telegram OAuth"
|
||||
PROJECT_URL: str = "http://127.0.0.1:8081"
|
||||
PROJECT_LOGIN_SUCCESS_URL: str = "http://google.com"
|
||||
PROJECT_HOST: str = "127.0.0.1"
|
||||
PROJECT_PORT: int = 8001
|
||||
DEBUG: bool = True
|
||||
|
||||
CONN_URI: str
|
||||
|
||||
JWT_PUBLIC_KEY: str
|
||||
JWT_PRIVATE_KEY: str
|
||||
|
||||
ACCESS_TOKEN_EXP: int = 900 # 15 minutes
|
||||
REFRESH_TOKEN_EXP: int = 86400 # 1 day
|
||||
|
||||
CORS_ORIGINS: List[str] = ["*"]
|
||||
|
||||
BOT_TOKEN: str
|
||||
BOT_USERNAME: str
|
||||
BOT_API_ID: int
|
||||
BOT_API_HASH: str
|
||||
BOT_MANAGER_IDS: List[int]
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
case_sensitive = True
|
||||
|
||||
|
||||
settings = Settings()
|
27
src/events.py
Normal file
27
src/events.py
Normal file
@ -0,0 +1,27 @@
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import NullPool
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from .config import settings
|
||||
from .storage import sqlalchemy
|
||||
|
||||
|
||||
async def create_sqlalchemy_connection():
|
||||
# NOTE: https://docs.sqlalchemy.org/en/14/orm/extensions/asyncio.html#using-multiple-asyncio-event-loops
|
||||
engine = create_async_engine(settings.CONN_URI, echo=True, poolclass=NullPool)
|
||||
async_session = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
|
||||
sqlalchemy.sqlalchemy_session = async_session()
|
||||
|
||||
|
||||
async def close_sqlalchemy_connection():
|
||||
if sqlalchemy.sqlalchemy_session is not None:
|
||||
await sqlalchemy.sqlalchemy_session.close()
|
||||
|
||||
|
||||
on_startup = [
|
||||
create_sqlalchemy_connection,
|
||||
]
|
||||
on_shutdown = [
|
||||
close_sqlalchemy_connection,
|
||||
]
|
3
src/html.py
Normal file
3
src/html.py
Normal file
@ -0,0 +1,3 @@
|
||||
from starlette.templating import Jinja2Templates
|
||||
|
||||
templates = Jinja2Templates(directory="html")
|
21
src/logs.py
Normal file
21
src/logs.py
Normal file
@ -0,0 +1,21 @@
|
||||
from logging import getLogger, StreamHandler, basicConfig, INFO, CRITICAL, ERROR
|
||||
|
||||
from coloredlogs import ColoredFormatter
|
||||
|
||||
logs = getLogger("telegram-oauth")
|
||||
logging_format = "%(levelname)s [%(asctime)s] [%(name)s] %(message)s"
|
||||
logging_handler = StreamHandler()
|
||||
logging_handler.setFormatter(ColoredFormatter(logging_format))
|
||||
root_logger = getLogger()
|
||||
root_logger.setLevel(CRITICAL)
|
||||
root_logger.addHandler(logging_handler)
|
||||
pyro_logger = getLogger("pyrogram")
|
||||
pyro_logger.setLevel(INFO)
|
||||
sql_logger = getLogger("sqlalchemy")
|
||||
sql_logger.setLevel(CRITICAL)
|
||||
sql_engine_logger = getLogger("sqlalchemy.engine.Engine")
|
||||
sql_engine_logger.setLevel(CRITICAL)
|
||||
aioauth_logger = getLogger("aioauth")
|
||||
aioauth_logger.setLevel(INFO)
|
||||
basicConfig(level=ERROR)
|
||||
logs.setLevel(INFO)
|
0
src/oauth2/__init__.py
Normal file
0
src/oauth2/__init__.py
Normal file
56
src/oauth2/endpoints.py
Normal file
56
src/oauth2/endpoints.py
Normal file
@ -0,0 +1,56 @@
|
||||
from aioauth.config import Settings
|
||||
from aioauth.oidc.core.grant_type import AuthorizationCodeGrantType
|
||||
from aioauth.requests import Request as OAuth2Request
|
||||
from aioauth.server import AuthorizationServer
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
|
||||
from aioauth_fastapi.utils import to_fastapi_response, to_oauth2_request
|
||||
from .storage import Storage
|
||||
from ..config import settings as local_settings
|
||||
from ..storage.sqlalchemy import SQLAlchemyStorage, get_sqlalchemy_storage
|
||||
from ..users.crypto import get_pub_key_resp
|
||||
from ..utils.oauth import to_login_request
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
settings = Settings(
|
||||
TOKEN_EXPIRES_IN=local_settings.ACCESS_TOKEN_EXP,
|
||||
REFRESH_TOKEN_EXPIRES_IN=local_settings.REFRESH_TOKEN_EXP,
|
||||
INSECURE_TRANSPORT=local_settings.DEBUG,
|
||||
)
|
||||
grant_types = {
|
||||
"authorization_code": AuthorizationCodeGrantType,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/token")
|
||||
async def token(
|
||||
request: Request,
|
||||
storage: SQLAlchemyStorage = Depends(get_sqlalchemy_storage),
|
||||
):
|
||||
oauth2_storage = Storage(storage=storage)
|
||||
authorization_server = AuthorizationServer(storage=oauth2_storage, grant_types=grant_types)
|
||||
oauth2_request: OAuth2Request = await to_oauth2_request(request, settings)
|
||||
oauth2_response = await authorization_server.create_token_response(oauth2_request)
|
||||
return await to_fastapi_response(oauth2_response)
|
||||
|
||||
|
||||
@router.get("/authorize")
|
||||
async def authorize(
|
||||
request: Request,
|
||||
storage: SQLAlchemyStorage = Depends(get_sqlalchemy_storage),
|
||||
):
|
||||
if not request.user.is_authenticated:
|
||||
return await to_login_request(request)
|
||||
oauth2_storage = Storage(storage=storage)
|
||||
authorization_server = AuthorizationServer(storage=oauth2_storage, grant_types=grant_types)
|
||||
oauth2_request: OAuth2Request = await to_oauth2_request(request, settings)
|
||||
oauth2_response = await authorization_server.create_authorization_response(
|
||||
oauth2_request
|
||||
)
|
||||
return await to_fastapi_response(oauth2_response)
|
||||
|
||||
|
||||
@router.get("/keys")
|
||||
async def keys():
|
||||
return get_pub_key_resp()
|
50
src/oauth2/models.py
Normal file
50
src/oauth2/models.py
Normal file
@ -0,0 +1,50 @@
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from pydantic.types import UUID4
|
||||
from sqlmodel.main import Field, Relationship
|
||||
|
||||
from ..storage.models import BaseTable
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from ..users.models import User
|
||||
|
||||
|
||||
class Client(BaseTable, table=True): # type: ignore
|
||||
client_id: str
|
||||
client_secret: str
|
||||
grant_types: str
|
||||
response_types: str
|
||||
redirect_uris: str
|
||||
|
||||
scope: str
|
||||
|
||||
|
||||
class AuthorizationCode(BaseTable, table=True): # type: ignore
|
||||
code: str
|
||||
client_id: str
|
||||
redirect_uri: str
|
||||
response_type: str
|
||||
scope: str
|
||||
auth_time: int
|
||||
expires_in: int
|
||||
code_challenge: Optional[str]
|
||||
code_challenge_method: Optional[str]
|
||||
nonce: Optional[str]
|
||||
|
||||
user_id: UUID4 = Field(foreign_key="users.id", nullable=False)
|
||||
user: "User" = Relationship(back_populates="user_authorization_codes")
|
||||
|
||||
|
||||
class Token(BaseTable, table=True): # type: ignore
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
scope: str
|
||||
issued_at: int
|
||||
expires_in: int
|
||||
refresh_token_expires_in: int
|
||||
client_id: str
|
||||
token_type: str
|
||||
revoked: bool
|
||||
|
||||
user_id: UUID4 = Field(foreign_key="users.id", nullable=False)
|
||||
user: "User" = Relationship(back_populates="user_tokens")
|
289
src/oauth2/storage.py
Normal file
289
src/oauth2/storage.py
Normal file
@ -0,0 +1,289 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from aioauth.models import AuthorizationCode, Client, Token
|
||||
from aioauth.requests import Request
|
||||
from aioauth.storage import BaseStorage
|
||||
from aioauth.types import CodeChallengeMethod, ResponseType, TokenType
|
||||
from aioauth.utils import enforce_list
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.sql.expression import delete
|
||||
|
||||
from .models import AuthorizationCode as AuthorizationCodeDB
|
||||
from .models import Client as ClientDB
|
||||
from .models import Token as TokenDB
|
||||
from ..config import settings
|
||||
from ..storage.sqlalchemy import SQLAlchemyStorage
|
||||
from ..users.crypto import encode_jwt, get_jwt
|
||||
from ..users.models import User
|
||||
|
||||
|
||||
class Storage(BaseStorage):
|
||||
def __init__(self, storage: SQLAlchemyStorage):
|
||||
self.storage = storage
|
||||
|
||||
async def get_user(self, request: Request):
|
||||
user: Optional[User] = None
|
||||
|
||||
if request.query.response_type == "token":
|
||||
# If ResponseType is token get the user from current session
|
||||
user = request.user
|
||||
|
||||
if request.post.grant_type == "authorization_code":
|
||||
# If GrantType is authorization code get user from DB by code
|
||||
q_results = await self.storage.select(
|
||||
select(AuthorizationCodeDB).where(
|
||||
AuthorizationCodeDB.code == request.post.code
|
||||
)
|
||||
)
|
||||
|
||||
authorization_code: Optional[AuthorizationCodeDB]
|
||||
authorization_code = q_results.scalars().one_or_none()
|
||||
|
||||
if not authorization_code:
|
||||
return
|
||||
|
||||
q_results = await self.storage.select(
|
||||
select(User).where(User.id == authorization_code.user_id)
|
||||
)
|
||||
|
||||
user = q_results.scalars().one_or_none()
|
||||
|
||||
if request.post.grant_type == "refresh_token":
|
||||
# Get user from token
|
||||
q_results = await self.storage.select(
|
||||
select(TokenDB)
|
||||
.where(TokenDB.refresh_token == request.post.refresh_token)
|
||||
.options(selectinload(TokenDB.user))
|
||||
)
|
||||
|
||||
token: Optional[TokenDB]
|
||||
|
||||
token = q_results.scalars().one_or_none()
|
||||
|
||||
if not token:
|
||||
return
|
||||
|
||||
user = token.user
|
||||
|
||||
return user
|
||||
|
||||
async def create_token(
|
||||
self,
|
||||
request: Request,
|
||||
client_id: str,
|
||||
scope: str,
|
||||
access_token: str,
|
||||
refresh_token: str,
|
||||
) -> Token:
|
||||
"""
|
||||
Create token and store it in storage.
|
||||
"""
|
||||
user = await self.get_user(request)
|
||||
|
||||
_access_token, _refresh_token = get_jwt(user)
|
||||
|
||||
token = Token(
|
||||
access_token=_access_token,
|
||||
client_id=client_id,
|
||||
expires_in=300,
|
||||
issued_at=int(datetime.now(tz=timezone.utc).timestamp()),
|
||||
refresh_token=_refresh_token,
|
||||
refresh_token_expires_in=900,
|
||||
revoked=False,
|
||||
scope=scope,
|
||||
token_type="Bearer",
|
||||
user=user,
|
||||
)
|
||||
|
||||
token_record = TokenDB(
|
||||
access_token=token.access_token,
|
||||
refresh_token=token.refresh_token,
|
||||
scope=token.scope,
|
||||
issued_at=token.issued_at,
|
||||
expires_in=token.expires_in,
|
||||
refresh_token_expires_in=token.refresh_token_expires_in,
|
||||
client_id=token.client_id,
|
||||
token_type=token.token_type,
|
||||
revoked=token.revoked,
|
||||
user_id=user.id,
|
||||
)
|
||||
|
||||
await self.storage.add(token_record)
|
||||
|
||||
return token
|
||||
|
||||
async def revoke_token(
|
||||
self,
|
||||
request: Request,
|
||||
token_type: Optional[TokenType] = "refresh_token",
|
||||
access_token: Optional[str] = None,
|
||||
refresh_token: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Remove refresh_token from whitelist.
|
||||
"""
|
||||
q_results = await self.storage.select(
|
||||
select(TokenDB).where(TokenDB.refresh_token == refresh_token)
|
||||
)
|
||||
token_record: Optional[TokenDB]
|
||||
token_record = q_results.scalars().one_or_none()
|
||||
|
||||
if token_record:
|
||||
token_record.revoked = True
|
||||
await self.storage.add(token_record)
|
||||
|
||||
async def get_token(
|
||||
self,
|
||||
request: Request,
|
||||
client_id: str,
|
||||
token_type: Optional[str] = "refresh_token",
|
||||
access_token: Optional[str] = None,
|
||||
refresh_token: Optional[str] = None,
|
||||
) -> Optional[Token]:
|
||||
if token_type == "refresh_token":
|
||||
q = select(TokenDB).where(TokenDB.refresh_token == refresh_token)
|
||||
else:
|
||||
q = select(TokenDB).where(TokenDB.access_token == access_token)
|
||||
|
||||
q_results = await self.storage.select(
|
||||
q.where(TokenDB.revoked == False).options( # noqa
|
||||
selectinload(TokenDB.user)
|
||||
)
|
||||
)
|
||||
|
||||
token_record: Optional[TokenDB]
|
||||
token_record = q_results.scalars().one_or_none()
|
||||
|
||||
if token_record:
|
||||
return Token(
|
||||
access_token=token_record.access_token,
|
||||
refresh_token=token_record.refresh_token,
|
||||
scope=token_record.scope,
|
||||
issued_at=token_record.issued_at,
|
||||
expires_in=token_record.expires_in,
|
||||
refresh_token_expires_in=token_record.refresh_token_expires_in,
|
||||
client_id=client_id,
|
||||
)
|
||||
|
||||
async def create_authorization_code(
|
||||
self,
|
||||
request: Request,
|
||||
client_id: str,
|
||||
scope: str,
|
||||
response_type: ResponseType,
|
||||
redirect_uri: str,
|
||||
code_challenge_method: Optional[CodeChallengeMethod],
|
||||
code_challenge: Optional[str],
|
||||
code: str,
|
||||
**kwargs,
|
||||
) -> AuthorizationCode:
|
||||
authorization_code = AuthorizationCode(
|
||||
auth_time=int(datetime.now(tz=timezone.utc).timestamp()),
|
||||
client_id=client_id,
|
||||
code=code,
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method=code_challenge_method,
|
||||
expires_in=300,
|
||||
redirect_uri=redirect_uri,
|
||||
response_type=response_type,
|
||||
scope=scope,
|
||||
user=request.user,
|
||||
)
|
||||
|
||||
authorization_code_record = AuthorizationCodeDB(
|
||||
code=authorization_code.code,
|
||||
client_id=authorization_code.client_id,
|
||||
redirect_uri=authorization_code.redirect_uri,
|
||||
response_type=authorization_code.response_type,
|
||||
scope=authorization_code.scope,
|
||||
auth_time=authorization_code.auth_time,
|
||||
expires_in=authorization_code.expires_in,
|
||||
code_challenge_method=authorization_code.code_challenge_method,
|
||||
code_challenge=authorization_code.code_challenge,
|
||||
nonce=authorization_code.nonce,
|
||||
user_id=request.user.id,
|
||||
)
|
||||
|
||||
await self.storage.add(authorization_code_record)
|
||||
|
||||
return authorization_code
|
||||
|
||||
async def get_client(
|
||||
self, request: Request, client_id: str, client_secret: Optional[str] = None
|
||||
) -> Optional[Client]:
|
||||
q_results = await self.storage.select(
|
||||
select(ClientDB).where(ClientDB.client_id == client_id)
|
||||
)
|
||||
|
||||
client_record: Optional[ClientDB]
|
||||
client_record = q_results.scalars().one_or_none()
|
||||
|
||||
if not client_record:
|
||||
return None
|
||||
|
||||
return Client(
|
||||
client_id=client_record.client_id,
|
||||
client_secret=client_record.client_secret,
|
||||
grant_types=[client_record.grant_types],
|
||||
response_types=[client_record.response_types],
|
||||
redirect_uris=[client_record.redirect_uris],
|
||||
scope=client_record.scope,
|
||||
)
|
||||
|
||||
async def get_authorization_code(
|
||||
self, request: Request, client_id: str, code: str
|
||||
) -> Optional[AuthorizationCode]:
|
||||
q_results = await self.storage.select(
|
||||
select(AuthorizationCodeDB).where(AuthorizationCodeDB.code == code)
|
||||
)
|
||||
|
||||
authorization_code_record: Optional[AuthorizationCode]
|
||||
authorization_code_record = q_results.scalars().one_or_none()
|
||||
|
||||
if not authorization_code_record:
|
||||
return None
|
||||
|
||||
return AuthorizationCode(
|
||||
code=authorization_code_record.code,
|
||||
client_id=authorization_code_record.client_id,
|
||||
redirect_uri=authorization_code_record.redirect_uri,
|
||||
response_type=authorization_code_record.response_type,
|
||||
scope=authorization_code_record.scope,
|
||||
auth_time=authorization_code_record.auth_time,
|
||||
expires_in=authorization_code_record.expires_in,
|
||||
code_challenge=authorization_code_record.code_challenge,
|
||||
code_challenge_method=authorization_code_record.code_challenge_method,
|
||||
nonce=authorization_code_record.nonce,
|
||||
)
|
||||
|
||||
async def delete_authorization_code(
|
||||
self, request: Request, client_id: str, code: str
|
||||
) -> None:
|
||||
await self.storage.delete(
|
||||
delete(AuthorizationCodeDB).where(AuthorizationCodeDB.code == code)
|
||||
)
|
||||
|
||||
async def get_id_token(
|
||||
self,
|
||||
request: Request,
|
||||
client_id: str,
|
||||
scope: str,
|
||||
response_type: ResponseType,
|
||||
redirect_uri: str,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
scopes = enforce_list(scope)
|
||||
user = await self.get_user(request)
|
||||
user_data = {}
|
||||
|
||||
if "email" in scopes:
|
||||
user_data["email"] = user.username
|
||||
user_data["username"] = user.username
|
||||
|
||||
return encode_jwt(
|
||||
expires_delta=settings.ACCESS_TOKEN_EXP,
|
||||
sub=str(user.id),
|
||||
additional_claims=user_data,
|
||||
)
|
33
src/scheduler.py
Normal file
33
src/scheduler.py
Normal file
@ -0,0 +1,33 @@
|
||||
import contextlib
|
||||
import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytz
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.telegram.enums import Message
|
||||
|
||||
scheduler = AsyncIOScheduler(timezone="Asia/ShangHai")
|
||||
if not scheduler.running:
|
||||
scheduler.start()
|
||||
|
||||
|
||||
async def delete_message(message: "Message") -> bool:
|
||||
with contextlib.suppress(Exception):
|
||||
await message.delete()
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def add_delete_message_job(message: "Message", delete_seconds: int = 60):
|
||||
scheduler.add_job(
|
||||
delete_message,
|
||||
"date",
|
||||
id=f"{message.chat.id}|{message.id}|delete_message",
|
||||
name=f"{message.chat.id}|{message.id}|delete_message",
|
||||
args=[message],
|
||||
run_date=datetime.datetime.now(pytz.timezone("Asia/Shanghai"))
|
||||
+ datetime.timedelta(seconds=delete_seconds),
|
||||
replace_existing=True,
|
||||
)
|
0
src/storage/__init__.py
Normal file
0
src/storage/__init__.py
Normal file
14
src/storage/models.py
Normal file
14
src/storage/models.py
Normal file
@ -0,0 +1,14 @@
|
||||
import uuid
|
||||
|
||||
from pydantic.types import UUID4
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
|
||||
class BaseTable(SQLModel):
|
||||
id: UUID4 = Field(
|
||||
primary_key=True,
|
||||
default_factory=uuid.uuid4,
|
||||
nullable=False,
|
||||
index=True,
|
||||
sa_column_kwargs={"unique": True},
|
||||
)
|
70
src/storage/sqlalchemy.py
Normal file
70
src/storage/sqlalchemy.py
Normal file
@ -0,0 +1,70 @@
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.engine.result import Result
|
||||
from sqlalchemy.sql.expression import Delete, Update
|
||||
from sqlalchemy.sql.selectable import Select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
class SQLAlchemyTransaction:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self.session = session
|
||||
|
||||
async def __aenter__(self) -> "SQLAlchemyTransaction":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback) -> None:
|
||||
if exc_type is None:
|
||||
await self.commit()
|
||||
else:
|
||||
await self.rollback()
|
||||
|
||||
await self.close()
|
||||
|
||||
async def rollback(self):
|
||||
await self.session.rollback()
|
||||
|
||||
async def commit(self):
|
||||
await self.session.commit()
|
||||
|
||||
async def close(self):
|
||||
await self.session.close()
|
||||
|
||||
|
||||
class SQLAlchemyStorage:
|
||||
def __init__(
|
||||
self, session: AsyncSession, transaction: SQLAlchemyTransaction
|
||||
) -> None:
|
||||
self.session = session
|
||||
self.transaction = transaction
|
||||
|
||||
async def select(self, q: Select) -> Result:
|
||||
async with self.transaction:
|
||||
return await self.session.execute(q)
|
||||
|
||||
async def add(self, model) -> None:
|
||||
async with self.transaction:
|
||||
self.session.add(model)
|
||||
|
||||
async def delete(self, q: Delete) -> None:
|
||||
async with self.transaction:
|
||||
await self.session.execute(q)
|
||||
|
||||
async def update(self, q: Update):
|
||||
async with self.transaction:
|
||||
await self.session.execute(q)
|
||||
|
||||
|
||||
sqlalchemy_session: Optional[AsyncSession] = None
|
||||
|
||||
|
||||
def get_sqlalchemy_storage() -> SQLAlchemyStorage:
|
||||
"""Get SQLAlchemy storage instance.
|
||||
|
||||
Returns:
|
||||
SQLAlchemyStorage: SQLAlchemy storage instance
|
||||
"""
|
||||
sqllachemy_trancation = SQLAlchemyTransaction(session=sqlalchemy_session)
|
||||
return SQLAlchemyStorage(
|
||||
session=sqlalchemy_session, transaction=sqllachemy_trancation
|
||||
)
|
0
src/telegram/__init__.py
Normal file
0
src/telegram/__init__.py
Normal file
25
src/telegram/enums.py
Normal file
25
src/telegram/enums.py
Normal file
@ -0,0 +1,25 @@
|
||||
from typing import Optional
|
||||
|
||||
from pyrogram import Client as PyroClient
|
||||
from pyrogram.types import Message as PyroMessage
|
||||
|
||||
|
||||
class Client(PyroClient): # noqa
|
||||
async def listen(self, chat_id, filters=None, timeout=None) -> Optional["Message"]:
|
||||
return
|
||||
|
||||
async def ask(
|
||||
self, chat_id, text, filters=None, timeout=None, *args, **kwargs
|
||||
) -> Optional["Message"]:
|
||||
return
|
||||
|
||||
def cancel_listener(self, chat_id):
|
||||
"""Cancel the conversation with the given chat_id."""
|
||||
|
||||
|
||||
class Message(PyroMessage): # noqa
|
||||
async def delay_delete(self, delete_seconds: int = 60) -> Optional[bool]:
|
||||
return
|
||||
|
||||
async def safe_delete(self, revoke: bool = True) -> None:
|
||||
return
|
15
src/telegram/filters.py
Normal file
15
src/telegram/filters.py
Normal file
@ -0,0 +1,15 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pyrogram.filters import create
|
||||
|
||||
from src.config import settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .enums import Message
|
||||
|
||||
|
||||
async def admin_filter(_, __, m: "Message"):
|
||||
return bool(m.from_user and m.from_user.id in settings.BOT_MANAGER_IDS)
|
||||
|
||||
|
||||
admin = create(admin_filter)
|
8
src/telegram/message.py
Normal file
8
src/telegram/message.py
Normal file
@ -0,0 +1,8 @@
|
||||
import re
|
||||
|
||||
NO_ACCOUNT_MSG = """UID `%s` 还没有注册账号,请联系管理员注册账号。"""
|
||||
ACCOUNT_MSG = """UID: `%s`\n邮箱: `%s`"""
|
||||
REG_MSG = """请发送需要使用的邮箱"""
|
||||
MAIL_REGEX = re.compile(r"^[a-zA-Z0-9_-]+@[a-zA-Z0-9_-]+(\.[a-zA-Z0-9_-]+)+$")
|
||||
LOGIN_MSG = """请点击下面的按钮登录:"""
|
||||
LOGIN_BUTTON = """跳转登录"""
|
0
src/telegram/plugins/__init__.py
Normal file
0
src/telegram/plugins/__init__.py
Normal file
24
src/telegram/plugins/account.py
Normal file
24
src/telegram/plugins/account.py
Normal file
@ -0,0 +1,24 @@
|
||||
from pyrogram import filters
|
||||
|
||||
from src.bot import bot
|
||||
from src.config import settings
|
||||
from src.telegram.enums import Client, Message
|
||||
from src.telegram.message import ACCOUNT_MSG, NO_ACCOUNT_MSG
|
||||
from src.users.crud import get_user_crud
|
||||
|
||||
|
||||
async def account(message: Message, uid: int):
|
||||
crud = get_user_crud()
|
||||
user = await crud.get_by_tg_id(uid)
|
||||
if user:
|
||||
await message.reply(ACCOUNT_MSG % (user.tg_id, user.username), quote=True)
|
||||
else:
|
||||
await message.reply(NO_ACCOUNT_MSG % uid, quote=True)
|
||||
|
||||
|
||||
@bot.on_message(filters=filters.private & filters.command("account"))
|
||||
async def get_account(_: Client, message: Message):
|
||||
uid = message.from_user.id
|
||||
if uid in settings.BOT_MANAGER_IDS and len(message.command) >= 2 and message.command[1].isnumeric():
|
||||
uid = int(message.command[1])
|
||||
await account(message, uid)
|
44
src/telegram/plugins/edit.py
Normal file
44
src/telegram/plugins/edit.py
Normal file
@ -0,0 +1,44 @@
|
||||
from pyrogram import filters
|
||||
|
||||
from pyromod.utils.errors import TimeoutConversationError
|
||||
from src.bot import bot
|
||||
from src.logs import logs
|
||||
from src.telegram.enums import Client, Message
|
||||
from src.telegram.filters import admin
|
||||
from src.telegram.message import REG_MSG, MAIL_REGEX
|
||||
from src.users.crud import get_user_crud
|
||||
|
||||
|
||||
async def reg(client: Client, from_id: int, uid: int):
|
||||
msg_ = await client.send_message(from_id, REG_MSG)
|
||||
try:
|
||||
msg = await client.listen(from_id, filters=filters.text, timeout=60)
|
||||
except TimeoutConversationError:
|
||||
await msg_.edit("响应超时,请重试")
|
||||
return
|
||||
if msg.text and MAIL_REGEX.match(msg.text):
|
||||
crud = get_user_crud()
|
||||
try:
|
||||
user = await crud.get_by_tg_id(uid)
|
||||
if user:
|
||||
await crud.update(user, username=msg.text)
|
||||
else:
|
||||
await crud.create(
|
||||
username=msg.text,
|
||||
password="1",
|
||||
tg_id=uid,
|
||||
)
|
||||
except Exception as e:
|
||||
logs.exception("注册失败", exc_info=e)
|
||||
await msg.reply_text("注册失败")
|
||||
await msg.reply_text("注册成功")
|
||||
else:
|
||||
await msg.reply_text("邮箱格式错误")
|
||||
|
||||
|
||||
@bot.on_message(filters=filters.private & filters.command("edit") & admin)
|
||||
async def edit_account(client: Client, message: Message):
|
||||
uid = from_id = message.from_user.id
|
||||
if len(message.command) >= 2 and message.command[1].isnumeric():
|
||||
uid = int(message.command[1])
|
||||
await reg(client, from_id, uid)
|
41
src/telegram/plugins/start.py
Normal file
41
src/telegram/plugins/start.py
Normal file
@ -0,0 +1,41 @@
|
||||
from httpx import URL
|
||||
from pyrogram import filters, Client
|
||||
from pyrogram.types import InlineKeyboardMarkup, InlineKeyboardButton
|
||||
|
||||
from src.bot import bot
|
||||
from src.config import settings
|
||||
from src.telegram.enums import Message
|
||||
from src.telegram.message import NO_ACCOUNT_MSG, LOGIN_MSG, LOGIN_BUTTON
|
||||
from src.users.crud import get_user_crud
|
||||
from src.utils.telegram import encode_telegram_auth_data
|
||||
|
||||
|
||||
async def login(message: Message):
|
||||
uid = message.from_user.id
|
||||
crud = get_user_crud()
|
||||
user = await crud.get_by_tg_id(uid)
|
||||
if not user:
|
||||
await message.reply(NO_ACCOUNT_MSG % uid, quote=True)
|
||||
return
|
||||
token = await encode_telegram_auth_data(uid)
|
||||
url = settings.PROJECT_URL + "/api/users/auth"
|
||||
url = URL(url).copy_add_param("jwt", token)
|
||||
url = str(url)
|
||||
await message.reply(
|
||||
LOGIN_MSG,
|
||||
quote=True,
|
||||
reply_markup=InlineKeyboardMarkup(
|
||||
[[InlineKeyboardButton(LOGIN_BUTTON, url=url)]]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@bot.on_message(filters=filters.private & filters.command("start"))
|
||||
async def start(client: Client, message: Message):
|
||||
if message.command and len(message.command) >= 2:
|
||||
action = message.command[1]
|
||||
if action == "login":
|
||||
await login(message)
|
||||
return
|
||||
me = await client.get_me()
|
||||
await message.reply(f"Hello, I'm {me.first_name}. ")
|
0
src/users/__init__.py
Normal file
0
src/users/__init__.py
Normal file
32
src/users/backends.py
Normal file
32
src/users/backends.py
Normal file
@ -0,0 +1,32 @@
|
||||
from fastapi.security.utils import get_authorization_scheme_param
|
||||
from starlette.authentication import AuthCredentials, AuthenticationBackend
|
||||
|
||||
from .crypto import authenticate, read_rsa_key_from_env
|
||||
from .models import User, UserAnonymous
|
||||
from ..config import settings
|
||||
|
||||
|
||||
class TokenAuthenticationBackend(AuthenticationBackend):
|
||||
async def authenticate(self, request):
|
||||
authorization: str = request.headers.get("Authorization")
|
||||
_, bearer_token = get_authorization_scheme_param(authorization)
|
||||
|
||||
token: str = request.cookies.get("access_token") or bearer_token
|
||||
|
||||
if not token:
|
||||
return AuthCredentials(), UserAnonymous()
|
||||
|
||||
key = read_rsa_key_from_env(settings.JWT_PUBLIC_KEY)
|
||||
|
||||
is_authenticated, decoded_token = authenticate(token=token, key=key)
|
||||
|
||||
if is_authenticated:
|
||||
return AuthCredentials(), User(
|
||||
id=decoded_token["sub"],
|
||||
is_superuser=decoded_token["is_superuser"],
|
||||
is_blocked=decoded_token["is_blocked"],
|
||||
is_active=decoded_token["is_active"],
|
||||
username=decoded_token["username"],
|
||||
)
|
||||
|
||||
return AuthCredentials(), UserAnonymous()
|
40
src/users/crud.py
Normal file
40
src/users/crud.py
Normal file
@ -0,0 +1,40 @@
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.sql.expression import Update
|
||||
|
||||
from .models import User
|
||||
from ..storage.sqlalchemy import SQLAlchemyStorage, get_sqlalchemy_storage
|
||||
|
||||
|
||||
class SQLAlchemyCRUD:
|
||||
def __init__(self, storage: SQLAlchemyStorage):
|
||||
self.storage = storage
|
||||
|
||||
async def get_by_tg_id(self, tg_id: int) -> Optional[User]:
|
||||
q_results = await self.storage.select(
|
||||
select(User)
|
||||
.options(
|
||||
# for relationship loading, eager loading should be applied.
|
||||
selectinload(User.user_tokens)
|
||||
)
|
||||
.where(User.tg_id == tg_id)
|
||||
)
|
||||
|
||||
return q_results.scalars().one_or_none()
|
||||
|
||||
async def create(self, **kwargs) -> None:
|
||||
user = User(**kwargs)
|
||||
await self.storage.add(user)
|
||||
|
||||
async def update(self, user: User, **kwargs) -> None:
|
||||
await self.storage.update(
|
||||
Update(User).where(User.id == user.id).values(**kwargs)
|
||||
)
|
||||
|
||||
|
||||
def get_user_crud(storage: SQLAlchemyStorage = None) -> SQLAlchemyCRUD:
|
||||
if storage is None:
|
||||
storage = get_sqlalchemy_storage()
|
||||
return SQLAlchemyCRUD(storage=storage)
|
166
src/users/crypto.py
Normal file
166
src/users/crypto.py
Normal file
@ -0,0 +1,166 @@
|
||||
import base64
|
||||
import pathlib
|
||||
import re
|
||||
import string
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from Crypto.PublicKey import RSA
|
||||
from jose import constants, jwt
|
||||
from jose.exceptions import JWTError
|
||||
|
||||
from ..config import settings
|
||||
|
||||
RANDOM_STRING_CHARS = string.ascii_lowercase + string.ascii_uppercase + string.digits
|
||||
KEYS = {}
|
||||
|
||||
|
||||
def reformat_rsa_key(rsa_key: str) -> str:
|
||||
"""Reformat an RSA PEM key without newlines to one with correct newline characters
|
||||
|
||||
@param rsa_key: the PEM RSA key lacking newline characters
|
||||
@return: the reformatted PEM RSA key with appropriate newline characters
|
||||
"""
|
||||
# split headers from the body
|
||||
split_rsa_key = re.split(r"(-+)", rsa_key)
|
||||
|
||||
# add newlines between headers and body
|
||||
split_rsa_key.insert(4, "\n")
|
||||
split_rsa_key.insert(6, "\n")
|
||||
|
||||
reformatted_rsa_key = "".join(split_rsa_key)
|
||||
|
||||
# reformat body
|
||||
return RSA.importKey(reformatted_rsa_key).exportKey().decode("utf-8")
|
||||
|
||||
|
||||
def read_rsa_key_from_env(file_path: str) -> str:
|
||||
if file_path in KEYS:
|
||||
return KEYS[file_path]
|
||||
path = pathlib.Path(file_path)
|
||||
|
||||
# path to rsa key file
|
||||
if path.is_file():
|
||||
with open(file_path, "rb") as key_file:
|
||||
jwt_private_key = RSA.importKey(key_file.read()).exportKey()
|
||||
k = jwt_private_key.decode("utf-8")
|
||||
KEYS[file_path] = k
|
||||
return k
|
||||
|
||||
# rsa key without newlines
|
||||
if "\n" not in file_path:
|
||||
k = reformat_rsa_key(file_path)
|
||||
KEYS[file_path] = k
|
||||
return k
|
||||
|
||||
return file_path
|
||||
|
||||
|
||||
def get_n(rsa: RSA):
|
||||
bytes_data = rsa.n.to_bytes((rsa.n.bit_length() + 7) // 8, 'big')
|
||||
return base64.urlsafe_b64encode(bytes_data).decode('utf-8')
|
||||
|
||||
|
||||
def get_pub_key_resp():
|
||||
pub_key = RSA.importKey(read_rsa_key_from_env(settings.JWT_PUBLIC_KEY))
|
||||
return {
|
||||
"keys": [
|
||||
{
|
||||
"n": get_n(pub_key),
|
||||
"kty": "RSA",
|
||||
"alg": "RS256",
|
||||
"kid": "sig",
|
||||
"e": "AQAB",
|
||||
"use": "sig"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def encode_jwt(
|
||||
expires_delta,
|
||||
sub,
|
||||
secret=None,
|
||||
additional_claims=None,
|
||||
algorithm=constants.ALGORITHMS.RS256,
|
||||
):
|
||||
if additional_claims is None:
|
||||
additional_claims = {}
|
||||
if secret is None:
|
||||
secret = read_rsa_key_from_env(settings.JWT_PRIVATE_KEY)
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
claims = {
|
||||
"iat": now,
|
||||
"jti": str(uuid.uuid4()),
|
||||
"nbf": now,
|
||||
"sub": sub,
|
||||
"exp": now + timedelta(seconds=expires_delta),
|
||||
**additional_claims,
|
||||
}
|
||||
|
||||
return jwt.encode(
|
||||
claims,
|
||||
secret,
|
||||
algorithm,
|
||||
)
|
||||
|
||||
|
||||
def decode_jwt(
|
||||
encoded_token,
|
||||
secret=None,
|
||||
algorithms=None,
|
||||
):
|
||||
if algorithms is None:
|
||||
algorithms = constants.ALGORITHMS.RS256
|
||||
if secret is None:
|
||||
secret = read_rsa_key_from_env(settings.JWT_PRIVATE_KEY)
|
||||
return jwt.decode(
|
||||
encoded_token,
|
||||
secret,
|
||||
algorithms=algorithms,
|
||||
)
|
||||
|
||||
|
||||
def get_jwt(user):
|
||||
access_token = encode_jwt(
|
||||
sub=str(user.id),
|
||||
expires_delta=settings.ACCESS_TOKEN_EXP,
|
||||
additional_claims={
|
||||
"token_type": "access",
|
||||
"is_blocked": user.is_blocked,
|
||||
"is_superuser": user.is_superuser,
|
||||
"username": user.username,
|
||||
"is_active": user.is_active,
|
||||
},
|
||||
)
|
||||
|
||||
refresh_token = encode_jwt(
|
||||
sub=str(user.id),
|
||||
expires_delta=settings.REFRESH_TOKEN_EXP,
|
||||
additional_claims={
|
||||
"token_type": "refresh",
|
||||
"is_blocked": user.is_blocked,
|
||||
"is_superuser": user.is_superuser,
|
||||
"username": user.username,
|
||||
"is_active": user.is_active,
|
||||
},
|
||||
)
|
||||
|
||||
return access_token, refresh_token
|
||||
|
||||
|
||||
def authenticate(
|
||||
*,
|
||||
token: str,
|
||||
key: str,
|
||||
) -> Tuple[bool, Dict]:
|
||||
"""Authenticate user by token"""
|
||||
try:
|
||||
token_header = jwt.get_unverified_header(token)
|
||||
decoded_token = jwt.decode(token, key, algorithms=token_header.get("alg"))
|
||||
except JWTError:
|
||||
return False, {}
|
||||
else:
|
||||
return True, decoded_token
|
79
src/users/endpoints.py
Normal file
79
src/users/endpoints.py
Normal file
@ -0,0 +1,79 @@
|
||||
from http import HTTPStatus
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from jose import JWTError
|
||||
from starlette.requests import Request
|
||||
|
||||
from .crud import SQLAlchemyCRUD
|
||||
from .crypto import get_jwt
|
||||
from ..config import settings
|
||||
from ..html import templates
|
||||
from ..storage.sqlalchemy import SQLAlchemyStorage, get_sqlalchemy_storage
|
||||
from ..utils.oauth import back_auth_request
|
||||
from ..utils.redirect import RedirectResponseBuilder
|
||||
from ..utils.telegram import decode_telegram_auth_data, verify_telegram_auth_data
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/login", name="users:login:get")
|
||||
async def user_login_get(request: Request):
|
||||
if request.user.is_authenticated:
|
||||
if resp := await back_auth_request(request):
|
||||
return resp
|
||||
return RedirectResponseBuilder().build(settings.PROJECT_LOGIN_SUCCESS_URL)
|
||||
url = request.url
|
||||
callback_url = str(url).replace("/login", "/callback")
|
||||
return templates.TemplateResponse(
|
||||
"login.jinja",
|
||||
{"request": request, "callback_url": callback_url, "username": settings.BOT_USERNAME}
|
||||
)
|
||||
|
||||
|
||||
async def auth(
|
||||
tg_id: int,
|
||||
request: Request,
|
||||
storage: SQLAlchemyStorage,
|
||||
):
|
||||
if tg_id is None:
|
||||
raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED)
|
||||
crud = SQLAlchemyCRUD(storage=storage)
|
||||
user = await crud.get_by_tg_id(tg_id=tg_id)
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED)
|
||||
|
||||
access_token, refresh_token = get_jwt(user)
|
||||
# NOTE: Setting expire causes an exception for requests library:
|
||||
# https://github.com/psf/requests/issues/6004
|
||||
if resp := await back_auth_request(request, access_token, refresh_token):
|
||||
return resp
|
||||
resp = RedirectResponseBuilder()
|
||||
resp.set_cookie(
|
||||
key="access_token", value=access_token, max_age=settings.ACCESS_TOKEN_EXP
|
||||
)
|
||||
resp.set_cookie(
|
||||
key="refresh_token", value=refresh_token, max_age=settings.REFRESH_TOKEN_EXP
|
||||
)
|
||||
return resp.build(settings.PROJECT_LOGIN_SUCCESS_URL)
|
||||
|
||||
|
||||
@router.get("/callback", name="users:login")
|
||||
async def user_login(
|
||||
request: Request,
|
||||
storage: SQLAlchemyStorage = Depends(get_sqlalchemy_storage),
|
||||
):
|
||||
tg_id = await verify_telegram_auth_data(request.query_params)
|
||||
return await auth(tg_id, request, storage)
|
||||
|
||||
|
||||
@router.get("/auth", name="users:auth")
|
||||
async def user_auth(
|
||||
request: Request,
|
||||
storage: SQLAlchemyStorage = Depends(get_sqlalchemy_storage),
|
||||
):
|
||||
try:
|
||||
tg_id = await decode_telegram_auth_data(request.query_params)
|
||||
except JWTError:
|
||||
tg_id = None
|
||||
return await auth(tg_id, request, storage)
|
37
src/users/models.py
Normal file
37
src/users/models.py
Normal file
@ -0,0 +1,37 @@
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import Column, BigInteger
|
||||
from sqlmodel.main import Field, Relationship
|
||||
|
||||
from ..storage.models import BaseTable
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from ..oauth2.models import AuthorizationCode, Token
|
||||
|
||||
|
||||
class UserAnonymous(BaseModel):
|
||||
@property
|
||||
def is_authenticated(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class User(BaseTable, table=True): # type: ignore
|
||||
__tablename__ = "users"
|
||||
|
||||
is_superuser: bool = False
|
||||
is_blocked: bool = False
|
||||
is_active: bool = False
|
||||
|
||||
username: str = Field(nullable=False, sa_column_kwargs={"unique": True}, index=True)
|
||||
password: Optional[str] = None
|
||||
tg_id: int = Field(sa_column=Column(BigInteger(), nullable=False))
|
||||
|
||||
user_authorization_codes: List["AuthorizationCode"] = Relationship(
|
||||
back_populates="user"
|
||||
)
|
||||
user_tokens: List["Token"] = Relationship(back_populates="user")
|
||||
|
||||
@property
|
||||
def is_authenticated(self) -> bool:
|
||||
return True
|
0
src/utils/__init__.py
Normal file
0
src/utils/__init__.py
Normal file
38
src/utils/oauth.py
Normal file
38
src/utils/oauth.py
Normal file
@ -0,0 +1,38 @@
|
||||
from typing import Optional
|
||||
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import RedirectResponse
|
||||
|
||||
from src.config import settings
|
||||
from src.users.crypto import encode_jwt, decode_jwt
|
||||
from src.utils.redirect import RedirectResponseBuilder
|
||||
|
||||
|
||||
async def to_login_request(request: Request) -> RedirectResponse:
|
||||
query_params = dict(request.query_params)
|
||||
params = ""
|
||||
for key, value in query_params.items():
|
||||
params += f"{key}={value}&"
|
||||
params = params[:-1]
|
||||
jwt = encode_jwt(settings.ACCESS_TOKEN_EXP, "", additional_claims={"params": params})
|
||||
resp = RedirectResponseBuilder()
|
||||
resp.set_cookie("SEND", jwt, max_age=settings.ACCESS_TOKEN_EXP)
|
||||
return resp.build("/api/users/login")
|
||||
|
||||
|
||||
async def back_auth_request(
|
||||
request: Request,
|
||||
access_token: str = None,
|
||||
refresh_token: str = None,
|
||||
) -> Optional[RedirectResponse]:
|
||||
cookie = request.cookies.get("SEND")
|
||||
if cookie is None:
|
||||
return None
|
||||
params = decode_jwt(cookie)["params"]
|
||||
resp = RedirectResponseBuilder()
|
||||
if access_token:
|
||||
resp.set_cookie("access_token", access_token, max_age=settings.ACCESS_TOKEN_EXP)
|
||||
if refresh_token:
|
||||
resp.set_cookie("refresh_token", refresh_token, max_age=settings.ACCESS_TOKEN_EXP)
|
||||
resp.delete_cookie("SEND")
|
||||
return resp.build(f"/oauth2/authorize?{params}", status_code=303)
|
78
src/utils/redirect.py
Normal file
78
src/utils/redirect.py
Normal file
@ -0,0 +1,78 @@
|
||||
import http.cookies
|
||||
import typing
|
||||
from datetime import datetime
|
||||
from email.utils import format_datetime
|
||||
|
||||
from starlette.datastructures import MutableHeaders
|
||||
from starlette.responses import RedirectResponse
|
||||
|
||||
|
||||
class RedirectResponseBuilder:
|
||||
def __init__(self):
|
||||
self.raw_headers = []
|
||||
|
||||
def set_cookie(
|
||||
self,
|
||||
key: str,
|
||||
value: str = "",
|
||||
max_age: typing.Optional[int] = None,
|
||||
expires: typing.Optional[typing.Union[datetime, str, int]] = None,
|
||||
path: str = "/",
|
||||
domain: typing.Optional[str] = None,
|
||||
secure: bool = False,
|
||||
httponly: bool = False,
|
||||
samesite: typing.Optional[typing.Literal["lax", "strict", "none"]] = "lax",
|
||||
) -> None:
|
||||
cookie: "http.cookies.BaseCookie[str]" = http.cookies.SimpleCookie()
|
||||
cookie[key] = value
|
||||
if max_age is not None:
|
||||
cookie[key]["max-age"] = max_age
|
||||
if expires is not None:
|
||||
if isinstance(expires, datetime):
|
||||
cookie[key]["expires"] = format_datetime(expires, usegmt=True)
|
||||
else:
|
||||
cookie[key]["expires"] = expires
|
||||
if path is not None:
|
||||
cookie[key]["path"] = path
|
||||
if domain is not None:
|
||||
cookie[key]["domain"] = domain
|
||||
if secure:
|
||||
cookie[key]["secure"] = True
|
||||
if httponly:
|
||||
cookie[key]["httponly"] = True
|
||||
if samesite is not None:
|
||||
assert samesite.lower() in [
|
||||
"strict",
|
||||
"lax",
|
||||
"none",
|
||||
], "samesite must be either 'strict', 'lax' or 'none'"
|
||||
cookie[key]["samesite"] = samesite
|
||||
cookie_val = cookie.output(header="").strip()
|
||||
self.raw_headers.append((b"set-cookie", cookie_val.encode("latin-1")))
|
||||
|
||||
def delete_cookie(
|
||||
self,
|
||||
key: str,
|
||||
path: str = "/",
|
||||
domain: typing.Optional[str] = None,
|
||||
secure: bool = False,
|
||||
httponly: bool = False,
|
||||
samesite: typing.Optional[typing.Literal["lax", "strict", "none"]] = "lax",
|
||||
) -> None:
|
||||
self.set_cookie(
|
||||
key,
|
||||
max_age=0,
|
||||
expires=0,
|
||||
path=path,
|
||||
domain=domain,
|
||||
secure=secure,
|
||||
httponly=httponly,
|
||||
samesite=samesite,
|
||||
)
|
||||
|
||||
@property
|
||||
def headers(self) -> MutableHeaders:
|
||||
return MutableHeaders(raw=self.raw_headers)
|
||||
|
||||
def build(self, url: str, status_code: int = 307):
|
||||
return RedirectResponse(url, headers=self.headers, status_code=status_code)
|
45
src/utils/telegram.py
Normal file
45
src/utils/telegram.py
Normal file
@ -0,0 +1,45 @@
|
||||
import hashlib
|
||||
import hmac
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from starlette.datastructures import QueryParams
|
||||
|
||||
from src.config import settings
|
||||
from src.users.crypto import encode_jwt, decode_jwt
|
||||
|
||||
|
||||
async def verify_telegram_auth_data(params: QueryParams) -> Optional[int]:
|
||||
data = list(params.items())
|
||||
hash_str = ""
|
||||
text_list = []
|
||||
for key, value in data:
|
||||
if key == "hash":
|
||||
hash_str = value
|
||||
else:
|
||||
text_list.append(f"{key}={value}")
|
||||
check_string = "\n".join(sorted(text_list))
|
||||
|
||||
secret_key = hashlib.sha256(str.encode(settings.BOT_TOKEN)).digest()
|
||||
hmac_hash = hmac.new(secret_key, str.encode(check_string), hashlib.sha256).hexdigest()
|
||||
|
||||
return int(params.get("id")) if hmac_hash == hash_str else None
|
||||
|
||||
|
||||
async def encode_telegram_auth_data(uid: int) -> str:
|
||||
jwt = encode_jwt(settings.ACCESS_TOKEN_EXP, str(uid))
|
||||
return jwt
|
||||
|
||||
|
||||
async def decode_telegram_auth_data(params: QueryParams) -> Optional[int]:
|
||||
jwt = params.get("jwt")
|
||||
if not jwt:
|
||||
return None
|
||||
if not jwt:
|
||||
return None
|
||||
data = decode_jwt(jwt)
|
||||
now = datetime.now(timezone.utc)
|
||||
uid, exp = data["sub"], data["exp"]
|
||||
if exp < now.timestamp():
|
||||
return None
|
||||
return int(uid)
|
Loading…
Reference in New Issue
Block a user