Files
2025-11-25 10:51:26 -05:00

248 lines
7.5 KiB
Python

import json
import os
import time
import requests
import socket
class RSCAuth:
def __init__(self, config_file='rsc.json'):
self.config_file = config_file
self.load_config()
self.token = None
self.token_expiration = None
def load_config(self):
if not os.path.exists(self.config_file):
raise FileNotFoundError(f"Configuration file {self.config_file} not found")
with open(self.config_file, 'r') as f:
config = json.load(f)
self.client_id = config.get('client_id')
self.client_secret = config.get('client_secret')
self.access_token_uri = config.get('access_token_uri')
if not all([self.client_id, self.client_secret, self.access_token_uri]):
raise ValueError("Missing required fields in config: client_id, client_secret, access_token_uri")
# Derive host from access_token_uri
self.host = self.access_token_uri.replace('https://', '').replace('/api/client_token', '')
def get_token(self):
# Check if we have a cached token
cache_file = self._get_cache_file()
if os.path.exists(cache_file):
with open(cache_file, 'r') as f:
expiration, token = f.read().strip().split(' ', 1)
expiration = int(expiration)
if time.time() < expiration - 1800: # Refresh 30 min before expiry
self.token = token
self.token_expiration = expiration
return token
# Get new token
return self._fetch_token()
def _fetch_token(self):
payload = {
'client_id': self.client_id,
'client_secret': self.client_secret
}
headers = {'accept': 'application/json', 'Content-Type': 'application/json'}
response = requests.post(self.access_token_uri, json=payload, headers=headers)
response.raise_for_status()
data = response.json()
self.token = data['access_token']
expires_in = data['expires_in']
self.token_expiration = int(time.time()) + expires_in
# Cache the token
cache_file = self._get_cache_file()
os.makedirs(os.path.dirname(cache_file), exist_ok=True)
with open(cache_file, 'w') as f:
f.write(f"{self.token_expiration} {self.token}")
os.chmod(cache_file, 0o600)
return self.token
def _get_cache_file(self):
# Use the id part after 'client|'
if '|' in self.client_id:
id_part = self.client_id.split('|')[1]
else:
id_part = self.client_id
return os.path.expanduser(f"~/.rbkRscsession.{id_part}")
def get_headers(self):
return {
'Authorization': f'Bearer {self.get_token()}',
'Content-Type': 'application/json'
}
class RSCGraphQL:
def __init__(self, auth):
self.auth = auth
self.endpoint = f"https://{self.auth.host}/api/graphql"
def query(self, query, variables=None):
payload = {'query': query}
if variables:
payload['variables'] = variables
headers = self.auth.get_headers()
response = requests.post(self.endpoint, json=payload, headers=headers)
response.raise_for_status()
data = response.json()
# Check for GraphQL errors
if 'errors' in data:
raise Exception(f"GraphQL errors: {data['errors']}")
return data
def get_local_database_id(self):
"""Get the ID of the local database on this host, preferring one protected by SLA"""
hostname = socket.gethostname()
query = """
query OracleDatabases($filter: [Filter!]) {
oracleDatabases(filter: $filter) {
nodes {
id
dbUniqueName
isRelic
effectiveSlaDomain {
id
name
}
cluster {
id
name
}
logicalPath {
fid
name
objectType
}
}
}
}
"""
variables = {
"filter": [
{"texts": ["false"], "field": "IS_REPLICATED"}
]
}
response = self.query(query, variables)
all_dbs = response['data']['oracleDatabases']['nodes']
# Filter databases on this host
dbs = [db for db in all_dbs if db['logicalPath'] and db['logicalPath'][0]['name'] == hostname]
if not dbs:
raise ValueError(f"No databases found on host {hostname}")
# Filter databases with SLA protection
protected_dbs = [db for db in dbs if db.get('effectiveSlaDomain')]
if protected_dbs:
if len(protected_dbs) == 1:
return protected_dbs[0]['id']
else:
# Multiple protected, use the first one with a warning
print(f"WARN: Multiple protected databases on {hostname}, using {protected_dbs[0]['dbUniqueName']}")
return protected_dbs[0]['id']
else:
if len(dbs) == 1:
return dbs[0]['id']
else:
raise ValueError(f"Multiple databases on {hostname}, none protected by SLA")
def introspect_schema(self):
"""Introspect the GraphQL schema to get type information"""
introspection_query = """
query IntrospectionQuery {
__schema {
types {
name
kind
description
fields(includeDeprecated: true) {
name
description
type {
name
kind
ofType {
name
kind
}
}
args {
name
description
type {
name
kind
ofType {
name
kind
}
}
}
}
}
}
}
"""
return self.query(introspection_query)
def get_type_info(self, type_name):
"""Get detailed information about a specific GraphQL type"""
query = """
query GetTypeInfo($typeName: String!) {
__type(name: $typeName) {
name
kind
description
fields(includeDeprecated: true) {
name
description
type {
name
kind
ofType {
name
kind
ofType {
name
kind
}
}
}
args {
name
description
type {
name
kind
ofType {
name
kind
ofType {
name
kind
}
}
}
}
}
}
}
"""
return self.query(query, {"typeName": type_name})